#!/usr/bin/env python3 """ TED Procurement Document Embedding Service HTTP API for generating text embeddings using sentence-transformers. Model: intfloat/multilingual-e5-large (1024 dimensions) Author: Martin.Schweitzer@procon.co.at and claude.ai Usage: python embedding_service.py Environment Variables: MODEL_NAME: Model to use (default: intfloat/multilingual-e5-large) MAX_LENGTH: Maximum token length (default: 512) HOST: Server host (default: 0.0.0.0) PORT: Server port (default: 8001) API Endpoints: POST /embed - Generate embedding for single text POST /embed/batch - Generate embeddings for multiple texts GET /health - Health check """ import os import logging import threading import time from typing import List from contextlib import asynccontextmanager import numpy as np from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field import uvicorn # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Suppress noisy HTTP warnings from uvicorn and asyncio logging.getLogger("uvicorn.error").setLevel(logging.CRITICAL) logging.getLogger("uvicorn.access").setLevel(logging.WARNING) logging.getLogger("asyncio").setLevel(logging.CRITICAL) # Configuration from environment MODEL_NAME = os.getenv("MODEL_NAME", "intfloat/multilingual-e5-large") #MODEL_NAME = os.getenv("MODEL_NAME", "BAAI/bge-m3") MAX_LENGTH = int(os.getenv("MAX_LENGTH", "512")) HOST = os.getenv("HOST", "0.0.0.0") PORT = int(os.getenv("PORT", "8001")) # Global model instance (single model) with thread-safe access model = None model_lock = threading.Lock() model_dimensions = None # Statistics embedding_count = 0 total_embedding_time = 0.0 stats_lock = threading.Lock() class EmbedRequest(BaseModel): """Request model for single text embedding.""" text: str = Field(..., description="Text to embed") is_query: bool = Field(False, description="If True, use 'query:' prefix for e5 models") class EmbedBatchRequest(BaseModel): """Request model for batch text embedding.""" texts: List[str] = Field(..., description="List of texts to embed") is_query: bool = Field(False, description="If True, use 'query:' prefix for e5 models") class EmbedResponse(BaseModel): """Response model for embedding result.""" embedding: List[float] = Field(..., description="Vector embedding") dimensions: int = Field(..., description="Number of dimensions") token_count: int = Field(..., description="Number of input tokens") class EmbedBatchResponse(BaseModel): """Response model for batch embedding result.""" embeddings: List[List[float]] = Field(..., description="List of vector embeddings") dimensions: int = Field(..., description="Number of dimensions") count: int = Field(..., description="Number of embeddings generated") token_counts: List[int] = Field(..., description="Number of input tokens for each text") class HealthResponse(BaseModel): """Health check response.""" status: str model_name: str dimensions: int max_length: int embeddings_processed: int avg_time_ms: float @asynccontextmanager async def lifespan(app: FastAPI): """Initialize single model on startup.""" global model, model_dimensions from sentence_transformers import SentenceTransformer logger.info(f"Loading single model: {MODEL_NAME}") try: model = SentenceTransformer(MODEL_NAME) model_dimensions = model.get_sentence_embedding_dimension() logger.info(f"Model loaded successfully. Embedding dimension: {model_dimensions}") logger.info("Ready to process embeddings - statistics will be logged every 100 embeddings") except Exception as e: logger.error(f"Failed to load model: {e}") raise yield # Cleanup with stats_lock: avg_time_ms = (total_embedding_time / embedding_count * 1000) if embedding_count > 0 else 0.0 logger.info(f"Shutting down embedding service - Final statistics: {embedding_count} embeddings processed, average time: {avg_time_ms:.2f}ms per embedding") # Create FastAPI app app = FastAPI( title="TED Embedding Service", description="Generate text embeddings using sentence-transformers for semantic search", version="1.0.0", lifespan=lifespan ) def add_prefix(text: str, is_query: bool) -> str: """Add appropriate prefix for e5 models.""" if "e5" in MODEL_NAME.lower(): prefix = "query: " if is_query else "passage: " return prefix + text return text def check_token_length(text: str, model) -> tuple[int, bool]: """ Check if text exceeds MAX_LENGTH tokens and return token count. Returns: tuple: (token_count, is_truncated) """ try: # Get tokenizer from model tokenizer = model.tokenizer tokens = tokenizer.encode(text, add_special_tokens=True) token_count = len(tokens) byte_count = len(text.encode('utf-8')) if token_count > MAX_LENGTH: logger.warning( f"Text exceeds MAX_LENGTH ({MAX_LENGTH} tokens). " f"Actual: {token_count} tokens, {byte_count} bytes ({len(text)} chars). " f"Text will be truncated by the model. " f"Preview: {text[:100]}..." ) return token_count, True return token_count, False except Exception as e: logger.debug(f"Could not check token length: {e}") return 0, False @app.post("/embed", response_model=EmbedResponse) async def embed_text(request: EmbedRequest) -> EmbedResponse: """Generate embedding for a single text using thread-safe single model.""" global embedding_count, total_embedding_time if model is None: raise HTTPException(status_code=503, detail="Model not initialized") try: start_time = time.time() # Thread-safe access to single model with model_lock: # Add prefix for e5 models text = add_prefix(request.text, request.is_query) # Check token length and warn if exceeding MAX_LENGTH token_count, is_truncated = check_token_length(text, model) byte_count = len(text.encode('utf-8')) if is_truncated: logger.info(f"Processing text: {token_count} tokens, {byte_count} bytes ({len(text)} chars) - exceeds {MAX_LENGTH}, will be truncated") # Generate embedding embedding = model.encode( text, normalize_embeddings=True, convert_to_numpy=True ) # Update statistics elapsed_time = time.time() - start_time with stats_lock: embedding_count += 1 total_embedding_time += elapsed_time # Log statistics every 100 embeddings if embedding_count % 100 == 0: avg_time = total_embedding_time / embedding_count logger.info(f"Statistics: {embedding_count} embeddings processed, average time: {avg_time*1000:.2f}ms per embedding") return EmbedResponse( embedding=embedding.tolist(), dimensions=len(embedding), token_count=token_count ) except Exception as e: logger.error(f"Embedding failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/embed/batch", response_model=EmbedBatchResponse) async def embed_batch(request: EmbedBatchRequest) -> EmbedBatchResponse: """Generate embeddings for multiple texts using thread-safe single model.""" global embedding_count, total_embedding_time if model is None: raise HTTPException(status_code=503, detail="Model not initialized") if not request.texts: raise HTTPException(status_code=400, detail="Empty text list") try: start_time = time.time() batch_count = len(request.texts) # Thread-safe access to single model with model_lock: # Add prefixes texts = [add_prefix(text, request.is_query) for text in request.texts] # Check token length for each text and warn if exceeding MAX_LENGTH truncated_count = 0 token_counts = [] for i, text in enumerate(texts): token_count, is_truncated = check_token_length(text, model) token_counts.append(token_count) if is_truncated: truncated_count += 1 byte_count = len(text.encode('utf-8')) logger.info( f"Batch item {i + 1}/{len(texts)}: {token_count} tokens, " f"{byte_count} bytes ({len(text)} chars) - exceeds {MAX_LENGTH}, will be truncated" ) if truncated_count > 0: logger.warning( f"Batch processing: {truncated_count}/{len(texts)} texts exceed " f"MAX_LENGTH ({MAX_LENGTH} tokens) and will be truncated" ) # Generate embeddings embeddings = model.encode( texts, normalize_embeddings=True, convert_to_numpy=True, batch_size=16, show_progress_bar=False ) # Update statistics elapsed_time = time.time() - start_time with stats_lock: embedding_count += batch_count total_embedding_time += elapsed_time # Log statistics every 100 embeddings if embedding_count % 100 == 0 or (embedding_count // 100) != ((embedding_count - batch_count) // 100): avg_time = total_embedding_time / embedding_count logger.info(f"Statistics: {embedding_count} embeddings processed, average time: {avg_time*1000:.2f}ms per embedding") return EmbedBatchResponse( embeddings=[emb.tolist() for emb in embeddings], dimensions=embeddings.shape[1] if len(embeddings.shape) > 1 else len(embeddings), count=len(embeddings), token_counts=token_counts ) except Exception as e: logger.error(f"Batch embedding failed: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health", response_model=HealthResponse) async def health_check() -> HealthResponse: """Health check endpoint.""" if model is None: raise HTTPException(status_code=503, detail="Model not initialized") with stats_lock: avg_time_ms = (total_embedding_time / embedding_count * 1000) if embedding_count > 0 else 0.0 return HealthResponse( status="healthy", model_name=MODEL_NAME, dimensions=model_dimensions, max_length=MAX_LENGTH, embeddings_processed=embedding_count, avg_time_ms=round(avg_time_ms, 2) ) @app.get("/") async def root(): """Root endpoint with API info.""" return { "service": "TED Embedding Service", "model": MODEL_NAME, "endpoints": { "embed": "POST /embed - Generate single embedding", "embed_batch": "POST /embed/batch - Generate batch embeddings", "health": "GET /health - Health check" } } if __name__ == "__main__": logger.info(f"Starting embedding service on {HOST}:{PORT}") uvicorn.run( "embedding_service:app", host=HOST, port=PORT, log_level="info", reload=False )