You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
DIP/embedding_service.py

341 lines
11 KiB
Python

#!/usr/bin/env python3
"""
TED Procurement Document Embedding Service
HTTP API for generating text embeddings using sentence-transformers.
Model: intfloat/multilingual-e5-large (1024 dimensions)
Author: Martin.Schweitzer@procon.co.at and claude.ai
Usage:
python embedding_service.py
Environment Variables:
MODEL_NAME: Model to use (default: intfloat/multilingual-e5-large)
MAX_LENGTH: Maximum token length (default: 512)
HOST: Server host (default: 0.0.0.0)
PORT: Server port (default: 8001)
API Endpoints:
POST /embed - Generate embedding for single text
POST /embed/batch - Generate embeddings for multiple texts
GET /health - Health check
"""
import os
import logging
import threading
import time
from typing import List
from contextlib import asynccontextmanager
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import uvicorn
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Suppress noisy HTTP warnings from uvicorn and asyncio
logging.getLogger("uvicorn.error").setLevel(logging.CRITICAL)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
logging.getLogger("asyncio").setLevel(logging.CRITICAL)
# Configuration from environment
MODEL_NAME = os.getenv("MODEL_NAME", "intfloat/multilingual-e5-large")
#MODEL_NAME = os.getenv("MODEL_NAME", "BAAI/bge-m3")
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
HOST = os.getenv("HOST", "0.0.0.0")
PORT = int(os.getenv("PORT", "8001"))
# Global model instance (single model) with thread-safe access
model = None
model_lock = threading.Lock()
model_dimensions = None
# Statistics
embedding_count = 0
total_embedding_time = 0.0
stats_lock = threading.Lock()
class EmbedRequest(BaseModel):
"""Request model for single text embedding."""
text: str = Field(..., description="Text to embed")
is_query: bool = Field(False, description="If True, use 'query:' prefix for e5 models")
class EmbedBatchRequest(BaseModel):
"""Request model for batch text embedding."""
texts: List[str] = Field(..., description="List of texts to embed")
is_query: bool = Field(False, description="If True, use 'query:' prefix for e5 models")
class EmbedResponse(BaseModel):
"""Response model for embedding result."""
embedding: List[float] = Field(..., description="Vector embedding")
dimensions: int = Field(..., description="Number of dimensions")
token_count: int = Field(..., description="Number of input tokens")
class EmbedBatchResponse(BaseModel):
"""Response model for batch embedding result."""
embeddings: List[List[float]] = Field(..., description="List of vector embeddings")
dimensions: int = Field(..., description="Number of dimensions")
count: int = Field(..., description="Number of embeddings generated")
token_counts: List[int] = Field(..., description="Number of input tokens for each text")
class HealthResponse(BaseModel):
"""Health check response."""
status: str
model_name: str
dimensions: int
max_length: int
embeddings_processed: int
avg_time_ms: float
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize single model on startup."""
global model, model_dimensions
from sentence_transformers import SentenceTransformer
logger.info(f"Loading single model: {MODEL_NAME}")
try:
model = SentenceTransformer(MODEL_NAME)
model_dimensions = model.get_sentence_embedding_dimension()
logger.info(f"Model loaded successfully. Embedding dimension: {model_dimensions}")
logger.info("Ready to process embeddings - statistics will be logged every 100 embeddings")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
# Cleanup
with stats_lock:
avg_time_ms = (total_embedding_time / embedding_count * 1000) if embedding_count > 0 else 0.0
logger.info(f"Shutting down embedding service - Final statistics: {embedding_count} embeddings processed, average time: {avg_time_ms:.2f}ms per embedding")
# Create FastAPI app
app = FastAPI(
title="TED Embedding Service",
description="Generate text embeddings using sentence-transformers for semantic search",
version="1.0.0",
lifespan=lifespan
)
def add_prefix(text: str, is_query: bool) -> str:
"""Add appropriate prefix for e5 models."""
if "e5" in MODEL_NAME.lower():
prefix = "query: " if is_query else "passage: "
return prefix + text
return text
def check_token_length(text: str, model) -> tuple[int, bool]:
"""
Check if text exceeds MAX_LENGTH tokens and return token count.
Returns:
tuple: (token_count, is_truncated)
"""
try:
# Get tokenizer from model
tokenizer = model.tokenizer
tokens = tokenizer.encode(text, add_special_tokens=True)
token_count = len(tokens)
byte_count = len(text.encode('utf-8'))
if token_count > MAX_LENGTH:
logger.warning(
f"Text exceeds MAX_LENGTH ({MAX_LENGTH} tokens). "
f"Actual: {token_count} tokens, {byte_count} bytes ({len(text)} chars). "
f"Text will be truncated by the model. "
f"Preview: {text[:100]}..."
)
return token_count, True
return token_count, False
except Exception as e:
logger.debug(f"Could not check token length: {e}")
return 0, False
@app.post("/embed", response_model=EmbedResponse)
async def embed_text(request: EmbedRequest) -> EmbedResponse:
"""Generate embedding for a single text using thread-safe single model."""
global embedding_count, total_embedding_time
if model is None:
raise HTTPException(status_code=503, detail="Model not initialized")
try:
start_time = time.time()
# Thread-safe access to single model
with model_lock:
# Add prefix for e5 models
text = add_prefix(request.text, request.is_query)
# Check token length and warn if exceeding MAX_LENGTH
token_count, is_truncated = check_token_length(text, model)
byte_count = len(text.encode('utf-8'))
if is_truncated:
logger.info(f"Processing text: {token_count} tokens, {byte_count} bytes ({len(text)} chars) - exceeds {MAX_LENGTH}, will be truncated")
# Generate embedding
embedding = model.encode(
text,
normalize_embeddings=True,
convert_to_numpy=True
)
# Update statistics
elapsed_time = time.time() - start_time
with stats_lock:
embedding_count += 1
total_embedding_time += elapsed_time
# Log statistics every 100 embeddings
if embedding_count % 100 == 0:
avg_time = total_embedding_time / embedding_count
logger.info(f"Statistics: {embedding_count} embeddings processed, average time: {avg_time*1000:.2f}ms per embedding")
return EmbedResponse(
embedding=embedding.tolist(),
dimensions=len(embedding),
token_count=token_count
)
except Exception as e:
logger.error(f"Embedding failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/embed/batch", response_model=EmbedBatchResponse)
async def embed_batch(request: EmbedBatchRequest) -> EmbedBatchResponse:
"""Generate embeddings for multiple texts using thread-safe single model."""
global embedding_count, total_embedding_time
if model is None:
raise HTTPException(status_code=503, detail="Model not initialized")
if not request.texts:
raise HTTPException(status_code=400, detail="Empty text list")
try:
start_time = time.time()
batch_count = len(request.texts)
# Thread-safe access to single model
with model_lock:
# Add prefixes
texts = [add_prefix(text, request.is_query) for text in request.texts]
# Check token length for each text and warn if exceeding MAX_LENGTH
truncated_count = 0
token_counts = []
for i, text in enumerate(texts):
token_count, is_truncated = check_token_length(text, model)
token_counts.append(token_count)
if is_truncated:
truncated_count += 1
byte_count = len(text.encode('utf-8'))
logger.info(
f"Batch item {i + 1}/{len(texts)}: {token_count} tokens, "
f"{byte_count} bytes ({len(text)} chars) - exceeds {MAX_LENGTH}, will be truncated"
)
if truncated_count > 0:
logger.warning(
f"Batch processing: {truncated_count}/{len(texts)} texts exceed "
f"MAX_LENGTH ({MAX_LENGTH} tokens) and will be truncated"
)
# Generate embeddings
embeddings = model.encode(
texts,
normalize_embeddings=True,
convert_to_numpy=True,
batch_size=16,
show_progress_bar=False
)
# Update statistics
elapsed_time = time.time() - start_time
with stats_lock:
embedding_count += batch_count
total_embedding_time += elapsed_time
# Log statistics every 100 embeddings
if embedding_count % 100 == 0 or (embedding_count // 100) != ((embedding_count - batch_count) // 100):
avg_time = total_embedding_time / embedding_count
logger.info(f"Statistics: {embedding_count} embeddings processed, average time: {avg_time*1000:.2f}ms per embedding")
return EmbedBatchResponse(
embeddings=[emb.tolist() for emb in embeddings],
dimensions=embeddings.shape[1] if len(embeddings.shape) > 1 else len(embeddings),
count=len(embeddings),
token_counts=token_counts
)
except Exception as e:
logger.error(f"Batch embedding failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", response_model=HealthResponse)
async def health_check() -> HealthResponse:
"""Health check endpoint."""
if model is None:
raise HTTPException(status_code=503, detail="Model not initialized")
with stats_lock:
avg_time_ms = (total_embedding_time / embedding_count * 1000) if embedding_count > 0 else 0.0
return HealthResponse(
status="healthy",
model_name=MODEL_NAME,
dimensions=model_dimensions,
max_length=MAX_LENGTH,
embeddings_processed=embedding_count,
avg_time_ms=round(avg_time_ms, 2)
)
@app.get("/")
async def root():
"""Root endpoint with API info."""
return {
"service": "TED Embedding Service",
"model": MODEL_NAME,
"endpoints": {
"embed": "POST /embed - Generate single embedding",
"embed_batch": "POST /embed/batch - Generate batch embeddings",
"health": "GET /health - Health check"
}
}
if __name__ == "__main__":
logger.info(f"Starting embedding service on {HOST}:{PORT}")
uvicorn.run(
"embedding_service:app",
host=HOST,
port=PORT,
log_level="info",
reload=False
)