diff --git a/plesna/graph.py b/plesna/graph.py index 00b8aff..4e04baa 100644 --- a/plesna/graph.py +++ b/plesna/graph.py @@ -1,3 +1,4 @@ +from functools import reduce from typing import Callable from pydantic import BaseModel @@ -23,11 +24,54 @@ class Graph: self._edges = [] self._nodes = set() - def append(self, edge: Edge): + def add_edge(self, edge: Edge): self._edges.append(edge) self._nodes.add(edge.source) self._nodes.add(edge.target) + def add_node(self, node: Node): + self._nodes.add(node) + @property def nodes(self): return self._nodes + + def get_edges_from(self, node: Node) -> list[Edge]: + """Get all edges which have the node as source""" + return [edge for edge in self._edges if edge.source == node] + + def get_edges_to(self, node: Node) -> list[Edge]: + """Get all edges which have the node as target""" + return [edge for edge in self._edges if edge.target == node] + + def get_direct_targets_from(self, node: Node) -> set[Node]: + """Get direct nodes that are accessible from the node""" + return set(edge.target for edge in self._edges if edge.source == node) + + def get_targets_from(self, node: Node) -> set[Node]: + """Get all nodes that are accessible from the node + + If the graph have a loop, the procedure be in an infinite loop! + + """ + direct_targets = self.get_direct_targets_from(node) + undirect_targets = [self.get_targets_from(n) for n in direct_targets] + undirect_targets = reduce(lambda x, y: x.union(y), undirect_targets, set()) + + return direct_targets.union(undirect_targets) + + def get_direct_sources_from(self, node: Node) -> set[Node]: + """Get direct nodes that are targeted the node""" + return set(edge.source for edge in self._edges if edge.target == node) + + def get_sources_from(self, node: Node) -> set[Node]: + """Get all nodes that are targeted the node""" + direct_sources = self.get_direct_sources_from(node) + undirect_sources = [self.get_sources_from(n) for n in direct_sources] + undirect_sources = reduce(lambda x, y: x.union(y), undirect_sources, set()) + + return direct_sources.union(undirect_sources) + + def is_valid_dag(self): + for node in self._nodes: + pass diff --git a/tests/graphs/test_graph.py b/tests/graphs/test_graph.py index 1fbe18b..9729ac9 100644 --- a/tests/graphs/test_graph.py +++ b/tests/graphs/test_graph.py @@ -1,7 +1,20 @@ +import pytest + from plesna.graph import Edge, Graph, Node -def test_init(): +def test_append_nodess(): + nodeA = Node(name="A") + nodeB = Node(name="B") + + graph = Graph() + graph.add_node(nodeA) + graph.add_node(nodeB) + + assert graph.nodes == {nodeA, nodeB} + + +def test_append_edges(): nodeA = Node(name="A") nodeB = Node(name="B") nodeC = Node(name="C") @@ -10,7 +23,69 @@ def test_init(): edge2 = Edge(arrow_name="arrow", source=nodeB, target=nodeC) graph = Graph() - graph.append(edge1) - graph.append(edge2) + graph.add_edge(edge1) + graph.add_edge(edge2) assert graph.nodes == {nodeA, nodeB, nodeC} + + +def test_init_edges_nodes(): + nodeA = Node(name="A") + nodeB = Node(name="B") + nodeC = Node(name="C") + + edge1 = Edge(arrow_name="arrow", source=nodeB, target=nodeC) + + graph = Graph() + graph.add_node(nodeA) + graph.add_edge(edge1) + + assert graph.nodes == {nodeA, nodeB, nodeC} + + +@pytest.fixture +def nodes(): + return { + "A": Node(name="A"), + "B": Node(name="B"), + "C": Node(name="C"), + "D": Node(name="D"), + } + + +@pytest.fixture +def edges(nodes): + return { + "1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]), + "2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]), + "3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]), + } + + +@pytest.fixture +def graph(nodes, edges): + + graph = Graph() + graph.add_edge(edges["1"]) + graph.add_edge(edges["2"]) + graph.add_edge(edges["3"]) + return graph + + +def test_get_edges_from(nodes, edges, graph): + assert graph.get_edges_from(nodes["A"]) == [edges["1"]] + + +def test_get_targets_from(nodes, edges, graph): + assert graph.get_direct_targets_from(nodes["A"]) == set([nodes["C"]]) + assert graph.get_direct_targets_from(nodes["C"]) == set([nodes["D"]]) + assert graph.get_direct_targets_from(nodes["D"]) == set() + assert graph.get_targets_from(nodes["A"]) == set([nodes["C"], nodes["D"]]) + + +def test_get_sources_from(nodes, edges, graph): + assert graph.get_direct_sources_from(nodes["A"]) == set() + assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]]) + assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]]) + + assert graph.get_sources_from(nodes["D"]) == set([nodes["A"], nodes["B"], nodes["C"]])