You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
3.0 KiB
Java

package at.procon.dip.embedding.service;
import at.procon.dip.domain.document.EmbeddingStatus;
import at.procon.dip.domain.document.entity.DocumentEmbedding;
import at.procon.dip.domain.document.entity.DocumentEmbeddingModel;
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
import at.procon.dip.domain.document.repository.DocumentEmbeddingRepository;
import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository;
import at.procon.dip.domain.document.service.DocumentEmbeddingService;
import at.procon.dip.embedding.model.EmbeddingProviderResult;
import at.procon.dip.embedding.support.EmbeddingVectorCodec;
import java.time.OffsetDateTime;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@Service
@RequiredArgsConstructor
@Transactional
public class EmbeddingPersistenceService {
private final DocumentTextRepresentationRepository representationRepository;
private final DocumentEmbeddingService documentEmbeddingService;
private final DocumentEmbeddingRepository embeddingRepository;
private final EmbeddingModelCatalogService modelCatalogService;
public DocumentEmbedding ensurePending(UUID representationId, String modelKey) {
DocumentTextRepresentation representation = representationRepository.findById(representationId)
.orElseThrow(() -> new IllegalArgumentException("Unknown representation id: " + representationId));
DocumentEmbeddingModel model = modelCatalogService.ensureRegistered(modelKey);
return documentEmbeddingService.ensurePendingEmbedding(
representation.getDocument().getId(),
representation.getId(),
model.getId()
);
}
public DocumentEmbedding markProcessing(UUID embeddingId) {
return documentEmbeddingService.markProcessing(embeddingId);
}
public void saveCompleted(UUID embeddingId, EmbeddingProviderResult result) {
if (result.vectors() == null || result.vectors().isEmpty()) {
throw new IllegalArgumentException("Embedding provider result contains no vectors");
}
float[] vector = result.vectors().getFirst();
embeddingRepository.updateEmbeddingVector(
embeddingId,
vector, //EmbeddingVectorCodec.toPgVector(vector),
result.tokenCount(),
vector.length
);
}
public void markFailed(UUID embeddingId, String errorMessage) {
embeddingRepository.updateEmbeddingStatus(
embeddingId,
EmbeddingStatus.FAILED,
errorMessage,
null
);
}
public void markSkipped(UUID embeddingId, String reason) {
embeddingRepository.updateEmbeddingStatus(
embeddingId,
EmbeddingStatus.SKIPPED,
reason,
OffsetDateTime.now()
);
}
}