You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
341 lines
11 KiB
Python
341 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
TED Procurement Document Embedding Service
|
|
|
|
HTTP API for generating text embeddings using sentence-transformers.
|
|
Model: intfloat/multilingual-e5-large (1024 dimensions)
|
|
|
|
Author: Martin.Schweitzer@procon.co.at and claude.ai
|
|
|
|
Usage:
|
|
python embedding_service.py
|
|
|
|
Environment Variables:
|
|
MODEL_NAME: Model to use (default: intfloat/multilingual-e5-large)
|
|
MAX_LENGTH: Maximum token length (default: 512)
|
|
HOST: Server host (default: 0.0.0.0)
|
|
PORT: Server port (default: 8001)
|
|
|
|
API Endpoints:
|
|
POST /embed - Generate embedding for single text
|
|
POST /embed/batch - Generate embeddings for multiple texts
|
|
GET /health - Health check
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import threading
|
|
import time
|
|
from typing import List
|
|
from contextlib import asynccontextmanager
|
|
|
|
import numpy as np
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
import uvicorn
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Suppress noisy HTTP warnings from uvicorn and asyncio
|
|
logging.getLogger("uvicorn.error").setLevel(logging.CRITICAL)
|
|
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
|
logging.getLogger("asyncio").setLevel(logging.CRITICAL)
|
|
|
|
# Configuration from environment
|
|
MODEL_NAME = os.getenv("MODEL_NAME", "intfloat/multilingual-e5-large")
|
|
#MODEL_NAME = os.getenv("MODEL_NAME", "BAAI/bge-m3")
|
|
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512"))
|
|
HOST = os.getenv("HOST", "0.0.0.0")
|
|
PORT = int(os.getenv("PORT", "8001"))
|
|
|
|
# Global model instance (single model) with thread-safe access
|
|
model = None
|
|
model_lock = threading.Lock()
|
|
model_dimensions = None
|
|
|
|
# Statistics
|
|
embedding_count = 0
|
|
total_embedding_time = 0.0
|
|
stats_lock = threading.Lock()
|
|
|
|
|
|
class EmbedRequest(BaseModel):
|
|
"""Request model for single text embedding."""
|
|
text: str = Field(..., description="Text to embed")
|
|
is_query: bool = Field(False, description="If True, use 'query:' prefix for e5 models")
|
|
|
|
|
|
class EmbedBatchRequest(BaseModel):
|
|
"""Request model for batch text embedding."""
|
|
texts: List[str] = Field(..., description="List of texts to embed")
|
|
is_query: bool = Field(False, description="If True, use 'query:' prefix for e5 models")
|
|
|
|
|
|
class EmbedResponse(BaseModel):
|
|
"""Response model for embedding result."""
|
|
embedding: List[float] = Field(..., description="Vector embedding")
|
|
dimensions: int = Field(..., description="Number of dimensions")
|
|
token_count: int = Field(..., description="Number of input tokens")
|
|
|
|
|
|
class EmbedBatchResponse(BaseModel):
|
|
"""Response model for batch embedding result."""
|
|
embeddings: List[List[float]] = Field(..., description="List of vector embeddings")
|
|
dimensions: int = Field(..., description="Number of dimensions")
|
|
count: int = Field(..., description="Number of embeddings generated")
|
|
token_counts: List[int] = Field(..., description="Number of input tokens for each text")
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""Health check response."""
|
|
status: str
|
|
model_name: str
|
|
dimensions: int
|
|
max_length: int
|
|
embeddings_processed: int
|
|
avg_time_ms: float
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Initialize single model on startup."""
|
|
global model, model_dimensions
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
logger.info(f"Loading single model: {MODEL_NAME}")
|
|
|
|
try:
|
|
model = SentenceTransformer(MODEL_NAME)
|
|
model_dimensions = model.get_sentence_embedding_dimension()
|
|
logger.info(f"Model loaded successfully. Embedding dimension: {model_dimensions}")
|
|
logger.info("Ready to process embeddings - statistics will be logged every 100 embeddings")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {e}")
|
|
raise
|
|
|
|
yield
|
|
|
|
# Cleanup
|
|
with stats_lock:
|
|
avg_time_ms = (total_embedding_time / embedding_count * 1000) if embedding_count > 0 else 0.0
|
|
logger.info(f"Shutting down embedding service - Final statistics: {embedding_count} embeddings processed, average time: {avg_time_ms:.2f}ms per embedding")
|
|
|
|
|
|
# Create FastAPI app
|
|
app = FastAPI(
|
|
title="TED Embedding Service",
|
|
description="Generate text embeddings using sentence-transformers for semantic search",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
|
|
def add_prefix(text: str, is_query: bool) -> str:
|
|
"""Add appropriate prefix for e5 models."""
|
|
if "e5" in MODEL_NAME.lower():
|
|
prefix = "query: " if is_query else "passage: "
|
|
return prefix + text
|
|
return text
|
|
|
|
|
|
def check_token_length(text: str, model) -> tuple[int, bool]:
|
|
"""
|
|
Check if text exceeds MAX_LENGTH tokens and return token count.
|
|
|
|
Returns:
|
|
tuple: (token_count, is_truncated)
|
|
"""
|
|
try:
|
|
# Get tokenizer from model
|
|
tokenizer = model.tokenizer
|
|
tokens = tokenizer.encode(text, add_special_tokens=True)
|
|
token_count = len(tokens)
|
|
byte_count = len(text.encode('utf-8'))
|
|
|
|
if token_count > MAX_LENGTH:
|
|
logger.warning(
|
|
f"Text exceeds MAX_LENGTH ({MAX_LENGTH} tokens). "
|
|
f"Actual: {token_count} tokens, {byte_count} bytes ({len(text)} chars). "
|
|
f"Text will be truncated by the model. "
|
|
f"Preview: {text[:100]}..."
|
|
)
|
|
return token_count, True
|
|
|
|
return token_count, False
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Could not check token length: {e}")
|
|
return 0, False
|
|
|
|
|
|
@app.post("/embed", response_model=EmbedResponse)
|
|
async def embed_text(request: EmbedRequest) -> EmbedResponse:
|
|
"""Generate embedding for a single text using thread-safe single model."""
|
|
global embedding_count, total_embedding_time
|
|
|
|
if model is None:
|
|
raise HTTPException(status_code=503, detail="Model not initialized")
|
|
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Thread-safe access to single model
|
|
with model_lock:
|
|
# Add prefix for e5 models
|
|
text = add_prefix(request.text, request.is_query)
|
|
|
|
# Check token length and warn if exceeding MAX_LENGTH
|
|
token_count, is_truncated = check_token_length(text, model)
|
|
byte_count = len(text.encode('utf-8'))
|
|
if is_truncated:
|
|
logger.info(f"Processing text: {token_count} tokens, {byte_count} bytes ({len(text)} chars) - exceeds {MAX_LENGTH}, will be truncated")
|
|
|
|
# Generate embedding
|
|
embedding = model.encode(
|
|
text,
|
|
normalize_embeddings=True,
|
|
convert_to_numpy=True
|
|
)
|
|
|
|
# Update statistics
|
|
elapsed_time = time.time() - start_time
|
|
with stats_lock:
|
|
embedding_count += 1
|
|
total_embedding_time += elapsed_time
|
|
|
|
# Log statistics every 100 embeddings
|
|
if embedding_count % 100 == 0:
|
|
avg_time = total_embedding_time / embedding_count
|
|
logger.info(f"Statistics: {embedding_count} embeddings processed, average time: {avg_time*1000:.2f}ms per embedding")
|
|
|
|
return EmbedResponse(
|
|
embedding=embedding.tolist(),
|
|
dimensions=len(embedding),
|
|
token_count=token_count
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Embedding failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/embed/batch", response_model=EmbedBatchResponse)
|
|
async def embed_batch(request: EmbedBatchRequest) -> EmbedBatchResponse:
|
|
"""Generate embeddings for multiple texts using thread-safe single model."""
|
|
global embedding_count, total_embedding_time
|
|
|
|
if model is None:
|
|
raise HTTPException(status_code=503, detail="Model not initialized")
|
|
|
|
if not request.texts:
|
|
raise HTTPException(status_code=400, detail="Empty text list")
|
|
|
|
try:
|
|
start_time = time.time()
|
|
batch_count = len(request.texts)
|
|
|
|
# Thread-safe access to single model
|
|
with model_lock:
|
|
# Add prefixes
|
|
texts = [add_prefix(text, request.is_query) for text in request.texts]
|
|
|
|
# Check token length for each text and warn if exceeding MAX_LENGTH
|
|
truncated_count = 0
|
|
token_counts = []
|
|
for i, text in enumerate(texts):
|
|
token_count, is_truncated = check_token_length(text, model)
|
|
token_counts.append(token_count)
|
|
if is_truncated:
|
|
truncated_count += 1
|
|
byte_count = len(text.encode('utf-8'))
|
|
logger.info(
|
|
f"Batch item {i + 1}/{len(texts)}: {token_count} tokens, "
|
|
f"{byte_count} bytes ({len(text)} chars) - exceeds {MAX_LENGTH}, will be truncated"
|
|
)
|
|
|
|
if truncated_count > 0:
|
|
logger.warning(
|
|
f"Batch processing: {truncated_count}/{len(texts)} texts exceed "
|
|
f"MAX_LENGTH ({MAX_LENGTH} tokens) and will be truncated"
|
|
)
|
|
|
|
# Generate embeddings
|
|
embeddings = model.encode(
|
|
texts,
|
|
normalize_embeddings=True,
|
|
convert_to_numpy=True,
|
|
batch_size=16,
|
|
show_progress_bar=False
|
|
)
|
|
|
|
# Update statistics
|
|
elapsed_time = time.time() - start_time
|
|
with stats_lock:
|
|
embedding_count += batch_count
|
|
total_embedding_time += elapsed_time
|
|
|
|
# Log statistics every 100 embeddings
|
|
if embedding_count % 100 == 0 or (embedding_count // 100) != ((embedding_count - batch_count) // 100):
|
|
avg_time = total_embedding_time / embedding_count
|
|
logger.info(f"Statistics: {embedding_count} embeddings processed, average time: {avg_time*1000:.2f}ms per embedding")
|
|
|
|
return EmbedBatchResponse(
|
|
embeddings=[emb.tolist() for emb in embeddings],
|
|
dimensions=embeddings.shape[1] if len(embeddings.shape) > 1 else len(embeddings),
|
|
count=len(embeddings),
|
|
token_counts=token_counts
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Batch embedding failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse)
|
|
async def health_check() -> HealthResponse:
|
|
"""Health check endpoint."""
|
|
if model is None:
|
|
raise HTTPException(status_code=503, detail="Model not initialized")
|
|
|
|
with stats_lock:
|
|
avg_time_ms = (total_embedding_time / embedding_count * 1000) if embedding_count > 0 else 0.0
|
|
|
|
return HealthResponse(
|
|
status="healthy",
|
|
model_name=MODEL_NAME,
|
|
dimensions=model_dimensions,
|
|
max_length=MAX_LENGTH,
|
|
embeddings_processed=embedding_count,
|
|
avg_time_ms=round(avg_time_ms, 2)
|
|
)
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint with API info."""
|
|
return {
|
|
"service": "TED Embedding Service",
|
|
"model": MODEL_NAME,
|
|
"endpoints": {
|
|
"embed": "POST /embed - Generate single embedding",
|
|
"embed_batch": "POST /embed/batch - Generate batch embeddings",
|
|
"health": "GET /health - Health check"
|
|
}
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logger.info(f"Starting embedding service on {HOST}:{PORT}")
|
|
uvicorn.run(
|
|
"embedding_service:app",
|
|
host=HOST,
|
|
port=PORT,
|
|
log_level="info",
|
|
reload=False
|
|
)
|