Source code for TELF.post_processing.ArcticFox.arcticfox

from .helpers.local_labels import ClusterLabeler  
from .helpers.intial_post_process import HNMFkPostProcessor
from .helpers.post_statistics import HNMFkStatsGenerator
from pathlib import Path
import pandas as pd

[docs] class ArcticFox: def __init__( self, model, embedding_model="SCINCL", distance_metric="cosine", center_metric="centroid", text_cols=None, top_n_words=50, clean_cols_name="clean_title_abstract", # Postprocessor column config col_year='year', col_type='type', col_cluster='cluster', col_cluster_coords='cluster_coordinates', col_similarity='similarity_to_cluster_centroid', ): self.model = model self.clean_cols_name = clean_cols_name self.labeler = ClusterLabeler( embedding_model=embedding_model, distance_metric=distance_metric, center_metric=center_metric, text_cols=text_cols ) self.postprocessor = HNMFkPostProcessor( top_n_words=top_n_words, default_clean_col=clean_cols_name, col_year=col_year, col_type=col_type, col_cluster=col_cluster, col_cluster_coords=col_cluster_coords, col_similarity=col_similarity ) self.stats_generator = HNMFkStatsGenerator(clean_cols_name=clean_cols_name) # Allow reuse of these columns in other methods self.col_cluster = col_cluster
[docs] def run_full_pipeline( self, vocab, data_df, text_column=None, ollama_model="llama3.2:3b-instruct-fp16", label_clusters=True, generate_stats=True, generate_visuals=True, process_parents=True, skip_completed=True, label_criteria=None, label_info=None, number_of_labels=5 ): text_column = text_column or self.clean_cols_name print("Step 1: Post-processing W/H matrix and cluster data...") self.postprocessor.post_process_hnmfk( hnmfk_model=self.model, V=vocab, D=data_df, col_name=text_column, skip_completed=skip_completed, process_parents=process_parents ) if label_clusters: print("Step 2: Labeling clusters with LLM...") self._label_all_clusters( vocab=vocab, data_df=data_df, text_column=text_column, ollama_model=ollama_model, label_criteria=label_criteria, label_info=label_info, number_of_labels=number_of_labels, process_parents=process_parents ) if generate_stats: print("Step 3: Generating Peacock visual stats...") self.stats_generator.generate_cluster_stats( model=self.model, process_parents=process_parents, skip_completed=skip_completed )
def _label_all_clusters( self, vocab, data_df, text_column, ollama_model, label_criteria, label_info, number_of_labels, process_parents ): for node in self.model.traverse_nodes(): if node["leaf"] or process_parents: w = node['W'] if w is None: w = node['signature'].reshape(-1, 1) node_dir = Path(node["node_save_path"]).resolve().parent cluster_file = node_dir / f"cluster_for_k={w.shape[1]}.csv" top_words_file = node_dir / "top_words.csv" if cluster_file.exists() and top_words_file.exists(): df = pd.read_csv(cluster_file) top_words_df = pd.read_csv(top_words_file) print(node_dir) annotations = self.labeler.label_clusters_ollama( top_words_df=top_words_df, ollama_model_name=ollama_model, embedding_model=self.labeler.embedding_model, df=df, criteria=label_criteria, additional_information=label_info, number_of_labels=number_of_labels, embedds_use_gpu=False ) pd.DataFrame([ {self.col_cluster: k, 'label': v, 'summary': ""} for k, v in annotations.items() ]).to_csv(node_dir / "cluster_summaries.csv", index=False)
[docs] def run_labeling(self, df, top_words_df, ollama_model_name, label_criteria=None, additional_info=None, number_of_labels=5): return self.labeler.label_clusters_ollama( top_words_df=top_words_df, ollama_model_name=ollama_model_name, embedding_model=self.labeler.embedding_model, df=df, criteria=label_criteria, additional_information=additional_info, number_of_labels=number_of_labels )
[docs] def run_postprocessing(self, V, D, col_name=None, **kwargs): return self.postprocessor.post_process_hnmfk( hnmfk_model=self.model, V=V, D=D, col_name=col_name or self.clean_cols_name, **kwargs )
[docs] def run_stats(self, process_parents=True, skip_completed=True): return self.stats_generator.generate_cluster_stats( model=self.model, process_parents=process_parents, skip_completed=skip_completed )