"""
Utility functions for data loading, saving, and visualization.
"""
from __future__ import annotations
import csv
import json
from collections import defaultdict
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from clusx.errors import MissingClusterColumnError, MissingParametersError
from clusx.logging import get_logger
from clusx.utils import to_numpy
if TYPE_CHECKING:
from typing import Optional, Union
import torch
from numpy.typing import NDArray
EmbeddingTensor = Union[torch.Tensor, NDArray[np.float32]]
logger = get_logger(__name__)
[docs]
def is_csv_file(input_file: str) -> bool:
"""
Determine if a file is a CSV file based on extension and content.
Args:
input_file: Path to the input file
Returns:
bool: True if the file is likely a CSV, False otherwise
"""
# First check file extension
if input_file.lower().endswith(".csv"):
return True
# For files without .csv extension, try to detect CSV format
is_csv = False
try:
with open(input_file, "r", encoding="utf-8") as f:
# Read a sample of the file to determine if it's CSV
sample = f.read(4096) # Read a reasonable sample size
if sample: # Check if we got any content
# Try to detect CSV with Sniffer
try:
dialect = csv.Sniffer().sniff(sample)
if dialect.delimiter in [",", ";", "\t"]:
is_csv = True
except csv.Error:
# Not a CSV according to the sniffer
pass
except OSError:
logger.warning("Error accessing file %s", input_file)
return is_csv
[docs]
def load_data(input_file: str, column: Optional[str] = None) -> list[str]:
"""
Load text data from a file. Supports text files and CSV files.
Args:
input_file: Path to the input file (text or CSV)
column: Column name containing the text data (required for CSV files)
Returns:
list[str]: A list of texts
Raises:
ValueError: If a CSV file is provided without specifying a column
"""
texts = []
if is_csv_file(input_file):
if column is None:
raise ValueError("Column name must be specified when using a CSV file")
df = pd.read_csv(input_file, skip_blank_lines=True)
if column in df.columns:
# Drop NaN values and convert to list
texts = df[column].dropna().tolist()
else:
logger.warning(
"Column '%s' not found in CSV. Available columns: %s",
column,
", ".join(df.columns),
)
raise ValueError(f"Column '{column}' not found in the CSV file")
else:
# Process as a text file (one text per line)
with open(input_file, "r", encoding="utf-8") as f:
# Read all lines at once and filter out empty lines
texts = [line.strip() for line in f.readlines() if line.strip()]
return texts
[docs]
def save_clusters_to_csv(
output_file: str,
texts: list[str],
clusters: list[int],
model_name: str,
alpha: float,
sigma: float,
kappa: float,
) -> None:
"""
Save clustering results to a CSV file.
Args:
output_file: Path to the output CSV file
texts: List of text strings
clusters: List of cluster assignments
model_name: Name of the clustering model
alpha: Concentration parameter
sigma: Discount parameter
kappa: Kappa parameter for likelihood model
"""
df = pd.DataFrame(
{
"Text": texts,
f"Cluster_{model_name}": clusters,
"Alpha": [alpha] * len(texts),
"Sigma": [sigma] * len(texts),
"Kappa": [kappa] * len(texts),
}
)
df.to_csv(output_file, index=False, encoding="utf-8", quoting=csv.QUOTE_MINIMAL)
logger.debug("Clustering results saved to %s", output_file)
[docs]
def save_clusters_to_json(
output_file: str,
texts: list[str],
clusters: list[int],
model_name: str,
alpha: float,
sigma: float,
kappa: float,
) -> None:
"""
Save clustering results to a JSON file.
Args:
output_file: Path to the output JSON file
texts: List of text strings
clusters: List of cluster assignments
model_name: Name of the clustering model
alpha: Concentration parameter
sigma: Discount parameter
kappa: Kappa parameter for likelihood model
"""
# Group texts by cluster
cluster_groups = defaultdict(list)
for text, cluster_id in zip(texts, clusters):
cluster_groups[cluster_id].append(text)
clusters_json = {
"clusters": [],
"metadata": {
"model_name": model_name,
"alpha": alpha,
"sigma": sigma,
"kappa": kappa,
},
}
for i, (cluster_id, cluster_texts) in enumerate(cluster_groups.items()):
representative_text = cluster_texts[0]
# Create the cluster object with the new format
cluster_obj = {
"id": i + 1,
"representative": representative_text,
"members": cluster_texts,
}
clusters_json["clusters"].append(cluster_obj)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(clusters_json, f, indent=2, ensure_ascii=False)
logger.debug("JSON clusters saved to %s", output_file)
[docs]
def get_embeddings(texts: list[str]) -> np.ndarray:
"""
Get embeddings for a list of texts.
Args:
texts: List of text strings
Returns:
Numpy array of embeddings
"""
from datetime import datetime
from clusx.clustering import DirichletProcess
# TODO: Extract embedding generation to a separate function/class
# Use default parameters for embedding generation only
dp = DirichletProcess(alpha=1.0, kappa=1.0)
embeddings = []
# Process texts with progress bar
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
for text in tqdm(
texts,
desc=f"{timestamp} - INFO - Computing embeddings",
total=len(texts),
disable=None, # Disable on non-TTY
unit=" texts",
):
emb_array = to_numpy(dp.get_embedding(text))
embeddings.append(emb_array)
return np.array(embeddings)
[docs]
def load_cluster_assignments(csv_path: str) -> tuple[list[int], dict[str, float]]:
"""
Load cluster assignments and parameters from a CSV file.
Args:
csv_path: Path to the CSV file containing cluster assignments
Returns:
tuple[list[int], dict[str, float]]: A tuple containing:
- List of cluster assignments (clustered texts)
- Dictionary of parameters (alpha, sigma, kappa)
Raises:
MissingClusterColumnError: If no cluster column is found in the file
MissingParametersError: If required parameters are missing in the file
"""
df = pd.read_csv(csv_path)
# Check which column contains the cluster assignments
cluster_column = None
for col in df.columns:
# Cluster_PYP or Cluster_DP
if col.lower().startswith("cluster_"):
cluster_column = col
break
if not cluster_column:
raise MissingClusterColumnError(csv_path)
# Extract cluster assignments
cluster_assignments = df[cluster_column].tolist()
params = {}
# Check if parameter columns exist in the CSV
if "Alpha" in df.columns:
params["alpha"] = float(df["Alpha"].iloc[0])
if "Sigma" in df.columns:
params["sigma"] = float(df["Sigma"].iloc[0])
if "Kappa" in df.columns:
params["kappa"] = float(df["Kappa"].iloc[0])
missing_params = [
key
for key in ["alpha", "sigma", "kappa"]
if key not in params or params[key] is None
]
if missing_params:
raise MissingParametersError(csv_path, missing_params)
return cluster_assignments, params
[docs]
def load_parameters_from_json(json_path: str) -> dict[str, float]:
"""
Load clustering parameters from a JSON file.
Args:
json_path: Path to the JSON file containing clustering results
Returns:
dict[str, float]: A dictionary of parameters (alpha, sigma, kappa)
"""
# TODO: Do I really need defaults?
params = {"alpha": 1.0, "sigma": 0.0, "kappa": 1.0} # Default values
try:
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
# TODO: Should I throw an error if metadata or its keys are missing?
# Check if metadata is available in the JSON
if "metadata" in data:
if "alpha" in data["metadata"]:
params["alpha"] = float(data["metadata"]["alpha"])
if "sigma" in data["metadata"]:
params["sigma"] = float(data["metadata"]["sigma"])
if "kappa" in data["metadata"]:
params["kappa"] = float(data["metadata"]["kappa"])
except (OSError, json.JSONDecodeError) as err:
logger.error("Error loading parameters from JSON: %s", err)
return params