import pytest from plesna.graph import Edge, Graph, Node def test_append_nodess(): nodeA = Node(name="A") nodeB = Node(name="B") graph = Graph() graph.add_node(nodeA) graph.add_node(nodeB) assert graph.nodes == {nodeA, nodeB} def test_append_edges(): nodeA = Node(name="A") nodeB = Node(name="B") nodeC = Node(name="C") edge1 = Edge(arrow_name="arrow", source=nodeA, target=nodeC) edge2 = Edge(arrow_name="arrow", source=nodeB, target=nodeC) graph = Graph() graph.add_edge(edge1) graph.add_edge(edge2) assert graph.nodes == {nodeA, nodeB, nodeC} def test_init_edges_nodes(): nodeA = Node(name="A") nodeB = Node(name="B") nodeC = Node(name="C") edge1 = Edge(arrow_name="arrow", source=nodeB, target=nodeC) graph = Graph() graph.add_node(nodeA) graph.add_edge(edge1) assert graph.nodes == {nodeA, nodeB, nodeC} @pytest.fixture def nodes(): return { "A": Node(name="A"), "B": Node(name="B"), "C": Node(name="C"), "D": Node(name="D"), } @pytest.fixture def dag_edges(nodes): return { "1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]), "2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]), "3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]), } @pytest.fixture def notdag_edges(nodes): return { "1": Edge(arrow_name="arrow", source=nodes["A"], target=nodes["C"]), "2": Edge(arrow_name="arrow", source=nodes["B"], target=nodes["C"]), "3": Edge(arrow_name="arrow", source=nodes["C"], target=nodes["D"]), "4": Edge(arrow_name="arrow", source=nodes["D"], target=nodes["B"]), } def test_get_edges_from(nodes, dag_edges): edges = dag_edges graph = Graph(edges=edges.values()) assert graph.get_edges_from(nodes["A"]) == [edges["1"]] def test_get_targets_from(nodes, dag_edges): edges = dag_edges graph = Graph(edges=edges.values()) assert graph.get_direct_targets_from(nodes["A"]) == set([nodes["C"]]) assert graph.get_direct_targets_from(nodes["C"]) == set([nodes["D"]]) assert graph.get_direct_targets_from(nodes["D"]) == set() assert graph.get_targets_from(nodes["A"]) == set([nodes["C"], nodes["D"]]) def test_get_sources_from(nodes, dag_edges): edges = dag_edges graph = Graph(edges=edges.values()) assert graph.get_direct_sources_from(nodes["A"]) == set() assert graph.get_direct_sources_from(nodes["C"]) == set([nodes["A"], nodes["B"]]) assert graph.get_direct_sources_from(nodes["D"]) == set([nodes["C"]]) assert graph.get_sources_from(nodes["D"]) == set( [nodes["A"], nodes["B"], nodes["C"]] ) def test_valid_dage(dag_edges, notdag_edges): graph = Graph(edges=dag_edges.values()) assert graph.is_dag() graph = Graph(edges=notdag_edges.values()) assert not graph.is_dag()