Feat: add get functions on sources and targets

This commit is contained in:
Bertrand Benjamin 2024-10-27 13:48:37 +01:00
parent 5c69bb5503
commit 9ff68cb285
2 changed files with 123 additions and 4 deletions

View File

@ -1,3 +1,4 @@
from functools import reduce
from typing import Callable
from pydantic import BaseModel
@ -23,11 +24,54 @@ class Graph:
self._edges = []
self._nodes = set()
def append(self, edge: Edge):
def add_edge(self, edge: Edge):
self._edges.append(edge)
self._nodes.add(edge.source)
self._nodes.add(edge.target)
def add_node(self, node: Node):
self._nodes.add(node)
@property
def nodes(self):
return self._nodes
def get_edges_from(self, node: Node) -> list[Edge]:
"""Get all edges which have the node as source"""
return [edge for edge in self._edges if edge.source == node]
def get_edges_to(self, node: Node) -> list[Edge]:
"""Get all edges which have the node as target"""
return [edge for edge in self._edges if edge.target == node]
def get_direct_targets_from(self, node: Node) -> set[Node]:
"""Get direct nodes that are accessible from the node"""
return set(edge.target for edge in self._edges if edge.source == node)
def get_targets_from(self, node: Node) -> set[Node]:
"""Get all nodes that are accessible from the node
If the graph have a loop, the procedure be in an infinite loop!
"""
direct_targets = self.get_direct_targets_from(node)
undirect_targets = [self.get_targets_from(n) for n in direct_targets]
undirect_targets = reduce(lambda x, y: x.union(y), undirect_targets, set())
return direct_targets.union(undirect_targets)
def get_direct_sources_from(self, node: Node) -> set[Node]:
"""Get direct nodes that are targeted the node"""
return set(edge.source for edge in self._edges if edge.target == node)
def get_sources_from(self, node: Node) -> set[Node]:
"""Get all nodes that are targeted the node"""
direct_sources = self.get_direct_sources_from(node)
undirect_sources = [self.get_sources_from(n) for n in direct_sources]
undirect_sources = reduce(lambda x, y: x.union(y), undirect_sources, set())
return direct_sources.union(undirect_sources)
def is_valid_dag(self):
for node in self._nodes:
pass

View File

@ -1,7 +1,20 @@
import pytest
from plesna.graph import Edge, Graph, Node
def test_init():
def test_append_nodess():
nodeA = Node(name="A")
nodeB = Node(name="B")
graph = Graph()
graph.add_node(nodeA)
graph.add_node(nodeB)
assert graph.nodes == {nodeA, nodeB}
def test_append_edges():
nodeA = Node(name="A")
nodeB = Node(name="B")
nodeC = Node(name="C")
@ -10,7 +23,69 @@ def test_init():
edge2 = Edge(arrow_name="arrow", source=nodeB, target=nodeC)
graph = Graph()
graph.append(edge1)
graph.append(edge2)
graph.add_edge(edge1)
graph.add_edge(edge2)
assert graph.nodes == {nodeA, nodeB, nodeC}
def test_init_edges_nodes():
nodeA = Node(name="A")
nodeB = Node(name="B")
nodeC = Node(name="C")
edge1 = Edge(arrow_name="arrow", source=nodeB, target=nodeC)
graph = Graph()
graph.add_node(nodeA)
graph.add_edge(edge1)
assert graph.nodes == {nodeA, nodeB, nodeC}
@pytest.fixture
def nodes():
return {
"A": Node(name="A"),
"B": Node(name="B"),
"C": Node(name="C"),
"D": Node(name="D"),
}
@pytest.fixture
def edges(nodes):
return {
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
"3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]),
}
@pytest.fixture
def graph(nodes, edges):
graph = Graph()
graph.add_edge(edges["1"])
graph.add_edge(edges["2"])
graph.add_edge(edges["3"])
return graph
def test_get_edges_from(nodes, edges, graph):
assert graph.get_edges_from(nodes["A"]) == [edges["1"]]
def test_get_targets_from(nodes, edges, graph):
assert graph.get_direct_targets_from(nodes["A"]) == set([nodes["C"]])
assert graph.get_direct_targets_from(nodes["C"]) == set([nodes["D"]])
assert graph.get_direct_targets_from(nodes["D"]) == set()
assert graph.get_targets_from(nodes["A"]) == set([nodes["C"], nodes["D"]])
def test_get_sources_from(nodes, edges, graph):
assert graph.get_direct_sources_from(nodes["A"]) == set()
assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]])
assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]])
assert graph.get_sources_from(nodes["D"]) == set([nodes["A"], nodes["B"], nodes["C"]])