embedding nv2
parent
d7369c796c
commit
6a9696a6a7
@ -0,0 +1,96 @@
|
||||
package at.procon.dip.embedding.job.entity;
|
||||
|
||||
import at.procon.dip.architecture.SchemaNames;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobStatus;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobType;
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.EnumType;
|
||||
import jakarta.persistence.Enumerated;
|
||||
import jakarta.persistence.GeneratedValue;
|
||||
import jakarta.persistence.GenerationType;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Index;
|
||||
import jakarta.persistence.PrePersist;
|
||||
import jakarta.persistence.PreUpdate;
|
||||
import jakarta.persistence.Table;
|
||||
import java.time.OffsetDateTime;
|
||||
import java.util.UUID;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Setter;
|
||||
|
||||
@Entity
|
||||
@Table(schema = SchemaNames.DOC, name = "doc_embedding_job", indexes = {
|
||||
@Index(name = "idx_doc_embedding_job_status_next_retry", columnList = "status,next_retry_at"),
|
||||
@Index(name = "idx_doc_embedding_job_representation", columnList = "representation_id"),
|
||||
@Index(name = "idx_doc_embedding_job_document", columnList = "document_id"),
|
||||
@Index(name = "idx_doc_embedding_job_model_key", columnList = "model_key")
|
||||
})
|
||||
@Getter
|
||||
@Setter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class EmbeddingJob {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.UUID)
|
||||
private UUID id;
|
||||
|
||||
@Column(name = "document_id", nullable = false)
|
||||
private UUID documentId;
|
||||
|
||||
@Column(name = "representation_id", nullable = false)
|
||||
private UUID representationId;
|
||||
|
||||
@Column(name = "model_key", nullable = false, length = 255)
|
||||
private String modelKey;
|
||||
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Column(name = "job_type", nullable = false, length = 32)
|
||||
private EmbeddingJobType jobType;
|
||||
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Column(name = "status", nullable = false, length = 32)
|
||||
@Builder.Default
|
||||
private EmbeddingJobStatus status = EmbeddingJobStatus.PENDING;
|
||||
|
||||
@Column(name = "attempt_count", nullable = false)
|
||||
@Builder.Default
|
||||
private int attemptCount = 0;
|
||||
|
||||
@Column(name = "next_retry_at")
|
||||
private OffsetDateTime nextRetryAt;
|
||||
|
||||
@Column(name = "priority", nullable = false)
|
||||
@Builder.Default
|
||||
private int priority = 0;
|
||||
|
||||
@Column(name = "provider_request_id", length = 255)
|
||||
private String providerRequestId;
|
||||
|
||||
@Column(name = "last_error", columnDefinition = "TEXT")
|
||||
private String lastError;
|
||||
|
||||
@Column(name = "created_at", nullable = false, updatable = false)
|
||||
@Builder.Default
|
||||
private OffsetDateTime createdAt = OffsetDateTime.now();
|
||||
|
||||
@Column(name = "updated_at", nullable = false)
|
||||
@Builder.Default
|
||||
private OffsetDateTime updatedAt = OffsetDateTime.now();
|
||||
|
||||
@PrePersist
|
||||
protected void onCreate() {
|
||||
createdAt = OffsetDateTime.now();
|
||||
updatedAt = OffsetDateTime.now();
|
||||
}
|
||||
|
||||
@PreUpdate
|
||||
protected void onUpdate() {
|
||||
updatedAt = OffsetDateTime.now();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,33 @@
|
||||
package at.procon.dip.embedding.job.repository;
|
||||
|
||||
import at.procon.dip.embedding.job.entity.EmbeddingJob;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobStatus;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobType;
|
||||
import jakarta.persistence.LockModeType;
|
||||
import java.time.OffsetDateTime;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
import org.springframework.data.domain.Pageable;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.data.jpa.repository.Lock;
|
||||
import org.springframework.data.jpa.repository.Query;
|
||||
import org.springframework.data.repository.query.Param;
|
||||
|
||||
public interface EmbeddingJobRepository extends JpaRepository<EmbeddingJob, UUID> {
|
||||
|
||||
Optional<EmbeddingJob> findFirstByRepresentationIdAndModelKeyAndJobTypeAndStatusIn(
|
||||
UUID representationId,
|
||||
String modelKey,
|
||||
EmbeddingJobType jobType,
|
||||
Collection<EmbeddingJobStatus> statuses);
|
||||
|
||||
@Lock(LockModeType.PESSIMISTIC_WRITE)
|
||||
@Query("SELECT j FROM EmbeddingJob j WHERE j.status IN :statuses AND (j.nextRetryAt IS NULL OR j.nextRetryAt <= :now) ORDER BY j.priority DESC, j.createdAt ASC")
|
||||
List<EmbeddingJob> findReadyJobsForUpdate(@Param("statuses") Collection<EmbeddingJobStatus> statuses,
|
||||
@Param("now") OffsetDateTime now,
|
||||
Pageable pageable);
|
||||
|
||||
List<EmbeddingJob> findByDocumentId(UUID documentId);
|
||||
}
|
||||
@ -0,0 +1,116 @@
|
||||
package at.procon.dip.embedding.job.service;
|
||||
|
||||
import at.procon.dip.embedding.config.EmbeddingProperties;
|
||||
import at.procon.dip.embedding.job.entity.EmbeddingJob;
|
||||
import at.procon.dip.embedding.job.repository.EmbeddingJobRepository;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobStatus;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobType;
|
||||
import at.procon.dip.embedding.policy.EmbeddingSelectionPolicy;
|
||||
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
||||
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
|
||||
import java.time.Duration;
|
||||
import java.time.OffsetDateTime;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.data.domain.PageRequest;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Transactional
|
||||
public class EmbeddingJobService {
|
||||
|
||||
private static final Set<EmbeddingJobStatus> ACTIVE_STATUSES = Set.of(
|
||||
EmbeddingJobStatus.PENDING,
|
||||
EmbeddingJobStatus.IN_PROGRESS,
|
||||
EmbeddingJobStatus.RETRY_SCHEDULED
|
||||
);
|
||||
|
||||
private final EmbeddingJobRepository jobRepository;
|
||||
private final EmbeddingSelectionPolicy selectionPolicy;
|
||||
private final EmbeddingModelRegistry modelRegistry;
|
||||
private final EmbeddingProperties properties;
|
||||
|
||||
public List<EmbeddingJob> enqueueForDocument(UUID documentId) {
|
||||
return enqueueForDocument(documentId, modelRegistry.getRequiredDefaultDocumentModelKey());
|
||||
}
|
||||
|
||||
public List<EmbeddingJob> enqueueForDocument(UUID documentId, String modelKey) {
|
||||
var model = modelRegistry.getRequired(modelKey);
|
||||
List<DocumentTextRepresentation> selected = selectionPolicy.selectRepresentations(documentId, model);
|
||||
return selected.stream()
|
||||
.map(representation -> enqueueForRepresentation(documentId, representation.getId(), modelKey, EmbeddingJobType.DOCUMENT_EMBED))
|
||||
.toList();
|
||||
}
|
||||
|
||||
public EmbeddingJob enqueueForRepresentation(UUID documentId, UUID representationId, String modelKey, EmbeddingJobType jobType) {
|
||||
return jobRepository.findFirstByRepresentationIdAndModelKeyAndJobTypeAndStatusIn(
|
||||
representationId,
|
||||
modelKey,
|
||||
jobType,
|
||||
ACTIVE_STATUSES
|
||||
)
|
||||
.orElseGet(() -> jobRepository.save(EmbeddingJob.builder()
|
||||
.documentId(documentId)
|
||||
.representationId(representationId)
|
||||
.modelKey(modelKey)
|
||||
.jobType(jobType)
|
||||
.status(EmbeddingJobStatus.PENDING)
|
||||
.priority(0)
|
||||
.attemptCount(0)
|
||||
.build()));
|
||||
}
|
||||
|
||||
public List<EmbeddingJob> claimNextReadyJobs(int limit) {
|
||||
List<EmbeddingJob> jobs = jobRepository.findReadyJobsForUpdate(
|
||||
Set.of(EmbeddingJobStatus.PENDING, EmbeddingJobStatus.RETRY_SCHEDULED),
|
||||
OffsetDateTime.now(),
|
||||
PageRequest.of(0, limit)
|
||||
);
|
||||
jobs.forEach(job -> {
|
||||
job.setStatus(EmbeddingJobStatus.IN_PROGRESS);
|
||||
job.setAttemptCount(job.getAttemptCount() + 1);
|
||||
job.setLastError(null);
|
||||
job.setNextRetryAt(null);
|
||||
});
|
||||
return jobRepository.saveAll(jobs);
|
||||
}
|
||||
|
||||
public EmbeddingJob getRequired(UUID jobId) {
|
||||
return jobRepository.findById(jobId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Unknown embedding job id: " + jobId));
|
||||
}
|
||||
|
||||
public void markCompleted(UUID jobId, String providerRequestId) {
|
||||
EmbeddingJob job = getRequired(jobId);
|
||||
job.setStatus(EmbeddingJobStatus.DONE);
|
||||
job.setProviderRequestId(providerRequestId);
|
||||
job.setLastError(null);
|
||||
job.setNextRetryAt(null);
|
||||
jobRepository.save(job);
|
||||
}
|
||||
|
||||
public void markFailed(UUID jobId, String errorMessage, boolean retryable) {
|
||||
EmbeddingJob job = getRequired(jobId);
|
||||
job.setLastError(errorMessage);
|
||||
if (retryable && job.getAttemptCount() < properties.getJobs().getMaxRetries()) {
|
||||
job.setStatus(EmbeddingJobStatus.RETRY_SCHEDULED);
|
||||
job.setNextRetryAt(OffsetDateTime.now().plus(nextDelay(job.getAttemptCount())));
|
||||
} else {
|
||||
job.setStatus(EmbeddingJobStatus.FAILED);
|
||||
job.setNextRetryAt(null);
|
||||
}
|
||||
jobRepository.save(job);
|
||||
}
|
||||
|
||||
private Duration nextDelay(int attemptCount) {
|
||||
long factor = Math.max(0, attemptCount - 1);
|
||||
Duration candidate = properties.getJobs().getInitialRetryDelay().multipliedBy(1L << Math.min(factor, 10));
|
||||
return candidate.compareTo(properties.getJobs().getMaxRetryDelay()) > 0
|
||||
? properties.getJobs().getMaxRetryDelay()
|
||||
: candidate;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,9 @@
|
||||
package at.procon.dip.embedding.model;
|
||||
|
||||
public enum EmbeddingJobStatus {
|
||||
PENDING,
|
||||
IN_PROGRESS,
|
||||
DONE,
|
||||
FAILED,
|
||||
RETRY_SCHEDULED
|
||||
}
|
||||
@ -0,0 +1,7 @@
|
||||
package at.procon.dip.embedding.model;
|
||||
|
||||
public enum EmbeddingJobType {
|
||||
DOCUMENT_EMBED,
|
||||
QUERY_EMBED,
|
||||
REEMBED
|
||||
}
|
||||
@ -0,0 +1,68 @@
|
||||
package at.procon.dip.embedding.policy;
|
||||
|
||||
import at.procon.dip.domain.document.RepresentationType;
|
||||
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
|
||||
import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository;
|
||||
import at.procon.dip.embedding.config.EmbeddingProperties;
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class DefaultEmbeddingSelectionPolicy implements EmbeddingSelectionPolicy {
|
||||
|
||||
private final DocumentTextRepresentationRepository representationRepository;
|
||||
private final EmbeddingProperties embeddingProperties;
|
||||
|
||||
@Override
|
||||
public List<DocumentTextRepresentation> selectRepresentations(UUID documentId, EmbeddingModelDescriptor model) {
|
||||
List<DocumentTextRepresentation> representations = representationRepository.findByDocument_Id(documentId);
|
||||
List<DocumentTextRepresentation> selected = new ArrayList<>();
|
||||
EmbeddingProperties.IndexingProperties indexing = embeddingProperties.getIndexing();
|
||||
|
||||
for (DocumentTextRepresentation representation : representations) {
|
||||
if (include(representation, indexing)) {
|
||||
selected.add(representation);
|
||||
}
|
||||
}
|
||||
|
||||
if (selected.isEmpty()) {
|
||||
representationRepository.findFirstByDocument_IdAndPrimaryRepresentationTrue(documentId)
|
||||
.ifPresent(selected::add);
|
||||
}
|
||||
|
||||
Set<UUID> seen = new LinkedHashSet<>();
|
||||
return selected.stream()
|
||||
.filter(rep -> seen.add(rep.getId()))
|
||||
.sorted(Comparator
|
||||
.comparing(DocumentTextRepresentation::isPrimaryRepresentation).reversed()
|
||||
.thenComparing(rep -> rep.getRepresentationType().ordinal())
|
||||
.thenComparing(rep -> rep.getChunkIndex() == null ? Integer.MAX_VALUE : rep.getChunkIndex()))
|
||||
.toList();
|
||||
}
|
||||
|
||||
private boolean include(DocumentTextRepresentation representation, EmbeddingProperties.IndexingProperties indexing) {
|
||||
return switch (representation.getRepresentationType()) {
|
||||
case SEMANTIC_TEXT -> indexing.isEmbedSemanticText();
|
||||
case TITLE_ABSTRACT -> indexing.isEmbedTitleAbstract();
|
||||
case CHUNK -> indexing.isEmbedChunks() && effectiveLength(representation) >= indexing.getChunkMinLength();
|
||||
case FULLTEXT -> indexing.isEmbedFulltext();
|
||||
case SUMMARY -> indexing.isEmbedSummary();
|
||||
default -> representation.isPrimaryRepresentation();
|
||||
};
|
||||
}
|
||||
|
||||
private int effectiveLength(DocumentTextRepresentation representation) {
|
||||
if (representation.getCharCount() != null) {
|
||||
return representation.getCharCount();
|
||||
}
|
||||
return representation.getTextBody() == null ? 0 : representation.getTextBody().length();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,11 @@
|
||||
package at.procon.dip.embedding.policy;
|
||||
|
||||
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
public interface EmbeddingSelectionPolicy {
|
||||
|
||||
List<DocumentTextRepresentation> selectRepresentations(UUID documentId, EmbeddingModelDescriptor model);
|
||||
}
|
||||
@ -0,0 +1,33 @@
|
||||
package at.procon.dip.embedding.service;
|
||||
|
||||
import at.procon.dip.domain.document.entity.DocumentEmbeddingModel;
|
||||
import at.procon.dip.domain.document.service.DocumentEmbeddingService;
|
||||
import at.procon.dip.domain.document.service.command.RegisterEmbeddingModelCommand;
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Transactional
|
||||
public class EmbeddingModelCatalogService {
|
||||
|
||||
private final EmbeddingModelRegistry modelRegistry;
|
||||
private final DocumentEmbeddingService documentEmbeddingService;
|
||||
|
||||
public DocumentEmbeddingModel ensureRegistered(String modelKey) {
|
||||
EmbeddingModelDescriptor descriptor = modelRegistry.getRequired(modelKey);
|
||||
documentEmbeddingService.registerModel(new RegisterEmbeddingModelCommand(
|
||||
descriptor.modelKey(),
|
||||
descriptor.providerConfigKey(),
|
||||
descriptor.providerModelKey(),
|
||||
descriptor.dimensions(),
|
||||
descriptor.distanceMetric(),
|
||||
descriptor.supportsQueryEmbeddingMode(),
|
||||
descriptor.active()
|
||||
));
|
||||
return documentEmbeddingService.findActiveModelByKey(modelKey);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,73 @@
|
||||
package at.procon.dip.embedding.service;
|
||||
|
||||
import at.procon.dip.domain.document.EmbeddingStatus;
|
||||
import at.procon.dip.domain.document.entity.DocumentEmbedding;
|
||||
import at.procon.dip.domain.document.entity.DocumentEmbeddingModel;
|
||||
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
|
||||
import at.procon.dip.domain.document.repository.DocumentEmbeddingRepository;
|
||||
import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository;
|
||||
import at.procon.dip.domain.document.service.DocumentEmbeddingService;
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.support.EmbeddingVectorCodec;
|
||||
import java.time.OffsetDateTime;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Transactional
|
||||
public class EmbeddingPersistenceService {
|
||||
|
||||
private final DocumentTextRepresentationRepository representationRepository;
|
||||
private final DocumentEmbeddingService documentEmbeddingService;
|
||||
private final DocumentEmbeddingRepository embeddingRepository;
|
||||
private final EmbeddingModelCatalogService modelCatalogService;
|
||||
|
||||
public DocumentEmbedding ensurePending(UUID representationId, String modelKey) {
|
||||
DocumentTextRepresentation representation = representationRepository.findById(representationId)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Unknown representation id: " + representationId));
|
||||
DocumentEmbeddingModel model = modelCatalogService.ensureRegistered(modelKey);
|
||||
return documentEmbeddingService.ensurePendingEmbedding(
|
||||
representation.getDocument().getId(),
|
||||
representation.getId(),
|
||||
model.getId()
|
||||
);
|
||||
}
|
||||
|
||||
public DocumentEmbedding markProcessing(UUID embeddingId) {
|
||||
return documentEmbeddingService.markProcessing(embeddingId);
|
||||
}
|
||||
|
||||
public void saveCompleted(UUID embeddingId, EmbeddingProviderResult result) {
|
||||
if (result.vectors() == null || result.vectors().isEmpty()) {
|
||||
throw new IllegalArgumentException("Embedding provider result contains no vectors");
|
||||
}
|
||||
float[] vector = result.vectors().getFirst();
|
||||
embeddingRepository.updateEmbeddingVector(
|
||||
embeddingId,
|
||||
EmbeddingVectorCodec.toPgVector(vector),
|
||||
result.tokenCount(),
|
||||
vector.length
|
||||
);
|
||||
}
|
||||
|
||||
public void markFailed(UUID embeddingId, String errorMessage) {
|
||||
embeddingRepository.updateEmbeddingStatus(
|
||||
embeddingId,
|
||||
EmbeddingStatus.FAILED,
|
||||
errorMessage,
|
||||
null
|
||||
);
|
||||
}
|
||||
|
||||
public void markSkipped(UUID embeddingId, String reason) {
|
||||
embeddingRepository.updateEmbeddingStatus(
|
||||
embeddingId,
|
||||
EmbeddingStatus.SKIPPED,
|
||||
reason,
|
||||
OffsetDateTime.now()
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,96 @@
|
||||
package at.procon.dip.embedding.service;
|
||||
|
||||
import at.procon.dip.domain.document.entity.DocumentEmbedding;
|
||||
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
|
||||
import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository;
|
||||
import at.procon.dip.embedding.config.EmbeddingProperties;
|
||||
import at.procon.dip.embedding.job.entity.EmbeddingJob;
|
||||
import at.procon.dip.embedding.job.service.EmbeddingJobService;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobType;
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
||||
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class RepresentationEmbeddingOrchestrator {
|
||||
|
||||
private final EmbeddingJobService jobService;
|
||||
private final EmbeddingExecutionService executionService;
|
||||
private final EmbeddingPersistenceService persistenceService;
|
||||
private final DocumentTextRepresentationRepository representationRepository;
|
||||
private final EmbeddingModelRegistry modelRegistry;
|
||||
private final EmbeddingProperties embeddingProperties;
|
||||
|
||||
@Transactional
|
||||
public List<EmbeddingJob> enqueueDocument(UUID documentId) {
|
||||
return jobService.enqueueForDocument(documentId);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public List<EmbeddingJob> enqueueDocument(UUID documentId, String modelKey) {
|
||||
return jobService.enqueueForDocument(documentId, modelKey);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public EmbeddingJob enqueueRepresentation(UUID documentId, UUID representationId, String modelKey) {
|
||||
return jobService.enqueueForRepresentation(documentId, representationId, modelKey, EmbeddingJobType.DOCUMENT_EMBED);
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public int processNextReadyBatch() {
|
||||
if (!embeddingProperties.isEnabled() || !embeddingProperties.getJobs().isEnabled()) {
|
||||
log.debug("New embedding subsystem jobs are disabled");
|
||||
return 0;
|
||||
}
|
||||
|
||||
List<EmbeddingJob> jobs = jobService.claimNextReadyJobs(embeddingProperties.getJobs().getBatchSize());
|
||||
for (EmbeddingJob job : jobs) {
|
||||
processClaimedJob(job);
|
||||
}
|
||||
return jobs.size();
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void processClaimedJob(EmbeddingJob job) {
|
||||
DocumentTextRepresentation representation = representationRepository.findById(job.getRepresentationId())
|
||||
.orElseThrow(() -> new IllegalArgumentException("Unknown representation id: " + job.getRepresentationId()));
|
||||
|
||||
String text = representation.getTextBody();
|
||||
if (text == null || text.isBlank()) {
|
||||
jobService.markFailed(job.getId(), "No text representation available", false);
|
||||
return;
|
||||
}
|
||||
|
||||
int maxChars = modelRegistry.getRequired(job.getModelKey()).maxInputChars() != null
|
||||
? modelRegistry.getRequired(job.getModelKey()).maxInputChars()
|
||||
: embeddingProperties.getIndexing().getFallbackMaxInputChars();
|
||||
if (text.length() > maxChars) {
|
||||
text = text.substring(0, maxChars);
|
||||
}
|
||||
|
||||
DocumentEmbedding embedding = persistenceService.ensurePending(representation.getId(), job.getModelKey());
|
||||
persistenceService.markProcessing(embedding.getId());
|
||||
|
||||
try {
|
||||
EmbeddingProviderResult result = executionService.embedTexts(
|
||||
job.getModelKey(),
|
||||
EmbeddingUseCase.DOCUMENT,
|
||||
List.of(text)
|
||||
);
|
||||
persistenceService.saveCompleted(embedding.getId(), result);
|
||||
jobService.markCompleted(job.getId(), result.providerRequestId());
|
||||
} catch (RuntimeException ex) {
|
||||
persistenceService.markFailed(embedding.getId(), ex.getMessage());
|
||||
jobService.markFailed(job.getId(), ex.getMessage(), true);
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,23 @@
|
||||
package at.procon.dip.embedding.support;
|
||||
|
||||
public final class EmbeddingVectorCodec {
|
||||
|
||||
private EmbeddingVectorCodec() {
|
||||
}
|
||||
|
||||
public static String toPgVector(float[] vector) {
|
||||
if (vector == null || vector.length == 0) {
|
||||
throw new IllegalArgumentException("Embedding vector must not be null or empty");
|
||||
}
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append('[');
|
||||
for (int i = 0; i < vector.length; i++) {
|
||||
if (i > 0) {
|
||||
builder.append(',');
|
||||
}
|
||||
builder.append(Float.toString(vector[i]));
|
||||
}
|
||||
builder.append(']');
|
||||
return builder.toString();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,30 @@
|
||||
-- NV2: Parallel embedding subsystem job table
|
||||
-- Additive and independent from the legacy/transitional vectorization flow.
|
||||
|
||||
CREATE TABLE IF NOT EXISTS DOC.doc_embedding_job (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
document_id UUID NOT NULL REFERENCES DOC.doc_document(id) ON DELETE CASCADE,
|
||||
representation_id UUID NOT NULL REFERENCES DOC.doc_text_representation(id) ON DELETE CASCADE,
|
||||
model_key VARCHAR(255) NOT NULL,
|
||||
job_type VARCHAR(32) NOT NULL,
|
||||
status VARCHAR(32) NOT NULL DEFAULT 'PENDING',
|
||||
attempt_count INTEGER NOT NULL DEFAULT 0,
|
||||
next_retry_at TIMESTAMP WITH TIME ZONE,
|
||||
priority INTEGER NOT NULL DEFAULT 0,
|
||||
provider_request_id VARCHAR(255),
|
||||
last_error TEXT,
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_doc_embedding_job_status_next_retry
|
||||
ON DOC.doc_embedding_job(status, next_retry_at);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_doc_embedding_job_representation
|
||||
ON DOC.doc_embedding_job(representation_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_doc_embedding_job_document
|
||||
ON DOC.doc_embedding_job(document_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_doc_embedding_job_model_key
|
||||
ON DOC.doc_embedding_job(model_key);
|
||||
@ -0,0 +1,78 @@
|
||||
package at.procon.dip.embedding.policy;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import at.procon.dip.domain.document.RepresentationType;
|
||||
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
|
||||
import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository;
|
||||
import at.procon.dip.embedding.config.EmbeddingProperties;
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.domain.document.DistanceMetric;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class DefaultEmbeddingSelectionPolicyTest {
|
||||
|
||||
@Mock
|
||||
private DocumentTextRepresentationRepository representationRepository;
|
||||
|
||||
private DefaultEmbeddingSelectionPolicy policy;
|
||||
private EmbeddingModelDescriptor model;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
EmbeddingProperties properties = new EmbeddingProperties();
|
||||
properties.getIndexing().setEmbedSemanticText(true);
|
||||
properties.getIndexing().setEmbedTitleAbstract(true);
|
||||
properties.getIndexing().setEmbedChunks(true);
|
||||
properties.getIndexing().setChunkMinLength(20);
|
||||
policy = new DefaultEmbeddingSelectionPolicy(representationRepository, properties);
|
||||
model = new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16,
|
||||
DistanceMetric.COSINE, true, false, 4096, true);
|
||||
}
|
||||
|
||||
@Test
|
||||
void should_select_semantic_text_title_and_large_chunks() {
|
||||
UUID documentId = UUID.randomUUID();
|
||||
DocumentTextRepresentation semantic = representation(RepresentationType.SEMANTIC_TEXT, true, null,
|
||||
"semantic text");
|
||||
DocumentTextRepresentation titleAbstract = representation(RepresentationType.TITLE_ABSTRACT, false, null,
|
||||
"title abstract text");
|
||||
DocumentTextRepresentation shortChunk = representation(RepresentationType.CHUNK, false, 0,
|
||||
"too short");
|
||||
DocumentTextRepresentation longChunk = representation(RepresentationType.CHUNK, false, 1,
|
||||
"This is a sufficiently long chunk that should be selected.");
|
||||
when(representationRepository.findByDocument_Id(documentId))
|
||||
.thenReturn(List.of(longChunk, semantic, shortChunk, titleAbstract));
|
||||
|
||||
List<DocumentTextRepresentation> selected = policy.selectRepresentations(documentId, model);
|
||||
|
||||
assertThat(selected)
|
||||
.extracting(DocumentTextRepresentation::getRepresentationType)
|
||||
.containsExactly(RepresentationType.SEMANTIC_TEXT, RepresentationType.TITLE_ABSTRACT, RepresentationType.CHUNK);
|
||||
assertThat(selected)
|
||||
.extracting(DocumentTextRepresentation::getChunkIndex)
|
||||
.containsExactly(null, null, 1);
|
||||
}
|
||||
|
||||
private DocumentTextRepresentation representation(RepresentationType type,
|
||||
boolean primary,
|
||||
Integer chunkIndex,
|
||||
String text) {
|
||||
return DocumentTextRepresentation.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.representationType(type)
|
||||
.primaryRepresentation(primary)
|
||||
.chunkIndex(chunkIndex)
|
||||
.charCount(text.length())
|
||||
.textBody(text)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,105 @@
|
||||
package at.procon.dip.embedding.service;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import at.procon.dip.domain.document.entity.Document;
|
||||
import at.procon.dip.domain.document.entity.DocumentEmbedding;
|
||||
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
|
||||
import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository;
|
||||
import at.procon.dip.embedding.config.EmbeddingProperties;
|
||||
import at.procon.dip.embedding.job.entity.EmbeddingJob;
|
||||
import at.procon.dip.embedding.job.service.EmbeddingJobService;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobStatus;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobType;
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
||||
import at.procon.dip.domain.document.DistanceMetric;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class RepresentationEmbeddingOrchestratorTest {
|
||||
|
||||
@Mock
|
||||
private EmbeddingJobService jobService;
|
||||
@Mock
|
||||
private EmbeddingExecutionService executionService;
|
||||
@Mock
|
||||
private EmbeddingPersistenceService persistenceService;
|
||||
@Mock
|
||||
private DocumentTextRepresentationRepository representationRepository;
|
||||
@Mock
|
||||
private EmbeddingModelRegistry modelRegistry;
|
||||
|
||||
private RepresentationEmbeddingOrchestrator orchestrator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
EmbeddingProperties properties = new EmbeddingProperties();
|
||||
properties.setEnabled(true);
|
||||
properties.getJobs().setEnabled(true);
|
||||
properties.getJobs().setBatchSize(10);
|
||||
properties.getIndexing().setFallbackMaxInputChars(8192);
|
||||
orchestrator = new RepresentationEmbeddingOrchestrator(
|
||||
jobService,
|
||||
executionService,
|
||||
persistenceService,
|
||||
representationRepository,
|
||||
modelRegistry,
|
||||
properties
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void processClaimedJob_should_embed_representation_and_mark_job_completed() {
|
||||
UUID documentId = UUID.randomUUID();
|
||||
UUID representationId = UUID.randomUUID();
|
||||
UUID embeddingId = UUID.randomUUID();
|
||||
EmbeddingJob job = EmbeddingJob.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.documentId(documentId)
|
||||
.representationId(representationId)
|
||||
.modelKey("mock-search")
|
||||
.jobType(EmbeddingJobType.DOCUMENT_EMBED)
|
||||
.status(EmbeddingJobStatus.IN_PROGRESS)
|
||||
.attemptCount(1)
|
||||
.build();
|
||||
|
||||
Document document = Document.builder().id(documentId).build();
|
||||
DocumentTextRepresentation representation = DocumentTextRepresentation.builder()
|
||||
.id(representationId)
|
||||
.document(document)
|
||||
.textBody("District heating optimization strategy")
|
||||
.build();
|
||||
when(representationRepository.findById(representationId)).thenReturn(java.util.Optional.of(representation));
|
||||
when(modelRegistry.getRequired("mock-search"))
|
||||
.thenReturn(new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16,
|
||||
DistanceMetric.COSINE, true, false, 4096, true));
|
||||
when(persistenceService.ensurePending(representationId, "mock-search"))
|
||||
.thenReturn(DocumentEmbedding.builder().id(embeddingId).build());
|
||||
when(executionService.embedTexts(eq("mock-search"), any(), any()))
|
||||
.thenReturn(new EmbeddingProviderResult(
|
||||
new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16,
|
||||
DistanceMetric.COSINE, true, false, 4096, true),
|
||||
List.of(new float[]{0.1f, 0.2f}),
|
||||
List.of(),
|
||||
"req-1",
|
||||
12
|
||||
));
|
||||
|
||||
orchestrator.processClaimedJob(job);
|
||||
|
||||
verify(persistenceService).markProcessing(embeddingId);
|
||||
verify(persistenceService).saveCompleted(eq(embeddingId), any(EmbeddingProviderResult.class));
|
||||
verify(jobService).markCompleted(job.getId(), "req-1");
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,14 @@
|
||||
package at.procon.dip.embedding.support;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class EmbeddingVectorCodecTest {
|
||||
|
||||
@Test
|
||||
void should_render_pgvector_literal() {
|
||||
String rendered = EmbeddingVectorCodec.toPgVector(new float[]{1.0f, 2.5f, -0.75f});
|
||||
assertThat(rendered).isEqualTo("[1.0,2.5,-0.75]");
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue