Feat: add to_graph and is_valid_dag for graph_set
This commit is contained in:
parent
44a7eed5b4
commit
5ebde14be9
@ -1,4 +1,7 @@
|
|||||||
from plesna.models.graphs import EdgeOnSet
|
from typing import Set
|
||||||
|
from plesna.graph.graph import Graph
|
||||||
|
from plesna.models.graphs import Edge, EdgeOnSet
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
|
||||||
class GraphSet:
|
class GraphSet:
|
||||||
@ -12,8 +15,21 @@ class GraphSet:
|
|||||||
self._node_sets.add(frozenset(edge.targets))
|
self._node_sets.add(frozenset(edge.targets))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def node_sets(self):
|
def node_sets(self) -> Set[frozenset]:
|
||||||
return self._node_sets
|
return self._node_sets
|
||||||
|
|
||||||
def is_valid_dag(self):
|
def to_graph(self) -> Graph:
|
||||||
pass
|
graph = Graph()
|
||||||
|
for node_set in self.node_sets:
|
||||||
|
graph.add_nodes(node_set)
|
||||||
|
for edge in self._edges:
|
||||||
|
flatten_edge = [
|
||||||
|
Edge(arrow=edge.arrow, source=s, target=t, edge_kwrds=edge.edge_kwrds)
|
||||||
|
for (s, t) in product(edge.sources, edge.targets)
|
||||||
|
]
|
||||||
|
graph.add_edges(flatten_edge)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def is_valid_dag(self) -> bool:
|
||||||
|
return self.to_graph().is_dag()
|
||||||
|
@ -9,7 +9,7 @@ class Node(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Edge(BaseModel):
|
class Edge(BaseModel):
|
||||||
arrow_name: str
|
arrow: str
|
||||||
source: Node
|
source: Node
|
||||||
target: Node
|
target: Node
|
||||||
edge_kwrds: dict = {}
|
edge_kwrds: dict = {}
|
||||||
@ -19,3 +19,4 @@ class EdgeOnSet(BaseModel):
|
|||||||
arrow: str
|
arrow: str
|
||||||
sources: list[Node]
|
sources: list[Node]
|
||||||
targets: list[Node]
|
targets: list[Node]
|
||||||
|
edge_kwrds: dict = {}
|
||||||
|
@ -20,8 +20,8 @@ def test_append_edges():
|
|||||||
nodeB = Node(name="B")
|
nodeB = Node(name="B")
|
||||||
nodeC = Node(name="C")
|
nodeC = Node(name="C")
|
||||||
|
|
||||||
edge1 = Edge(arrow_name="arrow", source=nodeA, target=nodeC)
|
edge1 = Edge(arrow="arrow", source=nodeA, target=nodeC)
|
||||||
edge2 = Edge(arrow_name="arrow", source=nodeB, target=nodeC)
|
edge2 = Edge(arrow="arrow", source=nodeB, target=nodeC)
|
||||||
|
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
graph.add_edge(edge1)
|
graph.add_edge(edge1)
|
||||||
@ -35,7 +35,7 @@ def test_init_edges_nodes():
|
|||||||
nodeB = Node(name="B")
|
nodeB = Node(name="B")
|
||||||
nodeC = Node(name="C")
|
nodeC = Node(name="C")
|
||||||
|
|
||||||
edge1 = Edge(arrow_name="arrow", source=nodeB, target=nodeC)
|
edge1 = Edge(arrow="arrow", source=nodeB, target=nodeC)
|
||||||
|
|
||||||
graph = Graph()
|
graph = Graph()
|
||||||
graph.add_node(nodeA)
|
graph.add_node(nodeA)
|
||||||
@ -57,19 +57,19 @@ def nodes():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dag_edges(nodes):
|
def dag_edges(nodes):
|
||||||
return {
|
return {
|
||||||
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
|
"1": Edge(arrow="arrow", source=nodes["A"], target=nodes["C"]),
|
||||||
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
|
"2": Edge(arrow="arrow", source=nodes["B"], target=nodes["C"]),
|
||||||
"3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]),
|
"3": Edge(arrow="arrow", source=nodes["C"], target=nodes["D"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def notdag_edges(nodes):
|
def notdag_edges(nodes):
|
||||||
return {
|
return {
|
||||||
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
|
"1": Edge(arrow="arrow", source=nodes["A"], target=nodes["C"]),
|
||||||
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
|
"2": Edge(arrow="arrow", source=nodes["B"], target=nodes["C"]),
|
||||||
"3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]),
|
"3": Edge(arrow="arrow", source=nodes["C"], target=nodes["D"]),
|
||||||
"4": Edge(arrow_name="arrow", source=nodes["D"], target=nodes["B"]),
|
"4": Edge(arrow="arrow", source=nodes["D"], target=nodes["B"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
|
from plesna.graph.graph import Graph
|
||||||
from plesna.graph.graph_set import GraphSet
|
from plesna.graph.graph_set import GraphSet
|
||||||
from plesna.models.graphs import EdgeOnSet, Node
|
from plesna.models.graphs import Edge, EdgeOnSet, Node
|
||||||
|
|
||||||
|
|
||||||
def test_init():
|
def test_init():
|
||||||
@ -13,3 +14,30 @@ def test_init():
|
|||||||
graph_set.append(edge1)
|
graph_set.append(edge1)
|
||||||
|
|
||||||
assert graph_set.node_sets == {frozenset([nodeA, nodeB]), frozenset([nodeC])}
|
assert graph_set.node_sets == {frozenset([nodeA, nodeB]), frozenset([nodeC])}
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_graph():
|
||||||
|
graph_set = GraphSet()
|
||||||
|
|
||||||
|
nodeA = Node(name="A")
|
||||||
|
nodeB = Node(name="B")
|
||||||
|
nodeC = Node(name="C")
|
||||||
|
nodeD = Node(name="D")
|
||||||
|
edge1 = EdgeOnSet(arrow="arrow-AB-C", sources=[nodeA, nodeB], targets=[nodeC])
|
||||||
|
edge2 = EdgeOnSet(arrow="arrow-C-D", sources=[nodeC], targets=[nodeD])
|
||||||
|
|
||||||
|
graph_set.append(edge1)
|
||||||
|
graph_set.append(edge2)
|
||||||
|
|
||||||
|
graph = graph_set.to_graph()
|
||||||
|
assert graph.nodes == {
|
||||||
|
nodeA,
|
||||||
|
nodeB,
|
||||||
|
nodeC,
|
||||||
|
nodeD,
|
||||||
|
}
|
||||||
|
assert graph.edges == [
|
||||||
|
Edge(arrow="arrow-AB-C", source=nodeA, target=nodeC),
|
||||||
|
Edge(arrow="arrow-AB-C", source=nodeB, target=nodeC),
|
||||||
|
Edge(arrow="arrow-C-D", source=nodeC, target=nodeD),
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user