108 lines
2.9 KiB
Python
108 lines
2.9 KiB
Python
import pytest
|
|
|
|
from plesna.graph import Edge, Graph, Node
|
|
|
|
|
|
def test_append_nodess():
|
|
nodeA = Node(name="A")
|
|
nodeB = Node(name="B")
|
|
|
|
graph = Graph()
|
|
graph.add_node(nodeA)
|
|
graph.add_node(nodeB)
|
|
|
|
assert graph.nodes == {nodeA, nodeB}
|
|
|
|
|
|
def test_append_edges():
|
|
nodeA = Node(name="A")
|
|
nodeB = Node(name="B")
|
|
nodeC = Node(name="C")
|
|
|
|
edge1 = Edge(arrow_name="arrow", source=nodeA, target=nodeC)
|
|
edge2 = Edge(arrow_name="arrow", source=nodeB, target=nodeC)
|
|
|
|
graph = Graph()
|
|
graph.add_edge(edge1)
|
|
graph.add_edge(edge2)
|
|
|
|
assert graph.nodes == {nodeA, nodeB, nodeC}
|
|
|
|
|
|
def test_init_edges_nodes():
|
|
nodeA = Node(name="A")
|
|
nodeB = Node(name="B")
|
|
nodeC = Node(name="C")
|
|
|
|
edge1 = Edge(arrow_name="arrow", source=nodeB, target=nodeC)
|
|
|
|
graph = Graph()
|
|
graph.add_node(nodeA)
|
|
graph.add_edge(edge1)
|
|
|
|
assert graph.nodes == {nodeA, nodeB, nodeC}
|
|
|
|
|
|
@pytest.fixture
|
|
def nodes():
|
|
return {
|
|
"A": Node(name="A"),
|
|
"B": Node(name="B"),
|
|
"C": Node(name="C"),
|
|
"D": Node(name="D"),
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def dag_edges(nodes):
|
|
return {
|
|
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
|
|
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
|
|
"3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]),
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def notdag_edges(nodes):
|
|
return {
|
|
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
|
|
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
|
|
"3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]),
|
|
"4": Edge(arrow_name="arrow", source=nodes["D"], target=nodes["B"]),
|
|
}
|
|
|
|
|
|
def test_get_edges_from(nodes, dag_edges):
|
|
edges = dag_edges
|
|
graph = Graph(edges=edges.values())
|
|
assert graph.get_edges_from(nodes["A"]) == [edges["1"]]
|
|
|
|
|
|
def test_get_targets_from(nodes, dag_edges):
|
|
edges = dag_edges
|
|
graph = Graph(edges=edges.values())
|
|
assert graph.get_direct_targets_from(nodes["A"]) == set([nodes["C"]])
|
|
assert graph.get_direct_targets_from(nodes["C"]) == set([nodes["D"]])
|
|
assert graph.get_direct_targets_from(nodes["D"]) == set()
|
|
assert graph.get_targets_from(nodes["A"]) == set([nodes["C"], nodes["D"]])
|
|
|
|
|
|
def test_get_sources_from(nodes, dag_edges):
|
|
edges = dag_edges
|
|
graph = Graph(edges=edges.values())
|
|
assert graph.get_direct_sources_from(nodes["A"]) == set()
|
|
assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]])
|
|
assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]])
|
|
|
|
assert graph.get_sources_from(nodes["D"]) == set(
|
|
[nodes["A"], nodes["B"], nodes["C"]]
|
|
)
|
|
|
|
|
|
def test_valid_dage(dag_edges, notdag_edges):
|
|
graph = Graph(edges=dag_edges.values())
|
|
assert graph.is_dag()
|
|
|
|
graph = Graph(edges=notdag_edges.values())
|
|
assert not graph.is_dag()
|