From 5ebde14be9f6131caa51de7a3d1752b7f6bc09af Mon Sep 17 00:00:00 2001 From: Bertrand Benjamin Date: Sun, 5 Jan 2025 16:42:57 +0100 Subject: [PATCH] Feat: add to_graph and is_valid_dag for graph_set --- plesna/graph/graph_set.py | 24 ++++++++++++++++++++---- plesna/models/graphs.py | 3 ++- tests/graphs/test_graph.py | 20 ++++++++++---------- tests/graphs/test_graph_set.py | 30 +++++++++++++++++++++++++++++- 4 files changed, 61 insertions(+), 16 deletions(-) diff --git a/plesna/graph/graph_set.py b/plesna/graph/graph_set.py index 8809399..df856a5 100644 --- a/plesna/graph/graph_set.py +++ b/plesna/graph/graph_set.py @@ -1,4 +1,7 @@ -from plesna.models.graphs import EdgeOnSet +from typing import Set +from plesna.graph.graph import Graph +from plesna.models.graphs import Edge, EdgeOnSet +from itertools import product class GraphSet: @@ -12,8 +15,21 @@ class GraphSet: self._node_sets.add(frozenset(edge.targets)) @property - def node_sets(self): + def node_sets(self) -> Set[frozenset]: return self._node_sets - def is_valid_dag(self): - pass + def to_graph(self) -> Graph: + graph = Graph() + for node_set in self.node_sets: + graph.add_nodes(node_set) + for edge in self._edges: + flatten_edge = [ + Edge(arrow=edge.arrow, source=s, target=t, edge_kwrds=edge.edge_kwrds) + for (s, t) in product(edge.sources, edge.targets) + ] + graph.add_edges(flatten_edge) + + return graph + + def is_valid_dag(self) -> bool: + return self.to_graph().is_dag() diff --git a/plesna/models/graphs.py b/plesna/models/graphs.py index f2d6895..3506005 100644 --- a/plesna/models/graphs.py +++ b/plesna/models/graphs.py @@ -9,7 +9,7 @@ class Node(BaseModel): class Edge(BaseModel): - arrow_name: str + arrow: str source: Node target: Node edge_kwrds: dict = {} @@ -19,3 +19,4 @@ class EdgeOnSet(BaseModel): arrow: str sources: list[Node] targets: list[Node] + edge_kwrds: dict = {} diff --git a/tests/graphs/test_graph.py b/tests/graphs/test_graph.py index 0830ad0..b1b6096 100644 --- a/tests/graphs/test_graph.py +++ b/tests/graphs/test_graph.py @@ -20,8 +20,8 @@ def test_append_edges(): nodeB = Node(name="B") nodeC = Node(name="C") - edge1 = Edge(arrow_name="arrow", source=nodeA, target=nodeC) - edge2 = Edge(arrow_name="arrow", source=nodeB, target=nodeC) + edge1 = Edge(arrow="arrow", source=nodeA, target=nodeC) + edge2 = Edge(arrow="arrow", source=nodeB, target=nodeC) graph = Graph() graph.add_edge(edge1) @@ -35,7 +35,7 @@ def test_init_edges_nodes(): nodeB = Node(name="B") nodeC = Node(name="C") - edge1 = Edge(arrow_name="arrow", source=nodeB, target=nodeC) + edge1 = Edge(arrow="arrow", source=nodeB, target=nodeC) graph = Graph() graph.add_node(nodeA) @@ -57,19 +57,19 @@ def nodes(): @pytest.fixture def dag_edges(nodes): return { - "1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]), - "2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]), - "3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]), + "1": Edge(arrow="arrow", source=nodes["A"], target=nodes["C"]), + "2": Edge(arrow="arrow", source=nodes["B"], target=nodes["C"]), + "3": Edge(arrow="arrow", source=nodes["C"], target=nodes["D"]), } @pytest.fixture def notdag_edges(nodes): return { - "1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]), - "2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]), - "3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]), - "4": Edge(arrow_name="arrow", source=nodes["D"], target=nodes["B"]), + "1": Edge(arrow="arrow", source=nodes["A"], target=nodes["C"]), + "2": Edge(arrow="arrow", source=nodes["B"], target=nodes["C"]), + "3": Edge(arrow="arrow", source=nodes["C"], target=nodes["D"]), + "4": Edge(arrow="arrow", source=nodes["D"], target=nodes["B"]), } diff --git a/tests/graphs/test_graph_set.py b/tests/graphs/test_graph_set.py index a93eb81..180541a 100644 --- a/tests/graphs/test_graph_set.py +++ b/tests/graphs/test_graph_set.py @@ -1,5 +1,6 @@ +from plesna.graph.graph import Graph from plesna.graph.graph_set import GraphSet -from plesna.models.graphs import EdgeOnSet, Node +from plesna.models.graphs import Edge, EdgeOnSet, Node def test_init(): @@ -13,3 +14,30 @@ def test_init(): graph_set.append(edge1) assert graph_set.node_sets == {frozenset([nodeA, nodeB]), frozenset([nodeC])} + + +def test_to_graph(): + graph_set = GraphSet() + + nodeA = Node(name="A") + nodeB = Node(name="B") + nodeC = Node(name="C") + nodeD = Node(name="D") + edge1 = EdgeOnSet(arrow="arrow-AB-C", sources=[nodeA, nodeB], targets=[nodeC]) + edge2 = EdgeOnSet(arrow="arrow-C-D", sources=[nodeC], targets=[nodeD]) + + graph_set.append(edge1) + graph_set.append(edge2) + + graph = graph_set.to_graph() + assert graph.nodes == { + nodeA, + nodeB, + nodeC, + nodeD, + } + assert graph.edges == [ + Edge(arrow="arrow-AB-C", source=nodeA, target=nodeC), + Edge(arrow="arrow-AB-C", source=nodeB, target=nodeC), + Edge(arrow="arrow-C-D", source=nodeC, target=nodeD), + ]