Feat: add Interseptor

This commit is contained in:
Bertrand Benjamin 2024-02-18 17:40:52 +01:00
parent 75e196e366
commit 2f77206b8f
4 changed files with 91 additions and 0 deletions

0
scripts/__init__.py Normal file
View File

View File

@ -0,0 +1,28 @@
from collections.abc import Callable
import pandas as pd
from pydantic import BaseModel, ValidationError
class Interseptor:
def __init__(self, model: BaseModel):
self.model = model
self.not_valid_rows = []
def __call__(self, func: Callable[..., pd.DataFrame]):
def wrapped(*args, **kwrds):
res = func(*args, **kwrds)
df_dict = res.to_dict(orient="records")
valid_rows = []
for i, r in enumerate(df_dict):
try:
self.model(**r)
except ValidationError:
r["InterseptorOrigin"] = func.__name__
r["InterseptorIndex"] = i
self.not_valid_rows.append(r)
else:
valid_rows.append(r)
return pd.DataFrame.from_records(valid_rows)
return wrapped

0
tests/__ini__.py Normal file
View File

View File

@ -0,0 +1,63 @@
import random
import pandas as pd
import pytest
from pydantic import BaseModel
from scripts.intersept_not_valid import Interseptor
class FakeModel(BaseModel):
name: str
age: int
def test_init_composed():
interceptor = Interseptor(FakeModel)
def df_generator(nrows=3):
records = [{"name": "plop", "age": random.randint(1, 50)} for _ in range(nrows)]
return pd.DataFrame.from_records(records)
df_generator_val = interceptor(df_generator)
df = df_generator_val(3)
assert len(df) == 3
assert interceptor.not_valid_rows == []
def test_init_decorator():
interceptor = Interseptor(FakeModel)
@interceptor
def df_generator(nrows=3):
records = [{"name": "plop", "age": random.randint(1, 50)} for _ in range(nrows)]
return pd.DataFrame.from_records(records)
df = df_generator(3)
assert len(df) == 3
assert interceptor.not_valid_rows == []
def test_intersept_not_valid():
interceptor = Interseptor(FakeModel)
@interceptor
def df_generator():
records = [
{"name": "plop", "age": 12},
{"name": "hop", "age": "ui"},
{"name": "pipo", "age": 12},
]
return pd.DataFrame.from_records(records)
df = df_generator()
assert len(df) == 2
assert interceptor.not_valid_rows == [
{
"name": "hop",
"age": "ui",
"InterseptorOrigin": "df_generator",
"InterseptorIndex": 1,
}
]