always request balanced chunking vectorization, vectorize newest embedding jobs first
This commit is contained in:
parent
8fddc2a429
commit
1500e84757
|
|
@ -152,6 +152,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|||
settings.truncateText(),
|
||||
settings.truncateLength(),
|
||||
settings.chunkSize(),
|
||||
true,
|
||||
items
|
||||
)
|
||||
);
|
||||
|
|
@ -296,6 +297,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|||
@JsonProperty("truncate_text") boolean truncateText,
|
||||
@JsonProperty("truncate_length") int truncateLength,
|
||||
@JsonProperty("chunk_size") int chunkSize,
|
||||
@JsonProperty("balance_chunks") boolean balanceChunks,
|
||||
@JsonProperty("items") List<VectorizeBatchItemRequest> items
|
||||
) {
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,8 +16,10 @@ import java.util.ArrayList;
|
|||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionException;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.ExecutorCompletionService;
|
||||
import java.util.concurrent.Future;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
|
@ -80,20 +82,10 @@ public class RepresentationEmbeddingOrchestrator {
|
|||
|
||||
int batchSize = Math.max(1, embeddingProperties.getJobs().getBatchSize());
|
||||
int parallelBatchCount = Math.max(1, embeddingProperties.getJobs().getParallelBatchCount());
|
||||
int claimLimit = batchSize * parallelBatchCount;
|
||||
|
||||
List<EmbeddingJob> jobs = jobService.claimNextReadyJobs(claimLimit);
|
||||
if (jobs.isEmpty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
List<List<EmbeddingJob>> claimedBatches = partition(jobs, batchSize);
|
||||
if (claimedBatches.size() == 1) {
|
||||
processClaimedJobBatch(claimedBatches.getFirst());
|
||||
} else {
|
||||
runClaimedBatchesInParallel(claimedBatches);
|
||||
}
|
||||
return jobs.size();
|
||||
return parallelBatchCount == 1
|
||||
? processWithSingleBatchSlot(batchSize)
|
||||
: processWithSlidingParallelBatchWindow(batchSize, parallelBatchCount);
|
||||
}
|
||||
|
||||
public void processClaimedJob(EmbeddingJob job) {
|
||||
|
|
@ -126,6 +118,86 @@ public class RepresentationEmbeddingOrchestrator {
|
|||
}
|
||||
}
|
||||
|
||||
private int processWithSingleBatchSlot(int batchSize) {
|
||||
int claimedTotal = 0;
|
||||
while (true) {
|
||||
List<EmbeddingJob> claimedBatch = jobService.claimNextReadyJobs(batchSize);
|
||||
if (claimedBatch.isEmpty()) {
|
||||
return claimedTotal;
|
||||
}
|
||||
claimedTotal += claimedBatch.size();
|
||||
processClaimedJobBatch(claimedBatch);
|
||||
}
|
||||
}
|
||||
|
||||
private int processWithSlidingParallelBatchWindow(int batchSize, int parallelBatchCount) {
|
||||
ExecutorCompletionService<Integer> completionService = new ExecutorCompletionService<>(embeddingJobProcessingExecutor);
|
||||
int claimedTotal = 0;
|
||||
int activeBatchCount = 0;
|
||||
boolean noMoreReadyJobs = false;
|
||||
|
||||
while (activeBatchCount < parallelBatchCount && !noMoreReadyJobs) {
|
||||
List<EmbeddingJob> claimedBatch = jobService.claimNextReadyJobs(batchSize);
|
||||
if (claimedBatch.isEmpty()) {
|
||||
noMoreReadyJobs = true;
|
||||
} else {
|
||||
claimedTotal += claimedBatch.size();
|
||||
submitClaimedBatch(completionService, claimedBatch);
|
||||
activeBatchCount++;
|
||||
}
|
||||
}
|
||||
|
||||
while (activeBatchCount > 0) {
|
||||
waitForBatchCompletion(takeCompletedBatch(completionService));
|
||||
activeBatchCount--;
|
||||
|
||||
if (!noMoreReadyJobs) {
|
||||
List<EmbeddingJob> claimedBatch = jobService.claimNextReadyJobs(batchSize);
|
||||
if (claimedBatch.isEmpty()) {
|
||||
noMoreReadyJobs = true;
|
||||
} else {
|
||||
claimedTotal += claimedBatch.size();
|
||||
submitClaimedBatch(completionService, claimedBatch);
|
||||
activeBatchCount++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return claimedTotal;
|
||||
}
|
||||
|
||||
private void submitClaimedBatch(ExecutorCompletionService<Integer> completionService,
|
||||
List<EmbeddingJob> claimedBatch) {
|
||||
completionService.submit(() -> {
|
||||
processClaimedJobBatch(claimedBatch);
|
||||
return claimedBatch.size();
|
||||
});
|
||||
}
|
||||
|
||||
private Future<Integer> takeCompletedBatch(ExecutorCompletionService<Integer> completionService) {
|
||||
try {
|
||||
return completionService.take();
|
||||
} catch (InterruptedException ex) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new IllegalStateException("Interrupted while waiting for an embedding batch slot to become free", ex);
|
||||
}
|
||||
}
|
||||
|
||||
private void waitForBatchCompletion(Future<Integer> future) {
|
||||
try {
|
||||
future.get();
|
||||
} catch (InterruptedException ex) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new IllegalStateException("Interrupted while waiting for embedding batch completion", ex);
|
||||
} catch (java.util.concurrent.ExecutionException ex) {
|
||||
Throwable cause = ex.getCause();
|
||||
if (cause instanceof RuntimeException runtimeException) {
|
||||
throw runtimeException;
|
||||
}
|
||||
throw new CompletionException(cause);
|
||||
}
|
||||
}
|
||||
|
||||
private void processClaimedJobBatch(List<EmbeddingJob> jobs) {
|
||||
if (embeddingProperties.getJobs().isProcessInBatches()) {
|
||||
processClaimedJobsInExecutionBatches(jobs);
|
||||
|
|
@ -134,17 +206,6 @@ public class RepresentationEmbeddingOrchestrator {
|
|||
}
|
||||
}
|
||||
|
||||
private void runClaimedBatchesInParallel(List<List<EmbeddingJob>> claimedBatches) {
|
||||
List<CompletableFuture<Void>> futures = claimedBatches.stream()
|
||||
.map(batch -> CompletableFuture.runAsync(
|
||||
() -> processClaimedJobBatch(batch),
|
||||
embeddingJobProcessingExecutor
|
||||
))
|
||||
.toList();
|
||||
|
||||
futures.forEach(CompletableFuture::join);
|
||||
}
|
||||
|
||||
private void processClaimedJobsInExecutionBatches(List<EmbeddingJob> jobs) {
|
||||
LinkedHashMap<String, List<EmbeddingJob>> jobsByModelKey = new LinkedHashMap<>();
|
||||
for (EmbeddingJob job : jobs) {
|
||||
|
|
@ -171,13 +232,6 @@ public class RepresentationEmbeddingOrchestrator {
|
|||
}
|
||||
}
|
||||
|
||||
private List<List<EmbeddingJob>> partition(List<EmbeddingJob> jobs, int batchSize) {
|
||||
List<List<EmbeddingJob>> batches = new ArrayList<>();
|
||||
for (int start = 0; start < jobs.size(); start += batchSize) {
|
||||
batches.add(jobs.subList(start, Math.min(start + batchSize, jobs.size())));
|
||||
}
|
||||
return batches;
|
||||
}
|
||||
|
||||
private void processClaimedBatchSafely(List<EmbeddingJob> jobs, EmbeddingModelDescriptor model) {
|
||||
try {
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ class RepresentationEmbeddingOrchestratorTest {
|
|||
.attemptCount(1)
|
||||
.build();
|
||||
|
||||
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job1, job2));
|
||||
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job1, job2), List.of());
|
||||
|
||||
Document document = Document.builder().id(documentId).build();
|
||||
when(representationRepository.findById(representationId1)).thenReturn(Optional.of(
|
||||
|
|
@ -196,7 +196,7 @@ class RepresentationEmbeddingOrchestratorTest {
|
|||
.attemptCount(1)
|
||||
.build();
|
||||
|
||||
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job));
|
||||
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job), List.of());
|
||||
Document document = Document.builder().id(documentId).build();
|
||||
when(representationRepository.findById(representationId)).thenReturn(Optional.of(
|
||||
DocumentTextRepresentation.builder().id(representationId).document(document).textBody("gamma").build()
|
||||
|
|
@ -225,7 +225,7 @@ class RepresentationEmbeddingOrchestratorTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
void processNextReadyBatch_should_claim_multiple_batches_and_start_them_in_parallel() {
|
||||
void processNextReadyBatch_should_keep_claiming_new_batches_until_all_parallel_slots_are_exhausted() {
|
||||
properties.getJobs().setParallelBatchCount(2);
|
||||
properties.getJobs().setBatchSize(2);
|
||||
properties.getJobs().setProcessInBatches(true);
|
||||
|
|
@ -236,23 +236,35 @@ class RepresentationEmbeddingOrchestratorTest {
|
|||
UUID representationId2 = UUID.randomUUID();
|
||||
UUID representationId3 = UUID.randomUUID();
|
||||
UUID representationId4 = UUID.randomUUID();
|
||||
UUID representationId5 = UUID.randomUUID();
|
||||
UUID representationId6 = UUID.randomUUID();
|
||||
UUID embeddingId1 = UUID.randomUUID();
|
||||
UUID embeddingId2 = UUID.randomUUID();
|
||||
UUID embeddingId3 = UUID.randomUUID();
|
||||
UUID embeddingId4 = UUID.randomUUID();
|
||||
UUID embeddingId5 = UUID.randomUUID();
|
||||
UUID embeddingId6 = UUID.randomUUID();
|
||||
|
||||
EmbeddingJob job1 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId1).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
|
||||
EmbeddingJob job2 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId2).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
|
||||
EmbeddingJob job3 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId3).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
|
||||
EmbeddingJob job4 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId4).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
|
||||
EmbeddingJob job5 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId5).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
|
||||
EmbeddingJob job6 = EmbeddingJob.builder().id(UUID.randomUUID()).documentId(documentId).representationId(representationId6).modelKey("e5-default").jobType(EmbeddingJobType.DOCUMENT_EMBED).status(EmbeddingJobStatus.IN_PROGRESS).attemptCount(1).build();
|
||||
|
||||
when(jobService.claimNextReadyJobs(4)).thenReturn(List.of(job1, job2, job3, job4));
|
||||
when(jobService.claimNextReadyJobs(2))
|
||||
.thenReturn(List.of(job1, job2))
|
||||
.thenReturn(List.of(job3, job4))
|
||||
.thenReturn(List.of(job5, job6))
|
||||
.thenReturn(List.of());
|
||||
|
||||
Document document = Document.builder().id(documentId).build();
|
||||
when(representationRepository.findById(representationId1)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId1).document(document).textBody("alpha").build()));
|
||||
when(representationRepository.findById(representationId2)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId2).document(document).textBody("beta").build()));
|
||||
when(representationRepository.findById(representationId3)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId3).document(document).textBody("gamma").build()));
|
||||
when(representationRepository.findById(representationId4)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId4).document(document).textBody("delta").build()));
|
||||
when(representationRepository.findById(representationId5)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId5).document(document).textBody("epsilon").build()));
|
||||
when(representationRepository.findById(representationId6)).thenReturn(Optional.of(DocumentTextRepresentation.builder().id(representationId6).document(document).textBody("zeta").build()));
|
||||
|
||||
EmbeddingModelDescriptor model = new EmbeddingModelDescriptor(
|
||||
"e5-default", "vector-sync-e5", "intfloat/multilingual-e5-large", 3,
|
||||
|
|
@ -263,6 +275,8 @@ class RepresentationEmbeddingOrchestratorTest {
|
|||
when(executionPersistenceService.startProcessing(representationId2, "e5-default")).thenReturn(embeddingId2);
|
||||
when(executionPersistenceService.startProcessing(representationId3, "e5-default")).thenReturn(embeddingId3);
|
||||
when(executionPersistenceService.startProcessing(representationId4, "e5-default")).thenReturn(embeddingId4);
|
||||
when(executionPersistenceService.startProcessing(representationId5, "e5-default")).thenReturn(embeddingId5);
|
||||
when(executionPersistenceService.startProcessing(representationId6, "e5-default")).thenReturn(embeddingId6);
|
||||
when(executionService.embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any()))
|
||||
.thenReturn(new EmbeddingProviderResult(
|
||||
model,
|
||||
|
|
@ -277,16 +291,25 @@ class RepresentationEmbeddingOrchestratorTest {
|
|||
List.of(),
|
||||
"batch-req-2",
|
||||
25
|
||||
))
|
||||
.thenReturn(new EmbeddingProviderResult(
|
||||
model,
|
||||
List.of(new float[]{1.3f, 1.4f, 1.5f}, new float[]{1.6f, 1.7f, 1.8f}),
|
||||
List.of(),
|
||||
"batch-req-3",
|
||||
29
|
||||
));
|
||||
|
||||
int processed = orchestrator.processNextReadyBatch();
|
||||
|
||||
assertThat(processed).isEqualTo(4);
|
||||
verify(jobService).claimNextReadyJobs(4);
|
||||
verify(executionService, times(2)).embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any());
|
||||
assertThat(processed).isEqualTo(6);
|
||||
verify(jobService, times(4)).claimNextReadyJobs(2);
|
||||
verify(executionService, times(3)).embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any());
|
||||
verify(executionPersistenceService).completeJob(eq(embeddingId1), aryEq(new float[]{0.1f, 0.2f, 0.3f}), eq(null), eq(job1.getId()), eq("batch-req-1"));
|
||||
verify(executionPersistenceService).completeJob(eq(embeddingId2), aryEq(new float[]{0.4f, 0.5f, 0.6f}), eq(null), eq(job2.getId()), eq("batch-req-1"));
|
||||
verify(executionPersistenceService).completeJob(eq(embeddingId3), aryEq(new float[]{0.7f, 0.8f, 0.9f}), eq(null), eq(job3.getId()), eq("batch-req-2"));
|
||||
verify(executionPersistenceService).completeJob(eq(embeddingId4), aryEq(new float[]{1.0f, 1.1f, 1.2f}), eq(null), eq(job4.getId()), eq("batch-req-2"));
|
||||
verify(executionPersistenceService).completeJob(eq(embeddingId5), aryEq(new float[]{1.3f, 1.4f, 1.5f}), eq(null), eq(job5.getId()), eq("batch-req-3"));
|
||||
verify(executionPersistenceService).completeJob(eq(embeddingId6), aryEq(new float[]{1.6f, 1.7f, 1.8f}), eq(null), eq(job6.getId()), eq("batch-req-3"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue