vectorization using py temporal service
This commit is contained in:
parent
177c61803e
commit
3913ed83e8
|
|
@ -0,0 +1,39 @@
|
|||
# Vector-sync HTTP embedding provider
|
||||
|
||||
This patch adds a new provider type:
|
||||
|
||||
- `http-vector-sync`
|
||||
|
||||
## Request
|
||||
Endpoint:
|
||||
- `POST {baseUrl}/vector-sync`
|
||||
|
||||
Request body:
|
||||
```json
|
||||
{
|
||||
"model": "intfloat/multilingual-e5-large",
|
||||
"text": "This is a sample text to vectorize"
|
||||
}
|
||||
```
|
||||
|
||||
## Response
|
||||
```json
|
||||
{
|
||||
"runtime_ms": 472.49,
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
"incomplete": false,
|
||||
"combined_vector": null,
|
||||
"token_count": 9,
|
||||
"model": "intfloat/multilingual-e5-large",
|
||||
"max_seq_length": 512
|
||||
}
|
||||
```
|
||||
|
||||
## Notes
|
||||
- supports a single text per request
|
||||
- works for both document and query embeddings
|
||||
- validates returned vector dimension against the configured embedding model
|
||||
- keeps the existing `/embed` provider in place as `http-json`
|
||||
|
||||
## Example config
|
||||
See `application-new-example-vector-sync-provider.yml`.
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
package at.procon.dip.embedding.provider.http;
|
||||
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
abstract class AbstractHttpEmbeddingProviderSupport {
|
||||
|
||||
protected final ObjectMapper objectMapper;
|
||||
protected final HttpClient httpClient = HttpClient.newBuilder()
|
||||
.version(HttpClient.Version.HTTP_1_1)
|
||||
.build();
|
||||
|
||||
protected String trimTrailingSlash(String value) {
|
||||
if (value == null || value.isBlank()) {
|
||||
throw new IllegalArgumentException("Embedding provider baseUrl must be configured");
|
||||
}
|
||||
return value.endsWith("/") ? value.substring(0, value.length() - 1) : value;
|
||||
}
|
||||
|
||||
protected HttpResponse<String> postJson(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
String path,
|
||||
Object body) throws IOException, InterruptedException {
|
||||
HttpRequest.Builder builder = HttpRequest.newBuilder()
|
||||
.uri(URI.create(trimTrailingSlash(providerConfig.baseUrl()) + path))
|
||||
.timeout(providerConfig.readTimeout() == null ? Duration.ofSeconds(60) : providerConfig.readTimeout())
|
||||
.header("Content-Type", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(
|
||||
objectMapper.writeValueAsString(body),
|
||||
StandardCharsets.UTF_8
|
||||
));
|
||||
|
||||
if (providerConfig.apiKey() != null && !providerConfig.apiKey().isBlank()) {
|
||||
builder.header("Authorization", "Bearer " + providerConfig.apiKey());
|
||||
}
|
||||
if (providerConfig.headers() != null) {
|
||||
providerConfig.headers().forEach(builder::header);
|
||||
}
|
||||
|
||||
HttpResponse<String> response = httpClient.send(
|
||||
builder.build(),
|
||||
HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)
|
||||
);
|
||||
if (response.statusCode() / 100 != 2) {
|
||||
throw new IllegalStateException(
|
||||
"Embedding provider returned status %d: %s".formatted(response.statusCode(), response.body())
|
||||
);
|
||||
}
|
||||
return response;
|
||||
}
|
||||
|
||||
protected float[] toArray(List<Float> embedding) {
|
||||
float[] result = new float[embedding.size()];
|
||||
for (int i = 0; i < embedding.size(); i++) {
|
||||
result[i] = embedding.get(i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
@ -6,30 +6,25 @@ import at.procon.dip.embedding.model.EmbeddingRequest;
|
|||
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
import at.procon.dip.embedding.provider.EmbeddingProvider;
|
||||
import at.procon.ted.camel.VectorizationRoute;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.*;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.apache.camel.Exchange;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* Existing HTTP/JSON embedding provider using the /embed contract.
|
||||
*/
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class ExternalHttpEmbeddingProvider implements EmbeddingProvider {
|
||||
|
||||
public class ExternalHttpEmbeddingProvider extends AbstractHttpEmbeddingProviderSupport implements EmbeddingProvider {
|
||||
private static final String PROVIDER_TYPE = "http-json";
|
||||
|
||||
private final ObjectMapper objectMapper;
|
||||
private final HttpClient httpClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build();
|
||||
public ExternalHttpEmbeddingProvider(ObjectMapper objectMapper, ObjectMapper mapper) {
|
||||
super(objectMapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String providerType() {
|
||||
|
|
@ -45,63 +40,43 @@ public class ExternalHttpEmbeddingProvider implements EmbeddingProvider {
|
|||
public EmbeddingProviderResult embedDocuments(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
return execute(providerConfig, model, request, EmbeddingUseCase.DOCUMENT);
|
||||
return execute(providerConfig, request, EmbeddingUseCase.DOCUMENT);
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingProviderResult embedQuery(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
return execute(providerConfig, model, request, EmbeddingUseCase.QUERY);
|
||||
return execute(providerConfig, request, EmbeddingUseCase.QUERY);
|
||||
}
|
||||
|
||||
private EmbeddingProviderResult execute(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request,
|
||||
EmbeddingUseCase useCase) {
|
||||
if (request.texts() == null || request.texts().isEmpty()) {
|
||||
throw new IllegalArgumentException("Embedding request texts must not be empty");
|
||||
}
|
||||
|
||||
try {
|
||||
var payload = new ProviderRequest(
|
||||
model.providerModelKey(),
|
||||
request.texts(),
|
||||
useCase == EmbeddingUseCase.QUERY,
|
||||
request.providerOptions() == null ? Map.of() : request.providerOptions()
|
||||
HttpResponse<String> response = postJson(
|
||||
providerConfig,
|
||||
"/embed",
|
||||
Map.of(
|
||||
"text", request.texts().getFirst(),
|
||||
"isQuery", useCase == EmbeddingUseCase.QUERY
|
||||
)
|
||||
);
|
||||
|
||||
// Prepare request object
|
||||
var map = new HashMap<>();
|
||||
map.put("text", request.texts().getFirst());
|
||||
map.put("isQuery", false);
|
||||
|
||||
HttpRequest.Builder builder = HttpRequest.newBuilder()
|
||||
.uri(URI.create(trimTrailingSlash(providerConfig.baseUrl()) + "/embed"))
|
||||
.timeout(providerConfig.readTimeout() == null ? Duration.ofSeconds(60) : providerConfig.readTimeout())
|
||||
.header("Content-Type", "application/json")
|
||||
.header("documentId", UUID.randomUUID().toString())
|
||||
.POST(HttpRequest.BodyPublishers.ofString(objectMapper.writeValueAsString(map), StandardCharsets.UTF_8));
|
||||
|
||||
if (providerConfig.apiKey() != null && !providerConfig.apiKey().isBlank()) {
|
||||
builder.header("Authorization", "Bearer " + providerConfig.apiKey());
|
||||
}
|
||||
if (providerConfig.headers() != null) {
|
||||
providerConfig.headers().forEach(builder::header);
|
||||
}
|
||||
|
||||
HttpResponse<String> response = httpClient.send(builder.build(), HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8));
|
||||
if (response.statusCode() / 100 != 2) {
|
||||
throw new IllegalStateException("Embedding provider returned status %d: %s".formatted(response.statusCode(), response.body()));
|
||||
}
|
||||
|
||||
EmbedResponse parsed = objectMapper.readValue(response.body(), EmbedResponse.class);
|
||||
List<float[]> vectors = new ArrayList<>();
|
||||
if (parsed.embedding != null) {
|
||||
vectors.add(toArray(toList(parsed.embedding)));
|
||||
if (parsed.embedding == null) {
|
||||
throw new IllegalStateException("Embedding provider returned no embedding");
|
||||
}
|
||||
|
||||
return new EmbeddingProviderResult(
|
||||
model,
|
||||
vectors,
|
||||
null, //parsed.warnings == null ? List.of() : parsed.warnings,
|
||||
null, //parsed.requestId,
|
||||
null,
|
||||
List.of(parsed.embedding),
|
||||
List.of(),
|
||||
null,
|
||||
parsed.tokenCount
|
||||
);
|
||||
} catch (InterruptedException e) {
|
||||
|
|
@ -112,124 +87,14 @@ public class ExternalHttpEmbeddingProvider implements EmbeddingProvider {
|
|||
}
|
||||
}
|
||||
|
||||
public static List<Float> toList(float[] arr) {
|
||||
if (arr == null) {
|
||||
return null;
|
||||
}
|
||||
List<Float> list = new ArrayList<>(arr.length);
|
||||
for (float v : arr) {
|
||||
list.add(v);
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
||||
private float[] toArray(List<Float> embedding) {
|
||||
float[] result = new float[embedding.size()];
|
||||
for (int i = 0; i < embedding.size(); i++) {
|
||||
result[i] = embedding.get(i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private String trimTrailingSlash(String value) {
|
||||
if (value == null || value.isBlank()) {
|
||||
throw new IllegalArgumentException("Embedding provider baseUrl must be configured");
|
||||
}
|
||||
return value.endsWith("/") ? value.substring(0, value.length() - 1) : value;
|
||||
}
|
||||
|
||||
private record ProviderRequest(
|
||||
@JsonProperty("model") String model,
|
||||
@JsonProperty("texts") List<String> texts,
|
||||
@JsonProperty("is_query") boolean query,
|
||||
@JsonProperty("options") Map<String, Object> options
|
||||
) {
|
||||
}
|
||||
|
||||
private static class ProviderResponse {
|
||||
public static class EmbedResponse {
|
||||
@JsonProperty("embedding")
|
||||
public List<Float> embedding;
|
||||
public float[] embedding;
|
||||
|
||||
@JsonProperty("embeddings")
|
||||
public List<List<Float>> embeddings;
|
||||
|
||||
@JsonProperty("warnings")
|
||||
public List<String> warnings;
|
||||
|
||||
@JsonProperty("request_id")
|
||||
public String requestId;
|
||||
@JsonProperty("dimensions")
|
||||
public Integer dimensions;
|
||||
|
||||
@JsonProperty("token_count")
|
||||
public Integer tokenCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request model for embedding service.
|
||||
* Matches Python FastAPI EmbedRequest model with snake_case field names.
|
||||
*/
|
||||
public static class EmbedRequest {
|
||||
@JsonProperty("text")
|
||||
public String text;
|
||||
|
||||
@JsonProperty("is_query")
|
||||
public boolean isQuery;
|
||||
|
||||
public EmbedRequest() {}
|
||||
|
||||
public String getText() {
|
||||
return text;
|
||||
}
|
||||
|
||||
public void setText(String text) {
|
||||
this.text = text;
|
||||
}
|
||||
|
||||
@JsonProperty("is_query")
|
||||
public boolean isIsQuery() {
|
||||
return isQuery;
|
||||
}
|
||||
|
||||
@JsonProperty("is_query")
|
||||
public void setIsQuery(boolean isQuery) {
|
||||
this.isQuery = isQuery;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Response model for embedding service.
|
||||
*/
|
||||
public static class EmbedResponse {
|
||||
public float[] embedding;
|
||||
public int dimensions;
|
||||
@JsonProperty("token_count")
|
||||
public int tokenCount;
|
||||
|
||||
public EmbedResponse() {}
|
||||
|
||||
public float[] getEmbedding() {
|
||||
return embedding;
|
||||
}
|
||||
|
||||
public void setEmbedding(float[] embedding) {
|
||||
this.embedding = embedding;
|
||||
}
|
||||
|
||||
public int getDimensions() {
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
public void setDimensions(int dimensions) {
|
||||
this.dimensions = dimensions;
|
||||
}
|
||||
|
||||
@JsonProperty("token_count")
|
||||
public int getTokenCount() {
|
||||
return tokenCount;
|
||||
}
|
||||
|
||||
@JsonProperty("token_count")
|
||||
public void setTokenCount(int tokenCount) {
|
||||
this.tokenCount = tokenCount;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,438 @@
|
|||
package at.procon.dip.embedding.provider.http;
|
||||
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.model.EmbeddingRequest;
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
import at.procon.dip.embedding.provider.EmbeddingProvider;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* HTTP provider for vector APIs.
|
||||
*
|
||||
* Supported endpoints:
|
||||
* POST {baseUrl}/vector-sync - single text
|
||||
* POST {baseUrl}/vectorize-batch - multiple texts
|
||||
*
|
||||
* Batch settings are resolved from provider config if present, otherwise defaults are used.
|
||||
* Supported property keys:
|
||||
* - vectorize-batch.truncate-text
|
||||
* - vectorize-batch.truncate-length
|
||||
* - vectorize-batch.chunk-size
|
||||
*
|
||||
* Also accepted as fallbacks:
|
||||
* - truncate_text / truncate-length / chunk_size
|
||||
* - truncateText / truncateLength / chunkSize
|
||||
*/
|
||||
@Component
|
||||
public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProviderSupport implements EmbeddingProvider {
|
||||
private static final String PROVIDER_TYPE = "http-vector-sync";
|
||||
|
||||
private static final boolean DEFAULT_TRUNCATE_TEXT = false;
|
||||
private static final int DEFAULT_TRUNCATE_LENGTH = 512;
|
||||
private static final int DEFAULT_CHUNK_SIZE = 20;
|
||||
|
||||
private static final List<String> TRUNCATE_TEXT_KEYS = List.of(
|
||||
"vectorize-batch.truncate-text",
|
||||
"vectorize-batch.truncate_text",
|
||||
"truncate_text",
|
||||
"truncate-text",
|
||||
"truncateText"
|
||||
);
|
||||
|
||||
private static final List<String> TRUNCATE_LENGTH_KEYS = List.of(
|
||||
"vectorize-batch.truncate-length",
|
||||
"vectorize-batch.truncate_length",
|
||||
"truncate_length",
|
||||
"truncate-length",
|
||||
"truncateLength"
|
||||
);
|
||||
|
||||
private static final List<String> CHUNK_SIZE_KEYS = List.of(
|
||||
"vectorize-batch.chunk-size",
|
||||
"vectorize-batch.chunk_size",
|
||||
"chunk_size",
|
||||
"chunk-size",
|
||||
"chunkSize"
|
||||
);
|
||||
|
||||
public VectorSyncHttpEmbeddingProvider(ObjectMapper objectMapper) {
|
||||
super(objectMapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String providerType() {
|
||||
return PROVIDER_TYPE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supports(EmbeddingModelDescriptor model, ResolvedEmbeddingProviderConfig providerConfig) {
|
||||
return PROVIDER_TYPE.equalsIgnoreCase(providerConfig.providerType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingProviderResult embedDocuments(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
return execute(providerConfig, model, request);
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingProviderResult embedQuery(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
return execute(providerConfig, model, request);
|
||||
}
|
||||
|
||||
private EmbeddingProviderResult execute(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
if (request.texts() == null || request.texts().isEmpty()) {
|
||||
throw new IllegalArgumentException("Embedding request texts must not be empty");
|
||||
}
|
||||
|
||||
try {
|
||||
return request.texts().size() == 1
|
||||
? executeSingle(providerConfig, model, request.texts().getFirst())
|
||||
: executeBatch(providerConfig, model, request.texts());
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new IllegalStateException("Embedding provider call interrupted", e);
|
||||
} catch (IOException e) {
|
||||
throw new IllegalStateException("Failed to call embedding provider", e);
|
||||
}
|
||||
}
|
||||
|
||||
private EmbeddingProviderResult executeSingle(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
String text) throws IOException, InterruptedException {
|
||||
HttpResponse<String> response = postJson(
|
||||
providerConfig,
|
||||
"/vector-sync",
|
||||
new VectorSyncRequest(model.providerModelKey(), text)
|
||||
);
|
||||
|
||||
VectorSyncResponse parsed = objectMapper.readValue(response.body(), VectorSyncResponse.class);
|
||||
float[] vector = extractVector(parsed.vector, parsed.combinedVector, model);
|
||||
|
||||
return new EmbeddingProviderResult(
|
||||
model,
|
||||
List.of(vector),
|
||||
List.of(),
|
||||
null,
|
||||
parsed.tokenCount
|
||||
);
|
||||
}
|
||||
|
||||
private EmbeddingProviderResult executeBatch(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
List<String> texts) throws IOException, InterruptedException {
|
||||
boolean truncateText = resolveBooleanProperty(providerConfig, TRUNCATE_TEXT_KEYS, DEFAULT_TRUNCATE_TEXT);
|
||||
int truncateLength = resolveIntProperty(providerConfig, TRUNCATE_LENGTH_KEYS, DEFAULT_TRUNCATE_LENGTH);
|
||||
int chunkSize = resolveIntProperty(providerConfig, CHUNK_SIZE_KEYS, DEFAULT_CHUNK_SIZE);
|
||||
|
||||
if (truncateLength <= 0) {
|
||||
throw new IllegalArgumentException("Batch truncate length must be > 0");
|
||||
}
|
||||
if (chunkSize <= 0) {
|
||||
throw new IllegalArgumentException("Batch chunk size must be > 0");
|
||||
}
|
||||
|
||||
List<String> requestOrder = new ArrayList<>(texts.size());
|
||||
List<VectorizeBatchItemRequest> items = new ArrayList<>(texts.size());
|
||||
|
||||
for (String text : texts) {
|
||||
String id = UUID.randomUUID().toString();
|
||||
requestOrder.add(id);
|
||||
items.add(new VectorizeBatchItemRequest(id, text));
|
||||
}
|
||||
|
||||
HttpResponse<String> response = postJson(
|
||||
providerConfig,
|
||||
"/vectorize-batch",
|
||||
new VectorizeBatchRequest(
|
||||
model.providerModelKey(),
|
||||
truncateText,
|
||||
truncateLength,
|
||||
chunkSize,
|
||||
items
|
||||
)
|
||||
);
|
||||
|
||||
VectorizeBatchResponse parsed = objectMapper.readValue(response.body(), VectorizeBatchResponse.class);
|
||||
if (parsed.results == null || parsed.results.isEmpty()) {
|
||||
throw new IllegalStateException("Vectorize-batch provider returned no results");
|
||||
}
|
||||
|
||||
Map<String, VectorizeBatchItemResponse> resultById = new HashMap<>();
|
||||
for (VectorizeBatchItemResponse result : parsed.results) {
|
||||
resultById.put(result.id, result);
|
||||
}
|
||||
|
||||
List<float[]> vectors = new ArrayList<>(texts.size());
|
||||
int totalTokenCount = 0;
|
||||
boolean hasAnyTokenCount = false;
|
||||
|
||||
for (String id : requestOrder) {
|
||||
VectorizeBatchItemResponse item = resultById.get(id);
|
||||
if (item == null) {
|
||||
throw new IllegalStateException("Vectorize-batch provider response is missing item for id " + id);
|
||||
}
|
||||
|
||||
vectors.add(extractVector(item.vector, item.combinedVector, model));
|
||||
|
||||
if (item.tokenCount != null) {
|
||||
totalTokenCount += item.tokenCount;
|
||||
hasAnyTokenCount = true;
|
||||
}
|
||||
}
|
||||
|
||||
return new EmbeddingProviderResult(
|
||||
model,
|
||||
vectors,
|
||||
List.of(),
|
||||
null,
|
||||
hasAnyTokenCount ? totalTokenCount : null
|
||||
);
|
||||
}
|
||||
|
||||
private float[] extractVector(List<Float> vector,
|
||||
List<Float> combinedVector,
|
||||
EmbeddingModelDescriptor model) {
|
||||
float[] resolved;
|
||||
|
||||
if (combinedVector != null && !combinedVector.isEmpty()) {
|
||||
resolved = toArray(combinedVector);
|
||||
} else if (vector != null && !vector.isEmpty()) {
|
||||
resolved = toArray(vector);
|
||||
} else {
|
||||
throw new IllegalStateException("Embedding provider returned no vector");
|
||||
}
|
||||
|
||||
if (model.dimensions() > 0 && resolved.length != model.dimensions()) {
|
||||
throw new IllegalStateException(
|
||||
"Embedding provider returned dimension %d for model %s, expected %d"
|
||||
.formatted(resolved.length, model.modelKey(), model.dimensions())
|
||||
);
|
||||
}
|
||||
|
||||
return resolved;
|
||||
}
|
||||
|
||||
private boolean resolveBooleanProperty(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
List<String> keys,
|
||||
boolean defaultValue) {
|
||||
Object raw = resolveProviderConfigValue(providerConfig, keys);
|
||||
if (raw == null) {
|
||||
return defaultValue;
|
||||
}
|
||||
if (raw instanceof Boolean b) {
|
||||
return b;
|
||||
}
|
||||
String normalized = String.valueOf(raw).trim();
|
||||
if (normalized.isEmpty()) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Boolean.parseBoolean(normalized);
|
||||
}
|
||||
|
||||
private int resolveIntProperty(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
List<String> keys,
|
||||
int defaultValue) {
|
||||
Object raw = resolveProviderConfigValue(providerConfig, keys);
|
||||
if (raw == null) {
|
||||
return defaultValue;
|
||||
}
|
||||
if (raw instanceof Number n) {
|
||||
return n.intValue();
|
||||
}
|
||||
String normalized = String.valueOf(raw).trim();
|
||||
if (normalized.isEmpty()) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Integer.parseInt(normalized);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Object resolveProviderConfigValue(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
List<String> keys) {
|
||||
List<Object> containers = new ArrayList<>();
|
||||
containers.add(providerConfig);
|
||||
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "properties"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "providerProperties"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "config"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "settings"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "options"));
|
||||
addIfNonNull(containers, readField(providerConfig, "properties"));
|
||||
addIfNonNull(containers, readField(providerConfig, "providerProperties"));
|
||||
addIfNonNull(containers, readField(providerConfig, "config"));
|
||||
addIfNonNull(containers, readField(providerConfig, "settings"));
|
||||
addIfNonNull(containers, readField(providerConfig, "options"));
|
||||
|
||||
for (Object container : containers) {
|
||||
if (container instanceof Map<?, ?> map) {
|
||||
for (String key : keys) {
|
||||
if (map.containsKey(key)) {
|
||||
return map.get(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (String key : keys) {
|
||||
Object value = invokeStringArg(container, "get", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
value = invokeStringArg(container, "getProperty", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
value = invokeStringArg(container, "property", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
value = invokeStringArg(container, "option", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private void addIfNonNull(List<Object> containers, Object value) {
|
||||
if (value != null) {
|
||||
containers.add(value);
|
||||
}
|
||||
}
|
||||
|
||||
private Object invokeNoArg(Object target, String methodName) {
|
||||
try {
|
||||
Method method = target.getClass().getMethod(methodName);
|
||||
method.setAccessible(true);
|
||||
return method.invoke(target);
|
||||
} catch (Exception ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Object invokeStringArg(Object target, String methodName, String arg) {
|
||||
try {
|
||||
Method method = target.getClass().getMethod(methodName, String.class);
|
||||
method.setAccessible(true);
|
||||
return method.invoke(target, arg);
|
||||
} catch (Exception ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Object readField(Object target, String fieldName) {
|
||||
try {
|
||||
Field field = target.getClass().getDeclaredField(fieldName);
|
||||
field.setAccessible(true);
|
||||
return field.get(target);
|
||||
} catch (Exception ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private record VectorSyncRequest(
|
||||
@JsonProperty("model") String model,
|
||||
@JsonProperty("text") String text
|
||||
) {
|
||||
}
|
||||
|
||||
private record VectorizeBatchRequest(
|
||||
@JsonProperty("model") String model,
|
||||
@JsonProperty("truncate_text") boolean truncateText,
|
||||
@JsonProperty("truncate_length") int truncateLength,
|
||||
@JsonProperty("chunk_size") int chunkSize,
|
||||
@JsonProperty("items") List<VectorizeBatchItemRequest> items
|
||||
) {
|
||||
}
|
||||
|
||||
private record VectorizeBatchItemRequest(
|
||||
@JsonProperty("id") String id,
|
||||
@JsonProperty("text") String text
|
||||
) {
|
||||
}
|
||||
|
||||
static class VectorSyncResponse {
|
||||
@JsonProperty("runtime_ms")
|
||||
public Double runtimeMs;
|
||||
|
||||
@JsonProperty("vector")
|
||||
public List<Float> vector;
|
||||
|
||||
@JsonProperty("incomplete")
|
||||
public Boolean incomplete;
|
||||
|
||||
@JsonProperty("combined_vector")
|
||||
public List<Float> combinedVector;
|
||||
|
||||
@JsonProperty("token_count")
|
||||
public Integer tokenCount;
|
||||
|
||||
@JsonProperty("model")
|
||||
public String model;
|
||||
|
||||
@JsonProperty("max_seq_length")
|
||||
public Integer maxSeqLength;
|
||||
}
|
||||
|
||||
static class VectorizeBatchResponse {
|
||||
@JsonProperty("model")
|
||||
public String model;
|
||||
|
||||
@JsonProperty("count")
|
||||
public Integer count;
|
||||
|
||||
@JsonProperty("results")
|
||||
public List<VectorizeBatchItemResponse> results;
|
||||
}
|
||||
|
||||
static class VectorizeBatchItemResponse {
|
||||
@JsonProperty("id")
|
||||
public String id;
|
||||
|
||||
@JsonProperty("vector")
|
||||
public List<Float> vector;
|
||||
|
||||
@JsonProperty("token_count")
|
||||
public Integer tokenCount;
|
||||
|
||||
@JsonProperty("runtime_ms")
|
||||
public Double runtimeMs;
|
||||
|
||||
@JsonProperty("incomplete")
|
||||
public Boolean incomplete;
|
||||
|
||||
@JsonProperty("combined_vector")
|
||||
public List<Float> combinedVector;
|
||||
|
||||
@JsonProperty("truncated")
|
||||
public Boolean truncated;
|
||||
|
||||
@JsonProperty("truncate_length")
|
||||
public Integer truncateLength;
|
||||
|
||||
@JsonProperty("model")
|
||||
public String model;
|
||||
|
||||
@JsonProperty("max_seq_length")
|
||||
public Integer maxSeqLength;
|
||||
}
|
||||
}
|
||||
|
|
@ -13,6 +13,8 @@ import at.procon.dip.ingestion.spi.SourceDescriptor;
|
|||
import at.procon.dip.processing.spi.DocumentProcessingPolicy;
|
||||
import at.procon.dip.processing.spi.StructuredDocumentProcessor;
|
||||
import at.procon.dip.processing.spi.StructuredProcessingRequest;
|
||||
import at.procon.dip.runtime.condition.ConditionalOnRuntimeMode;
|
||||
import at.procon.dip.runtime.config.RuntimeMode;
|
||||
import at.procon.ted.model.entity.ProcurementDocument;
|
||||
import at.procon.ted.service.XmlParserService;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
|
@ -25,6 +27,7 @@ import org.springframework.stereotype.Component;
|
|||
import org.springframework.util.StringUtils;
|
||||
|
||||
@Component
|
||||
@ConditionalOnRuntimeMode(RuntimeMode.NEW)
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class TedStructuredDocumentProcessor implements StructuredDocumentProcessor {
|
||||
|
|
|
|||
|
|
@ -44,8 +44,6 @@ public class BatchDocumentProcessingService {
|
|||
private final ProcurementDocumentRepository documentRepository;
|
||||
private final ProcessingLogService processingLogService;
|
||||
private final TedProcessorProperties properties;
|
||||
private final TedPhase2GenericDocumentService tedPhase2GenericDocumentService;
|
||||
private final TedNoticeProjectionService tedNoticeProjectionService;
|
||||
|
||||
/**
|
||||
* Process a batch of XML files from a Daily Package.
|
||||
|
|
@ -137,12 +135,6 @@ public class BatchDocumentProcessingService {
|
|||
ProcessingLog.EventStatus.SUCCESS,
|
||||
"Document parsed and stored successfully (batch)", null,
|
||||
doc.getSourceFilename(), 0);
|
||||
|
||||
if (doc.getDocumentHash() != null) {
|
||||
if (properties.getProjection().isEnabled()) {
|
||||
tedNoticeProjectionService.registerOrRefreshProjection(doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.info("Successfully inserted {} documents in batch", savedDocuments.size());
|
||||
|
|
|
|||
|
|
@ -40,8 +40,6 @@ public class DocumentProcessingService {
|
|||
private final ProcessingLogService processingLogService;
|
||||
private final TedProcessorProperties properties;
|
||||
private final ApplicationEventPublisher eventPublisher;
|
||||
private final TedPhase2GenericDocumentService tedPhase2GenericDocumentService;
|
||||
private final TedNoticeProjectionService tedNoticeProjectionService;
|
||||
|
||||
/**
|
||||
* Process an XML document from the file system.
|
||||
|
|
@ -93,20 +91,10 @@ public class DocumentProcessingService {
|
|||
"Document parsed and stored successfully", null, filename,
|
||||
(int) (System.currentTimeMillis() - startTime));
|
||||
|
||||
if (properties.getProjection().isEnabled()) {
|
||||
tedNoticeProjectionService.registerOrRefreshProjection(document);
|
||||
log.debug("Document saved successfully, Phase 3 TED projection ensured: {}", document.getId());
|
||||
|
||||
// Keep legacy vectorization behavior when the generic embedding pipeline is disabled.
|
||||
eventPublisher.publishEvent(new DocumentSavedEvent(document.getId(), document.getPublicationId()));
|
||||
log.debug("Document saved successfully, legacy vectorization event published: {}", document.getId());
|
||||
|
||||
} else {
|
||||
// Publish event to trigger vectorization AFTER transaction commit
|
||||
// This ensures document is visible in DB and avoids transaction isolation issues
|
||||
eventPublisher.publishEvent(new DocumentSavedEvent(document.getId(), document.getPublicationId()));
|
||||
log.debug("Document saved successfully, vectorization event published: {}", document.getId());
|
||||
}
|
||||
// Publish event to trigger vectorization AFTER transaction commit
|
||||
// This ensures document is visible in DB and avoids transaction isolation issues
|
||||
eventPublisher.publishEvent(new DocumentSavedEvent(document.getId(), document.getPublicationId()));
|
||||
log.debug("Document saved successfully, vectorization event published: {}", document.getId());
|
||||
|
||||
return ProcessingResult.success(document.getId(), documentHash, document.getPublicationId());
|
||||
|
||||
|
|
@ -157,10 +145,6 @@ public class DocumentProcessingService {
|
|||
|
||||
documentRepository.save(updated);
|
||||
|
||||
if (properties.getProjection().isEnabled()) {
|
||||
tedNoticeProjectionService.registerOrRefreshProjection(updated);
|
||||
}
|
||||
|
||||
// Note: Re-vectorization will be triggered automatically by the active scheduler
|
||||
return updated;
|
||||
} catch (Exception e) {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
dip:
|
||||
embedding:
|
||||
enabled: true
|
||||
default-document-model: e5-default
|
||||
default-query-model: e5-default
|
||||
|
||||
providers:
|
||||
vector-sync-local:
|
||||
type: http-vector-sync
|
||||
base-url: http://localhost:8001
|
||||
connect-timeout: 5s
|
||||
read-timeout: 60s
|
||||
headers:
|
||||
X-Client: dip
|
||||
|
||||
models:
|
||||
e5-default:
|
||||
provider-config-key: vector-sync-local
|
||||
provider-model-key: intfloat/multilingual-e5-large
|
||||
dimensions: 1024
|
||||
distance-metric: COSINE
|
||||
supports-query-embedding-mode: true
|
||||
supports-batch: false
|
||||
max-input-chars: 8192
|
||||
active: true
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
package at.procon.dip.embedding.provider.http;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import at.procon.dip.domain.document.DistanceMetric;
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.embedding.model.EmbeddingRequest;
|
||||
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.sun.net.httpserver.HttpExchange;
|
||||
import com.sun.net.httpserver.HttpServer;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class VectorSyncHttpEmbeddingProviderTest {
|
||||
|
||||
private HttpServer server;
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
if (server != null) {
|
||||
server.stop(0);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCallVectorSyncEndpointAndParseVector() throws Exception {
|
||||
server = HttpServer.create(new InetSocketAddress(0), 0);
|
||||
server.createContext("/vector-sync", this::handleVectorSync);
|
||||
server.start();
|
||||
|
||||
var provider = new VectorSyncHttpEmbeddingProvider(new ObjectMapper());
|
||||
var config = ResolvedEmbeddingProviderConfig.builder()
|
||||
.key("vector-sync-local")
|
||||
.providerType("http-vector-sync")
|
||||
.baseUrl("http://localhost:" + server.getAddress().getPort())
|
||||
.readTimeout(Duration.ofSeconds(5))
|
||||
.headers(Map.of("X-Client", "dip-test"))
|
||||
.build();
|
||||
var model = new EmbeddingModelDescriptor(
|
||||
"e5-default",
|
||||
"vector-sync-local",
|
||||
"intfloat/multilingual-e5-large",
|
||||
3,
|
||||
DistanceMetric.COSINE,
|
||||
true,
|
||||
false,
|
||||
8192,
|
||||
true
|
||||
);
|
||||
var request = EmbeddingRequest.builder()
|
||||
.modelKey("e5-default")
|
||||
.useCase(EmbeddingUseCase.DOCUMENT)
|
||||
.texts(List.of("This is a sample text to vectorize"))
|
||||
.providerOptions(Map.of())
|
||||
.build();
|
||||
|
||||
var result = provider.embedDocuments(config, model, request);
|
||||
|
||||
assertThat(result.vectors()).hasSize(1);
|
||||
assertThat(result.vectors().getFirst()).containsExactly(0.1f, 0.2f, 0.3f);
|
||||
assertThat(result.tokenCount()).isEqualTo(9);
|
||||
}
|
||||
|
||||
private void handleVectorSync(HttpExchange exchange) throws IOException {
|
||||
String body;
|
||||
try (InputStream in = exchange.getRequestBody()) {
|
||||
body = new String(in.readAllBytes(), StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
assertThat(exchange.getRequestMethod()).isEqualTo("POST");
|
||||
assertThat(body).contains("\"model\":\"intfloat/multilingual-e5-large\"");
|
||||
assertThat(body).contains("\"text\":\"This is a sample text to vectorize\"");
|
||||
assertThat(exchange.getRequestHeaders().getFirst("X-Client")).isEqualTo("dip-test");
|
||||
|
||||
String json = "{"
|
||||
+ "\"runtime_ms\":472.49,"
|
||||
+ "\"vector\":[0.1,0.2,0.3],"
|
||||
+ "\"incomplete\":false,"
|
||||
+ "\"combined_vector\":null,"
|
||||
+ "\"token_count\":9,"
|
||||
+ "\"model\":\"intfloat/multilingual-e5-large\","
|
||||
+ "\"max_seq_length\":512"
|
||||
+ "}";
|
||||
|
||||
byte[] response = json.getBytes(StandardCharsets.UTF_8);
|
||||
exchange.getResponseHeaders().add("Content-Type", "application/json");
|
||||
exchange.sendResponseHeaders(200, response.length);
|
||||
exchange.getResponseBody().write(response);
|
||||
exchange.close();
|
||||
}
|
||||
}
|
||||
|
|
@ -48,7 +48,7 @@ class DefaultEmbeddingPolicyResolverTest {
|
|||
|
||||
DefaultEmbeddingPolicyResolver resolver = new DefaultEmbeddingPolicyResolver(properties);
|
||||
|
||||
SourceDescriptor descriptor = sourceDescriptor(SourceType.MAIL_MESSAGE, "message/rfc822", Map.of(
|
||||
SourceDescriptor descriptor = sourceDescriptor(SourceType.MAIL, "message/rfc822", Map.of(
|
||||
"embeddingPolicyHint", "mail-default",
|
||||
"embeddingPolicyKey", "ted-default"
|
||||
));
|
||||
|
|
@ -76,8 +76,8 @@ class DefaultEmbeddingPolicyResolverTest {
|
|||
|
||||
DefaultEmbeddingPolicyResolver resolver = new DefaultEmbeddingPolicyResolver(properties);
|
||||
|
||||
var policy = resolver.resolve(document(DocumentFamily.GENERIC, DocumentType.FILE, "en"),
|
||||
sourceDescriptor(SourceType.MAIL_ATTACHMENT, "application/pdf", Map.of()));
|
||||
var policy = resolver.resolve(document(DocumentFamily.GENERIC, DocumentType.UNKNOWN, "en"),
|
||||
sourceDescriptor(SourceType.MAIL, "application/pdf", Map.of()));
|
||||
|
||||
assertThat(policy.policyKey()).isEqualTo("mail-attachment-pdf");
|
||||
}
|
||||
|
|
@ -85,10 +85,10 @@ class DefaultEmbeddingPolicyResolverTest {
|
|||
@Test
|
||||
void shouldFailForUnknownOverridePolicy() {
|
||||
DefaultEmbeddingPolicyResolver resolver = new DefaultEmbeddingPolicyResolver(baseProperties());
|
||||
SourceDescriptor descriptor = sourceDescriptor(SourceType.FILE_IMPORT, "application/pdf", Map.of(
|
||||
SourceDescriptor descriptor = sourceDescriptor(SourceType.FILE_SYSTEM, "application/pdf", Map.of(
|
||||
"embeddingPolicyKey", "missing-policy"
|
||||
));
|
||||
assertThatThrownBy(() -> resolver.resolve(document(DocumentFamily.GENERIC, DocumentType.FILE, "en"), descriptor))
|
||||
assertThatThrownBy(() -> resolver.resolve(document(DocumentFamily.GENERIC, DocumentType.UNKNOWN, "en"), descriptor))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessageContaining("Unknown embedding policy key");
|
||||
}
|
||||
|
|
@ -111,7 +111,7 @@ class DefaultEmbeddingPolicyResolverTest {
|
|||
.documentFamily(family)
|
||||
.documentType(type)
|
||||
.languageCode(language)
|
||||
.status(DocumentStatus.IMPORTED)
|
||||
.status(DocumentStatus.RECEIVED)
|
||||
.visibility(DocumentVisibility.PUBLIC)
|
||||
.title("Test document")
|
||||
.build();
|
||||
|
|
|
|||
Loading…
Reference in New Issue