diff --git a/plesna/graph/graph.py b/plesna/graph/graph.py index ca09e82..be5f28b 100644 --- a/plesna/graph/graph.py +++ b/plesna/graph/graph.py @@ -1,21 +1,6 @@ -from functools import reduce - from pydantic import BaseModel - - -class Node(BaseModel): - name: str - infos: dict = {} - - def __hash__(self): - return hash(self.name) - - -class Edge(BaseModel): - arrow_name: str - source: Node - target: Node - edge_kwrds: dict = {} +from functools import reduce +from plesna.models.graphs import Node, Edge class Graph: diff --git a/plesna/graph/graph_set.py b/plesna/graph/graph_set.py index f79f1cd..8809399 100644 --- a/plesna/graph/graph_set.py +++ b/plesna/graph/graph_set.py @@ -1,19 +1,4 @@ -from typing import Callable - -from pydantic import BaseModel - - -class Node(BaseModel): - name: str - - def __hash__(self): - return hash(self.name) - - -class EdgeOnSet(BaseModel): - arrow: str - sources: list[Node] - targets: list[Node] +from plesna.models.graphs import EdgeOnSet class GraphSet: diff --git a/plesna/models/graphs.py b/plesna/models/graphs.py new file mode 100644 index 0000000..f2d6895 --- /dev/null +++ b/plesna/models/graphs.py @@ -0,0 +1,21 @@ +from pydantic import BaseModel + + +class Node(BaseModel): + name: str + + def __hash__(self): + return hash(self.name) + + +class Edge(BaseModel): + arrow_name: str + source: Node + target: Node + edge_kwrds: dict = {} + + +class EdgeOnSet(BaseModel): + arrow: str + sources: list[Node] + targets: list[Node] diff --git a/tests/graphs/test_graph.py b/tests/graphs/test_graph.py index 24f5eb8..0830ad0 100644 --- a/tests/graphs/test_graph.py +++ b/tests/graphs/test_graph.py @@ -1,6 +1,7 @@ import pytest -from plesna.graph.graph import Edge, Graph, Node +from plesna.graph.graph import Graph +from plesna.models.graphs import Edge, Node def test_append_nodess(): @@ -94,9 +95,7 @@ def test_get_sources_from(nodes, dag_edges): assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]]) assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]]) - assert graph.get_sources_from(nodes["D"]) == set( - [nodes["A"], nodes["B"], nodes["C"]] - ) + assert graph.get_sources_from(nodes["D"]) == set([nodes["A"], nodes["B"], nodes["C"]]) def test_valid_dage(dag_edges, notdag_edges): diff --git a/tests/graphs/test_graph_set.py b/tests/graphs/test_graph_set.py index 1806f7e..a93eb81 100644 --- a/tests/graphs/test_graph_set.py +++ b/tests/graphs/test_graph_set.py @@ -1,4 +1,5 @@ -from plesna.graph.graph_set import EdgeOnSet, GraphSet, Node +from plesna.graph.graph_set import GraphSet +from plesna.models.graphs import EdgeOnSet, Node def test_init():