always request balanced chunking vectorization, vectorize newest embedding jobs first

This commit is contained in:
trifonovt 2026-04-20 12:26:30 +02:00
parent 8fddc2a429
commit 1500e84757
3 changed files with 118 additions and 39 deletions

View File

@ -152,6 +152,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
settings.truncateText(),
settings.truncateLength(),
settings.chunkSize(),
true,
items
)
);
@ -296,6 +297,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
@JsonProperty("truncate_text") boolean truncateText,
@JsonProperty("truncate_length") int truncateLength,
@JsonProperty("chunk_size") int chunkSize,
@JsonProperty("balance_chunks") boolean balanceChunks,
@JsonProperty("items") List<VectorizeBatchItemRequest> items
) {
}

View File

@ -16,8 +16,10 @@ import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Future;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
@ -80,20 +82,10 @@ public class RepresentationEmbeddingOrchestrator {
int batchSize = Math.max(1, embeddingProperties.getJobs().getBatchSize());
int parallelBatchCount = Math.max(1, embeddingProperties.getJobs().getParallelBatchCount());
int claimLimit = batchSize * parallelBatchCount;
List<EmbeddingJob> jobs = jobService.claimNextReadyJobs(claimLimit);
if (jobs.isEmpty()) {
return 0;
}
List<List<EmbeddingJob>> claimedBatches = partition(jobs, batchSize);
if (claimedBatches.size() == 1) {
processClaimedJobBatch(claimedBatches.getFirst());
} else {
runClaimedBatchesInParallel(claimedBatches);
}
return jobs.size();
return parallelBatchCount == 1
? processWithSingleBatchSlot(batchSize)
: processWithSlidingParallelBatchWindow(batchSize, parallelBatchCount);
}
public void processClaimedJob(EmbeddingJob job) {
@ -126,6 +118,86 @@ public class RepresentationEmbeddingOrchestrator {
}
}
private int processWithSingleBatchSlot(int batchSize) {
int claimedTotal = 0;
while (true) {
List<EmbeddingJob> claimedBatch = jobService.claimNextReadyJobs(batchSize);
if (claimedBatch.isEmpty()) {
return claimedTotal;
}
claimedTotal += claimedBatch.size();
processClaimedJobBatch(claimedBatch);
}
}
private int processWithSlidingParallelBatchWindow(int batchSize, int parallelBatchCount) {
ExecutorCompletionService<Integer> completionService = new ExecutorCompletionService<>(embeddingJobProcessingExecutor);
int claimedTotal = 0;
int activeBatchCount = 0;
boolean noMoreReadyJobs = false;
while (activeBatchCount < parallelBatchCount && !noMoreReadyJobs) {
List<EmbeddingJob> claimedBatch = jobService.claimNextReadyJobs(batchSize);
if (claimedBatch.isEmpty()) {
noMoreReadyJobs = true;
} else {
claimedTotal += claimedBatch.size();
submitClaimedBatch(completionService, claimedBatch);
activeBatchCount++;
}
}
while (activeBatchCount > 0) {
waitForBatchCompletion(takeCompletedBatch(completionService));
activeBatchCount--;
if (!noMoreReadyJobs) {
List<EmbeddingJob> claimedBatch = jobService.claimNextReadyJobs(batchSize);
if (claimedBatch.isEmpty()) {
noMoreReadyJobs = true;
} else {
claimedTotal += claimedBatch.size();
submitClaimedBatch(completionService, claimedBatch);
activeBatchCount++;
}
}
}
return claimedTotal;
}
private void submitClaimedBatch(ExecutorCompletionService<Integer> completionService,
List<EmbeddingJob> claimedBatch) {
completionService.submit(() -> {
processClaimedJobBatch(claimedBatch);
return claimedBatch.size();
});
}
private Future<Integer> takeCompletedBatch(ExecutorCompletionService<Integer> completionService) {
try {
return completionService.take();
} catch (InterruptedException ex) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Interrupted while waiting for an embedding batch slot to become free", ex);
}
}
private void waitForBatchCompletion(Future<Integer> future) {
try {
future.get();
} catch (InterruptedException ex) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Interrupted while waiting for embedding batch completion", ex);
} catch (java.util.concurrent.ExecutionException ex) {
Throwable cause = ex.getCause();
if (cause instanceof RuntimeException runtimeException) {
throw runtimeException;
}
throw new CompletionException(cause);
}
}
private void processClaimedJobBatch(List<EmbeddingJob> jobs) {
if (embeddingProperties.getJobs().isProcessInBatches()) {
processClaimedJobsInExecutionBatches(jobs);
@ -134,17 +206,6 @@ public class RepresentationEmbeddingOrchestrator {
}
}
private void runClaimedBatchesInParallel(List<List<EmbeddingJob>> claimedBatches) {
List<CompletableFuture<Void>> futures = claimedBatches.stream()
.map(batch -> CompletableFuture.runAsync(
() -> processClaimedJobBatch(batch),
embeddingJobProcessingExecutor
))
.toList();
futures.forEach(CompletableFuture::join);
}
private void processClaimedJobsInExecutionBatches(List<EmbeddingJob> jobs) {
LinkedHashMap<String, List<EmbeddingJob>> jobsByModelKey = new LinkedHashMap<>();
for (EmbeddingJob job : jobs) {
@ -171,13 +232,6 @@ public class RepresentationEmbeddingOrchestrator {
}
}
private List<List<EmbeddingJob>> partition(List<EmbeddingJob> jobs, int batchSize) {
List<List<EmbeddingJob>> batches = new ArrayList<>();
for (int start = 0; start < jobs.size(); start += batchSize) {
batches.add(jobs.subList(start, Math.min(start + batchSize, jobs.size())));
}
return batches;
}
private void processClaimedBatchSafely(List<EmbeddingJob> jobs, EmbeddingModelDescriptor model) {
try {

View File

@ -139,7 +139,7 @@ class RepresentationEmbeddingOrchestratorTest {
.attemptCount(1)
.build();
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job1, job2));
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job1, job2), List.of());
Document document = Document.builder().id(documentId).build();
when(representationRepository.findById(representationId1)).thenReturn(Optional.of(
@ -196,7 +196,7 @@ class RepresentationEmbeddingOrchestratorTest {
.attemptCount(1)
.build();
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job));
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job), List.of());
Document document = Document.builder().id(documentId).build();
when(representationRepository.findById(representationId)).thenReturn(Optional.of(
DocumentTextRepresentation.builder().id(representationId).document(document).textBody("gamma").build()
@ -225,7 +225,7 @@ class RepresentationEmbeddingOrchestratorTest {
}
@Test
void processNextReadyBatch_should_claim_multiple_batches_and_start_them_in_parallel() {
void processNextReadyBatch_should_keep_claiming_new_batches_until_all_parallel_slots_are_exhausted() {
properties.getJobs().setParallelBatchCount(2);
properties.getJobs().setBatchSize(2);
properties.getJobs().setProcessInBatches(true);
@ -236,23 +236,35 @@ class RepresentationEmbeddingOrchestratorTest {
UUID representationId2 = UUID.randomUUID();
UUID representationId3 = UUID.randomUUID();
UUID representationId4 = UUID.randomUUID();
UUID representationId5 = UUID.randomUUID();
UUID representationId6 = UUID.randomUUID();
UUID embeddingId1 = UUID.randomUUID();
UUID embeddingId2 = UUID.randomUUID();
UUID embeddingId3 = UUID.randomUUID();
UUID embeddingId4 = UUID.randomUUID();
UUID embeddingId5 = UUID.randomUUID();
UUID embeddingId6 = UUID.randomUUID();
EmbeddingJob job1 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId1).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
EmbeddingJob job2 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId2).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
EmbeddingJob job3 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId3).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
EmbeddingJob job4 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId4).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
EmbeddingJob job5 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId5).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
EmbeddingJob job6 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId6).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
when(jobService.claimNextReadyJobs(4)).thenReturn(List.of(job1, job2, job3, job4));
when(jobService.claimNextReadyJobs(2))
.thenReturn(List.of(job1, job2))
.thenReturn(List.of(job3, job4))
.thenReturn(List.of(job5, job6))
.thenReturn(List.of());
Document document = Document.builder().id(documentId).build();
when(representationRepository.findById(representationId1)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId1).document(document).textBody("alpha").build()));
when(representationRepository.findById(representationId2)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId2).document(document).textBody("beta").build()));
when(representationRepository.findById(representationId3)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId3).document(document).textBody("gamma").build()));
when(representationRepository.findById(representationId4)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId4).document(document).textBody("delta").build()));
when(representationRepository.findById(representationId5)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId5).document(document).textBody("epsilon").build()));
when(representationRepository.findById(representationId6)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId6).document(document).textBody("zeta").build()));
EmbeddingModelDescriptor model = new EmbeddingModelDescriptor(
"e5-default", "vector-sync-e5", "intfloat/multilingual-e5-large", 3,
@ -263,6 +275,8 @@ class RepresentationEmbeddingOrchestratorTest {
when(executionPersistenceService.startProcessing(representationId2, "e5-default")).thenReturn(embeddingId2);
when(executionPersistenceService.startProcessing(representationId3, "e5-default")).thenReturn(embeddingId3);
when(executionPersistenceService.startProcessing(representationId4, "e5-default")).thenReturn(embeddingId4);
when(executionPersistenceService.startProcessing(representationId5, "e5-default")).thenReturn(embeddingId5);
when(executionPersistenceService.startProcessing(representationId6, "e5-default")).thenReturn(embeddingId6);
when(executionService.embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any()))
.thenReturn(new EmbeddingProviderResult(
model,
@ -277,16 +291,25 @@ class RepresentationEmbeddingOrchestratorTest {
List.of(),
"batch-req-2",
25
))
.thenReturn(new EmbeddingProviderResult(
model,
List.of(new float[]{1.3f, 1.4f, 1.5f}, new float[]{1.6f, 1.7f, 1.8f}),
List.of(),
"batch-req-3",
29
));
int processed = orchestrator.processNextReadyBatch();
assertThat(processed).isEqualTo(4);
verify(jobService).claimNextReadyJobs(4);
verify(executionService, times(2)).embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any());
assertThat(processed).isEqualTo(6);
verify(jobService, times(4)).claimNextReadyJobs(2);
verify(executionService, times(3)).embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any());
verify(executionPersistenceService).completeJob(eq(embeddingId1), aryEq(new float[]{0.1f, 0.2f, 0.3f}), eq(null), eq(job1.getId()), eq("batch-req-1"));
verify(executionPersistenceService).completeJob(eq(embeddingId2), aryEq(new float[]{0.4f, 0.5f, 0.6f}), eq(null), eq(job2.getId()), eq("batch-req-1"));
verify(executionPersistenceService).completeJob(eq(embeddingId3), aryEq(new float[]{0.7f, 0.8f, 0.9f}), eq(null), eq(job3.getId()), eq("batch-req-2"));
verify(executionPersistenceService).completeJob(eq(embeddingId4), aryEq(new float[]{1.0f, 1.1f, 1.2f}), eq(null), eq(job4.getId()), eq("batch-req-2"));
verify(executionPersistenceService).completeJob(eq(embeddingId5), aryEq(new float[]{1.3f, 1.4f, 1.5f}), eq(null), eq(job5.getId()), eq("batch-req-3"));
verify(executionPersistenceService).completeJob(eq(embeddingId6), aryEq(new float[]{1.6f, 1.7f, 1.8f}), eq(null), eq(job6.getId()), eq("batch-req-3"));
}
}