Feat: add is_dag to Graph

This commit is contained in:
2024-10-27 14:10:33 +01:00
parent 9ff68cb285
commit 226ce84dce
2 changed files with 57 additions and 20 deletions

View File

@@ -54,7 +54,7 @@ def nodes():
@pytest.fixture
def edges(nodes):
def dag_edges(nodes):
return {
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
@@ -63,29 +63,45 @@ def edges(nodes):
@pytest.fixture
def graph(nodes, edges):
graph = Graph()
graph.add_edge(edges["1"])
graph.add_edge(edges["2"])
graph.add_edge(edges["3"])
return graph
def notdag_edges(nodes):
return {
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
"3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]),
"4": Edge(arrow_name="arrow", source=nodes["D"], target=nodes["B"]),
}
def test_get_edges_from(nodes, edges, graph):
def test_get_edges_from(nodes, dag_edges):
edges = dag_edges
graph = Graph(edges=edges.values())
assert graph.get_edges_from(nodes["A"]) == [edges["1"]]
def test_get_targets_from(nodes, edges, graph):
def test_get_targets_from(nodes, dag_edges):
edges = dag_edges
graph = Graph(edges=edges.values())
assert graph.get_direct_targets_from(nodes["A"]) == set([nodes["C"]])
assert graph.get_direct_targets_from(nodes["C"]) == set([nodes["D"]])
assert graph.get_direct_targets_from(nodes["D"]) == set()
assert graph.get_targets_from(nodes["A"]) == set([nodes["C"], nodes["D"]])
def test_get_sources_from(nodes, edges, graph):
def test_get_sources_from(nodes, dag_edges):
edges = dag_edges
graph = Graph(edges=edges.values())
assert graph.get_direct_sources_from(nodes["A"]) == set()
assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]])
assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]])
assert graph.get_sources_from(nodes["D"]) == set([nodes["A"], nodes["B"], nodes["C"]])
assert graph.get_sources_from(nodes["D"]) == set(
[nodes["A"], nodes["B"], nodes["C"]]
)
def test_valid_dage(dag_edges, notdag_edges):
graph = Graph(edges=dag_edges.values())
assert graph.is_dag()
graph = Graph(edges=notdag_edges.values())
assert not graph.is_dag()