From 226ce84dcea7e0c400d808ec559e386f2f93c82e Mon Sep 17 00:00:00 2001 From: Bertrand Benjamin Date: Sun, 27 Oct 2024 14:10:33 +0100 Subject: [PATCH] Feat: add is_dag to Graph --- plesna/graph.py | 37 +++++++++++++++++++++++++++-------- tests/graphs/test_graph.py | 40 ++++++++++++++++++++++++++------------ 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/plesna/graph.py b/plesna/graph.py index 4e04baa..2764bb2 100644 --- a/plesna/graph.py +++ b/plesna/graph.py @@ -20,22 +20,36 @@ class Edge(BaseModel): class Graph: - def __init__(self): + def __init__(self, nodes: list[Node] = [], edges: list[Edge] = []): self._edges = [] self._nodes = set() - - def add_edge(self, edge: Edge): - self._edges.append(edge) - self._nodes.add(edge.source) - self._nodes.add(edge.target) + self.add_edges(edges) + self.add_nodes(nodes) def add_node(self, node: Node): self._nodes.add(node) + def add_nodes(self, nodes: list[Node]): + for node in nodes: + self.add_node(node) + + def add_edge(self, edge: Edge): + self._edges.append(edge) + self.add_node(edge.source) + self.add_node(edge.target) + + def add_edges(self, edges: list[Edge]): + for edge in edges: + self.add_edge(edge) + @property def nodes(self): return self._nodes + @property + def edges(self): + return self._edges + def get_edges_from(self, node: Node) -> list[Edge]: """Get all edges which have the node as source""" return [edge for edge in self._edges if edge.source == node] @@ -72,6 +86,13 @@ class Graph: return direct_sources.union(undirect_sources) - def is_valid_dag(self): + def is_dag(self) -> bool: + visited = set() for node in self._nodes: - pass + if node not in visited: + try: + targets = self.get_targets_from(node) + except RecursionError: + return False + visited.union(targets) + return True diff --git a/tests/graphs/test_graph.py b/tests/graphs/test_graph.py index 9729ac9..8f2c406 100644 --- a/tests/graphs/test_graph.py +++ b/tests/graphs/test_graph.py @@ -54,7 +54,7 @@ def nodes(): @pytest.fixture -def edges(nodes): +def dag_edges(nodes): return { "1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]), "2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]), @@ -63,29 +63,45 @@ def edges(nodes): @pytest.fixture -def graph(nodes, edges): - - graph = Graph() - graph.add_edge(edges["1"]) - graph.add_edge(edges["2"]) - graph.add_edge(edges["3"]) - return graph +def notdag_edges(nodes): + return { + "1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]), + "2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]), + "3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]), + "4": Edge(arrow_name="arrow", source=nodes["D"], target=nodes["B"]), + } -def test_get_edges_from(nodes, edges, graph): +def test_get_edges_from(nodes, dag_edges): + edges = dag_edges + graph = Graph(edges=edges.values()) assert graph.get_edges_from(nodes["A"]) == [edges["1"]] -def test_get_targets_from(nodes, edges, graph): +def test_get_targets_from(nodes, dag_edges): + edges = dag_edges + graph = Graph(edges=edges.values()) assert graph.get_direct_targets_from(nodes["A"]) == set([nodes["C"]]) assert graph.get_direct_targets_from(nodes["C"]) == set([nodes["D"]]) assert graph.get_direct_targets_from(nodes["D"]) == set() assert graph.get_targets_from(nodes["A"]) == set([nodes["C"], nodes["D"]]) -def test_get_sources_from(nodes, edges, graph): +def test_get_sources_from(nodes, dag_edges): + edges = dag_edges + graph = Graph(edges=edges.values()) assert graph.get_direct_sources_from(nodes["A"]) == set() assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]]) assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]]) - assert graph.get_sources_from(nodes["D"]) == set([nodes["A"], nodes["B"], nodes["C"]]) + assert graph.get_sources_from(nodes["D"]) == set( + [nodes["A"], nodes["B"], nodes["C"]] + ) + + +def test_valid_dage(dag_edges, notdag_edges): + graph = Graph(edges=dag_edges.values()) + assert graph.is_dag() + + graph = Graph(edges=notdag_edges.values()) + assert not graph.is_dag()