""" Context Management API Endpoints. Provides REST endpoints for context assembly and optimization for LLM requests using the ContextEngine. """ import logging from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field from app.api.dependencies.permissions import require_superuser from app.models.user import User from app.services.context import ( AssemblyTimeoutError, BudgetExceededError, ContextEngine, ContextSettings, create_context_engine, get_context_settings, ) from app.services.mcp import MCPClientManager, get_mcp_client logger = logging.getLogger(__name__) router = APIRouter() # ============================================================================ # Singleton Engine Management # ============================================================================ _context_engine: ContextEngine | None = None def _get_or_create_engine( mcp: MCPClientManager, settings: ContextSettings | None = None, ) -> ContextEngine: """Get or create the singleton ContextEngine.""" global _context_engine if _context_engine is None: _context_engine = create_context_engine( mcp_manager=mcp, redis=None, # Optional: add Redis caching later settings=settings or get_context_settings(), ) logger.info("ContextEngine initialized") else: # Ensure MCP manager is up to date _context_engine.set_mcp_manager(mcp) return _context_engine async def get_context_engine( mcp: MCPClientManager = Depends(get_mcp_client), ) -> ContextEngine: """FastAPI dependency to get the ContextEngine.""" return _get_or_create_engine(mcp) # ============================================================================ # Request/Response Schemas # ============================================================================ class ConversationTurn(BaseModel): """A single conversation turn.""" role: str = Field(..., description="Role: 'user' or 'assistant'") content: str = Field(..., description="Message content") class ToolResult(BaseModel): """A tool execution result.""" tool_name: str = Field(..., description="Name of the tool") content: str | dict[str, Any] = Field(..., description="Tool result content") status: str = Field(default="success", description="Execution status") class AssembleContextRequest(BaseModel): """Request to assemble context for an LLM request.""" project_id: str = Field(..., description="Project identifier") agent_id: str = Field(..., description="Agent identifier") query: str = Field(..., description="User's query or current request") model: str = Field( default="claude-3-sonnet", description="Target model name", ) max_tokens: int | None = Field( None, description="Maximum context tokens (uses model default if None)", ) system_prompt: str | None = Field( None, description="System prompt/instructions", ) task_description: str | None = Field( None, description="Current task description", ) knowledge_query: str | None = Field( None, description="Query for knowledge base search", ) knowledge_limit: int = Field( default=10, ge=1, le=50, description="Max number of knowledge results", ) conversation_history: list[ConversationTurn] | None = Field( None, description="Previous conversation turns", ) tool_results: list[ToolResult] | None = Field( None, description="Tool execution results to include", ) compress: bool = Field( default=True, description="Whether to apply compression", ) use_cache: bool = Field( default=True, description="Whether to use caching", ) class AssembledContextResponse(BaseModel): """Response containing assembled context.""" content: str = Field(..., description="Assembled context content") total_tokens: int = Field(..., description="Total token count") context_count: int = Field(..., description="Number of context items included") compressed: bool = Field(..., description="Whether compression was applied") budget_used_percent: float = Field( ..., description="Percentage of token budget used", ) metadata: dict[str, Any] = Field( default_factory=dict, description="Additional metadata", ) class TokenCountRequest(BaseModel): """Request to count tokens in content.""" content: str = Field(..., description="Content to count tokens in") model: str | None = Field( None, description="Model for model-specific tokenization", ) class TokenCountResponse(BaseModel): """Response containing token count.""" token_count: int = Field(..., description="Number of tokens") model: str | None = Field(None, description="Model used for counting") class BudgetInfoResponse(BaseModel): """Response containing budget information for a model.""" model: str = Field(..., description="Model name") total_tokens: int = Field(..., description="Total token budget") system_tokens: int = Field(..., description="Tokens reserved for system") knowledge_tokens: int = Field(..., description="Tokens for knowledge") conversation_tokens: int = Field(..., description="Tokens for conversation") tool_tokens: int = Field(..., description="Tokens for tool results") response_reserve: int = Field(..., description="Tokens reserved for response") class ContextEngineStatsResponse(BaseModel): """Response containing engine statistics.""" cache: dict[str, Any] = Field(..., description="Cache statistics") settings: dict[str, Any] = Field(..., description="Current settings") class HealthResponse(BaseModel): """Health check response.""" status: str = Field(..., description="Health status") mcp_connected: bool = Field(..., description="Whether MCP is connected") cache_enabled: bool = Field(..., description="Whether caching is enabled") # ============================================================================ # Endpoints # ============================================================================ @router.get( "/health", response_model=HealthResponse, summary="Context Engine Health", description="Check health status of the context engine.", ) async def health_check( engine: ContextEngine = Depends(get_context_engine), ) -> HealthResponse: """Check context engine health.""" stats = await engine.get_stats() return HealthResponse( status="healthy", mcp_connected=engine._mcp is not None, cache_enabled=stats.get("settings", {}).get("cache_enabled", False), ) @router.post( "/assemble", response_model=AssembledContextResponse, summary="Assemble Context", description="Assemble optimized context for an LLM request.", ) async def assemble_context( request: AssembleContextRequest, current_user: User = Depends(require_superuser), engine: ContextEngine = Depends(get_context_engine), ) -> AssembledContextResponse: """ Assemble optimized context for an LLM request. This endpoint gathers context from various sources, scores and ranks them, compresses if needed, and formats for the target model. """ logger.info( "Context assembly for project=%s agent=%s by user=%s", request.project_id, request.agent_id, current_user.id, ) # Convert conversation history to dict format conversation_history = None if request.conversation_history: conversation_history = [ {"role": turn.role, "content": turn.content} for turn in request.conversation_history ] # Convert tool results to dict format tool_results = None if request.tool_results: tool_results = [ { "tool_name": tr.tool_name, "content": tr.content, "status": tr.status, } for tr in request.tool_results ] try: result = await engine.assemble_context( project_id=request.project_id, agent_id=request.agent_id, query=request.query, model=request.model, max_tokens=request.max_tokens, system_prompt=request.system_prompt, task_description=request.task_description, knowledge_query=request.knowledge_query, knowledge_limit=request.knowledge_limit, conversation_history=conversation_history, tool_results=tool_results, compress=request.compress, use_cache=request.use_cache, ) # Calculate budget usage percentage budget = await engine.get_budget_for_model(request.model, request.max_tokens) budget_used_percent = (result.total_tokens / budget.total) * 100 # Check if compression was applied (from metadata if available) was_compressed = result.metadata.get("compressed_contexts", 0) > 0 return AssembledContextResponse( content=result.content, total_tokens=result.total_tokens, context_count=result.context_count, compressed=was_compressed, budget_used_percent=round(budget_used_percent, 2), metadata={ "model": request.model, "query": request.query, "knowledge_included": bool(request.knowledge_query), "conversation_turns": len(request.conversation_history or []), "excluded_count": result.excluded_count, "assembly_time_ms": result.assembly_time_ms, }, ) except AssemblyTimeoutError as e: raise HTTPException( status_code=status.HTTP_504_GATEWAY_TIMEOUT, detail=f"Context assembly timed out: {e}", ) from e except BudgetExceededError as e: raise HTTPException( status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, detail=f"Token budget exceeded: {e}", ) from e except Exception as e: logger.exception("Context assembly failed") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Context assembly failed: {e}", ) from e @router.post( "/count-tokens", response_model=TokenCountResponse, summary="Count Tokens", description="Count tokens in content using the LLM Gateway.", ) async def count_tokens( request: TokenCountRequest, engine: ContextEngine = Depends(get_context_engine), ) -> TokenCountResponse: """Count tokens in content.""" try: count = await engine.count_tokens( content=request.content, model=request.model, ) return TokenCountResponse( token_count=count, model=request.model, ) except Exception as e: logger.warning(f"Token counting failed: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Token counting failed: {e}", ) from e @router.get( "/budget/{model}", response_model=BudgetInfoResponse, summary="Get Token Budget", description="Get token budget allocation for a specific model.", ) async def get_budget( model: str, max_tokens: Annotated[int | None, Query(description="Custom max tokens")] = None, engine: ContextEngine = Depends(get_context_engine), ) -> BudgetInfoResponse: """Get token budget information for a model.""" budget = await engine.get_budget_for_model(model, max_tokens) return BudgetInfoResponse( model=model, total_tokens=budget.total, system_tokens=budget.system, knowledge_tokens=budget.knowledge, conversation_tokens=budget.conversation, tool_tokens=budget.tools, response_reserve=budget.response_reserve, ) @router.get( "/stats", response_model=ContextEngineStatsResponse, summary="Engine Statistics", description="Get context engine statistics and configuration.", ) async def get_stats( current_user: User = Depends(require_superuser), engine: ContextEngine = Depends(get_context_engine), ) -> ContextEngineStatsResponse: """Get engine statistics.""" stats = await engine.get_stats() return ContextEngineStatsResponse( cache=stats.get("cache", {}), settings=stats.get("settings", {}), ) @router.post( "/cache/invalidate", status_code=status.HTTP_204_NO_CONTENT, summary="Invalidate Cache (Admin Only)", description="Invalidate context cache entries.", ) async def invalidate_cache( project_id: Annotated[ str | None, Query(description="Project to invalidate") ] = None, pattern: Annotated[str | None, Query(description="Pattern to match")] = None, current_user: User = Depends(require_superuser), engine: ContextEngine = Depends(get_context_engine), ) -> None: """Invalidate cache entries.""" logger.info( "Cache invalidation by user %s: project=%s pattern=%s", current_user.id, project_id, pattern, ) await engine.invalidate_cache(project_id=project_id, pattern=pattern)