diff --git a/src/main/java/at/procon/dip/embedding/config/EmbeddingProperties.java b/src/main/java/at/procon/dip/embedding/config/EmbeddingProperties.java index ace2954..a7a5328 100644 --- a/src/main/java/at/procon/dip/embedding/config/EmbeddingProperties.java +++ b/src/main/java/at/procon/dip/embedding/config/EmbeddingProperties.java @@ -30,6 +30,14 @@ public class EmbeddingProperties { private Duration readTimeout = Duration.ofSeconds(60); private Map headers = new LinkedHashMap<>(); private Integer dimensions; + private BatchRequestProperties batchRequest = new BatchRequestProperties(); + } + + @Data + public static class BatchRequestProperties { + private boolean truncateText = false; + private int truncateLength = 512; + private int chunkSize = 20; } @Data @@ -59,6 +67,8 @@ public class EmbeddingProperties { public static class JobsProperties { private boolean enabled = false; private int batchSize = 16; + private boolean processInBatches = false; + private int executionBatchSize = 8; private int maxRetries = 5; private Duration initialRetryDelay = Duration.ofSeconds(30); private Duration maxRetryDelay = Duration.ofHours(6); diff --git a/src/main/java/at/procon/dip/embedding/model/ResolvedEmbeddingProviderConfig.java b/src/main/java/at/procon/dip/embedding/model/ResolvedEmbeddingProviderConfig.java index 4d0e8e6..0477dc4 100644 --- a/src/main/java/at/procon/dip/embedding/model/ResolvedEmbeddingProviderConfig.java +++ b/src/main/java/at/procon/dip/embedding/model/ResolvedEmbeddingProviderConfig.java @@ -13,6 +13,9 @@ public record ResolvedEmbeddingProviderConfig( Duration connectTimeout, Duration readTimeout, Map headers, - Integer dimensions + Integer dimensions, + Boolean batchTruncateText, + Integer batchTruncateLength, + Integer batchChunkSize ) { } diff --git a/src/main/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProvider.java b/src/main/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProvider.java index 603d18b..6aa8ae6 100644 --- a/src/main/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProvider.java +++ b/src/main/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProvider.java @@ -8,8 +8,6 @@ import at.procon.dip.embedding.provider.EmbeddingProvider; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; -import java.lang.reflect.Field; -import java.lang.reflect.Method; import java.net.http.HttpResponse; import java.util.ArrayList; import java.util.HashMap; @@ -24,16 +22,6 @@ import org.springframework.stereotype.Component; * Supported endpoints: * POST {baseUrl}/vector-sync - single text * POST {baseUrl}/vectorize-batch - multiple texts - * - * Batch settings are resolved from provider config if present, otherwise defaults are used. - * Supported property keys: - * - vectorize-batch.truncate-text - * - vectorize-batch.truncate-length - * - vectorize-batch.chunk-size - * - * Also accepted as fallbacks: - * - truncate_text / truncate-length / chunk_size - * - truncateText / truncateLength / chunkSize */ @Component public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProviderSupport implements EmbeddingProvider { @@ -105,7 +93,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid try { return request.texts().size() == 1 ? executeSingle(providerConfig, model, request.texts().getFirst()) - : executeBatch(providerConfig, model, request.texts()); + : executeBatch(providerConfig, model, request); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new IllegalStateException("Embedding provider call interrupted", e); @@ -137,22 +125,20 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid private EmbeddingProviderResult executeBatch(ResolvedEmbeddingProviderConfig providerConfig, EmbeddingModelDescriptor model, - List texts) throws IOException, InterruptedException { - boolean truncateText = resolveBooleanProperty(providerConfig, TRUNCATE_TEXT_KEYS, DEFAULT_TRUNCATE_TEXT); - int truncateLength = resolveIntProperty(providerConfig, TRUNCATE_LENGTH_KEYS, DEFAULT_TRUNCATE_LENGTH); - int chunkSize = resolveIntProperty(providerConfig, CHUNK_SIZE_KEYS, DEFAULT_CHUNK_SIZE); + EmbeddingRequest request) throws IOException, InterruptedException { + BatchRequestSettings settings = resolveBatchRequestSettings(providerConfig, request.providerOptions()); - if (truncateLength <= 0) { + if (settings.truncateLength() <= 0) { throw new IllegalArgumentException("Batch truncate length must be > 0"); } - if (chunkSize <= 0) { + if (settings.chunkSize() <= 0) { throw new IllegalArgumentException("Batch chunk size must be > 0"); } - List requestOrder = new ArrayList<>(texts.size()); - List items = new ArrayList<>(texts.size()); + List requestOrder = new ArrayList<>(request.texts().size()); + List items = new ArrayList<>(request.texts().size()); - for (String text : texts) { + for (String text : request.texts()) { String id = UUID.randomUUID().toString(); requestOrder.add(id); items.add(new VectorizeBatchItemRequest(id, text)); @@ -163,9 +149,9 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid "/vectorize-batch", new VectorizeBatchRequest( model.providerModelKey(), - truncateText, - truncateLength, - chunkSize, + settings.truncateText(), + settings.truncateLength(), + settings.chunkSize(), items ) ); @@ -180,7 +166,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid resultById.put(result.id, result); } - List vectors = new ArrayList<>(texts.size()); + List vectors = new ArrayList<>(request.texts().size()); int totalTokenCount = 0; boolean hasAnyTokenCount = false; @@ -207,38 +193,35 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid ); } - private float[] extractVector(List vector, - List combinedVector, - EmbeddingModelDescriptor model) { - float[] resolved; - - if (combinedVector != null && !combinedVector.isEmpty()) { - resolved = toArray(combinedVector); - } else if (vector != null && !vector.isEmpty()) { - resolved = toArray(vector); - } else { - throw new IllegalStateException("Embedding provider returned no vector"); - } - - if (model.dimensions() > 0 && resolved.length != model.dimensions()) { - throw new IllegalStateException( - "Embedding provider returned dimension %d for model %s, expected %d" - .formatted(resolved.length, model.modelKey(), model.dimensions()) - ); - } - - return resolved; + private BatchRequestSettings resolveBatchRequestSettings(ResolvedEmbeddingProviderConfig providerConfig, + Map providerOptions) { + boolean truncateText = resolveBooleanOption( + providerOptions, + TRUNCATE_TEXT_KEYS, + providerConfig.batchTruncateText() != null ? providerConfig.batchTruncateText() : DEFAULT_TRUNCATE_TEXT + ); + int truncateLength = resolveIntOption( + providerOptions, + TRUNCATE_LENGTH_KEYS, + providerConfig.batchTruncateLength() != null ? providerConfig.batchTruncateLength() : DEFAULT_TRUNCATE_LENGTH + ); + int chunkSize = resolveIntOption( + providerOptions, + CHUNK_SIZE_KEYS, + providerConfig.batchChunkSize() != null ? providerConfig.batchChunkSize() : DEFAULT_CHUNK_SIZE + ); + return new BatchRequestSettings(truncateText, truncateLength, chunkSize); } - private boolean resolveBooleanProperty(ResolvedEmbeddingProviderConfig providerConfig, - List keys, - boolean defaultValue) { - Object raw = resolveProviderConfigValue(providerConfig, keys); + private boolean resolveBooleanOption(Map providerOptions, + List keys, + boolean defaultValue) { + Object raw = resolveOption(providerOptions, keys); if (raw == null) { return defaultValue; } - if (raw instanceof Boolean b) { - return b; + if (raw instanceof Boolean booleanValue) { + return booleanValue; } String normalized = String.valueOf(raw).trim(); if (normalized.isEmpty()) { @@ -247,15 +230,15 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid return Boolean.parseBoolean(normalized); } - private int resolveIntProperty(ResolvedEmbeddingProviderConfig providerConfig, - List keys, - int defaultValue) { - Object raw = resolveProviderConfigValue(providerConfig, keys); + private int resolveIntOption(Map providerOptions, + List keys, + int defaultValue) { + Object raw = resolveOption(providerOptions, keys); if (raw == null) { return defaultValue; } - if (raw instanceof Number n) { - return n.intValue(); + if (raw instanceof Number number) { + return number.intValue(); } String normalized = String.valueOf(raw).trim(); if (normalized.isEmpty()) { @@ -264,89 +247,42 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid return Integer.parseInt(normalized); } - @SuppressWarnings("unchecked") - private Object resolveProviderConfigValue(ResolvedEmbeddingProviderConfig providerConfig, - List keys) { - List containers = new ArrayList<>(); - containers.add(providerConfig); - - addIfNonNull(containers, invokeNoArg(providerConfig, "properties")); - addIfNonNull(containers, invokeNoArg(providerConfig, "providerProperties")); - addIfNonNull(containers, invokeNoArg(providerConfig, "config")); - addIfNonNull(containers, invokeNoArg(providerConfig, "settings")); - addIfNonNull(containers, invokeNoArg(providerConfig, "options")); - addIfNonNull(containers, readField(providerConfig, "properties")); - addIfNonNull(containers, readField(providerConfig, "providerProperties")); - addIfNonNull(containers, readField(providerConfig, "config")); - addIfNonNull(containers, readField(providerConfig, "settings")); - addIfNonNull(containers, readField(providerConfig, "options")); - - for (Object container : containers) { - if (container instanceof Map map) { - for (String key : keys) { - if (map.containsKey(key)) { - return map.get(key); - } - } - } - - for (String key : keys) { - Object value = invokeStringArg(container, "get", key); - if (value != null) { - return value; - } - value = invokeStringArg(container, "getProperty", key); - if (value != null) { - return value; - } - value = invokeStringArg(container, "property", key); - if (value != null) { - return value; - } - value = invokeStringArg(container, "option", key); - if (value != null) { - return value; - } + private Object resolveOption(Map providerOptions, List keys) { + if (providerOptions == null || providerOptions.isEmpty()) { + return null; + } + for (String key : keys) { + if (providerOptions.containsKey(key)) { + return providerOptions.get(key); } } - return null; } - private void addIfNonNull(List containers, Object value) { - if (value != null) { - containers.add(value); - } - } + private float[] extractVector(List vector, + List combinedVector, + EmbeddingModelDescriptor model) { + float[] resolved; - private Object invokeNoArg(Object target, String methodName) { - try { - Method method = target.getClass().getMethod(methodName); - method.setAccessible(true); - return method.invoke(target); - } catch (Exception ignored) { - return null; + if (combinedVector != null && !combinedVector.isEmpty()) { + resolved = toArray(combinedVector); + } else if (vector != null && !vector.isEmpty()) { + resolved = toArray(vector); + } else { + throw new IllegalStateException("Embedding provider returned no vector"); } - } - private Object invokeStringArg(Object target, String methodName, String arg) { - try { - Method method = target.getClass().getMethod(methodName, String.class); - method.setAccessible(true); - return method.invoke(target, arg); - } catch (Exception ignored) { - return null; + if (model.dimensions() > 0 && resolved.length != model.dimensions()) { + throw new IllegalStateException( + "Embedding provider returned dimension %d for model %s, expected %d" + .formatted(resolved.length, model.modelKey(), model.dimensions()) + ); } + + return resolved; } - private Object readField(Object target, String fieldName) { - try { - Field field = target.getClass().getDeclaredField(fieldName); - field.setAccessible(true); - return field.get(target); - } catch (Exception ignored) { - return null; - } + private record BatchRequestSettings(boolean truncateText, int truncateLength, int chunkSize) { } private record VectorSyncRequest( diff --git a/src/main/java/at/procon/dip/embedding/registry/EmbeddingProviderConfigResolver.java b/src/main/java/at/procon/dip/embedding/registry/EmbeddingProviderConfigResolver.java index 7c3ece8..5d826b4 100644 --- a/src/main/java/at/procon/dip/embedding/registry/EmbeddingProviderConfigResolver.java +++ b/src/main/java/at/procon/dip/embedding/registry/EmbeddingProviderConfigResolver.java @@ -18,6 +18,10 @@ public class EmbeddingProviderConfigResolver { throw new IllegalArgumentException("Unknown embedding provider config key: " + providerConfigKey); } + EmbeddingProperties.BatchRequestProperties batchRequest = provider.getBatchRequest() == null + ? new EmbeddingProperties.BatchRequestProperties() + : provider.getBatchRequest(); + return ResolvedEmbeddingProviderConfig.builder() .key(providerConfigKey) .providerType(provider.getType()) @@ -27,6 +31,9 @@ public class EmbeddingProviderConfigResolver { .readTimeout(provider.getReadTimeout()) .headers(provider.getHeaders() == null ? Map.of() : Map.copyOf(provider.getHeaders())) .dimensions(provider.getDimensions()) + .batchTruncateText(batchRequest.isTruncateText()) + .batchTruncateLength(batchRequest.getTruncateLength()) + .batchChunkSize(batchRequest.getChunkSize()) .build(); } } diff --git a/src/main/java/at/procon/dip/embedding/service/EmbeddingPersistenceService.java b/src/main/java/at/procon/dip/embedding/service/EmbeddingPersistenceService.java index de63639..f26bb04 100644 --- a/src/main/java/at/procon/dip/embedding/service/EmbeddingPersistenceService.java +++ b/src/main/java/at/procon/dip/embedding/service/EmbeddingPersistenceService.java @@ -8,7 +8,6 @@ import at.procon.dip.domain.document.repository.DocumentEmbeddingRepository; import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository; import at.procon.dip.domain.document.service.DocumentEmbeddingService; import at.procon.dip.embedding.model.EmbeddingProviderResult; -import at.procon.dip.embedding.support.EmbeddingVectorCodec; import java.time.OffsetDateTime; import java.util.UUID; import lombok.RequiredArgsConstructor; @@ -44,11 +43,17 @@ public class EmbeddingPersistenceService { if (result.vectors() == null || result.vectors().isEmpty()) { throw new IllegalArgumentException("Embedding provider result contains no vectors"); } - float[] vector = result.vectors().getFirst(); + saveCompleted(embeddingId, result.vectors().getFirst(), result.tokenCount()); + } + + public void saveCompleted(UUID embeddingId, float[] vector, Integer tokenCount) { + if (vector == null || vector.length == 0) { + throw new IllegalArgumentException("Embedding vector must not be empty"); + } embeddingRepository.updateEmbeddingVector( embeddingId, - vector, //EmbeddingVectorCodec.toPgVector(vector), - result.tokenCount(), + vector, + tokenCount, vector.length ); } diff --git a/src/main/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestrator.java b/src/main/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestrator.java index 083e793..45fadb4 100644 --- a/src/main/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestrator.java +++ b/src/main/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestrator.java @@ -7,11 +7,14 @@ import at.procon.dip.embedding.config.EmbeddingProperties; import at.procon.dip.embedding.job.entity.EmbeddingJob; import at.procon.dip.embedding.job.service.EmbeddingJobService; import at.procon.dip.embedding.model.EmbeddingJobType; +import at.procon.dip.embedding.model.EmbeddingModelDescriptor; import at.procon.dip.embedding.model.EmbeddingProviderResult; import at.procon.dip.embedding.model.EmbeddingUseCase; import at.procon.dip.embedding.policy.EmbeddingProfile; import at.procon.dip.embedding.policy.EmbeddingSelectionPolicy; import at.procon.dip.embedding.registry.EmbeddingModelRegistry; +import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.UUID; import lombok.RequiredArgsConstructor; @@ -63,25 +66,138 @@ public class RepresentationEmbeddingOrchestrator { } List jobs = jobService.claimNextReadyJobs(embeddingProperties.getJobs().getBatchSize()); - for (EmbeddingJob job : jobs) { - processClaimedJob(job); + if (jobs.isEmpty()) { + return 0; + } + + if (embeddingProperties.getJobs().isProcessInBatches()) { + processClaimedJobsInBatches(jobs); + } else { + jobs.forEach(this::processClaimedJobSafely); } return jobs.size(); } @Transactional public void processClaimedJob(EmbeddingJob job) { + EmbeddingModelDescriptor model = modelRegistry.getRequired(job.getModelKey()); + PreparedEmbedding prepared = prepareEmbedding(job, model); + if (prepared == null) { + return; + } + + try { + EmbeddingProviderResult result = executionService.embedTexts( + job.getModelKey(), + EmbeddingUseCase.DOCUMENT, + List.of(prepared.text()) + ); + persistenceService.saveCompleted(prepared.embeddingId(), result); + jobService.markCompleted(job.getId(), result.providerRequestId()); + } catch (RuntimeException ex) { + persistenceService.markFailed(prepared.embeddingId(), ex.getMessage()); + jobService.markFailed(job.getId(), ex.getMessage(), true); + throw ex; + } + } + + private void processClaimedJobsInBatches(List jobs) { + LinkedHashMap> jobsByModelKey = new LinkedHashMap<>(); + for (EmbeddingJob job : jobs) { + jobsByModelKey.computeIfAbsent(job.getModelKey(), ignored -> new ArrayList<>()).add(job); + } + + int executionBatchSize = Math.max(1, embeddingProperties.getJobs().getExecutionBatchSize()); + for (var entry : jobsByModelKey.entrySet()) { + EmbeddingModelDescriptor model = modelRegistry.getRequired(entry.getKey()); + if (!model.supportsBatch()) { + entry.getValue().forEach(this::processClaimedJobSafely); + continue; + } + + List sameModelJobs = entry.getValue(); + for (int start = 0; start < sameModelJobs.size(); start += executionBatchSize) { + List partition = sameModelJobs.subList(start, Math.min(start + executionBatchSize, sameModelJobs.size())); + if (partition.size() == 1) { + processClaimedJobSafely(partition.getFirst()); + } else { + processClaimedBatchSafely(partition, model); + } + } + } + } + + private void processClaimedBatchSafely(List jobs, EmbeddingModelDescriptor model) { + try { + processClaimedBatch(jobs, model); + } catch (RuntimeException ex) { + log.warn("Failed to process embedding batch for model {} ({} jobs): {}", + model.modelKey(), jobs.size(), ex.getMessage(), ex); + } + } + + private void processClaimedJobSafely(EmbeddingJob job) { + try { + processClaimedJob(job); + } catch (RuntimeException ex) { + log.warn("Failed to process embedding job {} for representation {}: {}", + job.getId(), job.getRepresentationId(), ex.getMessage(), ex); + } + } + + private void processClaimedBatch(List jobs, EmbeddingModelDescriptor model) { + List preparedItems = new ArrayList<>(jobs.size()); + for (EmbeddingJob job : jobs) { + PreparedEmbedding prepared = prepareEmbedding(job, model); + if (prepared != null) { + preparedItems.add(prepared); + } + } + + if (preparedItems.isEmpty()) { + return; + } + + try { + EmbeddingProviderResult result = executionService.embedTexts( + model.modelKey(), + EmbeddingUseCase.DOCUMENT, + preparedItems.stream().map(PreparedEmbedding::text).toList() + ); + + if (result.vectors() == null || result.vectors().size() != preparedItems.size()) { + throw new IllegalStateException( + "Embedding provider returned %d vectors for %d batch items" + .formatted(result.vectors() == null ? 0 : result.vectors().size(), preparedItems.size()) + ); + } + + for (int i = 0; i < preparedItems.size(); i++) { + PreparedEmbedding prepared = preparedItems.get(i); + persistenceService.saveCompleted(prepared.embeddingId(), result.vectors().get(i), null); + jobService.markCompleted(prepared.job().getId(), result.providerRequestId()); + } + } catch (RuntimeException ex) { + for (PreparedEmbedding prepared : preparedItems) { + persistenceService.markFailed(prepared.embeddingId(), ex.getMessage()); + jobService.markFailed(prepared.job().getId(), ex.getMessage(), true); + } + throw ex; + } + } + + private PreparedEmbedding prepareEmbedding(EmbeddingJob job, EmbeddingModelDescriptor model) { DocumentTextRepresentation representation = representationRepository.findById(job.getRepresentationId()) .orElseThrow(() -> new IllegalArgumentException("Unknown representation id: " + job.getRepresentationId())); String text = representation.getTextBody(); if (text == null || text.isBlank()) { jobService.markFailed(job.getId(), "No text representation available", false); - return; + return null; } - int maxChars = modelRegistry.getRequired(job.getModelKey()).maxInputChars() != null - ? modelRegistry.getRequired(job.getModelKey()).maxInputChars() + int maxChars = model.maxInputChars() != null + ? model.maxInputChars() : embeddingProperties.getIndexing().getFallbackMaxInputChars(); if (text.length() > maxChars) { text = text.substring(0, maxChars); @@ -89,19 +205,9 @@ public class RepresentationEmbeddingOrchestrator { DocumentEmbedding embedding = persistenceService.ensurePending(representation.getId(), job.getModelKey()); persistenceService.markProcessing(embedding.getId()); + return new PreparedEmbedding(job, embedding.getId(), text); + } - try { - EmbeddingProviderResult result = executionService.embedTexts( - job.getModelKey(), - EmbeddingUseCase.DOCUMENT, - List.of(text) - ); - persistenceService.saveCompleted(embedding.getId(), result); - jobService.markCompleted(job.getId(), result.providerRequestId()); - } catch (RuntimeException ex) { - persistenceService.markFailed(embedding.getId(), ex.getMessage()); - jobService.markFailed(job.getId(), ex.getMessage(), true); - throw ex; - } + private record PreparedEmbedding(EmbeddingJob job, UUID embeddingId, String text) { } } diff --git a/src/main/resources/application-new-example-vector-sync-provider.yml b/src/main/resources/application-new-example-vector-sync-provider.yml index 892e940..434d353 100644 --- a/src/main/resources/application-new-example-vector-sync-provider.yml +++ b/src/main/resources/application-new-example-vector-sync-provider.yml @@ -1,6 +1,12 @@ dip: embedding: enabled: true + + jobs: + enabled: true + process-in-batches: true + execution-batch-size: 20 + default-document-model: e5-default default-query-model: e5-default @@ -12,6 +18,10 @@ dip: read-timeout: 60s headers: X-Client: dip + batch-request: + truncate-text: false + truncate-length: 512 + chunk-size: 20 models: e5-default: @@ -20,6 +30,6 @@ dip: dimensions: 1024 distance-metric: COSINE supports-query-embedding-mode: true - supports-batch: false + supports-batch: true max-input-chars: 8192 active: true diff --git a/src/test/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProviderTest.java b/src/test/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProviderTest.java index d794d3f..889e99a 100644 --- a/src/test/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProviderTest.java +++ b/src/test/java/at/procon/dip/embedding/provider/http/VectorSyncHttpEmbeddingProviderTest.java @@ -7,6 +7,7 @@ import at.procon.dip.embedding.model.EmbeddingModelDescriptor; import at.procon.dip.embedding.model.EmbeddingRequest; import at.procon.dip.embedding.model.EmbeddingUseCase; import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.sun.net.httpserver.HttpExchange; import com.sun.net.httpserver.HttpServer; @@ -17,12 +18,15 @@ import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; class VectorSyncHttpEmbeddingProviderTest { + private final ObjectMapper objectMapper = new ObjectMapper(); private HttpServer server; + private final AtomicReference lastBatchBody = new AtomicReference<>(); @AfterEach void tearDown() { @@ -37,13 +41,16 @@ class VectorSyncHttpEmbeddingProviderTest { server.createContext("/vector-sync", this::handleVectorSync); server.start(); - var provider = new VectorSyncHttpEmbeddingProvider(new ObjectMapper()); + var provider = new VectorSyncHttpEmbeddingProvider(objectMapper); var config = ResolvedEmbeddingProviderConfig.builder() .key("vector-sync-local") .providerType("http-vector-sync") .baseUrl("http://localhost:" + server.getAddress().getPort()) .readTimeout(Duration.ofSeconds(5)) .headers(Map.of("X-Client", "dip-test")) + .batchTruncateText(false) + .batchTruncateLength(512) + .batchChunkSize(20) .build(); var model = new EmbeddingModelDescriptor( "e5-default", @@ -70,6 +77,100 @@ class VectorSyncHttpEmbeddingProviderTest { assertThat(result.tokenCount()).isEqualTo(9); } + @Test + void shouldCallVectorizeBatchEndpointUsingConfiguredBatchRequestSettings() throws Exception { + server = HttpServer.create(new InetSocketAddress(0), 0); + server.createContext("/vectorize-batch", this::handleVectorizeBatch); + server.start(); + + var provider = new VectorSyncHttpEmbeddingProvider(objectMapper); + var config = ResolvedEmbeddingProviderConfig.builder() + .key("vector-sync-local") + .providerType("http-vector-sync") + .baseUrl("http://localhost:" + server.getAddress().getPort()) + .readTimeout(Duration.ofSeconds(5)) + .headers(Map.of("X-Client", "dip-test")) + .batchTruncateText(true) + .batchTruncateLength(768) + .batchChunkSize(33) + .build(); + var model = new EmbeddingModelDescriptor( + "e5-default", + "vector-sync-local", + "intfloat/multilingual-e5-large", + 3, + DistanceMetric.COSINE, + true, + true, + 8192, + true + ); + var request = EmbeddingRequest.builder() + .modelKey("e5-default") + .useCase(EmbeddingUseCase.DOCUMENT) + .texts(List.of("First text", "Second text")) + .providerOptions(Map.of()) + .build(); + + var result = provider.embedDocuments(config, model, request); + + assertThat(result.vectors()).hasSize(2); + assertThat(result.vectors().get(0)).containsExactly(0.1f, 0.2f, 0.3f); + assertThat(result.vectors().get(1)).containsExactly(0.4f, 0.5f, 0.6f); + assertThat(result.tokenCount()).isEqualTo(12); + + JsonNode requestBody = objectMapper.readTree(lastBatchBody.get()); + assertThat(requestBody.get("truncate_text").asBoolean()).isTrue(); + assertThat(requestBody.get("truncate_length").asInt()).isEqualTo(768); + assertThat(requestBody.get("chunk_size").asInt()).isEqualTo(33); + } + + @Test + void shouldAllowPerRequestBatchOverrides() throws Exception { + server = HttpServer.create(new InetSocketAddress(0), 0); + server.createContext("/vectorize-batch", this::handleVectorizeBatch); + server.start(); + + var provider = new VectorSyncHttpEmbeddingProvider(objectMapper); + var config = ResolvedEmbeddingProviderConfig.builder() + .key("vector-sync-local") + .providerType("http-vector-sync") + .baseUrl("http://localhost:" + server.getAddress().getPort()) + .readTimeout(Duration.ofSeconds(5)) + .batchTruncateText(false) + .batchTruncateLength(512) + .batchChunkSize(20) + .build(); + var model = new EmbeddingModelDescriptor( + "e5-default", + "vector-sync-local", + "intfloat/multilingual-e5-large", + 3, + DistanceMetric.COSINE, + true, + true, + 8192, + true + ); + var request = EmbeddingRequest.builder() + .modelKey("e5-default") + .useCase(EmbeddingUseCase.DOCUMENT) + .texts(List.of("First text", "Second text")) + .providerOptions(Map.of( + "truncate_text", true, + "truncate_length", 1024, + "chunk_size", 44 + )) + .build(); + + provider.embedDocuments(config, model, request); + + JsonNode requestBody = objectMapper.readTree(lastBatchBody.get()); + assertThat(requestBody.get("truncate_text").asBoolean()).isTrue(); + assertThat(requestBody.get("truncate_length").asInt()).isEqualTo(1024); + assertThat(requestBody.get("chunk_size").asInt()).isEqualTo(44); + } + private void handleVectorSync(HttpExchange exchange) throws IOException { String body; try (InputStream in = exchange.getRequestBody()) { @@ -81,16 +182,69 @@ class VectorSyncHttpEmbeddingProviderTest { assertThat(body).contains("\"text\":\"This is a sample text to vectorize\""); assertThat(exchange.getRequestHeaders().getFirst("X-Client")).isEqualTo("dip-test"); - String json = "{" - + "\"runtime_ms\":472.49," - + "\"vector\":[0.1,0.2,0.3]," - + "\"incomplete\":false," - + "\"combined_vector\":null," - + "\"token_count\":9," - + "\"model\":\"intfloat/multilingual-e5-large\"," - + "\"max_seq_length\":512" - + "}"; + respondJson(exchange, """ + { + "runtime_ms": 472.49, + "vector": [0.1, 0.2, 0.3], + "incomplete": false, + "combined_vector": null, + "token_count": 9, + "model": "intfloat/multilingual-e5-large", + "max_seq_length": 512 + } + """); + } + + private void handleVectorizeBatch(HttpExchange exchange) throws IOException { + String body; + try (InputStream in = exchange.getRequestBody()) { + body = new String(in.readAllBytes(), StandardCharsets.UTF_8); + } + lastBatchBody.set(body); + + assertThat(exchange.getRequestMethod()).isEqualTo("POST"); + JsonNode requestBody = objectMapper.readTree(body); + assertThat(requestBody.get("model").asText()).isEqualTo("intfloat/multilingual-e5-large"); + assertThat(requestBody.get("items")).hasSize(2); + + String firstId = requestBody.get("items").get(0).get("id").asText(); + String secondId = requestBody.get("items").get(1).get("id").asText(); + + respondJson(exchange, """ + { + "model": "intfloat/multilingual-e5-large", + "count": 2, + "results": [ + { + "id": "%s", + "vector": [0.1, 0.2, 0.3], + "token_count": 5, + "runtime_ms": 0.0, + "incomplete": false, + "combined_vector": null, + "truncated": false, + "truncate_length": 512, + "model": "intfloat/multilingual-e5-large", + "max_seq_length": 512 + }, + { + "id": "%s", + "vector": [0.4, 0.5, 0.6], + "token_count": 7, + "runtime_ms": 0.0, + "incomplete": false, + "combined_vector": null, + "truncated": false, + "truncate_length": 512, + "model": "intfloat/multilingual-e5-large", + "max_seq_length": 512 + } + ] + } + """.formatted(firstId, secondId)); + } + private void respondJson(HttpExchange exchange, String json) throws IOException { byte[] response = json.getBytes(StandardCharsets.UTF_8); exchange.getResponseHeaders().add("Content-Type", "application/json"); exchange.sendResponseHeaders(200, response.length); diff --git a/src/test/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestratorTest.java b/src/test/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestratorTest.java index a4f6bb2..c2050c8 100644 --- a/src/test/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestratorTest.java +++ b/src/test/java/at/procon/dip/embedding/service/RepresentationEmbeddingOrchestratorTest.java @@ -1,10 +1,6 @@ package at.procon.dip.embedding.service; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - +import at.procon.dip.domain.document.DistanceMetric; import at.procon.dip.domain.document.entity.Document; import at.procon.dip.domain.document.entity.DocumentEmbedding; import at.procon.dip.domain.document.entity.DocumentTextRepresentation; @@ -12,20 +8,25 @@ import at.procon.dip.domain.document.repository.DocumentTextRepresentationReposi import at.procon.dip.embedding.config.EmbeddingProperties; import at.procon.dip.embedding.job.entity.EmbeddingJob; import at.procon.dip.embedding.job.service.EmbeddingJobService; -import at.procon.dip.embedding.model.EmbeddingJobStatus; -import at.procon.dip.embedding.model.EmbeddingJobType; -import at.procon.dip.embedding.model.EmbeddingModelDescriptor; -import at.procon.dip.embedding.model.EmbeddingProviderResult; +import at.procon.dip.embedding.model.*; +import at.procon.dip.embedding.policy.EmbeddingSelectionPolicy; import at.procon.dip.embedding.registry.EmbeddingModelRegistry; -import at.procon.dip.domain.document.DistanceMetric; -import java.util.List; -import java.util.UUID; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.AdditionalMatchers.aryEq; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + @ExtendWith(MockitoExtension.class) class RepresentationEmbeddingOrchestratorTest { @@ -38,22 +39,27 @@ class RepresentationEmbeddingOrchestratorTest { @Mock private DocumentTextRepresentationRepository representationRepository; @Mock + private EmbeddingSelectionPolicy selectionPolicy; + @Mock private EmbeddingModelRegistry modelRegistry; + private EmbeddingProperties properties; private RepresentationEmbeddingOrchestrator orchestrator; @BeforeEach void setUp() { - EmbeddingProperties properties = new EmbeddingProperties(); + properties = new EmbeddingProperties(); properties.setEnabled(true); properties.getJobs().setEnabled(true); properties.getJobs().setBatchSize(10); + properties.getJobs().setExecutionBatchSize(10); properties.getIndexing().setFallbackMaxInputChars(8192); orchestrator = new RepresentationEmbeddingOrchestrator( jobService, executionService, persistenceService, representationRepository, + selectionPolicy, modelRegistry, properties ); @@ -80,16 +86,17 @@ class RepresentationEmbeddingOrchestratorTest { .document(document) .textBody("District heating optimization strategy") .build(); - when(representationRepository.findById(representationId)).thenReturn(java.util.Optional.of(representation)); - when(modelRegistry.getRequired("mock-search")) - .thenReturn(new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16, - DistanceMetric.COSINE, true, false, 4096, true)); + EmbeddingModelDescriptor model = new EmbeddingModelDescriptor( + "mock-search", "mock-provider", "mock-search", 16, + DistanceMetric.COSINE, true, false, 4096, true + ); + when(representationRepository.findById(representationId)).thenReturn(Optional.of(representation)); + when(modelRegistry.getRequired("mock-search")).thenReturn(model); when(persistenceService.ensurePending(representationId, "mock-search")) .thenReturn(DocumentEmbedding.builder().id(embeddingId).build()); - when(executionService.embedTexts(eq("mock-search"), any(), any())) + when(executionService.embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), any())) .thenReturn(new EmbeddingProviderResult( - new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16, - DistanceMetric.COSINE, true, false, 4096, true), + model, List.of(new float[]{0.1f, 0.2f}), List.of(), "req-1", @@ -102,4 +109,121 @@ class RepresentationEmbeddingOrchestratorTest { verify(persistenceService).saveCompleted(eq(embeddingId), any(EmbeddingProviderResult.class)); verify(jobService).markCompleted(job.getId(), "req-1"); } + + @Test + void processNextReadyBatch_should_group_batchable_jobs_and_call_provider_once() { + properties.getJobs().setProcessInBatches(true); + properties.getJobs().setExecutionBatchSize(10); + + UUID documentId = UUID.randomUUID(); + UUID representationId1 = UUID.randomUUID(); + UUID representationId2 = UUID.randomUUID(); + UUID embeddingId1 = UUID.randomUUID(); + UUID embeddingId2 = UUID.randomUUID(); + + EmbeddingJob job1 = EmbeddingJob.builder() + .id(UUID.randomUUID()) + .documentId(documentId) + .representationId(representationId1) + .modelKey("e5-default") + .jobType(EmbeddingJobType.DOCUMENT_EMBED) + .status(EmbeddingJobStatus.IN_PROGRESS) + .attemptCount(1) + .build(); + EmbeddingJob job2 = EmbeddingJob.builder() + .id(UUID.randomUUID()) + .documentId(documentId) + .representationId(representationId2) + .modelKey("e5-default") + .jobType(EmbeddingJobType.DOCUMENT_EMBED) + .status(EmbeddingJobStatus.IN_PROGRESS) + .attemptCount(1) + .build(); + + when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job1, job2)); + + Document document = Document.builder().id(documentId).build(); + when(representationRepository.findById(representationId1)).thenReturn(Optional.of( + DocumentTextRepresentation.builder().id(representationId1).document(document).textBody("alpha").build() + )); + when(representationRepository.findById(representationId2)).thenReturn(Optional.of( + DocumentTextRepresentation.builder().id(representationId2).document(document).textBody("beta").build() + )); + + EmbeddingModelDescriptor model = new EmbeddingModelDescriptor( + "e5-default", "vector-sync-e5", "intfloat/multilingual-e5-large", 3, + DistanceMetric.COSINE, true, true, 4096, true + ); + when(modelRegistry.getRequired("e5-default")).thenReturn(model); + when(persistenceService.ensurePending(representationId1, "e5-default")) + .thenReturn(DocumentEmbedding.builder().id(embeddingId1).build()); + when(persistenceService.ensurePending(representationId2, "e5-default")) + .thenReturn(DocumentEmbedding.builder().id(embeddingId2).build()); + when(executionService.embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any())) + .thenReturn(new EmbeddingProviderResult( + model, + List.of(new float[]{0.1f, 0.2f, 0.3f}, new float[]{0.4f, 0.5f, 0.6f}), + List.of(), + "batch-req-1", + 21 + )); + + int processed = orchestrator.processNextReadyBatch(); + + assertThat(processed).isEqualTo(2); + verify(persistenceService).markProcessing(embeddingId1); + verify(persistenceService).markProcessing(embeddingId2); + ArgumentCaptor> textsCaptor = ArgumentCaptor.forClass(List.class); + verify(executionService, times(1)).embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), textsCaptor.capture()); + assertThat(textsCaptor.getValue()).containsExactly("alpha", "beta"); + verify(persistenceService).saveCompleted(eq(embeddingId1), aryEq(new float[]{0.1f, 0.2f, 0.3f}), eq(null)); + verify(persistenceService).saveCompleted(eq(embeddingId2), aryEq(new float[]{0.4f, 0.5f, 0.6f}), eq(null)); + verify(jobService).markCompleted(job1.getId(), "batch-req-1"); + verify(jobService).markCompleted(job2.getId(), "batch-req-1"); + } + + @Test + void processNextReadyBatch_should_fall_back_to_single_processing_for_non_batch_model() { + properties.getJobs().setProcessInBatches(true); + + UUID documentId = UUID.randomUUID(); + UUID representationId = UUID.randomUUID(); + UUID embeddingId = UUID.randomUUID(); + EmbeddingJob job = EmbeddingJob.builder() + .id(UUID.randomUUID()) + .documentId(documentId) + .representationId(representationId) + .modelKey("mock-search") + .jobType(EmbeddingJobType.DOCUMENT_EMBED) + .status(EmbeddingJobStatus.IN_PROGRESS) + .attemptCount(1) + .build(); + + when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job)); + Document document = Document.builder().id(documentId).build(); + when(representationRepository.findById(representationId)).thenReturn(Optional.of( + DocumentTextRepresentation.builder().id(representationId).document(document).textBody("gamma").build() + )); + EmbeddingModelDescriptor model = new EmbeddingModelDescriptor( + "mock-search", "mock-provider", "mock-search", 2, + DistanceMetric.COSINE, true, false, 4096, true + ); + when(modelRegistry.getRequired("mock-search")).thenReturn(model); + when(persistenceService.ensurePending(representationId, "mock-search")) + .thenReturn(DocumentEmbedding.builder().id(embeddingId).build()); + when(executionService.embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), any())) + .thenReturn(new EmbeddingProviderResult( + model, + List.of(new float[]{0.7f, 0.8f}), + List.of(), + "req-2", + 7 + )); + + orchestrator.processNextReadyBatch(); + + verify(executionService, times(1)).embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), eq(List.of("gamma"))); + verify(persistenceService, never()).saveCompleted(eq(embeddingId), any(float[].class), eq(null)); + verify(persistenceService).saveCompleted(eq(embeddingId), any(EmbeddingProviderResult.class)); + } }