From 09783f9c1e20318bffb2e61d9ff12ab1752870f3 Mon Sep 17 00:00:00 2001 From: Bertrand Benjamin Date: Sun, 5 Jan 2025 15:31:40 +0100 Subject: [PATCH] Feat: flux takes list of tables for sources and targets --- plesna/compute/consume_flux.py | 2 +- plesna/models/flux.py | 26 +++++++++++++++++++++--- tests/compute/test_consume_flux.py | 24 ++++++++-------------- tests/dataplatform/test_dataplateform.py | 26 ++++++++++++------------ 4 files changed, 45 insertions(+), 33 deletions(-) diff --git a/plesna/compute/consume_flux.py b/plesna/compute/consume_flux.py index 2a176ea..207415d 100644 --- a/plesna/compute/consume_flux.py +++ b/plesna/compute/consume_flux.py @@ -3,6 +3,6 @@ from plesna.models.flux import Flux, FluxMetaData def consume_flux(flux: Flux) -> FluxMetaData: metadata = flux.transformation.function( - sources=flux.sources, targets=flux.targets, **flux.transformation.extra_kwrds + sources=flux.sources_dict, targets=flux.targets_dict, **flux.transformation.extra_kwrds ) return FluxMetaData(data=metadata) diff --git a/plesna/models/flux.py b/plesna/models/flux.py index 28dcc8f..42783d5 100644 --- a/plesna/models/flux.py +++ b/plesna/models/flux.py @@ -1,14 +1,34 @@ -from pydantic import BaseModel +from pydantic import BaseModel, computed_field from plesna.models.storage import Table from plesna.models.transformation import Transformation class Flux(BaseModel): - sources: dict[str, Table] - targets: dict[str, Table] + sources: list[Table] + targets: list[Table] transformation: Transformation + @computed_field + @property + def sources_dict(self) -> dict[str, Table]: + return {s.id: s for s in self.sources} + + @computed_field + @property + def sources_id(self) -> dict[str, Table]: + return [s.id for s in self.sources] + + @computed_field + @property + def targets_id(self) -> dict[str, Table]: + return [s.id for s in self.targets] + + @computed_field + @property + def targets_dict(self) -> dict[str, Table]: + return {s.id: s for s in self.targets} + class FluxMetaData(BaseModel): data: dict diff --git a/tests/compute/test_consume_flux.py b/tests/compute/test_consume_flux.py index 9582d81..85ecf9d 100644 --- a/tests/compute/test_consume_flux.py +++ b/tests/compute/test_consume_flux.py @@ -5,22 +5,14 @@ from plesna.models.transformation import Transformation def test_consume_flux(): - sources = { - "src1": Table( - id="src1", repo_id="test", schema_id="test", name="test", value="here", datas=["d"] - ), - "src2": Table( - id="src2", repo_id="test", schema_id="test", name="test", value="here", datas=["d"] - ), - } - targets = { - "tgt1": Table( - id="tgt1", repo_id="test", schema_id="test", name="test", value="this", datas=["d"] - ), - "tgt2": Table( - id="tgt2", repo_id="test", schema_id="test", name="test", value="that", datas=["d"] - ), - } + sources = [ + Table(id="src1", repo_id="test", schema_id="test", name="test", value="here", datas=["d"]), + Table(id="src2", repo_id="test", schema_id="test", name="test", value="here", datas=["d"]), + ] + targets = [ + Table(id="tgt1", repo_id="test", schema_id="test", name="test", value="this", datas=["d"]), + Table(id="tgt2", repo_id="test", schema_id="test", name="test", value="that", datas=["d"]), + ] def func(sources, targets, **kwrds): return { diff --git a/tests/dataplatform/test_dataplateform.py b/tests/dataplatform/test_dataplateform.py index 0867eef..f583013 100644 --- a/tests/dataplatform/test_dataplateform.py +++ b/tests/dataplatform/test_dataplateform.py @@ -41,12 +41,12 @@ def test_add_repository( @pytest.fixture def copy_flux(repository: FSRepository) -> Flux: - raw_username = {"username": repository.table("test-raw-username")} - bronze_username = {"username": repository.table("test-bronze-username")} + raw_username = [repository.table("test-raw-username")] + bronze_username = [repository.table("test-bronze-username")] def copy(sources, targets): - src_path = Path(sources["username"].datas[0]) - tgt_path = Path(targets["username"].datas[0]) + src_path = Path(sources["test-raw-username"].datas[0]) + tgt_path = Path(targets["test-bronze-username"].datas[0]) shutil.copy(src_path, tgt_path) return {"src_size": src_path.stat().st_size, "tgt_size": tgt_path.stat().st_size} @@ -62,11 +62,11 @@ def copy_flux(repository: FSRepository) -> Flux: @pytest.fixture def foo_flux(repository: FSRepository) -> Flux: - src = { - "username": repository.table("test-raw-username"), - "recovery": repository.table("test-raw-recovery"), - } - targets = {"username_foo": repository.table("test-bronze-foo")} + src = [ + repository.table("test-raw-username"), + repository.table("test-raw-recovery"), + ] + targets = [repository.table("test-bronze-foo")] def foo(sources, targets): return {"who": "foo"} @@ -131,10 +131,10 @@ def test_content_from_graph(dataplatform: DataPlateform): Node(name="test-raw-username", infos={}), } - # assert dataplatform.graphset.node_sets == { - # Node(name="test-raw-username", infos={}), - # Node(name="test-bronze-username", infos={}), - # } + assert dataplatform.graphset.node_sets == { + Node(name="test-raw-username", infos={}), + Node(name="test-bronze-username", infos={}), + } def test_execute_flux(dataplatform: DataPlateform):