vectorization using py temporal service
parent
177c61803e
commit
3913ed83e8
@ -0,0 +1,39 @@
|
||||
# Vector-sync HTTP embedding provider
|
||||
|
||||
This patch adds a new provider type:
|
||||
|
||||
- `http-vector-sync`
|
||||
|
||||
## Request
|
||||
Endpoint:
|
||||
- `POST {baseUrl}/vector-sync`
|
||||
|
||||
Request body:
|
||||
```json
|
||||
{
|
||||
"model": "intfloat/multilingual-e5-large",
|
||||
"text": "This is a sample text to vectorize"
|
||||
}
|
||||
```
|
||||
|
||||
## Response
|
||||
```json
|
||||
{
|
||||
"runtime_ms": 472.49,
|
||||
"vector": [0.1, 0.2, 0.3],
|
||||
"incomplete": false,
|
||||
"combined_vector": null,
|
||||
"token_count": 9,
|
||||
"model": "intfloat/multilingual-e5-large",
|
||||
"max_seq_length": 512
|
||||
}
|
||||
```
|
||||
|
||||
## Notes
|
||||
- supports a single text per request
|
||||
- works for both document and query embeddings
|
||||
- validates returned vector dimension against the configured embedding model
|
||||
- keeps the existing `/embed` provider in place as `http-json`
|
||||
|
||||
## Example config
|
||||
See `application-new-example-vector-sync-provider.yml`.
|
||||
@ -0,0 +1,68 @@
|
||||
package at.procon.dip.embedding.provider.http;
|
||||
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
|
||||
@RequiredArgsConstructor
|
||||
abstract class AbstractHttpEmbeddingProviderSupport {
|
||||
|
||||
protected final ObjectMapper objectMapper;
|
||||
protected final HttpClient httpClient = HttpClient.newBuilder()
|
||||
.version(HttpClient.Version.HTTP_1_1)
|
||||
.build();
|
||||
|
||||
protected String trimTrailingSlash(String value) {
|
||||
if (value == null || value.isBlank()) {
|
||||
throw new IllegalArgumentException("Embedding provider baseUrl must be configured");
|
||||
}
|
||||
return value.endsWith("/") ? value.substring(0, value.length() - 1) : value;
|
||||
}
|
||||
|
||||
protected HttpResponse<String> postJson(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
String path,
|
||||
Object body) throws IOException, InterruptedException {
|
||||
HttpRequest.Builder builder = HttpRequest.newBuilder()
|
||||
.uri(URI.create(trimTrailingSlash(providerConfig.baseUrl()) + path))
|
||||
.timeout(providerConfig.readTimeout() == null ? Duration.ofSeconds(60) : providerConfig.readTimeout())
|
||||
.header("Content-Type", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(
|
||||
objectMapper.writeValueAsString(body),
|
||||
StandardCharsets.UTF_8
|
||||
));
|
||||
|
||||
if (providerConfig.apiKey() != null && !providerConfig.apiKey().isBlank()) {
|
||||
builder.header("Authorization", "Bearer " + providerConfig.apiKey());
|
||||
}
|
||||
if (providerConfig.headers() != null) {
|
||||
providerConfig.headers().forEach(builder::header);
|
||||
}
|
||||
|
||||
HttpResponse<String> response = httpClient.send(
|
||||
builder.build(),
|
||||
HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)
|
||||
);
|
||||
if (response.statusCode() / 100 != 2) {
|
||||
throw new IllegalStateException(
|
||||
"Embedding provider returned status %d: %s".formatted(response.statusCode(), response.body())
|
||||
);
|
||||
}
|
||||
return response;
|
||||
}
|
||||
|
||||
protected float[] toArray(List<Float> embedding) {
|
||||
float[] result = new float[embedding.size()];
|
||||
for (int i = 0; i < embedding.size(); i++) {
|
||||
result[i] = embedding.get(i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,438 @@
|
||||
package at.procon.dip.embedding.provider.http;
|
||||
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.model.EmbeddingRequest;
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
import at.procon.dip.embedding.provider.EmbeddingProvider;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Method;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* HTTP provider for vector APIs.
|
||||
*
|
||||
* Supported endpoints:
|
||||
* POST {baseUrl}/vector-sync - single text
|
||||
* POST {baseUrl}/vectorize-batch - multiple texts
|
||||
*
|
||||
* Batch settings are resolved from provider config if present, otherwise defaults are used.
|
||||
* Supported property keys:
|
||||
* - vectorize-batch.truncate-text
|
||||
* - vectorize-batch.truncate-length
|
||||
* - vectorize-batch.chunk-size
|
||||
*
|
||||
* Also accepted as fallbacks:
|
||||
* - truncate_text / truncate-length / chunk_size
|
||||
* - truncateText / truncateLength / chunkSize
|
||||
*/
|
||||
@Component
|
||||
public class VectorSyncHttpEmbeddingProvider extends AbstractHttpEmbeddingProviderSupport implements EmbeddingProvider {
|
||||
private static final String PROVIDER_TYPE = "http-vector-sync";
|
||||
|
||||
private static final boolean DEFAULT_TRUNCATE_TEXT = false;
|
||||
private static final int DEFAULT_TRUNCATE_LENGTH = 512;
|
||||
private static final int DEFAULT_CHUNK_SIZE = 20;
|
||||
|
||||
private static final List<String> TRUNCATE_TEXT_KEYS = List.of(
|
||||
"vectorize-batch.truncate-text",
|
||||
"vectorize-batch.truncate_text",
|
||||
"truncate_text",
|
||||
"truncate-text",
|
||||
"truncateText"
|
||||
);
|
||||
|
||||
private static final List<String> TRUNCATE_LENGTH_KEYS = List.of(
|
||||
"vectorize-batch.truncate-length",
|
||||
"vectorize-batch.truncate_length",
|
||||
"truncate_length",
|
||||
"truncate-length",
|
||||
"truncateLength"
|
||||
);
|
||||
|
||||
private static final List<String> CHUNK_SIZE_KEYS = List.of(
|
||||
"vectorize-batch.chunk-size",
|
||||
"vectorize-batch.chunk_size",
|
||||
"chunk_size",
|
||||
"chunk-size",
|
||||
"chunkSize"
|
||||
);
|
||||
|
||||
public VectorSyncHttpEmbeddingProvider(ObjectMapper objectMapper) {
|
||||
super(objectMapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String providerType() {
|
||||
return PROVIDER_TYPE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supports(EmbeddingModelDescriptor model, ResolvedEmbeddingProviderConfig providerConfig) {
|
||||
return PROVIDER_TYPE.equalsIgnoreCase(providerConfig.providerType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingProviderResult embedDocuments(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
return execute(providerConfig, model, request);
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingProviderResult embedQuery(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
return execute(providerConfig, model, request);
|
||||
}
|
||||
|
||||
private EmbeddingProviderResult execute(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
if (request.texts() == null || request.texts().isEmpty()) {
|
||||
throw new IllegalArgumentException("Embedding request texts must not be empty");
|
||||
}
|
||||
|
||||
try {
|
||||
return request.texts().size() == 1
|
||||
? executeSingle(providerConfig, model, request.texts().getFirst())
|
||||
: executeBatch(providerConfig, model, request.texts());
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new IllegalStateException("Embedding provider call interrupted", e);
|
||||
} catch (IOException e) {
|
||||
throw new IllegalStateException("Failed to call embedding provider", e);
|
||||
}
|
||||
}
|
||||
|
||||
private EmbeddingProviderResult executeSingle(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
String text) throws IOException, InterruptedException {
|
||||
HttpResponse<String> response = postJson(
|
||||
providerConfig,
|
||||
"/vector-sync",
|
||||
new VectorSyncRequest(model.providerModelKey(), text)
|
||||
);
|
||||
|
||||
VectorSyncResponse parsed = objectMapper.readValue(response.body(), VectorSyncResponse.class);
|
||||
float[] vector = extractVector(parsed.vector, parsed.combinedVector, model);
|
||||
|
||||
return new EmbeddingProviderResult(
|
||||
model,
|
||||
List.of(vector),
|
||||
List.of(),
|
||||
null,
|
||||
parsed.tokenCount
|
||||
);
|
||||
}
|
||||
|
||||
private EmbeddingProviderResult executeBatch(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
List<String> texts) throws IOException, InterruptedException {
|
||||
boolean truncateText = resolveBooleanProperty(providerConfig, TRUNCATE_TEXT_KEYS, DEFAULT_TRUNCATE_TEXT);
|
||||
int truncateLength = resolveIntProperty(providerConfig, TRUNCATE_LENGTH_KEYS, DEFAULT_TRUNCATE_LENGTH);
|
||||
int chunkSize = resolveIntProperty(providerConfig, CHUNK_SIZE_KEYS, DEFAULT_CHUNK_SIZE);
|
||||
|
||||
if (truncateLength <= 0) {
|
||||
throw new IllegalArgumentException("Batch truncate length must be > 0");
|
||||
}
|
||||
if (chunkSize <= 0) {
|
||||
throw new IllegalArgumentException("Batch chunk size must be > 0");
|
||||
}
|
||||
|
||||
List<String> requestOrder = new ArrayList<>(texts.size());
|
||||
List<VectorizeBatchItemRequest> items = new ArrayList<>(texts.size());
|
||||
|
||||
for (String text : texts) {
|
||||
String id = UUID.randomUUID().toString();
|
||||
requestOrder.add(id);
|
||||
items.add(new VectorizeBatchItemRequest(id, text));
|
||||
}
|
||||
|
||||
HttpResponse<String> response = postJson(
|
||||
providerConfig,
|
||||
"/vectorize-batch",
|
||||
new VectorizeBatchRequest(
|
||||
model.providerModelKey(),
|
||||
truncateText,
|
||||
truncateLength,
|
||||
chunkSize,
|
||||
items
|
||||
)
|
||||
);
|
||||
|
||||
VectorizeBatchResponse parsed = objectMapper.readValue(response.body(), VectorizeBatchResponse.class);
|
||||
if (parsed.results == null || parsed.results.isEmpty()) {
|
||||
throw new IllegalStateException("Vectorize-batch provider returned no results");
|
||||
}
|
||||
|
||||
Map<String, VectorizeBatchItemResponse> resultById = new HashMap<>();
|
||||
for (VectorizeBatchItemResponse result : parsed.results) {
|
||||
resultById.put(result.id, result);
|
||||
}
|
||||
|
||||
List<float[]> vectors = new ArrayList<>(texts.size());
|
||||
int totalTokenCount = 0;
|
||||
boolean hasAnyTokenCount = false;
|
||||
|
||||
for (String id : requestOrder) {
|
||||
VectorizeBatchItemResponse item = resultById.get(id);
|
||||
if (item == null) {
|
||||
throw new IllegalStateException("Vectorize-batch provider response is missing item for id " + id);
|
||||
}
|
||||
|
||||
vectors.add(extractVector(item.vector, item.combinedVector, model));
|
||||
|
||||
if (item.tokenCount != null) {
|
||||
totalTokenCount += item.tokenCount;
|
||||
hasAnyTokenCount = true;
|
||||
}
|
||||
}
|
||||
|
||||
return new EmbeddingProviderResult(
|
||||
model,
|
||||
vectors,
|
||||
List.of(),
|
||||
null,
|
||||
hasAnyTokenCount ? totalTokenCount : null
|
||||
);
|
||||
}
|
||||
|
||||
private float[] extractVector(List<Float> vector,
|
||||
List<Float> combinedVector,
|
||||
EmbeddingModelDescriptor model) {
|
||||
float[] resolved;
|
||||
|
||||
if (combinedVector != null && !combinedVector.isEmpty()) {
|
||||
resolved = toArray(combinedVector);
|
||||
} else if (vector != null && !vector.isEmpty()) {
|
||||
resolved = toArray(vector);
|
||||
} else {
|
||||
throw new IllegalStateException("Embedding provider returned no vector");
|
||||
}
|
||||
|
||||
if (model.dimensions() > 0 && resolved.length != model.dimensions()) {
|
||||
throw new IllegalStateException(
|
||||
"Embedding provider returned dimension %d for model %s, expected %d"
|
||||
.formatted(resolved.length, model.modelKey(), model.dimensions())
|
||||
);
|
||||
}
|
||||
|
||||
return resolved;
|
||||
}
|
||||
|
||||
private boolean resolveBooleanProperty(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
List<String> keys,
|
||||
boolean defaultValue) {
|
||||
Object raw = resolveProviderConfigValue(providerConfig, keys);
|
||||
if (raw == null) {
|
||||
return defaultValue;
|
||||
}
|
||||
if (raw instanceof Boolean b) {
|
||||
return b;
|
||||
}
|
||||
String normalized = String.valueOf(raw).trim();
|
||||
if (normalized.isEmpty()) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Boolean.parseBoolean(normalized);
|
||||
}
|
||||
|
||||
private int resolveIntProperty(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
List<String> keys,
|
||||
int defaultValue) {
|
||||
Object raw = resolveProviderConfigValue(providerConfig, keys);
|
||||
if (raw == null) {
|
||||
return defaultValue;
|
||||
}
|
||||
if (raw instanceof Number n) {
|
||||
return n.intValue();
|
||||
}
|
||||
String normalized = String.valueOf(raw).trim();
|
||||
if (normalized.isEmpty()) {
|
||||
return defaultValue;
|
||||
}
|
||||
return Integer.parseInt(normalized);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private Object resolveProviderConfigValue(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
List<String> keys) {
|
||||
List<Object> containers = new ArrayList<>();
|
||||
containers.add(providerConfig);
|
||||
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "properties"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "providerProperties"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "config"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "settings"));
|
||||
addIfNonNull(containers, invokeNoArg(providerConfig, "options"));
|
||||
addIfNonNull(containers, readField(providerConfig, "properties"));
|
||||
addIfNonNull(containers, readField(providerConfig, "providerProperties"));
|
||||
addIfNonNull(containers, readField(providerConfig, "config"));
|
||||
addIfNonNull(containers, readField(providerConfig, "settings"));
|
||||
addIfNonNull(containers, readField(providerConfig, "options"));
|
||||
|
||||
for (Object container : containers) {
|
||||
if (container instanceof Map<?, ?> map) {
|
||||
for (String key : keys) {
|
||||
if (map.containsKey(key)) {
|
||||
return map.get(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (String key : keys) {
|
||||
Object value = invokeStringArg(container, "get", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
value = invokeStringArg(container, "getProperty", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
value = invokeStringArg(container, "property", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
value = invokeStringArg(container, "option", key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private void addIfNonNull(List<Object> containers, Object value) {
|
||||
if (value != null) {
|
||||
containers.add(value);
|
||||
}
|
||||
}
|
||||
|
||||
private Object invokeNoArg(Object target, String methodName) {
|
||||
try {
|
||||
Method method = target.getClass().getMethod(methodName);
|
||||
method.setAccessible(true);
|
||||
return method.invoke(target);
|
||||
} catch (Exception ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Object invokeStringArg(Object target, String methodName, String arg) {
|
||||
try {
|
||||
Method method = target.getClass().getMethod(methodName, String.class);
|
||||
method.setAccessible(true);
|
||||
return method.invoke(target, arg);
|
||||
} catch (Exception ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Object readField(Object target, String fieldName) {
|
||||
try {
|
||||
Field field = target.getClass().getDeclaredField(fieldName);
|
||||
field.setAccessible(true);
|
||||
return field.get(target);
|
||||
} catch (Exception ignored) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private record VectorSyncRequest(
|
||||
@JsonProperty("model") String model,
|
||||
@JsonProperty("text") String text
|
||||
) {
|
||||
}
|
||||
|
||||
private record VectorizeBatchRequest(
|
||||
@JsonProperty("model") String model,
|
||||
@JsonProperty("truncate_text") boolean truncateText,
|
||||
@JsonProperty("truncate_length") int truncateLength,
|
||||
@JsonProperty("chunk_size") int chunkSize,
|
||||
@JsonProperty("items") List<VectorizeBatchItemRequest> items
|
||||
) {
|
||||
}
|
||||
|
||||
private record VectorizeBatchItemRequest(
|
||||
@JsonProperty("id") String id,
|
||||
@JsonProperty("text") String text
|
||||
) {
|
||||
}
|
||||
|
||||
static class VectorSyncResponse {
|
||||
@JsonProperty("runtime_ms")
|
||||
public Double runtimeMs;
|
||||
|
||||
@JsonProperty("vector")
|
||||
public List<Float> vector;
|
||||
|
||||
@JsonProperty("incomplete")
|
||||
public Boolean incomplete;
|
||||
|
||||
@JsonProperty("combined_vector")
|
||||
public List<Float> combinedVector;
|
||||
|
||||
@JsonProperty("token_count")
|
||||
public Integer tokenCount;
|
||||
|
||||
@JsonProperty("model")
|
||||
public String model;
|
||||
|
||||
@JsonProperty("max_seq_length")
|
||||
public Integer maxSeqLength;
|
||||
}
|
||||
|
||||
static class VectorizeBatchResponse {
|
||||
@JsonProperty("model")
|
||||
public String model;
|
||||
|
||||
@JsonProperty("count")
|
||||
public Integer count;
|
||||
|
||||
@JsonProperty("results")
|
||||
public List<VectorizeBatchItemResponse> results;
|
||||
}
|
||||
|
||||
static class VectorizeBatchItemResponse {
|
||||
@JsonProperty("id")
|
||||
public String id;
|
||||
|
||||
@JsonProperty("vector")
|
||||
public List<Float> vector;
|
||||
|
||||
@JsonProperty("token_count")
|
||||
public Integer tokenCount;
|
||||
|
||||
@JsonProperty("runtime_ms")
|
||||
public Double runtimeMs;
|
||||
|
||||
@JsonProperty("incomplete")
|
||||
public Boolean incomplete;
|
||||
|
||||
@JsonProperty("combined_vector")
|
||||
public List<Float> combinedVector;
|
||||
|
||||
@JsonProperty("truncated")
|
||||
public Boolean truncated;
|
||||
|
||||
@JsonProperty("truncate_length")
|
||||
public Integer truncateLength;
|
||||
|
||||
@JsonProperty("model")
|
||||
public String model;
|
||||
|
||||
@JsonProperty("max_seq_length")
|
||||
public Integer maxSeqLength;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,25 @@
|
||||
dip:
|
||||
embedding:
|
||||
enabled: true
|
||||
default-document-model: e5-default
|
||||
default-query-model: e5-default
|
||||
|
||||
providers:
|
||||
vector-sync-local:
|
||||
type: http-vector-sync
|
||||
base-url: http://localhost:8001
|
||||
connect-timeout: 5s
|
||||
read-timeout: 60s
|
||||
headers:
|
||||
X-Client: dip
|
||||
|
||||
models:
|
||||
e5-default:
|
||||
provider-config-key: vector-sync-local
|
||||
provider-model-key: intfloat/multilingual-e5-large
|
||||
dimensions: 1024
|
||||
distance-metric: COSINE
|
||||
supports-query-embedding-mode: true
|
||||
supports-batch: false
|
||||
max-input-chars: 8192
|
||||
active: true
|
||||
@ -0,0 +1,100 @@
|
||||
package at.procon.dip.embedding.provider.http;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import at.procon.dip.domain.document.DistanceMetric;
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.embedding.model.EmbeddingRequest;
|
||||
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.sun.net.httpserver.HttpExchange;
|
||||
import com.sun.net.httpserver.HttpServer;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.net.InetSocketAddress;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class VectorSyncHttpEmbeddingProviderTest {
|
||||
|
||||
private HttpServer server;
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
if (server != null) {
|
||||
server.stop(0);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void shouldCallVectorSyncEndpointAndParseVector() throws Exception {
|
||||
server = HttpServer.create(new InetSocketAddress(0), 0);
|
||||
server.createContext("/vector-sync", this::handleVectorSync);
|
||||
server.start();
|
||||
|
||||
var provider = new VectorSyncHttpEmbeddingProvider(new ObjectMapper());
|
||||
var config = ResolvedEmbeddingProviderConfig.builder()
|
||||
.key("vector-sync-local")
|
||||
.providerType("http-vector-sync")
|
||||
.baseUrl("http://localhost:" + server.getAddress().getPort())
|
||||
.readTimeout(Duration.ofSeconds(5))
|
||||
.headers(Map.of("X-Client", "dip-test"))
|
||||
.build();
|
||||
var model = new EmbeddingModelDescriptor(
|
||||
"e5-default",
|
||||
"vector-sync-local",
|
||||
"intfloat/multilingual-e5-large",
|
||||
3,
|
||||
DistanceMetric.COSINE,
|
||||
true,
|
||||
false,
|
||||
8192,
|
||||
true
|
||||
);
|
||||
var request = EmbeddingRequest.builder()
|
||||
.modelKey("e5-default")
|
||||
.useCase(EmbeddingUseCase.DOCUMENT)
|
||||
.texts(List.of("This is a sample text to vectorize"))
|
||||
.providerOptions(Map.of())
|
||||
.build();
|
||||
|
||||
var result = provider.embedDocuments(config, model, request);
|
||||
|
||||
assertThat(result.vectors()).hasSize(1);
|
||||
assertThat(result.vectors().getFirst()).containsExactly(0.1f, 0.2f, 0.3f);
|
||||
assertThat(result.tokenCount()).isEqualTo(9);
|
||||
}
|
||||
|
||||
private void handleVectorSync(HttpExchange exchange) throws IOException {
|
||||
String body;
|
||||
try (InputStream in = exchange.getRequestBody()) {
|
||||
body = new String(in.readAllBytes(), StandardCharsets.UTF_8);
|
||||
}
|
||||
|
||||
assertThat(exchange.getRequestMethod()).isEqualTo("POST");
|
||||
assertThat(body).contains("\"model\":\"intfloat/multilingual-e5-large\"");
|
||||
assertThat(body).contains("\"text\":\"This is a sample text to vectorize\"");
|
||||
assertThat(exchange.getRequestHeaders().getFirst("X-Client")).isEqualTo("dip-test");
|
||||
|
||||
String json = "{"
|
||||
+ "\"runtime_ms\":472.49,"
|
||||
+ "\"vector\":[0.1,0.2,0.3],"
|
||||
+ "\"incomplete\":false,"
|
||||
+ "\"combined_vector\":null,"
|
||||
+ "\"token_count\":9,"
|
||||
+ "\"model\":\"intfloat/multilingual-e5-large\","
|
||||
+ "\"max_seq_length\":512"
|
||||
+ "}";
|
||||
|
||||
byte[] response = json.getBytes(StandardCharsets.UTF_8);
|
||||
exchange.getResponseHeaders().add("Content-Type", "application/json");
|
||||
exchange.sendResponseHeaders(200, response.length);
|
||||
exchange.getResponseBody().write(response);
|
||||
exchange.close();
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue