From 49359b141602297cdd81b71a7f6419f943d17779 Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Mon, 5 Jan 2026 01:02:33 +0100 Subject: [PATCH] feat(api): add Context Management API and routes - Introduced a new `context` module and its endpoints for Context Management. - Added `/context` route to the API router for assembling LLM context, token counting, budget management, and cache invalidation. - Implemented health checks, context assembly, token counting, and caching operations in the Context Management Engine. - Included schemas for request/response models and tightened error handling for context-related operations. --- backend/app/api/main.py | 4 + backend/app/api/routes/context.py | 411 ++++++++++++++++++++++++++++++ 2 files changed, 415 insertions(+) create mode 100644 backend/app/api/routes/context.py diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 16e2594..6e2a2f7 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -5,6 +5,7 @@ from app.api.routes import ( agent_types, agents, auth, + context, events, issues, mcp, @@ -35,6 +36,9 @@ api_router.include_router(events.router, tags=["Events"]) # MCP (Model Context Protocol) router api_router.include_router(mcp.router, prefix="/mcp", tags=["MCP"]) +# Context Management Engine router +api_router.include_router(context.router, prefix="/context", tags=["Context"]) + # Syndarix domain routers api_router.include_router(projects.router, prefix="/projects", tags=["Projects"]) api_router.include_router( diff --git a/backend/app/api/routes/context.py b/backend/app/api/routes/context.py new file mode 100644 index 0000000..88acfba --- /dev/null +++ b/backend/app/api/routes/context.py @@ -0,0 +1,411 @@ +""" +Context Management API Endpoints. + +Provides REST endpoints for context assembly and optimization +for LLM requests using the ContextEngine. +""" + +import logging +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from pydantic import BaseModel, Field + +from app.api.dependencies.permissions import require_superuser +from app.models.user import User +from app.services.context import ( + AssemblyTimeoutError, + BudgetExceededError, + ContextEngine, + ContextSettings, + create_context_engine, + get_context_settings, +) +from app.services.mcp import MCPClientManager, get_mcp_client + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# ============================================================================ +# Singleton Engine Management +# ============================================================================ + +_context_engine: ContextEngine | None = None + + +def _get_or_create_engine( + mcp: MCPClientManager, + settings: ContextSettings | None = None, +) -> ContextEngine: + """Get or create the singleton ContextEngine.""" + global _context_engine + if _context_engine is None: + _context_engine = create_context_engine( + mcp_manager=mcp, + redis=None, # Optional: add Redis caching later + settings=settings or get_context_settings(), + ) + logger.info("ContextEngine initialized") + else: + # Ensure MCP manager is up to date + _context_engine.set_mcp_manager(mcp) + return _context_engine + + +async def get_context_engine( + mcp: MCPClientManager = Depends(get_mcp_client), +) -> ContextEngine: + """FastAPI dependency to get the ContextEngine.""" + return _get_or_create_engine(mcp) + + +# ============================================================================ +# Request/Response Schemas +# ============================================================================ + + +class ConversationTurn(BaseModel): + """A single conversation turn.""" + + role: str = Field(..., description="Role: 'user' or 'assistant'") + content: str = Field(..., description="Message content") + + +class ToolResult(BaseModel): + """A tool execution result.""" + + tool_name: str = Field(..., description="Name of the tool") + content: str | dict[str, Any] = Field(..., description="Tool result content") + status: str = Field(default="success", description="Execution status") + + +class AssembleContextRequest(BaseModel): + """Request to assemble context for an LLM request.""" + + project_id: str = Field(..., description="Project identifier") + agent_id: str = Field(..., description="Agent identifier") + query: str = Field(..., description="User's query or current request") + model: str = Field( + default="claude-3-sonnet", + description="Target model name", + ) + max_tokens: int | None = Field( + None, + description="Maximum context tokens (uses model default if None)", + ) + system_prompt: str | None = Field( + None, + description="System prompt/instructions", + ) + task_description: str | None = Field( + None, + description="Current task description", + ) + knowledge_query: str | None = Field( + None, + description="Query for knowledge base search", + ) + knowledge_limit: int = Field( + default=10, + ge=1, + le=50, + description="Max number of knowledge results", + ) + conversation_history: list[ConversationTurn] | None = Field( + None, + description="Previous conversation turns", + ) + tool_results: list[ToolResult] | None = Field( + None, + description="Tool execution results to include", + ) + compress: bool = Field( + default=True, + description="Whether to apply compression", + ) + use_cache: bool = Field( + default=True, + description="Whether to use caching", + ) + + +class AssembledContextResponse(BaseModel): + """Response containing assembled context.""" + + content: str = Field(..., description="Assembled context content") + total_tokens: int = Field(..., description="Total token count") + context_count: int = Field(..., description="Number of context items included") + compressed: bool = Field(..., description="Whether compression was applied") + budget_used_percent: float = Field( + ..., + description="Percentage of token budget used", + ) + metadata: dict[str, Any] = Field( + default_factory=dict, + description="Additional metadata", + ) + + +class TokenCountRequest(BaseModel): + """Request to count tokens in content.""" + + content: str = Field(..., description="Content to count tokens in") + model: str | None = Field( + None, + description="Model for model-specific tokenization", + ) + + +class TokenCountResponse(BaseModel): + """Response containing token count.""" + + token_count: int = Field(..., description="Number of tokens") + model: str | None = Field(None, description="Model used for counting") + + +class BudgetInfoResponse(BaseModel): + """Response containing budget information for a model.""" + + model: str = Field(..., description="Model name") + total_tokens: int = Field(..., description="Total token budget") + system_tokens: int = Field(..., description="Tokens reserved for system") + knowledge_tokens: int = Field(..., description="Tokens for knowledge") + conversation_tokens: int = Field(..., description="Tokens for conversation") + tool_tokens: int = Field(..., description="Tokens for tool results") + response_reserve: int = Field(..., description="Tokens reserved for response") + + +class ContextEngineStatsResponse(BaseModel): + """Response containing engine statistics.""" + + cache: dict[str, Any] = Field(..., description="Cache statistics") + settings: dict[str, Any] = Field(..., description="Current settings") + + +class HealthResponse(BaseModel): + """Health check response.""" + + status: str = Field(..., description="Health status") + mcp_connected: bool = Field(..., description="Whether MCP is connected") + cache_enabled: bool = Field(..., description="Whether caching is enabled") + + +# ============================================================================ +# Endpoints +# ============================================================================ + + +@router.get( + "/health", + response_model=HealthResponse, + summary="Context Engine Health", + description="Check health status of the context engine.", +) +async def health_check( + engine: ContextEngine = Depends(get_context_engine), +) -> HealthResponse: + """Check context engine health.""" + stats = await engine.get_stats() + return HealthResponse( + status="healthy", + mcp_connected=engine._mcp is not None, + cache_enabled=stats.get("settings", {}).get("cache_enabled", False), + ) + + +@router.post( + "/assemble", + response_model=AssembledContextResponse, + summary="Assemble Context", + description="Assemble optimized context for an LLM request.", +) +async def assemble_context( + request: AssembleContextRequest, + current_user: User = Depends(require_superuser), + engine: ContextEngine = Depends(get_context_engine), +) -> AssembledContextResponse: + """ + Assemble optimized context for an LLM request. + + This endpoint gathers context from various sources, scores and ranks them, + compresses if needed, and formats for the target model. + """ + logger.info( + "Context assembly for project=%s agent=%s by user=%s", + request.project_id, + request.agent_id, + current_user.id, + ) + + # Convert conversation history to dict format + conversation_history = None + if request.conversation_history: + conversation_history = [ + {"role": turn.role, "content": turn.content} + for turn in request.conversation_history + ] + + # Convert tool results to dict format + tool_results = None + if request.tool_results: + tool_results = [ + { + "tool_name": tr.tool_name, + "content": tr.content, + "status": tr.status, + } + for tr in request.tool_results + ] + + try: + result = await engine.assemble_context( + project_id=request.project_id, + agent_id=request.agent_id, + query=request.query, + model=request.model, + max_tokens=request.max_tokens, + system_prompt=request.system_prompt, + task_description=request.task_description, + knowledge_query=request.knowledge_query, + knowledge_limit=request.knowledge_limit, + conversation_history=conversation_history, + tool_results=tool_results, + compress=request.compress, + use_cache=request.use_cache, + ) + + # Calculate budget usage percentage + budget = await engine.get_budget_for_model(request.model, request.max_tokens) + budget_used_percent = (result.total_tokens / budget.total) * 100 + + # Check if compression was applied (from metadata if available) + was_compressed = result.metadata.get("compressed_contexts", 0) > 0 + + return AssembledContextResponse( + content=result.content, + total_tokens=result.total_tokens, + context_count=result.context_count, + compressed=was_compressed, + budget_used_percent=round(budget_used_percent, 2), + metadata={ + "model": request.model, + "query": request.query, + "knowledge_included": bool(request.knowledge_query), + "conversation_turns": len(request.conversation_history or []), + "excluded_count": result.excluded_count, + "assembly_time_ms": result.assembly_time_ms, + }, + ) + + except AssemblyTimeoutError as e: + raise HTTPException( + status_code=status.HTTP_504_GATEWAY_TIMEOUT, + detail=f"Context assembly timed out: {e}", + ) from e + except BudgetExceededError as e: + raise HTTPException( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + detail=f"Token budget exceeded: {e}", + ) from e + except Exception as e: + logger.exception("Context assembly failed") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Context assembly failed: {e}", + ) from e + + +@router.post( + "/count-tokens", + response_model=TokenCountResponse, + summary="Count Tokens", + description="Count tokens in content using the LLM Gateway.", +) +async def count_tokens( + request: TokenCountRequest, + engine: ContextEngine = Depends(get_context_engine), +) -> TokenCountResponse: + """Count tokens in content.""" + try: + count = await engine.count_tokens( + content=request.content, + model=request.model, + ) + return TokenCountResponse( + token_count=count, + model=request.model, + ) + except Exception as e: + logger.warning(f"Token counting failed: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Token counting failed: {e}", + ) from e + + +@router.get( + "/budget/{model}", + response_model=BudgetInfoResponse, + summary="Get Token Budget", + description="Get token budget allocation for a specific model.", +) +async def get_budget( + model: str, + max_tokens: Annotated[int | None, Query(description="Custom max tokens")] = None, + engine: ContextEngine = Depends(get_context_engine), +) -> BudgetInfoResponse: + """Get token budget information for a model.""" + budget = await engine.get_budget_for_model(model, max_tokens) + return BudgetInfoResponse( + model=model, + total_tokens=budget.total, + system_tokens=budget.system, + knowledge_tokens=budget.knowledge, + conversation_tokens=budget.conversation, + tool_tokens=budget.tools, + response_reserve=budget.response_reserve, + ) + + +@router.get( + "/stats", + response_model=ContextEngineStatsResponse, + summary="Engine Statistics", + description="Get context engine statistics and configuration.", +) +async def get_stats( + current_user: User = Depends(require_superuser), + engine: ContextEngine = Depends(get_context_engine), +) -> ContextEngineStatsResponse: + """Get engine statistics.""" + stats = await engine.get_stats() + return ContextEngineStatsResponse( + cache=stats.get("cache", {}), + settings=stats.get("settings", {}), + ) + + +@router.post( + "/cache/invalidate", + status_code=status.HTTP_204_NO_CONTENT, + summary="Invalidate Cache (Admin Only)", + description="Invalidate context cache entries.", +) +async def invalidate_cache( + project_id: Annotated[ + str | None, Query(description="Project to invalidate") + ] = None, + pattern: Annotated[str | None, Query(description="Pattern to match")] = None, + current_user: User = Depends(require_superuser), + engine: ContextEngine = Depends(get_context_engine), +) -> None: + """Invalidate cache entries.""" + logger.info( + "Cache invalidation by user %s: project=%s pattern=%s", + current_user.id, + project_id, + pattern, + ) + await engine.invalidate_cache(project_id=project_id, pattern=pattern)