diff --git a/src/main/java/at/procon/dip/README_PHASE0.md b/docs/README_PHASE0.md similarity index 100% rename from src/main/java/at/procon/dip/README_PHASE0.md rename to docs/README_PHASE0.md diff --git a/src/main/java/at/procon/dip/README_PHASE1.md b/docs/README_PHASE1.md similarity index 100% rename from src/main/java/at/procon/dip/README_PHASE1.md rename to docs/README_PHASE1.md diff --git a/src/main/java/at/procon/dip/README_PHASE2.md b/docs/README_PHASE2.md similarity index 100% rename from src/main/java/at/procon/dip/README_PHASE2.md rename to docs/README_PHASE2.md diff --git a/src/main/java/at/procon/dip/README_PHASE3.md b/docs/README_PHASE3.md similarity index 100% rename from src/main/java/at/procon/dip/README_PHASE3.md rename to docs/README_PHASE3.md diff --git a/src/main/java/at/procon/dip/README_PHASE4.md b/docs/README_PHASE4.md similarity index 100% rename from src/main/java/at/procon/dip/README_PHASE4.md rename to docs/README_PHASE4.md diff --git a/src/main/java/at/procon/dip/README_PHASE4_1.md b/docs/README_PHASE4_1.md similarity index 100% rename from src/main/java/at/procon/dip/README_PHASE4_1.md rename to docs/README_PHASE4_1.md diff --git a/src/main/java/at/procon/dip/DocumentIntelligencePlatformApplication.java b/src/main/java/at/procon/dip/DocumentIntelligencePlatformApplication.java index a172f94..b7f2d7e 100644 --- a/src/main/java/at/procon/dip/DocumentIntelligencePlatformApplication.java +++ b/src/main/java/at/procon/dip/DocumentIntelligencePlatformApplication.java @@ -18,8 +18,8 @@ import org.springframework.scheduling.annotation.EnableAsync; @SpringBootApplication(scanBasePackages = {"at.procon.dip", "at.procon.ted"}) @EnableAsync //@EnableConfigurationProperties(TedProcessorProperties.class) -@EntityScan(basePackages = {"at.procon.ted.model.entity", "at.procon.dip.domain.document.entity", "at.procon.dip.domain.tenant.entity", "at.procon.dip.domain.ted.entity"}) -@EnableJpaRepositories(basePackages = {"at.procon.ted.repository", "at.procon.dip.domain.document.repository", "at.procon.dip.domain.tenant.repository", "at.procon.dip.domain.ted.repository"}) +@EntityScan(basePackages = {"at.procon.ted.model.entity", "at.procon.dip.domain.document.entity", "at.procon.dip.domain.tenant.entity", "at.procon.dip.domain.ted.entity", "at.procon.dip.embedding.job.entity"}) +@EnableJpaRepositories(basePackages = {"at.procon.ted.repository", "at.procon.dip.domain.document.repository", "at.procon.dip.domain.tenant.repository", "at.procon.dip.domain.ted.repository", "at.procon.dip.embedding.job.repository"}) public class DocumentIntelligencePlatformApplication { public static void main(String[] args) { diff --git a/src/main/java/at/procon/dip/embedding/config/EmbeddingProperties.java b/src/main/java/at/procon/dip/embedding/config/EmbeddingProperties.java index 3a34c91..56d2ad6 100644 --- a/src/main/java/at/procon/dip/embedding/config/EmbeddingProperties.java +++ b/src/main/java/at/procon/dip/embedding/config/EmbeddingProperties.java @@ -18,6 +18,8 @@ public class EmbeddingProperties { private String defaultQueryModel; private Map providers = new LinkedHashMap<>(); private Map models = new LinkedHashMap<>(); + private IndexingProperties indexing = new IndexingProperties(); + private JobsProperties jobs = new JobsProperties(); @Data public static class ProviderProperties { @@ -41,4 +43,24 @@ public class EmbeddingProperties { private Integer maxInputChars; private boolean active = true; } + + @Data + public static class IndexingProperties { + private boolean embedSemanticText = true; + private boolean embedTitleAbstract = true; + private boolean embedChunks = true; + private boolean embedFulltext = false; + private boolean embedSummary = false; + private int chunkMinLength = 300; + private int fallbackMaxInputChars = 8192; + } + + @Data + public static class JobsProperties { + private boolean enabled = false; + private int batchSize = 16; + private int maxRetries = 5; + private Duration initialRetryDelay = Duration.ofSeconds(30); + private Duration maxRetryDelay = Duration.ofHours(6); + } } diff --git a/src/main/java/at/procon/dip/embedding/job/entity/EmbeddingJob.java b/src/main/java/at/procon/dip/embedding/job/entity/EmbeddingJob.java new file mode 100644 index 0000000..9d3430f --- /dev/null +++ b/src/main/java/at/procon/dip/embedding/job/entity/EmbeddingJob.java @@ -0,0 +1,96 @@ +package at.procon.dip.embedding.job.entity; + +import at.procon.dip.architecture.SchemaNames; +import at.procon.dip.embedding.model.EmbeddingJobStatus; +import at.procon.dip.embedding.model.EmbeddingJobType; +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.EnumType; +import jakarta.persistence.Enumerated; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.GenerationType; +import jakarta.persistence.Id; +import jakarta.persistence.Index; +import jakarta.persistence.PrePersist; +import jakarta.persistence.PreUpdate; +import jakarta.persistence.Table; +import java.time.OffsetDateTime; +import java.util.UUID; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Entity +@Table(schema = SchemaNames.DOC, name = "doc_embedding_job", indexes = { + @Index(name = "idx_doc_embedding_job_status_next_retry", columnList = "status,next_retry_at"), + @Index(name = "idx_doc_embedding_job_representation", columnList = "representation_id"), + @Index(name = "idx_doc_embedding_job_document", columnList = "document_id"), + @Index(name = "idx_doc_embedding_job_model_key", columnList = "model_key") +}) +@Getter +@Setter +@NoArgsConstructor +@AllArgsConstructor +@Builder +public class EmbeddingJob { + + @Id + @GeneratedValue(strategy = GenerationType.UUID) + private UUID id; + + @Column(name = "document_id", nullable = false) + private UUID documentId; + + @Column(name = "representation_id", nullable = false) + private UUID representationId; + + @Column(name = "model_key", nullable = false, length = 255) + private String modelKey; + + @Enumerated(EnumType.STRING) + @Column(name = "job_type", nullable = false, length = 32) + private EmbeddingJobType jobType; + + @Enumerated(EnumType.STRING) + @Column(name = "status", nullable = false, length = 32) + @Builder.Default + private EmbeddingJobStatus status = EmbeddingJobStatus.PENDING; + + @Column(name = "attempt_count", nullable = false) + @Builder.Default + private int attemptCount = 0; + + @Column(name = "next_retry_at") + private OffsetDateTime nextRetryAt; + + @Column(name = "priority", nullable = false) + @Builder.Default + private int priority = 0; + + @Column(name = "provider_request_id", length = 255) + private String providerRequestId; + + @Column(name = "last_error", columnDefinition = "TEXT") + private String lastError; + + @Column(name = "created_at", nullable = false, updatable = false) + @Builder.Default + private OffsetDateTime createdAt = OffsetDateTime.now(); + + @Column(name = "updated_at", nullable = false) + @Builder.Default + private OffsetDateTime updatedAt = OffsetDateTime.now(); + + @PrePersist + protected void onCreate() { + createdAt = OffsetDateTime.now(); + updatedAt = OffsetDateTime.now(); + } + + @PreUpdate + protected void onUpdate() { + updatedAt = OffsetDateTime.now(); + } +} diff --git a/src/main/java/at/procon/dip/embedding/job/repository/EmbeddingJobRepository.java b/src/main/java/at/procon/dip/embedding/job/repository/EmbeddingJobRepository.java new file mode 100644 index 0000000..646c6c3 --- /dev/null +++ b/src/main/java/at/procon/dip/embedding/job/repository/EmbeddingJobRepository.java @@ -0,0 +1,33 @@ +package at.procon.dip.embedding.job.repository; + +import at.procon.dip.embedding.job.entity.EmbeddingJob; +import at.procon.dip.embedding.model.EmbeddingJobStatus; +import at.procon.dip.embedding.model.EmbeddingJobType; +import jakarta.persistence.LockModeType; +import java.time.OffsetDateTime; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import org.springframework.data.domain.Pageable; +import org.springframework.data.jpa.repository.JpaRepository; +import org.springframework.data.jpa.repository.Lock; +import org.springframework.data.jpa.repository.Query; +import org.springframework.data.repository.query.Param; + +public interface EmbeddingJobRepository extends JpaRepository { + + Optional findFirstByRepresentationIdAndModelKeyAndJobTypeAndStatusIn( + UUID representationId, + String modelKey, + EmbeddingJobType jobType, + Collection statuses); + + @Lock(LockModeType.PESSIMISTIC_WRITE) + @Query("SELECT j FROM EmbeddingJob j WHERE j.status IN :statuses AND (j.nextRetryAt IS NULL OR j.nextRetryAt <= :now) ORDER BY j.priority DESC, j.createdAt ASC") + List findReadyJobsForUpdate(@Param("statuses") Collection statuses, + @Param("now") OffsetDateTime now, + Pageable pageable); + + List findByDocumentId(UUID documentId); +} diff --git a/src/main/java/at/procon/dip/embedding/job/service/EmbeddingJobService.java b/src/main/java/at/procon/dip/embedding/job/service/EmbeddingJobService.java new file mode 100644 index 0000000..9cae485 --- /dev/null +++ b/src/main/java/at/procon/dip/embedding/job/service/EmbeddingJobService.java @@ -0,0 +1,116 @@ +package at.procon.dip.embedding.job.service; + +import at.procon.dip.embedding.config.EmbeddingProperties; +import at.procon.dip.embedding.job.entity.EmbeddingJob; +import at.procon.dip.embedding.job.repository.EmbeddingJobRepository; +import at.procon.dip.embedding.model.EmbeddingJobStatus; +import at.procon.dip.embedding.model.EmbeddingJobType; +import at.procon.dip.embedding.policy.EmbeddingSelectionPolicy; +import at.procon.dip.embedding.registry.EmbeddingModelRegistry; +import at.procon.dip.domain.document.entity.DocumentTextRepresentation; +import java.time.Duration; +import java.time.OffsetDateTime; +import java.util.List; +import java.util.Set; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import org.springframework.data.domain.PageRequest; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@RequiredArgsConstructor +@Transactional +public class EmbeddingJobService { + + private static final Set ACTIVE_STATUSES = Set.of( + EmbeddingJobStatus.PENDING, + EmbeddingJobStatus.IN_PROGRESS, + EmbeddingJobStatus.RETRY_SCHEDULED + ); + + private final EmbeddingJobRepository jobRepository; + private final EmbeddingSelectionPolicy selectionPolicy; + private final EmbeddingModelRegistry modelRegistry; + private final EmbeddingProperties properties; + + public List enqueueForDocument(UUID documentId) { + return enqueueForDocument(documentId, modelRegistry.getRequiredDefaultDocumentModelKey()); + } + + public List enqueueForDocument(UUID documentId, String modelKey) { + var model = modelRegistry.getRequired(modelKey); + List selected = selectionPolicy.selectRepresentations(documentId, model); + return selected.stream() + .map(representation -> enqueueForRepresentation(documentId, representation.getId(), modelKey, EmbeddingJobType.DOCUMENT_EMBED)) + .toList(); + } + + public EmbeddingJob enqueueForRepresentation(UUID documentId, UUID representationId, String modelKey, EmbeddingJobType jobType) { + return jobRepository.findFirstByRepresentationIdAndModelKeyAndJobTypeAndStatusIn( + representationId, + modelKey, + jobType, + ACTIVE_STATUSES + ) + .orElseGet(() -> jobRepository.save(EmbeddingJob.builder() + .documentId(documentId) + .representationId(representationId) + .modelKey(modelKey) + .jobType(jobType) + .status(EmbeddingJobStatus.PENDING) + .priority(0) + .attemptCount(0) + .build())); + } + + public List claimNextReadyJobs(int limit) { + List jobs = jobRepository.findReadyJobsForUpdate( + Set.of(EmbeddingJobStatus.PENDING, EmbeddingJobStatus.RETRY_SCHEDULED), + OffsetDateTime.now(), + PageRequest.of(0, limit) + ); + jobs.forEach(job -> { + job.setStatus(EmbeddingJobStatus.IN_PROGRESS); + job.setAttemptCount(job.getAttemptCount() + 1); + job.setLastError(null); + job.setNextRetryAt(null); + }); + return jobRepository.saveAll(jobs); + } + + public EmbeddingJob getRequired(UUID jobId) { + return jobRepository.findById(jobId) + .orElseThrow(() -> new IllegalArgumentException("Unknown embedding job id: " + jobId)); + } + + public void markCompleted(UUID jobId, String providerRequestId) { + EmbeddingJob job = getRequired(jobId); + job.setStatus(EmbeddingJobStatus.DONE); + job.setProviderRequestId(providerRequestId); + job.setLastError(null); + job.setNextRetryAt(null); + jobRepository.save(job); + } + + public void markFailed(UUID jobId, String errorMessage, boolean retryable) { + EmbeddingJob job = getRequired(jobId); + job.setLastError(errorMessage); + if (retryable && job.getAttemptCount() < properties.getJobs().getMaxRetries()) { + job.setStatus(EmbeddingJobStatus.RETRY_SCHEDULED); + job.setNextRetryAt(OffsetDateTime.now().plus(nextDelay(job.getAttemptCount()))); + } else { + job.setStatus(EmbeddingJobStatus.FAILED); + job.setNextRetryAt(null); + } + jobRepository.save(job); + } + + private Duration nextDelay(int attemptCount) { + long factor = Math.max(0, attemptCount - 1); + Duration candidate = properties.getJobs().getInitialRetryDelay().multipliedBy(1L << Math.min(factor, 10)); + return candidate.compareTo(properties.getJobs().getMaxRetryDelay()) > 0 + ? properties.getJobs().getMaxRetryDelay() + : candidate; + } +} diff --git a/src/main/java/at/procon/dip/embedding/model/EmbeddingJobStatus.java b/src/main/java/at/procon/dip/embedding/model/EmbeddingJobStatus.java new file mode 100644 index 0000000..461cec0 --- /dev/null +++ b/src/main/java/at/procon/dip/embedding/model/EmbeddingJobStatus.java @@ -0,0 +1,9 @@ +package at.procon.dip.embedding.model; + +public enum EmbeddingJobStatus { + PENDING, + IN_PROGRESS, + DONE, + FAILED, + RETRY_SCHEDULED +} diff --git a/src/main/java/at/procon/dip/embedding/model/EmbeddingJobType.java b/src/main/java/at/procon/dip/embedding/model/EmbeddingJobType.java new file mode 100644 index 0000000..5323682 --- /dev/null +++ b/src/main/java/at/procon/dip/embedding/model/EmbeddingJobType.java @@ -0,0 +1,7 @@ +package at.procon.dip.embedding.model; + +public enum EmbeddingJobType { + DOCUMENT_EMBED, + QUERY_EMBED, + REEMBED +} diff --git a/src/main/java/at/procon/dip/embedding/policy/DefaultEmbeddingSelectionPolicy.java b/src/main/java/at/procon/dip/embedding/policy/DefaultEmbeddingSelectionPolicy.java new file mode 100644 index 0000000..033483b --- /dev/null +++ b/src/main/java/at/procon/dip/embedding/policy/DefaultEmbeddingSelectionPolicy.java @@ -0,0 +1,68 @@ +package at.procon.dip.embedding.policy; + +import at.procon.dip.domain.document.RepresentationType; +import at.procon.dip.domain.document.entity.DocumentTextRepresentation; +import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository; +import at.procon.dip.embedding.config.EmbeddingProperties; +import at.procon.dip.embedding.model.EmbeddingModelDescriptor; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Component; + +@Component +@RequiredArgsConstructor +public class DefaultEmbeddingSelectionPolicy implements EmbeddingSelectionPolicy { + + private final DocumentTextRepresentationRepository representationRepository; + private final EmbeddingProperties embeddingProperties; + + @Override + public List selectRepresentations(UUID documentId, EmbeddingModelDescriptor model) { + List representations = representationRepository.findByDocument_Id(documentId); + List selected = new ArrayList<>(); + EmbeddingProperties.IndexingProperties indexing = embeddingProperties.getIndexing(); + + for (DocumentTextRepresentation representation : representations) { + if (include(representation, indexing)) { + selected.add(representation); + } + } + + if (selected.isEmpty()) { + representationRepository.findFirstByDocument_IdAndPrimaryRepresentationTrue(documentId) + .ifPresent(selected::add); + } + + Set seen = new LinkedHashSet<>(); + return selected.stream() + .filter(rep -> seen.add(rep.getId())) + .sorted(Comparator + .comparing(DocumentTextRepresentation::isPrimaryRepresentation).reversed() + .thenComparing(rep -> rep.getRepresentationType().ordinal()) + .thenComparing(rep -> rep.getChunkIndex() == null ? Integer.MAX_VALUE : rep.getChunkIndex())) + .toList(); + } + + private boolean include(DocumentTextRepresentation representation, EmbeddingProperties.IndexingProperties indexing) { + return switch (representation.getRepresentationType()) { + case SEMANTIC_TEXT -> indexing.isEmbedSemanticText(); + case TITLE_ABSTRACT -> indexing.isEmbedTitleAbstract(); + case CHUNK -> indexing.isEmbedChunks() && effectiveLength(representation) >= indexing.getChunkMinLength(); + case FULLTEXT -> indexing.isEmbedFulltext(); + case SUMMARY -> indexing.isEmbedSummary(); + default -> representation.isPrimaryRepresentation(); + }; + } + + private int effectiveLength(DocumentTextRepresentation representation) { + if (representation.getCharCount() != null) { + return representation.getCharCount(); + } + return representation.getTextBody() == null ? 0 : representation.getTextBody().length(); + } +} diff --git a/src/main/java/at/procon/dip/embedding/policy/EmbeddingSelectionPolicy.java b/src/main/java/at/procon/dip/embedding/policy/EmbeddingSelectionPolicy.java new file mode 100644 index 0000000..1fbf538 --- /dev/null +++ b/src/main/java/at/procon/dip/embedding/policy/EmbeddingSelectionPolicy.java @@ -0,0 +1,11 @@ +package at.procon.dip.embedding.policy; + +import at.procon.dip.domain.document.entity.DocumentTextRepresentation; +import at.procon.dip.embedding.model.EmbeddingModelDescriptor; +import java.util.List; +import java.util.UUID; + +public interface EmbeddingSelectionPolicy { + + List selectRepresentations(UUID documentId, EmbeddingModelDescriptor model); +} diff --git a/src/main/java/at/procon/dip/embedding/service/EmbeddingModelCatalogService.java b/src/main/java/at/procon/dip/embedding/service/EmbeddingModelCatalogService.java new file mode 100644 index 0000000..77e10a7 --- /dev/null +++ b/src/main/java/at/procon/dip/embedding/service/EmbeddingModelCatalogService.java @@ -0,0 +1,33 @@ +package at.procon.dip.embedding.service; + +import at.procon.dip.domain.document.entity.DocumentEmbeddingModel; +import at.procon.dip.domain.document.service.DocumentEmbeddingService; +import at.procon.dip.domain.document.service.command.RegisterEmbeddingModelCommand; +import at.procon.dip.embedding.model.EmbeddingModelDescriptor; +import at.procon.dip.embedding.registry.EmbeddingModelRegistry; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@RequiredArgsConstructor +@Transactional +public class EmbeddingModelCatalogService { + + private final EmbeddingModelRegistry modelRegistry; + private final DocumentEmbeddingService documentEmbeddingService; + + public DocumentEmbeddingModel ensureRegistered(String modelKey) { + EmbeddingModelDescriptor descriptor = modelRegistry.getRequired(modelKey); + documentEmbeddingService.registerModel(new RegisterEmbeddingModelCommand( + descriptor.modelKey(), + descriptor.providerConfigKey(), + descriptor.providerModelKey(), + descriptor.dimensions(), + descriptor.distanceMetric(), + descriptor.supportsQueryEmbeddingMode(), + descriptor.active() + )); + return documentEmbeddingService.findActiveModelByKey(modelKey); + } +} diff --git a/src/main/java/at/procon/dip/embedding/service/EmbeddingPersistenceService.java b/src/main/java/at/procon/dip/embedding/service/EmbeddingPersistenceService.java new file mode 100644 index 0000000..502b477 --- /dev/null +++ b/src/main/java/at/procon/dip/embedding/service/EmbeddingPersistenceService.java @@ -0,0 +1,73 @@ +package at.procon.dip.embedding.service; + +import at.procon.dip.domain.document.EmbeddingStatus; +import at.procon.dip.domain.document.entity.DocumentEmbedding; +import at.procon.dip.domain.document.entity.DocumentEmbeddingModel; +import at.procon.dip.domain.document.entity.DocumentTextRepresentation; +import at.procon.dip.domain.document.repository.DocumentEmbeddingRepository; +import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository; +import at.procon.dip.domain.document.service.DocumentEmbeddingService; +import at.procon.dip.embedding.model.EmbeddingProviderResult; +import at.procon.dip.embedding.support.EmbeddingVectorCodec; +import java.time.OffsetDateTime; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@RequiredArgsConstructor +@Transactional +public class EmbeddingPersistenceService { + + private final DocumentTextRepresentationRepository representationRepository; + private final DocumentEmbeddingService documentEmbeddingService; + private final DocumentEmbeddingRepository embeddingRepository; + private final EmbeddingModelCatalogService modelCatalogService; + + public DocumentEmbedding ensurePending(UUID representationId, String modelKey) { + DocumentTextRepresentation representation = representationRepository.findById(representationId) + .orElseThrow(() -> new IllegalArgumentException("Unknown representation id: " + representationId)); + DocumentEmbeddingModel model = modelCatalogService.ensureRegistered(modelKey); + return documentEmbeddingService.ensurePendingEmbedding( + representation.getDocument().getId(), + representation.getId(), + model.getId() + ); + } + + public DocumentEmbedding markProcessing(UUID embeddingId) { + return documentEmbeddingService.markProcessing(embeddingId); + } + + public void saveCompleted(UUID embeddingId, EmbeddingProviderResult result) { + if (result.vectors() == null || result.vectors().isEmpty()) { + throw new IllegalArgumentException("Embedding provider result contains no vectors"); + } + float[] vector = result.vectors().getFirst(); + embeddingRepository.updateEmbeddingVector( + embeddingId, + EmbeddingVectorCodec.toPgVector(vector), + result.tokenCount(), + vector.length + ); + } + + public void markFailed(UUID embeddingId, String errorMessage) { + embeddingRepository.updateEmbeddingStatus( + embeddingId, + EmbeddingStatus.FAILED, + errorMessage, + null + ); + } + + public void markSkipped(UUID embeddingId, String reason) { + embeddingRepository.updateEmbeddingStatus( + embeddingId, + EmbeddingStatus.SKIPPED, + reason, + OffsetDateTime.now() + ); + } +} diff --git a/src/main/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestrator.java b/src/main/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestrator.java new file mode 100644 index 0000000..3222135 --- /dev/null +++ b/src/main/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestrator.java @@ -0,0 +1,96 @@ +package at.procon.dip.embedding.service; + +import at.procon.dip.domain.document.entity.DocumentEmbedding; +import at.procon.dip.domain.document.entity.DocumentTextRepresentation; +import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository; +import at.procon.dip.embedding.config.EmbeddingProperties; +import at.procon.dip.embedding.job.entity.EmbeddingJob; +import at.procon.dip.embedding.job.service.EmbeddingJobService; +import at.procon.dip.embedding.model.EmbeddingJobType; +import at.procon.dip.embedding.model.EmbeddingProviderResult; +import at.procon.dip.embedding.model.EmbeddingUseCase; +import at.procon.dip.embedding.registry.EmbeddingModelRegistry; +import java.util.List; +import java.util.UUID; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +@Service +@RequiredArgsConstructor +@Slf4j +public class RepresentationEmbeddingOrchestrator { + + private final EmbeddingJobService jobService; + private final EmbeddingExecutionService executionService; + private final EmbeddingPersistenceService persistenceService; + private final DocumentTextRepresentationRepository representationRepository; + private final EmbeddingModelRegistry modelRegistry; + private final EmbeddingProperties embeddingProperties; + + @Transactional + public List enqueueDocument(UUID documentId) { + return jobService.enqueueForDocument(documentId); + } + + @Transactional + public List enqueueDocument(UUID documentId, String modelKey) { + return jobService.enqueueForDocument(documentId, modelKey); + } + + @Transactional + public EmbeddingJob enqueueRepresentation(UUID documentId, UUID representationId, String modelKey) { + return jobService.enqueueForRepresentation(documentId, representationId, modelKey, EmbeddingJobType.DOCUMENT_EMBED); + } + + @Transactional + public int processNextReadyBatch() { + if (!embeddingProperties.isEnabled() || !embeddingProperties.getJobs().isEnabled()) { + log.debug("New embedding subsystem jobs are disabled"); + return 0; + } + + List jobs = jobService.claimNextReadyJobs(embeddingProperties.getJobs().getBatchSize()); + for (EmbeddingJob job : jobs) { + processClaimedJob(job); + } + return jobs.size(); + } + + @Transactional + public void processClaimedJob(EmbeddingJob job) { + DocumentTextRepresentation representation = representationRepository.findById(job.getRepresentationId()) + .orElseThrow(() -> new IllegalArgumentException("Unknown representation id: " + job.getRepresentationId())); + + String text = representation.getTextBody(); + if (text == null || text.isBlank()) { + jobService.markFailed(job.getId(), "No text representation available", false); + return; + } + + int maxChars = modelRegistry.getRequired(job.getModelKey()).maxInputChars() != null + ? modelRegistry.getRequired(job.getModelKey()).maxInputChars() + : embeddingProperties.getIndexing().getFallbackMaxInputChars(); + if (text.length() > maxChars) { + text = text.substring(0, maxChars); + } + + DocumentEmbedding embedding = persistenceService.ensurePending(representation.getId(), job.getModelKey()); + persistenceService.markProcessing(embedding.getId()); + + try { + EmbeddingProviderResult result = executionService.embedTexts( + job.getModelKey(), + EmbeddingUseCase.DOCUMENT, + List.of(text) + ); + persistenceService.saveCompleted(embedding.getId(), result); + jobService.markCompleted(job.getId(), result.providerRequestId()); + } catch (RuntimeException ex) { + persistenceService.markFailed(embedding.getId(), ex.getMessage()); + jobService.markFailed(job.getId(), ex.getMessage(), true); + throw ex; + } + } +} diff --git a/src/main/java/at/procon/dip/embedding/support/EmbeddingVectorCodec.java b/src/main/java/at/procon/dip/embedding/support/EmbeddingVectorCodec.java new file mode 100644 index 0000000..a4e4f40 --- /dev/null +++ b/src/main/java/at/procon/dip/embedding/support/EmbeddingVectorCodec.java @@ -0,0 +1,23 @@ +package at.procon.dip.embedding.support; + +public final class EmbeddingVectorCodec { + + private EmbeddingVectorCodec() { + } + + public static String toPgVector(float[] vector) { + if (vector == null || vector.length == 0) { + throw new IllegalArgumentException("Embedding vector must not be null or empty"); + } + StringBuilder builder = new StringBuilder(); + builder.append('['); + for (int i = 0; i < vector.length; i++) { + if (i > 0) { + builder.append(','); + } + builder.append(Float.toString(vector[i])); + } + builder.append(']'); + return builder.toString(); + } +} diff --git a/src/main/resources/db/migration/V9__doc_nv2_embedding_jobs.sql b/src/main/resources/db/migration/V9__doc_nv2_embedding_jobs.sql new file mode 100644 index 0000000..31215f7 --- /dev/null +++ b/src/main/resources/db/migration/V9__doc_nv2_embedding_jobs.sql @@ -0,0 +1,30 @@ +-- NV2: Parallel embedding subsystem job table +-- Additive and independent from the legacy/transitional vectorization flow. + +CREATE TABLE IF NOT EXISTS DOC.doc_embedding_job ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + document_id UUID NOT NULL REFERENCES DOC.doc_document(id) ON DELETE CASCADE, + representation_id UUID NOT NULL REFERENCES DOC.doc_text_representation(id) ON DELETE CASCADE, + model_key VARCHAR(255) NOT NULL, + job_type VARCHAR(32) NOT NULL, + status VARCHAR(32) NOT NULL DEFAULT 'PENDING', + attempt_count INTEGER NOT NULL DEFAULT 0, + next_retry_at TIMESTAMP WITH TIME ZONE, + priority INTEGER NOT NULL DEFAULT 0, + provider_request_id VARCHAR(255), + last_error TEXT, + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_doc_embedding_job_status_next_retry + ON DOC.doc_embedding_job(status, next_retry_at); + +CREATE INDEX IF NOT EXISTS idx_doc_embedding_job_representation + ON DOC.doc_embedding_job(representation_id); + +CREATE INDEX IF NOT EXISTS idx_doc_embedding_job_document + ON DOC.doc_embedding_job(document_id); + +CREATE INDEX IF NOT EXISTS idx_doc_embedding_job_model_key + ON DOC.doc_embedding_job(model_key); diff --git a/src/test/java/at/procon/dip/embedding/policy/DefaultEmbeddingSelectionPolicyTest.java b/src/test/java/at/procon/dip/embedding/policy/DefaultEmbeddingSelectionPolicyTest.java new file mode 100644 index 0000000..a049db3 --- /dev/null +++ b/src/test/java/at/procon/dip/embedding/policy/DefaultEmbeddingSelectionPolicyTest.java @@ -0,0 +1,78 @@ +package at.procon.dip.embedding.policy; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.when; + +import at.procon.dip.domain.document.RepresentationType; +import at.procon.dip.domain.document.entity.DocumentTextRepresentation; +import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository; +import at.procon.dip.embedding.config.EmbeddingProperties; +import at.procon.dip.embedding.model.EmbeddingModelDescriptor; +import at.procon.dip.domain.document.DistanceMetric; +import java.util.List; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class DefaultEmbeddingSelectionPolicyTest { + + @Mock + private DocumentTextRepresentationRepository representationRepository; + + private DefaultEmbeddingSelectionPolicy policy; + private EmbeddingModelDescriptor model; + + @BeforeEach + void setUp() { + EmbeddingProperties properties = new EmbeddingProperties(); + properties.getIndexing().setEmbedSemanticText(true); + properties.getIndexing().setEmbedTitleAbstract(true); + properties.getIndexing().setEmbedChunks(true); + properties.getIndexing().setChunkMinLength(20); + policy = new DefaultEmbeddingSelectionPolicy(representationRepository, properties); + model = new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16, + DistanceMetric.COSINE, true, false, 4096, true); + } + + @Test + void should_select_semantic_text_title_and_large_chunks() { + UUID documentId = UUID.randomUUID(); + DocumentTextRepresentation semantic = representation(RepresentationType.SEMANTIC_TEXT, true, null, + "semantic text"); + DocumentTextRepresentation titleAbstract = representation(RepresentationType.TITLE_ABSTRACT, false, null, + "title abstract text"); + DocumentTextRepresentation shortChunk = representation(RepresentationType.CHUNK, false, 0, + "too short"); + DocumentTextRepresentation longChunk = representation(RepresentationType.CHUNK, false, 1, + "This is a sufficiently long chunk that should be selected."); + when(representationRepository.findByDocument_Id(documentId)) + .thenReturn(List.of(longChunk, semantic, shortChunk, titleAbstract)); + + List selected = policy.selectRepresentations(documentId, model); + + assertThat(selected) + .extracting(DocumentTextRepresentation::getRepresentationType) + .containsExactly(RepresentationType.SEMANTIC_TEXT, RepresentationType.TITLE_ABSTRACT, RepresentationType.CHUNK); + assertThat(selected) + .extracting(DocumentTextRepresentation::getChunkIndex) + .containsExactly(null, null, 1); + } + + private DocumentTextRepresentation representation(RepresentationType type, + boolean primary, + Integer chunkIndex, + String text) { + return DocumentTextRepresentation.builder() + .id(UUID.randomUUID()) + .representationType(type) + .primaryRepresentation(primary) + .chunkIndex(chunkIndex) + .charCount(text.length()) + .textBody(text) + .build(); + } +} diff --git a/src/test/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestratorTest.java b/src/test/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestratorTest.java new file mode 100644 index 0000000..a4f6bb2 --- /dev/null +++ b/src/test/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestratorTest.java @@ -0,0 +1,105 @@ +package at.procon.dip.embedding.service; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import at.procon.dip.domain.document.entity.Document; +import at.procon.dip.domain.document.entity.DocumentEmbedding; +import at.procon.dip.domain.document.entity.DocumentTextRepresentation; +import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository; +import at.procon.dip.embedding.config.EmbeddingProperties; +import at.procon.dip.embedding.job.entity.EmbeddingJob; +import at.procon.dip.embedding.job.service.EmbeddingJobService; +import at.procon.dip.embedding.model.EmbeddingJobStatus; +import at.procon.dip.embedding.model.EmbeddingJobType; +import at.procon.dip.embedding.model.EmbeddingModelDescriptor; +import at.procon.dip.embedding.model.EmbeddingProviderResult; +import at.procon.dip.embedding.registry.EmbeddingModelRegistry; +import at.procon.dip.domain.document.DistanceMetric; +import java.util.List; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +class RepresentationEmbeddingOrchestratorTest { + + @Mock + private EmbeddingJobService jobService; + @Mock + private EmbeddingExecutionService executionService; + @Mock + private EmbeddingPersistenceService persistenceService; + @Mock + private DocumentTextRepresentationRepository representationRepository; + @Mock + private EmbeddingModelRegistry modelRegistry; + + private RepresentationEmbeddingOrchestrator orchestrator; + + @BeforeEach + void setUp() { + EmbeddingProperties properties = new EmbeddingProperties(); + properties.setEnabled(true); + properties.getJobs().setEnabled(true); + properties.getJobs().setBatchSize(10); + properties.getIndexing().setFallbackMaxInputChars(8192); + orchestrator = new RepresentationEmbeddingOrchestrator( + jobService, + executionService, + persistenceService, + representationRepository, + modelRegistry, + properties + ); + } + + @Test + void processClaimedJob_should_embed_representation_and_mark_job_completed() { + UUID documentId = UUID.randomUUID(); + UUID representationId = UUID.randomUUID(); + UUID embeddingId = UUID.randomUUID(); + EmbeddingJob job = EmbeddingJob.builder() + .id(UUID.randomUUID()) + .documentId(documentId) + .representationId(representationId) + .modelKey("mock-search") + .jobType(EmbeddingJobType.DOCUMENT_EMBED) + .status(EmbeddingJobStatus.IN_PROGRESS) + .attemptCount(1) + .build(); + + Document document = Document.builder().id(documentId).build(); + DocumentTextRepresentation representation = DocumentTextRepresentation.builder() + .id(representationId) + .document(document) + .textBody("District heating optimization strategy") + .build(); + when(representationRepository.findById(representationId)).thenReturn(java.util.Optional.of(representation)); + when(modelRegistry.getRequired("mock-search")) + .thenReturn(new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16, + DistanceMetric.COSINE, true, false, 4096, true)); + when(persistenceService.ensurePending(representationId, "mock-search")) + .thenReturn(DocumentEmbedding.builder().id(embeddingId).build()); + when(executionService.embedTexts(eq("mock-search"), any(), any())) + .thenReturn(new EmbeddingProviderResult( + new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16, + DistanceMetric.COSINE, true, false, 4096, true), + List.of(new float[]{0.1f, 0.2f}), + List.of(), + "req-1", + 12 + )); + + orchestrator.processClaimedJob(job); + + verify(persistenceService).markProcessing(embeddingId); + verify(persistenceService).saveCompleted(eq(embeddingId), any(EmbeddingProviderResult.class)); + verify(jobService).markCompleted(job.getId(), "req-1"); + } +} diff --git a/src/test/java/at/procon/dip/embedding/support/EmbeddingVectorCodecTest.java b/src/test/java/at/procon/dip/embedding/support/EmbeddingVectorCodecTest.java new file mode 100644 index 0000000..64578da --- /dev/null +++ b/src/test/java/at/procon/dip/embedding/support/EmbeddingVectorCodecTest.java @@ -0,0 +1,14 @@ +package at.procon.dip.embedding.support; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +class EmbeddingVectorCodecTest { + + @Test + void should_render_pgvector_literal() { + String rendered = EmbeddingVectorCodec.toPgVector(new float[]{1.0f, 2.5f, -0.75f}); + assertThat(rendered).isEqualTo("[1.0,2.5,-0.75]"); + } +}