batch embedding support

master
trifonovt 2 weeks ago
parent 52330d751d
commit 678db76415

@ -30,6 +30,14 @@ public class EmbeddingProperties {
private Duration readTimeout = Duration.ofSeconds(60);
private Map<String, String> headers = new LinkedHashMap<>();
private Integer dimensions;
private BatchRequestProperties batchRequest = new BatchRequestProperties();
}
@Data
public static class BatchRequestProperties {
private boolean truncateText = false;
private int truncateLength = 512;
private int chunkSize = 20;
}
@Data
@ -59,6 +67,8 @@ public class EmbeddingProperties {
public static class JobsProperties {
private boolean enabled = false;
private int batchSize = 16;
private boolean processInBatches = false;
private int executionBatchSize = 8;
private int maxRetries = 5;
private Duration initialRetryDelay = Duration.ofSeconds(30);
private Duration maxRetryDelay = Duration.ofHours(6);

@ -13,6 +13,9 @@ public record ResolvedEmbeddingProviderConfig(
Duration connectTimeout,
Duration readTimeout,
Map<String, String> headers,
Integer dimensions
Integer dimensions,
Boolean batchTruncateText,
Integer batchTruncateLength,
Integer batchChunkSize
) {
}

@ -8,8 +8,6 @@ import at.procon.dip.embedding.provider.EmbeddingProvider;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.http.HttpResponse;
import java.util.ArrayList;
import java.util.HashMap;
@ -24,16 +22,6 @@ import org.springframework.stereotype.Component;
* Supported endpoints:
* POST {baseUrl}/vector-sync - single text
* POST {baseUrl}/vectorize-batch - multiple texts
*
* Batch settings are resolved from provider config if present, otherwise defaults are used.
* Supported property keys:
* - vectorize-batch.truncate-text
* - vectorize-batch.truncate-length
* - vectorize-batch.chunk-size
*
* Also accepted as fallbacks:
* - truncate_text / truncate-length / chunk_size
* - truncateText / truncateLength / chunkSize
*/
@Component
public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProviderSupport implements EmbeddingProvider {
@ -105,7 +93,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
try {
return request.texts().size() == 1
? executeSingle(providerConfig, model, request.texts().getFirst())
: executeBatch(providerConfig, model, request.texts());
: executeBatch(providerConfig, model, request);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Embedding provider call interrupted", e);
@ -137,22 +125,20 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
private EmbeddingProviderResult executeBatch(ResolvedEmbeddingProviderConfig providerConfig,
EmbeddingModelDescriptor model,
List<String> texts) throws IOException, InterruptedException {
boolean truncateText = resolveBooleanProperty(providerConfig, TRUNCATE_TEXT_KEYS, DEFAULT_TRUNCATE_TEXT);
int truncateLength = resolveIntProperty(providerConfig, TRUNCATE_LENGTH_KEYS, DEFAULT_TRUNCATE_LENGTH);
int chunkSize = resolveIntProperty(providerConfig, CHUNK_SIZE_KEYS, DEFAULT_CHUNK_SIZE);
EmbeddingRequest request) throws IOException, InterruptedException {
BatchRequestSettings settings = resolveBatchRequestSettings(providerConfig, request.providerOptions());
if (truncateLength <= 0) {
if (settings.truncateLength() <= 0) {
throw new IllegalArgumentException("Batch truncate length must be > 0");
}
if (chunkSize <= 0) {
if (settings.chunkSize() <= 0) {
throw new IllegalArgumentException("Batch chunk size must be > 0");
}
List<String> requestOrder = new ArrayList<>(texts.size());
List<VectorizeBatchItemRequest> items = new ArrayList<>(texts.size());
List<String> requestOrder = new ArrayList<>(request.texts().size());
List<VectorizeBatchItemRequest> items = new ArrayList<>(request.texts().size());
for (String text : texts) {
for (String text : request.texts()) {
String id = UUID.randomUUID().toString();
requestOrder.add(id);
items.add(new VectorizeBatchItemRequest(id, text));
@ -163,9 +149,9 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
"/vectorize-batch",
new VectorizeBatchRequest(
model.providerModelKey(),
truncateText,
truncateLength,
chunkSize,
settings.truncateText(),
settings.truncateLength(),
settings.chunkSize(),
items
)
);
@ -180,7 +166,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
resultById.put(result.id, result);
}
List<float[]> vectors = new ArrayList<>(texts.size());
List<float[]> vectors = new ArrayList<>(request.texts().size());
int totalTokenCount = 0;
boolean hasAnyTokenCount = false;
@ -207,38 +193,35 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
);
}
private float[] extractVector(List<Float> vector,
List<Float> combinedVector,
EmbeddingModelDescriptor model) {
float[] resolved;
if (combinedVector != null && !combinedVector.isEmpty()) {
resolved = toArray(combinedVector);
} else if (vector != null && !vector.isEmpty()) {
resolved = toArray(vector);
} else {
throw new IllegalStateException("Embedding provider returned no vector");
}
if (model.dimensions() > 0 && resolved.length != model.dimensions()) {
throw new IllegalStateException(
"Embedding provider returned dimension %d for model %s, expected %d"
.formatted(resolved.length, model.modelKey(), model.dimensions())
private BatchRequestSettings resolveBatchRequestSettings(ResolvedEmbeddingProviderConfig providerConfig,
Map<String, Object> providerOptions) {
boolean truncateText = resolveBooleanOption(
providerOptions,
TRUNCATE_TEXT_KEYS,
providerConfig.batchTruncateText() != null ? providerConfig.batchTruncateText() : DEFAULT_TRUNCATE_TEXT
);
int truncateLength = resolveIntOption(
providerOptions,
TRUNCATE_LENGTH_KEYS,
providerConfig.batchTruncateLength() != null ? providerConfig.batchTruncateLength() : DEFAULT_TRUNCATE_LENGTH
);
int chunkSize = resolveIntOption(
providerOptions,
CHUNK_SIZE_KEYS,
providerConfig.batchChunkSize() != null ? providerConfig.batchChunkSize() : DEFAULT_CHUNK_SIZE
);
return new BatchRequestSettings(truncateText, truncateLength, chunkSize);
}
return resolved;
}
private boolean resolveBooleanProperty(ResolvedEmbeddingProviderConfig providerConfig,
private boolean resolveBooleanOption(Map<String, Object> providerOptions,
List<String> keys,
boolean defaultValue) {
Object raw = resolveProviderConfigValue(providerConfig, keys);
Object raw = resolveOption(providerOptions, keys);
if (raw == null) {
return defaultValue;
}
if (raw instanceof Boolean b) {
return b;
if (raw instanceof Boolean booleanValue) {
return booleanValue;
}
String normalized = String.valueOf(raw).trim();
if (normalized.isEmpty()) {
@ -247,15 +230,15 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
return Boolean.parseBoolean(normalized);
}
private int resolveIntProperty(ResolvedEmbeddingProviderConfig providerConfig,
private int resolveIntOption(Map<String, Object> providerOptions,
List<String> keys,
int defaultValue) {
Object raw = resolveProviderConfigValue(providerConfig, keys);
Object raw = resolveOption(providerOptions, keys);
if (raw == null) {
return defaultValue;
}
if (raw instanceof Number n) {
return n.intValue();
if (raw instanceof Number number) {
return number.intValue();
}
String normalized = String.valueOf(raw).trim();
if (normalized.isEmpty()) {
@ -264,89 +247,42 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
return Integer.parseInt(normalized);
}
@SuppressWarnings("unchecked")
private Object resolveProviderConfigValue(ResolvedEmbeddingProviderConfig providerConfig,
List<String> keys) {
List<Object> containers = new ArrayList<>();
containers.add(providerConfig);
addIfNonNull(containers, invokeNoArg(providerConfig, "properties"));
addIfNonNull(containers, invokeNoArg(providerConfig, "providerProperties"));
addIfNonNull(containers, invokeNoArg(providerConfig, "config"));
addIfNonNull(containers, invokeNoArg(providerConfig, "settings"));
addIfNonNull(containers, invokeNoArg(providerConfig, "options"));
addIfNonNull(containers, readField(providerConfig, "properties"));
addIfNonNull(containers, readField(providerConfig, "providerProperties"));
addIfNonNull(containers, readField(providerConfig, "config"));
addIfNonNull(containers, readField(providerConfig, "settings"));
addIfNonNull(containers, readField(providerConfig, "options"));
for (Object container : containers) {
if (container instanceof Map<?, ?> map) {
for (String key : keys) {
if (map.containsKey(key)) {
return map.get(key);
}
}
private Object resolveOption(Map<String, Object> providerOptions, List<String> keys) {
if (providerOptions == null || providerOptions.isEmpty()) {
return null;
}
for (String key : keys) {
Object value = invokeStringArg(container, "get", key);
if (value != null) {
return value;
}
value = invokeStringArg(container, "getProperty", key);
if (value != null) {
return value;
}
value = invokeStringArg(container, "property", key);
if (value != null) {
return value;
if (providerOptions.containsKey(key)) {
return providerOptions.get(key);
}
value = invokeStringArg(container, "option", key);
if (value != null) {
return value;
}
}
}
return null;
}
private void addIfNonNull(List<Object> containers, Object value) {
if (value != null) {
containers.add(value);
}
}
private float[] extractVector(List<Float> vector,
List<Float> combinedVector,
EmbeddingModelDescriptor model) {
float[] resolved;
private Object invokeNoArg(Object target, String methodName) {
try {
Method method = target.getClass().getMethod(methodName);
method.setAccessible(true);
return method.invoke(target);
} catch (Exception ignored) {
return null;
}
if (combinedVector != null && !combinedVector.isEmpty()) {
resolved = toArray(combinedVector);
} else if (vector != null && !vector.isEmpty()) {
resolved = toArray(vector);
} else {
throw new IllegalStateException("Embedding provider returned no vector");
}
private Object invokeStringArg(Object target, String methodName, String arg) {
try {
Method method = target.getClass().getMethod(methodName, String.class);
method.setAccessible(true);
return method.invoke(target, arg);
} catch (Exception ignored) {
return null;
}
if (model.dimensions() > 0 && resolved.length != model.dimensions()) {
throw new IllegalStateException(
"Embedding provider returned dimension %d for model %s, expected %d"
.formatted(resolved.length, model.modelKey(), model.dimensions())
);
}
private Object readField(Object target, String fieldName) {
try {
Field field = target.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
return field.get(target);
} catch (Exception ignored) {
return null;
return resolved;
}
private record BatchRequestSettings(boolean truncateText, int truncateLength, int chunkSize) {
}
private record VectorSyncRequest(

@ -18,6 +18,10 @@ public class EmbeddingProviderConfigResolver {
throw new IllegalArgumentException("Unknown embedding provider config key: " + providerConfigKey);
}
EmbeddingProperties.BatchRequestProperties batchRequest = provider.getBatchRequest() == null
? new EmbeddingProperties.BatchRequestProperties()
: provider.getBatchRequest();
return ResolvedEmbeddingProviderConfig.builder()
.key(providerConfigKey)
.providerType(provider.getType())
@ -27,6 +31,9 @@ public class EmbeddingProviderConfigResolver {
.readTimeout(provider.getReadTimeout())
.headers(provider.getHeaders() == null ? Map.of() : Map.copyOf(provider.getHeaders()))
.dimensions(provider.getDimensions())
.batchTruncateText(batchRequest.isTruncateText())
.batchTruncateLength(batchRequest.getTruncateLength())
.batchChunkSize(batchRequest.getChunkSize())
.build();
}
}

@ -8,7 +8,6 @@ import at.procon.dip.domain.document.repository.DocumentEmbeddingRepository;
import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository;
import at.procon.dip.domain.document.service.DocumentEmbeddingService;
import at.procon.dip.embedding.model.EmbeddingProviderResult;
import at.procon.dip.embedding.support.EmbeddingVectorCodec;
import java.time.OffsetDateTime;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
@ -44,11 +43,17 @@ public class EmbeddingPersistenceService {
if (result.vectors() == null || result.vectors().isEmpty()) {
throw new IllegalArgumentException("Embedding provider result contains no vectors");
}
float[] vector = result.vectors().getFirst();
saveCompleted(embeddingId, result.vectors().getFirst(), result.tokenCount());
}
public void saveCompleted(UUID embeddingId, float[] vector, Integer tokenCount) {
if (vector == null || vector.length == 0) {
throw new IllegalArgumentException("Embedding vector must not be empty");
}
embeddingRepository.updateEmbeddingVector(
embeddingId,
vector, //EmbeddingVectorCodec.toPgVector(vector),
result.tokenCount(),
vector,
tokenCount,
vector.length
);
}

@ -7,11 +7,14 @@ import at.procon.dip.embedding.config.EmbeddingProperties;
import at.procon.dip.embedding.job.entity.EmbeddingJob;
import at.procon.dip.embedding.job.service.EmbeddingJobService;
import at.procon.dip.embedding.model.EmbeddingJobType;
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
import at.procon.dip.embedding.model.EmbeddingProviderResult;
import at.procon.dip.embedding.model.EmbeddingUseCase;
import at.procon.dip.embedding.policy.EmbeddingProfile;
import at.procon.dip.embedding.policy.EmbeddingSelectionPolicy;
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.UUID;
import lombok.RequiredArgsConstructor;
@ -63,25 +66,138 @@ public class RepresentationEmbeddingOrchestrator {
}
List<EmbeddingJob> jobs = jobService.claimNextReadyJobs(embeddingProperties.getJobs().getBatchSize());
for (EmbeddingJob job : jobs) {
processClaimedJob(job);
if (jobs.isEmpty()) {
return 0;
}
if (embeddingProperties.getJobs().isProcessInBatches()) {
processClaimedJobsInBatches(jobs);
} else {
jobs.forEach(this::processClaimedJobSafely);
}
return jobs.size();
}
@Transactional
public void processClaimedJob(EmbeddingJob job) {
EmbeddingModelDescriptor model = modelRegistry.getRequired(job.getModelKey());
PreparedEmbedding prepared = prepareEmbedding(job, model);
if (prepared == null) {
return;
}
try {
EmbeddingProviderResult result = executionService.embedTexts(
job.getModelKey(),
EmbeddingUseCase.DOCUMENT,
List.of(prepared.text())
);
persistenceService.saveCompleted(prepared.embeddingId(), result);
jobService.markCompleted(job.getId(), result.providerRequestId());
} catch (RuntimeException ex) {
persistenceService.markFailed(prepared.embeddingId(), ex.getMessage());
jobService.markFailed(job.getId(), ex.getMessage(), true);
throw ex;
}
}
private void processClaimedJobsInBatches(List<EmbeddingJob> jobs) {
LinkedHashMap<String, List<EmbeddingJob>> jobsByModelKey = new LinkedHashMap<>();
for (EmbeddingJob job : jobs) {
jobsByModelKey.computeIfAbsent(job.getModelKey(), ignored -> new ArrayList<>()).add(job);
}
int executionBatchSize = Math.max(1, embeddingProperties.getJobs().getExecutionBatchSize());
for (var entry : jobsByModelKey.entrySet()) {
EmbeddingModelDescriptor model = modelRegistry.getRequired(entry.getKey());
if (!model.supportsBatch()) {
entry.getValue().forEach(this::processClaimedJobSafely);
continue;
}
List<EmbeddingJob> sameModelJobs = entry.getValue();
for (int start = 0; start < sameModelJobs.size(); start += executionBatchSize) {
List<EmbeddingJob> partition = sameModelJobs.subList(start, Math.min(start + executionBatchSize, sameModelJobs.size()));
if (partition.size() == 1) {
processClaimedJobSafely(partition.getFirst());
} else {
processClaimedBatchSafely(partition, model);
}
}
}
}
private void processClaimedBatchSafely(List<EmbeddingJob> jobs, EmbeddingModelDescriptor model) {
try {
processClaimedBatch(jobs, model);
} catch (RuntimeException ex) {
log.warn("Failed to process embedding batch for model {} ({} jobs): {}",
model.modelKey(), jobs.size(), ex.getMessage(), ex);
}
}
private void processClaimedJobSafely(EmbeddingJob job) {
try {
processClaimedJob(job);
} catch (RuntimeException ex) {
log.warn("Failed to process embedding job {} for representation {}: {}",
job.getId(), job.getRepresentationId(), ex.getMessage(), ex);
}
}
private void processClaimedBatch(List<EmbeddingJob> jobs, EmbeddingModelDescriptor model) {
List<PreparedEmbedding> preparedItems = new ArrayList<>(jobs.size());
for (EmbeddingJob job : jobs) {
PreparedEmbedding prepared = prepareEmbedding(job, model);
if (prepared != null) {
preparedItems.add(prepared);
}
}
if (preparedItems.isEmpty()) {
return;
}
try {
EmbeddingProviderResult result = executionService.embedTexts(
model.modelKey(),
EmbeddingUseCase.DOCUMENT,
preparedItems.stream().map(PreparedEmbedding::text).toList()
);
if (result.vectors() == null || result.vectors().size() != preparedItems.size()) {
throw new IllegalStateException(
"Embedding provider returned %d vectors for %d batch items"
.formatted(result.vectors() == null ? 0 : result.vectors().size(), preparedItems.size())
);
}
for (int i = 0; i < preparedItems.size(); i++) {
PreparedEmbedding prepared = preparedItems.get(i);
persistenceService.saveCompleted(prepared.embeddingId(), result.vectors().get(i), null);
jobService.markCompleted(prepared.job().getId(), result.providerRequestId());
}
} catch (RuntimeException ex) {
for (PreparedEmbedding prepared : preparedItems) {
persistenceService.markFailed(prepared.embeddingId(), ex.getMessage());
jobService.markFailed(prepared.job().getId(), ex.getMessage(), true);
}
throw ex;
}
}
private PreparedEmbedding prepareEmbedding(EmbeddingJob job, EmbeddingModelDescriptor model) {
DocumentTextRepresentation representation = representationRepository.findById(job.getRepresentationId())
.orElseThrow(() -> new IllegalArgumentException("Unknown representation id: " + job.getRepresentationId()));
String text = representation.getTextBody();
if (text == null || text.isBlank()) {
jobService.markFailed(job.getId(), "No text representation available", false);
return;
return null;
}
int maxChars = modelRegistry.getRequired(job.getModelKey()).maxInputChars() != null
? modelRegistry.getRequired(job.getModelKey()).maxInputChars()
int maxChars = model.maxInputChars() != null
? model.maxInputChars()
: embeddingProperties.getIndexing().getFallbackMaxInputChars();
if (text.length() > maxChars) {
text = text.substring(0, maxChars);
@ -89,19 +205,9 @@ public class RepresentationEmbeddingOrchestrator {
DocumentEmbedding embedding = persistenceService.ensurePending(representation.getId(), job.getModelKey());
persistenceService.markProcessing(embedding.getId());
try {
EmbeddingProviderResult result = executionService.embedTexts(
job.getModelKey(),
EmbeddingUseCase.DOCUMENT,
List.of(text)
);
persistenceService.saveCompleted(embedding.getId(), result);
jobService.markCompleted(job.getId(), result.providerRequestId());
} catch (RuntimeException ex) {
persistenceService.markFailed(embedding.getId(), ex.getMessage());
jobService.markFailed(job.getId(), ex.getMessage(), true);
throw ex;
return new PreparedEmbedding(job, embedding.getId(), text);
}
private record PreparedEmbedding(EmbeddingJob job, UUID embeddingId, String text) {
}
}

@ -1,6 +1,12 @@
dip:
embedding:
enabled: true
jobs:
enabled: true
process-in-batches: true
execution-batch-size: 20
default-document-model: e5-default
default-query-model: e5-default
@ -12,6 +18,10 @@ dip:
read-timeout: 60s
headers:
X-Client: dip
batch-request:
truncate-text: false
truncate-length: 512
chunk-size: 20
models:
e5-default:
@ -20,6 +30,6 @@ dip:
dimensions: 1024
distance-metric: COSINE
supports-query-embedding-mode: true
supports-batch: false
supports-batch: true
max-input-chars: 8192
active: true

@ -7,6 +7,7 @@ import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
import at.procon.dip.embedding.model.EmbeddingRequest;
import at.procon.dip.embedding.model.EmbeddingUseCase;
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpServer;
@ -17,12 +18,15 @@ import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
class VectorSyncHttpEmbeddingProviderTest {
private final ObjectMapper objectMapper = new ObjectMapper();
private HttpServer server;
private final AtomicReference<String> lastBatchBody = new AtomicReference<>();
@AfterEach
void tearDown() {
@ -37,13 +41,16 @@ class VectorSyncHttpEmbeddingProviderTest {
server.createContext("/vector-sync", this::handleVectorSync);
server.start();
var provider = new VectorSyncHttpEmbeddingProvider(new ObjectMapper());
var provider = new VectorSyncHttpEmbeddingProvider(objectMapper);
var config = ResolvedEmbeddingProviderConfig.builder()
.key("vector-sync-local")
.providerType("http-vector-sync")
.baseUrl("http://localhost:" + server.getAddress().getPort())
.readTimeout(Duration.ofSeconds(5))
.headers(Map.of("X-Client", "dip-test"))
.batchTruncateText(false)
.batchTruncateLength(512)
.batchChunkSize(20)
.build();
var model = new EmbeddingModelDescriptor(
"e5-default",
@ -70,6 +77,100 @@ class VectorSyncHttpEmbeddingProviderTest {
assertThat(result.tokenCount()).isEqualTo(9);
}
@Test
void shouldCallVectorizeBatchEndpointUsingConfiguredBatchRequestSettings() throws Exception {
server = HttpServer.create(new InetSocketAddress(0), 0);
server.createContext("/vectorize-batch", this::handleVectorizeBatch);
server.start();
var provider = new VectorSyncHttpEmbeddingProvider(objectMapper);
var config = ResolvedEmbeddingProviderConfig.builder()
.key("vector-sync-local")
.providerType("http-vector-sync")
.baseUrl("http://localhost:" + server.getAddress().getPort())
.readTimeout(Duration.ofSeconds(5))
.headers(Map.of("X-Client", "dip-test"))
.batchTruncateText(true)
.batchTruncateLength(768)
.batchChunkSize(33)
.build();
var model = new EmbeddingModelDescriptor(
"e5-default",
"vector-sync-local",
"intfloat/multilingual-e5-large",
3,
DistanceMetric.COSINE,
true,
true,
8192,
true
);
var request = EmbeddingRequest.builder()
.modelKey("e5-default")
.useCase(EmbeddingUseCase.DOCUMENT)
.texts(List.of("First text", "Second text"))
.providerOptions(Map.of())
.build();
var result = provider.embedDocuments(config, model, request);
assertThat(result.vectors()).hasSize(2);
assertThat(result.vectors().get(0)).containsExactly(0.1f, 0.2f, 0.3f);
assertThat(result.vectors().get(1)).containsExactly(0.4f, 0.5f, 0.6f);
assertThat(result.tokenCount()).isEqualTo(12);
JsonNode requestBody = objectMapper.readTree(lastBatchBody.get());
assertThat(requestBody.get("truncate_text").asBoolean()).isTrue();
assertThat(requestBody.get("truncate_length").asInt()).isEqualTo(768);
assertThat(requestBody.get("chunk_size").asInt()).isEqualTo(33);
}
@Test
void shouldAllowPerRequestBatchOverrides() throws Exception {
server = HttpServer.create(new InetSocketAddress(0), 0);
server.createContext("/vectorize-batch", this::handleVectorizeBatch);
server.start();
var provider = new VectorSyncHttpEmbeddingProvider(objectMapper);
var config = ResolvedEmbeddingProviderConfig.builder()
.key("vector-sync-local")
.providerType("http-vector-sync")
.baseUrl("http://localhost:" + server.getAddress().getPort())
.readTimeout(Duration.ofSeconds(5))
.batchTruncateText(false)
.batchTruncateLength(512)
.batchChunkSize(20)
.build();
var model = new EmbeddingModelDescriptor(
"e5-default",
"vector-sync-local",
"intfloat/multilingual-e5-large",
3,
DistanceMetric.COSINE,
true,
true,
8192,
true
);
var request = EmbeddingRequest.builder()
.modelKey("e5-default")
.useCase(EmbeddingUseCase.DOCUMENT)
.texts(List.of("First text", "Second text"))
.providerOptions(Map.of(
"truncate_text", true,
"truncate_length", 1024,
"chunk_size", 44
))
.build();
provider.embedDocuments(config, model, request);
JsonNode requestBody = objectMapper.readTree(lastBatchBody.get());
assertThat(requestBody.get("truncate_text").asBoolean()).isTrue();
assertThat(requestBody.get("truncate_length").asInt()).isEqualTo(1024);
assertThat(requestBody.get("chunk_size").asInt()).isEqualTo(44);
}
private void handleVectorSync(HttpExchange exchange) throws IOException {
String body;
try (InputStream in = exchange.getRequestBody()) {
@ -81,16 +182,69 @@ class VectorSyncHttpEmbeddingProviderTest {
assertThat(body).contains("\"text\":\"This is a sample text to vectorize\"");
assertThat(exchange.getRequestHeaders().getFirst("X-Client")).isEqualTo("dip-test");
String json = "{"
+ "\"runtime_ms\":472.49,"
+ "\"vector\":[0.1,0.2,0.3],"
+ "\"incomplete\":false,"
+ "\"combined_vector\":null,"
+ "\"token_count\":9,"
+ "\"model\":\"intfloat/multilingual-e5-large\","
+ "\"max_seq_length\":512"
+ "}";
respondJson(exchange, """
{
"runtime_ms": 472.49,
"vector": [0.1, 0.2, 0.3],
"incomplete": false,
"combined_vector": null,
"token_count": 9,
"model": "intfloat/multilingual-e5-large",
"max_seq_length": 512
}
""");
}
private void handleVectorizeBatch(HttpExchange exchange) throws IOException {
String body;
try (InputStream in = exchange.getRequestBody()) {
body = new String(in.readAllBytes(), StandardCharsets.UTF_8);
}
lastBatchBody.set(body);
assertThat(exchange.getRequestMethod()).isEqualTo("POST");
JsonNode requestBody = objectMapper.readTree(body);
assertThat(requestBody.get("model").asText()).isEqualTo("intfloat/multilingual-e5-large");
assertThat(requestBody.get("items")).hasSize(2);
String firstId = requestBody.get("items").get(0).get("id").asText();
String secondId = requestBody.get("items").get(1).get("id").asText();
respondJson(exchange, """
{
"model": "intfloat/multilingual-e5-large",
"count": 2,
"results": [
{
"id": "%s",
"vector": [0.1, 0.2, 0.3],
"token_count": 5,
"runtime_ms": 0.0,
"incomplete": false,
"combined_vector": null,
"truncated": false,
"truncate_length": 512,
"model": "intfloat/multilingual-e5-large",
"max_seq_length": 512
},
{
"id": "%s",
"vector": [0.4, 0.5, 0.6],
"token_count": 7,
"runtime_ms": 0.0,
"incomplete": false,
"combined_vector": null,
"truncated": false,
"truncate_length": 512,
"model": "intfloat/multilingual-e5-large",
"max_seq_length": 512
}
]
}
""".formatted(firstId, secondId));
}
private void respondJson(HttpExchange exchange, String json) throws IOException {
byte[] response = json.getBytes(StandardCharsets.UTF_8);
exchange.getResponseHeaders().add("Content-Type", "application/json");
exchange.sendResponseHeaders(200, response.length);

@ -1,10 +1,6 @@
package at.procon.dip.embedding.service;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import at.procon.dip.domain.document.DistanceMetric;
import at.procon.dip.domain.document.entity.Document;
import at.procon.dip.domain.document.entity.DocumentEmbedding;
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
@ -12,20 +8,25 @@ import at.procon.dip.domain.document.repository.DocumentTextRepresentationReposi
import at.procon.dip.embedding.config.EmbeddingProperties;
import at.procon.dip.embedding.job.entity.EmbeddingJob;
import at.procon.dip.embedding.job.service.EmbeddingJobService;
import at.procon.dip.embedding.model.EmbeddingJobStatus;
import at.procon.dip.embedding.model.EmbeddingJobType;
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
import at.procon.dip.embedding.model.EmbeddingProviderResult;
import at.procon.dip.embedding.model.*;
import at.procon.dip.embedding.policy.EmbeddingSelectionPolicy;
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
import at.procon.dip.domain.document.DistanceMetric;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.AdditionalMatchers.aryEq;
import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class RepresentationEmbeddingOrchestratorTest {
@ -38,22 +39,27 @@ class RepresentationEmbeddingOrchestratorTest {
@Mock
private DocumentTextRepresentationRepository representationRepository;
@Mock
private EmbeddingSelectionPolicy selectionPolicy;
@Mock
private EmbeddingModelRegistry modelRegistry;
private EmbeddingProperties properties;
private RepresentationEmbeddingOrchestrator orchestrator;
@BeforeEach
void setUp() {
EmbeddingProperties properties = new EmbeddingProperties();
properties = new EmbeddingProperties();
properties.setEnabled(true);
properties.getJobs().setEnabled(true);
properties.getJobs().setBatchSize(10);
properties.getJobs().setExecutionBatchSize(10);
properties.getIndexing().setFallbackMaxInputChars(8192);
orchestrator = new RepresentationEmbeddingOrchestrator(
jobService,
executionService,
persistenceService,
representationRepository,
selectionPolicy,
modelRegistry,
properties
);
@ -80,16 +86,17 @@ class RepresentationEmbeddingOrchestratorTest {
.document(document)
.textBody("District heating optimization strategy")
.build();
when(representationRepository.findById(representationId)).thenReturn(java.util.Optional.of(representation));
when(modelRegistry.getRequired("mock-search"))
.thenReturn(new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16,
DistanceMetric.COSINE, true, false, 4096, true));
EmbeddingModelDescriptor model = new EmbeddingModelDescriptor(
"mock-search", "mock-provider", "mock-search", 16,
DistanceMetric.COSINE, true, false, 4096, true
);
when(representationRepository.findById(representationId)).thenReturn(Optional.of(representation));
when(modelRegistry.getRequired("mock-search")).thenReturn(model);
when(persistenceService.ensurePending(representationId, "mock-search"))
.thenReturn(DocumentEmbedding.builder().id(embeddingId).build());
when(executionService.embedTexts(eq("mock-search"), any(), any()))
when(executionService.embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), any()))
.thenReturn(new EmbeddingProviderResult(
new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16,
DistanceMetric.COSINE, true, false, 4096, true),
model,
List.of(new float[]{0.1f, 0.2f}),
List.of(),
"req-1",
@ -102,4 +109,121 @@ class RepresentationEmbeddingOrchestratorTest {
verify(persistenceService).saveCompleted(eq(embeddingId), any(EmbeddingProviderResult.class));
verify(jobService).markCompleted(job.getId(), "req-1");
}
@Test
void processNextReadyBatch_should_group_batchable_jobs_and_call_provider_once() {
properties.getJobs().setProcessInBatches(true);
properties.getJobs().setExecutionBatchSize(10);
UUID documentId = UUID.randomUUID();
UUID representationId1 = UUID.randomUUID();
UUID representationId2 = UUID.randomUUID();
UUID embeddingId1 = UUID.randomUUID();
UUID embeddingId2 = UUID.randomUUID();
EmbeddingJob job1 = EmbeddingJob.builder()
.id(UUID.randomUUID())
.documentId(documentId)
.representationId(representationId1)
.modelKey("e5-default")
.jobType(EmbeddingJobType.DOCUMENT_EMBED)
.status(EmbeddingJobStatus.IN_PROGRESS)
.attemptCount(1)
.build();
EmbeddingJob job2 = EmbeddingJob.builder()
.id(UUID.randomUUID())
.documentId(documentId)
.representationId(representationId2)
.modelKey("e5-default")
.jobType(EmbeddingJobType.DOCUMENT_EMBED)
.status(EmbeddingJobStatus.IN_PROGRESS)
.attemptCount(1)
.build();
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job1, job2));
Document document = Document.builder().id(documentId).build();
when(representationRepository.findById(representationId1)).thenReturn(Optional.of(
DocumentTextRepresentation.builder().id(representationId1).document(document).textBody("alpha").build()
));
when(representationRepository.findById(representationId2)).thenReturn(Optional.of(
DocumentTextRepresentation.builder().id(representationId2).document(document).textBody("beta").build()
));
EmbeddingModelDescriptor model = new EmbeddingModelDescriptor(
"e5-default", "vector-sync-e5", "intfloat/multilingual-e5-large", 3,
DistanceMetric.COSINE, true, true, 4096, true
);
when(modelRegistry.getRequired("e5-default")).thenReturn(model);
when(persistenceService.ensurePending(representationId1, "e5-default"))
.thenReturn(DocumentEmbedding.builder().id(embeddingId1).build());
when(persistenceService.ensurePending(representationId2, "e5-default"))
.thenReturn(DocumentEmbedding.builder().id(embeddingId2).build());
when(executionService.embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any()))
.thenReturn(new EmbeddingProviderResult(
model,
List.of(new float[]{0.1f, 0.2f, 0.3f}, new float[]{0.4f, 0.5f, 0.6f}),
List.of(),
"batch-req-1",
21
));
int processed = orchestrator.processNextReadyBatch();
assertThat(processed).isEqualTo(2);
verify(persistenceService).markProcessing(embeddingId1);
verify(persistenceService).markProcessing(embeddingId2);
ArgumentCaptor<List<String>> textsCaptor = ArgumentCaptor.forClass(List.class);
verify(executionService, times(1)).embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), textsCaptor.capture());
assertThat(textsCaptor.getValue()).containsExactly("alpha", "beta");
verify(persistenceService).saveCompleted(eq(embeddingId1), aryEq(new float[]{0.1f, 0.2f, 0.3f}), eq(null));
verify(persistenceService).saveCompleted(eq(embeddingId2), aryEq(new float[]{0.4f, 0.5f, 0.6f}), eq(null));
verify(jobService).markCompleted(job1.getId(), "batch-req-1");
verify(jobService).markCompleted(job2.getId(), "batch-req-1");
}
@Test
void processNextReadyBatch_should_fall_back_to_single_processing_for_non_batch_model() {
properties.getJobs().setProcessInBatches(true);
UUID documentId = UUID.randomUUID();
UUID representationId = UUID.randomUUID();
UUID embeddingId = UUID.randomUUID();
EmbeddingJob job = EmbeddingJob.builder()
.id(UUID.randomUUID())
.documentId(documentId)
.representationId(representationId)
.modelKey("mock-search")
.jobType(EmbeddingJobType.DOCUMENT_EMBED)
.status(EmbeddingJobStatus.IN_PROGRESS)
.attemptCount(1)
.build();
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job));
Document document = Document.builder().id(documentId).build();
when(representationRepository.findById(representationId)).thenReturn(Optional.of(
DocumentTextRepresentation.builder().id(representationId).document(document).textBody("gamma").build()
));
EmbeddingModelDescriptor model = new EmbeddingModelDescriptor(
"mock-search", "mock-provider", "mock-search", 2,
DistanceMetric.COSINE, true, false, 4096, true
);
when(modelRegistry.getRequired("mock-search")).thenReturn(model);
when(persistenceService.ensurePending(representationId, "mock-search"))
.thenReturn(DocumentEmbedding.builder().id(embeddingId).build());
when(executionService.embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), any()))
.thenReturn(new EmbeddingProviderResult(
model,
List.of(new float[]{0.7f, 0.8f}),
List.of(),
"req-2",
7
));
orchestrator.processNextReadyBatch();
verify(executionService, times(1)).embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), eq(List.of("gamma")));
verify(persistenceService, never()).saveCompleted(eq(embeddingId), any(float[].class), eq(null));
verify(persistenceService).saveCompleted(eq(embeddingId), any(EmbeddingProviderResult.class));
}
}

Loading…
Cancel
Save