2024-02-18 16:40:52 +00:00
|
|
|
from collections.abc import Callable
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
|
|
|
|
|
2024-02-21 07:46:11 +00:00
|
|
|
class ValidationInterseptor:
|
2024-02-18 16:40:52 +00:00
|
|
|
def __init__(self, model: BaseModel):
|
|
|
|
self.model = model
|
|
|
|
self.not_valid_rows = []
|
|
|
|
|
|
|
|
def __call__(self, func: Callable[..., pd.DataFrame]):
|
|
|
|
def wrapped(*args, **kwrds):
|
|
|
|
res = func(*args, **kwrds)
|
|
|
|
df_dict = res.to_dict(orient="records")
|
|
|
|
valid_rows = []
|
|
|
|
for i, r in enumerate(df_dict):
|
|
|
|
try:
|
|
|
|
self.model(**r)
|
|
|
|
except ValidationError:
|
2024-02-21 07:46:11 +00:00
|
|
|
r["ValidationInterseptorFunc"] = func.__name__
|
|
|
|
r["ValidationInterseptorArgs"] = args
|
|
|
|
r["ValidationInterseptorKwrds"] = kwrds
|
|
|
|
r["ValidationInterseptorIndex"] = i
|
2024-02-18 16:40:52 +00:00
|
|
|
self.not_valid_rows.append(r)
|
|
|
|
else:
|
|
|
|
valid_rows.append(r)
|
|
|
|
return pd.DataFrame.from_records(valid_rows)
|
|
|
|
|
|
|
|
return wrapped
|