Source code for TELF.pre_processing.Squirrel.squirrel

import pandas as pd
import logging
from pathlib import Path
from typing import Union, List
from .pruners.embed_prune import EmbeddingPruner

log = logging.getLogger(__name__)

[docs] class Squirrel: """ Orchestrate a sequence of pruners that each take a DataFrame and return a DataFrame. """ def __init__( self, data_source: Union[str, Path, pd.DataFrame], output_dir: Union[str, Path], pipeline: List, label_column = 'type', reference_label = 0, aggregrate_prune = True, data_column = 'title_abstract', ): """ Parameters ---------- data_source : str | Path | pd.DataFrame CSV path or initial DataFrame to process. output_dir : str | Path Base directory for pruner outputs. pipeline : list List of pruner instances; each __call__ must return a DataFrame. """ self.label_column =label_column self.reference_label = reference_label self.aggregrate_prune = aggregrate_prune self.data_column = data_column # Load or copy if isinstance(data_source, (str, Path)): self._df = pd.read_csv(data_source) else: self._df = data_source.copy() self._output_dir = Path(output_dir) self._output_dir.mkdir(parents=True, exist_ok=True) if not pipeline: self._pipeline = [EmbeddingPruner(self._df)] self._pipeline = pipeline def __call__(self) -> pd.DataFrame: """ Run each pruner in sequence, passing the DataFrame result from one into the next. Before running each pruner, copy the latest “*_accept” column into `prev_accept`. """ df = self._df for pruner in self._pipeline: result_df = pruner(df, self._output_dir, self.label_column, self.reference_label, self.data_column) if self.aggregrate_prune: df = result_df[result_df[pruner.NAME]] df.to_csv(self._output_dir / "squirrel_pruned.csv", index=False) return df