Compare commits
5 Commits
Author | SHA1 | Date | |
---|---|---|---|
226ce84dce | |||
9ff68cb285 | |||
5c69bb5503 | |||
c90f407cfc | |||
867747d748 |
98
plesna/graph.py
Normal file
98
plesna/graph.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from functools import reduce
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Node(BaseModel):
|
||||||
|
name: str
|
||||||
|
infos: dict = {}
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.name)
|
||||||
|
|
||||||
|
|
||||||
|
class Edge(BaseModel):
|
||||||
|
arrow_name: str
|
||||||
|
source: Node
|
||||||
|
target: Node
|
||||||
|
edge_kwrds: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
class Graph:
|
||||||
|
def __init__(self, nodes: list[Node] = [], edges: list[Edge] = []):
|
||||||
|
self._edges = []
|
||||||
|
self._nodes = set()
|
||||||
|
self.add_edges(edges)
|
||||||
|
self.add_nodes(nodes)
|
||||||
|
|
||||||
|
def add_node(self, node: Node):
|
||||||
|
self._nodes.add(node)
|
||||||
|
|
||||||
|
def add_nodes(self, nodes: list[Node]):
|
||||||
|
for node in nodes:
|
||||||
|
self.add_node(node)
|
||||||
|
|
||||||
|
def add_edge(self, edge: Edge):
|
||||||
|
self._edges.append(edge)
|
||||||
|
self.add_node(edge.source)
|
||||||
|
self.add_node(edge.target)
|
||||||
|
|
||||||
|
def add_edges(self, edges: list[Edge]):
|
||||||
|
for edge in edges:
|
||||||
|
self.add_edge(edge)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nodes(self):
|
||||||
|
return self._nodes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def edges(self):
|
||||||
|
return self._edges
|
||||||
|
|
||||||
|
def get_edges_from(self, node: Node) -> list[Edge]:
|
||||||
|
"""Get all edges which have the node as source"""
|
||||||
|
return [edge for edge in self._edges if edge.source == node]
|
||||||
|
|
||||||
|
def get_edges_to(self, node: Node) -> list[Edge]:
|
||||||
|
"""Get all edges which have the node as target"""
|
||||||
|
return [edge for edge in self._edges if edge.target == node]
|
||||||
|
|
||||||
|
def get_direct_targets_from(self, node: Node) -> set[Node]:
|
||||||
|
"""Get direct nodes that are accessible from the node"""
|
||||||
|
return set(edge.target for edge in self._edges if edge.source == node)
|
||||||
|
|
||||||
|
def get_targets_from(self, node: Node) -> set[Node]:
|
||||||
|
"""Get all nodes that are accessible from the node
|
||||||
|
|
||||||
|
If the graph have a loop, the procedure be in an infinite loop!
|
||||||
|
|
||||||
|
"""
|
||||||
|
direct_targets = self.get_direct_targets_from(node)
|
||||||
|
undirect_targets = [self.get_targets_from(n) for n in direct_targets]
|
||||||
|
undirect_targets = reduce(lambda x, y: x.union(y), undirect_targets, set())
|
||||||
|
|
||||||
|
return direct_targets.union(undirect_targets)
|
||||||
|
|
||||||
|
def get_direct_sources_from(self, node: Node) -> set[Node]:
|
||||||
|
"""Get direct nodes that are targeted the node"""
|
||||||
|
return set(edge.source for edge in self._edges if edge.target == node)
|
||||||
|
|
||||||
|
def get_sources_from(self, node: Node) -> set[Node]:
|
||||||
|
"""Get all nodes that are targeted the node"""
|
||||||
|
direct_sources = self.get_direct_sources_from(node)
|
||||||
|
undirect_sources = [self.get_sources_from(n) for n in direct_sources]
|
||||||
|
undirect_sources = reduce(lambda x, y: x.union(y), undirect_sources, set())
|
||||||
|
|
||||||
|
return direct_sources.union(undirect_sources)
|
||||||
|
|
||||||
|
def is_dag(self) -> bool:
|
||||||
|
visited = set()
|
||||||
|
for node in self._nodes:
|
||||||
|
if node not in visited:
|
||||||
|
try:
|
||||||
|
targets = self.get_targets_from(node)
|
||||||
|
except RecursionError:
|
||||||
|
return False
|
||||||
|
visited.union(targets)
|
||||||
|
return True
|
33
plesna/graph_set.py
Normal file
33
plesna/graph_set.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Node(BaseModel):
|
||||||
|
name: str
|
||||||
|
infos: dict = {}
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.name)
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeOnSet(BaseModel):
|
||||||
|
arrow: Callable
|
||||||
|
sources: dict[str, Node]
|
||||||
|
targets: dict[str, Node]
|
||||||
|
edge_kwrds: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
class GraphSet:
|
||||||
|
def __init__(self):
|
||||||
|
self._edges = []
|
||||||
|
self._node_sets = set()
|
||||||
|
|
||||||
|
def append(self, edge: EdgeOnSet):
|
||||||
|
self._edges.append(edge)
|
||||||
|
self._node_sets.add(frozenset(edge.sources.values()))
|
||||||
|
self._node_sets.add(frozenset(edge.targets.values()))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def node_sets(self):
|
||||||
|
return self._node_sets
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/graphs/__init__.py
Normal file
0
tests/graphs/__init__.py
Normal file
107
tests/graphs/test_graph.py
Normal file
107
tests/graphs/test_graph.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from plesna.graph import Edge, Graph, Node
|
||||||
|
|
||||||
|
|
||||||
|
def test_append_nodess():
|
||||||
|
nodeA = Node(name="A")
|
||||||
|
nodeB = Node(name="B")
|
||||||
|
|
||||||
|
graph = Graph()
|
||||||
|
graph.add_node(nodeA)
|
||||||
|
graph.add_node(nodeB)
|
||||||
|
|
||||||
|
assert graph.nodes == {nodeA, nodeB}
|
||||||
|
|
||||||
|
|
||||||
|
def test_append_edges():
|
||||||
|
nodeA = Node(name="A")
|
||||||
|
nodeB = Node(name="B")
|
||||||
|
nodeC = Node(name="C")
|
||||||
|
|
||||||
|
edge1 = Edge(arrow_name="arrow", source=nodeA, target=nodeC)
|
||||||
|
edge2 = Edge(arrow_name="arrow", source=nodeB, target=nodeC)
|
||||||
|
|
||||||
|
graph = Graph()
|
||||||
|
graph.add_edge(edge1)
|
||||||
|
graph.add_edge(edge2)
|
||||||
|
|
||||||
|
assert graph.nodes == {nodeA, nodeB, nodeC}
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_edges_nodes():
|
||||||
|
nodeA = Node(name="A")
|
||||||
|
nodeB = Node(name="B")
|
||||||
|
nodeC = Node(name="C")
|
||||||
|
|
||||||
|
edge1 = Edge(arrow_name="arrow", source=nodeB, target=nodeC)
|
||||||
|
|
||||||
|
graph = Graph()
|
||||||
|
graph.add_node(nodeA)
|
||||||
|
graph.add_edge(edge1)
|
||||||
|
|
||||||
|
assert graph.nodes == {nodeA, nodeB, nodeC}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def nodes():
|
||||||
|
return {
|
||||||
|
"A": Node(name="A"),
|
||||||
|
"B": Node(name="B"),
|
||||||
|
"C": Node(name="C"),
|
||||||
|
"D": Node(name="D"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dag_edges(nodes):
|
||||||
|
return {
|
||||||
|
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
|
||||||
|
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
|
||||||
|
"3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def notdag_edges(nodes):
|
||||||
|
return {
|
||||||
|
"1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]),
|
||||||
|
"2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]),
|
||||||
|
"3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]),
|
||||||
|
"4": Edge(arrow_name="arrow", source=nodes["D"], target=nodes["B"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_edges_from(nodes, dag_edges):
|
||||||
|
edges = dag_edges
|
||||||
|
graph = Graph(edges=edges.values())
|
||||||
|
assert graph.get_edges_from(nodes["A"]) == [edges["1"]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_targets_from(nodes, dag_edges):
|
||||||
|
edges = dag_edges
|
||||||
|
graph = Graph(edges=edges.values())
|
||||||
|
assert graph.get_direct_targets_from(nodes["A"]) == set([nodes["C"]])
|
||||||
|
assert graph.get_direct_targets_from(nodes["C"]) == set([nodes["D"]])
|
||||||
|
assert graph.get_direct_targets_from(nodes["D"]) == set()
|
||||||
|
assert graph.get_targets_from(nodes["A"]) == set([nodes["C"], nodes["D"]])
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_sources_from(nodes, dag_edges):
|
||||||
|
edges = dag_edges
|
||||||
|
graph = Graph(edges=edges.values())
|
||||||
|
assert graph.get_direct_sources_from(nodes["A"]) == set()
|
||||||
|
assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]])
|
||||||
|
assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]])
|
||||||
|
|
||||||
|
assert graph.get_sources_from(nodes["D"]) == set(
|
||||||
|
[nodes["A"], nodes["B"], nodes["C"]]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_dage(dag_edges, notdag_edges):
|
||||||
|
graph = Graph(edges=dag_edges.values())
|
||||||
|
assert graph.is_dag()
|
||||||
|
|
||||||
|
graph = Graph(edges=notdag_edges.values())
|
||||||
|
assert not graph.is_dag()
|
18
tests/graphs/test_graph_set.py
Normal file
18
tests/graphs/test_graph_set.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from plesna.graph_set import EdgeOnSet, GraphSet, Node
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
nodeA = Node(name="A")
|
||||||
|
nodeB = Node(name="B")
|
||||||
|
nodeC = Node(name="C")
|
||||||
|
|
||||||
|
def arrow(sources, targets):
|
||||||
|
targets["C"].infos["res"] = sources["A"].name + sources["B"].name
|
||||||
|
|
||||||
|
edge1 = EdgeOnSet(
|
||||||
|
arrow=arrow, sources={"A": nodeA, "B": nodeB}, targets={"C": nodeC}
|
||||||
|
)
|
||||||
|
graph_set = GraphSet()
|
||||||
|
graph_set.append(edge1)
|
||||||
|
|
||||||
|
assert graph_set.node_sets == {frozenset([nodeA, nodeB]), frozenset([nodeC])}
|
Loading…
Reference in New Issue
Block a user