batch embedding support
This commit is contained in:
parent
52330d751d
commit
678db76415
|
|
@ -30,6 +30,14 @@ public class EmbeddingProperties {
|
||||||
private Duration readTimeout = Duration.ofSeconds(60);
|
private Duration readTimeout = Duration.ofSeconds(60);
|
||||||
private Map<String, String> headers = new LinkedHashMap<>();
|
private Map<String, String> headers = new LinkedHashMap<>();
|
||||||
private Integer dimensions;
|
private Integer dimensions;
|
||||||
|
private BatchRequestProperties batchRequest = new BatchRequestProperties();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class BatchRequestProperties {
|
||||||
|
private boolean truncateText = false;
|
||||||
|
private int truncateLength = 512;
|
||||||
|
private int chunkSize = 20;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
|
@ -59,6 +67,8 @@ public class EmbeddingProperties {
|
||||||
public static class JobsProperties {
|
public static class JobsProperties {
|
||||||
private boolean enabled = false;
|
private boolean enabled = false;
|
||||||
private int batchSize = 16;
|
private int batchSize = 16;
|
||||||
|
private boolean processInBatches = false;
|
||||||
|
private int executionBatchSize = 8;
|
||||||
private int maxRetries = 5;
|
private int maxRetries = 5;
|
||||||
private Duration initialRetryDelay = Duration.ofSeconds(30);
|
private Duration initialRetryDelay = Duration.ofSeconds(30);
|
||||||
private Duration maxRetryDelay = Duration.ofHours(6);
|
private Duration maxRetryDelay = Duration.ofHours(6);
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,9 @@ public record ResolvedEmbeddingProviderConfig(
|
||||||
Duration connectTimeout,
|
Duration connectTimeout,
|
||||||
Duration readTimeout,
|
Duration readTimeout,
|
||||||
Map<String, String> headers,
|
Map<String, String> headers,
|
||||||
Integer dimensions
|
Integer dimensions,
|
||||||
|
Boolean batchTruncateText,
|
||||||
|
Integer batchTruncateLength,
|
||||||
|
Integer batchChunkSize
|
||||||
) {
|
) {
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,6 @@ import at.procon.dip.embedding.provider.EmbeddingProvider;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.lang.reflect.Field;
|
|
||||||
import java.lang.reflect.Method;
|
|
||||||
import java.net.http.HttpResponse;
|
import java.net.http.HttpResponse;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
|
@ -24,16 +22,6 @@ import org.springframework.stereotype.Component;
|
||||||
* Supported endpoints:
|
* Supported endpoints:
|
||||||
* POST {baseUrl}/vector-sync - single text
|
* POST {baseUrl}/vector-sync - single text
|
||||||
* POST {baseUrl}/vectorize-batch - multiple texts
|
* POST {baseUrl}/vectorize-batch - multiple texts
|
||||||
*
|
|
||||||
* Batch settings are resolved from provider config if present, otherwise defaults are used.
|
|
||||||
* Supported property keys:
|
|
||||||
* - vectorize-batch.truncate-text
|
|
||||||
* - vectorize-batch.truncate-length
|
|
||||||
* - vectorize-batch.chunk-size
|
|
||||||
*
|
|
||||||
* Also accepted as fallbacks:
|
|
||||||
* - truncate_text / truncate-length / chunk_size
|
|
||||||
* - truncateText / truncateLength / chunkSize
|
|
||||||
*/
|
*/
|
||||||
@Component
|
@Component
|
||||||
public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProviderSupport implements EmbeddingProvider {
|
public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProviderSupport implements EmbeddingProvider {
|
||||||
|
|
@ -105,7 +93,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
||||||
try {
|
try {
|
||||||
return request.texts().size() == 1
|
return request.texts().size() == 1
|
||||||
? executeSingle(providerConfig, model, request.texts().getFirst())
|
? executeSingle(providerConfig, model, request.texts().getFirst())
|
||||||
: executeBatch(providerConfig, model, request.texts());
|
: executeBatch(providerConfig, model, request);
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
Thread.currentThread().interrupt();
|
Thread.currentThread().interrupt();
|
||||||
throw new IllegalStateException("Embedding provider call interrupted", e);
|
throw new IllegalStateException("Embedding provider call interrupted", e);
|
||||||
|
|
@ -137,22 +125,20 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
||||||
|
|
||||||
private EmbeddingProviderResult executeBatch(ResolvedEmbeddingProviderConfig providerConfig,
|
private EmbeddingProviderResult executeBatch(ResolvedEmbeddingProviderConfig providerConfig,
|
||||||
EmbeddingModelDescriptor model,
|
EmbeddingModelDescriptor model,
|
||||||
List<String> texts) throws IOException, InterruptedException {
|
EmbeddingRequest request) throws IOException, InterruptedException {
|
||||||
boolean truncateText = resolveBooleanProperty(providerConfig, TRUNCATE_TEXT_KEYS, DEFAULT_TRUNCATE_TEXT);
|
BatchRequestSettings settings = resolveBatchRequestSettings(providerConfig, request.providerOptions());
|
||||||
int truncateLength = resolveIntProperty(providerConfig, TRUNCATE_LENGTH_KEYS, DEFAULT_TRUNCATE_LENGTH);
|
|
||||||
int chunkSize = resolveIntProperty(providerConfig, CHUNK_SIZE_KEYS, DEFAULT_CHUNK_SIZE);
|
|
||||||
|
|
||||||
if (truncateLength <= 0) {
|
if (settings.truncateLength() <= 0) {
|
||||||
throw new IllegalArgumentException("Batch truncate length must be > 0");
|
throw new IllegalArgumentException("Batch truncate length must be > 0");
|
||||||
}
|
}
|
||||||
if (chunkSize <= 0) {
|
if (settings.chunkSize() <= 0) {
|
||||||
throw new IllegalArgumentException("Batch chunk size must be > 0");
|
throw new IllegalArgumentException("Batch chunk size must be > 0");
|
||||||
}
|
}
|
||||||
|
|
||||||
List<String> requestOrder = new ArrayList<>(texts.size());
|
List<String> requestOrder = new ArrayList<>(request.texts().size());
|
||||||
List<VectorizeBatchItemRequest> items = new ArrayList<>(texts.size());
|
List<VectorizeBatchItemRequest> items = new ArrayList<>(request.texts().size());
|
||||||
|
|
||||||
for (String text : texts) {
|
for (String text : request.texts()) {
|
||||||
String id = UUID.randomUUID().toString();
|
String id = UUID.randomUUID().toString();
|
||||||
requestOrder.add(id);
|
requestOrder.add(id);
|
||||||
items.add(new VectorizeBatchItemRequest(id, text));
|
items.add(new VectorizeBatchItemRequest(id, text));
|
||||||
|
|
@ -163,9 +149,9 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
||||||
"/vectorize-batch",
|
"/vectorize-batch",
|
||||||
new VectorizeBatchRequest(
|
new VectorizeBatchRequest(
|
||||||
model.providerModelKey(),
|
model.providerModelKey(),
|
||||||
truncateText,
|
settings.truncateText(),
|
||||||
truncateLength,
|
settings.truncateLength(),
|
||||||
chunkSize,
|
settings.chunkSize(),
|
||||||
items
|
items
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
@ -180,7 +166,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
||||||
resultById.put(result.id, result);
|
resultById.put(result.id, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
List<float[]> vectors = new ArrayList<>(texts.size());
|
List<float[]> vectors = new ArrayList<>(request.texts().size());
|
||||||
int totalTokenCount = 0;
|
int totalTokenCount = 0;
|
||||||
boolean hasAnyTokenCount = false;
|
boolean hasAnyTokenCount = false;
|
||||||
|
|
||||||
|
|
@ -207,6 +193,72 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private BatchRequestSettings resolveBatchRequestSettings(ResolvedEmbeddingProviderConfig providerConfig,
|
||||||
|
Map<String, Object> providerOptions) {
|
||||||
|
boolean truncateText = resolveBooleanOption(
|
||||||
|
providerOptions,
|
||||||
|
TRUNCATE_TEXT_KEYS,
|
||||||
|
providerConfig.batchTruncateText() != null ? providerConfig.batchTruncateText() : DEFAULT_TRUNCATE_TEXT
|
||||||
|
);
|
||||||
|
int truncateLength = resolveIntOption(
|
||||||
|
providerOptions,
|
||||||
|
TRUNCATE_LENGTH_KEYS,
|
||||||
|
providerConfig.batchTruncateLength() != null ? providerConfig.batchTruncateLength() : DEFAULT_TRUNCATE_LENGTH
|
||||||
|
);
|
||||||
|
int chunkSize = resolveIntOption(
|
||||||
|
providerOptions,
|
||||||
|
CHUNK_SIZE_KEYS,
|
||||||
|
providerConfig.batchChunkSize() != null ? providerConfig.batchChunkSize() : DEFAULT_CHUNK_SIZE
|
||||||
|
);
|
||||||
|
return new BatchRequestSettings(truncateText, truncateLength, chunkSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean resolveBooleanOption(Map<String, Object> providerOptions,
|
||||||
|
List<String> keys,
|
||||||
|
boolean defaultValue) {
|
||||||
|
Object raw = resolveOption(providerOptions, keys);
|
||||||
|
if (raw == null) {
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
if (raw instanceof Boolean booleanValue) {
|
||||||
|
return booleanValue;
|
||||||
|
}
|
||||||
|
String normalized = String.valueOf(raw).trim();
|
||||||
|
if (normalized.isEmpty()) {
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
return Boolean.parseBoolean(normalized);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int resolveIntOption(Map<String, Object> providerOptions,
|
||||||
|
List<String> keys,
|
||||||
|
int defaultValue) {
|
||||||
|
Object raw = resolveOption(providerOptions, keys);
|
||||||
|
if (raw == null) {
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
if (raw instanceof Number number) {
|
||||||
|
return number.intValue();
|
||||||
|
}
|
||||||
|
String normalized = String.valueOf(raw).trim();
|
||||||
|
if (normalized.isEmpty()) {
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
return Integer.parseInt(normalized);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Object resolveOption(Map<String, Object> providerOptions, List<String> keys) {
|
||||||
|
if (providerOptions == null || providerOptions.isEmpty()) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
for (String key : keys) {
|
||||||
|
if (providerOptions.containsKey(key)) {
|
||||||
|
return providerOptions.get(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
private float[] extractVector(List<Float> vector,
|
private float[] extractVector(List<Float> vector,
|
||||||
List<Float> combinedVector,
|
List<Float> combinedVector,
|
||||||
EmbeddingModelDescriptor model) {
|
EmbeddingModelDescriptor model) {
|
||||||
|
|
@ -230,123 +282,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
||||||
return resolved;
|
return resolved;
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean resolveBooleanProperty(ResolvedEmbeddingProviderConfig providerConfig,
|
private record BatchRequestSettings(boolean truncateText, int truncateLength, int chunkSize) {
|
||||||
List<String> keys,
|
|
||||||
boolean defaultValue) {
|
|
||||||
Object raw = resolveProviderConfigValue(providerConfig, keys);
|
|
||||||
if (raw == null) {
|
|
||||||
return defaultValue;
|
|
||||||
}
|
|
||||||
if (raw instanceof Boolean b) {
|
|
||||||
return b;
|
|
||||||
}
|
|
||||||
String normalized = String.valueOf(raw).trim();
|
|
||||||
if (normalized.isEmpty()) {
|
|
||||||
return defaultValue;
|
|
||||||
}
|
|
||||||
return Boolean.parseBoolean(normalized);
|
|
||||||
}
|
|
||||||
|
|
||||||
private int resolveIntProperty(ResolvedEmbeddingProviderConfig providerConfig,
|
|
||||||
List<String> keys,
|
|
||||||
int defaultValue) {
|
|
||||||
Object raw = resolveProviderConfigValue(providerConfig, keys);
|
|
||||||
if (raw == null) {
|
|
||||||
return defaultValue;
|
|
||||||
}
|
|
||||||
if (raw instanceof Number n) {
|
|
||||||
return n.intValue();
|
|
||||||
}
|
|
||||||
String normalized = String.valueOf(raw).trim();
|
|
||||||
if (normalized.isEmpty()) {
|
|
||||||
return defaultValue;
|
|
||||||
}
|
|
||||||
return Integer.parseInt(normalized);
|
|
||||||
}
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
private Object resolveProviderConfigValue(ResolvedEmbeddingProviderConfig providerConfig,
|
|
||||||
List<String> keys) {
|
|
||||||
List<Object> containers = new ArrayList<>();
|
|
||||||
containers.add(providerConfig);
|
|
||||||
|
|
||||||
addIfNonNull(containers, invokeNoArg(providerConfig, "properties"));
|
|
||||||
addIfNonNull(containers, invokeNoArg(providerConfig, "providerProperties"));
|
|
||||||
addIfNonNull(containers, invokeNoArg(providerConfig, "config"));
|
|
||||||
addIfNonNull(containers, invokeNoArg(providerConfig, "settings"));
|
|
||||||
addIfNonNull(containers, invokeNoArg(providerConfig, "options"));
|
|
||||||
addIfNonNull(containers, readField(providerConfig, "properties"));
|
|
||||||
addIfNonNull(containers, readField(providerConfig, "providerProperties"));
|
|
||||||
addIfNonNull(containers, readField(providerConfig, "config"));
|
|
||||||
addIfNonNull(containers, readField(providerConfig, "settings"));
|
|
||||||
addIfNonNull(containers, readField(providerConfig, "options"));
|
|
||||||
|
|
||||||
for (Object container : containers) {
|
|
||||||
if (container instanceof Map<?, ?> map) {
|
|
||||||
for (String key : keys) {
|
|
||||||
if (map.containsKey(key)) {
|
|
||||||
return map.get(key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (String key : keys) {
|
|
||||||
Object value = invokeStringArg(container, "get", key);
|
|
||||||
if (value != null) {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
value = invokeStringArg(container, "getProperty", key);
|
|
||||||
if (value != null) {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
value = invokeStringArg(container, "property", key);
|
|
||||||
if (value != null) {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
value = invokeStringArg(container, "option", key);
|
|
||||||
if (value != null) {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void addIfNonNull(List<Object> containers, Object value) {
|
|
||||||
if (value != null) {
|
|
||||||
containers.add(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Object invokeNoArg(Object target, String methodName) {
|
|
||||||
try {
|
|
||||||
Method method = target.getClass().getMethod(methodName);
|
|
||||||
method.setAccessible(true);
|
|
||||||
return method.invoke(target);
|
|
||||||
} catch (Exception ignored) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Object invokeStringArg(Object target, String methodName, String arg) {
|
|
||||||
try {
|
|
||||||
Method method = target.getClass().getMethod(methodName, String.class);
|
|
||||||
method.setAccessible(true);
|
|
||||||
return method.invoke(target, arg);
|
|
||||||
} catch (Exception ignored) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Object readField(Object target, String fieldName) {
|
|
||||||
try {
|
|
||||||
Field field = target.getClass().getDeclaredField(fieldName);
|
|
||||||
field.setAccessible(true);
|
|
||||||
return field.get(target);
|
|
||||||
} catch (Exception ignored) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private record VectorSyncRequest(
|
private record VectorSyncRequest(
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,10 @@ public class EmbeddingProviderConfigResolver {
|
||||||
throw new IllegalArgumentException("Unknown embedding provider config key: " + providerConfigKey);
|
throw new IllegalArgumentException("Unknown embedding provider config key: " + providerConfigKey);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EmbeddingProperties.BatchRequestProperties batchRequest = provider.getBatchRequest() == null
|
||||||
|
? new EmbeddingProperties.BatchRequestProperties()
|
||||||
|
: provider.getBatchRequest();
|
||||||
|
|
||||||
return ResolvedEmbeddingProviderConfig.builder()
|
return ResolvedEmbeddingProviderConfig.builder()
|
||||||
.key(providerConfigKey)
|
.key(providerConfigKey)
|
||||||
.providerType(provider.getType())
|
.providerType(provider.getType())
|
||||||
|
|
@ -27,6 +31,9 @@ public class EmbeddingProviderConfigResolver {
|
||||||
.readTimeout(provider.getReadTimeout())
|
.readTimeout(provider.getReadTimeout())
|
||||||
.headers(provider.getHeaders() == null ? Map.of() : Map.copyOf(provider.getHeaders()))
|
.headers(provider.getHeaders() == null ? Map.of() : Map.copyOf(provider.getHeaders()))
|
||||||
.dimensions(provider.getDimensions())
|
.dimensions(provider.getDimensions())
|
||||||
|
.batchTruncateText(batchRequest.isTruncateText())
|
||||||
|
.batchTruncateLength(batchRequest.getTruncateLength())
|
||||||
|
.batchChunkSize(batchRequest.getChunkSize())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ import at.procon.dip.domain.document.repository.DocumentEmbeddingRepository;
|
||||||
import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository;
|
import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository;
|
||||||
import at.procon.dip.domain.document.service.DocumentEmbeddingService;
|
import at.procon.dip.domain.document.service.DocumentEmbeddingService;
|
||||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||||
import at.procon.dip.embedding.support.EmbeddingVectorCodec;
|
|
||||||
import java.time.OffsetDateTime;
|
import java.time.OffsetDateTime;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
@ -44,11 +43,17 @@ public class EmbeddingPersistenceService {
|
||||||
if (result.vectors() == null || result.vectors().isEmpty()) {
|
if (result.vectors() == null || result.vectors().isEmpty()) {
|
||||||
throw new IllegalArgumentException("Embedding provider result contains no vectors");
|
throw new IllegalArgumentException("Embedding provider result contains no vectors");
|
||||||
}
|
}
|
||||||
float[] vector = result.vectors().getFirst();
|
saveCompleted(embeddingId, result.vectors().getFirst(), result.tokenCount());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void saveCompleted(UUID embeddingId, float[] vector, Integer tokenCount) {
|
||||||
|
if (vector == null || vector.length == 0) {
|
||||||
|
throw new IllegalArgumentException("Embedding vector must not be empty");
|
||||||
|
}
|
||||||
embeddingRepository.updateEmbeddingVector(
|
embeddingRepository.updateEmbeddingVector(
|
||||||
embeddingId,
|
embeddingId,
|
||||||
vector, //EmbeddingVectorCodec.toPgVector(vector),
|
vector,
|
||||||
result.tokenCount(),
|
tokenCount,
|
||||||
vector.length
|
vector.length
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,14 @@ import at.procon.dip.embedding.config.EmbeddingProperties;
|
||||||
import at.procon.dip.embedding.job.entity.EmbeddingJob;
|
import at.procon.dip.embedding.job.entity.EmbeddingJob;
|
||||||
import at.procon.dip.embedding.job.service.EmbeddingJobService;
|
import at.procon.dip.embedding.job.service.EmbeddingJobService;
|
||||||
import at.procon.dip.embedding.model.EmbeddingJobType;
|
import at.procon.dip.embedding.model.EmbeddingJobType;
|
||||||
|
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||||
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
||||||
import at.procon.dip.embedding.policy.EmbeddingProfile;
|
import at.procon.dip.embedding.policy.EmbeddingProfile;
|
||||||
import at.procon.dip.embedding.policy.EmbeddingSelectionPolicy;
|
import at.procon.dip.embedding.policy.EmbeddingSelectionPolicy;
|
||||||
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.UUID;
|
import java.util.UUID;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
|
@ -63,25 +66,138 @@ public class RepresentationEmbeddingOrchestrator {
|
||||||
}
|
}
|
||||||
|
|
||||||
List<EmbeddingJob> jobs = jobService.claimNextReadyJobs(embeddingProperties.getJobs().getBatchSize());
|
List<EmbeddingJob> jobs = jobService.claimNextReadyJobs(embeddingProperties.getJobs().getBatchSize());
|
||||||
for (EmbeddingJob job : jobs) {
|
if (jobs.isEmpty()) {
|
||||||
processClaimedJob(job);
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embeddingProperties.getJobs().isProcessInBatches()) {
|
||||||
|
processClaimedJobsInBatches(jobs);
|
||||||
|
} else {
|
||||||
|
jobs.forEach(this::processClaimedJobSafely);
|
||||||
}
|
}
|
||||||
return jobs.size();
|
return jobs.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void processClaimedJob(EmbeddingJob job) {
|
public void processClaimedJob(EmbeddingJob job) {
|
||||||
|
EmbeddingModelDescriptor model = modelRegistry.getRequired(job.getModelKey());
|
||||||
|
PreparedEmbedding prepared = prepareEmbedding(job, model);
|
||||||
|
if (prepared == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
EmbeddingProviderResult result = executionService.embedTexts(
|
||||||
|
job.getModelKey(),
|
||||||
|
EmbeddingUseCase.DOCUMENT,
|
||||||
|
List.of(prepared.text())
|
||||||
|
);
|
||||||
|
persistenceService.saveCompleted(prepared.embeddingId(), result);
|
||||||
|
jobService.markCompleted(job.getId(), result.providerRequestId());
|
||||||
|
} catch (RuntimeException ex) {
|
||||||
|
persistenceService.markFailed(prepared.embeddingId(), ex.getMessage());
|
||||||
|
jobService.markFailed(job.getId(), ex.getMessage(), true);
|
||||||
|
throw ex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void processClaimedJobsInBatches(List<EmbeddingJob> jobs) {
|
||||||
|
LinkedHashMap<String, List<EmbeddingJob>> jobsByModelKey = new LinkedHashMap<>();
|
||||||
|
for (EmbeddingJob job : jobs) {
|
||||||
|
jobsByModelKey.computeIfAbsent(job.getModelKey(), ignored -> new ArrayList<>()).add(job);
|
||||||
|
}
|
||||||
|
|
||||||
|
int executionBatchSize = Math.max(1, embeddingProperties.getJobs().getExecutionBatchSize());
|
||||||
|
for (var entry : jobsByModelKey.entrySet()) {
|
||||||
|
EmbeddingModelDescriptor model = modelRegistry.getRequired(entry.getKey());
|
||||||
|
if (!model.supportsBatch()) {
|
||||||
|
entry.getValue().forEach(this::processClaimedJobSafely);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
List<EmbeddingJob> sameModelJobs = entry.getValue();
|
||||||
|
for (int start = 0; start < sameModelJobs.size(); start += executionBatchSize) {
|
||||||
|
List<EmbeddingJob> partition = sameModelJobs.subList(start, Math.min(start + executionBatchSize, sameModelJobs.size()));
|
||||||
|
if (partition.size() == 1) {
|
||||||
|
processClaimedJobSafely(partition.getFirst());
|
||||||
|
} else {
|
||||||
|
processClaimedBatchSafely(partition, model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void processClaimedBatchSafely(List<EmbeddingJob> jobs, EmbeddingModelDescriptor model) {
|
||||||
|
try {
|
||||||
|
processClaimedBatch(jobs, model);
|
||||||
|
} catch (RuntimeException ex) {
|
||||||
|
log.warn("Failed to process embedding batch for model {} ({} jobs): {}",
|
||||||
|
model.modelKey(), jobs.size(), ex.getMessage(), ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void processClaimedJobSafely(EmbeddingJob job) {
|
||||||
|
try {
|
||||||
|
processClaimedJob(job);
|
||||||
|
} catch (RuntimeException ex) {
|
||||||
|
log.warn("Failed to process embedding job {} for representation {}: {}",
|
||||||
|
job.getId(), job.getRepresentationId(), ex.getMessage(), ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void processClaimedBatch(List<EmbeddingJob> jobs, EmbeddingModelDescriptor model) {
|
||||||
|
List<PreparedEmbedding> preparedItems = new ArrayList<>(jobs.size());
|
||||||
|
for (EmbeddingJob job : jobs) {
|
||||||
|
PreparedEmbedding prepared = prepareEmbedding(job, model);
|
||||||
|
if (prepared != null) {
|
||||||
|
preparedItems.add(prepared);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (preparedItems.isEmpty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
EmbeddingProviderResult result = executionService.embedTexts(
|
||||||
|
model.modelKey(),
|
||||||
|
EmbeddingUseCase.DOCUMENT,
|
||||||
|
preparedItems.stream().map(PreparedEmbedding::text).toList()
|
||||||
|
);
|
||||||
|
|
||||||
|
if (result.vectors() == null || result.vectors().size() != preparedItems.size()) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"Embedding provider returned %d vectors for %d batch items"
|
||||||
|
.formatted(result.vectors() == null ? 0 : result.vectors().size(), preparedItems.size())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < preparedItems.size(); i++) {
|
||||||
|
PreparedEmbedding prepared = preparedItems.get(i);
|
||||||
|
persistenceService.saveCompleted(prepared.embeddingId(), result.vectors().get(i), null);
|
||||||
|
jobService.markCompleted(prepared.job().getId(), result.providerRequestId());
|
||||||
|
}
|
||||||
|
} catch (RuntimeException ex) {
|
||||||
|
for (PreparedEmbedding prepared : preparedItems) {
|
||||||
|
persistenceService.markFailed(prepared.embeddingId(), ex.getMessage());
|
||||||
|
jobService.markFailed(prepared.job().getId(), ex.getMessage(), true);
|
||||||
|
}
|
||||||
|
throw ex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private PreparedEmbedding prepareEmbedding(EmbeddingJob job, EmbeddingModelDescriptor model) {
|
||||||
DocumentTextRepresentation representation = representationRepository.findById(job.getRepresentationId())
|
DocumentTextRepresentation representation = representationRepository.findById(job.getRepresentationId())
|
||||||
.orElseThrow(() -> new IllegalArgumentException("Unknown representation id: " + job.getRepresentationId()));
|
.orElseThrow(() -> new IllegalArgumentException("Unknown representation id: " + job.getRepresentationId()));
|
||||||
|
|
||||||
String text = representation.getTextBody();
|
String text = representation.getTextBody();
|
||||||
if (text == null || text.isBlank()) {
|
if (text == null || text.isBlank()) {
|
||||||
jobService.markFailed(job.getId(), "No text representation available", false);
|
jobService.markFailed(job.getId(), "No text representation available", false);
|
||||||
return;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
int maxChars = modelRegistry.getRequired(job.getModelKey()).maxInputChars() != null
|
int maxChars = model.maxInputChars() != null
|
||||||
? modelRegistry.getRequired(job.getModelKey()).maxInputChars()
|
? model.maxInputChars()
|
||||||
: embeddingProperties.getIndexing().getFallbackMaxInputChars();
|
: embeddingProperties.getIndexing().getFallbackMaxInputChars();
|
||||||
if (text.length() > maxChars) {
|
if (text.length() > maxChars) {
|
||||||
text = text.substring(0, maxChars);
|
text = text.substring(0, maxChars);
|
||||||
|
|
@ -89,19 +205,9 @@ public class RepresentationEmbeddingOrchestrator {
|
||||||
|
|
||||||
DocumentEmbedding embedding = persistenceService.ensurePending(representation.getId(), job.getModelKey());
|
DocumentEmbedding embedding = persistenceService.ensurePending(representation.getId(), job.getModelKey());
|
||||||
persistenceService.markProcessing(embedding.getId());
|
persistenceService.markProcessing(embedding.getId());
|
||||||
|
return new PreparedEmbedding(job, embedding.getId(), text);
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
private record PreparedEmbedding(EmbeddingJob job, UUID embeddingId, String text) {
|
||||||
EmbeddingProviderResult result = executionService.embedTexts(
|
|
||||||
job.getModelKey(),
|
|
||||||
EmbeddingUseCase.DOCUMENT,
|
|
||||||
List.of(text)
|
|
||||||
);
|
|
||||||
persistenceService.saveCompleted(embedding.getId(), result);
|
|
||||||
jobService.markCompleted(job.getId(), result.providerRequestId());
|
|
||||||
} catch (RuntimeException ex) {
|
|
||||||
persistenceService.markFailed(embedding.getId(), ex.getMessage());
|
|
||||||
jobService.markFailed(job.getId(), ex.getMessage(), true);
|
|
||||||
throw ex;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,12 @@
|
||||||
dip:
|
dip:
|
||||||
embedding:
|
embedding:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
enabled: true
|
||||||
|
process-in-batches: true
|
||||||
|
execution-batch-size: 20
|
||||||
|
|
||||||
default-document-model: e5-default
|
default-document-model: e5-default
|
||||||
default-query-model: e5-default
|
default-query-model: e5-default
|
||||||
|
|
||||||
|
|
@ -12,6 +18,10 @@ dip:
|
||||||
read-timeout: 60s
|
read-timeout: 60s
|
||||||
headers:
|
headers:
|
||||||
X-Client: dip
|
X-Client: dip
|
||||||
|
batch-request:
|
||||||
|
truncate-text: false
|
||||||
|
truncate-length: 512
|
||||||
|
chunk-size: 20
|
||||||
|
|
||||||
models:
|
models:
|
||||||
e5-default:
|
e5-default:
|
||||||
|
|
@ -20,6 +30,6 @@ dip:
|
||||||
dimensions: 1024
|
dimensions: 1024
|
||||||
distance-metric: COSINE
|
distance-metric: COSINE
|
||||||
supports-query-embedding-mode: true
|
supports-query-embedding-mode: true
|
||||||
supports-batch: false
|
supports-batch: true
|
||||||
max-input-chars: 8192
|
max-input-chars: 8192
|
||||||
active: true
|
active: true
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||||
import at.procon.dip.embedding.model.EmbeddingRequest;
|
import at.procon.dip.embedding.model.EmbeddingRequest;
|
||||||
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
||||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.sun.net.httpserver.HttpExchange;
|
import com.sun.net.httpserver.HttpExchange;
|
||||||
import com.sun.net.httpserver.HttpServer;
|
import com.sun.net.httpserver.HttpServer;
|
||||||
|
|
@ -17,12 +18,15 @@ import java.nio.charset.StandardCharsets;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
import org.junit.jupiter.api.AfterEach;
|
import org.junit.jupiter.api.AfterEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
class VectorSyncHttpEmbeddingProviderTest {
|
class VectorSyncHttpEmbeddingProviderTest {
|
||||||
|
|
||||||
|
private final ObjectMapper objectMapper = new ObjectMapper();
|
||||||
private HttpServer server;
|
private HttpServer server;
|
||||||
|
private final AtomicReference<String> lastBatchBody = new AtomicReference<>();
|
||||||
|
|
||||||
@AfterEach
|
@AfterEach
|
||||||
void tearDown() {
|
void tearDown() {
|
||||||
|
|
@ -37,13 +41,16 @@ class VectorSyncHttpEmbeddingProviderTest {
|
||||||
server.createContext("/vector-sync", this::handleVectorSync);
|
server.createContext("/vector-sync", this::handleVectorSync);
|
||||||
server.start();
|
server.start();
|
||||||
|
|
||||||
var provider = new VectorSyncHttpEmbeddingProvider(new ObjectMapper());
|
var provider = new VectorSyncHttpEmbeddingProvider(objectMapper);
|
||||||
var config = ResolvedEmbeddingProviderConfig.builder()
|
var config = ResolvedEmbeddingProviderConfig.builder()
|
||||||
.key("vector-sync-local")
|
.key("vector-sync-local")
|
||||||
.providerType("http-vector-sync")
|
.providerType("http-vector-sync")
|
||||||
.baseUrl("http://localhost:" + server.getAddress().getPort())
|
.baseUrl("http://localhost:" + server.getAddress().getPort())
|
||||||
.readTimeout(Duration.ofSeconds(5))
|
.readTimeout(Duration.ofSeconds(5))
|
||||||
.headers(Map.of("X-Client", "dip-test"))
|
.headers(Map.of("X-Client", "dip-test"))
|
||||||
|
.batchTruncateText(false)
|
||||||
|
.batchTruncateLength(512)
|
||||||
|
.batchChunkSize(20)
|
||||||
.build();
|
.build();
|
||||||
var model = new EmbeddingModelDescriptor(
|
var model = new EmbeddingModelDescriptor(
|
||||||
"e5-default",
|
"e5-default",
|
||||||
|
|
@ -70,6 +77,100 @@ class VectorSyncHttpEmbeddingProviderTest {
|
||||||
assertThat(result.tokenCount()).isEqualTo(9);
|
assertThat(result.tokenCount()).isEqualTo(9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldCallVectorizeBatchEndpointUsingConfiguredBatchRequestSettings() throws Exception {
|
||||||
|
server = HttpServer.create(new InetSocketAddress(0), 0);
|
||||||
|
server.createContext("/vectorize-batch", this::handleVectorizeBatch);
|
||||||
|
server.start();
|
||||||
|
|
||||||
|
var provider = new VectorSyncHttpEmbeddingProvider(objectMapper);
|
||||||
|
var config = ResolvedEmbeddingProviderConfig.builder()
|
||||||
|
.key("vector-sync-local")
|
||||||
|
.providerType("http-vector-sync")
|
||||||
|
.baseUrl("http://localhost:" + server.getAddress().getPort())
|
||||||
|
.readTimeout(Duration.ofSeconds(5))
|
||||||
|
.headers(Map.of("X-Client", "dip-test"))
|
||||||
|
.batchTruncateText(true)
|
||||||
|
.batchTruncateLength(768)
|
||||||
|
.batchChunkSize(33)
|
||||||
|
.build();
|
||||||
|
var model = new EmbeddingModelDescriptor(
|
||||||
|
"e5-default",
|
||||||
|
"vector-sync-local",
|
||||||
|
"intfloat/multilingual-e5-large",
|
||||||
|
3,
|
||||||
|
DistanceMetric.COSINE,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
8192,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
var request = EmbeddingRequest.builder()
|
||||||
|
.modelKey("e5-default")
|
||||||
|
.useCase(EmbeddingUseCase.DOCUMENT)
|
||||||
|
.texts(List.of("First text", "Second text"))
|
||||||
|
.providerOptions(Map.of())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
var result = provider.embedDocuments(config, model, request);
|
||||||
|
|
||||||
|
assertThat(result.vectors()).hasSize(2);
|
||||||
|
assertThat(result.vectors().get(0)).containsExactly(0.1f, 0.2f, 0.3f);
|
||||||
|
assertThat(result.vectors().get(1)).containsExactly(0.4f, 0.5f, 0.6f);
|
||||||
|
assertThat(result.tokenCount()).isEqualTo(12);
|
||||||
|
|
||||||
|
JsonNode requestBody = objectMapper.readTree(lastBatchBody.get());
|
||||||
|
assertThat(requestBody.get("truncate_text").asBoolean()).isTrue();
|
||||||
|
assertThat(requestBody.get("truncate_length").asInt()).isEqualTo(768);
|
||||||
|
assertThat(requestBody.get("chunk_size").asInt()).isEqualTo(33);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void shouldAllowPerRequestBatchOverrides() throws Exception {
|
||||||
|
server = HttpServer.create(new InetSocketAddress(0), 0);
|
||||||
|
server.createContext("/vectorize-batch", this::handleVectorizeBatch);
|
||||||
|
server.start();
|
||||||
|
|
||||||
|
var provider = new VectorSyncHttpEmbeddingProvider(objectMapper);
|
||||||
|
var config = ResolvedEmbeddingProviderConfig.builder()
|
||||||
|
.key("vector-sync-local")
|
||||||
|
.providerType("http-vector-sync")
|
||||||
|
.baseUrl("http://localhost:" + server.getAddress().getPort())
|
||||||
|
.readTimeout(Duration.ofSeconds(5))
|
||||||
|
.batchTruncateText(false)
|
||||||
|
.batchTruncateLength(512)
|
||||||
|
.batchChunkSize(20)
|
||||||
|
.build();
|
||||||
|
var model = new EmbeddingModelDescriptor(
|
||||||
|
"e5-default",
|
||||||
|
"vector-sync-local",
|
||||||
|
"intfloat/multilingual-e5-large",
|
||||||
|
3,
|
||||||
|
DistanceMetric.COSINE,
|
||||||
|
true,
|
||||||
|
true,
|
||||||
|
8192,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
var request = EmbeddingRequest.builder()
|
||||||
|
.modelKey("e5-default")
|
||||||
|
.useCase(EmbeddingUseCase.DOCUMENT)
|
||||||
|
.texts(List.of("First text", "Second text"))
|
||||||
|
.providerOptions(Map.of(
|
||||||
|
"truncate_text", true,
|
||||||
|
"truncate_length", 1024,
|
||||||
|
"chunk_size", 44
|
||||||
|
))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
provider.embedDocuments(config, model, request);
|
||||||
|
|
||||||
|
JsonNode requestBody = objectMapper.readTree(lastBatchBody.get());
|
||||||
|
assertThat(requestBody.get("truncate_text").asBoolean()).isTrue();
|
||||||
|
assertThat(requestBody.get("truncate_length").asInt()).isEqualTo(1024);
|
||||||
|
assertThat(requestBody.get("chunk_size").asInt()).isEqualTo(44);
|
||||||
|
}
|
||||||
|
|
||||||
private void handleVectorSync(HttpExchange exchange) throws IOException {
|
private void handleVectorSync(HttpExchange exchange) throws IOException {
|
||||||
String body;
|
String body;
|
||||||
try (InputStream in = exchange.getRequestBody()) {
|
try (InputStream in = exchange.getRequestBody()) {
|
||||||
|
|
@ -81,16 +182,69 @@ class VectorSyncHttpEmbeddingProviderTest {
|
||||||
assertThat(body).contains("\"text\":\"This is a sample text to vectorize\"");
|
assertThat(body).contains("\"text\":\"This is a sample text to vectorize\"");
|
||||||
assertThat(exchange.getRequestHeaders().getFirst("X-Client")).isEqualTo("dip-test");
|
assertThat(exchange.getRequestHeaders().getFirst("X-Client")).isEqualTo("dip-test");
|
||||||
|
|
||||||
String json = "{"
|
respondJson(exchange, """
|
||||||
+ "\"runtime_ms\":472.49,"
|
{
|
||||||
+ "\"vector\":[0.1,0.2,0.3],"
|
"runtime_ms": 472.49,
|
||||||
+ "\"incomplete\":false,"
|
"vector": [0.1, 0.2, 0.3],
|
||||||
+ "\"combined_vector\":null,"
|
"incomplete": false,
|
||||||
+ "\"token_count\":9,"
|
"combined_vector": null,
|
||||||
+ "\"model\":\"intfloat/multilingual-e5-large\","
|
"token_count": 9,
|
||||||
+ "\"max_seq_length\":512"
|
"model": "intfloat/multilingual-e5-large",
|
||||||
+ "}";
|
"max_seq_length": 512
|
||||||
|
}
|
||||||
|
""");
|
||||||
|
}
|
||||||
|
|
||||||
|
private void handleVectorizeBatch(HttpExchange exchange) throws IOException {
|
||||||
|
String body;
|
||||||
|
try (InputStream in = exchange.getRequestBody()) {
|
||||||
|
body = new String(in.readAllBytes(), StandardCharsets.UTF_8);
|
||||||
|
}
|
||||||
|
lastBatchBody.set(body);
|
||||||
|
|
||||||
|
assertThat(exchange.getRequestMethod()).isEqualTo("POST");
|
||||||
|
JsonNode requestBody = objectMapper.readTree(body);
|
||||||
|
assertThat(requestBody.get("model").asText()).isEqualTo("intfloat/multilingual-e5-large");
|
||||||
|
assertThat(requestBody.get("items")).hasSize(2);
|
||||||
|
|
||||||
|
String firstId = requestBody.get("items").get(0).get("id").asText();
|
||||||
|
String secondId = requestBody.get("items").get(1).get("id").asText();
|
||||||
|
|
||||||
|
respondJson(exchange, """
|
||||||
|
{
|
||||||
|
"model": "intfloat/multilingual-e5-large",
|
||||||
|
"count": 2,
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"id": "%s",
|
||||||
|
"vector": [0.1, 0.2, 0.3],
|
||||||
|
"token_count": 5,
|
||||||
|
"runtime_ms": 0.0,
|
||||||
|
"incomplete": false,
|
||||||
|
"combined_vector": null,
|
||||||
|
"truncated": false,
|
||||||
|
"truncate_length": 512,
|
||||||
|
"model": "intfloat/multilingual-e5-large",
|
||||||
|
"max_seq_length": 512
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "%s",
|
||||||
|
"vector": [0.4, 0.5, 0.6],
|
||||||
|
"token_count": 7,
|
||||||
|
"runtime_ms": 0.0,
|
||||||
|
"incomplete": false,
|
||||||
|
"combined_vector": null,
|
||||||
|
"truncated": false,
|
||||||
|
"truncate_length": 512,
|
||||||
|
"model": "intfloat/multilingual-e5-large",
|
||||||
|
"max_seq_length": 512
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
""".formatted(firstId, secondId));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void respondJson(HttpExchange exchange, String json) throws IOException {
|
||||||
byte[] response = json.getBytes(StandardCharsets.UTF_8);
|
byte[] response = json.getBytes(StandardCharsets.UTF_8);
|
||||||
exchange.getResponseHeaders().add("Content-Type", "application/json");
|
exchange.getResponseHeaders().add("Content-Type", "application/json");
|
||||||
exchange.sendResponseHeaders(200, response.length);
|
exchange.sendResponseHeaders(200, response.length);
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,6 @@
|
||||||
package at.procon.dip.embedding.service;
|
package at.procon.dip.embedding.service;
|
||||||
|
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
import at.procon.dip.domain.document.DistanceMetric;
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
|
||||||
import static org.mockito.Mockito.verify;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
import at.procon.dip.domain.document.entity.Document;
|
import at.procon.dip.domain.document.entity.Document;
|
||||||
import at.procon.dip.domain.document.entity.DocumentEmbedding;
|
import at.procon.dip.domain.document.entity.DocumentEmbedding;
|
||||||
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
|
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
|
||||||
|
|
@ -12,20 +8,25 @@ import at.procon.dip.domain.document.repository.DocumentTextRepresentationReposi
|
||||||
import at.procon.dip.embedding.config.EmbeddingProperties;
|
import at.procon.dip.embedding.config.EmbeddingProperties;
|
||||||
import at.procon.dip.embedding.job.entity.EmbeddingJob;
|
import at.procon.dip.embedding.job.entity.EmbeddingJob;
|
||||||
import at.procon.dip.embedding.job.service.EmbeddingJobService;
|
import at.procon.dip.embedding.job.service.EmbeddingJobService;
|
||||||
import at.procon.dip.embedding.model.EmbeddingJobStatus;
|
import at.procon.dip.embedding.model.*;
|
||||||
import at.procon.dip.embedding.model.EmbeddingJobType;
|
import at.procon.dip.embedding.policy.EmbeddingSelectionPolicy;
|
||||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
|
||||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
|
||||||
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
||||||
import at.procon.dip.domain.document.DistanceMetric;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.UUID;
|
|
||||||
import org.junit.jupiter.api.BeforeEach;
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.extension.ExtendWith;
|
import org.junit.jupiter.api.extension.ExtendWith;
|
||||||
|
import org.mockito.ArgumentCaptor;
|
||||||
import org.mockito.Mock;
|
import org.mockito.Mock;
|
||||||
import org.mockito.junit.jupiter.MockitoExtension;
|
import org.mockito.junit.jupiter.MockitoExtension;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.mockito.AdditionalMatchers.aryEq;
|
||||||
|
import static org.mockito.ArgumentMatchers.*;
|
||||||
|
import static org.mockito.Mockito.*;
|
||||||
|
|
||||||
@ExtendWith(MockitoExtension.class)
|
@ExtendWith(MockitoExtension.class)
|
||||||
class RepresentationEmbeddingOrchestratorTest {
|
class RepresentationEmbeddingOrchestratorTest {
|
||||||
|
|
||||||
|
|
@ -38,22 +39,27 @@ class RepresentationEmbeddingOrchestratorTest {
|
||||||
@Mock
|
@Mock
|
||||||
private DocumentTextRepresentationRepository representationRepository;
|
private DocumentTextRepresentationRepository representationRepository;
|
||||||
@Mock
|
@Mock
|
||||||
|
private EmbeddingSelectionPolicy selectionPolicy;
|
||||||
|
@Mock
|
||||||
private EmbeddingModelRegistry modelRegistry;
|
private EmbeddingModelRegistry modelRegistry;
|
||||||
|
|
||||||
|
private EmbeddingProperties properties;
|
||||||
private RepresentationEmbeddingOrchestrator orchestrator;
|
private RepresentationEmbeddingOrchestrator orchestrator;
|
||||||
|
|
||||||
@BeforeEach
|
@BeforeEach
|
||||||
void setUp() {
|
void setUp() {
|
||||||
EmbeddingProperties properties = new EmbeddingProperties();
|
properties = new EmbeddingProperties();
|
||||||
properties.setEnabled(true);
|
properties.setEnabled(true);
|
||||||
properties.getJobs().setEnabled(true);
|
properties.getJobs().setEnabled(true);
|
||||||
properties.getJobs().setBatchSize(10);
|
properties.getJobs().setBatchSize(10);
|
||||||
|
properties.getJobs().setExecutionBatchSize(10);
|
||||||
properties.getIndexing().setFallbackMaxInputChars(8192);
|
properties.getIndexing().setFallbackMaxInputChars(8192);
|
||||||
orchestrator = new RepresentationEmbeddingOrchestrator(
|
orchestrator = new RepresentationEmbeddingOrchestrator(
|
||||||
jobService,
|
jobService,
|
||||||
executionService,
|
executionService,
|
||||||
persistenceService,
|
persistenceService,
|
||||||
representationRepository,
|
representationRepository,
|
||||||
|
selectionPolicy,
|
||||||
modelRegistry,
|
modelRegistry,
|
||||||
properties
|
properties
|
||||||
);
|
);
|
||||||
|
|
@ -80,16 +86,17 @@ class RepresentationEmbeddingOrchestratorTest {
|
||||||
.document(document)
|
.document(document)
|
||||||
.textBody("District heating optimization strategy")
|
.textBody("District heating optimization strategy")
|
||||||
.build();
|
.build();
|
||||||
when(representationRepository.findById(representationId)).thenReturn(java.util.Optional.of(representation));
|
EmbeddingModelDescriptor model = new EmbeddingModelDescriptor(
|
||||||
when(modelRegistry.getRequired("mock-search"))
|
"mock-search", "mock-provider", "mock-search", 16,
|
||||||
.thenReturn(new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16,
|
DistanceMetric.COSINE, true, false, 4096, true
|
||||||
DistanceMetric.COSINE, true, false, 4096, true));
|
);
|
||||||
|
when(representationRepository.findById(representationId)).thenReturn(Optional.of(representation));
|
||||||
|
when(modelRegistry.getRequired("mock-search")).thenReturn(model);
|
||||||
when(persistenceService.ensurePending(representationId, "mock-search"))
|
when(persistenceService.ensurePending(representationId, "mock-search"))
|
||||||
.thenReturn(DocumentEmbedding.builder().id(embeddingId).build());
|
.thenReturn(DocumentEmbedding.builder().id(embeddingId).build());
|
||||||
when(executionService.embedTexts(eq("mock-search"), any(), any()))
|
when(executionService.embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), any()))
|
||||||
.thenReturn(new EmbeddingProviderResult(
|
.thenReturn(new EmbeddingProviderResult(
|
||||||
new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16,
|
model,
|
||||||
DistanceMetric.COSINE, true, false, 4096, true),
|
|
||||||
List.of(new float[]{0.1f, 0.2f}),
|
List.of(new float[]{0.1f, 0.2f}),
|
||||||
List.of(),
|
List.of(),
|
||||||
"req-1",
|
"req-1",
|
||||||
|
|
@ -102,4 +109,121 @@ class RepresentationEmbeddingOrchestratorTest {
|
||||||
verify(persistenceService).saveCompleted(eq(embeddingId), any(EmbeddingProviderResult.class));
|
verify(persistenceService).saveCompleted(eq(embeddingId), any(EmbeddingProviderResult.class));
|
||||||
verify(jobService).markCompleted(job.getId(), "req-1");
|
verify(jobService).markCompleted(job.getId(), "req-1");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void processNextReadyBatch_should_group_batchable_jobs_and_call_provider_once() {
|
||||||
|
properties.getJobs().setProcessInBatches(true);
|
||||||
|
properties.getJobs().setExecutionBatchSize(10);
|
||||||
|
|
||||||
|
UUID documentId = UUID.randomUUID();
|
||||||
|
UUID representationId1 = UUID.randomUUID();
|
||||||
|
UUID representationId2 = UUID.randomUUID();
|
||||||
|
UUID embeddingId1 = UUID.randomUUID();
|
||||||
|
UUID embeddingId2 = UUID.randomUUID();
|
||||||
|
|
||||||
|
EmbeddingJob job1 = EmbeddingJob.builder()
|
||||||
|
.id(UUID.randomUUID())
|
||||||
|
.documentId(documentId)
|
||||||
|
.representationId(representationId1)
|
||||||
|
.modelKey("e5-default")
|
||||||
|
.jobType(EmbeddingJobType.DOCUMENT_EMBED)
|
||||||
|
.status(EmbeddingJobStatus.IN_PROGRESS)
|
||||||
|
.attemptCount(1)
|
||||||
|
.build();
|
||||||
|
EmbeddingJob job2 = EmbeddingJob.builder()
|
||||||
|
.id(UUID.randomUUID())
|
||||||
|
.documentId(documentId)
|
||||||
|
.representationId(representationId2)
|
||||||
|
.modelKey("e5-default")
|
||||||
|
.jobType(EmbeddingJobType.DOCUMENT_EMBED)
|
||||||
|
.status(EmbeddingJobStatus.IN_PROGRESS)
|
||||||
|
.attemptCount(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job1, job2));
|
||||||
|
|
||||||
|
Document document = Document.builder().id(documentId).build();
|
||||||
|
when(representationRepository.findById(representationId1)).thenReturn(Optional.of(
|
||||||
|
DocumentTextRepresentation.builder().id(representationId1).document(document).textBody("alpha").build()
|
||||||
|
));
|
||||||
|
when(representationRepository.findById(representationId2)).thenReturn(Optional.of(
|
||||||
|
DocumentTextRepresentation.builder().id(representationId2).document(document).textBody("beta").build()
|
||||||
|
));
|
||||||
|
|
||||||
|
EmbeddingModelDescriptor model = new EmbeddingModelDescriptor(
|
||||||
|
"e5-default", "vector-sync-e5", "intfloat/multilingual-e5-large", 3,
|
||||||
|
DistanceMetric.COSINE, true, true, 4096, true
|
||||||
|
);
|
||||||
|
when(modelRegistry.getRequired("e5-default")).thenReturn(model);
|
||||||
|
when(persistenceService.ensurePending(representationId1, "e5-default"))
|
||||||
|
.thenReturn(DocumentEmbedding.builder().id(embeddingId1).build());
|
||||||
|
when(persistenceService.ensurePending(representationId2, "e5-default"))
|
||||||
|
.thenReturn(DocumentEmbedding.builder().id(embeddingId2).build());
|
||||||
|
when(executionService.embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any()))
|
||||||
|
.thenReturn(new EmbeddingProviderResult(
|
||||||
|
model,
|
||||||
|
List.of(new float[]{0.1f, 0.2f, 0.3f}, new float[]{0.4f, 0.5f, 0.6f}),
|
||||||
|
List.of(),
|
||||||
|
"batch-req-1",
|
||||||
|
21
|
||||||
|
));
|
||||||
|
|
||||||
|
int processed = orchestrator.processNextReadyBatch();
|
||||||
|
|
||||||
|
assertThat(processed).isEqualTo(2);
|
||||||
|
verify(persistenceService).markProcessing(embeddingId1);
|
||||||
|
verify(persistenceService).markProcessing(embeddingId2);
|
||||||
|
ArgumentCaptor<List<String>> textsCaptor = ArgumentCaptor.forClass(List.class);
|
||||||
|
verify(executionService, times(1)).embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), textsCaptor.capture());
|
||||||
|
assertThat(textsCaptor.getValue()).containsExactly("alpha", "beta");
|
||||||
|
verify(persistenceService).saveCompleted(eq(embeddingId1), aryEq(new float[]{0.1f, 0.2f, 0.3f}), eq(null));
|
||||||
|
verify(persistenceService).saveCompleted(eq(embeddingId2), aryEq(new float[]{0.4f, 0.5f, 0.6f}), eq(null));
|
||||||
|
verify(jobService).markCompleted(job1.getId(), "batch-req-1");
|
||||||
|
verify(jobService).markCompleted(job2.getId(), "batch-req-1");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void processNextReadyBatch_should_fall_back_to_single_processing_for_non_batch_model() {
|
||||||
|
properties.getJobs().setProcessInBatches(true);
|
||||||
|
|
||||||
|
UUID documentId = UUID.randomUUID();
|
||||||
|
UUID representationId = UUID.randomUUID();
|
||||||
|
UUID embeddingId = UUID.randomUUID();
|
||||||
|
EmbeddingJob job = EmbeddingJob.builder()
|
||||||
|
.id(UUID.randomUUID())
|
||||||
|
.documentId(documentId)
|
||||||
|
.representationId(representationId)
|
||||||
|
.modelKey("mock-search")
|
||||||
|
.jobType(EmbeddingJobType.DOCUMENT_EMBED)
|
||||||
|
.status(EmbeddingJobStatus.IN_PROGRESS)
|
||||||
|
.attemptCount(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job));
|
||||||
|
Document document = Document.builder().id(documentId).build();
|
||||||
|
when(representationRepository.findById(representationId)).thenReturn(Optional.of(
|
||||||
|
DocumentTextRepresentation.builder().id(representationId).document(document).textBody("gamma").build()
|
||||||
|
));
|
||||||
|
EmbeddingModelDescriptor model = new EmbeddingModelDescriptor(
|
||||||
|
"mock-search", "mock-provider", "mock-search", 2,
|
||||||
|
DistanceMetric.COSINE, true, false, 4096, true
|
||||||
|
);
|
||||||
|
when(modelRegistry.getRequired("mock-search")).thenReturn(model);
|
||||||
|
when(persistenceService.ensurePending(representationId, "mock-search"))
|
||||||
|
.thenReturn(DocumentEmbedding.builder().id(embeddingId).build());
|
||||||
|
when(executionService.embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), any()))
|
||||||
|
.thenReturn(new EmbeddingProviderResult(
|
||||||
|
model,
|
||||||
|
List.of(new float[]{0.7f, 0.8f}),
|
||||||
|
List.of(),
|
||||||
|
"req-2",
|
||||||
|
7
|
||||||
|
));
|
||||||
|
|
||||||
|
orchestrator.processNextReadyBatch();
|
||||||
|
|
||||||
|
verify(executionService, times(1)).embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), eq(List.of("gamma")));
|
||||||
|
verify(persistenceService, never()).saveCompleted(eq(embeddingId), any(float[].class), eq(null));
|
||||||
|
verify(persistenceService).saveCompleted(eq(embeddingId), any(EmbeddingProviderResult.class));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue