embedding nv1 + search tests
parent
40890101b1
commit
2687d4ba17
@ -0,0 +1,18 @@
|
||||
package at.procon.dip.config;
|
||||
|
||||
import com.fasterxml.jackson.databind.SerializationFeature;
|
||||
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
|
||||
import org.springframework.boot.autoconfigure.jackson.Jackson2ObjectMapperBuilderCustomizer;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
public class JacksonConfig {
|
||||
|
||||
@Bean
|
||||
public Jackson2ObjectMapperBuilderCustomizer jsonCustomizer() {
|
||||
return builder -> builder
|
||||
.modules(new JavaTimeModule())
|
||||
.featuresToDisable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,89 @@
|
||||
package at.procon.dip.domain.ted.service;
|
||||
|
||||
import at.procon.dip.domain.access.DocumentVisibility;
|
||||
import at.procon.dip.domain.document.DocumentFamily;
|
||||
import at.procon.dip.domain.document.DocumentStatus;
|
||||
import at.procon.dip.domain.document.DocumentType;
|
||||
import at.procon.dip.domain.document.entity.Document;
|
||||
import at.procon.dip.domain.document.repository.DocumentRepository;
|
||||
import at.procon.dip.domain.document.service.DocumentService;
|
||||
import at.procon.dip.domain.document.service.command.CreateDocumentCommand;
|
||||
import at.procon.ted.model.entity.ProcurementDocument;
|
||||
import java.util.UUID;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
/**
|
||||
* Side-effect-free helper for TED projection flows.
|
||||
* <p>
|
||||
* Unlike the legacy Phase 2 bridge, this service only ensures that the canonical
|
||||
* DOC document root exists and is refreshed with TED metadata. It intentionally
|
||||
* does not create/update sources, contents, representations, or embeddings.
|
||||
*/
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class TedGenericDocumentRootService {
|
||||
|
||||
private final DocumentRepository documentRepository;
|
||||
private final DocumentService documentService;
|
||||
|
||||
@Transactional
|
||||
public UUID ensureGenericTedDocumentRoot(ProcurementDocument tedDocument) {
|
||||
return ensureGenericTedDocument(tedDocument).getId();
|
||||
}
|
||||
|
||||
@Transactional
|
||||
public Document ensureGenericTedDocument(ProcurementDocument tedDocument) {
|
||||
Document document = documentRepository.findByDedupHash(tedDocument.getDocumentHash())
|
||||
.orElseGet(() -> createGenericDocument(tedDocument));
|
||||
|
||||
document.setDocumentType(DocumentType.TED_NOTICE);
|
||||
document.setDocumentFamily(DocumentFamily.PROCUREMENT);
|
||||
document.setVisibility(DocumentVisibility.PUBLIC);
|
||||
document.setStatus(DocumentStatus.CLASSIFIED);
|
||||
document.setTitle(tedDocument.getProjectTitle());
|
||||
document.setSummary(tedDocument.getProjectDescription());
|
||||
document.setLanguageCode(tedDocument.getLanguageCode());
|
||||
document.setMimeType("application/xml");
|
||||
document.setBusinessKey(buildBusinessKey(tedDocument));
|
||||
document.setDedupHash(tedDocument.getDocumentHash());
|
||||
|
||||
Document saved = documentService.save(document);
|
||||
log.debug("Ensured side-effect-free generic TED document root {} for legacy TED document {}",
|
||||
saved.getId(), tedDocument.getId());
|
||||
return saved;
|
||||
}
|
||||
|
||||
private Document createGenericDocument(ProcurementDocument tedDocument) {
|
||||
return documentService.create(new CreateDocumentCommand(
|
||||
null,
|
||||
DocumentVisibility.PUBLIC,
|
||||
DocumentType.TED_NOTICE,
|
||||
DocumentFamily.PROCUREMENT,
|
||||
DocumentStatus.CLASSIFIED,
|
||||
tedDocument.getProjectTitle(),
|
||||
tedDocument.getProjectDescription(),
|
||||
tedDocument.getLanguageCode(),
|
||||
"application/xml",
|
||||
buildBusinessKey(tedDocument),
|
||||
tedDocument.getDocumentHash()
|
||||
));
|
||||
}
|
||||
|
||||
private String buildBusinessKey(ProcurementDocument tedDocument) {
|
||||
if (StringUtils.hasText(tedDocument.getPublicationId())) {
|
||||
return "TED:publication:" + tedDocument.getPublicationId();
|
||||
}
|
||||
if (StringUtils.hasText(tedDocument.getNoticeId())) {
|
||||
return "TED:notice:" + tedDocument.getNoticeId();
|
||||
}
|
||||
if (StringUtils.hasText(tedDocument.getNoticeUrl())) {
|
||||
return "TED:url:" + tedDocument.getNoticeUrl();
|
||||
}
|
||||
return "TED:hash:" + tedDocument.getDocumentHash();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,44 @@
|
||||
package at.procon.dip.embedding.config;
|
||||
|
||||
import at.procon.dip.domain.document.DistanceMetric;
|
||||
import java.time.Duration;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import lombok.Data;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
@ConfigurationProperties(prefix = "dip.embedding")
|
||||
@Data
|
||||
public class EmbeddingProperties {
|
||||
|
||||
private boolean enabled = false;
|
||||
private String defaultDocumentModel;
|
||||
private String defaultQueryModel;
|
||||
private Map<String, ProviderProperties> providers = new LinkedHashMap<>();
|
||||
private Map<String, ModelProperties> models = new LinkedHashMap<>();
|
||||
|
||||
@Data
|
||||
public static class ProviderProperties {
|
||||
private String type;
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private Duration connectTimeout = Duration.ofSeconds(5);
|
||||
private Duration readTimeout = Duration.ofSeconds(60);
|
||||
private Map<String, String> headers = new LinkedHashMap<>();
|
||||
private Integer dimensions;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class ModelProperties {
|
||||
private String providerConfigKey;
|
||||
private String providerModelKey;
|
||||
private Integer dimensions;
|
||||
private DistanceMetric distanceMetric = DistanceMetric.COSINE;
|
||||
private boolean supportsQueryEmbeddingMode = true;
|
||||
private boolean supportsBatch = false;
|
||||
private Integer maxInputChars;
|
||||
private boolean active = true;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,16 @@
|
||||
package at.procon.dip.embedding.model;
|
||||
|
||||
import at.procon.dip.domain.document.DistanceMetric;
|
||||
|
||||
public record EmbeddingModelDescriptor(
|
||||
String modelKey,
|
||||
String providerConfigKey,
|
||||
String providerModelKey,
|
||||
int dimensions,
|
||||
DistanceMetric distanceMetric,
|
||||
boolean supportsQueryEmbeddingMode,
|
||||
boolean supportsBatch,
|
||||
Integer maxInputChars,
|
||||
boolean active
|
||||
) {
|
||||
}
|
||||
@ -0,0 +1,12 @@
|
||||
package at.procon.dip.embedding.model;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public record EmbeddingProviderResult(
|
||||
EmbeddingModelDescriptor model,
|
||||
List<float[]> vectors,
|
||||
List<String> warnings,
|
||||
String providerRequestId,
|
||||
Integer tokenCount
|
||||
) {
|
||||
}
|
||||
@ -0,0 +1,14 @@
|
||||
package at.procon.dip.embedding.model;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
|
||||
@Builder
|
||||
public record EmbeddingRequest(
|
||||
String modelKey,
|
||||
EmbeddingUseCase useCase,
|
||||
List<String> texts,
|
||||
Map<String, Object> providerOptions
|
||||
) {
|
||||
}
|
||||
@ -0,0 +1,6 @@
|
||||
package at.procon.dip.embedding.model;
|
||||
|
||||
public enum EmbeddingUseCase {
|
||||
DOCUMENT,
|
||||
QUERY
|
||||
}
|
||||
@ -0,0 +1,18 @@
|
||||
package at.procon.dip.embedding.model;
|
||||
|
||||
import java.time.Duration;
|
||||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
|
||||
@Builder
|
||||
public record ResolvedEmbeddingProviderConfig(
|
||||
String key,
|
||||
String providerType,
|
||||
String baseUrl,
|
||||
String apiKey,
|
||||
Duration connectTimeout,
|
||||
Duration readTimeout,
|
||||
Map<String, String> headers,
|
||||
Integer dimensions
|
||||
) {
|
||||
}
|
||||
@ -0,0 +1,25 @@
|
||||
package at.procon.dip.embedding.provider;
|
||||
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.model.EmbeddingRequest;
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
|
||||
public interface EmbeddingProvider {
|
||||
|
||||
String providerType();
|
||||
|
||||
boolean supports(EmbeddingModelDescriptor model, ResolvedEmbeddingProviderConfig providerConfig);
|
||||
|
||||
EmbeddingProviderResult embedDocuments(
|
||||
ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request
|
||||
);
|
||||
|
||||
EmbeddingProviderResult embedQuery(
|
||||
ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request
|
||||
);
|
||||
}
|
||||
@ -0,0 +1,151 @@
|
||||
package at.procon.dip.embedding.provider.http;
|
||||
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.model.EmbeddingRequest;
|
||||
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
import at.procon.dip.embedding.provider.EmbeddingProvider;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.http.HttpClient;
|
||||
import java.net.http.HttpRequest;
|
||||
import java.net.http.HttpResponse;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.time.Duration;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class ExternalHttpEmbeddingProvider implements EmbeddingProvider {
|
||||
|
||||
private static final String PROVIDER_TYPE = "http-json";
|
||||
|
||||
private final ObjectMapper objectMapper;
|
||||
private final HttpClient httpClient = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build();
|
||||
|
||||
@Override
|
||||
public String providerType() {
|
||||
return PROVIDER_TYPE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supports(EmbeddingModelDescriptor model, ResolvedEmbeddingProviderConfig providerConfig) {
|
||||
return PROVIDER_TYPE.equalsIgnoreCase(providerConfig.providerType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingProviderResult embedDocuments(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
return execute(providerConfig, model, request, EmbeddingUseCase.DOCUMENT);
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingProviderResult embedQuery(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
return execute(providerConfig, model, request, EmbeddingUseCase.QUERY);
|
||||
}
|
||||
|
||||
private EmbeddingProviderResult execute(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request,
|
||||
EmbeddingUseCase useCase) {
|
||||
try {
|
||||
var payload = new ProviderRequest(
|
||||
model.providerModelKey(),
|
||||
request.texts(),
|
||||
useCase == EmbeddingUseCase.QUERY,
|
||||
request.providerOptions() == null ? Map.of() : request.providerOptions()
|
||||
);
|
||||
|
||||
HttpRequest.Builder builder = HttpRequest.newBuilder()
|
||||
.uri(URI.create(trimTrailingSlash(providerConfig.baseUrl()) + "/embed"))
|
||||
.timeout(providerConfig.readTimeout() == null ? Duration.ofSeconds(60) : providerConfig.readTimeout())
|
||||
.header("Content-Type", "application/json")
|
||||
.POST(HttpRequest.BodyPublishers.ofString(objectMapper.writeValueAsString(payload), StandardCharsets.UTF_8));
|
||||
|
||||
if (providerConfig.apiKey() != null && !providerConfig.apiKey().isBlank()) {
|
||||
builder.header("Authorization", "Bearer " + providerConfig.apiKey());
|
||||
}
|
||||
if (providerConfig.headers() != null) {
|
||||
providerConfig.headers().forEach(builder::header);
|
||||
}
|
||||
|
||||
HttpResponse<String> response = httpClient.send(builder.build(), HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8));
|
||||
if (response.statusCode() / 100 != 2) {
|
||||
throw new IllegalStateException("Embedding provider returned status %d: %s".formatted(response.statusCode(), response.body()));
|
||||
}
|
||||
|
||||
ProviderResponse parsed = objectMapper.readValue(response.body(), ProviderResponse.class);
|
||||
List<float[]> vectors = new ArrayList<>();
|
||||
if (parsed.embeddings != null) {
|
||||
for (List<Float> embedding : parsed.embeddings) {
|
||||
vectors.add(toArray(embedding));
|
||||
}
|
||||
} else if (parsed.embedding != null) {
|
||||
vectors.add(toArray(parsed.embedding));
|
||||
}
|
||||
|
||||
return new EmbeddingProviderResult(
|
||||
model,
|
||||
vectors,
|
||||
parsed.warnings == null ? List.of() : parsed.warnings,
|
||||
parsed.requestId,
|
||||
parsed.tokenCount
|
||||
);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new IllegalStateException("Embedding provider call interrupted", e);
|
||||
} catch (IOException e) {
|
||||
throw new IllegalStateException("Failed to call external embedding provider", e);
|
||||
}
|
||||
}
|
||||
|
||||
private float[] toArray(List<Float> embedding) {
|
||||
float[] result = new float[embedding.size()];
|
||||
for (int i = 0; i < embedding.size(); i++) {
|
||||
result[i] = embedding.get(i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private String trimTrailingSlash(String value) {
|
||||
if (value == null || value.isBlank()) {
|
||||
throw new IllegalArgumentException("Embedding provider baseUrl must be configured");
|
||||
}
|
||||
return value.endsWith("/") ? value.substring(0, value.length() - 1) : value;
|
||||
}
|
||||
|
||||
private record ProviderRequest(
|
||||
@JsonProperty("model") String model,
|
||||
@JsonProperty("texts") List<String> texts,
|
||||
@JsonProperty("is_query") boolean query,
|
||||
@JsonProperty("options") Map<String, Object> options
|
||||
) {
|
||||
}
|
||||
|
||||
private static class ProviderResponse {
|
||||
@JsonProperty("embedding")
|
||||
public List<Float> embedding;
|
||||
|
||||
@JsonProperty("embeddings")
|
||||
public List<List<Float>> embeddings;
|
||||
|
||||
@JsonProperty("warnings")
|
||||
public List<String> warnings;
|
||||
|
||||
@JsonProperty("request_id")
|
||||
public String requestId;
|
||||
|
||||
@JsonProperty("token_count")
|
||||
public Integer tokenCount;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,72 @@
|
||||
package at.procon.dip.embedding.provider.mock;
|
||||
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.model.EmbeddingRequest;
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
import at.procon.dip.embedding.provider.EmbeddingProvider;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class MockEmbeddingProvider implements EmbeddingProvider {
|
||||
|
||||
private static final String PROVIDER_TYPE = "mock";
|
||||
|
||||
@Override
|
||||
public String providerType() {
|
||||
return PROVIDER_TYPE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supports(EmbeddingModelDescriptor model, ResolvedEmbeddingProviderConfig providerConfig) {
|
||||
return PROVIDER_TYPE.equalsIgnoreCase(providerConfig.providerType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingProviderResult embedDocuments(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
return execute(providerConfig, model, request);
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingProviderResult embedQuery(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
return execute(providerConfig, model, request);
|
||||
}
|
||||
|
||||
private EmbeddingProviderResult execute(ResolvedEmbeddingProviderConfig providerConfig,
|
||||
EmbeddingModelDescriptor model,
|
||||
EmbeddingRequest request) {
|
||||
int dimensions = model.dimensions() > 0
|
||||
? model.dimensions()
|
||||
: providerConfig.dimensions() == null ? 16 : providerConfig.dimensions();
|
||||
|
||||
List<float[]> vectors = new ArrayList<>();
|
||||
for (String text : request.texts()) {
|
||||
vectors.add(embedDeterministically(text, dimensions));
|
||||
}
|
||||
|
||||
return new EmbeddingProviderResult(
|
||||
model,
|
||||
vectors,
|
||||
List.of(),
|
||||
"mock-" + UUID.randomUUID(),
|
||||
request.texts().stream().mapToInt(text -> text == null ? 0 : text.length()).sum()
|
||||
);
|
||||
}
|
||||
|
||||
private float[] embedDeterministically(String text, int dimensions) {
|
||||
float[] vector = new float[dimensions];
|
||||
String value = text == null ? "" : text;
|
||||
for (int i = 0; i < value.length(); i++) {
|
||||
int bucket = i % dimensions;
|
||||
vector[bucket] += ((value.charAt(i) % 31) + 1) / 31.0f;
|
||||
}
|
||||
return vector;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,63 @@
|
||||
package at.procon.dip.embedding.registry;
|
||||
|
||||
import at.procon.dip.embedding.config.EmbeddingProperties;
|
||||
import at.procon.dip.embedding.model.EmbeddingModelDescriptor;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class EmbeddingModelRegistry {
|
||||
|
||||
private final EmbeddingProperties properties;
|
||||
|
||||
public EmbeddingModelDescriptor getRequired(String modelKey) {
|
||||
return find(modelKey)
|
||||
.orElseThrow(() -> new IllegalArgumentException("Unknown embedding model key: " + modelKey));
|
||||
}
|
||||
|
||||
public Optional<EmbeddingModelDescriptor> find(String modelKey) {
|
||||
EmbeddingProperties.ModelProperties model = properties.getModels().get(modelKey);
|
||||
if (model == null) {
|
||||
return Optional.empty();
|
||||
}
|
||||
return Optional.of(toDescriptor(modelKey, model));
|
||||
}
|
||||
|
||||
public List<EmbeddingModelDescriptor> getActiveModels() {
|
||||
return properties.getModels().entrySet().stream()
|
||||
.filter(entry -> entry.getValue().isActive())
|
||||
.map(entry -> toDescriptor(entry.getKey(), entry.getValue()))
|
||||
.toList();
|
||||
}
|
||||
|
||||
public String getRequiredDefaultQueryModelKey() {
|
||||
if (properties.getDefaultQueryModel() == null || properties.getDefaultQueryModel().isBlank()) {
|
||||
throw new IllegalStateException("dip.embedding.default-query-model is not configured");
|
||||
}
|
||||
return properties.getDefaultQueryModel();
|
||||
}
|
||||
|
||||
public String getRequiredDefaultDocumentModelKey() {
|
||||
if (properties.getDefaultDocumentModel() == null || properties.getDefaultDocumentModel().isBlank()) {
|
||||
throw new IllegalStateException("dip.embedding.default-document-model is not configured");
|
||||
}
|
||||
return properties.getDefaultDocumentModel();
|
||||
}
|
||||
|
||||
private EmbeddingModelDescriptor toDescriptor(String modelKey, EmbeddingProperties.ModelProperties model) {
|
||||
return new EmbeddingModelDescriptor(
|
||||
modelKey,
|
||||
model.getProviderConfigKey(),
|
||||
model.getProviderModelKey() == null || model.getProviderModelKey().isBlank() ? modelKey : model.getProviderModelKey(),
|
||||
model.getDimensions() == null ? 0 : model.getDimensions(),
|
||||
model.getDistanceMetric(),
|
||||
model.isSupportsQueryEmbeddingMode(),
|
||||
model.isSupportsBatch(),
|
||||
model.getMaxInputChars(),
|
||||
model.isActive()
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,32 @@
|
||||
package at.procon.dip.embedding.registry;
|
||||
|
||||
import at.procon.dip.embedding.config.EmbeddingProperties;
|
||||
import at.procon.dip.embedding.model.ResolvedEmbeddingProviderConfig;
|
||||
import java.util.Map;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class EmbeddingProviderConfigResolver {
|
||||
|
||||
private final EmbeddingProperties properties;
|
||||
|
||||
public ResolvedEmbeddingProviderConfig resolve(String providerConfigKey) {
|
||||
EmbeddingProperties.ProviderProperties provider = properties.getProviders().get(providerConfigKey);
|
||||
if (provider == null) {
|
||||
throw new IllegalArgumentException("Unknown embedding provider config key: " + providerConfigKey);
|
||||
}
|
||||
|
||||
return ResolvedEmbeddingProviderConfig.builder()
|
||||
.key(providerConfigKey)
|
||||
.providerType(provider.getType())
|
||||
.baseUrl(provider.getBaseUrl())
|
||||
.apiKey(provider.getApiKey())
|
||||
.connectTimeout(provider.getConnectTimeout())
|
||||
.readTimeout(provider.getReadTimeout())
|
||||
.headers(provider.getHeaders() == null ? Map.of() : Map.copyOf(provider.getHeaders()))
|
||||
.dimensions(provider.getDimensions())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,20 @@
|
||||
package at.procon.dip.embedding.registry;
|
||||
|
||||
import at.procon.dip.embedding.provider.EmbeddingProvider;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
public class EmbeddingProviderRegistry {
|
||||
|
||||
private final List<EmbeddingProvider> providers;
|
||||
|
||||
public EmbeddingProvider getRequired(String providerType) {
|
||||
return providers.stream()
|
||||
.filter(provider -> provider.providerType().equalsIgnoreCase(providerType))
|
||||
.findFirst()
|
||||
.orElseThrow(() -> new IllegalArgumentException("No embedding provider registered for type: " + providerType));
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,36 @@
|
||||
package at.procon.dip.embedding.service;
|
||||
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
||||
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
||||
import java.util.List;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class DefaultQueryEmbeddingService implements QueryEmbeddingService {
|
||||
|
||||
private final EmbeddingExecutionService executionService;
|
||||
private final EmbeddingModelRegistry modelRegistry;
|
||||
|
||||
@Override
|
||||
public float[] embedQuery(String queryText) {
|
||||
return embedQuery(queryText, modelRegistry.getRequiredDefaultQueryModelKey());
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] embedQuery(String queryText, String modelKey) {
|
||||
EmbeddingProviderResult result = executionService.embedTexts(
|
||||
modelKey,
|
||||
EmbeddingUseCase.QUERY,
|
||||
List.of(queryText)
|
||||
);
|
||||
|
||||
if (result.vectors() == null || result.vectors().isEmpty()) {
|
||||
throw new IllegalStateException("Embedding provider returned no query vector for model " + modelKey);
|
||||
}
|
||||
|
||||
return result.vectors().getFirst();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,51 @@
|
||||
package at.procon.dip.embedding.service;
|
||||
|
||||
import at.procon.dip.embedding.model.EmbeddingProviderResult;
|
||||
import at.procon.dip.embedding.model.EmbeddingRequest;
|
||||
import at.procon.dip.embedding.model.EmbeddingUseCase;
|
||||
import at.procon.dip.embedding.provider.EmbeddingProvider;
|
||||
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
||||
import at.procon.dip.embedding.registry.EmbeddingProviderConfigResolver;
|
||||
import at.procon.dip.embedding.registry.EmbeddingProviderRegistry;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class EmbeddingExecutionService {
|
||||
|
||||
private final EmbeddingModelRegistry modelRegistry;
|
||||
private final EmbeddingProviderConfigResolver providerConfigResolver;
|
||||
private final EmbeddingProviderRegistry providerRegistry;
|
||||
|
||||
public EmbeddingProviderResult embedTexts(String modelKey, EmbeddingUseCase useCase, List<String> texts) {
|
||||
return embedTexts(modelKey, useCase, texts, Map.of());
|
||||
}
|
||||
|
||||
public EmbeddingProviderResult embedTexts(String modelKey,
|
||||
EmbeddingUseCase useCase,
|
||||
List<String> texts,
|
||||
Map<String, Object> providerOptions) {
|
||||
var model = modelRegistry.getRequired(modelKey);
|
||||
var providerConfig = providerConfigResolver.resolve(model.providerConfigKey());
|
||||
EmbeddingProvider provider = providerRegistry.getRequired(providerConfig.providerType());
|
||||
|
||||
if (!provider.supports(model, providerConfig)) {
|
||||
throw new IllegalStateException("Provider %s does not support model %s".formatted(
|
||||
provider.providerType(), model.modelKey()));
|
||||
}
|
||||
|
||||
EmbeddingRequest request = EmbeddingRequest.builder()
|
||||
.modelKey(model.modelKey())
|
||||
.useCase(useCase)
|
||||
.texts(texts)
|
||||
.providerOptions(providerOptions)
|
||||
.build();
|
||||
|
||||
return useCase == EmbeddingUseCase.QUERY
|
||||
? provider.embedQuery(providerConfig, model, request)
|
||||
: provider.embedDocuments(providerConfig, model, request);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,8 @@
|
||||
package at.procon.dip.embedding.service;
|
||||
|
||||
public interface QueryEmbeddingService {
|
||||
|
||||
float[] embedQuery(String queryText);
|
||||
|
||||
float[] embedQuery(String queryText, String modelKey);
|
||||
}
|
||||
@ -0,0 +1,45 @@
|
||||
package at.procon.dip.embedding.startup;
|
||||
|
||||
import at.procon.dip.embedding.config.EmbeddingProperties;
|
||||
import at.procon.dip.embedding.registry.EmbeddingModelRegistry;
|
||||
import at.procon.dip.embedding.registry.EmbeddingProviderConfigResolver;
|
||||
import at.procon.dip.embedding.registry.EmbeddingProviderRegistry;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.boot.ApplicationArguments;
|
||||
import org.springframework.boot.ApplicationRunner;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@RequiredArgsConstructor
|
||||
@ConditionalOnProperty(prefix = "dip.embedding", name = "enabled", havingValue = "true")
|
||||
@Slf4j
|
||||
public class EmbeddingSubsystemStartupValidator implements ApplicationRunner {
|
||||
|
||||
private final EmbeddingProperties properties;
|
||||
private final EmbeddingModelRegistry modelRegistry;
|
||||
private final EmbeddingProviderConfigResolver providerConfigResolver;
|
||||
private final EmbeddingProviderRegistry providerRegistry;
|
||||
|
||||
@Override
|
||||
public void run(ApplicationArguments args) {
|
||||
if (properties.getModels().isEmpty()) {
|
||||
throw new IllegalStateException("dip.embedding.enabled=true but no models are configured");
|
||||
}
|
||||
|
||||
modelRegistry.getActiveModels().forEach(model -> {
|
||||
var providerConfig = providerConfigResolver.resolve(model.providerConfigKey());
|
||||
providerRegistry.getRequired(providerConfig.providerType());
|
||||
log.info("Validated embedding model {} -> provider {} ({})",
|
||||
model.modelKey(), model.providerConfigKey(), providerConfig.providerType());
|
||||
});
|
||||
|
||||
if (properties.getDefaultDocumentModel() != null && !properties.getDefaultDocumentModel().isBlank()) {
|
||||
modelRegistry.getRequired(properties.getDefaultDocumentModel());
|
||||
}
|
||||
if (properties.getDefaultQueryModel() != null && !properties.getDefaultQueryModel().isBlank()) {
|
||||
modelRegistry.getRequired(properties.getDefaultQueryModel());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,3 @@
|
||||
CREATE SCHEMA IF NOT EXISTS DOC;
|
||||
CREATE SCHEMA IF NOT EXISTS TED;
|
||||
CREATE EXTENSION IF NOT EXISTS pg_trgm;
|
||||
Loading…
Reference in New Issue