Feat: add is_dag to Graph
This commit is contained in:
parent
9ff68cb285
commit
226ce84dce
@ -20,22 +20,36 @@ class Edge(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Graph:
|
class Graph:
|
||||||
def __init__(self):
|
def __init__(self, nodes: list[Node] = [], edges: list[Edge] = []):
|
||||||
self._edges = []
|
self._edges = []
|
||||||
self._nodes = set()
|
self._nodes = set()
|
||||||
|
self.add_edges(edges)
|
||||||
def add_edge(self, edge: Edge):
|
self.add_nodes(nodes)
|
||||||
self._edges.append(edge)
|
|
||||||
self._nodes.add(edge.source)
|
|
||||||
self._nodes.add(edge.target)
|
|
||||||
|
|
||||||
def add_node(self, node: Node):
|
def add_node(self, node: Node):
|
||||||
self._nodes.add(node)
|
self._nodes.add(node)
|
||||||
|
|
||||||
|
def add_nodes(self, nodes: list[Node]):
|
||||||
|
for node in nodes:
|
||||||
|
self.add_node(node)
|
||||||
|
|
||||||
|
def add_edge(self, edge: Edge):
|
||||||
|
self._edges.append(edge)
|
||||||
|
self.add_node(edge.source)
|
||||||
|
self.add_node(edge.target)
|
||||||
|
|
||||||
|
def add_edges(self, edges: list[Edge]):
|
||||||
|
for edge in edges:
|
||||||
|
self.add_edge(edge)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nodes(self):
|
def nodes(self):
|
||||||
return self._nodes
|
return self._nodes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def edges(self):
|
||||||
|
return self._edges
|
||||||
|
|
||||||
def get_edges_from(self, node: Node) -> list[Edge]:
|
def get_edges_from(self, node: Node) -> list[Edge]:
|
||||||
"""Get all edges which have the node as source"""
|
"""Get all edges which have the node as source"""
|
||||||
return [edge for edge in self._edges if edge.source == node]
|
return [edge for edge in self._edges if edge.source == node]
|
||||||
@ -72,6 +86,13 @@ class Graph:
|
|||||||
|
|
||||||
return direct_sources.union(undirect_sources)
|
return direct_sources.union(undirect_sources)
|
||||||
|
|
||||||
def is_valid_dag(self):
|
def is_dag(self) -> bool:
|
||||||
|
visited = set()
|
||||||
for node in self._nodes:
|
for node in self._nodes:
|
||||||
pass
|
if node not in visited:
|
||||||
|
try:
|
||||||
|
targets = self.get_targets_from(node)
|
||||||
|
except RecursionError:
|
||||||
|
return False
|
||||||
|
visited.union(targets)
|
||||||
|
return True
|
||||||
|
@ -54,7 +54,7 @@ def nodes():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def edges(nodes):
|
def dag_edges(nodes):
|
||||||
return {
|
return {
|
||||||
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
|
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
|
||||||
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
|
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
|
||||||
@ -63,29 +63,45 @@ def edges(nodes):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def graph(nodes, edges):
|
def notdag_edges(nodes):
|
||||||
|
return {
|
||||||
graph = Graph()
|
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
|
||||||
graph.add_edge(edges["1"])
|
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
|
||||||
graph.add_edge(edges["2"])
|
"3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]),
|
||||||
graph.add_edge(edges["3"])
|
"4": Edge(arrow_name="arrow", source=nodes["D"], target=nodes["B"]),
|
||||||
return graph
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_get_edges_from(nodes, edges, graph):
|
def test_get_edges_from(nodes, dag_edges):
|
||||||
|
edges = dag_edges
|
||||||
|
graph = Graph(edges=edges.values())
|
||||||
assert graph.get_edges_from(nodes["A"]) == [edges["1"]]
|
assert graph.get_edges_from(nodes["A"]) == [edges["1"]]
|
||||||
|
|
||||||
|
|
||||||
def test_get_targets_from(nodes, edges, graph):
|
def test_get_targets_from(nodes, dag_edges):
|
||||||
|
edges = dag_edges
|
||||||
|
graph = Graph(edges=edges.values())
|
||||||
assert graph.get_direct_targets_from(nodes["A"]) == set([nodes["C"]])
|
assert graph.get_direct_targets_from(nodes["A"]) == set([nodes["C"]])
|
||||||
assert graph.get_direct_targets_from(nodes["C"]) == set([nodes["D"]])
|
assert graph.get_direct_targets_from(nodes["C"]) == set([nodes["D"]])
|
||||||
assert graph.get_direct_targets_from(nodes["D"]) == set()
|
assert graph.get_direct_targets_from(nodes["D"]) == set()
|
||||||
assert graph.get_targets_from(nodes["A"]) == set([nodes["C"], nodes["D"]])
|
assert graph.get_targets_from(nodes["A"]) == set([nodes["C"], nodes["D"]])
|
||||||
|
|
||||||
|
|
||||||
def test_get_sources_from(nodes, edges, graph):
|
def test_get_sources_from(nodes, dag_edges):
|
||||||
|
edges = dag_edges
|
||||||
|
graph = Graph(edges=edges.values())
|
||||||
assert graph.get_direct_sources_from(nodes["A"]) == set()
|
assert graph.get_direct_sources_from(nodes["A"]) == set()
|
||||||
assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]])
|
assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]])
|
||||||
assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]])
|
assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]])
|
||||||
|
|
||||||
assert graph.get_sources_from(nodes["D"]) == set([nodes["A"], nodes["B"], nodes["C"]])
|
assert graph.get_sources_from(nodes["D"]) == set(
|
||||||
|
[nodes["A"], nodes["B"], nodes["C"]]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_dage(dag_edges, notdag_edges):
|
||||||
|
graph = Graph(edges=dag_edges.values())
|
||||||
|
assert graph.is_dag()
|
||||||
|
|
||||||
|
graph = Graph(edges=notdag_edges.values())
|
||||||
|
assert not graph.is_dag()
|
||||||
|
Loading…
Reference in New Issue
Block a user