You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
3.0 KiB
Java
74 lines
3.0 KiB
Java
package at.procon.dip.embedding.service;
|
|
|
|
import at.procon.dip.domain.document.EmbeddingStatus;
|
|
import at.procon.dip.domain.document.entity.DocumentEmbedding;
|
|
import at.procon.dip.domain.document.entity.DocumentEmbeddingModel;
|
|
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
|
|
import at.procon.dip.domain.document.repository.DocumentEmbeddingRepository;
|
|
import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository;
|
|
import at.procon.dip.domain.document.service.DocumentEmbeddingService;
|
|
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
|
import at.procon.dip.embedding.support.EmbeddingVectorCodec;
|
|
import java.time.OffsetDateTime;
|
|
import java.util.UUID;
|
|
import lombok.RequiredArgsConstructor;
|
|
import org.springframework.stereotype.Service;
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
|
@Service
|
|
@RequiredArgsConstructor
|
|
@Transactional
|
|
public class EmbeddingPersistenceService {
|
|
|
|
private final DocumentTextRepresentationRepository representationRepository;
|
|
private final DocumentEmbeddingService documentEmbeddingService;
|
|
private final DocumentEmbeddingRepository embeddingRepository;
|
|
private final EmbeddingModelCatalogService modelCatalogService;
|
|
|
|
public DocumentEmbedding ensurePending(UUID representationId, String modelKey) {
|
|
DocumentTextRepresentation representation = representationRepository.findById(representationId)
|
|
.orElseThrow(() -> new IllegalArgumentException("Unknown representation id: " + representationId));
|
|
DocumentEmbeddingModel model = modelCatalogService.ensureRegistered(modelKey);
|
|
return documentEmbeddingService.ensurePendingEmbedding(
|
|
representation.getDocument().getId(),
|
|
representation.getId(),
|
|
model.getId()
|
|
);
|
|
}
|
|
|
|
public DocumentEmbedding markProcessing(UUID embeddingId) {
|
|
return documentEmbeddingService.markProcessing(embeddingId);
|
|
}
|
|
|
|
public void saveCompleted(UUID embeddingId, EmbeddingProviderResult result) {
|
|
if (result.vectors() == null || result.vectors().isEmpty()) {
|
|
throw new IllegalArgumentException("Embedding provider result contains no vectors");
|
|
}
|
|
float[] vector = result.vectors().getFirst();
|
|
embeddingRepository.updateEmbeddingVector(
|
|
embeddingId,
|
|
vector, //EmbeddingVectorCodec.toPgVector(vector),
|
|
result.tokenCount(),
|
|
vector.length
|
|
);
|
|
}
|
|
|
|
public void markFailed(UUID embeddingId, String errorMessage) {
|
|
embeddingRepository.updateEmbeddingStatus(
|
|
embeddingId,
|
|
EmbeddingStatus.FAILED,
|
|
errorMessage,
|
|
null
|
|
);
|
|
}
|
|
|
|
public void markSkipped(UUID embeddingId, String reason) {
|
|
embeddingRepository.updateEmbeddingStatus(
|
|
embeddingId,
|
|
EmbeddingStatus.SKIPPED,
|
|
reason,
|
|
OffsetDateTime.now()
|
|
);
|
|
}
|
|
}
|