99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
from functools import reduce
|
|
from typing import Callable
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class Node(BaseModel):
|
|
name: str
|
|
infos: dict = {}
|
|
|
|
def __hash__(self):
|
|
return hash(self.name)
|
|
|
|
|
|
class Edge(BaseModel):
|
|
arrow_name: str
|
|
source: Node
|
|
target: Node
|
|
edge_kwrds: dict = {}
|
|
|
|
|
|
class Graph:
|
|
def __init__(self, nodes: list[Node] = [], edges: list[Edge] = []):
|
|
self._edges = []
|
|
self._nodes = set()
|
|
self.add_edges(edges)
|
|
self.add_nodes(nodes)
|
|
|
|
def add_node(self, node: Node):
|
|
self._nodes.add(node)
|
|
|
|
def add_nodes(self, nodes: list[Node]):
|
|
for node in nodes:
|
|
self.add_node(node)
|
|
|
|
def add_edge(self, edge: Edge):
|
|
self._edges.append(edge)
|
|
self.add_node(edge.source)
|
|
self.add_node(edge.target)
|
|
|
|
def add_edges(self, edges: list[Edge]):
|
|
for edge in edges:
|
|
self.add_edge(edge)
|
|
|
|
@property
|
|
def nodes(self):
|
|
return self._nodes
|
|
|
|
@property
|
|
def edges(self):
|
|
return self._edges
|
|
|
|
def get_edges_from(self, node: Node) -> list[Edge]:
|
|
"""Get all edges which have the node as source"""
|
|
return [edge for edge in self._edges if edge.source == node]
|
|
|
|
def get_edges_to(self, node: Node) -> list[Edge]:
|
|
"""Get all edges which have the node as target"""
|
|
return [edge for edge in self._edges if edge.target == node]
|
|
|
|
def get_direct_targets_from(self, node: Node) -> set[Node]:
|
|
"""Get direct nodes that are accessible from the node"""
|
|
return set(edge.target for edge in self._edges if edge.source == node)
|
|
|
|
def get_targets_from(self, node: Node) -> set[Node]:
|
|
"""Get all nodes that are accessible from the node
|
|
|
|
If the graph have a loop, the procedure be in an infinite loop!
|
|
|
|
"""
|
|
direct_targets = self.get_direct_targets_from(node)
|
|
undirect_targets = [self.get_targets_from(n) for n in direct_targets]
|
|
undirect_targets = reduce(lambda x, y: x.union(y), undirect_targets, set())
|
|
|
|
return direct_targets.union(undirect_targets)
|
|
|
|
def get_direct_sources_from(self, node: Node) -> set[Node]:
|
|
"""Get direct nodes that are targeted the node"""
|
|
return set(edge.source for edge in self._edges if edge.target == node)
|
|
|
|
def get_sources_from(self, node: Node) -> set[Node]:
|
|
"""Get all nodes that are targeted the node"""
|
|
direct_sources = self.get_direct_sources_from(node)
|
|
undirect_sources = [self.get_sources_from(n) for n in direct_sources]
|
|
undirect_sources = reduce(lambda x, y: x.union(y), undirect_sources, set())
|
|
|
|
return direct_sources.union(undirect_sources)
|
|
|
|
def is_dag(self) -> bool:
|
|
visited = set()
|
|
for node in self._nodes:
|
|
if node not in visited:
|
|
try:
|
|
targets = self.get_targets_from(node)
|
|
except RecursionError:
|
|
return False
|
|
visited.union(targets)
|
|
return True
|