From d1c1b7420d2451c955b86d9bd086a457fce89cb5 Mon Sep 17 00:00:00 2001 From: Bertrand Benjamin Date: Sun, 5 Jan 2025 06:51:14 +0100 Subject: [PATCH] refact: replace callback with str for arrow in graph_set --- plesna/graph/graph.py | 1 - plesna/graph/graph_set.py | 11 +++++------ tests/graphs/test_graph_set.py | 10 +++------- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/plesna/graph/graph.py b/plesna/graph/graph.py index 2764bb2..ca09e82 100644 --- a/plesna/graph/graph.py +++ b/plesna/graph/graph.py @@ -1,5 +1,4 @@ from functools import reduce -from typing import Callable from pydantic import BaseModel diff --git a/plesna/graph/graph_set.py b/plesna/graph/graph_set.py index a52fdfc..e03d654 100644 --- a/plesna/graph/graph_set.py +++ b/plesna/graph/graph_set.py @@ -12,10 +12,9 @@ class Node(BaseModel): class EdgeOnSet(BaseModel): - arrow: Callable - sources: dict[str, Node] - targets: dict[str, Node] - edge_kwrds: dict = {} + arrow: str + sources: list[Node] + targets: list[Node] class GraphSet: @@ -25,8 +24,8 @@ class GraphSet: def append(self, edge: EdgeOnSet): self._edges.append(edge) - self._node_sets.add(frozenset(edge.sources.values())) - self._node_sets.add(frozenset(edge.targets.values())) + self._node_sets.add(frozenset(edge.sources)) + self._node_sets.add(frozenset(edge.targets)) @property def node_sets(self): diff --git a/tests/graphs/test_graph_set.py b/tests/graphs/test_graph_set.py index c402202..1806f7e 100644 --- a/tests/graphs/test_graph_set.py +++ b/tests/graphs/test_graph_set.py @@ -2,17 +2,13 @@ from plesna.graph.graph_set import EdgeOnSet, GraphSet, Node def test_init(): + graph_set = GraphSet() + nodeA = Node(name="A") nodeB = Node(name="B") nodeC = Node(name="C") + edge1 = EdgeOnSet(arrow="arrow", sources=[nodeA, nodeB], targets=[nodeC]) - def arrow(sources, targets): - targets["C"].infos["res"] = sources["A"].name + sources["B"].name - - edge1 = EdgeOnSet( - arrow=arrow, sources={"A": nodeA, "B": nodeB}, targets={"C": nodeC} - ) - graph_set = GraphSet() graph_set.append(edge1) assert graph_set.node_sets == {frozenset([nodeA, nodeB]), frozenset([nodeC])}