Feat: add is_dag to Graph

This commit is contained in:
Bertrand Benjamin 2024-10-27 14:10:33 +01:00
parent 9ff68cb285
commit 226ce84dce
2 changed files with 57 additions and 20 deletions

View File

@ -20,22 +20,36 @@ class Edge(BaseModel):
class Graph: class Graph:
def __init__(self): def __init__(self, nodes: list[Node] = [], edges: list[Edge] = []):
self._edges = [] self._edges = []
self._nodes = set() self._nodes = set()
self.add_edges(edges)
def add_edge(self, edge: Edge): self.add_nodes(nodes)
self._edges.append(edge)
self._nodes.add(edge.source)
self._nodes.add(edge.target)
def add_node(self, node: Node): def add_node(self, node: Node):
self._nodes.add(node) self._nodes.add(node)
def add_nodes(self, nodes: list[Node]):
for node in nodes:
self.add_node(node)
def add_edge(self, edge: Edge):
self._edges.append(edge)
self.add_node(edge.source)
self.add_node(edge.target)
def add_edges(self, edges: list[Edge]):
for edge in edges:
self.add_edge(edge)
@property @property
def nodes(self): def nodes(self):
return self._nodes return self._nodes
@property
def edges(self):
return self._edges
def get_edges_from(self, node: Node) -> list[Edge]: def get_edges_from(self, node: Node) -> list[Edge]:
"""Get all edges which have the node as source""" """Get all edges which have the node as source"""
return [edge for edge in self._edges if edge.source == node] return [edge for edge in self._edges if edge.source == node]
@ -72,6 +86,13 @@ class Graph:
return direct_sources.union(undirect_sources) return direct_sources.union(undirect_sources)
def is_valid_dag(self): def is_dag(self) -> bool:
visited = set()
for node in self._nodes: for node in self._nodes:
pass if node not in visited:
try:
targets = self.get_targets_from(node)
except RecursionError:
return False
visited.union(targets)
return True

View File

@ -54,7 +54,7 @@ def nodes():
@pytest.fixture @pytest.fixture
def edges(nodes): def dag_edges(nodes):
return { return {
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]), "1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]), "2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
@ -63,29 +63,45 @@ def edges(nodes):
@pytest.fixture @pytest.fixture
def graph(nodes, edges): def notdag_edges(nodes):
return {
graph = Graph() "1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
graph.add_edge(edges["1"]) "2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
graph.add_edge(edges["2"]) "3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]),
graph.add_edge(edges["3"]) "4": Edge(arrow_name="arrow", source=nodes["D"], target=nodes["B"]),
return graph }
def test_get_edges_from(nodes, edges, graph): def test_get_edges_from(nodes, dag_edges):
edges = dag_edges
graph = Graph(edges=edges.values())
assert graph.get_edges_from(nodes["A"]) == [edges["1"]] assert graph.get_edges_from(nodes["A"]) == [edges["1"]]
def test_get_targets_from(nodes, edges, graph): def test_get_targets_from(nodes, dag_edges):
edges = dag_edges
graph = Graph(edges=edges.values())
assert graph.get_direct_targets_from(nodes["A"]) == set([nodes["C"]]) assert graph.get_direct_targets_from(nodes["A"]) == set([nodes["C"]])
assert graph.get_direct_targets_from(nodes["C"]) == set([nodes["D"]]) assert graph.get_direct_targets_from(nodes["C"]) == set([nodes["D"]])
assert graph.get_direct_targets_from(nodes["D"]) == set() assert graph.get_direct_targets_from(nodes["D"]) == set()
assert graph.get_targets_from(nodes["A"]) == set([nodes["C"], nodes["D"]]) assert graph.get_targets_from(nodes["A"]) == set([nodes["C"], nodes["D"]])
def test_get_sources_from(nodes, edges, graph): def test_get_sources_from(nodes, dag_edges):
edges = dag_edges
graph = Graph(edges=edges.values())
assert graph.get_direct_sources_from(nodes["A"]) == set() assert graph.get_direct_sources_from(nodes["A"]) == set()
assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]]) assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]])
assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]]) assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]])
assert graph.get_sources_from(nodes["D"]) == set([nodes["A"], nodes["B"], nodes["C"]]) assert graph.get_sources_from(nodes["D"]) == set(
[nodes["A"], nodes["B"], nodes["C"]]
)
def test_valid_dage(dag_edges, notdag_edges):
graph = Graph(edges=dag_edges.values())
assert graph.is_dag()
graph = Graph(edges=notdag_edges.values())
assert not graph.is_dag()