diff --git a/src/main/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProvider.java b/src/main/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProvider.java index 6aa8ae6..c0f2804 100644 --- a/src/main/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProvider.java +++ b/src/main/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProvider.java @@ -152,6 +152,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid settings.truncateText(), settings.truncateLength(), settings.chunkSize(), + true, items ) ); @@ -296,6 +297,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid @JsonProperty("truncate_text") boolean truncateText, @JsonProperty("truncate_length") int truncateLength, @JsonProperty("chunk_size") int chunkSize, + @JsonProperty("balance_chunks") boolean balanceChunks, @JsonProperty("items") List items ) { } diff --git a/src/main/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestrator.java b/src/main/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestrator.java index daccdcd..b0581fe 100644 --- a/src/main/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestrator.java +++ b/src/main/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestrator.java @@ -16,8 +16,10 @@ import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.UUID; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.Future; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; @@ -80,20 +82,10 @@ public class RepresentationEmbeddingOrchestrator { int batchSize = Math.max(1, embeddingProperties.getJobs().getBatchSize()); int parallelBatchCount = Math.max(1, embeddingProperties.getJobs().getParallelBatchCount()); - int claimLimit = batchSize * parallelBatchCount; - List jobs = jobService.claimNextReadyJobs(claimLimit); - if (jobs.isEmpty()) { - return 0; - } - - List> claimedBatches = partition(jobs, batchSize); - if (claimedBatches.size() == 1) { - processClaimedJobBatch(claimedBatches.getFirst()); - } else { - runClaimedBatchesInParallel(claimedBatches); - } - return jobs.size(); + return parallelBatchCount == 1 + ? processWithSingleBatchSlot(batchSize) + : processWithSlidingParallelBatchWindow(batchSize, parallelBatchCount); } public void processClaimedJob(EmbeddingJob job) { @@ -126,6 +118,86 @@ public class RepresentationEmbeddingOrchestrator { } } + private int processWithSingleBatchSlot(int batchSize) { + int claimedTotal = 0; + while (true) { + List claimedBatch = jobService.claimNextReadyJobs(batchSize); + if (claimedBatch.isEmpty()) { + return claimedTotal; + } + claimedTotal += claimedBatch.size(); + processClaimedJobBatch(claimedBatch); + } + } + + private int processWithSlidingParallelBatchWindow(int batchSize, int parallelBatchCount) { + ExecutorCompletionService completionService = new ExecutorCompletionService<>(embeddingJobProcessingExecutor); + int claimedTotal = 0; + int activeBatchCount = 0; + boolean noMoreReadyJobs = false; + + while (activeBatchCount < parallelBatchCount && !noMoreReadyJobs) { + List claimedBatch = jobService.claimNextReadyJobs(batchSize); + if (claimedBatch.isEmpty()) { + noMoreReadyJobs = true; + } else { + claimedTotal += claimedBatch.size(); + submitClaimedBatch(completionService, claimedBatch); + activeBatchCount++; + } + } + + while (activeBatchCount > 0) { + waitForBatchCompletion(takeCompletedBatch(completionService)); + activeBatchCount--; + + if (!noMoreReadyJobs) { + List claimedBatch = jobService.claimNextReadyJobs(batchSize); + if (claimedBatch.isEmpty()) { + noMoreReadyJobs = true; + } else { + claimedTotal += claimedBatch.size(); + submitClaimedBatch(completionService, claimedBatch); + activeBatchCount++; + } + } + } + + return claimedTotal; + } + + private void submitClaimedBatch(ExecutorCompletionService completionService, + List claimedBatch) { + completionService.submit(() -> { + processClaimedJobBatch(claimedBatch); + return claimedBatch.size(); + }); + } + + private Future takeCompletedBatch(ExecutorCompletionService completionService) { + try { + return completionService.take(); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw new IllegalStateException("Interrupted while waiting for an embedding batch slot to become free", ex); + } + } + + private void waitForBatchCompletion(Future future) { + try { + future.get(); + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw new IllegalStateException("Interrupted while waiting for embedding batch completion", ex); + } catch (java.util.concurrent.ExecutionException ex) { + Throwable cause = ex.getCause(); + if (cause instanceof RuntimeException runtimeException) { + throw runtimeException; + } + throw new CompletionException(cause); + } + } + private void processClaimedJobBatch(List jobs) { if (embeddingProperties.getJobs().isProcessInBatches()) { processClaimedJobsInExecutionBatches(jobs); @@ -134,17 +206,6 @@ public class RepresentationEmbeddingOrchestrator { } } - private void runClaimedBatchesInParallel(List> claimedBatches) { - List> futures = claimedBatches.stream() - .map(batch -> CompletableFuture.runAsync( - () -> processClaimedJobBatch(batch), - embeddingJobProcessingExecutor - )) - .toList(); - - futures.forEach(CompletableFuture::join); - } - private void processClaimedJobsInExecutionBatches(List jobs) { LinkedHashMap> jobsByModelKey = new LinkedHashMap<>(); for (EmbeddingJob job : jobs) { @@ -171,13 +232,6 @@ public class RepresentationEmbeddingOrchestrator { } } - private List> partition(List jobs, int batchSize) { - List> batches = new ArrayList<>(); - for (int start = 0; start < jobs.size(); start += batchSize) { - batches.add(jobs.subList(start, Math.min(start + batchSize, jobs.size()))); - } - return batches; - } private void processClaimedBatchSafely(List jobs, EmbeddingModelDescriptor model) { try { diff --git a/src/test/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestratorTest.java b/src/test/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestratorTest.java index d94a75c..873e6c0 100644 --- a/src/test/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestratorTest.java +++ b/src/test/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestratorTest.java @@ -139,7 +139,7 @@ class RepresentationEmbeddingOrchestratorTest { .attemptCount(1) .build(); - when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job1, job2)); + when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job1, job2), List.of()); Document document = Document.builder().id(documentId).build(); when(representationRepository.findById(representationId1)).thenReturn(Optional.of( @@ -196,7 +196,7 @@ class RepresentationEmbeddingOrchestratorTest { .attemptCount(1) .build(); - when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job)); + when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job), List.of()); Document document = Document.builder().id(documentId).build(); when(representationRepository.findById(representationId)).thenReturn(Optional.of( DocumentTextRepresentation.builder().id(representationId).document(document).textBody("gamma").build() @@ -225,7 +225,7 @@ class RepresentationEmbeddingOrchestratorTest { } @Test - void processNextReadyBatch_should_claim_multiple_batches_and_start_them_in_parallel() { + void processNextReadyBatch_should_keep_claiming_new_batches_until_all_parallel_slots_are_exhausted() { properties.getJobs().setParallelBatchCount(2); properties.getJobs().setBatchSize(2); properties.getJobs().setProcessInBatches(true); @@ -236,23 +236,35 @@ class RepresentationEmbeddingOrchestratorTest { UUID representationId2 = UUID.randomUUID(); UUID representationId3 = UUID.randomUUID(); UUID representationId4 = UUID.randomUUID(); + UUID representationId5 = UUID.randomUUID(); + UUID representationId6 = UUID.randomUUID(); UUID embeddingId1 = UUID.randomUUID(); UUID embeddingId2 = UUID.randomUUID(); UUID embeddingId3 = UUID.randomUUID(); UUID embeddingId4 = UUID.randomUUID(); + UUID embeddingId5 = UUID.randomUUID(); + UUID embeddingId6 = UUID.randomUUID(); EmbeddingJob job1 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId1).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build(); EmbeddingJob job2 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId2).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build(); EmbeddingJob job3 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId3).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build(); EmbeddingJob job4 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId4).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build(); + EmbeddingJob job5 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId5).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build(); + EmbeddingJob job6 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId6).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build(); - when(jobService.claimNextReadyJobs(4)).thenReturn(List.of(job1, job2, job3, job4)); + when(jobService.claimNextReadyJobs(2)) + .thenReturn(List.of(job1, job2)) + .thenReturn(List.of(job3, job4)) + .thenReturn(List.of(job5, job6)) + .thenReturn(List.of()); Document document = Document.builder().id(documentId).build(); when(representationRepository.findById(representationId1)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId1).document(document).textBody("alpha").build())); when(representationRepository.findById(representationId2)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId2).document(document).textBody("beta").build())); when(representationRepository.findById(representationId3)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId3).document(document).textBody("gamma").build())); when(representationRepository.findById(representationId4)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId4).document(document).textBody("delta").build())); + when(representationRepository.findById(representationId5)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId5).document(document).textBody("epsilon").build())); + when(representationRepository.findById(representationId6)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId6).document(document).textBody("zeta").build())); EmbeddingModelDescriptor model = new EmbeddingModelDescriptor( "e5-default", "vector-sync-e5", "intfloat/multilingual-e5-large", 3, @@ -263,6 +275,8 @@ class RepresentationEmbeddingOrchestratorTest { when(executionPersistenceService.startProcessing(representationId2, "e5-default")).thenReturn(embeddingId2); when(executionPersistenceService.startProcessing(representationId3, "e5-default")).thenReturn(embeddingId3); when(executionPersistenceService.startProcessing(representationId4, "e5-default")).thenReturn(embeddingId4); + when(executionPersistenceService.startProcessing(representationId5, "e5-default")).thenReturn(embeddingId5); + when(executionPersistenceService.startProcessing(representationId6, "e5-default")).thenReturn(embeddingId6); when(executionService.embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any())) .thenReturn(new EmbeddingProviderResult( model, @@ -277,16 +291,25 @@ class RepresentationEmbeddingOrchestratorTest { List.of(), "batch-req-2", 25 + )) + .thenReturn(new EmbeddingProviderResult( + model, + List.of(new float[]{1.3f, 1.4f, 1.5f}, new float[]{1.6f, 1.7f, 1.8f}), + List.of(), + "batch-req-3", + 29 )); int processed = orchestrator.processNextReadyBatch(); - assertThat(processed).isEqualTo(4); - verify(jobService).claimNextReadyJobs(4); - verify(executionService, times(2)).embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any()); + assertThat(processed).isEqualTo(6); + verify(jobService, times(4)).claimNextReadyJobs(2); + verify(executionService, times(3)).embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any()); verify(executionPersistenceService).completeJob(eq(embeddingId1), aryEq(new float[]{0.1f, 0.2f, 0.3f}), eq(null), eq(job1.getId()), eq("batch-req-1")); verify(executionPersistenceService).completeJob(eq(embeddingId2), aryEq(new float[]{0.4f, 0.5f, 0.6f}), eq(null), eq(job2.getId()), eq("batch-req-1")); verify(executionPersistenceService).completeJob(eq(embeddingId3), aryEq(new float[]{0.7f, 0.8f, 0.9f}), eq(null), eq(job3.getId()), eq("batch-req-2")); verify(executionPersistenceService).completeJob(eq(embeddingId4), aryEq(new float[]{1.0f, 1.1f, 1.2f}), eq(null), eq(job4.getId()), eq("batch-req-2")); + verify(executionPersistenceService).completeJob(eq(embeddingId5), aryEq(new float[]{1.3f, 1.4f, 1.5f}), eq(null), eq(job5.getId()), eq("batch-req-3")); + verify(executionPersistenceService).completeJob(eq(embeddingId6), aryEq(new float[]{1.6f, 1.7f, 1.8f}), eq(null), eq(job6.getId()), eq("batch-req-3")); } }