batch embedding support
This commit is contained in:
parent
52330d751d
commit
678db76415
|
|
@ -30,6 +30,14 @@ public class EmbeddingProperties {
|
|||
private Duration readTimeout = Duration.ofSeconds(60);
|
||||
private Map<String, String> headers = new LinkedHashMap<>();
|
||||
private Integer dimensions;
|
||||
private BatchRequestProperties batchRequest = new BatchRequestProperties();
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class BatchRequestProperties {
|
||||
private boolean truncateText = false;
|
||||
private int truncateLength = 512;
|
||||
private int chunkSize = 20;
|
||||
}
|
||||
|
||||
@Data
|
||||
|
|
@ -59,6 +67,8 @@ public class EmbeddingProperties {
|
|||
public static class JobsProperties {
|
||||
private boolean enabled = false;
|
||||
private int batchSize = 16;
|
||||
private boolean processInBatches = false;
|
||||
private int executionBatchSize = 8;
|
||||
private int maxRetries = 5;
|
||||
private Duration initialRetryDelay = Duration.ofSeconds(30);
|
||||
private Duration maxRetryDelay = Duration.ofHours(6);
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ public record ResolvedEmbeddingProviderConfig(
|
|||
Duration connectTimeout,
|
||||
Duration readTimeout,
|
||||
Map<String, String> headers,
|
||||
Integer dimensions
|
||||
Integer dimensions,
|
||||
Boolean batchTruncateText,
|
||||
Integer batchTruncateLength,
|
||||
Integer batchChunkSize
|
||||
) {
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ import at.procon.dip.embedding.provider.EmbeddingProvider;
|
|||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
|
|
@ -24,16 +22,6 @@ import org.springframework.stereotype.Component;
|
|||
* Supported endpoints:
|
||||
* POST {baseUrl}/vector-sync - single text
|
||||
* POST {baseUrl}/vectorize-batch - multiple texts
|
||||
*
|
||||
* Batch settings are resolved from provider config if present, otherwise defaults are used.
|
||||
* Supported property keys:
|
||||
* - vectorize-batch.truncate-text
|
||||
* - vectorize-batch.truncate-length
|
||||
* - vectorize-batch.chunk-size
|
||||
*
|
||||
* Also accepted as fallbacks:
|
||||
* - truncate_text / truncate-length / chunk_size
|
||||
* - truncateText / truncateLength / chunkSize
|
||||
*/
|
||||
@Component
|
||||
public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProviderSupport implements EmbeddingProvider {
|
||||
|
|
@ -105,7 +93,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|||
try {
|
||||
return request.texts().size() == 1
|
||||
? executeSingle(providerConfig, model, request.texts().getFirst())
|
||||
: executeBatch(providerConfig, model, request.texts());
|
||||
: executeBatch(providerConfig, model, request);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new IllegalStateException("Embedding provider call interrupted", e);
|
||||
|
|
@ -137,22 +125,20 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|||
|
||||
private EmbeddingProviderResult executeBatch(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
List<String> texts) throws IOException, InterruptedException {
|
||||
boolean truncateText = resolveBooleanProperty(providerConfig, TRUNCATE_TEXT_KEYS, DEFAULT_TRUNCATE_TEXT);
|
||||
int truncateLength = resolveIntProperty(providerConfig, TRUNCATE_LENGTH_KEYS, DEFAULT_TRUNCATE_LENGTH);
|
||||
int chunkSize = resolveIntProperty(providerConfig, CHUNK_SIZE_KEYS, DEFAULT_CHUNK_SIZE);
|
||||
EmbeddingRequest request) throws IOException, InterruptedException {
|
||||
BatchRequestSettings settings = resolveBatchRequestSettings(providerConfig, request.providerOptions());
|
||||
|
||||
if (truncateLength <= 0) {
|
||||
if (settings.truncateLength() <= 0) {
|
||||
throw new IllegalArgumentException("Batch truncate length must be > 0");
|
||||
}
|
||||
if (chunkSize <= 0) {
|
||||
if (settings.chunkSize() <= 0) {
|
||||
throw new IllegalArgumentException("Batch chunk size must be > 0");
|
||||
}
|
||||
|
||||
List<String> requestOrder = new ArrayList<>(texts.size());
|
||||
List<VectorizeBatchItemRequest> items = new ArrayList<>(texts.size());
|
||||
List<String> requestOrder = new ArrayList<>(request.texts().size());
|
||||
List<VectorizeBatchItemRequest> items = new ArrayList<>(request.texts().size());
|
||||
|
||||
for (String text : texts) {
|
||||
for (String text : request.texts()) {
|
||||
String id = UUID.randomUUID().toString();
|
||||
requestOrder.add(id);
|
||||
items.add(new VectorizeBatchItemRequest(id, text));
|
||||
|
|
@ -163,9 +149,9 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|||
"/vectorize-batch",
|
||||
new VectorizeBatchRequest(
|
||||
model.providerModelKey(),
|
||||
truncateText,
|
||||
truncateLength,
|
||||
chunkSize,
|
||||
settings.truncateText(),
|
||||
settings.truncateLength(),
|
||||
settings.chunkSize(),
|
||||
items
|
||||
)
|
||||
);
|
||||
|
|
@ -180,7 +166,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|||
resultById.put(result.id, result);
|
||||
}
|
||||
|
||||
List<float[]> vectors = new ArrayList<>(texts.size());
|
||||
List<float[]> vectors = new ArrayList<>(request.texts().size());
|
||||
int totalTokenCount = 0;
|
||||
boolean hasAnyTokenCount = false;
|
||||
|
||||
|
|
@ -207,6 +193,72 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|||
);
|
||||
}
|
||||
|
||||
private BatchRequestSettings resolveBatchRequestSettings(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
Map<String, Object> providerOptions) {
|
||||
boolean truncateText = resolveBooleanOption(
|
||||
providerOptions,
|
||||
TRUNCATE_TEXT_KEYS,
|
||||
providerConfig.batchTruncateText() != null ? providerConfig.batchTruncateText() : DEFAULT_TRUNCATE_TEXT
|
||||
);
|
||||
int truncateLength = resolveIntOption(
|
||||
providerOptions,
|
||||
TRUNCATE_LENGTH_KEYS,
|
||||
providerConfig.batchTruncateLength() != null ? providerConfig.batchTruncateLength() : DEFAULT_TRUNCATE_LENGTH
|
||||
);
|
||||
int chunkSize = resolveIntOption(
|
||||
providerOptions,
|
||||
CHUNK_SIZE_KEYS,
|
||||
providerConfig.batchChunkSize() != null ? providerConfig.batchChunkSize() : DEFAULT_CHUNK_SIZE
|
||||
);
|
||||
return new BatchRequestSettings(truncateText, truncateLength, chunkSize);
|
||||
}
|
||||
|
||||
private boolean resolveBooleanOption(Map<String, Object> providerOptions,
|
||||
List<String> keys,
|
||||
boolean defaultValue) {
|
||||
Object raw = resolveOption(providerOptions, keys);
|
||||
if (raw == null) {
|
||||
return defaultValue;
|
||||
}
|
||||
if (raw instanceof Boolean booleanValue) {
|
||||
return booleanValue;
|
||||
}
|
||||
String normalized = String.valueOf(raw).trim();
|
||||
if (normalized.isEmpty()) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Boolean.parseBoolean(normalized);
|
||||
}
|
||||
|
||||
private int resolveIntOption(Map<String, Object> providerOptions,
|
||||
List<String> keys,
|
||||
int defaultValue) {
|
||||
Object raw = resolveOption(providerOptions, keys);
|
||||
if (raw == null) {
|
||||
return defaultValue;
|
||||
}
|
||||
if (raw instanceof Number number) {
|
||||
return number.intValue();
|
||||
}
|
||||
String normalized = String.valueOf(raw).trim();
|
||||
if (normalized.isEmpty()) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Integer.parseInt(normalized);
|
||||
}
|
||||
|
||||
private Object resolveOption(Map<String, Object> providerOptions, List<String> keys) {
|
||||
if (providerOptions == null || providerOptions.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
for (String key : keys) {
|
||||
if (providerOptions.containsKey(key)) {
|
||||
return providerOptions.get(key);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private float[] extractVector(List<Float> vector,
|
||||
List<Float> combinedVector,
|
||||
EmbeddingModelDescriptor model) {
|
||||
|
|
@ -230,123 +282,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|||
return resolved;
|
||||
}
|
||||
|
||||
private boolean resolveBooleanProperty(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
List<String> keys,
|
||||
boolean defaultValue) {
|
||||
Object raw = resolveProviderConfigValue(providerConfig, keys);
|
||||
if (raw == null) {
|
||||
return defaultValue;
|
||||
}
|
||||
if (raw instanceof Boolean b) {
|
||||
return b;
|
||||
}
|
||||
String normalized = String.valueOf(raw).trim();
|
||||
if (normalized.isEmpty()) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Boolean.parseBoolean(normalized);
|
||||
}
|
||||
|
||||
private int resolveIntProperty(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
List<String> keys,
|
||||
int defaultValue) {
|
||||
Object raw = resolveProviderConfigValue(providerConfig, keys);
|
||||
if (raw == null) {
|
||||
return defaultValue;
|
||||
}
|
||||
if (raw instanceof Number n) {
|
||||
return n.intValue();
|
||||
}
|
||||
String normalized = String.valueOf(raw).trim();
|
||||
if (normalized.isEmpty()) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Integer.parseInt(normalized);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Object resolveProviderConfigValue(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
List<String> keys) {
|
||||
List<Object> containers = new ArrayList<>();
|
||||
containers.add(providerConfig);
|
||||
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "properties"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "providerProperties"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "config"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "settings"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "options"));
|
||||
addIfNonNull(containers, readField(providerConfig, "properties"));
|
||||
addIfNonNull(containers, readField(providerConfig, "providerProperties"));
|
||||
addIfNonNull(containers, readField(providerConfig, "config"));
|
||||
addIfNonNull(containers, readField(providerConfig, "settings"));
|
||||
addIfNonNull(containers, readField(providerConfig, "options"));
|
||||
|
||||
for (Object container : containers) {
|
||||
if (container instanceof Map<?, ?> map) {
|
||||
for (String key : keys) {
|
||||
if (map.containsKey(key)) {
|
||||
return map.get(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (String key : keys) {
|
||||
Object value = invokeStringArg(container, "get", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
value = invokeStringArg(container, "getProperty", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
value = invokeStringArg(container, "property", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
value = invokeStringArg(container, "option", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private void addIfNonNull(List<Object> containers, Object value) {
|
||||
if (value != null) {
|
||||
containers.add(value);
|
||||
}
|
||||
}
|
||||
|
||||
private Object invokeNoArg(Object target, String methodName) {
|
||||
try {
|
||||
Method method = target.getClass().getMethod(methodName);
|
||||
method.setAccessible(true);
|
||||
return method.invoke(target);
|
||||
} catch (Exception ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Object invokeStringArg(Object target, String methodName, String arg) {
|
||||
try {
|
||||
Method method = target.getClass().getMethod(methodName, String.class);
|
||||
method.setAccessible(true);
|
||||
return method.invoke(target, arg);
|
||||
} catch (Exception ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Object readField(Object target, String fieldName) {
|
||||
try {
|
||||
Field field = target.getClass().getDeclaredField(fieldName);
|
||||
field.setAccessible(true);
|
||||
return field.get(target);
|
||||
} catch (Exception ignored) {
|
||||
return null;
|
||||
}
|
||||
private record BatchRequestSettings(boolean truncateText, int truncateLength, int chunkSize) {
|
||||
}
|
||||
|
||||
private record VectorSyncRequest(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ public class EmbeddingProviderConfigResolver {
|
|||
throw new IllegalArgumentException("Unknown embedding provider config key: " + providerConfigKey);
|
||||
}
|
||||
|
||||
EmbeddingProperties.BatchRequestProperties batchRequest = provider.getBatchRequest() == null
|
||||
? new EmbeddingProperties.BatchRequestProperties()
|
||||
: provider.getBatchRequest();
|
||||
|
||||
return ResolvedEmbeddingProviderConfig.builder()
|
||||
.key(providerConfigKey)
|
||||
.providerType(provider.getType())
|
||||
|
|
@ -27,6 +31,9 @@ public class EmbeddingProviderConfigResolver {
|
|||
.readTimeout(provider.getReadTimeout())
|
||||
.headers(provider.getHeaders() == null ? Map.of() : Map.copyOf(provider.getHeaders()))
|
||||
.dimensions(provider.getDimensions())
|
||||
.batchTruncateText(batchRequest.isTruncateText())
|
||||
.batchTruncateLength(batchRequest.getTruncateLength())
|
||||
.batchChunkSize(batchRequest.getChunkSize())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import at.procon.dip.domain.document.repository.DocumentEmbeddingRepository;
|
|||
import at.procon.dip.domain.document.repository.DocumentTextRepresentationRepository;
|
||||
import at.procon.dip.domain.document.service.DocumentEmbeddingService;
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.support.EmbeddingVectorCodec;
|
||||
import java.time.OffsetDateTime;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
|
@ -44,11 +43,17 @@ public class EmbeddingPersistenceService {
|
|||
if (result.vectors() == null || result.vectors().isEmpty()) {
|
||||
throw new IllegalArgumentException("Embedding provider result contains no vectors");
|
||||
}
|
||||
float[] vector = result.vectors().getFirst();
|
||||
saveCompleted(embeddingId, result.vectors().getFirst(), result.tokenCount());
|
||||
}
|
||||
|
||||
public void saveCompleted(UUID embeddingId, float[] vector, Integer tokenCount) {
|
||||
if (vector == null || vector.length == 0) {
|
||||
throw new IllegalArgumentException("Embedding vector must not be empty");
|
||||
}
|
||||
embeddingRepository.updateEmbeddingVector(
|
||||
embeddingId,
|
||||
vector, //EmbeddingVectorCodec.toPgVector(vector),
|
||||
result.tokenCount(),
|
||||
vector,
|
||||
tokenCount,
|
||||
vector.length
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,11 +7,14 @@ import at.procon.dip.embedding.config.EmbeddingProperties;
|
|||
import at.procon.dip.embedding.job.entity.EmbeddingJob;
|
||||
import at.procon.dip.embedding.job.service.EmbeddingJobService;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobType;
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
||||
import at.procon.dip.embedding.policy.EmbeddingProfile;
|
||||
import at.procon.dip.embedding.policy.EmbeddingSelectionPolicy;
|
||||
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
|
@ -63,25 +66,138 @@ public class RepresentationEmbeddingOrchestrator {
|
|||
}
|
||||
|
||||
List<EmbeddingJob> jobs = jobService.claimNextReadyJobs(embeddingProperties.getJobs().getBatchSize());
|
||||
for (EmbeddingJob job : jobs) {
|
||||
processClaimedJob(job);
|
||||
if (jobs.isEmpty()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (embeddingProperties.getJobs().isProcessInBatches()) {
|
||||
processClaimedJobsInBatches(jobs);
|
||||
} else {
|
||||
jobs.forEach(this::processClaimedJobSafely);
|
||||
}
|
||||
return jobs.size();
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public void processClaimedJob(EmbeddingJob job) {
|
||||
EmbeddingModelDescriptor model = modelRegistry.getRequired(job.getModelKey());
|
||||
PreparedEmbedding prepared = prepareEmbedding(job, model);
|
||||
if (prepared == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
EmbeddingProviderResult result = executionService.embedTexts(
|
||||
job.getModelKey(),
|
||||
EmbeddingUseCase.DOCUMENT,
|
||||
List.of(prepared.text())
|
||||
);
|
||||
persistenceService.saveCompleted(prepared.embeddingId(), result);
|
||||
jobService.markCompleted(job.getId(), result.providerRequestId());
|
||||
} catch (RuntimeException ex) {
|
||||
persistenceService.markFailed(prepared.embeddingId(), ex.getMessage());
|
||||
jobService.markFailed(job.getId(), ex.getMessage(), true);
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
|
||||
private void processClaimedJobsInBatches(List<EmbeddingJob> jobs) {
|
||||
LinkedHashMap<String, List<EmbeddingJob>> jobsByModelKey = new LinkedHashMap<>();
|
||||
for (EmbeddingJob job : jobs) {
|
||||
jobsByModelKey.computeIfAbsent(job.getModelKey(), ignored -> new ArrayList<>()).add(job);
|
||||
}
|
||||
|
||||
int executionBatchSize = Math.max(1, embeddingProperties.getJobs().getExecutionBatchSize());
|
||||
for (var entry : jobsByModelKey.entrySet()) {
|
||||
EmbeddingModelDescriptor model = modelRegistry.getRequired(entry.getKey());
|
||||
if (!model.supportsBatch()) {
|
||||
entry.getValue().forEach(this::processClaimedJobSafely);
|
||||
continue;
|
||||
}
|
||||
|
||||
List<EmbeddingJob> sameModelJobs = entry.getValue();
|
||||
for (int start = 0; start < sameModelJobs.size(); start += executionBatchSize) {
|
||||
List<EmbeddingJob> partition = sameModelJobs.subList(start, Math.min(start + executionBatchSize, sameModelJobs.size()));
|
||||
if (partition.size() == 1) {
|
||||
processClaimedJobSafely(partition.getFirst());
|
||||
} else {
|
||||
processClaimedBatchSafely(partition, model);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void processClaimedBatchSafely(List<EmbeddingJob> jobs, EmbeddingModelDescriptor model) {
|
||||
try {
|
||||
processClaimedBatch(jobs, model);
|
||||
} catch (RuntimeException ex) {
|
||||
log.warn("Failed to process embedding batch for model {} ({} jobs): {}",
|
||||
model.modelKey(), jobs.size(), ex.getMessage(), ex);
|
||||
}
|
||||
}
|
||||
|
||||
private void processClaimedJobSafely(EmbeddingJob job) {
|
||||
try {
|
||||
processClaimedJob(job);
|
||||
} catch (RuntimeException ex) {
|
||||
log.warn("Failed to process embedding job {} for representation {}: {}",
|
||||
job.getId(), job.getRepresentationId(), ex.getMessage(), ex);
|
||||
}
|
||||
}
|
||||
|
||||
private void processClaimedBatch(List<EmbeddingJob> jobs, EmbeddingModelDescriptor model) {
|
||||
List<PreparedEmbedding> preparedItems = new ArrayList<>(jobs.size());
|
||||
for (EmbeddingJob job : jobs) {
|
||||
PreparedEmbedding prepared = prepareEmbedding(job, model);
|
||||
if (prepared != null) {
|
||||
preparedItems.add(prepared);
|
||||
}
|
||||
}
|
||||
|
||||
if (preparedItems.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
EmbeddingProviderResult result = executionService.embedTexts(
|
||||
model.modelKey(),
|
||||
EmbeddingUseCase.DOCUMENT,
|
||||
preparedItems.stream().map(PreparedEmbedding::text).toList()
|
||||
);
|
||||
|
||||
if (result.vectors() == null || result.vectors().size() != preparedItems.size()) {
|
||||
throw new IllegalStateException(
|
||||
"Embedding provider returned %d vectors for %d batch items"
|
||||
.formatted(result.vectors() == null ? 0 : result.vectors().size(), preparedItems.size())
|
||||
);
|
||||
}
|
||||
|
||||
for (int i = 0; i < preparedItems.size(); i++) {
|
||||
PreparedEmbedding prepared = preparedItems.get(i);
|
||||
persistenceService.saveCompleted(prepared.embeddingId(), result.vectors().get(i), null);
|
||||
jobService.markCompleted(prepared.job().getId(), result.providerRequestId());
|
||||
}
|
||||
} catch (RuntimeException ex) {
|
||||
for (PreparedEmbedding prepared : preparedItems) {
|
||||
persistenceService.markFailed(prepared.embeddingId(), ex.getMessage());
|
||||
jobService.markFailed(prepared.job().getId(), ex.getMessage(), true);
|
||||
}
|
||||
throw ex;
|
||||
}
|
||||
}
|
||||
|
||||
private PreparedEmbedding prepareEmbedding(EmbeddingJob job, EmbeddingModelDescriptor model) {
|
||||
DocumentTextRepresentation representation = representationRepository.findById(job.getRepresentationId())
|
||||
.orElseThrow(() -> new IllegalArgumentException("Unknown representation id: " + job.getRepresentationId()));
|
||||
|
||||
String text = representation.getTextBody();
|
||||
if (text == null || text.isBlank()) {
|
||||
jobService.markFailed(job.getId(), "No text representation available", false);
|
||||
return;
|
||||
return null;
|
||||
}
|
||||
|
||||
int maxChars = modelRegistry.getRequired(job.getModelKey()).maxInputChars() != null
|
||||
? modelRegistry.getRequired(job.getModelKey()).maxInputChars()
|
||||
int maxChars = model.maxInputChars() != null
|
||||
? model.maxInputChars()
|
||||
: embeddingProperties.getIndexing().getFallbackMaxInputChars();
|
||||
if (text.length() > maxChars) {
|
||||
text = text.substring(0, maxChars);
|
||||
|
|
@ -89,19 +205,9 @@ public class RepresentationEmbeddingOrchestrator {
|
|||
|
||||
DocumentEmbedding embedding = persistenceService.ensurePending(representation.getId(), job.getModelKey());
|
||||
persistenceService.markProcessing(embedding.getId());
|
||||
return new PreparedEmbedding(job, embedding.getId(), text);
|
||||
}
|
||||
|
||||
try {
|
||||
EmbeddingProviderResult result = executionService.embedTexts(
|
||||
job.getModelKey(),
|
||||
EmbeddingUseCase.DOCUMENT,
|
||||
List.of(text)
|
||||
);
|
||||
persistenceService.saveCompleted(embedding.getId(), result);
|
||||
jobService.markCompleted(job.getId(), result.providerRequestId());
|
||||
} catch (RuntimeException ex) {
|
||||
persistenceService.markFailed(embedding.getId(), ex.getMessage());
|
||||
jobService.markFailed(job.getId(), ex.getMessage(), true);
|
||||
throw ex;
|
||||
}
|
||||
private record PreparedEmbedding(EmbeddingJob job, UUID embeddingId, String text) {
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
dip:
|
||||
embedding:
|
||||
enabled: true
|
||||
|
||||
jobs:
|
||||
enabled: true
|
||||
process-in-batches: true
|
||||
execution-batch-size: 20
|
||||
|
||||
default-document-model: e5-default
|
||||
default-query-model: e5-default
|
||||
|
||||
|
|
@ -12,6 +18,10 @@ dip:
|
|||
read-timeout: 60s
|
||||
headers:
|
||||
X-Client: dip
|
||||
batch-request:
|
||||
truncate-text: false
|
||||
truncate-length: 512
|
||||
chunk-size: 20
|
||||
|
||||
models:
|
||||
e5-default:
|
||||
|
|
@ -20,6 +30,6 @@ dip:
|
|||
dimensions: 1024
|
||||
distance-metric: COSINE
|
||||
supports-query-embedding-mode: true
|
||||
supports-batch: false
|
||||
supports-batch: true
|
||||
max-input-chars: 8192
|
||||
active: true
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
|||
import at.procon.dip.embedding.model.EmbeddingRequest;
|
||||
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.sun.net.httpserver.HttpExchange;
|
||||
import com.sun.net.httpserver.HttpServer;
|
||||
|
|
@ -17,12 +18,15 @@ import java.nio.charset.StandardCharsets;
|
|||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class VectorSyncHttpEmbeddingProviderTest {
|
||||
|
||||
private final ObjectMapper objectMapper = new ObjectMapper();
|
||||
private HttpServer server;
|
||||
private final AtomicReference<String> lastBatchBody = new AtomicReference<>();
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
|
|
@ -37,13 +41,16 @@ class VectorSyncHttpEmbeddingProviderTest {
|
|||
server.createContext("/vector-sync", this::handleVectorSync);
|
||||
server.start();
|
||||
|
||||
var provider = new VectorSyncHttpEmbeddingProvider(new ObjectMapper());
|
||||
var provider = new VectorSyncHttpEmbeddingProvider(objectMapper);
|
||||
var config = ResolvedEmbeddingProviderConfig.builder()
|
||||
.key("vector-sync-local")
|
||||
.providerType("http-vector-sync")
|
||||
.baseUrl("http://localhost:" + server.getAddress().getPort())
|
||||
.readTimeout(Duration.ofSeconds(5))
|
||||
.headers(Map.of("X-Client", "dip-test"))
|
||||
.batchTruncateText(false)
|
||||
.batchTruncateLength(512)
|
||||
.batchChunkSize(20)
|
||||
.build();
|
||||
var model = new EmbeddingModelDescriptor(
|
||||
"e5-default",
|
||||
|
|
@ -70,6 +77,100 @@ class VectorSyncHttpEmbeddingProviderTest {
|
|||
assertThat(result.tokenCount()).isEqualTo(9);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCallVectorizeBatchEndpointUsingConfiguredBatchRequestSettings() throws Exception {
|
||||
server = HttpServer.create(new InetSocketAddress(0), 0);
|
||||
server.createContext("/vectorize-batch", this::handleVectorizeBatch);
|
||||
server.start();
|
||||
|
||||
var provider = new VectorSyncHttpEmbeddingProvider(objectMapper);
|
||||
var config = ResolvedEmbeddingProviderConfig.builder()
|
||||
.key("vector-sync-local")
|
||||
.providerType("http-vector-sync")
|
||||
.baseUrl("http://localhost:" + server.getAddress().getPort())
|
||||
.readTimeout(Duration.ofSeconds(5))
|
||||
.headers(Map.of("X-Client", "dip-test"))
|
||||
.batchTruncateText(true)
|
||||
.batchTruncateLength(768)
|
||||
.batchChunkSize(33)
|
||||
.build();
|
||||
var model = new EmbeddingModelDescriptor(
|
||||
"e5-default",
|
||||
"vector-sync-local",
|
||||
"intfloat/multilingual-e5-large",
|
||||
3,
|
||||
DistanceMetric.COSINE,
|
||||
true,
|
||||
true,
|
||||
8192,
|
||||
true
|
||||
);
|
||||
var request = EmbeddingRequest.builder()
|
||||
.modelKey("e5-default")
|
||||
.useCase(EmbeddingUseCase.DOCUMENT)
|
||||
.texts(List.of("First text", "Second text"))
|
||||
.providerOptions(Map.of())
|
||||
.build();
|
||||
|
||||
var result = provider.embedDocuments(config, model, request);
|
||||
|
||||
assertThat(result.vectors()).hasSize(2);
|
||||
assertThat(result.vectors().get(0)).containsExactly(0.1f, 0.2f, 0.3f);
|
||||
assertThat(result.vectors().get(1)).containsExactly(0.4f, 0.5f, 0.6f);
|
||||
assertThat(result.tokenCount()).isEqualTo(12);
|
||||
|
||||
JsonNode requestBody = objectMapper.readTree(lastBatchBody.get());
|
||||
assertThat(requestBody.get("truncate_text").asBoolean()).isTrue();
|
||||
assertThat(requestBody.get("truncate_length").asInt()).isEqualTo(768);
|
||||
assertThat(requestBody.get("chunk_size").asInt()).isEqualTo(33);
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldAllowPerRequestBatchOverrides() throws Exception {
|
||||
server = HttpServer.create(new InetSocketAddress(0), 0);
|
||||
server.createContext("/vectorize-batch", this::handleVectorizeBatch);
|
||||
server.start();
|
||||
|
||||
var provider = new VectorSyncHttpEmbeddingProvider(objectMapper);
|
||||
var config = ResolvedEmbeddingProviderConfig.builder()
|
||||
.key("vector-sync-local")
|
||||
.providerType("http-vector-sync")
|
||||
.baseUrl("http://localhost:" + server.getAddress().getPort())
|
||||
.readTimeout(Duration.ofSeconds(5))
|
||||
.batchTruncateText(false)
|
||||
.batchTruncateLength(512)
|
||||
.batchChunkSize(20)
|
||||
.build();
|
||||
var model = new EmbeddingModelDescriptor(
|
||||
"e5-default",
|
||||
"vector-sync-local",
|
||||
"intfloat/multilingual-e5-large",
|
||||
3,
|
||||
DistanceMetric.COSINE,
|
||||
true,
|
||||
true,
|
||||
8192,
|
||||
true
|
||||
);
|
||||
var request = EmbeddingRequest.builder()
|
||||
.modelKey("e5-default")
|
||||
.useCase(EmbeddingUseCase.DOCUMENT)
|
||||
.texts(List.of("First text", "Second text"))
|
||||
.providerOptions(Map.of(
|
||||
"truncate_text", true,
|
||||
"truncate_length", 1024,
|
||||
"chunk_size", 44
|
||||
))
|
||||
.build();
|
||||
|
||||
provider.embedDocuments(config, model, request);
|
||||
|
||||
JsonNode requestBody = objectMapper.readTree(lastBatchBody.get());
|
||||
assertThat(requestBody.get("truncate_text").asBoolean()).isTrue();
|
||||
assertThat(requestBody.get("truncate_length").asInt()).isEqualTo(1024);
|
||||
assertThat(requestBody.get("chunk_size").asInt()).isEqualTo(44);
|
||||
}
|
||||
|
||||
private void handleVectorSync(HttpExchange exchange) throws IOException {
|
||||
String body;
|
||||
try (InputStream in = exchange.getRequestBody()) {
|
||||
|
|
@ -81,16 +182,69 @@ class VectorSyncHttpEmbeddingProviderTest {
|
|||
assertThat(body).contains("\"text\":\"This is a sample text to vectorize\"");
|
||||
assertThat(exchange.getRequestHeaders().getFirst("X-Client")).isEqualTo("dip-test");
|
||||
|
||||
String json = "{"
|
||||
+ "\"runtime_ms\":472.49,"
|
||||
+ "\"vector\":[0.1,0.2,0.3],"
|
||||
+ "\"incomplete\":false,"
|
||||
+ "\"combined_vector\":null,"
|
||||
+ "\"token_count\":9,"
|
||||
+ "\"model\":\"intfloat/multilingual-e5-large\","
|
||||
+ "\"max_seq_length\":512"
|
||||
+ "}";
|
||||
respondJson(exchange, """
|
||||
{
|
||||
"runtime_ms": 472.49,
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
"incomplete": false,
|
||||
"combined_vector": null,
|
||||
"token_count": 9,
|
||||
"model": "intfloat/multilingual-e5-large",
|
||||
"max_seq_length": 512
|
||||
}
|
||||
""");
|
||||
}
|
||||
|
||||
private void handleVectorizeBatch(HttpExchange exchange) throws IOException {
|
||||
String body;
|
||||
try (InputStream in = exchange.getRequestBody()) {
|
||||
body = new String(in.readAllBytes(), StandardCharsets.UTF_8);
|
||||
}
|
||||
lastBatchBody.set(body);
|
||||
|
||||
assertThat(exchange.getRequestMethod()).isEqualTo("POST");
|
||||
JsonNode requestBody = objectMapper.readTree(body);
|
||||
assertThat(requestBody.get("model").asText()).isEqualTo("intfloat/multilingual-e5-large");
|
||||
assertThat(requestBody.get("items")).hasSize(2);
|
||||
|
||||
String firstId = requestBody.get("items").get(0).get("id").asText();
|
||||
String secondId = requestBody.get("items").get(1).get("id").asText();
|
||||
|
||||
respondJson(exchange, """
|
||||
{
|
||||
"model": "intfloat/multilingual-e5-large",
|
||||
"count": 2,
|
||||
"results": [
|
||||
{
|
||||
"id": "%s",
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
"token_count": 5,
|
||||
"runtime_ms": 0.0,
|
||||
"incomplete": false,
|
||||
"combined_vector": null,
|
||||
"truncated": false,
|
||||
"truncate_length": 512,
|
||||
"model": "intfloat/multilingual-e5-large",
|
||||
"max_seq_length": 512
|
||||
},
|
||||
{
|
||||
"id": "%s",
|
||||
"vector": [0.4, 0.5, 0.6],
|
||||
"token_count": 7,
|
||||
"runtime_ms": 0.0,
|
||||
"incomplete": false,
|
||||
"combined_vector": null,
|
||||
"truncated": false,
|
||||
"truncate_length": 512,
|
||||
"model": "intfloat/multilingual-e5-large",
|
||||
"max_seq_length": 512
|
||||
}
|
||||
]
|
||||
}
|
||||
""".formatted(firstId, secondId));
|
||||
}
|
||||
|
||||
private void respondJson(HttpExchange exchange, String json) throws IOException {
|
||||
byte[] response = json.getBytes(StandardCharsets.UTF_8);
|
||||
exchange.getResponseHeaders().add("Content-Type", "application/json");
|
||||
exchange.sendResponseHeaders(200, response.length);
|
||||
|
|
|
|||
|
|
@ -1,10 +1,6 @@
|
|||
package at.procon.dip.embedding.service;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import at.procon.dip.domain.document.DistanceMetric;
|
||||
import at.procon.dip.domain.document.entity.Document;
|
||||
import at.procon.dip.domain.document.entity.DocumentEmbedding;
|
||||
import at.procon.dip.domain.document.entity.DocumentTextRepresentation;
|
||||
|
|
@ -12,20 +8,25 @@ import at.procon.dip.domain.document.repository.DocumentTextRepresentationReposi
|
|||
import at.procon.dip.embedding.config.EmbeddingProperties;
|
||||
import at.procon.dip.embedding.job.entity.EmbeddingJob;
|
||||
import at.procon.dip.embedding.job.service.EmbeddingJobService;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobStatus;
|
||||
import at.procon.dip.embedding.model.EmbeddingJobType;
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.model.*;
|
||||
import at.procon.dip.embedding.policy.EmbeddingSelectionPolicy;
|
||||
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
||||
import at.procon.dip.domain.document.DistanceMetric;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.jupiter.MockitoExtension;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.AdditionalMatchers.aryEq;
|
||||
import static org.mockito.ArgumentMatchers.*;
|
||||
import static org.mockito.Mockito.*;
|
||||
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
class RepresentationEmbeddingOrchestratorTest {
|
||||
|
||||
|
|
@ -38,22 +39,27 @@ class RepresentationEmbeddingOrchestratorTest {
|
|||
@Mock
|
||||
private DocumentTextRepresentationRepository representationRepository;
|
||||
@Mock
|
||||
private EmbeddingSelectionPolicy selectionPolicy;
|
||||
@Mock
|
||||
private EmbeddingModelRegistry modelRegistry;
|
||||
|
||||
private EmbeddingProperties properties;
|
||||
private RepresentationEmbeddingOrchestrator orchestrator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
EmbeddingProperties properties = new EmbeddingProperties();
|
||||
properties = new EmbeddingProperties();
|
||||
properties.setEnabled(true);
|
||||
properties.getJobs().setEnabled(true);
|
||||
properties.getJobs().setBatchSize(10);
|
||||
properties.getJobs().setExecutionBatchSize(10);
|
||||
properties.getIndexing().setFallbackMaxInputChars(8192);
|
||||
orchestrator = new RepresentationEmbeddingOrchestrator(
|
||||
jobService,
|
||||
executionService,
|
||||
persistenceService,
|
||||
representationRepository,
|
||||
selectionPolicy,
|
||||
modelRegistry,
|
||||
properties
|
||||
);
|
||||
|
|
@ -80,16 +86,17 @@ class RepresentationEmbeddingOrchestratorTest {
|
|||
.document(document)
|
||||
.textBody("District heating optimization strategy")
|
||||
.build();
|
||||
when(representationRepository.findById(representationId)).thenReturn(java.util.Optional.of(representation));
|
||||
when(modelRegistry.getRequired("mock-search"))
|
||||
.thenReturn(new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16,
|
||||
DistanceMetric.COSINE, true, false, 4096, true));
|
||||
EmbeddingModelDescriptor model = new EmbeddingModelDescriptor(
|
||||
"mock-search", "mock-provider", "mock-search", 16,
|
||||
DistanceMetric.COSINE, true, false, 4096, true
|
||||
);
|
||||
when(representationRepository.findById(representationId)).thenReturn(Optional.of(representation));
|
||||
when(modelRegistry.getRequired("mock-search")).thenReturn(model);
|
||||
when(persistenceService.ensurePending(representationId, "mock-search"))
|
||||
.thenReturn(DocumentEmbedding.builder().id(embeddingId).build());
|
||||
when(executionService.embedTexts(eq("mock-search"), any(), any()))
|
||||
when(executionService.embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), any()))
|
||||
.thenReturn(new EmbeddingProviderResult(
|
||||
new EmbeddingModelDescriptor("mock-search", "mock-provider", "mock-search", 16,
|
||||
DistanceMetric.COSINE, true, false, 4096, true),
|
||||
model,
|
||||
List.of(new float[]{0.1f, 0.2f}),
|
||||
List.of(),
|
||||
"req-1",
|
||||
|
|
@ -102,4 +109,121 @@ class RepresentationEmbeddingOrchestratorTest {
|
|||
verify(persistenceService).saveCompleted(eq(embeddingId), any(EmbeddingProviderResult.class));
|
||||
verify(jobService).markCompleted(job.getId(), "req-1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void processNextReadyBatch_should_group_batchable_jobs_and_call_provider_once() {
|
||||
properties.getJobs().setProcessInBatches(true);
|
||||
properties.getJobs().setExecutionBatchSize(10);
|
||||
|
||||
UUID documentId = UUID.randomUUID();
|
||||
UUID representationId1 = UUID.randomUUID();
|
||||
UUID representationId2 = UUID.randomUUID();
|
||||
UUID embeddingId1 = UUID.randomUUID();
|
||||
UUID embeddingId2 = UUID.randomUUID();
|
||||
|
||||
EmbeddingJob job1 = EmbeddingJob.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.documentId(documentId)
|
||||
.representationId(representationId1)
|
||||
.modelKey("e5-default")
|
||||
.jobType(EmbeddingJobType.DOCUMENT_EMBED)
|
||||
.status(EmbeddingJobStatus.IN_PROGRESS)
|
||||
.attemptCount(1)
|
||||
.build();
|
||||
EmbeddingJob job2 = EmbeddingJob.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.documentId(documentId)
|
||||
.representationId(representationId2)
|
||||
.modelKey("e5-default")
|
||||
.jobType(EmbeddingJobType.DOCUMENT_EMBED)
|
||||
.status(EmbeddingJobStatus.IN_PROGRESS)
|
||||
.attemptCount(1)
|
||||
.build();
|
||||
|
||||
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job1, job2));
|
||||
|
||||
Document document = Document.builder().id(documentId).build();
|
||||
when(representationRepository.findById(representationId1)).thenReturn(Optional.of(
|
||||
DocumentTextRepresentation.builder().id(representationId1).document(document).textBody("alpha").build()
|
||||
));
|
||||
when(representationRepository.findById(representationId2)).thenReturn(Optional.of(
|
||||
DocumentTextRepresentation.builder().id(representationId2).document(document).textBody("beta").build()
|
||||
));
|
||||
|
||||
EmbeddingModelDescriptor model = new EmbeddingModelDescriptor(
|
||||
"e5-default", "vector-sync-e5", "intfloat/multilingual-e5-large", 3,
|
||||
DistanceMetric.COSINE, true, true, 4096, true
|
||||
);
|
||||
when(modelRegistry.getRequired("e5-default")).thenReturn(model);
|
||||
when(persistenceService.ensurePending(representationId1, "e5-default"))
|
||||
.thenReturn(DocumentEmbedding.builder().id(embeddingId1).build());
|
||||
when(persistenceService.ensurePending(representationId2, "e5-default"))
|
||||
.thenReturn(DocumentEmbedding.builder().id(embeddingId2).build());
|
||||
when(executionService.embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), any()))
|
||||
.thenReturn(new EmbeddingProviderResult(
|
||||
model,
|
||||
List.of(new float[]{0.1f, 0.2f, 0.3f}, new float[]{0.4f, 0.5f, 0.6f}),
|
||||
List.of(),
|
||||
"batch-req-1",
|
||||
21
|
||||
));
|
||||
|
||||
int processed = orchestrator.processNextReadyBatch();
|
||||
|
||||
assertThat(processed).isEqualTo(2);
|
||||
verify(persistenceService).markProcessing(embeddingId1);
|
||||
verify(persistenceService).markProcessing(embeddingId2);
|
||||
ArgumentCaptor<List<String>> textsCaptor = ArgumentCaptor.forClass(List.class);
|
||||
verify(executionService, times(1)).embedTexts(eq("e5-default"), eq(EmbeddingUseCase.DOCUMENT), textsCaptor.capture());
|
||||
assertThat(textsCaptor.getValue()).containsExactly("alpha", "beta");
|
||||
verify(persistenceService).saveCompleted(eq(embeddingId1), aryEq(new float[]{0.1f, 0.2f, 0.3f}), eq(null));
|
||||
verify(persistenceService).saveCompleted(eq(embeddingId2), aryEq(new float[]{0.4f, 0.5f, 0.6f}), eq(null));
|
||||
verify(jobService).markCompleted(job1.getId(), "batch-req-1");
|
||||
verify(jobService).markCompleted(job2.getId(), "batch-req-1");
|
||||
}
|
||||
|
||||
@Test
|
||||
void processNextReadyBatch_should_fall_back_to_single_processing_for_non_batch_model() {
|
||||
properties.getJobs().setProcessInBatches(true);
|
||||
|
||||
UUID documentId = UUID.randomUUID();
|
||||
UUID representationId = UUID.randomUUID();
|
||||
UUID embeddingId = UUID.randomUUID();
|
||||
EmbeddingJob job = EmbeddingJob.builder()
|
||||
.id(UUID.randomUUID())
|
||||
.documentId(documentId)
|
||||
.representationId(representationId)
|
||||
.modelKey("mock-search")
|
||||
.jobType(EmbeddingJobType.DOCUMENT_EMBED)
|
||||
.status(EmbeddingJobStatus.IN_PROGRESS)
|
||||
.attemptCount(1)
|
||||
.build();
|
||||
|
||||
when(jobService.claimNextReadyJobs(10)).thenReturn(List.of(job));
|
||||
Document document = Document.builder().id(documentId).build();
|
||||
when(representationRepository.findById(representationId)).thenReturn(Optional.of(
|
||||
DocumentTextRepresentation.builder().id(representationId).document(document).textBody("gamma").build()
|
||||
));
|
||||
EmbeddingModelDescriptor model = new EmbeddingModelDescriptor(
|
||||
"mock-search", "mock-provider", "mock-search", 2,
|
||||
DistanceMetric.COSINE, true, false, 4096, true
|
||||
);
|
||||
when(modelRegistry.getRequired("mock-search")).thenReturn(model);
|
||||
when(persistenceService.ensurePending(representationId, "mock-search"))
|
||||
.thenReturn(DocumentEmbedding.builder().id(embeddingId).build());
|
||||
when(executionService.embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), any()))
|
||||
.thenReturn(new EmbeddingProviderResult(
|
||||
model,
|
||||
List.of(new float[]{0.7f, 0.8f}),
|
||||
List.of(),
|
||||
"req-2",
|
||||
7
|
||||
));
|
||||
|
||||
orchestrator.processNextReadyBatch();
|
||||
|
||||
verify(executionService, times(1)).embedTexts(eq("mock-search"), eq(EmbeddingUseCase.DOCUMENT), eq(List.of("gamma")));
|
||||
verify(persistenceService, never()).saveCompleted(eq(embeddingId), any(float[].class), eq(null));
|
||||
verify(persistenceService).saveCompleted(eq(embeddingId), any(EmbeddingProviderResult.class));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue