- Introduced a new `context` module and its endpoints for Context Management. - Added `/context` route to the API router for assembling LLM context, token counting, budget management, and cache invalidation. - Implemented health checks, context assembly, token counting, and caching operations in the Context Management Engine. - Included schemas for request/response models and tightened error handling for context-related operations.
412 lines
13 KiB
Python
412 lines
13 KiB
Python
"""
|
|
Context Management API Endpoints.
|
|
|
|
Provides REST endpoints for context assembly and optimization
|
|
for LLM requests using the ContextEngine.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Annotated, Any
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.api.dependencies.permissions import require_superuser
|
|
from app.models.user import User
|
|
from app.services.context import (
|
|
AssemblyTimeoutError,
|
|
BudgetExceededError,
|
|
ContextEngine,
|
|
ContextSettings,
|
|
create_context_engine,
|
|
get_context_settings,
|
|
)
|
|
from app.services.mcp import MCPClientManager, get_mcp_client
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# ============================================================================
|
|
# Singleton Engine Management
|
|
# ============================================================================
|
|
|
|
_context_engine: ContextEngine | None = None
|
|
|
|
|
|
def _get_or_create_engine(
|
|
mcp: MCPClientManager,
|
|
settings: ContextSettings | None = None,
|
|
) -> ContextEngine:
|
|
"""Get or create the singleton ContextEngine."""
|
|
global _context_engine
|
|
if _context_engine is None:
|
|
_context_engine = create_context_engine(
|
|
mcp_manager=mcp,
|
|
redis=None, # Optional: add Redis caching later
|
|
settings=settings or get_context_settings(),
|
|
)
|
|
logger.info("ContextEngine initialized")
|
|
else:
|
|
# Ensure MCP manager is up to date
|
|
_context_engine.set_mcp_manager(mcp)
|
|
return _context_engine
|
|
|
|
|
|
async def get_context_engine(
|
|
mcp: MCPClientManager = Depends(get_mcp_client),
|
|
) -> ContextEngine:
|
|
"""FastAPI dependency to get the ContextEngine."""
|
|
return _get_or_create_engine(mcp)
|
|
|
|
|
|
# ============================================================================
|
|
# Request/Response Schemas
|
|
# ============================================================================
|
|
|
|
|
|
class ConversationTurn(BaseModel):
|
|
"""A single conversation turn."""
|
|
|
|
role: str = Field(..., description="Role: 'user' or 'assistant'")
|
|
content: str = Field(..., description="Message content")
|
|
|
|
|
|
class ToolResult(BaseModel):
|
|
"""A tool execution result."""
|
|
|
|
tool_name: str = Field(..., description="Name of the tool")
|
|
content: str | dict[str, Any] = Field(..., description="Tool result content")
|
|
status: str = Field(default="success", description="Execution status")
|
|
|
|
|
|
class AssembleContextRequest(BaseModel):
|
|
"""Request to assemble context for an LLM request."""
|
|
|
|
project_id: str = Field(..., description="Project identifier")
|
|
agent_id: str = Field(..., description="Agent identifier")
|
|
query: str = Field(..., description="User's query or current request")
|
|
model: str = Field(
|
|
default="claude-3-sonnet",
|
|
description="Target model name",
|
|
)
|
|
max_tokens: int | None = Field(
|
|
None,
|
|
description="Maximum context tokens (uses model default if None)",
|
|
)
|
|
system_prompt: str | None = Field(
|
|
None,
|
|
description="System prompt/instructions",
|
|
)
|
|
task_description: str | None = Field(
|
|
None,
|
|
description="Current task description",
|
|
)
|
|
knowledge_query: str | None = Field(
|
|
None,
|
|
description="Query for knowledge base search",
|
|
)
|
|
knowledge_limit: int = Field(
|
|
default=10,
|
|
ge=1,
|
|
le=50,
|
|
description="Max number of knowledge results",
|
|
)
|
|
conversation_history: list[ConversationTurn] | None = Field(
|
|
None,
|
|
description="Previous conversation turns",
|
|
)
|
|
tool_results: list[ToolResult] | None = Field(
|
|
None,
|
|
description="Tool execution results to include",
|
|
)
|
|
compress: bool = Field(
|
|
default=True,
|
|
description="Whether to apply compression",
|
|
)
|
|
use_cache: bool = Field(
|
|
default=True,
|
|
description="Whether to use caching",
|
|
)
|
|
|
|
|
|
class AssembledContextResponse(BaseModel):
|
|
"""Response containing assembled context."""
|
|
|
|
content: str = Field(..., description="Assembled context content")
|
|
total_tokens: int = Field(..., description="Total token count")
|
|
context_count: int = Field(..., description="Number of context items included")
|
|
compressed: bool = Field(..., description="Whether compression was applied")
|
|
budget_used_percent: float = Field(
|
|
...,
|
|
description="Percentage of token budget used",
|
|
)
|
|
metadata: dict[str, Any] = Field(
|
|
default_factory=dict,
|
|
description="Additional metadata",
|
|
)
|
|
|
|
|
|
class TokenCountRequest(BaseModel):
|
|
"""Request to count tokens in content."""
|
|
|
|
content: str = Field(..., description="Content to count tokens in")
|
|
model: str | None = Field(
|
|
None,
|
|
description="Model for model-specific tokenization",
|
|
)
|
|
|
|
|
|
class TokenCountResponse(BaseModel):
|
|
"""Response containing token count."""
|
|
|
|
token_count: int = Field(..., description="Number of tokens")
|
|
model: str | None = Field(None, description="Model used for counting")
|
|
|
|
|
|
class BudgetInfoResponse(BaseModel):
|
|
"""Response containing budget information for a model."""
|
|
|
|
model: str = Field(..., description="Model name")
|
|
total_tokens: int = Field(..., description="Total token budget")
|
|
system_tokens: int = Field(..., description="Tokens reserved for system")
|
|
knowledge_tokens: int = Field(..., description="Tokens for knowledge")
|
|
conversation_tokens: int = Field(..., description="Tokens for conversation")
|
|
tool_tokens: int = Field(..., description="Tokens for tool results")
|
|
response_reserve: int = Field(..., description="Tokens reserved for response")
|
|
|
|
|
|
class ContextEngineStatsResponse(BaseModel):
|
|
"""Response containing engine statistics."""
|
|
|
|
cache: dict[str, Any] = Field(..., description="Cache statistics")
|
|
settings: dict[str, Any] = Field(..., description="Current settings")
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
"""Health check response."""
|
|
|
|
status: str = Field(..., description="Health status")
|
|
mcp_connected: bool = Field(..., description="Whether MCP is connected")
|
|
cache_enabled: bool = Field(..., description="Whether caching is enabled")
|
|
|
|
|
|
# ============================================================================
|
|
# Endpoints
|
|
# ============================================================================
|
|
|
|
|
|
@router.get(
|
|
"/health",
|
|
response_model=HealthResponse,
|
|
summary="Context Engine Health",
|
|
description="Check health status of the context engine.",
|
|
)
|
|
async def health_check(
|
|
engine: ContextEngine = Depends(get_context_engine),
|
|
) -> HealthResponse:
|
|
"""Check context engine health."""
|
|
stats = await engine.get_stats()
|
|
return HealthResponse(
|
|
status="healthy",
|
|
mcp_connected=engine._mcp is not None,
|
|
cache_enabled=stats.get("settings", {}).get("cache_enabled", False),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/assemble",
|
|
response_model=AssembledContextResponse,
|
|
summary="Assemble Context",
|
|
description="Assemble optimized context for an LLM request.",
|
|
)
|
|
async def assemble_context(
|
|
request: AssembleContextRequest,
|
|
current_user: User = Depends(require_superuser),
|
|
engine: ContextEngine = Depends(get_context_engine),
|
|
) -> AssembledContextResponse:
|
|
"""
|
|
Assemble optimized context for an LLM request.
|
|
|
|
This endpoint gathers context from various sources, scores and ranks them,
|
|
compresses if needed, and formats for the target model.
|
|
"""
|
|
logger.info(
|
|
"Context assembly for project=%s agent=%s by user=%s",
|
|
request.project_id,
|
|
request.agent_id,
|
|
current_user.id,
|
|
)
|
|
|
|
# Convert conversation history to dict format
|
|
conversation_history = None
|
|
if request.conversation_history:
|
|
conversation_history = [
|
|
{"role": turn.role, "content": turn.content}
|
|
for turn in request.conversation_history
|
|
]
|
|
|
|
# Convert tool results to dict format
|
|
tool_results = None
|
|
if request.tool_results:
|
|
tool_results = [
|
|
{
|
|
"tool_name": tr.tool_name,
|
|
"content": tr.content,
|
|
"status": tr.status,
|
|
}
|
|
for tr in request.tool_results
|
|
]
|
|
|
|
try:
|
|
result = await engine.assemble_context(
|
|
project_id=request.project_id,
|
|
agent_id=request.agent_id,
|
|
query=request.query,
|
|
model=request.model,
|
|
max_tokens=request.max_tokens,
|
|
system_prompt=request.system_prompt,
|
|
task_description=request.task_description,
|
|
knowledge_query=request.knowledge_query,
|
|
knowledge_limit=request.knowledge_limit,
|
|
conversation_history=conversation_history,
|
|
tool_results=tool_results,
|
|
compress=request.compress,
|
|
use_cache=request.use_cache,
|
|
)
|
|
|
|
# Calculate budget usage percentage
|
|
budget = await engine.get_budget_for_model(request.model, request.max_tokens)
|
|
budget_used_percent = (result.total_tokens / budget.total) * 100
|
|
|
|
# Check if compression was applied (from metadata if available)
|
|
was_compressed = result.metadata.get("compressed_contexts", 0) > 0
|
|
|
|
return AssembledContextResponse(
|
|
content=result.content,
|
|
total_tokens=result.total_tokens,
|
|
context_count=result.context_count,
|
|
compressed=was_compressed,
|
|
budget_used_percent=round(budget_used_percent, 2),
|
|
metadata={
|
|
"model": request.model,
|
|
"query": request.query,
|
|
"knowledge_included": bool(request.knowledge_query),
|
|
"conversation_turns": len(request.conversation_history or []),
|
|
"excluded_count": result.excluded_count,
|
|
"assembly_time_ms": result.assembly_time_ms,
|
|
},
|
|
)
|
|
|
|
except AssemblyTimeoutError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_504_GATEWAY_TIMEOUT,
|
|
detail=f"Context assembly timed out: {e}",
|
|
) from e
|
|
except BudgetExceededError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
|
detail=f"Token budget exceeded: {e}",
|
|
) from e
|
|
except Exception as e:
|
|
logger.exception("Context assembly failed")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Context assembly failed: {e}",
|
|
) from e
|
|
|
|
|
|
@router.post(
|
|
"/count-tokens",
|
|
response_model=TokenCountResponse,
|
|
summary="Count Tokens",
|
|
description="Count tokens in content using the LLM Gateway.",
|
|
)
|
|
async def count_tokens(
|
|
request: TokenCountRequest,
|
|
engine: ContextEngine = Depends(get_context_engine),
|
|
) -> TokenCountResponse:
|
|
"""Count tokens in content."""
|
|
try:
|
|
count = await engine.count_tokens(
|
|
content=request.content,
|
|
model=request.model,
|
|
)
|
|
return TokenCountResponse(
|
|
token_count=count,
|
|
model=request.model,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Token counting failed: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Token counting failed: {e}",
|
|
) from e
|
|
|
|
|
|
@router.get(
|
|
"/budget/{model}",
|
|
response_model=BudgetInfoResponse,
|
|
summary="Get Token Budget",
|
|
description="Get token budget allocation for a specific model.",
|
|
)
|
|
async def get_budget(
|
|
model: str,
|
|
max_tokens: Annotated[int | None, Query(description="Custom max tokens")] = None,
|
|
engine: ContextEngine = Depends(get_context_engine),
|
|
) -> BudgetInfoResponse:
|
|
"""Get token budget information for a model."""
|
|
budget = await engine.get_budget_for_model(model, max_tokens)
|
|
return BudgetInfoResponse(
|
|
model=model,
|
|
total_tokens=budget.total,
|
|
system_tokens=budget.system,
|
|
knowledge_tokens=budget.knowledge,
|
|
conversation_tokens=budget.conversation,
|
|
tool_tokens=budget.tools,
|
|
response_reserve=budget.response_reserve,
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/stats",
|
|
response_model=ContextEngineStatsResponse,
|
|
summary="Engine Statistics",
|
|
description="Get context engine statistics and configuration.",
|
|
)
|
|
async def get_stats(
|
|
current_user: User = Depends(require_superuser),
|
|
engine: ContextEngine = Depends(get_context_engine),
|
|
) -> ContextEngineStatsResponse:
|
|
"""Get engine statistics."""
|
|
stats = await engine.get_stats()
|
|
return ContextEngineStatsResponse(
|
|
cache=stats.get("cache", {}),
|
|
settings=stats.get("settings", {}),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/cache/invalidate",
|
|
status_code=status.HTTP_204_NO_CONTENT,
|
|
summary="Invalidate Cache (Admin Only)",
|
|
description="Invalidate context cache entries.",
|
|
)
|
|
async def invalidate_cache(
|
|
project_id: Annotated[
|
|
str | None, Query(description="Project to invalidate")
|
|
] = None,
|
|
pattern: Annotated[str | None, Query(description="Pattern to match")] = None,
|
|
current_user: User = Depends(require_superuser),
|
|
engine: ContextEngine = Depends(get_context_engine),
|
|
) -> None:
|
|
"""Invalidate cache entries."""
|
|
logger.info(
|
|
"Cache invalidation by user %s: project=%s pattern=%s",
|
|
current_user.id,
|
|
project_id,
|
|
pattern,
|
|
)
|
|
await engine.invalidate_cache(project_id=project_id, pattern=pattern)
|