from __future__ import annotations import json from typing import Any from uuid import UUID import numpy as np import psycopg2 from .models import ClusteringAlgorithm, PythonClusteringItem, ReductionMethod, RunMetadata from .settings import ServiceSettings def load_run_and_embeddings(run_id: UUID) -> tuple[RunMetadata, list[PythonClusteringItem]]: settings = ServiceSettings.from_env() with psycopg2.connect(settings.db_dsn) as connection: run = _load_run_metadata(connection, run_id) items = _load_embeddings(connection, run.selection) return run, items def _load_run_metadata(connection, run_id: UUID) -> RunMetadata: with connection.cursor() as cursor: cursor.execute( """ select id, algorithm, coalesce(parameters_json::text, '{}'), reduction_method, reduction_dimensions, coalesce(selection_json::text, '{}') from doc.doc_embedding_cluster_run where id = %s """, (str(run_id),), ) row = cursor.fetchone() if row is None: raise ValueError(f"Cluster run not found: {run_id}") parameters = _json_to_dict(row[2]) selection = _json_to_dict(row[5]) return RunMetadata( runId=row[0], algorithm=ClusteringAlgorithm(row[1]), parameters=parameters, reductionMethod=ReductionMethod(row[3]) if row[3] else ReductionMethod.NONE, reductionDimensions=row[4], selection=selection, ) def _load_embeddings(connection, selection: dict[str, Any]) -> list[PythonClusteringItem]: sql_parts = [ """ select e.id as embedding_id, e.document_id, e.representation_id, e.embedding_vector::text as embedding_vector_text from doc.doc_embedding e join doc.doc_document d on d.id = e.document_id join doc.doc_text_representation r on r.id = e.representation_id where e.embedding_status = 'COMPLETED' and e.embedding_vector is not null """ ] params: list[Any] = [] _apply_selection_filters(selection, sql_parts, params) sql_parts.append(" order by e.created_at asc") sql = "".join(sql_parts) items: list[PythonClusteringItem] = [] with connection.cursor(name="cluster_embedding_selection") as cursor: cursor.itersize = 2000 cursor.execute(sql, params) for embedding_id, document_id, representation_id, vector_text in cursor: items.append( PythonClusteringItem( embeddingId=embedding_id, documentId=document_id, representationId=representation_id, vector=_parse_vector_text(vector_text), ) ) return items def _apply_selection_filters(selection: dict[str, Any], sql_parts: list[str], params: list[Any]) -> None: if not selection: return _append_in_filter(sql_parts, params, "documentTypes", "d.document_type", selection.get("documentTypes")) _append_in_filter(sql_parts, params, "documentFamilies", "d.document_family", selection.get("documentFamilies")) _append_in_filter(sql_parts, params, "representationTypes", "r.representation_type", selection.get("representationTypes")) _append_in_filter(sql_parts, params, "embeddingStatuses", "e.embedding_status", selection.get("embeddingStatuses")) _append_in_filter(sql_parts, params, "modelIds", "e.model_id", selection.get("modelIds")) _append_in_filter(sql_parts, params, "prefixProfileIds", "e.prefix_profile_id", selection.get("prefixProfileIds")) _append_in_filter(sql_parts, params, "builderKeys", "r.builder_key", selection.get("builderKeys")) _append_in_filter(sql_parts, params, "languageCodes", "r.language_code", selection.get("languageCodes")) _append_in_filter(sql_parts, params, "ownerTenantIds", "d.owner_tenant_id", selection.get("ownerTenantIds")) business_key_like = selection.get("businessKeyLike") if business_key_like: sql_parts.append(" and d.business_key like %s") params.append(business_key_like) created_from = selection.get("createdFrom") if created_from: sql_parts.append(" and d.created_at >= %s") params.append(created_from) created_to = selection.get("createdTo") if created_to: sql_parts.append(" and d.created_at < %s") params.append(created_to) if selection.get("primaryRepresentationOnly") is True: sql_parts.append(" and r.is_primary = true") def _append_in_filter( sql_parts: list[str], params: list[Any], _key: str, column_name: str, values: list[Any] | None, ) -> None: if not values: return placeholders = ", ".join(["%s"] * len(values)) sql_parts.append(f" and {column_name} in ({placeholders})") params.extend(values) def _parse_vector_text(raw_value: str) -> list[float]: if raw_value is None: return [] value = raw_value.strip() if value.startswith("[") and value.endswith("]"): value = value[1:-1] if not value: return [] vector = np.fromstring(value, sep=",", dtype=np.float32) return vector.astype(float).tolist() def _json_to_dict(raw_json: str | dict[str, Any] | None) -> dict[str, Any]: if raw_json is None: return {} if isinstance(raw_json, dict): return raw_json if not raw_json.strip(): return {} loaded = json.loads(raw_json) return loaded if isinstance(loaded, dict) else {}