import logging
from abc import abstractmethod
from collections.abc import Callable
from pathlib import Path

import pandas as pd
from pydantic import BaseModel, Field


class Source(BaseModel):
    filename: str

    @abstractmethod
    def get_df(self) -> pd.DataFrame:
        raise NotImplementedError


class ExcelSource(Source):
    sheet_name: str

    def get_df(self, base_path: Path) -> pd.DataFrame:
        filepath = base_path / self.filename
        logging.debug(f"Get content of {filepath}")
        return pd.read_excel(filepath, sheet_name=self.sheet_name)


class CSVSource(Source):
    options: dict = {}

    def get_df(self, base_path: Path) -> pd.DataFrame:
        filepath = base_path / self.filename
        logging.debug(f"Get content of {filepath}")
        return pd.read_csv(filepath, **self.options)


class Transformation(BaseModel):
    function: Callable
    extra_kwrds: dict = {}


def to_csv(df, dest_basename: Path) -> Path:
    dest = dest_basename.parent / (dest_basename.stem + ".csv")
    if dest.exists():
        df.to_csv(dest, mode="a", header=False, index=False)
    else:
        df.to_csv(dest, index=False)
    return dest


def to_excel(df, dest_basename: Path) -> Path:
    dest = dest_basename.parent / (dest_basename.stem + ".xlsx")
    if dest.exists():
        raise ValueError(f"The destination exits {dest}")
    else:
        df.to_excel(dest)
    return dest


class Destination(BaseModel):
    name: str
    writer: Callable = Field(to_csv)

    def _write(
        self,
        df: pd.DataFrame,
        dest_basename: Path,
        writing_func: Callable | None = None,
    ) -> Path:
        if writing_func is None:
            writing_func = self.writer

        return writing_func(df, dest_basename)

    def write(
        self, df: pd.DataFrame, dest_path: Path, writing_func: Callable | None = None
    ) -> list[Path]:
        dest_basename = dest_path / self.name
        return [self._write(df, dest_basename, writing_func)]


class SplitDestination(Destination):
    split_column: str

    def write(
        self, df: pd.DataFrame, dest_path: Path, writing_func: Callable | None = None
    ) -> list[Path]:
        wrote_files = []

        for col_value in df[self.split_column].unique():
            filtered_df = df[df[self.split_column] == col_value]

            dest_basename = dest_path / f"{self.name}-{col_value}"
            dest = self._write(filtered_df, dest_basename, writing_func)
            wrote_files.append(dest)

        return wrote_files


class Flux(BaseModel):
    sources: list[Source]
    transformation: Transformation
    destination: Destination


def write_split_by(
    df: pd.DataFrame, column: str, dest_path: Path, name: str, writing_func
) -> list[Path]:
    wrote_files = []

    for col_value in df[column].unique():
        filtered_df = df[df[column] == col_value]

        dest_basename = dest_path / f"{name}-{col_value}"
        dest = writing_func(filtered_df, dest_basename)
        wrote_files.append(dest)

    return wrote_files


def extract_sources(sources: list[Source], base_path: Path = Path()):
    for src in sources:
        if "*" in src.filename:
            expanded_src = [
                src.model_copy(update={"filename": str(p.relative_to(base_path))})
                for p in base_path.glob(src.filename)
            ]
            yield from extract_sources(expanded_src, base_path)
        else:
            filepath = base_path / src.filename
            assert filepath.exists
            yield src.filename, src.get_df(base_path)


def split_duplicates(
    df, origin: str, duplicated: dict[str, pd.DataFrame]
) -> [pd.DataFrame, dict[str, pd.DataFrame]]:
    duplicates = df.duplicated()
    no_duplicates = df[~duplicates]
    duplicated[origin] = df[duplicates]
    return no_duplicates, duplicated


def consume_flux(
    name: str,
    flux: Flux,
    origin_path: Path,
    dest_path: Path,
    duplicated={},
):
    logging.info(f"Consume {name}")
    src_df = []
    for filename, df in extract_sources(flux.sources, origin_path):
        logging.info(f"Extracting {filename}")
        df, duplicated = split_duplicates(df, str(filename), duplicated)
        src_df.append(df)

    logging.info(f"Execute {flux.transformation.function.__name__}")
    df = flux.transformation.function(src_df, **flux.transformation.extra_kwrds)

    files = flux.destination.write(df, dest_path)

    logging.info(f"{files} written")
    return files


def consume_fluxes(
    fluxes: dict[str, Flux],
    origin_path: Path,
    dest_path: Path,
):
    duplicated = {}
    wrote_files = []

    for name, flux in fluxes.items():
        files = consume_flux(name, flux, origin_path, dest_path, duplicated)
        wrote_files += files
    return wrote_files