import json
import logging
import math
from pathlib import Path
from typing import Union, Optional
import pandas as pd
import numpy as np
import torch
from tqdm.auto import tqdm
from ....helpers.embeddings import compute_embeddings
log = logging.getLogger(__name__)
[docs]
class EmbeddingPruner:
"""
Prune documents by distance from a reference-class centroid in embedding space
"""
def __init__(
self,
*,
embedding_model: str = "SCINCL",
distance_std_factor: float = 3.0,
overwrite_embeddings: bool = False,
use_gpu: Optional[bool] = None,
verbose: bool = True,
):
"""
Initialize the EmbeddingPruner.
Parameters
----------
embedding_model : str
Name of the embedding model to use.
distance_std_factor : float
Multiplier on standard deviation to set distance threshold.
overwrite_embeddings : bool
If True, always recompute embeddings even if cache exists.
use_gpu : bool or None
Whether to use GPU for embedding. If None, auto-detect.
verbose : bool
Whether to display progress bars during embedding.
"""
self.embedding_model = embedding_model
self.distance_std_factor = distance_std_factor
self.overwrite_embeddings = overwrite_embeddings
self.use_gpu = torch.cuda.is_available() if use_gpu is None else use_gpu
self.verbose = verbose
self.NAME = 'embed_prune'
[docs]
def load_or_compute_embeddings(self, df, output_dir, data_column) -> np.ndarray:
"""
Load or compute embeddings for the specified column in the DataFrame.
Parameters
----------
df : pd.DataFrame
The input DataFrame to be processed.
output_dir : str or Path
Directory to save the output files.
data_column : str
Column name containing the data to be voted on.
"""
cache_path = Path(output_dir) / "embeddings.npy"
if cache_path.exists() and not self.overwrite_embeddings:
emb = np.load(cache_path)
if emb.shape[0] == len(df):
return emb
log.warning("Cache size %d != rows %d, recomputing", emb.shape[0], len(df))
# Batch compute embeddings
def batch_embed(df_slice: pd.DataFrame) -> np.ndarray:
return compute_embeddings(
df_slice,
model_name=self.embedding_model,
cols=[data_column],
sep_token="[SEP]",
as_np=True,
use_gpu=self.use_gpu
).astype(np.float32)
try:
emb = batch_embed(df)
except Exception as e:
log.warning("Batch embed failed (%s), row-wise fallback", e)
emb = np.vstack([
batch_embed(df.iloc[i:i+1])[0][None, :]
for i in tqdm(range(len(df)), disable=not self.verbose)
])
if emb.shape[0] != len(df):
raise ValueError(f"Embeddings {emb.shape[0]} != rows {len(df)}")
np.save(cache_path, emb)
return emb
[docs]
def select_inliers(self, df, emb: np.ndarray, label_column: str, reference_label: Union[int, str]) -> np.ndarray:
"""
Compute which rows are within threshold distance to reference centroid
Parameters
----------
df : pd.DataFrame
DataFrame containing the dataset.
emb : np.ndarray
Embedding matrix of shape (n_samples, embedding_dim).
label_column : str
Column name indicating class labels.
reference_label : int | str
Label used as the reference class for centroid.
Returns
-------
inliers_mask : np.ndarray of bool
Mask indicating rows within distance threshold.
"""
# Compute centroid only on previously accepted
cluster_mask = df[label_column] == reference_label
centroid = emb[cluster_mask].mean(axis=0)
# Distances for all
dists = np.linalg.norm(emb - centroid, axis=1)
mu, sigma = dists[cluster_mask].mean(), dists[cluster_mask].std()
thresh = mu + self.distance_std_factor * sigma
# Inliers
inliers = (dists <= thresh)
log.info("Embed prune: thr=%.4f, kept %d ", thresh, inliers.sum() )
return inliers
def __call__(self, df, output_dir, label_column: str, reference_label: Union[int, str], data_column:str) -> pd.DataFrame:
"""
Execute pruning: annotate 'embed_accept' for rows that were inliers.
Saves the annotated DataFrame to CSV and returns it.
Parameters
----------
df : pd.DataFrame
The input DataFrame to be processed.
output_dir : str or Path
Directory to save the output files.
label_column : str
Column name indicating class labels.
reference_label : Union[int, str]
Label used as the reference class for centroid.
data_column : str
Column name containing the data to be voted on.
Returns
-------
df : pd.DataFrame
DataFrame with added 'embed_accept' column.
"""
output_dir = Path(output_dir) / "embed_outputs"
output_dir.mkdir(parents=True, exist_ok=True)
emb = self.load_or_compute_embeddings(df, output_dir, data_column)
inliers = self.select_inliers(df, emb, label_column, reference_label)
df[self.NAME] = inliers
out_csv = Path(output_dir) / "embed_pruned.csv"
df.to_csv(out_csv, index=False)
log.info("Saved embed-pruned DF → %s", out_csv)
return df