DIP/python/dip-clustering-service/app/run_db_loader.py

164 lines
5.5 KiB
Python

from __future__ import annotations
import json
from typing import Any
from uuid import UUID
import numpy as np
import psycopg2
from .models import ClusteringAlgorithm, PythonClusteringItem, ReductionMethod, RunMetadata
from .settings import ServiceSettings
def load_run_and_embeddings(run_id: UUID) -> tuple[RunMetadata, list[PythonClusteringItem]]:
settings = ServiceSettings.from_env()
with psycopg2.connect(settings.db_dsn) as connection:
run = _load_run_metadata(connection, run_id)
items = _load_embeddings(connection, run.selection)
return run, items
def _load_run_metadata(connection, run_id: UUID) -> RunMetadata:
with connection.cursor() as cursor:
cursor.execute(
"""
select
id,
algorithm,
coalesce(parameters_json::text, '{}'),
reduction_method,
reduction_dimensions,
coalesce(selection_json::text, '{}')
from doc.doc_embedding_cluster_run
where id = %s
""",
(str(run_id),),
)
row = cursor.fetchone()
if row is None:
raise ValueError(f"Cluster run not found: {run_id}")
parameters = _json_to_dict(row[2])
selection = _json_to_dict(row[5])
return RunMetadata(
runId=row[0],
algorithm=ClusteringAlgorithm(row[1]),
parameters=parameters,
reductionMethod=ReductionMethod(row[3]) if row[3] else ReductionMethod.NONE,
reductionDimensions=row[4],
selection=selection,
)
def _load_embeddings(connection, selection: dict[str, Any]) -> list[PythonClusteringItem]:
sql_parts = [
"""
select
e.id as embedding_id,
e.document_id,
e.representation_id,
e.embedding_vector::text as embedding_vector_text
from doc.doc_embedding e
join doc.doc_document d on d.id = e.document_id
join doc.doc_text_representation r on r.id = e.representation_id
where e.embedding_status = 'COMPLETED'
and e.embedding_vector is not null
"""
]
params: list[Any] = []
_apply_selection_filters(selection, sql_parts, params)
sql_parts.append(" order by e.created_at asc")
sql = "".join(sql_parts)
items: list[PythonClusteringItem] = []
with connection.cursor(name="cluster_embedding_selection") as cursor:
cursor.itersize = 2000
cursor.execute(sql, params)
for embedding_id, document_id, representation_id, vector_text in cursor:
items.append(
PythonClusteringItem(
embeddingId=embedding_id,
documentId=document_id,
representationId=representation_id,
vector=_parse_vector_text(vector_text),
)
)
return items
def _apply_selection_filters(selection: dict[str, Any], sql_parts: list[str], params: list[Any]) -> None:
if not selection:
return
_append_in_filter(sql_parts, params, "documentTypes", "d.document_type", selection.get("documentTypes"))
_append_in_filter(sql_parts, params, "documentFamilies", "d.document_family", selection.get("documentFamilies"))
_append_in_filter(sql_parts, params, "representationTypes", "r.representation_type", selection.get("representationTypes"))
_append_in_filter(sql_parts, params, "embeddingStatuses", "e.embedding_status", selection.get("embeddingStatuses"))
_append_in_filter(sql_parts, params, "modelIds", "e.model_id", selection.get("modelIds"))
_append_in_filter(sql_parts, params, "prefixProfileIds", "e.prefix_profile_id", selection.get("prefixProfileIds"))
_append_in_filter(sql_parts, params, "builderKeys", "r.builder_key", selection.get("builderKeys"))
_append_in_filter(sql_parts, params, "languageCodes", "r.language_code", selection.get("languageCodes"))
_append_in_filter(sql_parts, params, "ownerTenantIds", "d.owner_tenant_id", selection.get("ownerTenantIds"))
business_key_like = selection.get("businessKeyLike")
if business_key_like:
sql_parts.append(" and d.business_key like %s")
params.append(business_key_like)
created_from = selection.get("createdFrom")
if created_from:
sql_parts.append(" and d.created_at >= %s")
params.append(created_from)
created_to = selection.get("createdTo")
if created_to:
sql_parts.append(" and d.created_at < %s")
params.append(created_to)
if selection.get("primaryRepresentationOnly") is True:
sql_parts.append(" and r.is_primary = true")
def _append_in_filter(
sql_parts: list[str],
params: list[Any],
_key: str,
column_name: str,
values: list[Any] | None,
) -> None:
if not values:
return
placeholders = ", ".join(["%s"] * len(values))
sql_parts.append(f" and {column_name} in ({placeholders})")
params.extend(values)
def _parse_vector_text(raw_value: str) -> list[float]:
if raw_value is None:
return []
value = raw_value.strip()
if value.startswith("[") and value.endswith("]"):
value = value[1:-1]
if not value:
return []
vector = np.fromstring(value, sep=",", dtype=np.float32)
return vector.astype(float).tolist()
def _json_to_dict(raw_json: str | dict[str, Any] | None) -> dict[str, Any]:
if raw_json is None:
return {}
if isinstance(raw_json, dict):
return raw_json
if not raw_json.strip():
return {}
loaded = json.loads(raw_json)
return loaded if isinstance(loaded, dict) else {}