|
|
|
|
@ -8,8 +8,6 @@ import at.procon.dip.embedding.provider.EmbeddingProvider;
|
|
|
|
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
|
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
|
|
import java.io.IOException;
|
|
|
|
|
import java.lang.reflect.Field;
|
|
|
|
|
import java.lang.reflect.Method;
|
|
|
|
|
import java.net.http.HttpResponse;
|
|
|
|
|
import java.util.ArrayList;
|
|
|
|
|
import java.util.HashMap;
|
|
|
|
|
@ -24,16 +22,6 @@ import org.springframework.stereotype.Component;
|
|
|
|
|
* Supported endpoints:
|
|
|
|
|
* POST {baseUrl}/vector-sync - single text
|
|
|
|
|
* POST {baseUrl}/vectorize-batch - multiple texts
|
|
|
|
|
*
|
|
|
|
|
* Batch settings are resolved from provider config if present, otherwise defaults are used.
|
|
|
|
|
* Supported property keys:
|
|
|
|
|
* - vectorize-batch.truncate-text
|
|
|
|
|
* - vectorize-batch.truncate-length
|
|
|
|
|
* - vectorize-batch.chunk-size
|
|
|
|
|
*
|
|
|
|
|
* Also accepted as fallbacks:
|
|
|
|
|
* - truncate_text / truncate-length / chunk_size
|
|
|
|
|
* - truncateText / truncateLength / chunkSize
|
|
|
|
|
*/
|
|
|
|
|
@Component
|
|
|
|
|
public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProviderSupport implements EmbeddingProvider {
|
|
|
|
|
@ -105,7 +93,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|
|
|
|
try {
|
|
|
|
|
return request.texts().size() == 1
|
|
|
|
|
? executeSingle(providerConfig, model, request.texts().getFirst())
|
|
|
|
|
: executeBatch(providerConfig, model, request.texts());
|
|
|
|
|
: executeBatch(providerConfig, model, request);
|
|
|
|
|
} catch (InterruptedException e) {
|
|
|
|
|
Thread.currentThread().interrupt();
|
|
|
|
|
throw new IllegalStateException("Embedding provider call interrupted", e);
|
|
|
|
|
@ -137,22 +125,20 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|
|
|
|
|
|
|
|
|
private EmbeddingProviderResult executeBatch(ResolvedEmbeddingProviderConfig providerConfig,
|
|
|
|
|
EmbeddingModelDescriptor model,
|
|
|
|
|
List<String> texts) throws IOException, InterruptedException {
|
|
|
|
|
boolean truncateText = resolveBooleanProperty(providerConfig, TRUNCATE_TEXT_KEYS, DEFAULT_TRUNCATE_TEXT);
|
|
|
|
|
int truncateLength = resolveIntProperty(providerConfig, TRUNCATE_LENGTH_KEYS, DEFAULT_TRUNCATE_LENGTH);
|
|
|
|
|
int chunkSize = resolveIntProperty(providerConfig, CHUNK_SIZE_KEYS, DEFAULT_CHUNK_SIZE);
|
|
|
|
|
EmbeddingRequest request) throws IOException, InterruptedException {
|
|
|
|
|
BatchRequestSettings settings = resolveBatchRequestSettings(providerConfig, request.providerOptions());
|
|
|
|
|
|
|
|
|
|
if (truncateLength <= 0) {
|
|
|
|
|
if (settings.truncateLength() <= 0) {
|
|
|
|
|
throw new IllegalArgumentException("Batch truncate length must be > 0");
|
|
|
|
|
}
|
|
|
|
|
if (chunkSize <= 0) {
|
|
|
|
|
if (settings.chunkSize() <= 0) {
|
|
|
|
|
throw new IllegalArgumentException("Batch chunk size must be > 0");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
List<String> requestOrder = new ArrayList<>(texts.size());
|
|
|
|
|
List<VectorizeBatchItemRequest> items = new ArrayList<>(texts.size());
|
|
|
|
|
List<String> requestOrder = new ArrayList<>(request.texts().size());
|
|
|
|
|
List<VectorizeBatchItemRequest> items = new ArrayList<>(request.texts().size());
|
|
|
|
|
|
|
|
|
|
for (String text : texts) {
|
|
|
|
|
for (String text : request.texts()) {
|
|
|
|
|
String id = UUID.randomUUID().toString();
|
|
|
|
|
requestOrder.add(id);
|
|
|
|
|
items.add(new VectorizeBatchItemRequest(id, text));
|
|
|
|
|
@ -163,9 +149,9 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|
|
|
|
"/vectorize-batch",
|
|
|
|
|
new VectorizeBatchRequest(
|
|
|
|
|
model.providerModelKey(),
|
|
|
|
|
truncateText,
|
|
|
|
|
truncateLength,
|
|
|
|
|
chunkSize,
|
|
|
|
|
settings.truncateText(),
|
|
|
|
|
settings.truncateLength(),
|
|
|
|
|
settings.chunkSize(),
|
|
|
|
|
items
|
|
|
|
|
)
|
|
|
|
|
);
|
|
|
|
|
@ -180,7 +166,7 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|
|
|
|
resultById.put(result.id, result);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
List<float[]> vectors = new ArrayList<>(texts.size());
|
|
|
|
|
List<float[]> vectors = new ArrayList<>(request.texts().size());
|
|
|
|
|
int totalTokenCount = 0;
|
|
|
|
|
boolean hasAnyTokenCount = false;
|
|
|
|
|
|
|
|
|
|
@ -207,38 +193,35 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private float[] extractVector(List<Float> vector,
|
|
|
|
|
List<Float> combinedVector,
|
|
|
|
|
EmbeddingModelDescriptor model) {
|
|
|
|
|
float[] resolved;
|
|
|
|
|
|
|
|
|
|
if (combinedVector != null && !combinedVector.isEmpty()) {
|
|
|
|
|
resolved = toArray(combinedVector);
|
|
|
|
|
} else if (vector != null && !vector.isEmpty()) {
|
|
|
|
|
resolved = toArray(vector);
|
|
|
|
|
} else {
|
|
|
|
|
throw new IllegalStateException("Embedding provider returned no vector");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (model.dimensions() > 0 && resolved.length != model.dimensions()) {
|
|
|
|
|
throw new IllegalStateException(
|
|
|
|
|
"Embedding provider returned dimension %d for model %s, expected %d"
|
|
|
|
|
.formatted(resolved.length, model.modelKey(), model.dimensions())
|
|
|
|
|
private BatchRequestSettings resolveBatchRequestSettings(ResolvedEmbeddingProviderConfig providerConfig,
|
|
|
|
|
Map<String, Object> providerOptions) {
|
|
|
|
|
boolean truncateText = resolveBooleanOption(
|
|
|
|
|
providerOptions,
|
|
|
|
|
TRUNCATE_TEXT_KEYS,
|
|
|
|
|
providerConfig.batchTruncateText() != null ? providerConfig.batchTruncateText() : DEFAULT_TRUNCATE_TEXT
|
|
|
|
|
);
|
|
|
|
|
int truncateLength = resolveIntOption(
|
|
|
|
|
providerOptions,
|
|
|
|
|
TRUNCATE_LENGTH_KEYS,
|
|
|
|
|
providerConfig.batchTruncateLength() != null ? providerConfig.batchTruncateLength() : DEFAULT_TRUNCATE_LENGTH
|
|
|
|
|
);
|
|
|
|
|
int chunkSize = resolveIntOption(
|
|
|
|
|
providerOptions,
|
|
|
|
|
CHUNK_SIZE_KEYS,
|
|
|
|
|
providerConfig.batchChunkSize() != null ? providerConfig.batchChunkSize() : DEFAULT_CHUNK_SIZE
|
|
|
|
|
);
|
|
|
|
|
return new BatchRequestSettings(truncateText, truncateLength, chunkSize);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return resolved;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private boolean resolveBooleanProperty(ResolvedEmbeddingProviderConfig providerConfig,
|
|
|
|
|
private boolean resolveBooleanOption(Map<String, Object> providerOptions,
|
|
|
|
|
List<String> keys,
|
|
|
|
|
boolean defaultValue) {
|
|
|
|
|
Object raw = resolveProviderConfigValue(providerConfig, keys);
|
|
|
|
|
Object raw = resolveOption(providerOptions, keys);
|
|
|
|
|
if (raw == null) {
|
|
|
|
|
return defaultValue;
|
|
|
|
|
}
|
|
|
|
|
if (raw instanceof Boolean b) {
|
|
|
|
|
return b;
|
|
|
|
|
if (raw instanceof Boolean booleanValue) {
|
|
|
|
|
return booleanValue;
|
|
|
|
|
}
|
|
|
|
|
String normalized = String.valueOf(raw).trim();
|
|
|
|
|
if (normalized.isEmpty()) {
|
|
|
|
|
@ -247,15 +230,15 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|
|
|
|
return Boolean.parseBoolean(normalized);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private int resolveIntProperty(ResolvedEmbeddingProviderConfig providerConfig,
|
|
|
|
|
private int resolveIntOption(Map<String, Object> providerOptions,
|
|
|
|
|
List<String> keys,
|
|
|
|
|
int defaultValue) {
|
|
|
|
|
Object raw = resolveProviderConfigValue(providerConfig, keys);
|
|
|
|
|
Object raw = resolveOption(providerOptions, keys);
|
|
|
|
|
if (raw == null) {
|
|
|
|
|
return defaultValue;
|
|
|
|
|
}
|
|
|
|
|
if (raw instanceof Number n) {
|
|
|
|
|
return n.intValue();
|
|
|
|
|
if (raw instanceof Number number) {
|
|
|
|
|
return number.intValue();
|
|
|
|
|
}
|
|
|
|
|
String normalized = String.valueOf(raw).trim();
|
|
|
|
|
if (normalized.isEmpty()) {
|
|
|
|
|
@ -264,89 +247,42 @@ public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProvid
|
|
|
|
|
return Integer.parseInt(normalized);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
|
|
private Object resolveProviderConfigValue(ResolvedEmbeddingProviderConfig providerConfig,
|
|
|
|
|
List<String> keys) {
|
|
|
|
|
List<Object> containers = new ArrayList<>();
|
|
|
|
|
containers.add(providerConfig);
|
|
|
|
|
|
|
|
|
|
addIfNonNull(containers, invokeNoArg(providerConfig, "properties"));
|
|
|
|
|
addIfNonNull(containers, invokeNoArg(providerConfig, "providerProperties"));
|
|
|
|
|
addIfNonNull(containers, invokeNoArg(providerConfig, "config"));
|
|
|
|
|
addIfNonNull(containers, invokeNoArg(providerConfig, "settings"));
|
|
|
|
|
addIfNonNull(containers, invokeNoArg(providerConfig, "options"));
|
|
|
|
|
addIfNonNull(containers, readField(providerConfig, "properties"));
|
|
|
|
|
addIfNonNull(containers, readField(providerConfig, "providerProperties"));
|
|
|
|
|
addIfNonNull(containers, readField(providerConfig, "config"));
|
|
|
|
|
addIfNonNull(containers, readField(providerConfig, "settings"));
|
|
|
|
|
addIfNonNull(containers, readField(providerConfig, "options"));
|
|
|
|
|
|
|
|
|
|
for (Object container : containers) {
|
|
|
|
|
if (container instanceof Map<?, ?> map) {
|
|
|
|
|
for (String key : keys) {
|
|
|
|
|
if (map.containsKey(key)) {
|
|
|
|
|
return map.get(key);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
private Object resolveOption(Map<String, Object> providerOptions, List<String> keys) {
|
|
|
|
|
if (providerOptions == null || providerOptions.isEmpty()) {
|
|
|
|
|
return null;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (String key : keys) {
|
|
|
|
|
Object value = invokeStringArg(container, "get", key);
|
|
|
|
|
if (value != null) {
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
value = invokeStringArg(container, "getProperty", key);
|
|
|
|
|
if (value != null) {
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
value = invokeStringArg(container, "property", key);
|
|
|
|
|
if (value != null) {
|
|
|
|
|
return value;
|
|
|
|
|
if (providerOptions.containsKey(key)) {
|
|
|
|
|
return providerOptions.get(key);
|
|
|
|
|
}
|
|
|
|
|
value = invokeStringArg(container, "option", key);
|
|
|
|
|
if (value != null) {
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return null;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private void addIfNonNull(List<Object> containers, Object value) {
|
|
|
|
|
if (value != null) {
|
|
|
|
|
containers.add(value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
private float[] extractVector(List<Float> vector,
|
|
|
|
|
List<Float> combinedVector,
|
|
|
|
|
EmbeddingModelDescriptor model) {
|
|
|
|
|
float[] resolved;
|
|
|
|
|
|
|
|
|
|
private Object invokeNoArg(Object target, String methodName) {
|
|
|
|
|
try {
|
|
|
|
|
Method method = target.getClass().getMethod(methodName);
|
|
|
|
|
method.setAccessible(true);
|
|
|
|
|
return method.invoke(target);
|
|
|
|
|
} catch (Exception ignored) {
|
|
|
|
|
return null;
|
|
|
|
|
}
|
|
|
|
|
if (combinedVector != null && !combinedVector.isEmpty()) {
|
|
|
|
|
resolved = toArray(combinedVector);
|
|
|
|
|
} else if (vector != null && !vector.isEmpty()) {
|
|
|
|
|
resolved = toArray(vector);
|
|
|
|
|
} else {
|
|
|
|
|
throw new IllegalStateException("Embedding provider returned no vector");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private Object invokeStringArg(Object target, String methodName, String arg) {
|
|
|
|
|
try {
|
|
|
|
|
Method method = target.getClass().getMethod(methodName, String.class);
|
|
|
|
|
method.setAccessible(true);
|
|
|
|
|
return method.invoke(target, arg);
|
|
|
|
|
} catch (Exception ignored) {
|
|
|
|
|
return null;
|
|
|
|
|
}
|
|
|
|
|
if (model.dimensions() > 0 && resolved.length != model.dimensions()) {
|
|
|
|
|
throw new IllegalStateException(
|
|
|
|
|
"Embedding provider returned dimension %d for model %s, expected %d"
|
|
|
|
|
.formatted(resolved.length, model.modelKey(), model.dimensions())
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private Object readField(Object target, String fieldName) {
|
|
|
|
|
try {
|
|
|
|
|
Field field = target.getClass().getDeclaredField(fieldName);
|
|
|
|
|
field.setAccessible(true);
|
|
|
|
|
return field.get(target);
|
|
|
|
|
} catch (Exception ignored) {
|
|
|
|
|
return null;
|
|
|
|
|
return resolved;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private record BatchRequestSettings(boolean truncateText, int truncateLength, int chunkSize) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private record VectorSyncRequest(
|
|
|
|
|
|