164 lines
5.5 KiB
Python
164 lines
5.5 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
import numpy as np
|
|
import psycopg2
|
|
|
|
from .models import ClusteringAlgorithm, PythonClusteringItem, ReductionMethod, RunMetadata
|
|
from .settings import ServiceSettings
|
|
|
|
|
|
def load_run_and_embeddings(run_id: UUID) -> tuple[RunMetadata, list[PythonClusteringItem]]:
|
|
settings = ServiceSettings.from_env()
|
|
with psycopg2.connect(settings.db_dsn) as connection:
|
|
run = _load_run_metadata(connection, run_id)
|
|
items = _load_embeddings(connection, run.selection)
|
|
return run, items
|
|
|
|
|
|
def _load_run_metadata(connection, run_id: UUID) -> RunMetadata:
|
|
with connection.cursor() as cursor:
|
|
cursor.execute(
|
|
"""
|
|
select
|
|
id,
|
|
algorithm,
|
|
coalesce(parameters_json::text, '{}'),
|
|
reduction_method,
|
|
reduction_dimensions,
|
|
coalesce(selection_json::text, '{}')
|
|
from doc.doc_embedding_cluster_run
|
|
where id = %s
|
|
""",
|
|
(str(run_id),),
|
|
)
|
|
row = cursor.fetchone()
|
|
|
|
if row is None:
|
|
raise ValueError(f"Cluster run not found: {run_id}")
|
|
|
|
parameters = _json_to_dict(row[2])
|
|
selection = _json_to_dict(row[5])
|
|
|
|
return RunMetadata(
|
|
runId=row[0],
|
|
algorithm=ClusteringAlgorithm(row[1]),
|
|
parameters=parameters,
|
|
reductionMethod=ReductionMethod(row[3]) if row[3] else ReductionMethod.NONE,
|
|
reductionDimensions=row[4],
|
|
selection=selection,
|
|
)
|
|
|
|
|
|
def _load_embeddings(connection, selection: dict[str, Any]) -> list[PythonClusteringItem]:
|
|
sql_parts = [
|
|
"""
|
|
select
|
|
e.id as embedding_id,
|
|
e.document_id,
|
|
e.representation_id,
|
|
e.embedding_vector::text as embedding_vector_text
|
|
from doc.doc_embedding e
|
|
join doc.doc_document d on d.id = e.document_id
|
|
join doc.doc_text_representation r on r.id = e.representation_id
|
|
where e.embedding_status = 'COMPLETED'
|
|
and e.embedding_vector is not null
|
|
"""
|
|
]
|
|
params: list[Any] = []
|
|
|
|
_apply_selection_filters(selection, sql_parts, params)
|
|
sql_parts.append(" order by e.created_at asc")
|
|
sql = "".join(sql_parts)
|
|
|
|
items: list[PythonClusteringItem] = []
|
|
with connection.cursor(name="cluster_embedding_selection") as cursor:
|
|
cursor.itersize = 2000
|
|
cursor.execute(sql, params)
|
|
for embedding_id, document_id, representation_id, vector_text in cursor:
|
|
items.append(
|
|
PythonClusteringItem(
|
|
embeddingId=embedding_id,
|
|
documentId=document_id,
|
|
representationId=representation_id,
|
|
vector=_parse_vector_text(vector_text),
|
|
)
|
|
)
|
|
return items
|
|
|
|
|
|
def _apply_selection_filters(selection: dict[str, Any], sql_parts: list[str], params: list[Any]) -> None:
|
|
if not selection:
|
|
return
|
|
|
|
_append_in_filter(sql_parts, params, "documentTypes", "d.document_type", selection.get("documentTypes"))
|
|
_append_in_filter(sql_parts, params, "documentFamilies", "d.document_family", selection.get("documentFamilies"))
|
|
_append_in_filter(sql_parts, params, "representationTypes", "r.representation_type", selection.get("representationTypes"))
|
|
_append_in_filter(sql_parts, params, "embeddingStatuses", "e.embedding_status", selection.get("embeddingStatuses"))
|
|
_append_in_filter(sql_parts, params, "modelIds", "e.model_id", selection.get("modelIds"))
|
|
_append_in_filter(sql_parts, params, "prefixProfileIds", "e.prefix_profile_id", selection.get("prefixProfileIds"))
|
|
_append_in_filter(sql_parts, params, "builderKeys", "r.builder_key", selection.get("builderKeys"))
|
|
_append_in_filter(sql_parts, params, "languageCodes", "r.language_code", selection.get("languageCodes"))
|
|
_append_in_filter(sql_parts, params, "ownerTenantIds", "d.owner_tenant_id", selection.get("ownerTenantIds"))
|
|
|
|
business_key_like = selection.get("businessKeyLike")
|
|
if business_key_like:
|
|
sql_parts.append(" and d.business_key like %s")
|
|
params.append(business_key_like)
|
|
|
|
created_from = selection.get("createdFrom")
|
|
if created_from:
|
|
sql_parts.append(" and d.created_at >= %s")
|
|
params.append(created_from)
|
|
|
|
created_to = selection.get("createdTo")
|
|
if created_to:
|
|
sql_parts.append(" and d.created_at < %s")
|
|
params.append(created_to)
|
|
|
|
if selection.get("primaryRepresentationOnly") is True:
|
|
sql_parts.append(" and r.is_primary = true")
|
|
|
|
|
|
def _append_in_filter(
|
|
sql_parts: list[str],
|
|
params: list[Any],
|
|
_key: str,
|
|
column_name: str,
|
|
values: list[Any] | None,
|
|
) -> None:
|
|
if not values:
|
|
return
|
|
placeholders = ", ".join(["%s"] * len(values))
|
|
sql_parts.append(f" and {column_name} in ({placeholders})")
|
|
params.extend(values)
|
|
|
|
|
|
def _parse_vector_text(raw_value: str) -> list[float]:
|
|
if raw_value is None:
|
|
return []
|
|
|
|
value = raw_value.strip()
|
|
if value.startswith("[") and value.endswith("]"):
|
|
value = value[1:-1]
|
|
|
|
if not value:
|
|
return []
|
|
|
|
vector = np.fromstring(value, sep=",", dtype=np.float32)
|
|
return vector.astype(float).tolist()
|
|
|
|
|
|
def _json_to_dict(raw_json: str | dict[str, Any] | None) -> dict[str, Any]:
|
|
if raw_json is None:
|
|
return {}
|
|
if isinstance(raw_json, dict):
|
|
return raw_json
|
|
if not raw_json.strip():
|
|
return {}
|
|
loaded = json.loads(raw_json)
|
|
return loaded if isinstance(loaded, dict) else {}
|