58 Commits
main ... dev

Author SHA1 Message Date
Felipe Cardoso
4ad3d20cf2 chore(agents): update sort_order values for agent types to improve logical grouping 2026-01-06 18:43:29 +01:00
Felipe Cardoso
8623eb56f5 feat(agents): add sorting by sort_order and include category & display fields in agent actions
- Implemented sorting of agent types by `sort_order` in Agents page.
- Added support for category, icon, color, sort_order, typical_tasks, and collaboration_hints fields in agent creation and update actions.
2026-01-06 18:20:04 +01:00
Felipe Cardoso
3cb6c8d13b feat(agents): implement grid/list view toggle and enhance filters
- Added grid and list view modes to AgentTypeList with user preference management.
- Enhanced filtering with category selection alongside existing search and status filters.
- Updated AgentTypeDetail with category badges and improved layout.
- Added unit tests for grid/list views and category filtering in AgentTypeList.
- Introduced `@radix-ui/react-toggle-group` for view mode toggle in AgentTypeList.
2026-01-06 18:17:46 +01:00
Felipe Cardoso
8e16e2645e test(forms): add unit tests for FormTextarea and FormSelect components
- Add comprehensive test coverage for FormTextarea and FormSelect components to validate rendering, accessibility, props forwarding, error handling, and behavior.
- Introduced function-scoped fixtures in e2e tests to ensure test isolation and address event loop issues with pytest-asyncio and SQLAlchemy.
2026-01-06 17:54:49 +01:00
Felipe Cardoso
82c3a6ba47 chore(makefiles): add format-check target and unify formatting logic
- Introduced `format-check` for verification without modification in `llm-gateway` and `knowledge-base` Makefiles.
- Updated `validate` to include `format-check`.
- Added `format-all` to root Makefile for consistent formatting across all components.
- Unexported `VIRTUAL_ENV` to prevent virtual environment warnings.
2026-01-06 17:25:21 +01:00
Felipe Cardoso
b6c38cac88 refactor(llm-gateway): adjust if-condition formatting for thread safety check
Updated line breaks and indentation for improved readability in circuit state recovery logic, ensuring consistent style.
2026-01-06 17:20:49 +01:00
Felipe Cardoso
51404216ae refactor(knowledge-base mcp server): adjust formatting for consistency and readability
Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
2026-01-06 17:20:31 +01:00
Felipe Cardoso
3f23bc3db3 refactor(migrations): replace hardcoded database URL with configurable environment variable and update command syntax to use consistent quoting style 2026-01-06 17:19:28 +01:00
Felipe Cardoso
a0ec5fa2cc test(agents): add validation tests for category and display fields
Added comprehensive unit and API tests to validate AgentType category and display fields:
- Category validation for valid, null, and invalid values
- Icon, color, and sort_order field constraints
- Typical tasks and collaboration hints handling (stripping, removing empty strings, normalization)
- New API tests for field creation, filtering, updating, and grouping
2026-01-06 17:19:21 +01:00
Felipe Cardoso
f262d08be2 test(project-events): add tests for demo configuration defaults
Added unit test cases to verify that the `demo.enabled` field is properly initialized to `false` in configurations and mock overrides.
2026-01-06 17:08:35 +01:00
Felipe Cardoso
b3f371e0a3 test(agents): add tests for AgentTypeForm enhancements
Added unit tests to cover new AgentTypeForm features:
- Category & Display fields (category select, sort order, icon, color)
- Typical Tasks management (add, remove, and prevent duplicates)
- Collaboration Hints management (add, remove, lowercase, and prevent duplicates)

This ensures thorough validation of recent form updates.
2026-01-06 17:07:21 +01:00
Felipe Cardoso
93cc37224c feat(agents): add category and display fields to AgentTypeForm
Add new "Category & Display" card in Basic Info tab with:
- Category dropdown to select agent category
- Sort order input for display ordering
- Icon text input with Lucide icon name
- Color picker with hex input and visual color selector
- Typical tasks tag input for agent capabilities
- Collaboration hints tag input for agent relationships

Updates include:
- TAB_FIELD_MAPPING with new field mappings
- State and handlers for typical_tasks and collaboration_hints
- Fix tests to use getAllByRole for multiple Add buttons

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 16:21:28 +01:00
Felipe Cardoso
5717bffd63 feat(agents): add frontend types and validation for category fields
Frontend changes to support new AgentType category and display fields:

Types (agentTypes.ts):
- Add AgentTypeCategory union type with 8 categories
- Add CATEGORY_METADATA constant with labels, descriptions, colors
- Update all interfaces with new fields (category, icon, color, etc.)
- Add AgentTypeGroupedResponse type

Validation (agentType.ts):
- Add AGENT_TYPE_CATEGORIES constant with metadata
- Add AVAILABLE_ICONS constant for icon picker
- Add COLOR_PALETTE constant for color selection
- Update agentTypeFormSchema with new field validators
- Update defaultAgentTypeValues with new fields

Form updates:
- Transform function now maps category and display fields from API

Test updates:
- Add new fields to mock AgentTypeResponse objects

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 16:16:21 +01:00
Felipe Cardoso
9339ea30a1 feat(agents): add category and display fields to AgentType model
Add 6 new fields to AgentType for better organization and UI display:
- category: enum for grouping (development, design, quality, etc.)
- icon: Lucide icon identifier for UI
- color: hex color code for visual distinction
- sort_order: display ordering within categories
- typical_tasks: list of tasks the agent excels at
- collaboration_hints: agent slugs that work well together

Backend changes:
- Add AgentTypeCategory enum to enums.py
- Update AgentType model with 6 new columns and indexes
- Update schemas with validators for new fields
- Add category filter and /grouped endpoint to routes
- Update CRUD with get_grouped_by_category method
- Update seed data with categories for all 27 agents
- Add migration 0007

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 16:11:22 +01:00
Felipe Cardoso
79cb6bfd7b feat(agents): comprehensive agent types with rich personalities
Major revamp of agent types based on SOTA personality design research:
- Expanded from 6 to 27 specialized agent types
- Rich personality prompts following Anthropic and CrewAI best practices
- Each agent has structured prompt with Core Identity, Expertise,
  Principles, and Scenario Handling sections

Agent Categories:
- Core Development (8): Product Owner, PM, BA, Architect, Full Stack,
  Backend, Frontend, Mobile Engineers
- Design (2): UI/UX Designer, UX Researcher
- Quality & Operations (3): QA, DevOps, Security Engineers
- AI/ML (5): AI/ML Engineer, Researcher, CV, NLP, MLOps Engineers
- Data (2): Data Scientist, Data Engineer
- Leadership (2): Technical Lead, Scrum Master
- Domain Specialists (5): Financial, Healthcare, Scientific,
  Behavioral Psychology Experts, Technical Writer

Research applied:
- Anthropic Claude persona design guidelines
- CrewAI role/backstory/goal patterns
- Role prompting research on detailed vs generic personas
- Temperature tuning per agent type (0.2-0.7 based on role)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 14:25:13 +01:00
Felipe Cardoso
45025bb2f1 fix(forms): handle nullable fields in deepMergeWithDefaults
When default value is null but source has a value (e.g., description
field), the merge was discarding the source value because typeof null
!== typeof string. Now properly accepts source values for nullable fields.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 13:54:18 +01:00
Felipe Cardoso
3c6b14d2bf refactor(forms): extract reusable form utilities and components
- Add getFirstValidationError utility for nested FieldErrors extraction
- Add mergeWithDefaults utilities (deepMergeWithDefaults, type guards)
- Add useValidationErrorHandler hook for toast + tab navigation
- Add FormSelect component with Controller integration
- Add FormTextarea component with register integration
- Refactor AgentTypeForm to use new utilities
- Remove verbose debug logging (now handled by hook)
- Add comprehensive tests (53 new tests, 100 total)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 13:50:36 +01:00
Felipe Cardoso
6b21a6fadd debug(agents): add comprehensive logging to form submission
Adds console.log statements throughout the form submission flow:
- Form submit triggered
- Current form values
- Form state (isDirty, isValid, isSubmitting, errors)
- Validation pass/fail
- onSubmit call and completion

This will help diagnose why the save button appears to do nothing.
Check browser console for '[AgentTypeForm]' logs.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 11:56:54 +01:00
Felipe Cardoso
600657adc4 fix(agents): properly initialize form with API data defaults
Root cause: The demo data's model_params was missing `top_p`, but the
Zod schema required all three fields (temperature, max_tokens, top_p).
This caused silent validation failures when editing agent types.

Fixes:
1. Add getInitialValues() that ensures all required fields have defaults
2. Handle nested validation errors in handleFormError (e.g., model_params.top_p)
3. Add useEffect to reset form when agentType changes
4. Add console.error logging for debugging validation failures
5. Update demo data to include top_p in all agent types

The form now properly initializes with safe defaults for any missing
fields from the API response, preventing silent validation failures.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 11:54:45 +01:00
Felipe Cardoso
c9d0d079b3 fix(frontend): show validation errors when agent type form fails
When form validation fails (e.g., personality_prompt is empty), the form
would silently not submit. Now it shows a toast with the first error
and navigates to the tab containing the error field.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 11:29:01 +01:00
Felipe Cardoso
4c8f81368c fix(docker): add NEXT_PUBLIC_API_BASE_URL to frontend containers
When running in Docker, the frontend needs to use 'http://backend:8000'
as the backend URL for Next.js rewrites. This env var is set to use
the Docker service name for proper container-to-container communication.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 09:23:50 +01:00
Felipe Cardoso
efbe91ce14 fix(frontend): use configurable backend URL in Next.js rewrite
The rewrite was using 'http://backend:8000' which only resolves inside
Docker network. When running Next.js locally (npm run dev), the hostname
'backend' doesn't exist, causing ENOTFOUND errors.

Now uses NEXT_PUBLIC_API_BASE_URL env var with fallback to localhost:8000
for local development. In Docker, set NEXT_PUBLIC_API_BASE_URL=http://backend:8000.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 09:22:44 +01:00
Felipe Cardoso
5d646779c9 fix(frontend): preserve /api prefix in Next.js rewrite
The rewrite was incorrectly configured:
- Before: /api/:path* -> http://backend:8000/:path* (strips /api)
- After: /api/:path* -> http://backend:8000/api/:path* (preserves /api)

This was causing requests to /api/v1/agent-types to be sent to
http://backend:8000/v1/agent-types instead of the correct path.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 03:12:08 +01:00
Felipe Cardoso
5a4d93df26 feat(dashboard): use real API data and add 3 more demo projects
Dashboard changes:
- Update useDashboard hook to fetch real projects from API
- Calculate stats (active projects, agents, issues) from real data
- Keep pending approvals as mock (no backend endpoint yet)

Demo data additions:
- API Gateway Modernization project (active, complex)
- Customer Analytics Dashboard project (completed)
- DevOps Pipeline Automation project (active, complex)
- Added sprints, agent instances, and issues for each new project

Total demo data: 6 projects, 14 agents, 22 issues

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 03:10:10 +01:00
Felipe Cardoso
7ef217be39 feat(demo): tie all demo projects to admin user
- Update demo_data.json to use "__admin__" as owner_email for all projects
- Add admin user lookup in load_demo_data() with special "__admin__" key
- Remove notification_email from project settings (not a valid field)

This ensures demo projects are visible to the admin user when logged in.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 03:00:07 +01:00
Felipe Cardoso
20159c5865 fix(knowledge-base): ensure pgvector extension before pool creation
register_vector() requires the vector type to exist in PostgreSQL before
it can register the type codec. Move CREATE EXTENSION to a separate
_ensure_pgvector_extension() method that runs before pool creation.

This fixes the "unknown type: public.vector" error on fresh databases.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 02:55:02 +01:00
Felipe Cardoso
f9a72fcb34 fix(models): use enum values instead of names for PostgreSQL
Add values_callable to all enum columns so SQLAlchemy serializes using
the enum's .value (lowercase) instead of .name (uppercase). PostgreSQL
enum types defined in migrations use lowercase values.

Fixes: invalid input value for enum autonomy_level: "MILESTONE"

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 02:53:45 +01:00
Felipe Cardoso
fcb0a5f86a fix(models): add explicit enum names to match migration types
SQLAlchemy's Enum() auto-generates type names from Python class names
(e.g., AutonomyLevel -> autonomylevel), but migrations defined them
with underscores (e.g., autonomy_level). This mismatch caused:

  "type 'autonomylevel' does not exist"

Added explicit name parameters to all enum columns to match the
migration-defined type names:
- autonomy_level, project_status, project_complexity, client_mode
- agent_status, sprint_status
- issue_type, issue_status, issue_priority, sync_status

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 02:48:10 +01:00
Felipe Cardoso
92782bcb05 refactor(init_db): remove demo data file and implement structured seeding
- Delete `demo_data.json` replaced by structured logic for better modularity.
- Add support for seeding default agent types and new demo data structure.
- Ensure demo mode only executes when explicitly enabled (settings.DEMO_MODE).
- Enhance logging for improved debugging during DB initialization.
2026-01-06 02:34:34 +01:00
Felipe Cardoso
1dcf99ee38 fix(memory): use deque for metrics histograms to ensure bounded memory usage
- Replace default empty list with `deque` for `memory_retrieval_latency_seconds`
- Prevents unbounded memory growth by leveraging bounded circular buffer behavior
2026-01-06 02:34:28 +01:00
Felipe Cardoso
70009676a3 fix(dashboard): disable SSE in demo mode and remove unused hooks
- Skip SSE connection in demo mode (MSW doesn't support SSE).
- Remove unused `useProjectEvents` and related real-time hooks from `Dashboard`.
- Temporarily disable activity feed SSE until a global endpoint is available.
2026-01-06 02:29:00 +01:00
Felipe Cardoso
192237e69b fix(memory): unify Outcome enum and add ABANDONED support
- Add ABANDONED value to core Outcome enum in types.py
- Replace duplicate OutcomeType class in mcp/tools.py with alias to Outcome
- Simplify mcp/service.py to use outcome directly (no more silent mapping)
- Add migration 0006 to extend PostgreSQL episode_outcome enum
- Add missing constraints to migration 0005 (ix_facts_unique_triple_global)

This fixes the semantic issue where ABANDONED outcomes were silently
converted to FAILURE, losing information about task abandonment.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 01:46:48 +01:00
Felipe Cardoso
3edce9cd26 fix(memory): address critical bugs from multi-agent review
Bug Fixes:
- Remove singleton pattern from consolidation/reflection services to
  prevent stale database session bugs (session is now passed per-request)
- Add LRU eviction to MemoryToolService._working dict (max 1000 sessions)
  to prevent unbounded memory growth
- Replace O(n) list.remove() with O(1) OrderedDict.move_to_end() in
  RetrievalCache for better performance under load
- Use deque with maxlen for metrics histograms to prevent unbounded
  memory growth (circular buffer with 10k max samples)
- Use full UUID for checkpoint IDs instead of 8-char prefix to avoid
  collision risk at scale (birthday paradox at ~50k checkpoints)

Test Updates:
- Update checkpoint test to expect 36-char UUID
- Update reflection singleton tests to expect new factory behavior
- Add reset_memory_reflection() no-op for backwards compatibility

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 18:55:32 +01:00
Felipe Cardoso
35aea2d73a perf(mcp): optimize test performance with parallel connections and reduced retries
- Connect to MCP servers concurrently instead of sequentially
- Reduce retry settings in test mode (IS_TEST=True):
  - 1 attempt instead of 3
  - 100ms retry delay instead of 1s
  - 2s timeout instead of 30-120s

Reduces MCP E2E test time from ~16s to under 1s.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 18:33:38 +01:00
Felipe Cardoso
d0f32d04f7 fix(tests): reduce TTL durations to improve test reliability
- Adjusted TTL durations and sleep intervals across memory and cache tests for consistent expiration behavior.
- Prevented test flakiness caused by timing discrepancies in token expiration and cache cleanup.
2026-01-05 18:29:02 +01:00
Felipe Cardoso
da85a8aba8 fix(memory): prevent entry metadata mutation in vector search
- Create shallow copy of VectorIndexEntry when adding similarity score
- Prevents mutation of cached entries that could corrupt shared state

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 17:39:54 +01:00
Felipe Cardoso
f8bd1011e9 security(memory): escape SQL ILIKE patterns to prevent injection
- Add _escape_like_pattern() helper to escape SQL wildcards (%, _, \)
- Apply escaping in SemanticMemory.search_facts and get_by_entity
- Apply escaping in ProceduralMemory.search and find_best_for_task

Prevents attackers from injecting SQL wildcard patterns through
user-controlled search terms.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 17:39:47 +01:00
Felipe Cardoso
f057c2f0b6 fix(memory): add thread-safe singleton initialization
- Add threading.Lock with double-check locking to ScopeManager
- Add asyncio.Lock with double-check locking to MemoryReflection
- Make reset_memory_metrics async with proper locking
- Update test fixtures to handle async reset functions

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 17:39:39 +01:00
Felipe Cardoso
33ec889fc4 fix(memory): add data integrity constraints to Fact model
- Change source_episode_ids from JSON to JSONB for PostgreSQL consistency
- Add unique constraint for global facts (project_id IS NULL)
- Add CHECK constraint ensuring reinforcement_count >= 1

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 17:39:30 +01:00
Felipe Cardoso
74b8c65741 fix(tests): move memory model tests to avoid import conflicts
Moved tests/unit/models/memory/ to tests/models/memory/ to avoid
Python import path conflicts when pytest collects all tests.

The conflict was caused by tests/models/ and tests/unit/models/ both
having __init__.py files, causing Python to confuse app.models.memory
imports.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 15:45:30 +01:00
Felipe Cardoso
b232298c61 feat(memory): add memory consolidation task and switch source_episode_ids to JSON
- Added `memory_consolidation` to the task list and updated `__all__` in test files.
- Updated `source_episode_ids` in `Fact` model to use JSON for cross-database compatibility.
- Revised related database migrations to use JSONB instead of ARRAY.
- Adjusted test concurrency in Makefile for improved test performance.
2026-01-05 15:38:52 +01:00
Felipe Cardoso
cf6291ac8e style(memory): apply ruff formatting and linting fixes
Auto-fixed linting errors and formatting issues:
- Removed unused imports (F401): pytest, Any, AnalysisType, MemoryType, OutcomeType
- Removed unused variable (F841): hooks variable in test
- Applied consistent formatting across memory service and test files

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 14:07:48 +01:00
Felipe Cardoso
e3fe0439fd docs(memory): add comprehensive memory system documentation (#101)
Add complete documentation for the Agent Memory System including:
- Architecture overview with ASCII diagram
- Memory type descriptions (working, episodic, semantic, procedural)
- Usage examples for all memory operations
- Memory scoping hierarchy explanation
- Consolidation flow documentation
- MCP tools reference
- Reflection capabilities
- Configuration reference table
- Integration with Context Engine
- Metrics reference
- Performance targets
- Troubleshooting guide
- Directory structure

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 11:03:57 +01:00
Felipe Cardoso
57680c3772 feat(memory): implement metrics and observability (#100)
Add comprehensive metrics collector for memory system with:
- Counter metrics: operations, retrievals, cache hits/misses, consolidations,
  episodes recorded, patterns/anomalies/insights detected
- Gauge metrics: item counts, memory size, cache size, procedure success rates,
  active sessions, pending consolidations
- Histogram metrics: working memory latency, retrieval latency, consolidation
  duration, embedding latency
- Prometheus format export
- Summary and cache stats helpers

31 tests covering all metric types, singleton pattern, and edge cases.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 11:00:53 +01:00
Felipe Cardoso
997cfaa03a feat(memory): implement memory reflection service (#99)
Add reflection layer for memory system with pattern detection, success/failure
factor analysis, anomaly detection, and insights generation. Enables agents to
learn from past experiences and identify optimization opportunities.

Key components:
- Pattern detection: recurring success/failure, action sequences, temporal, efficiency
- Factor analysis: action, context, timing, resource, preceding state factors
- Anomaly detection: unusual duration, token usage, failure rates, action patterns
- Insight generation: optimization, warning, learning, recommendation, trend insights

Also fixes pre-existing timezone issues in test_types.py (datetime.now() -> datetime.now(UTC)).

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 04:22:23 +01:00
Felipe Cardoso
6954774e36 feat(memory): implement caching layer for memory operations (#98)
Add comprehensive caching layer for the Agent Memory System:

- HotMemoryCache: LRU cache for frequently accessed memories
  - Python 3.12 type parameter syntax
  - Thread-safe operations with RLock
  - TTL-based expiration
  - Access count tracking for hot memory identification
  - Scoped invalidation by type, scope, or pattern

- EmbeddingCache: Cache embeddings by content hash
  - Content-hash based deduplication
  - Optional Redis backing for persistence
  - LRU eviction with configurable max size
  - CachedEmbeddingGenerator wrapper for transparent caching

- CacheManager: Unified cache management
  - Coordinates hot cache, embedding cache, and retrieval cache
  - Centralized invalidation across all caches
  - Aggregated statistics and hit rate tracking
  - Automatic cleanup scheduling
  - Cache warmup support

Performance targets:
- Cache hit rate > 80% for hot memories
- Cache operations < 1ms (memory), < 5ms (Redis)

83 new tests with comprehensive coverage.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 04:04:13 +01:00
Felipe Cardoso
30e5c68304 feat(memory): integrate memory system with context engine (#97)
## Changes

### New Context Type
- Add MEMORY to ContextType enum for agent memory context
- Create MemoryContext class with subtypes (working, episodic, semantic, procedural)
- Factory methods: from_working_memory, from_episodic_memory, from_semantic_memory, from_procedural_memory

### Memory Context Source
- MemoryContextSource service fetches relevant memories for context assembly
- Configurable fetch limits per memory type
- Parallel fetching from all memory types

### Agent Lifecycle Hooks
- AgentLifecycleManager handles spawn, pause, resume, terminate events
- spawn: Initialize working memory with optional initial state
- pause: Create checkpoint of working memory
- resume: Restore from checkpoint
- terminate: Consolidate working memory to episodic memory
- LifecycleHooks for custom extension points

### Context Engine Integration
- Add memory_query parameter to assemble_context()
- Add session_id and agent_type_id for memory scoping
- Memory budget allocation (15% by default)
- set_memory_source() for runtime configuration

### Tests
- 48 new tests for MemoryContext, MemoryContextSource, and lifecycle hooks
- All 108 memory-related tests passing
- mypy and ruff checks passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 03:49:22 +01:00
Felipe Cardoso
0b24d4c6cc feat(memory): implement MCP tools for agent memory operations (#96)
Add MCP-compatible tools that expose memory operations to agents:

Tools implemented:
- remember: Store data in working, episodic, semantic, or procedural memory
- recall: Retrieve memories by query across multiple memory types
- forget: Delete specific keys or bulk delete by pattern
- reflect: Analyze patterns in recent episodes (success/failure factors)
- get_memory_stats: Return usage statistics and breakdowns
- search_procedures: Find procedures matching trigger patterns
- record_outcome: Record task outcomes and update procedure success rates

Key components:
- tools.py: Pydantic schemas for tool argument validation with comprehensive
  field constraints (importance 0-1, TTL limits, limit ranges)
- service.py: MemoryToolService coordinating memory type operations with
  proper scoping via ToolContext (project_id, agent_instance_id, session_id)
- Lazy initialization of memory services (WorkingMemory, EpisodicMemory,
  SemanticMemory, ProceduralMemory)

Test coverage:
- 60 tests covering tool definitions, argument validation, and service
  execution paths
- Mock-based tests for all memory type interactions

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 03:32:10 +01:00
Felipe Cardoso
1670e05e0d feat(memory): implement memory consolidation service and tasks (#95)
- Add MemoryConsolidationService with Working→Episodic→Semantic/Procedural transfer
- Add Celery tasks for session and nightly consolidation
- Implement memory pruning with importance-based retention
- Add comprehensive test suite (32 tests)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 03:04:28 +01:00
Felipe Cardoso
999b7ac03f feat(memory): implement memory indexing and retrieval engine (#94)
Add comprehensive indexing and retrieval system for memory search:
- VectorIndex for semantic similarity search using cosine similarity
- TemporalIndex for time-based queries with range and recency support
- EntityIndex for entity-based lookups with multi-entity intersection
- OutcomeIndex for success/failure filtering on episodes
- MemoryIndexer as unified interface for all index types
- RetrievalEngine with hybrid search combining all indices
- RelevanceScorer for multi-signal relevance scoring
- RetrievalCache for LRU caching of search results

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 02:50:13 +01:00
Felipe Cardoso
48ecb40f18 feat(memory): implement memory scoping with hierarchy and access control (#93)
Add scope management system for hierarchical memory access:
- ScopeManager with hierarchy: Global → Project → Agent Type → Agent Instance → Session
- ScopePolicy for access control (read, write, inherit permissions)
- ScopeResolver for resolving queries across scope hierarchies with inheritance
- ScopeFilter for filtering scopes by type, project, or agent
- Access control enforcement with parent scope visibility
- Deduplication support during resolution across scopes

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 02:39:22 +01:00
Felipe Cardoso
b818f17418 feat(memory): add procedural memory implementation (Issue #92)
Implements procedural memory for learned skills and procedures:

Core functionality:
- ProceduralMemory class for procedure storage/retrieval
- record_procedure with duplicate detection and step merging
- find_matching for context-based procedure search
- record_outcome for success/failure tracking
- get_best_procedure for finding highest success rate
- update_steps for procedure refinement

Supporting modules:
- ProcedureMatcher: Keyword-based procedure matching
- MatchResult/MatchContext: Matching result types
- Success rate weighting in match scoring

Test coverage:
- 43 unit tests covering all modules
- matching.py: 97% coverage
- memory.py: 86% coverage

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 02:31:32 +01:00
Felipe Cardoso
e946787a61 feat(memory): add semantic memory implementation (Issue #91)
Implements semantic memory with fact storage, retrieval, and verification:

Core functionality:
- SemanticMemory class for fact storage/retrieval
- Fact storage as subject-predicate-object triples
- Duplicate detection with reinforcement
- Semantic search with text-based fallback
- Entity-based retrieval
- Confidence scoring and decay
- Conflict resolution

Supporting modules:
- FactExtractor: Pattern-based fact extraction from episodes
- FactVerifier: Contradiction detection and reliability scoring

Test coverage:
- 47 unit tests covering all modules
- extraction.py: 99% coverage
- verification.py: 95% coverage
- memory.py: 78% coverage

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 02:23:06 +01:00
Felipe Cardoso
3554efe66a feat(memory): add episodic memory implementation (Issue #90)
Implements the episodic memory service for storing and retrieving
agent task execution experiences. This enables learning from past
successes and failures.

Components:
- EpisodicMemory: Main service class combining recording and retrieval
- EpisodeRecorder: Handles episode creation, importance scoring
- EpisodeRetriever: Multiple retrieval strategies (recency, semantic,
  outcome, importance, task type)

Key features:
- Records task completions with context, actions, outcomes
- Calculates importance scores based on outcome, duration, lessons
- Semantic search with fallback to recency when embeddings unavailable
- Full CRUD operations with statistics and summarization
- Comprehensive unit tests (50 tests, all passing)

Closes #90

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 02:08:16 +01:00
Felipe Cardoso
bd988f76b0 fix(memory): address review findings from Issue #88
Fixes based on multi-agent review:

Model Improvements:
- Remove duplicate index ix_procedures_agent_type (already indexed via Column)
- Fix postgresql_where to use text() instead of string literal in Fact model
- Add thread-safety to Procedure.success_rate property (snapshot values)

Data Integrity Constraints:
- Add CheckConstraint for Episode: importance_score 0-1, duration >= 0, tokens >= 0
- Add CheckConstraint for Fact: confidence 0-1
- Add CheckConstraint for Procedure: success_count >= 0, failure_count >= 0

Migration Updates:
- Add check constraints creation in upgrade()
- Add check constraints removal in downgrade()

Note: SQLAlchemy Column default=list is correct (callable factory pattern)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 01:54:51 +01:00
Felipe Cardoso
4974233169 feat(memory): add working memory implementation (Issue #89)
Implements session-scoped ephemeral memory with:

Storage Backends:
- InMemoryStorage: Thread-safe fallback with TTL support and capacity limits
- RedisStorage: Primary storage with connection pooling and JSON serialization
- Auto-fallback from Redis to in-memory when unavailable

WorkingMemory Class:
- Key-value storage with TTL and reserved key protection
- Task state tracking with progress updates
- Scratchpad for reasoning steps with timestamps
- Checkpoint/snapshot support for recovery
- Factory methods for auto-configured storage

Tests:
- 55 unit tests covering all functionality
- Tests for basic ops, TTL, capacity, concurrency
- Tests for task state, scratchpad, checkpoints

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 01:51:03 +01:00
Felipe Cardoso
c9d8c0835c feat(memory): add database schema and storage layer (Issue #88)
Add SQLAlchemy models for the Agent Memory System:
- WorkingMemory: Key-value storage with TTL for active sessions
- Episode: Experiential memories from task executions
- Fact: Semantic knowledge triples with confidence scores
- Procedure: Learned skills and procedures with success tracking
- MemoryConsolidationLog: Tracks consolidation jobs between memory tiers

Create enums for memory system:
- ScopeType: global, project, agent_type, agent_instance, session
- EpisodeOutcome: success, failure, partial
- ConsolidationType: working_to_episodic, episodic_to_semantic, etc.
- ConsolidationStatus: pending, running, completed, failed

Add Alembic migration (0005) for all memory tables with:
- Foreign key relationships to projects, agent_instances, agent_types
- Comprehensive indexes for query patterns
- Unique constraints for key lookups and triple uniqueness
- Vector embedding column placeholders (Text fallback until pgvector enabled)

Fix timezone-naive datetime.now() in types.py TaskState (review feedback)

Includes 30 unit tests for models and enums.

Closes #88

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 01:37:58 +01:00
Felipe Cardoso
085a748929 feat(memory): #87 project setup & core architecture
Implements Sub-Issue #87 of Issue #62 (Agent Memory System).

Core infrastructure:
- memory/types.py: Type definitions for all memory types (Working, Episodic,
  Semantic, Procedural) with enums for MemoryType, ScopeLevel, Outcome
- memory/config.py: MemorySettings with MEM_ env prefix, thread-safe singleton
- memory/exceptions.py: Comprehensive exception hierarchy for memory operations
- memory/manager.py: MemoryManager facade with placeholder methods

Directory structure:
- working/: Working memory (Redis/in-memory) - to be implemented in #89
- episodic/: Episodic memory (experiences) - to be implemented in #90
- semantic/: Semantic memory (facts) - to be implemented in #91
- procedural/: Procedural memory (skills) - to be implemented in #92
- scoping/: Scope management - to be implemented in #93
- indexing/: Vector indexing - to be implemented in #94
- consolidation/: Memory consolidation - to be implemented in #95

Tests: 71 unit tests for config, types, and exceptions
Docs: Comprehensive implementation plan at docs/architecture/memory-system-plan.md

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 01:27:36 +01:00
188 changed files with 40275 additions and 961 deletions

View File

@@ -1,5 +1,5 @@
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all format-all
VERSION ?= latest
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
@@ -22,6 +22,9 @@ help:
@echo " make test-cov - Run all tests with coverage reports"
@echo " make test-integration - Run MCP integration tests (requires running stack)"
@echo ""
@echo "Formatting:"
@echo " make format-all - Format code in backend + MCP servers + frontend"
@echo ""
@echo "Validation:"
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
@echo " make validate-all - Validate everything including frontend"
@@ -161,6 +164,25 @@ test-integration:
@echo "Note: Requires running stack (make dev first)"
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
# ============================================================================
# Formatting
# ============================================================================
format-all:
@echo "Formatting backend..."
@cd backend && make format
@echo ""
@echo "Formatting LLM Gateway..."
@cd mcp-servers/llm-gateway && make format
@echo ""
@echo "Formatting Knowledge Base..."
@cd mcp-servers/knowledge-base && make format
@echo ""
@echo "Formatting frontend..."
@cd frontend && npm run format
@echo ""
@echo "All code formatted!"
# ============================================================================
# Validation (lint + type-check + test)
# ============================================================================

View File

@@ -80,7 +80,7 @@ test:
test-cov:
@echo "🧪 Running tests with coverage..."
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 16
@IS_TEST=True PYTHONPATH=. uv run pytest --cov=app --cov-report=term-missing --cov-report=html -n 20
@echo "📊 Coverage report generated in htmlcov/index.html"
# ============================================================================

View File

@@ -0,0 +1,512 @@
"""Add Agent Memory System tables
Revision ID: 0005
Revises: 0004
Create Date: 2025-01-05
This migration creates the Agent Memory System tables:
- working_memory: Key-value storage with TTL for active sessions
- episodes: Experiential memories from task executions
- facts: Semantic knowledge triples with confidence scores
- procedures: Learned skills and procedures
- memory_consolidation_log: Tracks consolidation jobs
See Issue #88: Database Schema & Storage Layer
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "0005"
down_revision: str | None = "0004"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Create Agent Memory System tables."""
# =========================================================================
# Create ENUM types for memory system
# =========================================================================
# Scope type enum
scope_type_enum = postgresql.ENUM(
"global",
"project",
"agent_type",
"agent_instance",
"session",
name="scope_type",
create_type=False,
)
scope_type_enum.create(op.get_bind(), checkfirst=True)
# Episode outcome enum
episode_outcome_enum = postgresql.ENUM(
"success",
"failure",
"partial",
name="episode_outcome",
create_type=False,
)
episode_outcome_enum.create(op.get_bind(), checkfirst=True)
# Consolidation type enum
consolidation_type_enum = postgresql.ENUM(
"working_to_episodic",
"episodic_to_semantic",
"episodic_to_procedural",
"pruning",
name="consolidation_type",
create_type=False,
)
consolidation_type_enum.create(op.get_bind(), checkfirst=True)
# Consolidation status enum
consolidation_status_enum = postgresql.ENUM(
"pending",
"running",
"completed",
"failed",
name="consolidation_status",
create_type=False,
)
consolidation_status_enum.create(op.get_bind(), checkfirst=True)
# =========================================================================
# Create working_memory table
# Key-value storage with TTL for active sessions
# =========================================================================
op.create_table(
"working_memory",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
"scope_type",
scope_type_enum,
nullable=False,
),
sa.Column("scope_id", sa.String(255), nullable=False),
sa.Column("key", sa.String(255), nullable=False),
sa.Column("value", postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
)
# Working memory indexes
op.create_index(
"ix_working_memory_scope_type",
"working_memory",
["scope_type"],
)
op.create_index(
"ix_working_memory_scope_id",
"working_memory",
["scope_id"],
)
op.create_index(
"ix_working_memory_scope_key",
"working_memory",
["scope_type", "scope_id", "key"],
unique=True,
)
op.create_index(
"ix_working_memory_expires",
"working_memory",
["expires_at"],
)
op.create_index(
"ix_working_memory_scope_list",
"working_memory",
["scope_type", "scope_id"],
)
# =========================================================================
# Create episodes table
# Experiential memories from task executions
# =========================================================================
op.create_table(
"episodes",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("agent_instance_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("session_id", sa.String(255), nullable=False),
sa.Column("task_type", sa.String(100), nullable=False),
sa.Column("task_description", sa.Text(), nullable=False),
sa.Column(
"actions",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
sa.Column("context_summary", sa.Text(), nullable=False),
sa.Column(
"outcome",
episode_outcome_enum,
nullable=False,
),
sa.Column("outcome_details", sa.Text(), nullable=True),
sa.Column("duration_seconds", sa.Float(), nullable=False, server_default="0.0"),
sa.Column("tokens_used", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column(
"lessons_learned",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
sa.Column("importance_score", sa.Float(), nullable=False, server_default="0.5"),
# Vector embedding - using TEXT as fallback, will be VECTOR(1536) when pgvector is available
sa.Column("embedding", sa.Text(), nullable=True),
sa.Column("occurred_at", sa.DateTime(timezone=True), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["project_id"],
["projects.id"],
name="fk_episodes_project",
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["agent_instance_id"],
["agent_instances.id"],
name="fk_episodes_agent_instance",
ondelete="SET NULL",
),
sa.ForeignKeyConstraint(
["agent_type_id"],
["agent_types.id"],
name="fk_episodes_agent_type",
ondelete="SET NULL",
),
)
# Episode indexes
op.create_index("ix_episodes_project_id", "episodes", ["project_id"])
op.create_index("ix_episodes_agent_instance_id", "episodes", ["agent_instance_id"])
op.create_index("ix_episodes_agent_type_id", "episodes", ["agent_type_id"])
op.create_index("ix_episodes_session_id", "episodes", ["session_id"])
op.create_index("ix_episodes_task_type", "episodes", ["task_type"])
op.create_index("ix_episodes_outcome", "episodes", ["outcome"])
op.create_index("ix_episodes_importance_score", "episodes", ["importance_score"])
op.create_index("ix_episodes_occurred_at", "episodes", ["occurred_at"])
op.create_index("ix_episodes_project_task", "episodes", ["project_id", "task_type"])
op.create_index(
"ix_episodes_project_outcome", "episodes", ["project_id", "outcome"]
)
op.create_index(
"ix_episodes_agent_task", "episodes", ["agent_instance_id", "task_type"]
)
op.create_index(
"ix_episodes_project_time", "episodes", ["project_id", "occurred_at"]
)
op.create_index(
"ix_episodes_importance_time",
"episodes",
["importance_score", "occurred_at"],
)
# =========================================================================
# Create facts table
# Semantic knowledge triples with confidence scores
# =========================================================================
op.create_table(
"facts",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
"project_id", postgresql.UUID(as_uuid=True), nullable=True
), # NULL for global facts
sa.Column("subject", sa.String(500), nullable=False),
sa.Column("predicate", sa.String(255), nullable=False),
sa.Column("object", sa.Text(), nullable=False),
sa.Column("confidence", sa.Float(), nullable=False, server_default="0.8"),
# Source episode IDs stored as JSON array of UUID strings for cross-db compatibility
sa.Column(
"source_episode_ids",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
sa.Column("first_learned", sa.DateTime(timezone=True), nullable=False),
sa.Column("last_reinforced", sa.DateTime(timezone=True), nullable=False),
sa.Column(
"reinforcement_count", sa.Integer(), nullable=False, server_default="1"
),
# Vector embedding
sa.Column("embedding", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["project_id"],
["projects.id"],
name="fk_facts_project",
ondelete="CASCADE",
),
)
# Fact indexes
op.create_index("ix_facts_project_id", "facts", ["project_id"])
op.create_index("ix_facts_subject", "facts", ["subject"])
op.create_index("ix_facts_predicate", "facts", ["predicate"])
op.create_index("ix_facts_confidence", "facts", ["confidence"])
op.create_index("ix_facts_subject_predicate", "facts", ["subject", "predicate"])
op.create_index("ix_facts_project_subject", "facts", ["project_id", "subject"])
op.create_index(
"ix_facts_confidence_time", "facts", ["confidence", "last_reinforced"]
)
# Unique constraint for triples within project scope
op.create_index(
"ix_facts_unique_triple",
"facts",
["project_id", "subject", "predicate", "object"],
unique=True,
postgresql_where=sa.text("project_id IS NOT NULL"),
)
# Unique constraint for global facts (project_id IS NULL)
op.create_index(
"ix_facts_unique_triple_global",
"facts",
["subject", "predicate", "object"],
unique=True,
postgresql_where=sa.text("project_id IS NULL"),
)
# =========================================================================
# Create procedures table
# Learned skills and procedures
# =========================================================================
op.create_table(
"procedures",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("project_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("agent_type_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("trigger_pattern", sa.Text(), nullable=False),
sa.Column(
"steps",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
sa.Column("success_count", sa.Integer(), nullable=False, server_default="0"),
sa.Column("failure_count", sa.Integer(), nullable=False, server_default="0"),
sa.Column("last_used", sa.DateTime(timezone=True), nullable=True),
# Vector embedding
sa.Column("embedding", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["project_id"],
["projects.id"],
name="fk_procedures_project",
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["agent_type_id"],
["agent_types.id"],
name="fk_procedures_agent_type",
ondelete="SET NULL",
),
)
# Procedure indexes
op.create_index("ix_procedures_project_id", "procedures", ["project_id"])
op.create_index("ix_procedures_agent_type_id", "procedures", ["agent_type_id"])
op.create_index("ix_procedures_name", "procedures", ["name"])
op.create_index("ix_procedures_last_used", "procedures", ["last_used"])
op.create_index(
"ix_procedures_unique_name",
"procedures",
["project_id", "agent_type_id", "name"],
unique=True,
)
op.create_index("ix_procedures_project_name", "procedures", ["project_id", "name"])
# Note: agent_type_id already indexed via ix_procedures_agent_type_id (line 354)
op.create_index(
"ix_procedures_success_rate",
"procedures",
["success_count", "failure_count"],
)
# =========================================================================
# Add check constraints for data integrity
# =========================================================================
# Episode constraints
op.create_check_constraint(
"ck_episodes_importance_range",
"episodes",
"importance_score >= 0.0 AND importance_score <= 1.0",
)
op.create_check_constraint(
"ck_episodes_duration_positive",
"episodes",
"duration_seconds >= 0.0",
)
op.create_check_constraint(
"ck_episodes_tokens_positive",
"episodes",
"tokens_used >= 0",
)
# Fact constraints
op.create_check_constraint(
"ck_facts_confidence_range",
"facts",
"confidence >= 0.0 AND confidence <= 1.0",
)
op.create_check_constraint(
"ck_facts_reinforcement_positive",
"facts",
"reinforcement_count >= 1",
)
# Procedure constraints
op.create_check_constraint(
"ck_procedures_success_positive",
"procedures",
"success_count >= 0",
)
op.create_check_constraint(
"ck_procedures_failure_positive",
"procedures",
"failure_count >= 0",
)
# =========================================================================
# Create memory_consolidation_log table
# Tracks consolidation jobs
# =========================================================================
op.create_table(
"memory_consolidation_log",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
"consolidation_type",
consolidation_type_enum,
nullable=False,
),
sa.Column("source_count", sa.Integer(), nullable=False, server_default="0"),
sa.Column("result_count", sa.Integer(), nullable=False, server_default="0"),
sa.Column("started_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"status",
consolidation_status_enum,
nullable=False,
server_default="pending",
),
sa.Column("error", sa.Text(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
)
# Consolidation log indexes
op.create_index(
"ix_consolidation_type",
"memory_consolidation_log",
["consolidation_type"],
)
op.create_index(
"ix_consolidation_status",
"memory_consolidation_log",
["status"],
)
op.create_index(
"ix_consolidation_type_status",
"memory_consolidation_log",
["consolidation_type", "status"],
)
op.create_index(
"ix_consolidation_started",
"memory_consolidation_log",
["started_at"],
)
def downgrade() -> None:
"""Drop Agent Memory System tables."""
# Drop check constraints first
op.drop_constraint("ck_procedures_failure_positive", "procedures", type_="check")
op.drop_constraint("ck_procedures_success_positive", "procedures", type_="check")
op.drop_constraint("ck_facts_reinforcement_positive", "facts", type_="check")
op.drop_constraint("ck_facts_confidence_range", "facts", type_="check")
op.drop_constraint("ck_episodes_tokens_positive", "episodes", type_="check")
op.drop_constraint("ck_episodes_duration_positive", "episodes", type_="check")
op.drop_constraint("ck_episodes_importance_range", "episodes", type_="check")
# Drop unique indexes for global facts
op.drop_index("ix_facts_unique_triple_global", "facts")
# Drop tables in reverse order (dependencies first)
op.drop_table("memory_consolidation_log")
op.drop_table("procedures")
op.drop_table("facts")
op.drop_table("episodes")
op.drop_table("working_memory")
# Drop ENUM types
op.execute("DROP TYPE IF EXISTS consolidation_status")
op.execute("DROP TYPE IF EXISTS consolidation_type")
op.execute("DROP TYPE IF EXISTS episode_outcome")
op.execute("DROP TYPE IF EXISTS scope_type")

View File

@@ -0,0 +1,52 @@
"""Add ABANDONED to episode_outcome enum
Revision ID: 0006
Revises: 0005
Create Date: 2025-01-06
This migration adds the 'abandoned' value to the episode_outcome enum type.
This allows episodes to track when a task was abandoned (not completed,
but not necessarily a failure either - e.g., user cancelled, session timeout).
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0006"
down_revision: str | None = "0005"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add 'abandoned' value to episode_outcome enum."""
# PostgreSQL ALTER TYPE ADD VALUE is safe and non-blocking
op.execute("ALTER TYPE episode_outcome ADD VALUE IF NOT EXISTS 'abandoned'")
def downgrade() -> None:
"""Remove 'abandoned' from episode_outcome enum.
Note: PostgreSQL doesn't support removing values from enums directly.
This downgrade converts any 'abandoned' episodes to 'failure' and
recreates the enum without 'abandoned'.
"""
# Convert any abandoned episodes to failure first
op.execute("""
UPDATE episodes
SET outcome = 'failure'
WHERE outcome = 'abandoned'
""")
# Recreate the enum without abandoned
# This is complex in PostgreSQL - requires creating new type, updating columns, dropping old
op.execute("ALTER TYPE episode_outcome RENAME TO episode_outcome_old")
op.execute("CREATE TYPE episode_outcome AS ENUM ('success', 'failure', 'partial')")
op.execute("""
ALTER TABLE episodes
ALTER COLUMN outcome TYPE episode_outcome
USING outcome::text::episode_outcome
""")
op.execute("DROP TYPE episode_outcome_old")

View File

@@ -0,0 +1,90 @@
"""Add category and display fields to agent_types table
Revision ID: 0007
Revises: 0006
Create Date: 2026-01-06
This migration adds:
- category: String(50) for grouping agents by role type
- icon: String(50) for Lucide icon identifier
- color: String(7) for hex color code
- sort_order: Integer for display ordering within categories
- typical_tasks: JSONB list of tasks this agent excels at
- collaboration_hints: JSONB list of agent slugs that work well together
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "0007"
down_revision: str | None = "0006"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add category and display fields to agent_types table."""
# Add new columns
op.add_column(
"agent_types",
sa.Column("category", sa.String(length=50), nullable=True),
)
op.add_column(
"agent_types",
sa.Column("icon", sa.String(length=50), nullable=True, server_default="bot"),
)
op.add_column(
"agent_types",
sa.Column(
"color", sa.String(length=7), nullable=True, server_default="#3B82F6"
),
)
op.add_column(
"agent_types",
sa.Column("sort_order", sa.Integer(), nullable=False, server_default="0"),
)
op.add_column(
"agent_types",
sa.Column(
"typical_tasks",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
)
op.add_column(
"agent_types",
sa.Column(
"collaboration_hints",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
)
# Add indexes for category and sort_order
op.create_index("ix_agent_types_category", "agent_types", ["category"])
op.create_index("ix_agent_types_sort_order", "agent_types", ["sort_order"])
op.create_index(
"ix_agent_types_category_sort", "agent_types", ["category", "sort_order"]
)
def downgrade() -> None:
"""Remove category and display fields from agent_types table."""
# Drop indexes
op.drop_index("ix_agent_types_category_sort", table_name="agent_types")
op.drop_index("ix_agent_types_sort_order", table_name="agent_types")
op.drop_index("ix_agent_types_category", table_name="agent_types")
# Drop columns
op.drop_column("agent_types", "collaboration_hints")
op.drop_column("agent_types", "typical_tasks")
op.drop_column("agent_types", "sort_order")
op.drop_column("agent_types", "color")
op.drop_column("agent_types", "icon")
op.drop_column("agent_types", "category")

View File

@@ -81,6 +81,13 @@ def _build_agent_type_response(
mcp_servers=agent_type.mcp_servers,
tool_permissions=agent_type.tool_permissions,
is_active=agent_type.is_active,
# Category and display fields
category=agent_type.category,
icon=agent_type.icon,
color=agent_type.color,
sort_order=agent_type.sort_order,
typical_tasks=agent_type.typical_tasks or [],
collaboration_hints=agent_type.collaboration_hints or [],
created_at=agent_type.created_at,
updated_at=agent_type.updated_at,
instance_count=instance_count,
@@ -300,6 +307,7 @@ async def list_agent_types(
request: Request,
pagination: PaginationParams = Depends(),
is_active: bool = Query(True, description="Filter by active status"),
category: str | None = Query(None, description="Filter by category"),
search: str | None = Query(None, description="Search by name, slug, description"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
@@ -314,6 +322,7 @@ async def list_agent_types(
request: FastAPI request object
pagination: Pagination parameters (page, limit)
is_active: Filter by active status (default: True)
category: Filter by category (e.g., "development", "design")
search: Optional search term for name, slug, description
current_user: Authenticated user
db: Database session
@@ -328,6 +337,7 @@ async def list_agent_types(
skip=pagination.offset,
limit=pagination.limit,
is_active=is_active,
category=category,
search=search,
)
@@ -354,6 +364,51 @@ async def list_agent_types(
raise
@router.get(
"/grouped",
response_model=dict[str, list[AgentTypeResponse]],
summary="List Agent Types Grouped by Category",
description="Get all agent types organized by category",
operation_id="list_agent_types_grouped",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def list_agent_types_grouped(
request: Request,
is_active: bool = Query(True, description="Filter by active status"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get agent types grouped by category.
Returns a dictionary where keys are category names and values
are lists of agent types, sorted by sort_order within each category.
Args:
request: FastAPI request object
is_active: Filter by active status (default: True)
current_user: Authenticated user
db: Database session
Returns:
Dictionary mapping category to list of agent types
"""
try:
grouped = await agent_type_crud.get_grouped_by_category(db, is_active=is_active)
# Transform to response objects
result: dict[str, list[AgentTypeResponse]] = {}
for category, types in grouped.items():
result[category] = [
_build_agent_type_response(t, instance_count=0) for t in types
]
return result
except Exception as e:
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
raise
@router.get(
"/{agent_type_id}",
response_model=AgentTypeResponse,

View File

@@ -1,366 +0,0 @@
{
"organizations": [
{
"name": "Acme Corp",
"slug": "acme-corp",
"description": "A leading provider of coyote-catching equipment."
},
{
"name": "Globex Corporation",
"slug": "globex",
"description": "We own the East Coast."
},
{
"name": "Soylent Corp",
"slug": "soylent",
"description": "Making food for the future."
},
{
"name": "Initech",
"slug": "initech",
"description": "Software for the soul."
},
{
"name": "Umbrella Corporation",
"slug": "umbrella",
"description": "Our business is life itself."
},
{
"name": "Massive Dynamic",
"slug": "massive-dynamic",
"description": "What don't we do?"
}
],
"users": [
{
"email": "demo@example.com",
"password": "DemoPass1234!",
"first_name": "Demo",
"last_name": "User",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": true
},
{
"email": "alice@acme.com",
"password": "Demo123!",
"first_name": "Alice",
"last_name": "Smith",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "admin",
"is_active": true
},
{
"email": "bob@acme.com",
"password": "Demo123!",
"first_name": "Bob",
"last_name": "Jones",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": true
},
{
"email": "charlie@acme.com",
"password": "Demo123!",
"first_name": "Charlie",
"last_name": "Brown",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": false
},
{
"email": "diana@acme.com",
"password": "Demo123!",
"first_name": "Diana",
"last_name": "Prince",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": true
},
{
"email": "carol@globex.com",
"password": "Demo123!",
"first_name": "Carol",
"last_name": "Williams",
"is_superuser": false,
"organization_slug": "globex",
"role": "owner",
"is_active": true
},
{
"email": "dan@globex.com",
"password": "Demo123!",
"first_name": "Dan",
"last_name": "Miller",
"is_superuser": false,
"organization_slug": "globex",
"role": "member",
"is_active": true
},
{
"email": "ellen@globex.com",
"password": "Demo123!",
"first_name": "Ellen",
"last_name": "Ripley",
"is_superuser": false,
"organization_slug": "globex",
"role": "member",
"is_active": true
},
{
"email": "fred@globex.com",
"password": "Demo123!",
"first_name": "Fred",
"last_name": "Flintstone",
"is_superuser": false,
"organization_slug": "globex",
"role": "member",
"is_active": true
},
{
"email": "dave@soylent.com",
"password": "Demo123!",
"first_name": "Dave",
"last_name": "Brown",
"is_superuser": false,
"organization_slug": "soylent",
"role": "member",
"is_active": true
},
{
"email": "gina@soylent.com",
"password": "Demo123!",
"first_name": "Gina",
"last_name": "Torres",
"is_superuser": false,
"organization_slug": "soylent",
"role": "member",
"is_active": true
},
{
"email": "harry@soylent.com",
"password": "Demo123!",
"first_name": "Harry",
"last_name": "Potter",
"is_superuser": false,
"organization_slug": "soylent",
"role": "admin",
"is_active": true
},
{
"email": "eve@initech.com",
"password": "Demo123!",
"first_name": "Eve",
"last_name": "Davis",
"is_superuser": false,
"organization_slug": "initech",
"role": "admin",
"is_active": true
},
{
"email": "iris@initech.com",
"password": "Demo123!",
"first_name": "Iris",
"last_name": "West",
"is_superuser": false,
"organization_slug": "initech",
"role": "member",
"is_active": true
},
{
"email": "jack@initech.com",
"password": "Demo123!",
"first_name": "Jack",
"last_name": "Sparrow",
"is_superuser": false,
"organization_slug": "initech",
"role": "member",
"is_active": false
},
{
"email": "frank@umbrella.com",
"password": "Demo123!",
"first_name": "Frank",
"last_name": "Miller",
"is_superuser": false,
"organization_slug": "umbrella",
"role": "member",
"is_active": true
},
{
"email": "george@umbrella.com",
"password": "Demo123!",
"first_name": "George",
"last_name": "Costanza",
"is_superuser": false,
"organization_slug": "umbrella",
"role": "member",
"is_active": false
},
{
"email": "kate@umbrella.com",
"password": "Demo123!",
"first_name": "Kate",
"last_name": "Bishop",
"is_superuser": false,
"organization_slug": "umbrella",
"role": "member",
"is_active": true
},
{
"email": "leo@massive.com",
"password": "Demo123!",
"first_name": "Leo",
"last_name": "Messi",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "owner",
"is_active": true
},
{
"email": "mary@massive.com",
"password": "Demo123!",
"first_name": "Mary",
"last_name": "Jane",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "nathan@massive.com",
"password": "Demo123!",
"first_name": "Nathan",
"last_name": "Drake",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "olivia@massive.com",
"password": "Demo123!",
"first_name": "Olivia",
"last_name": "Dunham",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "admin",
"is_active": true
},
{
"email": "peter@massive.com",
"password": "Demo123!",
"first_name": "Peter",
"last_name": "Parker",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "quinn@massive.com",
"password": "Demo123!",
"first_name": "Quinn",
"last_name": "Mallory",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "grace@example.com",
"password": "Demo123!",
"first_name": "Grace",
"last_name": "Hopper",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "heidi@example.com",
"password": "Demo123!",
"first_name": "Heidi",
"last_name": "Klum",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "ivan@example.com",
"password": "Demo123!",
"first_name": "Ivan",
"last_name": "Drago",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": false
},
{
"email": "rachel@example.com",
"password": "Demo123!",
"first_name": "Rachel",
"last_name": "Green",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "sam@example.com",
"password": "Demo123!",
"first_name": "Sam",
"last_name": "Wilson",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "tony@example.com",
"password": "Demo123!",
"first_name": "Tony",
"last_name": "Stark",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "una@example.com",
"password": "Demo123!",
"first_name": "Una",
"last_name": "Chin-Riley",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": false
},
{
"email": "victor@example.com",
"password": "Demo123!",
"first_name": "Victor",
"last_name": "Von Doom",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "wanda@example.com",
"password": "Demo123!",
"first_name": "Wanda",
"last_name": "Maximoff",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
}
]
}

View File

@@ -43,6 +43,13 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
mcp_servers=obj_in.mcp_servers,
tool_permissions=obj_in.tool_permissions,
is_active=obj_in.is_active,
# Category and display fields
category=obj_in.category.value if obj_in.category else None,
icon=obj_in.icon,
color=obj_in.color,
sort_order=obj_in.sort_order,
typical_tasks=obj_in.typical_tasks,
collaboration_hints=obj_in.collaboration_hints,
)
db.add(db_obj)
await db.commit()
@@ -68,6 +75,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
category: str | None = None,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
@@ -85,6 +93,9 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
if is_active is not None:
query = query.where(AgentType.is_active == is_active)
if category:
query = query.where(AgentType.category == category)
if search:
search_filter = or_(
AgentType.name.ilike(f"%{search}%"),
@@ -162,6 +173,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
category: str | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
@@ -177,6 +189,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
skip=skip,
limit=limit,
is_active=is_active,
category=category,
search=search,
)
@@ -260,6 +273,44 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
)
raise
async def get_grouped_by_category(
self,
db: AsyncSession,
*,
is_active: bool = True,
) -> dict[str, list[AgentType]]:
"""
Get agent types grouped by category, sorted by sort_order within each group.
Args:
db: Database session
is_active: Filter by active status (default: True)
Returns:
Dictionary mapping category to list of agent types
"""
try:
query = (
select(AgentType)
.where(AgentType.is_active == is_active)
.order_by(AgentType.category, AgentType.sort_order, AgentType.name)
)
result = await db.execute(query)
agent_types = list(result.scalars().all())
# Group by category
grouped: dict[str, list[AgentType]] = {}
for at in agent_types:
cat: str = str(at.category) if at.category else "uncategorized"
if cat not in grouped:
grouped[cat] = []
grouped[cat].append(at)
return grouped
except Exception as e:
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
raise
# Create a singleton instance for use across the application
agent_type = CRUDAgentType(AgentType)

View File

@@ -3,27 +3,48 @@
Async database initialization script.
Creates the first superuser if configured and doesn't already exist.
Seeds default agent types (production data) and demo data (when DEMO_MODE is enabled).
"""
import asyncio
import json
import logging
import random
from datetime import UTC, datetime, timedelta
from datetime import UTC, date, datetime, timedelta
from pathlib import Path
from sqlalchemy import select, text
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.database import SessionLocal, engine
from app.crud.syndarix.agent_type import agent_type as agent_type_crud
from app.crud.user import user as user_crud
from app.models.organization import Organization
from app.models.syndarix import AgentInstance, AgentType, Issue, Project, Sprint
from app.models.syndarix.enums import (
AgentStatus,
AutonomyLevel,
ClientMode,
IssuePriority,
IssueStatus,
IssueType,
ProjectComplexity,
ProjectStatus,
SprintStatus,
)
from app.models.user import User
from app.models.user_organization import UserOrganization
from app.schemas.syndarix import AgentTypeCreate
from app.schemas.users import UserCreate
logger = logging.getLogger(__name__)
# Data file paths
DATA_DIR = Path(__file__).parent.parent / "data"
DEFAULT_AGENT_TYPES_PATH = DATA_DIR / "default_agent_types.json"
DEMO_DATA_PATH = DATA_DIR / "demo_data.json"
async def init_db() -> User | None:
"""
@@ -54,28 +75,29 @@ async def init_db() -> User | None:
if existing_user:
logger.info(f"Superuser already exists: {existing_user.email}")
return existing_user
else:
# Create superuser if doesn't exist
user_in = UserCreate(
email=superuser_email,
password=superuser_password,
first_name="Admin",
last_name="User",
is_superuser=True,
)
# Create superuser if doesn't exist
user_in = UserCreate(
email=superuser_email,
password=superuser_password,
first_name="Admin",
last_name="User",
is_superuser=True,
)
existing_user = await user_crud.create(session, obj_in=user_in)
await session.commit()
await session.refresh(existing_user)
logger.info(f"Created first superuser: {existing_user.email}")
user = await user_crud.create(session, obj_in=user_in)
await session.commit()
await session.refresh(user)
# ALWAYS load default agent types (production data)
await load_default_agent_types(session)
logger.info(f"Created first superuser: {user.email}")
# Create demo data if in demo mode
# Only load demo data if in demo mode
if settings.DEMO_MODE:
await load_demo_data(session)
return user
return existing_user
except Exception as e:
await session.rollback()
@@ -88,26 +110,96 @@ def _load_json_file(path: Path):
return json.load(f)
async def load_demo_data(session):
"""Load demo data from JSON file."""
demo_data_path = Path(__file__).parent / "core" / "demo_data.json"
if not demo_data_path.exists():
logger.warning(f"Demo data file not found: {demo_data_path}")
async def load_default_agent_types(session: AsyncSession) -> None:
"""
Load default agent types from JSON file.
These are production defaults - created only if they don't exist, never overwritten.
This allows users to customize agent types without worrying about server restarts.
"""
if not DEFAULT_AGENT_TYPES_PATH.exists():
logger.warning(
f"Default agent types file not found: {DEFAULT_AGENT_TYPES_PATH}"
)
return
try:
# Use asyncio.to_thread to avoid blocking the event loop
data = await asyncio.to_thread(_load_json_file, demo_data_path)
data = await asyncio.to_thread(_load_json_file, DEFAULT_AGENT_TYPES_PATH)
# Create Organizations
org_map = {}
for org_data in data.get("organizations", []):
# Check if org exists
result = await session.execute(
text("SELECT * FROM organizations WHERE slug = :slug"),
{"slug": org_data["slug"]},
for agent_type_data in data:
slug = agent_type_data["slug"]
# Check if agent type already exists
existing = await agent_type_crud.get_by_slug(session, slug=slug)
if existing:
logger.debug(f"Agent type already exists: {agent_type_data['name']}")
continue
# Create the agent type
agent_type_in = AgentTypeCreate(
name=agent_type_data["name"],
slug=slug,
description=agent_type_data.get("description"),
expertise=agent_type_data.get("expertise", []),
personality_prompt=agent_type_data["personality_prompt"],
primary_model=agent_type_data["primary_model"],
fallback_models=agent_type_data.get("fallback_models", []),
model_params=agent_type_data.get("model_params", {}),
mcp_servers=agent_type_data.get("mcp_servers", []),
tool_permissions=agent_type_data.get("tool_permissions", {}),
is_active=agent_type_data.get("is_active", True),
# Category and display fields
category=agent_type_data.get("category"),
icon=agent_type_data.get("icon", "bot"),
color=agent_type_data.get("color", "#3B82F6"),
sort_order=agent_type_data.get("sort_order", 0),
typical_tasks=agent_type_data.get("typical_tasks", []),
collaboration_hints=agent_type_data.get("collaboration_hints", []),
)
existing_org = result.first()
await agent_type_crud.create(session, obj_in=agent_type_in)
logger.info(f"Created default agent type: {agent_type_data['name']}")
logger.info("Default agent types loaded successfully")
except Exception as e:
logger.error(f"Error loading default agent types: {e}")
raise
async def load_demo_data(session: AsyncSession) -> None:
"""
Load demo data from JSON file.
Only runs when DEMO_MODE is enabled. Creates demo organizations, users,
projects, sprints, agent instances, and issues.
"""
if not DEMO_DATA_PATH.exists():
logger.warning(f"Demo data file not found: {DEMO_DATA_PATH}")
return
try:
data = await asyncio.to_thread(_load_json_file, DEMO_DATA_PATH)
# Build lookup maps for FK resolution
org_map: dict[str, Organization] = {}
user_map: dict[str, User] = {}
project_map: dict[str, Project] = {}
sprint_map: dict[str, Sprint] = {} # key: "project_slug:sprint_number"
agent_type_map: dict[str, AgentType] = {}
agent_instance_map: dict[
str, AgentInstance
] = {} # key: "project_slug:agent_name"
# ========================
# 1. Create Organizations
# ========================
for org_data in data.get("organizations", []):
org_result = await session.execute(
select(Organization).where(Organization.slug == org_data["slug"])
)
existing_org = org_result.scalar_one_or_none()
if not existing_org:
org = Organization(
@@ -117,29 +209,20 @@ async def load_demo_data(session):
is_active=True,
)
session.add(org)
await session.flush() # Flush to get ID
org_map[org.slug] = org
await session.flush()
org_map[str(org.slug)] = org
logger.info(f"Created demo organization: {org.name}")
else:
# We can't easily get the ORM object from raw SQL result for map without querying again or mapping
# So let's just query it properly if we need it for relationships
# But for simplicity in this script, let's just assume we created it or it exists.
# To properly map for users, we need the ID.
# Let's use a simpler approach: just try to create, if slug conflict, skip.
pass
org_map[str(existing_org.slug)] = existing_org
# Re-query all orgs to build map for users
result = await session.execute(select(Organization))
orgs = result.scalars().all()
org_map = {org.slug: org for org in orgs}
# Create Users
# ========================
# 2. Create Users
# ========================
for user_data in data.get("users", []):
existing_user = await user_crud.get_by_email(
session, email=user_data["email"]
)
if not existing_user:
# Create user
user_in = UserCreate(
email=user_data["email"],
password=user_data["password"],
@@ -151,17 +234,13 @@ async def load_demo_data(session):
user = await user_crud.create(session, obj_in=user_in)
# Randomize created_at for demo data (last 30 days)
# This makes the charts look more realistic
days_ago = random.randint(0, 30) # noqa: S311
random_time = datetime.now(UTC) - timedelta(days=days_ago)
# Add some random hours/minutes variation
random_time = random_time.replace(
hour=random.randint(0, 23), # noqa: S311
minute=random.randint(0, 59), # noqa: S311
)
# Update the timestamp and is_active directly in the database
# We do this to ensure the values are persisted correctly
await session.execute(
text(
"UPDATE users SET created_at = :created_at, is_active = :is_active WHERE id = :user_id"
@@ -174,7 +253,7 @@ async def load_demo_data(session):
)
logger.info(
f"Created demo user: {user.email} (created {days_ago} days ago, active={user_data.get('is_active', True)})"
f"Created demo user: {user.email} (created {days_ago} days ago)"
)
# Add to organization if specified
@@ -182,19 +261,228 @@ async def load_demo_data(session):
role = user_data.get("role")
if org_slug and org_slug in org_map and role:
org = org_map[org_slug]
# Check if membership exists (it shouldn't for new user)
member = UserOrganization(
user_id=user.id, organization_id=org.id, role=role
)
session.add(member)
logger.info(f"Added {user.email} to {org.name} as {role}")
user_map[str(user.email)] = user
else:
logger.info(f"Demo user already exists: {existing_user.email}")
user_map[str(existing_user.email)] = existing_user
logger.debug(f"Demo user already exists: {existing_user.email}")
await session.flush()
# Add admin user to map with special "__admin__" key
# This allows demo data to reference the admin user as owner
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
admin_user = await user_crud.get_by_email(session, email=superuser_email)
if admin_user:
user_map["__admin__"] = admin_user
user_map[str(admin_user.email)] = admin_user
logger.debug(f"Added admin user to map: {admin_user.email}")
# ========================
# 3. Load Agent Types Map (for FK resolution)
# ========================
agent_types_result = await session.execute(select(AgentType))
for at in agent_types_result.scalars().all():
agent_type_map[str(at.slug)] = at
# ========================
# 4. Create Projects
# ========================
for project_data in data.get("projects", []):
project_result = await session.execute(
select(Project).where(Project.slug == project_data["slug"])
)
existing_project = project_result.scalar_one_or_none()
if not existing_project:
# Resolve owner email to user ID
owner_id = None
owner_email = project_data.get("owner_email")
if owner_email and owner_email in user_map:
owner_id = user_map[owner_email].id
project = Project(
name=project_data["name"],
slug=project_data["slug"],
description=project_data.get("description"),
owner_id=owner_id,
autonomy_level=AutonomyLevel(
project_data.get("autonomy_level", "milestone")
),
status=ProjectStatus(project_data.get("status", "active")),
complexity=ProjectComplexity(
project_data.get("complexity", "medium")
),
client_mode=ClientMode(project_data.get("client_mode", "auto")),
settings=project_data.get("settings", {}),
)
session.add(project)
await session.flush()
project_map[str(project.slug)] = project
logger.info(f"Created demo project: {project.name}")
else:
project_map[str(existing_project.slug)] = existing_project
logger.debug(f"Demo project already exists: {existing_project.name}")
# ========================
# 5. Create Sprints
# ========================
for sprint_data in data.get("sprints", []):
project_slug = sprint_data["project_slug"]
sprint_number = sprint_data["number"]
sprint_key = f"{project_slug}:{sprint_number}"
if project_slug not in project_map:
logger.warning(f"Project not found for sprint: {project_slug}")
continue
sprint_project = project_map[project_slug]
# Check if sprint exists
sprint_result = await session.execute(
select(Sprint).where(
Sprint.project_id == sprint_project.id,
Sprint.number == sprint_number,
)
)
existing_sprint = sprint_result.scalar_one_or_none()
if not existing_sprint:
sprint = Sprint(
project_id=sprint_project.id,
name=sprint_data["name"],
number=sprint_number,
goal=sprint_data.get("goal"),
start_date=date.fromisoformat(sprint_data["start_date"]),
end_date=date.fromisoformat(sprint_data["end_date"]),
status=SprintStatus(sprint_data.get("status", "planned")),
planned_points=sprint_data.get("planned_points"),
)
session.add(sprint)
await session.flush()
sprint_map[sprint_key] = sprint
logger.info(
f"Created demo sprint: {sprint.name} for {sprint_project.name}"
)
else:
sprint_map[sprint_key] = existing_sprint
logger.debug(f"Demo sprint already exists: {existing_sprint.name}")
# ========================
# 6. Create Agent Instances
# ========================
for agent_data in data.get("agent_instances", []):
project_slug = agent_data["project_slug"]
agent_type_slug = agent_data["agent_type_slug"]
agent_name = agent_data["name"]
agent_key = f"{project_slug}:{agent_name}"
if project_slug not in project_map:
logger.warning(f"Project not found for agent: {project_slug}")
continue
if agent_type_slug not in agent_type_map:
logger.warning(f"Agent type not found: {agent_type_slug}")
continue
agent_project = project_map[project_slug]
agent_type = agent_type_map[agent_type_slug]
# Check if agent instance exists (by name within project)
agent_result = await session.execute(
select(AgentInstance).where(
AgentInstance.project_id == agent_project.id,
AgentInstance.name == agent_name,
)
)
existing_agent = agent_result.scalar_one_or_none()
if not existing_agent:
agent_instance = AgentInstance(
project_id=agent_project.id,
agent_type_id=agent_type.id,
name=agent_name,
status=AgentStatus(agent_data.get("status", "idle")),
current_task=agent_data.get("current_task"),
)
session.add(agent_instance)
await session.flush()
agent_instance_map[agent_key] = agent_instance
logger.info(
f"Created demo agent: {agent_name} ({agent_type.name}) "
f"for {agent_project.name}"
)
else:
agent_instance_map[agent_key] = existing_agent
logger.debug(f"Demo agent already exists: {existing_agent.name}")
# ========================
# 7. Create Issues
# ========================
for issue_data in data.get("issues", []):
project_slug = issue_data["project_slug"]
if project_slug not in project_map:
logger.warning(f"Project not found for issue: {project_slug}")
continue
issue_project = project_map[project_slug]
# Check if issue exists (by title within project - simple heuristic)
issue_result = await session.execute(
select(Issue).where(
Issue.project_id == issue_project.id,
Issue.title == issue_data["title"],
)
)
existing_issue = issue_result.scalar_one_or_none()
if not existing_issue:
# Resolve sprint
sprint_id = None
sprint_number = issue_data.get("sprint_number")
if sprint_number:
sprint_key = f"{project_slug}:{sprint_number}"
if sprint_key in sprint_map:
sprint_id = sprint_map[sprint_key].id
# Resolve assigned agent
assigned_agent_id = None
assigned_agent_name = issue_data.get("assigned_agent_name")
if assigned_agent_name:
agent_key = f"{project_slug}:{assigned_agent_name}"
if agent_key in agent_instance_map:
assigned_agent_id = agent_instance_map[agent_key].id
issue = Issue(
project_id=issue_project.id,
sprint_id=sprint_id,
type=IssueType(issue_data.get("type", "task")),
title=issue_data["title"],
body=issue_data.get("body", ""),
status=IssueStatus(issue_data.get("status", "open")),
priority=IssuePriority(issue_data.get("priority", "medium")),
labels=issue_data.get("labels", []),
story_points=issue_data.get("story_points"),
assigned_agent_id=assigned_agent_id,
)
session.add(issue)
logger.info(f"Created demo issue: {issue.title[:50]}...")
else:
logger.debug(
f"Demo issue already exists: {existing_issue.title[:50]}..."
)
await session.commit()
logger.info("Demo data loaded successfully")
except Exception as e:
await session.rollback()
logger.error(f"Error loading demo data: {e}")
raise
@@ -210,12 +498,12 @@ async def main():
try:
user = await init_db()
if user:
print("Database initialized successfully")
print(f"Superuser: {user.email}")
print("Database initialized successfully")
print(f"Superuser: {user.email}")
else:
print("Failed to initialize database")
print("Failed to initialize database")
except Exception as e:
print(f"Error initializing database: {e}")
print(f"Error initializing database: {e}")
raise
finally:
# Close the engine

View File

@@ -8,6 +8,19 @@ from app.core.database import Base
from .base import TimestampMixin, UUIDMixin
# Memory system models
from .memory import (
ConsolidationStatus,
ConsolidationType,
Episode,
EpisodeOutcome,
Fact,
MemoryConsolidationLog,
Procedure,
ScopeType,
WorkingMemory,
)
# OAuth models (client mode - authenticate via Google/GitHub)
from .oauth_account import OAuthAccount
@@ -37,7 +50,14 @@ __all__ = [
"AgentInstance",
"AgentType",
"Base",
# Memory models
"ConsolidationStatus",
"ConsolidationType",
"Episode",
"EpisodeOutcome",
"Fact",
"Issue",
"MemoryConsolidationLog",
"OAuthAccount",
"OAuthAuthorizationCode",
"OAuthClient",
@@ -46,11 +66,14 @@ __all__ = [
"OAuthState",
"Organization",
"OrganizationRole",
"Procedure",
"Project",
"ScopeType",
"Sprint",
"TimestampMixin",
"UUIDMixin",
"User",
"UserOrganization",
"UserSession",
"WorkingMemory",
]

View File

@@ -0,0 +1,32 @@
# app/models/memory/__init__.py
"""
Memory System Database Models.
Provides SQLAlchemy models for the Agent Memory System:
- WorkingMemory: Key-value storage with TTL
- Episode: Experiential memories
- Fact: Semantic knowledge triples
- Procedure: Learned skills
- MemoryConsolidationLog: Consolidation job tracking
"""
from .consolidation import MemoryConsolidationLog
from .enums import ConsolidationStatus, ConsolidationType, EpisodeOutcome, ScopeType
from .episode import Episode
from .fact import Fact
from .procedure import Procedure
from .working_memory import WorkingMemory
__all__ = [
# Enums
"ConsolidationStatus",
"ConsolidationType",
# Models
"Episode",
"EpisodeOutcome",
"Fact",
"MemoryConsolidationLog",
"Procedure",
"ScopeType",
"WorkingMemory",
]

View File

@@ -0,0 +1,72 @@
# app/models/memory/consolidation.py
"""
Memory Consolidation Log database model.
Tracks memory consolidation jobs that transfer knowledge
between memory tiers.
"""
from sqlalchemy import Column, DateTime, Enum, Index, Integer, Text
from app.models.base import Base, TimestampMixin, UUIDMixin
from .enums import ConsolidationStatus, ConsolidationType
class MemoryConsolidationLog(Base, UUIDMixin, TimestampMixin):
"""
Memory consolidation job log.
Tracks consolidation operations:
- Working -> Episodic (session end)
- Episodic -> Semantic (fact extraction)
- Episodic -> Procedural (procedure learning)
- Pruning (removing low-value memories)
"""
__tablename__ = "memory_consolidation_log"
# Consolidation type
consolidation_type: Column[ConsolidationType] = Column(
Enum(ConsolidationType),
nullable=False,
index=True,
)
# Counts
source_count = Column(Integer, nullable=False, default=0)
result_count = Column(Integer, nullable=False, default=0)
# Timing
started_at = Column(DateTime(timezone=True), nullable=False)
completed_at = Column(DateTime(timezone=True), nullable=True)
# Status
status: Column[ConsolidationStatus] = Column(
Enum(ConsolidationStatus),
nullable=False,
default=ConsolidationStatus.PENDING,
index=True,
)
# Error details if failed
error = Column(Text, nullable=True)
__table_args__ = (
# Query patterns
Index("ix_consolidation_type_status", "consolidation_type", "status"),
Index("ix_consolidation_started", "started_at"),
)
@property
def duration_seconds(self) -> float | None:
"""Calculate duration of the consolidation job."""
if self.completed_at is None or self.started_at is None:
return None
return (self.completed_at - self.started_at).total_seconds()
def __repr__(self) -> str:
return (
f"<MemoryConsolidationLog {self.id} "
f"type={self.consolidation_type.value} status={self.status.value}>"
)

View File

@@ -0,0 +1,73 @@
# app/models/memory/enums.py
"""
Enums for Memory System database models.
These enums define the database-level constraints for memory types
and scoping levels.
"""
from enum import Enum as PyEnum
class ScopeType(str, PyEnum):
"""
Memory scope levels matching the memory service types.
GLOBAL: System-wide memories accessible by all
PROJECT: Project-scoped memories
AGENT_TYPE: Type-specific memories (shared by instances of same type)
AGENT_INSTANCE: Instance-specific memories
SESSION: Session-scoped ephemeral memories
"""
GLOBAL = "global"
PROJECT = "project"
AGENT_TYPE = "agent_type"
AGENT_INSTANCE = "agent_instance"
SESSION = "session"
class EpisodeOutcome(str, PyEnum):
"""
Outcome of an episode (task execution).
SUCCESS: Task completed successfully
FAILURE: Task failed
PARTIAL: Task partially completed
"""
SUCCESS = "success"
FAILURE = "failure"
PARTIAL = "partial"
class ConsolidationType(str, PyEnum):
"""
Types of memory consolidation operations.
WORKING_TO_EPISODIC: Transfer session state to episodic
EPISODIC_TO_SEMANTIC: Extract facts from episodes
EPISODIC_TO_PROCEDURAL: Extract procedures from episodes
PRUNING: Remove low-value memories
"""
WORKING_TO_EPISODIC = "working_to_episodic"
EPISODIC_TO_SEMANTIC = "episodic_to_semantic"
EPISODIC_TO_PROCEDURAL = "episodic_to_procedural"
PRUNING = "pruning"
class ConsolidationStatus(str, PyEnum):
"""
Status of a consolidation job.
PENDING: Job is queued
RUNNING: Job is currently executing
COMPLETED: Job finished successfully
FAILED: Job failed with errors
"""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"

View File

@@ -0,0 +1,139 @@
# app/models/memory/episode.py
"""
Episode database model.
Stores experiential memories - records of past task executions
with context, actions, outcomes, and lessons learned.
"""
from sqlalchemy import (
BigInteger,
CheckConstraint,
Column,
DateTime,
Enum,
Float,
ForeignKey,
Index,
String,
Text,
)
from sqlalchemy.dialects.postgresql import (
JSONB,
UUID as PGUUID,
)
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
from .enums import EpisodeOutcome
# Import pgvector type - will be available after migration enables extension
try:
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
except ImportError:
# Fallback for environments without pgvector
Vector = None
class Episode(Base, UUIDMixin, TimestampMixin):
"""
Episodic memory model.
Records experiential memories from agent task execution:
- What task was performed
- What actions were taken
- What was the outcome
- What lessons were learned
"""
__tablename__ = "episodes"
# Foreign keys
project_id = Column(
PGUUID(as_uuid=True),
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
agent_instance_id = Column(
PGUUID(as_uuid=True),
ForeignKey("agent_instances.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
agent_type_id = Column(
PGUUID(as_uuid=True),
ForeignKey("agent_types.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Session reference
session_id = Column(String(255), nullable=False, index=True)
# Task information
task_type = Column(String(100), nullable=False, index=True)
task_description = Column(Text, nullable=False)
# Actions taken (list of action dictionaries)
actions = Column(JSONB, default=list, nullable=False)
# Context summary
context_summary = Column(Text, nullable=False)
# Outcome
outcome: Column[EpisodeOutcome] = Column(
Enum(EpisodeOutcome),
nullable=False,
index=True,
)
outcome_details = Column(Text, nullable=True)
# Metrics
duration_seconds = Column(Float, nullable=False, default=0.0)
tokens_used = Column(BigInteger, nullable=False, default=0)
# Learning
lessons_learned = Column(JSONB, default=list, nullable=False)
importance_score = Column(Float, nullable=False, default=0.5, index=True)
# Vector embedding for semantic search
# Using 1536 dimensions for OpenAI text-embedding-3-small
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
# When the episode occurred
occurred_at = Column(DateTime(timezone=True), nullable=False, index=True)
# Relationships
project = relationship("Project", foreign_keys=[project_id])
agent_instance = relationship("AgentInstance", foreign_keys=[agent_instance_id])
agent_type = relationship("AgentType", foreign_keys=[agent_type_id])
__table_args__ = (
# Primary query patterns
Index("ix_episodes_project_task", "project_id", "task_type"),
Index("ix_episodes_project_outcome", "project_id", "outcome"),
Index("ix_episodes_agent_task", "agent_instance_id", "task_type"),
Index("ix_episodes_project_time", "project_id", "occurred_at"),
# For importance-based pruning
Index("ix_episodes_importance_time", "importance_score", "occurred_at"),
# Data integrity constraints
CheckConstraint(
"importance_score >= 0.0 AND importance_score <= 1.0",
name="ck_episodes_importance_range",
),
CheckConstraint(
"duration_seconds >= 0.0",
name="ck_episodes_duration_positive",
),
CheckConstraint(
"tokens_used >= 0",
name="ck_episodes_tokens_positive",
),
)
def __repr__(self) -> str:
return f"<Episode {self.id} task={self.task_type} outcome={self.outcome.value}>"

View File

@@ -0,0 +1,120 @@
# app/models/memory/fact.py
"""
Fact database model.
Stores semantic memories - learned facts in subject-predicate-object
triple format with confidence scores and source tracking.
"""
from sqlalchemy import (
CheckConstraint,
Column,
DateTime,
Float,
ForeignKey,
Index,
Integer,
String,
Text,
text,
)
from sqlalchemy.dialects.postgresql import (
JSONB,
UUID as PGUUID,
)
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
# Import pgvector type
try:
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
except ImportError:
Vector = None
class Fact(Base, UUIDMixin, TimestampMixin):
"""
Semantic memory model.
Stores learned facts as subject-predicate-object triples:
- "FastAPI" - "uses" - "Starlette framework"
- "Project Alpha" - "requires" - "OAuth authentication"
Facts have confidence scores that decay over time and can be
reinforced when the same fact is learned again.
"""
__tablename__ = "facts"
# Scoping: project_id is NULL for global facts
project_id = Column(
PGUUID(as_uuid=True),
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=True,
index=True,
)
# Triple format
subject = Column(String(500), nullable=False, index=True)
predicate = Column(String(255), nullable=False, index=True)
object = Column(Text, nullable=False)
# Confidence score (0.0 to 1.0)
confidence = Column(Float, nullable=False, default=0.8, index=True)
# Source tracking: which episodes contributed to this fact (stored as JSONB array of UUID strings)
source_episode_ids: Column[list] = Column(JSONB, default=list, nullable=False)
# Learning history
first_learned = Column(DateTime(timezone=True), nullable=False)
last_reinforced = Column(DateTime(timezone=True), nullable=False)
reinforcement_count = Column(Integer, nullable=False, default=1)
# Vector embedding for semantic search
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
# Relationships
project = relationship("Project", foreign_keys=[project_id])
__table_args__ = (
# Unique constraint on triple within project scope
Index(
"ix_facts_unique_triple",
"project_id",
"subject",
"predicate",
"object",
unique=True,
postgresql_where=text("project_id IS NOT NULL"),
),
# Unique constraint on triple for global facts (project_id IS NULL)
Index(
"ix_facts_unique_triple_global",
"subject",
"predicate",
"object",
unique=True,
postgresql_where=text("project_id IS NULL"),
),
# Query patterns
Index("ix_facts_subject_predicate", "subject", "predicate"),
Index("ix_facts_project_subject", "project_id", "subject"),
Index("ix_facts_confidence_time", "confidence", "last_reinforced"),
# Note: subject already has index=True on Column definition, no need for explicit index
# Data integrity constraints
CheckConstraint(
"confidence >= 0.0 AND confidence <= 1.0",
name="ck_facts_confidence_range",
),
CheckConstraint(
"reinforcement_count >= 1",
name="ck_facts_reinforcement_positive",
),
)
def __repr__(self) -> str:
return (
f"<Fact {self.id} '{self.subject}' - '{self.predicate}' - "
f"'{self.object[:50]}...' conf={self.confidence:.2f}>"
)

View File

@@ -0,0 +1,129 @@
# app/models/memory/procedure.py
"""
Procedure database model.
Stores procedural memories - learned skills and procedures
derived from successful task execution patterns.
"""
from sqlalchemy import (
CheckConstraint,
Column,
DateTime,
ForeignKey,
Index,
Integer,
String,
Text,
)
from sqlalchemy.dialects.postgresql import (
JSONB,
UUID as PGUUID,
)
from sqlalchemy.orm import relationship
from app.models.base import Base, TimestampMixin, UUIDMixin
# Import pgvector type
try:
from pgvector.sqlalchemy import Vector # type: ignore[import-not-found]
except ImportError:
Vector = None
class Procedure(Base, UUIDMixin, TimestampMixin):
"""
Procedural memory model.
Stores learned procedures (skills) extracted from successful
task execution patterns:
- Name and trigger pattern for matching
- Step-by-step actions
- Success/failure tracking
"""
__tablename__ = "procedures"
# Scoping
project_id = Column(
PGUUID(as_uuid=True),
ForeignKey("projects.id", ondelete="CASCADE"),
nullable=True,
index=True,
)
agent_type_id = Column(
PGUUID(as_uuid=True),
ForeignKey("agent_types.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Procedure identification
name = Column(String(255), nullable=False, index=True)
trigger_pattern = Column(Text, nullable=False)
# Steps as JSON array of step objects
# Each step: {order, action, parameters, expected_outcome, fallback_action}
steps = Column(JSONB, default=list, nullable=False)
# Success tracking
success_count = Column(Integer, nullable=False, default=0)
failure_count = Column(Integer, nullable=False, default=0)
# Usage tracking
last_used = Column(DateTime(timezone=True), nullable=True, index=True)
# Vector embedding for semantic matching
embedding = Column(Vector(1536) if Vector else Text, nullable=True)
# Relationships
project = relationship("Project", foreign_keys=[project_id])
agent_type = relationship("AgentType", foreign_keys=[agent_type_id])
__table_args__ = (
# Unique procedure name within scope
Index(
"ix_procedures_unique_name",
"project_id",
"agent_type_id",
"name",
unique=True,
),
# Query patterns
Index("ix_procedures_project_name", "project_id", "name"),
# Note: agent_type_id already has index=True on Column definition
# For finding best procedures
Index("ix_procedures_success_rate", "success_count", "failure_count"),
# Data integrity constraints
CheckConstraint(
"success_count >= 0",
name="ck_procedures_success_positive",
),
CheckConstraint(
"failure_count >= 0",
name="ck_procedures_failure_positive",
),
)
@property
def success_rate(self) -> float:
"""Calculate the success rate of this procedure."""
# Snapshot values to avoid race conditions in concurrent access
success = self.success_count
failure = self.failure_count
total = success + failure
if total == 0:
return 0.0
return success / total
@property
def total_uses(self) -> int:
"""Get total number of times this procedure was used."""
# Snapshot values for consistency
return self.success_count + self.failure_count
def __repr__(self) -> str:
return (
f"<Procedure {self.name} ({self.id}) success_rate={self.success_rate:.2%}>"
)

View File

@@ -0,0 +1,58 @@
# app/models/memory/working_memory.py
"""
Working Memory database model.
Stores ephemeral key-value data for active sessions with TTL support.
Used as database backup when Redis is unavailable.
"""
from sqlalchemy import Column, DateTime, Enum, Index, String
from sqlalchemy.dialects.postgresql import JSONB
from app.models.base import Base, TimestampMixin, UUIDMixin
from .enums import ScopeType
class WorkingMemory(Base, UUIDMixin, TimestampMixin):
"""
Working memory storage table.
Provides database-backed working memory as fallback when
Redis is unavailable. Supports TTL-based expiration.
"""
__tablename__ = "working_memory"
# Scoping
scope_type: Column[ScopeType] = Column(
Enum(ScopeType),
nullable=False,
index=True,
)
scope_id = Column(String(255), nullable=False, index=True)
# Key-value storage
key = Column(String(255), nullable=False)
value = Column(JSONB, nullable=False)
# TTL support
expires_at = Column(DateTime(timezone=True), nullable=True, index=True)
__table_args__ = (
# Primary lookup: scope + key
Index(
"ix_working_memory_scope_key",
"scope_type",
"scope_id",
"key",
unique=True,
),
# For cleanup of expired entries
Index("ix_working_memory_expires", "expires_at"),
# For listing all keys in a scope
Index("ix_working_memory_scope_list", "scope_type", "scope_id"),
)
def __repr__(self) -> str:
return f"<WorkingMemory {self.scope_type.value}:{self.scope_id}:{self.key}>"

View File

@@ -62,7 +62,11 @@ class AgentInstance(Base, UUIDMixin, TimestampMixin):
# Status tracking
status: Column[AgentStatus] = Column(
Enum(AgentStatus),
Enum(
AgentStatus,
name="agent_status",
values_callable=lambda x: [e.value for e in x],
),
default=AgentStatus.IDLE,
nullable=False,
index=True,

View File

@@ -6,7 +6,7 @@ An AgentType is a template that defines the capabilities, personality,
and model configuration for agent instances.
"""
from sqlalchemy import Boolean, Column, Index, String, Text
from sqlalchemy import Boolean, Column, Index, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
@@ -56,6 +56,24 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
# Whether this agent type is available for new instances
is_active = Column(Boolean, default=True, nullable=False, index=True)
# Category for grouping agents (development, design, quality, etc.)
category = Column(String(50), nullable=True, index=True)
# Lucide icon identifier for UI display (e.g., "code", "palette", "shield")
icon = Column(String(50), nullable=True, default="bot")
# Hex color code for visual distinction (e.g., "#3B82F6")
color = Column(String(7), nullable=True, default="#3B82F6")
# Display ordering within category (lower = first)
sort_order = Column(Integer, nullable=False, default=0, index=True)
# List of typical tasks this agent excels at
typical_tasks = Column(JSONB, default=list, nullable=False)
# List of agent slugs that collaborate well with this type
collaboration_hints = Column(JSONB, default=list, nullable=False)
# Relationships
instances = relationship(
"AgentInstance",
@@ -66,6 +84,7 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
__table_args__ = (
Index("ix_agent_types_slug_active", "slug", "is_active"),
Index("ix_agent_types_name_active", "name", "is_active"),
Index("ix_agent_types_category_sort", "category", "sort_order"),
)
def __repr__(self) -> str:

View File

@@ -167,3 +167,29 @@ class SprintStatus(str, PyEnum):
IN_REVIEW = "in_review"
COMPLETED = "completed"
CANCELLED = "cancelled"
class AgentTypeCategory(str, PyEnum):
"""
Category classification for agent types.
Used for grouping and filtering agents in the UI.
DEVELOPMENT: Product, project, and engineering roles
DESIGN: UI/UX and design research roles
QUALITY: QA and security engineering
OPERATIONS: DevOps and MLOps
AI_ML: Machine learning and AI specialists
DATA: Data science and engineering
LEADERSHIP: Technical leadership roles
DOMAIN_EXPERT: Industry and domain specialists
"""
DEVELOPMENT = "development"
DESIGN = "design"
QUALITY = "quality"
OPERATIONS = "operations"
AI_ML = "ai_ml"
DATA = "data"
LEADERSHIP = "leadership"
DOMAIN_EXPERT = "domain_expert"

View File

@@ -59,7 +59,9 @@ class Issue(Base, UUIDMixin, TimestampMixin):
# Issue type (Epic, Story, Task, Bug)
type: Column[IssueType] = Column(
Enum(IssueType),
Enum(
IssueType, name="issue_type", values_callable=lambda x: [e.value for e in x]
),
default=IssueType.TASK,
nullable=False,
index=True,
@@ -78,14 +80,22 @@ class Issue(Base, UUIDMixin, TimestampMixin):
# Status and priority
status: Column[IssueStatus] = Column(
Enum(IssueStatus),
Enum(
IssueStatus,
name="issue_status",
values_callable=lambda x: [e.value for e in x],
),
default=IssueStatus.OPEN,
nullable=False,
index=True,
)
priority: Column[IssuePriority] = Column(
Enum(IssuePriority),
Enum(
IssuePriority,
name="issue_priority",
values_callable=lambda x: [e.value for e in x],
),
default=IssuePriority.MEDIUM,
nullable=False,
index=True,
@@ -132,7 +142,11 @@ class Issue(Base, UUIDMixin, TimestampMixin):
# Sync status with external tracker
sync_status: Column[SyncStatus] = Column(
Enum(SyncStatus),
Enum(
SyncStatus,
name="sync_status",
values_callable=lambda x: [e.value for e in x],
),
default=SyncStatus.SYNCED,
nullable=False,
# Note: Index defined in __table_args__ as ix_issues_sync_status

View File

@@ -35,28 +35,44 @@ class Project(Base, UUIDMixin, TimestampMixin):
description = Column(Text, nullable=True)
autonomy_level: Column[AutonomyLevel] = Column(
Enum(AutonomyLevel),
Enum(
AutonomyLevel,
name="autonomy_level",
values_callable=lambda x: [e.value for e in x],
),
default=AutonomyLevel.MILESTONE,
nullable=False,
index=True,
)
status: Column[ProjectStatus] = Column(
Enum(ProjectStatus),
Enum(
ProjectStatus,
name="project_status",
values_callable=lambda x: [e.value for e in x],
),
default=ProjectStatus.ACTIVE,
nullable=False,
index=True,
)
complexity: Column[ProjectComplexity] = Column(
Enum(ProjectComplexity),
Enum(
ProjectComplexity,
name="project_complexity",
values_callable=lambda x: [e.value for e in x],
),
default=ProjectComplexity.MEDIUM,
nullable=False,
index=True,
)
client_mode: Column[ClientMode] = Column(
Enum(ClientMode),
Enum(
ClientMode,
name="client_mode",
values_callable=lambda x: [e.value for e in x],
),
default=ClientMode.AUTO,
nullable=False,
index=True,

View File

@@ -57,7 +57,11 @@ class Sprint(Base, UUIDMixin, TimestampMixin):
# Status
status: Column[SprintStatus] = Column(
Enum(SprintStatus),
Enum(
SprintStatus,
name="sprint_status",
values_callable=lambda x: [e.value for e in x],
),
default=SprintStatus.PLANNED,
nullable=False,
index=True,

View File

@@ -10,6 +10,8 @@ from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator
from app.models.syndarix.enums import AgentTypeCategory
class AgentTypeBase(BaseModel):
"""Base agent type schema with common fields."""
@@ -26,6 +28,14 @@ class AgentTypeBase(BaseModel):
tool_permissions: dict[str, Any] = Field(default_factory=dict)
is_active: bool = True
# Category and display fields
category: AgentTypeCategory | None = None
icon: str | None = Field(None, max_length=50)
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
sort_order: int = Field(default=0, ge=0, le=1000)
typical_tasks: list[str] = Field(default_factory=list)
collaboration_hints: list[str] = Field(default_factory=list)
@field_validator("slug")
@classmethod
def validate_slug(cls, v: str | None) -> str | None:
@@ -62,6 +72,18 @@ class AgentTypeBase(BaseModel):
"""Validate MCP server list."""
return [s.strip() for s in v if s.strip()]
@field_validator("typical_tasks")
@classmethod
def validate_typical_tasks(cls, v: list[str]) -> list[str]:
"""Validate and normalize typical tasks list."""
return [t.strip() for t in v if t.strip()]
@field_validator("collaboration_hints")
@classmethod
def validate_collaboration_hints(cls, v: list[str]) -> list[str]:
"""Validate and normalize collaboration hints (agent slugs)."""
return [h.strip().lower() for h in v if h.strip()]
class AgentTypeCreate(AgentTypeBase):
"""Schema for creating a new agent type."""
@@ -87,6 +109,14 @@ class AgentTypeUpdate(BaseModel):
tool_permissions: dict[str, Any] | None = None
is_active: bool | None = None
# Category and display fields (all optional for updates)
category: AgentTypeCategory | None = None
icon: str | None = Field(None, max_length=50)
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
sort_order: int | None = Field(None, ge=0, le=1000)
typical_tasks: list[str] | None = None
collaboration_hints: list[str] | None = None
@field_validator("slug")
@classmethod
def validate_slug(cls, v: str | None) -> str | None:
@@ -119,6 +149,22 @@ class AgentTypeUpdate(BaseModel):
return v
return [e.strip().lower() for e in v if e.strip()]
@field_validator("typical_tasks")
@classmethod
def validate_typical_tasks(cls, v: list[str] | None) -> list[str] | None:
"""Validate and normalize typical tasks list."""
if v is None:
return v
return [t.strip() for t in v if t.strip()]
@field_validator("collaboration_hints")
@classmethod
def validate_collaboration_hints(cls, v: list[str] | None) -> list[str] | None:
"""Validate and normalize collaboration hints (agent slugs)."""
if v is None:
return v
return [h.strip().lower() for h in v if h.strip()]
class AgentTypeInDB(AgentTypeBase):
"""Schema for agent type in database."""

View File

@@ -114,6 +114,8 @@ from .types import (
ContextType,
ConversationContext,
KnowledgeContext,
MemoryContext,
MemorySubtype,
MessageRole,
SystemContext,
TaskComplexity,
@@ -149,6 +151,8 @@ __all__ = [
"FormattingError",
"InvalidContextError",
"KnowledgeContext",
"MemoryContext",
"MemorySubtype",
"MessageRole",
"ModelAdapter",
"OpenAIAdapter",

View File

@@ -30,6 +30,7 @@ class TokenBudget:
knowledge: int = 0
conversation: int = 0
tools: int = 0
memory: int = 0 # Agent memory (working, episodic, semantic, procedural)
response_reserve: int = 0
buffer: int = 0
@@ -60,6 +61,7 @@ class TokenBudget:
"knowledge": self.knowledge,
"conversation": self.conversation,
"tool": self.tools,
"memory": self.memory,
}
return allocation_map.get(context_type, 0)
@@ -211,6 +213,7 @@ class TokenBudget:
"knowledge": self.knowledge,
"conversation": self.conversation,
"tools": self.tools,
"memory": self.memory,
"response_reserve": self.response_reserve,
"buffer": self.buffer,
},
@@ -264,9 +267,10 @@ class BudgetAllocator:
total=total_tokens,
system=int(total_tokens * alloc.get("system", 0.05)),
task=int(total_tokens * alloc.get("task", 0.10)),
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
knowledge=int(total_tokens * alloc.get("knowledge", 0.30)),
conversation=int(total_tokens * alloc.get("conversation", 0.15)),
tools=int(total_tokens * alloc.get("tools", 0.05)),
memory=int(total_tokens * alloc.get("memory", 0.15)),
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
)
@@ -317,6 +321,8 @@ class BudgetAllocator:
budget.conversation = max(0, budget.conversation + actual_adjustment)
elif context_type == "tool":
budget.tools = max(0, budget.tools + actual_adjustment)
elif context_type == "memory":
budget.memory = max(0, budget.memory + actual_adjustment)
return budget
@@ -338,7 +344,12 @@ class BudgetAllocator:
Rebalanced budget
"""
if prioritize is None:
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
prioritize = [
ContextType.KNOWLEDGE,
ContextType.MEMORY,
ContextType.TASK,
ContextType.SYSTEM,
]
# Calculate unused tokens per type
unused: dict[str, int] = {}

View File

@@ -7,6 +7,7 @@ Provides a high-level API for assembling optimized context for LLM requests.
import logging
from typing import TYPE_CHECKING, Any
from uuid import UUID
from .assembly import ContextPipeline
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
@@ -20,6 +21,7 @@ from .types import (
BaseContext,
ConversationContext,
KnowledgeContext,
MemoryContext,
MessageRole,
SystemContext,
TaskContext,
@@ -30,6 +32,7 @@ if TYPE_CHECKING:
from redis.asyncio import Redis
from app.services.mcp.client_manager import MCPClientManager
from app.services.memory.integration import MemoryContextSource
logger = logging.getLogger(__name__)
@@ -64,6 +67,7 @@ class ContextEngine:
mcp_manager: "MCPClientManager | None" = None,
redis: "Redis | None" = None,
settings: ContextSettings | None = None,
memory_source: "MemoryContextSource | None" = None,
) -> None:
"""
Initialize the context engine.
@@ -72,9 +76,11 @@ class ContextEngine:
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
redis: Redis connection for caching
settings: Context settings
memory_source: Optional memory context source for agent memory
"""
self._mcp = mcp_manager
self._settings = settings or get_context_settings()
self._memory_source = memory_source
# Initialize components
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
@@ -115,6 +121,15 @@ class ContextEngine:
"""
self._cache.set_redis(redis)
def set_memory_source(self, memory_source: "MemoryContextSource") -> None:
"""
Set memory context source for agent memory integration.
Args:
memory_source: Memory context source
"""
self._memory_source = memory_source
async def assemble_context(
self,
project_id: str,
@@ -126,6 +141,10 @@ class ContextEngine:
task_description: str | None = None,
knowledge_query: str | None = None,
knowledge_limit: int = 10,
memory_query: str | None = None,
memory_limit: int = 20,
session_id: str | None = None,
agent_type_id: str | None = None,
conversation_history: list[dict[str, str]] | None = None,
tool_results: list[dict[str, Any]] | None = None,
custom_contexts: list[BaseContext] | None = None,
@@ -151,6 +170,10 @@ class ContextEngine:
task_description: Current task description
knowledge_query: Query for knowledge base search
knowledge_limit: Max number of knowledge results
memory_query: Query for agent memory search
memory_limit: Max number of memory results
session_id: Session ID for working memory access
agent_type_id: Agent type ID for procedural memory
conversation_history: List of {"role": str, "content": str}
tool_results: List of tool results to include
custom_contexts: Additional custom contexts
@@ -197,15 +220,27 @@ class ContextEngine:
)
contexts.extend(knowledge_contexts)
# 4. Conversation history
# 4. Memory context from Agent Memory System
if memory_query and self._memory_source:
memory_contexts = await self._fetch_memory(
project_id=project_id,
agent_id=agent_id,
query=memory_query,
limit=memory_limit,
session_id=session_id,
agent_type_id=agent_type_id,
)
contexts.extend(memory_contexts)
# 5. Conversation history
if conversation_history:
contexts.extend(self._convert_conversation(conversation_history))
# 5. Tool results
# 6. Tool results
if tool_results:
contexts.extend(self._convert_tool_results(tool_results))
# 6. Custom contexts
# 7. Custom contexts
if custom_contexts:
contexts.extend(custom_contexts)
@@ -308,6 +343,65 @@ class ContextEngine:
logger.warning(f"Failed to fetch knowledge: {e}")
return []
async def _fetch_memory(
self,
project_id: str,
agent_id: str,
query: str,
limit: int = 20,
session_id: str | None = None,
agent_type_id: str | None = None,
) -> list[MemoryContext]:
"""
Fetch relevant memories from Agent Memory System.
Args:
project_id: Project identifier
agent_id: Agent identifier
query: Search query
limit: Maximum results
session_id: Session ID for working memory
agent_type_id: Agent type ID for procedural memory
Returns:
List of MemoryContext instances
"""
if not self._memory_source:
return []
try:
# Import here to avoid circular imports
# Configure fetch limits
from app.services.memory.integration.context_source import MemoryFetchConfig
config = MemoryFetchConfig(
working_limit=min(limit // 4, 5),
episodic_limit=min(limit // 2, 10),
semantic_limit=min(limit // 2, 10),
procedural_limit=min(limit // 4, 5),
include_working=session_id is not None,
)
result = await self._memory_source.fetch_context(
query=query,
project_id=UUID(project_id),
agent_instance_id=UUID(agent_id) if agent_id else None,
agent_type_id=UUID(agent_type_id) if agent_type_id else None,
session_id=session_id,
config=config,
)
logger.debug(
f"Fetched {len(result.contexts)} memory contexts for query: {query}, "
f"by_type: {result.by_type}"
)
return result.contexts[:limit]
except Exception as e:
logger.warning(f"Failed to fetch memory: {e}")
return []
def _convert_conversation(
self,
history: list[dict[str, str]],
@@ -466,6 +560,7 @@ def create_context_engine(
mcp_manager: "MCPClientManager | None" = None,
redis: "Redis | None" = None,
settings: ContextSettings | None = None,
memory_source: "MemoryContextSource | None" = None,
) -> ContextEngine:
"""
Create a context engine instance.
@@ -474,6 +569,7 @@ def create_context_engine(
mcp_manager: MCP client manager
redis: Redis connection
settings: Context settings
memory_source: Optional memory context source
Returns:
Configured ContextEngine instance
@@ -482,4 +578,5 @@ def create_context_engine(
mcp_manager=mcp_manager,
redis=redis,
settings=settings,
memory_source=memory_source,
)

View File

@@ -15,6 +15,10 @@ from .conversation import (
MessageRole,
)
from .knowledge import KnowledgeContext
from .memory import (
MemoryContext,
MemorySubtype,
)
from .system import SystemContext
from .task import (
TaskComplexity,
@@ -33,6 +37,8 @@ __all__ = [
"ContextType",
"ConversationContext",
"KnowledgeContext",
"MemoryContext",
"MemorySubtype",
"MessageRole",
"SystemContext",
"TaskComplexity",

View File

@@ -26,6 +26,7 @@ class ContextType(str, Enum):
KNOWLEDGE = "knowledge"
CONVERSATION = "conversation"
TOOL = "tool"
MEMORY = "memory" # Agent memory (working, episodic, semantic, procedural)
@classmethod
def from_string(cls, value: str) -> "ContextType":

View File

@@ -0,0 +1,282 @@
"""
Memory Context Type.
Represents agent memory as context for LLM requests.
Includes working, episodic, semantic, and procedural memories.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from .base import BaseContext, ContextPriority, ContextType
class MemorySubtype(str, Enum):
"""Types of agent memory."""
WORKING = "working" # Session-scoped temporary data
EPISODIC = "episodic" # Task history and outcomes
SEMANTIC = "semantic" # Facts and knowledge
PROCEDURAL = "procedural" # Learned procedures
@dataclass(eq=False)
class MemoryContext(BaseContext):
"""
Context from agent memory system.
Memory context represents data retrieved from the agent
memory system, including:
- Working memory: Current session state
- Episodic memory: Past task experiences
- Semantic memory: Learned facts and knowledge
- Procedural memory: Known procedures and workflows
Each memory item includes relevance scoring from search.
"""
# Memory-specific fields
memory_subtype: MemorySubtype = field(default=MemorySubtype.EPISODIC)
memory_id: str | None = field(default=None)
relevance_score: float = field(default=0.0)
importance: float = field(default=0.5)
search_query: str = field(default="")
# Type-specific fields (populated based on memory_subtype)
key: str | None = field(default=None) # For working memory
task_type: str | None = field(default=None) # For episodic
outcome: str | None = field(default=None) # For episodic
subject: str | None = field(default=None) # For semantic
predicate: str | None = field(default=None) # For semantic
object_value: str | None = field(default=None) # For semantic
trigger: str | None = field(default=None) # For procedural
success_rate: float | None = field(default=None) # For procedural
def get_type(self) -> ContextType:
"""Return MEMORY context type."""
return ContextType.MEMORY
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary with memory-specific fields."""
base = super().to_dict()
base.update(
{
"memory_subtype": self.memory_subtype.value,
"memory_id": self.memory_id,
"relevance_score": self.relevance_score,
"importance": self.importance,
"search_query": self.search_query,
"key": self.key,
"task_type": self.task_type,
"outcome": self.outcome,
"subject": self.subject,
"predicate": self.predicate,
"object_value": self.object_value,
"trigger": self.trigger,
"success_rate": self.success_rate,
}
)
return base
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MemoryContext":
"""Create MemoryContext from dictionary."""
return cls(
id=data.get("id", ""),
content=data["content"],
source=data["source"],
timestamp=datetime.fromisoformat(data["timestamp"])
if isinstance(data.get("timestamp"), str)
else data.get("timestamp", datetime.now(UTC)),
priority=data.get("priority", ContextPriority.NORMAL.value),
metadata=data.get("metadata", {}),
memory_subtype=MemorySubtype(data.get("memory_subtype", "episodic")),
memory_id=data.get("memory_id"),
relevance_score=data.get("relevance_score", 0.0),
importance=data.get("importance", 0.5),
search_query=data.get("search_query", ""),
key=data.get("key"),
task_type=data.get("task_type"),
outcome=data.get("outcome"),
subject=data.get("subject"),
predicate=data.get("predicate"),
object_value=data.get("object_value"),
trigger=data.get("trigger"),
success_rate=data.get("success_rate"),
)
@classmethod
def from_working_memory(
cls,
key: str,
value: Any,
source: str = "working_memory",
query: str = "",
) -> "MemoryContext":
"""
Create MemoryContext from working memory entry.
Args:
key: Working memory key
value: Value stored at key
source: Source identifier
query: Search query used
Returns:
MemoryContext instance
"""
return cls(
content=str(value),
source=source,
memory_subtype=MemorySubtype.WORKING,
key=key,
relevance_score=1.0, # Working memory is always relevant
importance=0.8, # Higher importance for current session state
search_query=query,
priority=ContextPriority.HIGH.value,
)
@classmethod
def from_episodic_memory(
cls,
episode: Any,
query: str = "",
) -> "MemoryContext":
"""
Create MemoryContext from episodic memory episode.
Args:
episode: Episode object from episodic memory
query: Search query used
Returns:
MemoryContext instance
"""
outcome_val = None
if hasattr(episode, "outcome") and episode.outcome:
outcome_val = (
episode.outcome.value
if hasattr(episode.outcome, "value")
else str(episode.outcome)
)
return cls(
content=episode.task_description,
source=f"episodic:{episode.id}",
memory_subtype=MemorySubtype.EPISODIC,
memory_id=str(episode.id),
relevance_score=getattr(episode, "importance_score", 0.5),
importance=getattr(episode, "importance_score", 0.5),
search_query=query,
task_type=getattr(episode, "task_type", None),
outcome=outcome_val,
metadata={
"session_id": getattr(episode, "session_id", None),
"occurred_at": episode.occurred_at.isoformat()
if hasattr(episode, "occurred_at") and episode.occurred_at
else None,
"lessons_learned": getattr(episode, "lessons_learned", []),
},
)
@classmethod
def from_semantic_memory(
cls,
fact: Any,
query: str = "",
) -> "MemoryContext":
"""
Create MemoryContext from semantic memory fact.
Args:
fact: Fact object from semantic memory
query: Search query used
Returns:
MemoryContext instance
"""
triple = f"{fact.subject} {fact.predicate} {fact.object}"
return cls(
content=triple,
source=f"semantic:{fact.id}",
memory_subtype=MemorySubtype.SEMANTIC,
memory_id=str(fact.id),
relevance_score=getattr(fact, "confidence", 0.5),
importance=getattr(fact, "confidence", 0.5),
search_query=query,
subject=fact.subject,
predicate=fact.predicate,
object_value=fact.object,
priority=ContextPriority.NORMAL.value,
)
@classmethod
def from_procedural_memory(
cls,
procedure: Any,
query: str = "",
) -> "MemoryContext":
"""
Create MemoryContext from procedural memory procedure.
Args:
procedure: Procedure object from procedural memory
query: Search query used
Returns:
MemoryContext instance
"""
# Format steps as content
steps = getattr(procedure, "steps", [])
steps_content = "\n".join(
f" {i + 1}. {step.get('action', step) if isinstance(step, dict) else step}"
for i, step in enumerate(steps)
)
content = f"Procedure: {procedure.name}\nTrigger: {procedure.trigger_pattern}\nSteps:\n{steps_content}"
return cls(
content=content,
source=f"procedural:{procedure.id}",
memory_subtype=MemorySubtype.PROCEDURAL,
memory_id=str(procedure.id),
relevance_score=getattr(procedure, "success_rate", 0.5),
importance=0.7, # Procedures are moderately important
search_query=query,
trigger=procedure.trigger_pattern,
success_rate=getattr(procedure, "success_rate", None),
metadata={
"steps_count": len(steps),
"execution_count": getattr(procedure, "success_count", 0)
+ getattr(procedure, "failure_count", 0),
},
)
def is_working_memory(self) -> bool:
"""Check if this is working memory."""
return self.memory_subtype == MemorySubtype.WORKING
def is_episodic_memory(self) -> bool:
"""Check if this is episodic memory."""
return self.memory_subtype == MemorySubtype.EPISODIC
def is_semantic_memory(self) -> bool:
"""Check if this is semantic memory."""
return self.memory_subtype == MemorySubtype.SEMANTIC
def is_procedural_memory(self) -> bool:
"""Check if this is procedural memory."""
return self.memory_subtype == MemorySubtype.PROCEDURAL
def get_formatted_source(self) -> str:
"""
Get a formatted source string for display.
Returns:
Formatted source string
"""
parts = [f"[{self.memory_subtype.value}]", self.source]
if self.memory_id:
parts.append(f"({self.memory_id[:8]}...)")
return " ".join(parts)

View File

@@ -122,16 +122,24 @@ class MCPClientManager:
)
async def _connect_all_servers(self) -> None:
"""Connect to all enabled MCP servers."""
"""Connect to all enabled MCP servers concurrently."""
import asyncio
enabled_servers = self._registry.get_enabled_configs()
for name, config in enabled_servers.items():
async def connect_server(name: str, config: "MCPServerConfig") -> None:
try:
await self._pool.get_connection(name, config)
logger.info("Connected to MCP server: %s", name)
except Exception as e:
logger.error("Failed to connect to MCP server %s: %s", name, e)
# Connect to all servers concurrently for faster startup
await asyncio.gather(
*(connect_server(name, config) for name, config in enabled_servers.items()),
return_exceptions=True,
)
async def shutdown(self) -> None:
"""
Shutdown the MCP client manager.

View File

@@ -179,6 +179,8 @@ def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
2. MCP_CONFIG_PATH environment variable
3. Default path (backend/mcp_servers.yaml)
4. Empty config if no file exists
In test mode (IS_TEST=True), retry settings are reduced for faster tests.
"""
if path is None:
path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH))
@@ -189,7 +191,18 @@ def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
# Return empty config if no file exists (allows runtime registration)
return MCPConfig()
return MCPConfig.from_yaml(path)
config = MCPConfig.from_yaml(path)
# In test mode, reduce retry settings to speed up tests
is_test = os.environ.get("IS_TEST", "").lower() in ("true", "1", "yes")
if is_test:
for server_config in config.mcp_servers.values():
server_config.retry_attempts = 1 # Single attempt
server_config.retry_delay = 0.1 # 100ms instead of 1s
server_config.retry_max_delay = 0.5 # 500ms max
server_config.timeout = 2 # 2s timeout instead of 30-120s
return config
def create_default_config() -> MCPConfig:

View File

@@ -0,0 +1,141 @@
"""
Agent Memory System
Multi-tier cognitive memory for AI agents, providing:
- Working Memory: Session-scoped ephemeral state (Redis/In-memory)
- Episodic Memory: Experiential records of past tasks (PostgreSQL)
- Semantic Memory: Learned facts and knowledge (PostgreSQL + pgvector)
- Procedural Memory: Learned skills and procedures (PostgreSQL)
Usage:
from app.services.memory import (
MemoryManager,
MemorySettings,
get_memory_settings,
MemoryType,
ScopeLevel,
)
# Create a manager for a session
manager = MemoryManager.for_session(
session_id="sess-123",
project_id=uuid,
)
async with manager:
# Working memory
await manager.set_working("key", {"data": "value"})
value = await manager.get_working("key")
# Episodic memory
episode = await manager.record_episode(episode_data)
similar = await manager.search_episodes("query")
# Semantic memory
fact = await manager.store_fact(fact_data)
facts = await manager.search_facts("query")
# Procedural memory
procedure = await manager.record_procedure(procedure_data)
procedures = await manager.find_procedures("context")
"""
# Configuration
from .config import (
MemorySettings,
get_default_settings,
get_memory_settings,
reset_memory_settings,
)
# Exceptions
from .exceptions import (
CheckpointError,
EmbeddingError,
MemoryCapacityError,
MemoryConflictError,
MemoryConsolidationError,
MemoryError,
MemoryExpiredError,
MemoryNotFoundError,
MemoryRetrievalError,
MemoryScopeError,
MemorySerializationError,
MemoryStorageError,
)
# Manager
from .manager import MemoryManager
# Types
from .types import (
ConsolidationStatus,
ConsolidationType,
Episode,
EpisodeCreate,
Fact,
FactCreate,
MemoryItem,
MemoryStats,
MemoryStore,
MemoryType,
Outcome,
Procedure,
ProcedureCreate,
RetrievalResult,
ScopeContext,
ScopeLevel,
Step,
TaskState,
WorkingMemoryItem,
)
# Reflection (lazy import available)
# Import directly: from app.services.memory.reflection import MemoryReflection
__all__ = [
"CheckpointError",
"ConsolidationStatus",
"ConsolidationType",
"EmbeddingError",
"Episode",
"EpisodeCreate",
"Fact",
"FactCreate",
"MemoryCapacityError",
"MemoryConflictError",
"MemoryConsolidationError",
# Exceptions
"MemoryError",
"MemoryExpiredError",
"MemoryItem",
# Manager
"MemoryManager",
"MemoryNotFoundError",
"MemoryRetrievalError",
"MemoryScopeError",
"MemorySerializationError",
# Configuration
"MemorySettings",
"MemoryStats",
"MemoryStorageError",
# Types - Abstract
"MemoryStore",
# Types - Enums
"MemoryType",
"Outcome",
"Procedure",
"ProcedureCreate",
"RetrievalResult",
# Types - Data Classes
"ScopeContext",
"ScopeLevel",
"Step",
"TaskState",
"WorkingMemoryItem",
"get_default_settings",
"get_memory_settings",
"reset_memory_settings",
# MCP Tools - lazy import to avoid circular dependencies
# Import directly: from app.services.memory.mcp import MemoryToolService
]

View File

@@ -0,0 +1,21 @@
# app/services/memory/cache/__init__.py
"""
Memory Caching Layer.
Provides caching for memory operations:
- Hot Memory Cache: LRU cache for frequently accessed memories
- Embedding Cache: Cache embeddings by content hash
- Cache Manager: Unified cache management with invalidation
"""
from .cache_manager import CacheManager, CacheStats, get_cache_manager
from .embedding_cache import EmbeddingCache
from .hot_cache import HotMemoryCache
__all__ = [
"CacheManager",
"CacheStats",
"EmbeddingCache",
"HotMemoryCache",
"get_cache_manager",
]

View File

@@ -0,0 +1,505 @@
# app/services/memory/cache/cache_manager.py
"""
Cache Manager.
Unified cache management for memory operations.
Coordinates hot cache, embedding cache, and retrieval cache.
Provides centralized invalidation and statistics.
"""
import logging
import threading
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from uuid import UUID
from app.services.memory.config import get_memory_settings
from .embedding_cache import EmbeddingCache, create_embedding_cache
from .hot_cache import CacheKey, HotMemoryCache, create_hot_cache
if TYPE_CHECKING:
from redis.asyncio import Redis
from app.services.memory.indexing.retrieval import RetrievalCache
logger = logging.getLogger(__name__)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class CacheStats:
"""Aggregated cache statistics."""
hot_cache: dict[str, Any] = field(default_factory=dict)
embedding_cache: dict[str, Any] = field(default_factory=dict)
retrieval_cache: dict[str, Any] = field(default_factory=dict)
overall_hit_rate: float = 0.0
last_cleanup: datetime | None = None
cleanup_count: int = 0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"hot_cache": self.hot_cache,
"embedding_cache": self.embedding_cache,
"retrieval_cache": self.retrieval_cache,
"overall_hit_rate": self.overall_hit_rate,
"last_cleanup": self.last_cleanup.isoformat()
if self.last_cleanup
else None,
"cleanup_count": self.cleanup_count,
}
class CacheManager:
"""
Unified cache manager for memory operations.
Provides:
- Centralized cache configuration
- Coordinated invalidation across caches
- Aggregated statistics
- Automatic cleanup scheduling
Performance targets:
- Overall cache hit rate > 80%
- Cache operations < 1ms (memory), < 5ms (Redis)
"""
def __init__(
self,
hot_cache: HotMemoryCache[Any] | None = None,
embedding_cache: EmbeddingCache | None = None,
retrieval_cache: "RetrievalCache | None" = None,
redis: "Redis | None" = None,
) -> None:
"""
Initialize the cache manager.
Args:
hot_cache: Optional pre-configured hot cache
embedding_cache: Optional pre-configured embedding cache
retrieval_cache: Optional pre-configured retrieval cache
redis: Optional Redis connection for persistence
"""
self._settings = get_memory_settings()
self._redis = redis
self._enabled = self._settings.cache_enabled
# Initialize caches
if hot_cache:
self._hot_cache = hot_cache
else:
self._hot_cache = create_hot_cache(
max_size=self._settings.cache_max_items,
default_ttl_seconds=self._settings.cache_ttl_seconds,
)
if embedding_cache:
self._embedding_cache = embedding_cache
else:
self._embedding_cache = create_embedding_cache(
max_size=self._settings.cache_max_items,
default_ttl_seconds=self._settings.cache_ttl_seconds
* 12, # 1hr for embeddings
redis=redis,
)
self._retrieval_cache = retrieval_cache
# Stats tracking
self._last_cleanup: datetime | None = None
self._cleanup_count = 0
self._lock = threading.RLock()
logger.info(
f"Initialized CacheManager: enabled={self._enabled}, "
f"redis={'connected' if redis else 'disabled'}"
)
def set_redis(self, redis: "Redis") -> None:
"""Set Redis connection for all caches."""
self._redis = redis
self._embedding_cache.set_redis(redis)
def set_retrieval_cache(self, cache: "RetrievalCache") -> None:
"""Set retrieval cache instance."""
self._retrieval_cache = cache
@property
def is_enabled(self) -> bool:
"""Check if caching is enabled."""
return self._enabled
@property
def hot_cache(self) -> HotMemoryCache[Any]:
"""Get the hot memory cache."""
return self._hot_cache
@property
def embedding_cache(self) -> EmbeddingCache:
"""Get the embedding cache."""
return self._embedding_cache
@property
def retrieval_cache(self) -> "RetrievalCache | None":
"""Get the retrieval cache."""
return self._retrieval_cache
# =========================================================================
# Hot Memory Cache Operations
# =========================================================================
def get_memory(
self,
memory_type: str,
memory_id: UUID | str,
scope: str | None = None,
) -> Any | None:
"""
Get a memory from hot cache.
Args:
memory_type: Type of memory
memory_id: Memory ID
scope: Optional scope
Returns:
Cached memory or None
"""
if not self._enabled:
return None
return self._hot_cache.get_by_id(memory_type, memory_id, scope)
def cache_memory(
self,
memory_type: str,
memory_id: UUID | str,
memory: Any,
scope: str | None = None,
ttl_seconds: float | None = None,
) -> None:
"""
Cache a memory in hot cache.
Args:
memory_type: Type of memory
memory_id: Memory ID
memory: Memory object
scope: Optional scope
ttl_seconds: Optional TTL override
"""
if not self._enabled:
return
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope, ttl_seconds)
# =========================================================================
# Embedding Cache Operations
# =========================================================================
async def get_embedding(
self,
content: str,
model: str = "default",
) -> list[float] | None:
"""
Get a cached embedding.
Args:
content: Content text
model: Model name
Returns:
Cached embedding or None
"""
if not self._enabled:
return None
return await self._embedding_cache.get(content, model)
async def cache_embedding(
self,
content: str,
embedding: list[float],
model: str = "default",
ttl_seconds: float | None = None,
) -> str:
"""
Cache an embedding.
Args:
content: Content text
embedding: Embedding vector
model: Model name
ttl_seconds: Optional TTL override
Returns:
Content hash
"""
if not self._enabled:
return EmbeddingCache.hash_content(content)
return await self._embedding_cache.put(content, embedding, model, ttl_seconds)
# =========================================================================
# Invalidation
# =========================================================================
async def invalidate_memory(
self,
memory_type: str,
memory_id: UUID | str,
scope: str | None = None,
) -> int:
"""
Invalidate a memory across all caches.
Args:
memory_type: Type of memory
memory_id: Memory ID
scope: Optional scope
Returns:
Number of entries invalidated
"""
count = 0
# Invalidate hot cache
if self._hot_cache.invalidate_by_id(memory_type, memory_id, scope):
count += 1
# Invalidate retrieval cache
if self._retrieval_cache:
uuid_id = (
UUID(str(memory_id)) if not isinstance(memory_id, UUID) else memory_id
)
count += self._retrieval_cache.invalidate_by_memory(uuid_id)
logger.debug(f"Invalidated {count} cache entries for {memory_type}:{memory_id}")
return count
async def invalidate_by_type(self, memory_type: str) -> int:
"""
Invalidate all entries of a memory type.
Args:
memory_type: Type of memory
Returns:
Number of entries invalidated
"""
count = self._hot_cache.invalidate_by_type(memory_type)
if self._retrieval_cache:
count += self._retrieval_cache.clear()
logger.info(f"Invalidated {count} cache entries for type {memory_type}")
return count
async def invalidate_by_scope(self, scope: str) -> int:
"""
Invalidate all entries in a scope.
Args:
scope: Scope to invalidate (e.g., project_id)
Returns:
Number of entries invalidated
"""
count = self._hot_cache.invalidate_by_scope(scope)
# Retrieval cache doesn't support scope-based invalidation
# so we clear it entirely for safety
if self._retrieval_cache:
count += self._retrieval_cache.clear()
logger.info(f"Invalidated {count} cache entries for scope {scope}")
return count
async def invalidate_embedding(
self,
content: str,
model: str = "default",
) -> bool:
"""
Invalidate a cached embedding.
Args:
content: Content text
model: Model name
Returns:
True if entry was found and removed
"""
return await self._embedding_cache.invalidate(content, model)
async def clear_all(self) -> int:
"""
Clear all caches.
Returns:
Total number of entries cleared
"""
count = 0
count += self._hot_cache.clear()
count += await self._embedding_cache.clear()
if self._retrieval_cache:
count += self._retrieval_cache.clear()
logger.info(f"Cleared {count} entries from all caches")
return count
# =========================================================================
# Cleanup
# =========================================================================
async def cleanup_expired(self) -> int:
"""
Clean up expired entries from all caches.
Returns:
Number of entries cleaned up
"""
with self._lock:
count = 0
count += self._hot_cache.cleanup_expired()
count += self._embedding_cache.cleanup_expired()
# Retrieval cache doesn't have a cleanup method,
# but entries expire on access
self._last_cleanup = _utcnow()
self._cleanup_count += 1
if count > 0:
logger.info(f"Cleaned up {count} expired cache entries")
return count
# =========================================================================
# Statistics
# =========================================================================
def get_stats(self) -> CacheStats:
"""
Get aggregated cache statistics.
Returns:
CacheStats with all cache metrics
"""
hot_stats = self._hot_cache.get_stats().to_dict()
emb_stats = self._embedding_cache.get_stats().to_dict()
retrieval_stats: dict[str, Any] = {}
if self._retrieval_cache:
retrieval_stats = self._retrieval_cache.get_stats()
# Calculate overall hit rate
total_hits = hot_stats.get("hits", 0) + emb_stats.get("hits", 0)
total_misses = hot_stats.get("misses", 0) + emb_stats.get("misses", 0)
if retrieval_stats:
# Retrieval cache doesn't track hits/misses the same way
pass
total_requests = total_hits + total_misses
overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0
return CacheStats(
hot_cache=hot_stats,
embedding_cache=emb_stats,
retrieval_cache=retrieval_stats,
overall_hit_rate=overall_hit_rate,
last_cleanup=self._last_cleanup,
cleanup_count=self._cleanup_count,
)
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
"""
Get the most frequently accessed memories.
Args:
limit: Maximum number to return
Returns:
List of (key, access_count) tuples
"""
return self._hot_cache.get_hot_memories(limit)
def reset_stats(self) -> None:
"""Reset all cache statistics."""
self._hot_cache.reset_stats()
self._embedding_cache.reset_stats()
# =========================================================================
# Warmup
# =========================================================================
async def warmup(
self,
memories: list[tuple[str, UUID | str, Any]],
scope: str | None = None,
) -> int:
"""
Warm up the hot cache with memories.
Args:
memories: List of (memory_type, memory_id, memory) tuples
scope: Optional scope for all memories
Returns:
Number of memories cached
"""
if not self._enabled:
return 0
for memory_type, memory_id, memory in memories:
self._hot_cache.put_by_id(memory_type, memory_id, memory, scope)
logger.info(f"Warmed up cache with {len(memories)} memories")
return len(memories)
# Singleton instance
_cache_manager: CacheManager | None = None
_cache_manager_lock = threading.Lock()
def get_cache_manager(
redis: "Redis | None" = None,
reset: bool = False,
) -> CacheManager:
"""
Get the global CacheManager instance.
Thread-safe with double-checked locking pattern.
Args:
redis: Optional Redis connection
reset: Force create a new instance
Returns:
CacheManager instance
"""
global _cache_manager
if reset or _cache_manager is None:
with _cache_manager_lock:
if reset or _cache_manager is None:
_cache_manager = CacheManager(redis=redis)
return _cache_manager
def reset_cache_manager() -> None:
"""Reset the global cache manager instance."""
global _cache_manager
with _cache_manager_lock:
_cache_manager = None

View File

@@ -0,0 +1,623 @@
# app/services/memory/cache/embedding_cache.py
"""
Embedding Cache.
Caches embeddings by content hash to avoid recomputing.
Provides significant performance improvement for repeated content.
"""
import hashlib
import logging
import threading
from collections import OrderedDict
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from redis.asyncio import Redis
logger = logging.getLogger(__name__)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class EmbeddingEntry:
"""A cached embedding entry."""
embedding: list[float]
content_hash: str
model: str
created_at: datetime
ttl_seconds: float = 3600.0 # 1 hour default
def is_expired(self) -> bool:
"""Check if this entry has expired."""
age = (_utcnow() - self.created_at).total_seconds()
return age > self.ttl_seconds
@dataclass
class EmbeddingCacheStats:
"""Statistics for the embedding cache."""
hits: int = 0
misses: int = 0
evictions: int = 0
expirations: int = 0
current_size: int = 0
max_size: int = 0
bytes_saved: int = 0 # Estimated bytes saved by caching
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate."""
total = self.hits + self.misses
if total == 0:
return 0.0
return self.hits / total
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"hits": self.hits,
"misses": self.misses,
"evictions": self.evictions,
"expirations": self.expirations,
"current_size": self.current_size,
"max_size": self.max_size,
"hit_rate": self.hit_rate,
"bytes_saved": self.bytes_saved,
}
class EmbeddingCache:
"""
Cache for embeddings by content hash.
Features:
- Content-hash based deduplication
- LRU eviction
- TTL-based expiration
- Optional Redis backing for persistence
- Thread-safe operations
Performance targets:
- Cache hit rate > 90% for repeated content
- Get/put operations < 1ms (memory), < 5ms (Redis)
"""
def __init__(
self,
max_size: int = 50000,
default_ttl_seconds: float = 3600.0,
redis: "Redis | None" = None,
redis_prefix: str = "mem:emb",
) -> None:
"""
Initialize the embedding cache.
Args:
max_size: Maximum number of entries in memory cache
default_ttl_seconds: Default TTL for entries (1 hour)
redis: Optional Redis connection for persistence
redis_prefix: Prefix for Redis keys
"""
self._max_size = max_size
self._default_ttl = default_ttl_seconds
self._cache: OrderedDict[str, EmbeddingEntry] = OrderedDict()
self._lock = threading.RLock()
self._stats = EmbeddingCacheStats(max_size=max_size)
self._redis = redis
self._redis_prefix = redis_prefix
logger.info(
f"Initialized EmbeddingCache with max_size={max_size}, "
f"ttl={default_ttl_seconds}s, redis={'enabled' if redis else 'disabled'}"
)
def set_redis(self, redis: "Redis") -> None:
"""Set Redis connection for persistence."""
self._redis = redis
@staticmethod
def hash_content(content: str) -> str:
"""
Compute hash of content for cache key.
Args:
content: Content to hash
Returns:
32-character hex hash
"""
return hashlib.sha256(content.encode()).hexdigest()[:32]
def _cache_key(self, content_hash: str, model: str) -> str:
"""Build cache key from content hash and model."""
return f"{content_hash}:{model}"
def _redis_key(self, content_hash: str, model: str) -> str:
"""Build Redis key from content hash and model."""
return f"{self._redis_prefix}:{content_hash}:{model}"
async def get(
self,
content: str,
model: str = "default",
) -> list[float] | None:
"""
Get a cached embedding.
Args:
content: Content text
model: Model name
Returns:
Cached embedding or None if not found/expired
"""
content_hash = self.hash_content(content)
cache_key = self._cache_key(content_hash, model)
# Check memory cache first
with self._lock:
if cache_key in self._cache:
entry = self._cache[cache_key]
if entry.is_expired():
del self._cache[cache_key]
self._stats.expirations += 1
self._stats.current_size = len(self._cache)
else:
# Move to end (most recently used)
self._cache.move_to_end(cache_key)
self._stats.hits += 1
return entry.embedding
# Check Redis if available
if self._redis:
try:
redis_key = self._redis_key(content_hash, model)
data = await self._redis.get(redis_key)
if data:
import json
embedding = json.loads(data)
# Store in memory cache for faster access
self._put_memory(content_hash, model, embedding)
self._stats.hits += 1
return embedding
except Exception as e:
logger.warning(f"Redis get error: {e}")
self._stats.misses += 1
return None
async def get_by_hash(
self,
content_hash: str,
model: str = "default",
) -> list[float] | None:
"""
Get a cached embedding by hash.
Args:
content_hash: Content hash
model: Model name
Returns:
Cached embedding or None if not found/expired
"""
cache_key = self._cache_key(content_hash, model)
with self._lock:
if cache_key in self._cache:
entry = self._cache[cache_key]
if entry.is_expired():
del self._cache[cache_key]
self._stats.expirations += 1
self._stats.current_size = len(self._cache)
else:
self._cache.move_to_end(cache_key)
self._stats.hits += 1
return entry.embedding
# Check Redis
if self._redis:
try:
redis_key = self._redis_key(content_hash, model)
data = await self._redis.get(redis_key)
if data:
import json
embedding = json.loads(data)
self._put_memory(content_hash, model, embedding)
self._stats.hits += 1
return embedding
except Exception as e:
logger.warning(f"Redis get error: {e}")
self._stats.misses += 1
return None
async def put(
self,
content: str,
embedding: list[float],
model: str = "default",
ttl_seconds: float | None = None,
) -> str:
"""
Cache an embedding.
Args:
content: Content text
embedding: Embedding vector
model: Model name
ttl_seconds: Optional TTL override
Returns:
Content hash
"""
content_hash = self.hash_content(content)
ttl = ttl_seconds or self._default_ttl
# Store in memory
self._put_memory(content_hash, model, embedding, ttl)
# Store in Redis if available
if self._redis:
try:
import json
redis_key = self._redis_key(content_hash, model)
await self._redis.setex(
redis_key,
int(ttl),
json.dumps(embedding),
)
except Exception as e:
logger.warning(f"Redis put error: {e}")
return content_hash
def _put_memory(
self,
content_hash: str,
model: str,
embedding: list[float],
ttl_seconds: float | None = None,
) -> None:
"""Store in memory cache."""
with self._lock:
# Evict if at capacity
self._evict_if_needed()
cache_key = self._cache_key(content_hash, model)
entry = EmbeddingEntry(
embedding=embedding,
content_hash=content_hash,
model=model,
created_at=_utcnow(),
ttl_seconds=ttl_seconds or self._default_ttl,
)
self._cache[cache_key] = entry
self._cache.move_to_end(cache_key)
self._stats.current_size = len(self._cache)
def _evict_if_needed(self) -> None:
"""Evict entries if cache is at capacity."""
while len(self._cache) >= self._max_size:
if self._cache:
self._cache.popitem(last=False)
self._stats.evictions += 1
async def put_batch(
self,
items: list[tuple[str, list[float]]],
model: str = "default",
ttl_seconds: float | None = None,
) -> list[str]:
"""
Cache multiple embeddings.
Args:
items: List of (content, embedding) tuples
model: Model name
ttl_seconds: Optional TTL override
Returns:
List of content hashes
"""
hashes = []
for content, embedding in items:
content_hash = await self.put(content, embedding, model, ttl_seconds)
hashes.append(content_hash)
return hashes
async def invalidate(
self,
content: str,
model: str = "default",
) -> bool:
"""
Invalidate a cached embedding.
Args:
content: Content text
model: Model name
Returns:
True if entry was found and removed
"""
content_hash = self.hash_content(content)
return await self.invalidate_by_hash(content_hash, model)
async def invalidate_by_hash(
self,
content_hash: str,
model: str = "default",
) -> bool:
"""
Invalidate a cached embedding by hash.
Args:
content_hash: Content hash
model: Model name
Returns:
True if entry was found and removed
"""
cache_key = self._cache_key(content_hash, model)
removed = False
with self._lock:
if cache_key in self._cache:
del self._cache[cache_key]
self._stats.current_size = len(self._cache)
removed = True
# Remove from Redis
if self._redis:
try:
redis_key = self._redis_key(content_hash, model)
await self._redis.delete(redis_key)
removed = True
except Exception as e:
logger.warning(f"Redis delete error: {e}")
return removed
async def invalidate_by_model(self, model: str) -> int:
"""
Invalidate all embeddings for a model.
Args:
model: Model name
Returns:
Number of entries invalidated
"""
count = 0
with self._lock:
keys_to_remove = [k for k, v in self._cache.items() if v.model == model]
for key in keys_to_remove:
del self._cache[key]
count += 1
self._stats.current_size = len(self._cache)
# Note: Redis pattern deletion would require SCAN which is expensive
# For now, we only clear memory cache for model-based invalidation
return count
async def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries cleared
"""
with self._lock:
count = len(self._cache)
self._cache.clear()
self._stats.current_size = 0
# Clear Redis entries
if self._redis:
try:
pattern = f"{self._redis_prefix}:*"
deleted = 0
async for key in self._redis.scan_iter(match=pattern):
await self._redis.delete(key)
deleted += 1
count = max(count, deleted)
except Exception as e:
logger.warning(f"Redis clear error: {e}")
logger.info(f"Cleared {count} entries from embedding cache")
return count
def cleanup_expired(self) -> int:
"""
Remove all expired entries from memory cache.
Returns:
Number of entries removed
"""
with self._lock:
keys_to_remove = [k for k, v in self._cache.items() if v.is_expired()]
for key in keys_to_remove:
del self._cache[key]
self._stats.expirations += 1
self._stats.current_size = len(self._cache)
if keys_to_remove:
logger.debug(f"Cleaned up {len(keys_to_remove)} expired embeddings")
return len(keys_to_remove)
def get_stats(self) -> EmbeddingCacheStats:
"""Get cache statistics."""
with self._lock:
self._stats.current_size = len(self._cache)
return self._stats
def reset_stats(self) -> None:
"""Reset cache statistics."""
with self._lock:
self._stats = EmbeddingCacheStats(
max_size=self._max_size,
current_size=len(self._cache),
)
@property
def size(self) -> int:
"""Get current cache size."""
return len(self._cache)
@property
def max_size(self) -> int:
"""Get maximum cache size."""
return self._max_size
class CachedEmbeddingGenerator:
"""
Wrapper for embedding generators with caching.
Wraps an embedding generator to cache results.
"""
def __init__(
self,
generator: Any,
cache: EmbeddingCache,
model: str = "default",
) -> None:
"""
Initialize the cached embedding generator.
Args:
generator: Underlying embedding generator
cache: Embedding cache
model: Model name for cache keys
"""
self._generator = generator
self._cache = cache
self._model = model
self._call_count = 0
self._cache_hit_count = 0
async def generate(self, text: str) -> list[float]:
"""
Generate embedding with caching.
Args:
text: Text to embed
Returns:
Embedding vector
"""
self._call_count += 1
# Check cache first
cached = await self._cache.get(text, self._model)
if cached is not None:
self._cache_hit_count += 1
return cached
# Generate and cache
embedding = await self._generator.generate(text)
await self._cache.put(text, embedding, self._model)
return embedding
async def generate_batch(
self,
texts: list[str],
) -> list[list[float]]:
"""
Generate embeddings for multiple texts with caching.
Args:
texts: Texts to embed
Returns:
List of embedding vectors
"""
results: list[list[float] | None] = [None] * len(texts)
to_generate: list[tuple[int, str]] = []
# Check cache for each text
for i, text in enumerate(texts):
cached = await self._cache.get(text, self._model)
if cached is not None:
results[i] = cached
self._cache_hit_count += 1
else:
to_generate.append((i, text))
self._call_count += len(texts)
# Generate missing embeddings
if to_generate:
if hasattr(self._generator, "generate_batch"):
texts_to_gen = [t for _, t in to_generate]
embeddings = await self._generator.generate_batch(texts_to_gen)
for (idx, text), embedding in zip(to_generate, embeddings, strict=True):
results[idx] = embedding
await self._cache.put(text, embedding, self._model)
else:
# Fallback to individual generation
for idx, text in to_generate:
embedding = await self._generator.generate(text)
results[idx] = embedding
await self._cache.put(text, embedding, self._model)
return results # type: ignore[return-value]
def get_stats(self) -> dict[str, Any]:
"""Get generator statistics."""
return {
"call_count": self._call_count,
"cache_hit_count": self._cache_hit_count,
"cache_hit_rate": (
self._cache_hit_count / self._call_count
if self._call_count > 0
else 0.0
),
"cache_stats": self._cache.get_stats().to_dict(),
}
# Factory function
def create_embedding_cache(
max_size: int = 50000,
default_ttl_seconds: float = 3600.0,
redis: "Redis | None" = None,
) -> EmbeddingCache:
"""
Create an embedding cache.
Args:
max_size: Maximum number of entries
default_ttl_seconds: Default TTL for entries
redis: Optional Redis connection
Returns:
Configured EmbeddingCache instance
"""
return EmbeddingCache(
max_size=max_size,
default_ttl_seconds=default_ttl_seconds,
redis=redis,
)

View File

@@ -0,0 +1,461 @@
# app/services/memory/cache/hot_cache.py
"""
Hot Memory Cache.
LRU cache for frequently accessed memories.
Provides fast access to recently used memories without database queries.
"""
import logging
import threading
from collections import OrderedDict
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
logger = logging.getLogger(__name__)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class CacheEntry[T]:
"""A cached memory entry with metadata."""
value: T
created_at: datetime
last_accessed_at: datetime
access_count: int = 1
ttl_seconds: float = 300.0
def is_expired(self) -> bool:
"""Check if this entry has expired."""
age = (_utcnow() - self.created_at).total_seconds()
return age > self.ttl_seconds
def touch(self) -> None:
"""Update access time and count."""
self.last_accessed_at = _utcnow()
self.access_count += 1
@dataclass
class CacheKey:
"""A structured cache key with components."""
memory_type: str
memory_id: str
scope: str | None = None
def __hash__(self) -> int:
return hash((self.memory_type, self.memory_id, self.scope))
def __eq__(self, other: object) -> bool:
if not isinstance(other, CacheKey):
return False
return (
self.memory_type == other.memory_type
and self.memory_id == other.memory_id
and self.scope == other.scope
)
def __str__(self) -> str:
if self.scope:
return f"{self.memory_type}:{self.scope}:{self.memory_id}"
return f"{self.memory_type}:{self.memory_id}"
@dataclass
class HotCacheStats:
"""Statistics for the hot memory cache."""
hits: int = 0
misses: int = 0
evictions: int = 0
expirations: int = 0
current_size: int = 0
max_size: int = 0
@property
def hit_rate(self) -> float:
"""Calculate cache hit rate."""
total = self.hits + self.misses
if total == 0:
return 0.0
return self.hits / total
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"hits": self.hits,
"misses": self.misses,
"evictions": self.evictions,
"expirations": self.expirations,
"current_size": self.current_size,
"max_size": self.max_size,
"hit_rate": self.hit_rate,
}
class HotMemoryCache[T]:
"""
LRU cache for frequently accessed memories.
Features:
- LRU eviction when capacity is reached
- TTL-based expiration
- Access count tracking for hot memory identification
- Thread-safe operations
- Scoped invalidation
Performance targets:
- Cache hit rate > 80% for hot memories
- Get/put operations < 1ms
"""
def __init__(
self,
max_size: int = 10000,
default_ttl_seconds: float = 300.0,
) -> None:
"""
Initialize the hot memory cache.
Args:
max_size: Maximum number of entries
default_ttl_seconds: Default TTL for entries (5 minutes)
"""
self._max_size = max_size
self._default_ttl = default_ttl_seconds
self._cache: OrderedDict[CacheKey, CacheEntry[T]] = OrderedDict()
self._lock = threading.RLock()
self._stats = HotCacheStats(max_size=max_size)
logger.info(
f"Initialized HotMemoryCache with max_size={max_size}, "
f"ttl={default_ttl_seconds}s"
)
def get(self, key: CacheKey) -> T | None:
"""
Get a memory from cache.
Args:
key: Cache key
Returns:
Cached value or None if not found/expired
"""
with self._lock:
if key not in self._cache:
self._stats.misses += 1
return None
entry = self._cache[key]
# Check expiration
if entry.is_expired():
del self._cache[key]
self._stats.expirations += 1
self._stats.misses += 1
self._stats.current_size = len(self._cache)
return None
# Move to end (most recently used)
self._cache.move_to_end(key)
entry.touch()
self._stats.hits += 1
return entry.value
def get_by_id(
self,
memory_type: str,
memory_id: UUID | str,
scope: str | None = None,
) -> T | None:
"""
Get a memory by type and ID.
Args:
memory_type: Type of memory (episodic, semantic, procedural)
memory_id: Memory ID
scope: Optional scope (project_id, agent_id)
Returns:
Cached value or None if not found/expired
"""
key = CacheKey(
memory_type=memory_type,
memory_id=str(memory_id),
scope=scope,
)
return self.get(key)
def put(
self,
key: CacheKey,
value: T,
ttl_seconds: float | None = None,
) -> None:
"""
Put a memory into cache.
Args:
key: Cache key
value: Value to cache
ttl_seconds: Optional TTL override
"""
with self._lock:
# Evict if at capacity
self._evict_if_needed()
now = _utcnow()
entry = CacheEntry(
value=value,
created_at=now,
last_accessed_at=now,
access_count=1,
ttl_seconds=ttl_seconds or self._default_ttl,
)
self._cache[key] = entry
self._cache.move_to_end(key)
self._stats.current_size = len(self._cache)
def put_by_id(
self,
memory_type: str,
memory_id: UUID | str,
value: T,
scope: str | None = None,
ttl_seconds: float | None = None,
) -> None:
"""
Put a memory by type and ID.
Args:
memory_type: Type of memory
memory_id: Memory ID
value: Value to cache
scope: Optional scope
ttl_seconds: Optional TTL override
"""
key = CacheKey(
memory_type=memory_type,
memory_id=str(memory_id),
scope=scope,
)
self.put(key, value, ttl_seconds)
def _evict_if_needed(self) -> None:
"""Evict entries if cache is at capacity."""
while len(self._cache) >= self._max_size:
# Remove least recently used (first item)
if self._cache:
self._cache.popitem(last=False)
self._stats.evictions += 1
def invalidate(self, key: CacheKey) -> bool:
"""
Invalidate a specific cache entry.
Args:
key: Cache key to invalidate
Returns:
True if entry was found and removed
"""
with self._lock:
if key in self._cache:
del self._cache[key]
self._stats.current_size = len(self._cache)
return True
return False
def invalidate_by_id(
self,
memory_type: str,
memory_id: UUID | str,
scope: str | None = None,
) -> bool:
"""
Invalidate a memory by type and ID.
Args:
memory_type: Type of memory
memory_id: Memory ID
scope: Optional scope
Returns:
True if entry was found and removed
"""
key = CacheKey(
memory_type=memory_type,
memory_id=str(memory_id),
scope=scope,
)
return self.invalidate(key)
def invalidate_by_type(self, memory_type: str) -> int:
"""
Invalidate all entries of a memory type.
Args:
memory_type: Type of memory to invalidate
Returns:
Number of entries invalidated
"""
with self._lock:
keys_to_remove = [
k for k in self._cache.keys() if k.memory_type == memory_type
]
for key in keys_to_remove:
del self._cache[key]
self._stats.current_size = len(self._cache)
return len(keys_to_remove)
def invalidate_by_scope(self, scope: str) -> int:
"""
Invalidate all entries in a scope.
Args:
scope: Scope to invalidate (e.g., project_id)
Returns:
Number of entries invalidated
"""
with self._lock:
keys_to_remove = [k for k in self._cache.keys() if k.scope == scope]
for key in keys_to_remove:
del self._cache[key]
self._stats.current_size = len(self._cache)
return len(keys_to_remove)
def invalidate_pattern(self, pattern: str) -> int:
"""
Invalidate entries matching a pattern.
Pattern can include * as wildcard.
Args:
pattern: Pattern to match (e.g., "episodic:*")
Returns:
Number of entries invalidated
"""
import fnmatch
with self._lock:
keys_to_remove = [
k for k in self._cache.keys() if fnmatch.fnmatch(str(k), pattern)
]
for key in keys_to_remove:
del self._cache[key]
self._stats.current_size = len(self._cache)
return len(keys_to_remove)
def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries cleared
"""
with self._lock:
count = len(self._cache)
self._cache.clear()
self._stats.current_size = 0
logger.info(f"Cleared {count} entries from hot cache")
return count
def cleanup_expired(self) -> int:
"""
Remove all expired entries.
Returns:
Number of entries removed
"""
with self._lock:
keys_to_remove = [k for k, v in self._cache.items() if v.is_expired()]
for key in keys_to_remove:
del self._cache[key]
self._stats.expirations += 1
self._stats.current_size = len(self._cache)
if keys_to_remove:
logger.debug(f"Cleaned up {len(keys_to_remove)} expired entries")
return len(keys_to_remove)
def get_hot_memories(self, limit: int = 10) -> list[tuple[CacheKey, int]]:
"""
Get the most frequently accessed memories.
Args:
limit: Maximum number of memories to return
Returns:
List of (key, access_count) tuples sorted by access count
"""
with self._lock:
entries = [
(k, v.access_count)
for k, v in self._cache.items()
if not v.is_expired()
]
entries.sort(key=lambda x: x[1], reverse=True)
return entries[:limit]
def get_stats(self) -> HotCacheStats:
"""Get cache statistics."""
with self._lock:
self._stats.current_size = len(self._cache)
return self._stats
def reset_stats(self) -> None:
"""Reset cache statistics."""
with self._lock:
self._stats = HotCacheStats(
max_size=self._max_size,
current_size=len(self._cache),
)
@property
def size(self) -> int:
"""Get current cache size."""
return len(self._cache)
@property
def max_size(self) -> int:
"""Get maximum cache size."""
return self._max_size
# Factory function for typed caches
def create_hot_cache(
max_size: int = 10000,
default_ttl_seconds: float = 300.0,
) -> HotMemoryCache[Any]:
"""
Create a hot memory cache.
Args:
max_size: Maximum number of entries
default_ttl_seconds: Default TTL for entries
Returns:
Configured HotMemoryCache instance
"""
return HotMemoryCache(
max_size=max_size,
default_ttl_seconds=default_ttl_seconds,
)

View File

@@ -0,0 +1,410 @@
"""
Memory System Configuration.
Provides Pydantic settings for the Agent Memory System,
including storage backends, capacity limits, and consolidation policies.
"""
import threading
from functools import lru_cache
from typing import Any
from pydantic import Field, field_validator, model_validator
from pydantic_settings import BaseSettings
class MemorySettings(BaseSettings):
"""
Configuration for the Agent Memory System.
All settings can be overridden via environment variables
with the MEM_ prefix.
"""
# Working Memory Settings
working_memory_backend: str = Field(
default="redis",
description="Backend for working memory: 'redis' or 'memory'",
)
working_memory_default_ttl_seconds: int = Field(
default=3600,
ge=60,
le=86400,
description="Default TTL for working memory items (1 hour default)",
)
working_memory_max_items_per_session: int = Field(
default=1000,
ge=100,
le=100000,
description="Maximum items per session in working memory",
)
working_memory_max_value_size_bytes: int = Field(
default=1048576, # 1MB
ge=1024,
le=104857600, # 100MB
description="Maximum size of a single value in working memory",
)
working_memory_checkpoint_enabled: bool = Field(
default=True,
description="Enable checkpointing for working memory recovery",
)
# Redis Settings (for working memory)
redis_url: str = Field(
default="redis://localhost:6379/0",
description="Redis connection URL",
)
redis_prefix: str = Field(
default="mem",
description="Redis key prefix for memory items",
)
redis_connection_timeout_seconds: int = Field(
default=5,
ge=1,
le=60,
description="Redis connection timeout",
)
# Episodic Memory Settings
episodic_max_episodes_per_project: int = Field(
default=10000,
ge=100,
le=1000000,
description="Maximum episodes to retain per project",
)
episodic_default_importance: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Default importance score for new episodes",
)
episodic_retention_days: int = Field(
default=365,
ge=7,
le=3650,
description="Days to retain episodes before archival",
)
# Semantic Memory Settings
semantic_max_facts_per_project: int = Field(
default=50000,
ge=1000,
le=10000000,
description="Maximum facts to retain per project",
)
semantic_confidence_decay_days: int = Field(
default=90,
ge=7,
le=365,
description="Days until confidence decays by 50%",
)
semantic_min_confidence: float = Field(
default=0.1,
ge=0.0,
le=1.0,
description="Minimum confidence before fact is pruned",
)
# Procedural Memory Settings
procedural_max_procedures_per_project: int = Field(
default=1000,
ge=10,
le=100000,
description="Maximum procedures per project",
)
procedural_min_success_rate: float = Field(
default=0.3,
ge=0.0,
le=1.0,
description="Minimum success rate before procedure is pruned",
)
procedural_min_uses_before_suggest: int = Field(
default=3,
ge=1,
le=100,
description="Minimum uses before procedure is suggested",
)
# Embedding Settings
embedding_model: str = Field(
default="text-embedding-3-small",
description="Model to use for embeddings",
)
embedding_dimensions: int = Field(
default=1536,
ge=256,
le=4096,
description="Embedding vector dimensions",
)
embedding_batch_size: int = Field(
default=100,
ge=1,
le=1000,
description="Batch size for embedding generation",
)
embedding_cache_enabled: bool = Field(
default=True,
description="Enable caching of embeddings",
)
# Retrieval Settings
retrieval_default_limit: int = Field(
default=10,
ge=1,
le=100,
description="Default limit for retrieval queries",
)
retrieval_max_limit: int = Field(
default=100,
ge=10,
le=1000,
description="Maximum limit for retrieval queries",
)
retrieval_min_similarity: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Minimum similarity score for retrieval",
)
# Consolidation Settings
consolidation_enabled: bool = Field(
default=True,
description="Enable automatic memory consolidation",
)
consolidation_batch_size: int = Field(
default=100,
ge=10,
le=1000,
description="Batch size for consolidation jobs",
)
consolidation_schedule_cron: str = Field(
default="0 3 * * *",
description="Cron expression for nightly consolidation (3 AM)",
)
consolidation_working_to_episodic_delay_minutes: int = Field(
default=30,
ge=5,
le=1440,
description="Minutes after session end before consolidating to episodic",
)
# Pruning Settings
pruning_enabled: bool = Field(
default=True,
description="Enable automatic memory pruning",
)
pruning_min_age_days: int = Field(
default=7,
ge=1,
le=365,
description="Minimum age before memory can be pruned",
)
pruning_importance_threshold: float = Field(
default=0.2,
ge=0.0,
le=1.0,
description="Importance threshold below which memory can be pruned",
)
# Caching Settings
cache_enabled: bool = Field(
default=True,
description="Enable caching for memory retrieval",
)
cache_ttl_seconds: int = Field(
default=300,
ge=10,
le=3600,
description="Cache TTL for retrieval results",
)
cache_max_items: int = Field(
default=10000,
ge=100,
le=1000000,
description="Maximum items in memory cache",
)
# Performance Settings
max_retrieval_time_ms: int = Field(
default=100,
ge=10,
le=5000,
description="Target maximum retrieval time in milliseconds",
)
parallel_retrieval: bool = Field(
default=True,
description="Enable parallel retrieval from multiple memory types",
)
max_parallel_retrievals: int = Field(
default=4,
ge=1,
le=10,
description="Maximum concurrent retrieval operations",
)
@field_validator("working_memory_backend")
@classmethod
def validate_backend(cls, v: str) -> str:
"""Validate working memory backend."""
valid_backends = {"redis", "memory"}
if v not in valid_backends:
raise ValueError(f"backend must be one of: {valid_backends}")
return v
@field_validator("embedding_model")
@classmethod
def validate_embedding_model(cls, v: str) -> str:
"""Validate embedding model name."""
valid_models = {
"text-embedding-3-small",
"text-embedding-3-large",
"text-embedding-ada-002",
}
if v not in valid_models:
raise ValueError(f"embedding_model must be one of: {valid_models}")
return v
@model_validator(mode="after")
def validate_limits(self) -> "MemorySettings":
"""Validate that limits are consistent."""
if self.retrieval_default_limit > self.retrieval_max_limit:
raise ValueError(
f"retrieval_default_limit ({self.retrieval_default_limit}) "
f"cannot exceed retrieval_max_limit ({self.retrieval_max_limit})"
)
return self
def get_working_memory_config(self) -> dict[str, Any]:
"""Get working memory configuration as a dictionary."""
return {
"backend": self.working_memory_backend,
"default_ttl_seconds": self.working_memory_default_ttl_seconds,
"max_items_per_session": self.working_memory_max_items_per_session,
"max_value_size_bytes": self.working_memory_max_value_size_bytes,
"checkpoint_enabled": self.working_memory_checkpoint_enabled,
}
def get_redis_config(self) -> dict[str, Any]:
"""Get Redis configuration as a dictionary."""
return {
"url": self.redis_url,
"prefix": self.redis_prefix,
"connection_timeout_seconds": self.redis_connection_timeout_seconds,
}
def get_embedding_config(self) -> dict[str, Any]:
"""Get embedding configuration as a dictionary."""
return {
"model": self.embedding_model,
"dimensions": self.embedding_dimensions,
"batch_size": self.embedding_batch_size,
"cache_enabled": self.embedding_cache_enabled,
}
def get_consolidation_config(self) -> dict[str, Any]:
"""Get consolidation configuration as a dictionary."""
return {
"enabled": self.consolidation_enabled,
"batch_size": self.consolidation_batch_size,
"schedule_cron": self.consolidation_schedule_cron,
"working_to_episodic_delay_minutes": (
self.consolidation_working_to_episodic_delay_minutes
),
}
def to_dict(self) -> dict[str, Any]:
"""Convert settings to dictionary for logging/debugging."""
return {
"working_memory": self.get_working_memory_config(),
"redis": self.get_redis_config(),
"episodic": {
"max_episodes_per_project": self.episodic_max_episodes_per_project,
"default_importance": self.episodic_default_importance,
"retention_days": self.episodic_retention_days,
},
"semantic": {
"max_facts_per_project": self.semantic_max_facts_per_project,
"confidence_decay_days": self.semantic_confidence_decay_days,
"min_confidence": self.semantic_min_confidence,
},
"procedural": {
"max_procedures_per_project": self.procedural_max_procedures_per_project,
"min_success_rate": self.procedural_min_success_rate,
"min_uses_before_suggest": self.procedural_min_uses_before_suggest,
},
"embedding": self.get_embedding_config(),
"retrieval": {
"default_limit": self.retrieval_default_limit,
"max_limit": self.retrieval_max_limit,
"min_similarity": self.retrieval_min_similarity,
},
"consolidation": self.get_consolidation_config(),
"pruning": {
"enabled": self.pruning_enabled,
"min_age_days": self.pruning_min_age_days,
"importance_threshold": self.pruning_importance_threshold,
},
"cache": {
"enabled": self.cache_enabled,
"ttl_seconds": self.cache_ttl_seconds,
"max_items": self.cache_max_items,
},
"performance": {
"max_retrieval_time_ms": self.max_retrieval_time_ms,
"parallel_retrieval": self.parallel_retrieval,
"max_parallel_retrievals": self.max_parallel_retrievals,
},
}
model_config = {
"env_prefix": "MEM_",
"env_file": ".env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore",
}
# Thread-safe singleton pattern
_settings: MemorySettings | None = None
_settings_lock = threading.Lock()
def get_memory_settings() -> MemorySettings:
"""
Get the global MemorySettings instance.
Thread-safe with double-checked locking pattern.
Returns:
MemorySettings instance
"""
global _settings
if _settings is None:
with _settings_lock:
if _settings is None:
_settings = MemorySettings()
return _settings
def reset_memory_settings() -> None:
"""
Reset the global settings instance.
Primarily used for testing.
"""
global _settings
with _settings_lock:
_settings = None
@lru_cache(maxsize=1)
def get_default_settings() -> MemorySettings:
"""
Get default settings (cached).
Use this for read-only access to defaults.
For mutable access, use get_memory_settings().
"""
return MemorySettings()

View File

@@ -0,0 +1,29 @@
# app/services/memory/consolidation/__init__.py
"""
Memory Consolidation.
Transfers and extracts knowledge between memory tiers:
- Working -> Episodic (session end)
- Episodic -> Semantic (learn facts)
- Episodic -> Procedural (learn procedures)
Also handles memory pruning and importance-based retention.
"""
from .service import (
ConsolidationConfig,
ConsolidationResult,
MemoryConsolidationService,
NightlyConsolidationResult,
SessionConsolidationResult,
get_consolidation_service,
)
__all__ = [
"ConsolidationConfig",
"ConsolidationResult",
"MemoryConsolidationService",
"NightlyConsolidationResult",
"SessionConsolidationResult",
"get_consolidation_service",
]

View File

@@ -0,0 +1,913 @@
# app/services/memory/consolidation/service.py
"""
Memory Consolidation Service.
Transfers and extracts knowledge between memory tiers:
- Working -> Episodic (session end)
- Episodic -> Semantic (learn facts)
- Episodic -> Procedural (learn procedures)
Also handles memory pruning and importance-based retention.
"""
import logging
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.memory.episodic.memory import EpisodicMemory
from app.services.memory.procedural.memory import ProceduralMemory
from app.services.memory.semantic.extraction import FactExtractor, get_fact_extractor
from app.services.memory.semantic.memory import SemanticMemory
from app.services.memory.types import (
Episode,
EpisodeCreate,
Outcome,
ProcedureCreate,
TaskState,
)
from app.services.memory.working.memory import WorkingMemory
logger = logging.getLogger(__name__)
@dataclass
class ConsolidationConfig:
"""Configuration for memory consolidation."""
# Working -> Episodic thresholds
min_steps_for_episode: int = 2
min_duration_seconds: float = 5.0
# Episodic -> Semantic thresholds
min_confidence_for_fact: float = 0.6
max_facts_per_episode: int = 10
reinforce_existing_facts: bool = True
# Episodic -> Procedural thresholds
min_episodes_for_procedure: int = 3
min_success_rate_for_procedure: float = 0.7
min_steps_for_procedure: int = 2
# Pruning thresholds
max_episode_age_days: int = 90
min_importance_to_keep: float = 0.2
keep_all_failures: bool = True
keep_all_with_lessons: bool = True
# Batch sizes
batch_size: int = 100
@dataclass
class ConsolidationResult:
"""Result of a consolidation operation."""
source_type: str
target_type: str
items_processed: int = 0
items_created: int = 0
items_updated: int = 0
items_skipped: int = 0
items_pruned: int = 0
errors: list[str] = field(default_factory=list)
duration_seconds: float = 0.0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"source_type": self.source_type,
"target_type": self.target_type,
"items_processed": self.items_processed,
"items_created": self.items_created,
"items_updated": self.items_updated,
"items_skipped": self.items_skipped,
"items_pruned": self.items_pruned,
"errors": self.errors,
"duration_seconds": self.duration_seconds,
}
@dataclass
class SessionConsolidationResult:
"""Result of consolidating a session's working memory to episodic."""
session_id: str
episode_created: bool = False
episode_id: UUID | None = None
scratchpad_entries: int = 0
variables_captured: int = 0
error: str | None = None
@dataclass
class NightlyConsolidationResult:
"""Result of nightly consolidation run."""
started_at: datetime
completed_at: datetime | None = None
episodic_to_semantic: ConsolidationResult | None = None
episodic_to_procedural: ConsolidationResult | None = None
pruning: ConsolidationResult | None = None
total_episodes_processed: int = 0
total_facts_created: int = 0
total_procedures_created: int = 0
total_pruned: int = 0
errors: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat()
if self.completed_at
else None,
"episodic_to_semantic": (
self.episodic_to_semantic.to_dict()
if self.episodic_to_semantic
else None
),
"episodic_to_procedural": (
self.episodic_to_procedural.to_dict()
if self.episodic_to_procedural
else None
),
"pruning": self.pruning.to_dict() if self.pruning else None,
"total_episodes_processed": self.total_episodes_processed,
"total_facts_created": self.total_facts_created,
"total_procedures_created": self.total_procedures_created,
"total_pruned": self.total_pruned,
"errors": self.errors,
}
class MemoryConsolidationService:
"""
Service for consolidating memories between tiers.
Responsibilities:
- Transfer working memory to episodic at session end
- Extract facts from episodes to semantic memory
- Learn procedures from successful episode patterns
- Prune old/low-value memories
"""
def __init__(
self,
session: AsyncSession,
config: ConsolidationConfig | None = None,
embedding_generator: Any | None = None,
) -> None:
"""
Initialize consolidation service.
Args:
session: Database session
config: Consolidation configuration
embedding_generator: Optional embedding generator
"""
self._session = session
self._config = config or ConsolidationConfig()
self._embedding_generator = embedding_generator
self._fact_extractor: FactExtractor = get_fact_extractor()
# Memory services (lazy initialized)
self._episodic: EpisodicMemory | None = None
self._semantic: SemanticMemory | None = None
self._procedural: ProceduralMemory | None = None
async def _get_episodic(self) -> EpisodicMemory:
"""Get or create episodic memory service."""
if self._episodic is None:
self._episodic = await EpisodicMemory.create(
self._session, self._embedding_generator
)
return self._episodic
async def _get_semantic(self) -> SemanticMemory:
"""Get or create semantic memory service."""
if self._semantic is None:
self._semantic = await SemanticMemory.create(
self._session, self._embedding_generator
)
return self._semantic
async def _get_procedural(self) -> ProceduralMemory:
"""Get or create procedural memory service."""
if self._procedural is None:
self._procedural = await ProceduralMemory.create(
self._session, self._embedding_generator
)
return self._procedural
# =========================================================================
# Working -> Episodic Consolidation
# =========================================================================
async def consolidate_session(
self,
working_memory: WorkingMemory,
project_id: UUID,
session_id: str,
task_type: str = "session_task",
agent_instance_id: UUID | None = None,
agent_type_id: UUID | None = None,
) -> SessionConsolidationResult:
"""
Consolidate a session's working memory to episodic memory.
Called at session end to transfer relevant session data
into a persistent episode.
Args:
working_memory: The session's working memory
project_id: Project ID
session_id: Session ID
task_type: Type of task performed
agent_instance_id: Optional agent instance
agent_type_id: Optional agent type
Returns:
SessionConsolidationResult with outcome details
"""
result = SessionConsolidationResult(session_id=session_id)
try:
# Get task state
task_state = await working_memory.get_task_state()
# Check if there's enough content to consolidate
if not self._should_consolidate_session(task_state):
logger.debug(
f"Skipping consolidation for session {session_id}: insufficient content"
)
return result
# Gather scratchpad entries
scratchpad = await working_memory.get_scratchpad()
result.scratchpad_entries = len(scratchpad)
# Gather user variables
all_data = await working_memory.get_all()
result.variables_captured = len(all_data)
# Determine outcome
outcome = self._determine_session_outcome(task_state)
# Build actions from scratchpad and variables
actions = self._build_actions_from_session(scratchpad, all_data, task_state)
# Build context summary
context_summary = self._build_context_summary(task_state, all_data)
# Extract lessons learned
lessons = self._extract_session_lessons(task_state, outcome)
# Calculate importance
importance = self._calculate_session_importance(
task_state, outcome, actions
)
# Create episode
episode_data = EpisodeCreate(
project_id=project_id,
session_id=session_id,
task_type=task_type,
task_description=task_state.description
if task_state
else "Session task",
actions=actions,
context_summary=context_summary,
outcome=outcome,
outcome_details=task_state.status if task_state else "",
duration_seconds=self._calculate_duration(task_state),
tokens_used=0, # Would need to track this in working memory
lessons_learned=lessons,
importance_score=importance,
agent_instance_id=agent_instance_id,
agent_type_id=agent_type_id,
)
episodic = await self._get_episodic()
episode = await episodic.record_episode(episode_data)
result.episode_created = True
result.episode_id = episode.id
logger.info(
f"Consolidated session {session_id} to episode {episode.id} "
f"({len(actions)} actions, outcome={outcome.value})"
)
except Exception as e:
result.error = str(e)
logger.exception(f"Failed to consolidate session {session_id}")
return result
def _should_consolidate_session(self, task_state: TaskState | None) -> bool:
"""Check if session has enough content to consolidate."""
if task_state is None:
return False
# Check minimum steps
if task_state.current_step < self._config.min_steps_for_episode:
return False
return True
def _determine_session_outcome(self, task_state: TaskState | None) -> Outcome:
"""Determine outcome from task state."""
if task_state is None:
return Outcome.PARTIAL
status = task_state.status.lower() if task_state.status else ""
progress = task_state.progress_percent
if "success" in status or "complete" in status or progress >= 100:
return Outcome.SUCCESS
if "fail" in status or "error" in status:
return Outcome.FAILURE
if progress >= 50:
return Outcome.PARTIAL
return Outcome.FAILURE
def _build_actions_from_session(
self,
scratchpad: list[str],
variables: dict[str, Any],
task_state: TaskState | None,
) -> list[dict[str, Any]]:
"""Build action list from session data."""
actions: list[dict[str, Any]] = []
# Add scratchpad entries as actions
for i, entry in enumerate(scratchpad):
actions.append(
{
"step": i + 1,
"type": "reasoning",
"content": entry[:500], # Truncate long entries
}
)
# Add final state
if task_state:
actions.append(
{
"step": len(scratchpad) + 1,
"type": "final_state",
"current_step": task_state.current_step,
"total_steps": task_state.total_steps,
"progress": task_state.progress_percent,
"status": task_state.status,
}
)
return actions
def _build_context_summary(
self,
task_state: TaskState | None,
variables: dict[str, Any],
) -> str:
"""Build context summary from session data."""
parts = []
if task_state:
parts.append(f"Task: {task_state.description}")
parts.append(f"Progress: {task_state.progress_percent:.1f}%")
parts.append(f"Steps: {task_state.current_step}/{task_state.total_steps}")
# Include key variables
key_vars = {k: v for k, v in variables.items() if len(str(v)) < 100}
if key_vars:
var_str = ", ".join(f"{k}={v}" for k, v in list(key_vars.items())[:5])
parts.append(f"Variables: {var_str}")
return "; ".join(parts) if parts else "Session completed"
def _extract_session_lessons(
self,
task_state: TaskState | None,
outcome: Outcome,
) -> list[str]:
"""Extract lessons from session."""
lessons: list[str] = []
if task_state and task_state.status:
if outcome == Outcome.FAILURE:
lessons.append(
f"Task failed at step {task_state.current_step}: {task_state.status}"
)
elif outcome == Outcome.SUCCESS:
lessons.append(
f"Successfully completed in {task_state.current_step} steps"
)
return lessons
def _calculate_session_importance(
self,
task_state: TaskState | None,
outcome: Outcome,
actions: list[dict[str, Any]],
) -> float:
"""Calculate importance score for session."""
score = 0.5 # Base score
# Failures are important to learn from
if outcome == Outcome.FAILURE:
score += 0.3
# Many steps means complex task
if task_state and task_state.total_steps >= 5:
score += 0.1
# Many actions means detailed reasoning
if len(actions) >= 5:
score += 0.1
return min(1.0, score)
def _calculate_duration(self, task_state: TaskState | None) -> float:
"""Calculate session duration."""
if task_state is None:
return 0.0
if task_state.started_at and task_state.updated_at:
delta = task_state.updated_at - task_state.started_at
return delta.total_seconds()
return 0.0
# =========================================================================
# Episodic -> Semantic Consolidation
# =========================================================================
async def consolidate_episodes_to_facts(
self,
project_id: UUID,
since: datetime | None = None,
limit: int | None = None,
) -> ConsolidationResult:
"""
Extract facts from episodic memories to semantic memory.
Args:
project_id: Project to consolidate
since: Only process episodes since this time
limit: Maximum episodes to process
Returns:
ConsolidationResult with extraction statistics
"""
start_time = datetime.now(UTC)
result = ConsolidationResult(
source_type="episodic",
target_type="semantic",
)
try:
episodic = await self._get_episodic()
semantic = await self._get_semantic()
# Get episodes to process
since_time = since or datetime.now(UTC) - timedelta(days=1)
episodes = await episodic.get_recent(
project_id,
limit=limit or self._config.batch_size,
since=since_time,
)
for episode in episodes:
result.items_processed += 1
try:
# Extract facts using the extractor
extracted_facts = self._fact_extractor.extract_from_episode(episode)
for extracted_fact in extracted_facts:
if (
extracted_fact.confidence
< self._config.min_confidence_for_fact
):
result.items_skipped += 1
continue
# Create fact (store_fact handles deduplication/reinforcement)
fact_create = extracted_fact.to_fact_create(
project_id=project_id,
source_episode_ids=[episode.id],
)
# store_fact automatically reinforces if fact already exists
fact = await semantic.store_fact(fact_create)
# Check if this was a new fact or reinforced existing
if fact.reinforcement_count == 1:
result.items_created += 1
else:
result.items_updated += 1
except Exception as e:
result.errors.append(f"Episode {episode.id}: {e}")
logger.warning(
f"Failed to extract facts from episode {episode.id}: {e}"
)
except Exception as e:
result.errors.append(f"Consolidation failed: {e}")
logger.exception("Failed episodic -> semantic consolidation")
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
logger.info(
f"Episodic -> Semantic consolidation: "
f"{result.items_processed} processed, "
f"{result.items_created} created, "
f"{result.items_updated} reinforced"
)
return result
# =========================================================================
# Episodic -> Procedural Consolidation
# =========================================================================
async def consolidate_episodes_to_procedures(
self,
project_id: UUID,
agent_type_id: UUID | None = None,
since: datetime | None = None,
) -> ConsolidationResult:
"""
Learn procedures from patterns in episodic memories.
Identifies recurring successful patterns and creates/updates
procedures to capture them.
Args:
project_id: Project to consolidate
agent_type_id: Optional filter by agent type
since: Only process episodes since this time
Returns:
ConsolidationResult with procedure statistics
"""
start_time = datetime.now(UTC)
result = ConsolidationResult(
source_type="episodic",
target_type="procedural",
)
try:
episodic = await self._get_episodic()
procedural = await self._get_procedural()
# Get successful episodes
since_time = since or datetime.now(UTC) - timedelta(days=7)
episodes = await episodic.get_by_outcome(
project_id,
outcome=Outcome.SUCCESS,
limit=self._config.batch_size,
agent_instance_id=None, # Get all agent instances
)
# Group by task type
task_groups: dict[str, list[Episode]] = {}
for episode in episodes:
if episode.occurred_at >= since_time:
if episode.task_type not in task_groups:
task_groups[episode.task_type] = []
task_groups[episode.task_type].append(episode)
result.items_processed = len(episodes)
# Process each task type group
for task_type, group in task_groups.items():
if len(group) < self._config.min_episodes_for_procedure:
result.items_skipped += len(group)
continue
try:
procedure_result = await self._learn_procedure_from_episodes(
procedural,
project_id,
agent_type_id,
task_type,
group,
)
if procedure_result == "created":
result.items_created += 1
elif procedure_result == "updated":
result.items_updated += 1
else:
result.items_skipped += 1
except Exception as e:
result.errors.append(f"Task type '{task_type}': {e}")
logger.warning(f"Failed to learn procedure for '{task_type}': {e}")
except Exception as e:
result.errors.append(f"Consolidation failed: {e}")
logger.exception("Failed episodic -> procedural consolidation")
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
logger.info(
f"Episodic -> Procedural consolidation: "
f"{result.items_processed} processed, "
f"{result.items_created} created, "
f"{result.items_updated} updated"
)
return result
async def _learn_procedure_from_episodes(
self,
procedural: ProceduralMemory,
project_id: UUID,
agent_type_id: UUID | None,
task_type: str,
episodes: list[Episode],
) -> str:
"""Learn or update a procedure from a set of episodes."""
# Calculate success rate for this pattern
success_count = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
total_count = len(episodes)
success_rate = success_count / total_count if total_count > 0 else 0
if success_rate < self._config.min_success_rate_for_procedure:
return "skipped"
# Extract common steps from episodes
steps = self._extract_common_steps(episodes)
if len(steps) < self._config.min_steps_for_procedure:
return "skipped"
# Check for existing procedure
matching = await procedural.find_matching(
context=task_type,
project_id=project_id,
agent_type_id=agent_type_id,
limit=1,
)
if matching:
# Update existing procedure with new success
await procedural.record_outcome(
matching[0].id,
success=True,
)
return "updated"
else:
# Create new procedure
# Note: success_count starts at 1 in record_procedure
procedure_data = ProcedureCreate(
project_id=project_id,
agent_type_id=agent_type_id,
name=f"Procedure for {task_type}",
trigger_pattern=task_type,
steps=steps,
)
await procedural.record_procedure(procedure_data)
return "created"
def _extract_common_steps(self, episodes: list[Episode]) -> list[dict[str, Any]]:
"""Extract common action steps from multiple episodes."""
# Simple heuristic: take the steps from the most successful episode
# with the most detailed actions
best_episode = max(
episodes,
key=lambda e: (
e.outcome == Outcome.SUCCESS,
len(e.actions),
e.importance_score,
),
)
steps: list[dict[str, Any]] = []
for i, action in enumerate(best_episode.actions):
step = {
"order": i + 1,
"action": action.get("type", "action"),
"description": action.get("content", str(action))[:500],
"parameters": action,
}
steps.append(step)
return steps
# =========================================================================
# Memory Pruning
# =========================================================================
async def prune_old_episodes(
self,
project_id: UUID,
max_age_days: int | None = None,
min_importance: float | None = None,
) -> ConsolidationResult:
"""
Prune old, low-value episodes.
Args:
project_id: Project to prune
max_age_days: Maximum age in days (default from config)
min_importance: Minimum importance to keep (default from config)
Returns:
ConsolidationResult with pruning statistics
"""
start_time = datetime.now(UTC)
result = ConsolidationResult(
source_type="episodic",
target_type="pruned",
)
max_age = max_age_days or self._config.max_episode_age_days
min_imp = min_importance or self._config.min_importance_to_keep
cutoff_date = datetime.now(UTC) - timedelta(days=max_age)
try:
episodic = await self._get_episodic()
# Get old episodes
# Note: In production, this would use a more efficient query
all_episodes = await episodic.get_recent(
project_id,
limit=self._config.batch_size * 10,
since=cutoff_date - timedelta(days=365), # Search past year
)
for episode in all_episodes:
result.items_processed += 1
# Check if should be pruned
if not self._should_prune_episode(episode, cutoff_date, min_imp):
result.items_skipped += 1
continue
try:
deleted = await episodic.delete(episode.id)
if deleted:
result.items_pruned += 1
else:
result.items_skipped += 1
except Exception as e:
result.errors.append(f"Episode {episode.id}: {e}")
except Exception as e:
result.errors.append(f"Pruning failed: {e}")
logger.exception("Failed episode pruning")
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
logger.info(
f"Episode pruning: {result.items_processed} processed, "
f"{result.items_pruned} pruned"
)
return result
def _should_prune_episode(
self,
episode: Episode,
cutoff_date: datetime,
min_importance: float,
) -> bool:
"""Determine if an episode should be pruned."""
# Keep recent episodes
if episode.occurred_at >= cutoff_date:
return False
# Keep failures if configured
if self._config.keep_all_failures and episode.outcome == Outcome.FAILURE:
return False
# Keep episodes with lessons if configured
if self._config.keep_all_with_lessons and episode.lessons_learned:
return False
# Keep high-importance episodes
if episode.importance_score >= min_importance:
return False
return True
# =========================================================================
# Nightly Consolidation
# =========================================================================
async def run_nightly_consolidation(
self,
project_id: UUID,
agent_type_id: UUID | None = None,
) -> NightlyConsolidationResult:
"""
Run full nightly consolidation workflow.
This includes:
1. Extract facts from recent episodes
2. Learn procedures from successful patterns
3. Prune old, low-value memories
Args:
project_id: Project to consolidate
agent_type_id: Optional agent type filter
Returns:
NightlyConsolidationResult with all outcomes
"""
result = NightlyConsolidationResult(started_at=datetime.now(UTC))
logger.info(f"Starting nightly consolidation for project {project_id}")
try:
# Step 1: Episodic -> Semantic (last 24 hours)
since_yesterday = datetime.now(UTC) - timedelta(days=1)
result.episodic_to_semantic = await self.consolidate_episodes_to_facts(
project_id=project_id,
since=since_yesterday,
)
result.total_facts_created = result.episodic_to_semantic.items_created
# Step 2: Episodic -> Procedural (last 7 days)
since_week = datetime.now(UTC) - timedelta(days=7)
result.episodic_to_procedural = (
await self.consolidate_episodes_to_procedures(
project_id=project_id,
agent_type_id=agent_type_id,
since=since_week,
)
)
result.total_procedures_created = (
result.episodic_to_procedural.items_created
)
# Step 3: Prune old memories
result.pruning = await self.prune_old_episodes(project_id=project_id)
result.total_pruned = result.pruning.items_pruned
# Calculate totals
result.total_episodes_processed = (
result.episodic_to_semantic.items_processed
if result.episodic_to_semantic
else 0
) + (
result.episodic_to_procedural.items_processed
if result.episodic_to_procedural
else 0
)
# Collect all errors
if result.episodic_to_semantic and result.episodic_to_semantic.errors:
result.errors.extend(result.episodic_to_semantic.errors)
if result.episodic_to_procedural and result.episodic_to_procedural.errors:
result.errors.extend(result.episodic_to_procedural.errors)
if result.pruning and result.pruning.errors:
result.errors.extend(result.pruning.errors)
except Exception as e:
result.errors.append(f"Nightly consolidation failed: {e}")
logger.exception("Nightly consolidation failed")
result.completed_at = datetime.now(UTC)
duration = (result.completed_at - result.started_at).total_seconds()
logger.info(
f"Nightly consolidation completed in {duration:.1f}s: "
f"{result.total_facts_created} facts, "
f"{result.total_procedures_created} procedures, "
f"{result.total_pruned} pruned"
)
return result
# Factory function - no singleton to avoid stale session issues
async def get_consolidation_service(
session: AsyncSession,
config: ConsolidationConfig | None = None,
) -> MemoryConsolidationService:
"""
Create a memory consolidation service for the given session.
Note: This creates a new instance each time to avoid stale session issues.
The service is lightweight and safe to recreate per-request.
Args:
session: Database session (must be active)
config: Optional configuration
Returns:
MemoryConsolidationService instance
"""
return MemoryConsolidationService(session=session, config=config)

View File

@@ -0,0 +1,17 @@
# app/services/memory/episodic/__init__.py
"""
Episodic Memory Package.
Provides experiential memory storage and retrieval for agent learning.
"""
from .memory import EpisodicMemory
from .recorder import EpisodeRecorder
from .retrieval import EpisodeRetriever, RetrievalStrategy
__all__ = [
"EpisodeRecorder",
"EpisodeRetriever",
"EpisodicMemory",
"RetrievalStrategy",
]

View File

@@ -0,0 +1,490 @@
# app/services/memory/episodic/memory.py
"""
Episodic Memory Implementation.
Provides experiential memory storage and retrieval for agent learning.
Combines episode recording and retrieval into a unified interface.
"""
import logging
from datetime import datetime
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.memory.types import Episode, EpisodeCreate, Outcome, RetrievalResult
from .recorder import EpisodeRecorder
from .retrieval import EpisodeRetriever, RetrievalStrategy
logger = logging.getLogger(__name__)
class EpisodicMemory:
"""
Episodic Memory Service.
Provides experiential memory for agent learning:
- Record task completions with context
- Store failures with error context
- Retrieve by semantic similarity
- Retrieve by recency, outcome, task type
- Track importance scores
- Extract lessons learned
Performance target: <100ms P95 for retrieval
"""
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> None:
"""
Initialize episodic memory.
Args:
session: Database session
embedding_generator: Optional embedding generator for semantic search
"""
self._session = session
self._embedding_generator = embedding_generator
self._recorder = EpisodeRecorder(session, embedding_generator)
self._retriever = EpisodeRetriever(session, embedding_generator)
@classmethod
async def create(
cls,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> "EpisodicMemory":
"""
Factory method to create EpisodicMemory.
Args:
session: Database session
embedding_generator: Optional embedding generator
Returns:
Configured EpisodicMemory instance
"""
return cls(session=session, embedding_generator=embedding_generator)
# =========================================================================
# Recording Operations
# =========================================================================
async def record_episode(self, episode: EpisodeCreate) -> Episode:
"""
Record a new episode.
Args:
episode: Episode data to record
Returns:
The created episode with assigned ID
"""
return await self._recorder.record(episode)
async def record_success(
self,
project_id: UUID,
session_id: str,
task_type: str,
task_description: str,
actions: list[dict[str, Any]],
context_summary: str,
outcome_details: str = "",
duration_seconds: float = 0.0,
tokens_used: int = 0,
lessons_learned: list[str] | None = None,
agent_instance_id: UUID | None = None,
agent_type_id: UUID | None = None,
) -> Episode:
"""
Convenience method to record a successful episode.
Args:
project_id: Project ID
session_id: Session ID
task_type: Type of task
task_description: Task description
actions: Actions taken
context_summary: Context summary
outcome_details: Optional outcome details
duration_seconds: Task duration
tokens_used: Tokens consumed
lessons_learned: Optional lessons
agent_instance_id: Optional agent instance
agent_type_id: Optional agent type
Returns:
The created episode
"""
episode_data = EpisodeCreate(
project_id=project_id,
session_id=session_id,
task_type=task_type,
task_description=task_description,
actions=actions,
context_summary=context_summary,
outcome=Outcome.SUCCESS,
outcome_details=outcome_details,
duration_seconds=duration_seconds,
tokens_used=tokens_used,
lessons_learned=lessons_learned or [],
agent_instance_id=agent_instance_id,
agent_type_id=agent_type_id,
)
return await self.record_episode(episode_data)
async def record_failure(
self,
project_id: UUID,
session_id: str,
task_type: str,
task_description: str,
actions: list[dict[str, Any]],
context_summary: str,
error_details: str,
duration_seconds: float = 0.0,
tokens_used: int = 0,
lessons_learned: list[str] | None = None,
agent_instance_id: UUID | None = None,
agent_type_id: UUID | None = None,
) -> Episode:
"""
Convenience method to record a failed episode.
Args:
project_id: Project ID
session_id: Session ID
task_type: Type of task
task_description: Task description
actions: Actions taken before failure
context_summary: Context summary
error_details: Error details
duration_seconds: Task duration
tokens_used: Tokens consumed
lessons_learned: Optional lessons from failure
agent_instance_id: Optional agent instance
agent_type_id: Optional agent type
Returns:
The created episode
"""
episode_data = EpisodeCreate(
project_id=project_id,
session_id=session_id,
task_type=task_type,
task_description=task_description,
actions=actions,
context_summary=context_summary,
outcome=Outcome.FAILURE,
outcome_details=error_details,
duration_seconds=duration_seconds,
tokens_used=tokens_used,
lessons_learned=lessons_learned or [],
agent_instance_id=agent_instance_id,
agent_type_id=agent_type_id,
)
return await self.record_episode(episode_data)
# =========================================================================
# Retrieval Operations
# =========================================================================
async def search_similar(
self,
project_id: UUID,
query: str,
limit: int = 10,
agent_instance_id: UUID | None = None,
) -> list[Episode]:
"""
Search for semantically similar episodes.
Args:
project_id: Project to search within
query: Search query
limit: Maximum results
agent_instance_id: Optional filter by agent instance
Returns:
List of similar episodes
"""
result = await self._retriever.search_similar(
project_id, query, limit, agent_instance_id
)
return result.items
async def get_recent(
self,
project_id: UUID,
limit: int = 10,
since: datetime | None = None,
agent_instance_id: UUID | None = None,
) -> list[Episode]:
"""
Get recent episodes.
Args:
project_id: Project to search within
limit: Maximum results
since: Optional time filter
agent_instance_id: Optional filter by agent instance
Returns:
List of recent episodes
"""
result = await self._retriever.get_recent(
project_id, limit, since, agent_instance_id
)
return result.items
async def get_by_outcome(
self,
project_id: UUID,
outcome: Outcome,
limit: int = 10,
agent_instance_id: UUID | None = None,
) -> list[Episode]:
"""
Get episodes by outcome.
Args:
project_id: Project to search within
outcome: Outcome to filter by
limit: Maximum results
agent_instance_id: Optional filter by agent instance
Returns:
List of episodes with specified outcome
"""
result = await self._retriever.get_by_outcome(
project_id, outcome, limit, agent_instance_id
)
return result.items
async def get_by_task_type(
self,
project_id: UUID,
task_type: str,
limit: int = 10,
agent_instance_id: UUID | None = None,
) -> list[Episode]:
"""
Get episodes by task type.
Args:
project_id: Project to search within
task_type: Task type to filter by
limit: Maximum results
agent_instance_id: Optional filter by agent instance
Returns:
List of episodes with specified task type
"""
result = await self._retriever.get_by_task_type(
project_id, task_type, limit, agent_instance_id
)
return result.items
async def get_important(
self,
project_id: UUID,
limit: int = 10,
min_importance: float = 0.7,
agent_instance_id: UUID | None = None,
) -> list[Episode]:
"""
Get high-importance episodes.
Args:
project_id: Project to search within
limit: Maximum results
min_importance: Minimum importance score
agent_instance_id: Optional filter by agent instance
Returns:
List of important episodes
"""
result = await self._retriever.get_important(
project_id, limit, min_importance, agent_instance_id
)
return result.items
async def retrieve(
self,
project_id: UUID,
strategy: RetrievalStrategy = RetrievalStrategy.RECENCY,
limit: int = 10,
**kwargs: Any,
) -> RetrievalResult[Episode]:
"""
Retrieve episodes with full result metadata.
Args:
project_id: Project to search within
strategy: Retrieval strategy
limit: Maximum results
**kwargs: Strategy-specific parameters
Returns:
RetrievalResult with episodes and metadata
"""
return await self._retriever.retrieve(project_id, strategy, limit, **kwargs)
# =========================================================================
# Modification Operations
# =========================================================================
async def get_by_id(self, episode_id: UUID) -> Episode | None:
"""Get an episode by ID."""
return await self._recorder.get_by_id(episode_id)
async def update_importance(
self,
episode_id: UUID,
importance_score: float,
) -> Episode | None:
"""
Update an episode's importance score.
Args:
episode_id: Episode to update
importance_score: New importance score (0.0 to 1.0)
Returns:
Updated episode or None if not found
"""
return await self._recorder.update_importance(episode_id, importance_score)
async def add_lessons(
self,
episode_id: UUID,
lessons: list[str],
) -> Episode | None:
"""
Add lessons learned to an episode.
Args:
episode_id: Episode to update
lessons: Lessons to add
Returns:
Updated episode or None if not found
"""
return await self._recorder.add_lessons(episode_id, lessons)
async def delete(self, episode_id: UUID) -> bool:
"""
Delete an episode.
Args:
episode_id: Episode to delete
Returns:
True if deleted
"""
return await self._recorder.delete(episode_id)
# =========================================================================
# Summarization
# =========================================================================
async def summarize_episodes(
self,
episode_ids: list[UUID],
) -> str:
"""
Summarize multiple episodes into a consolidated view.
Args:
episode_ids: Episodes to summarize
Returns:
Summary text
"""
if not episode_ids:
return "No episodes to summarize."
episodes: list[Episode] = []
for episode_id in episode_ids:
episode = await self.get_by_id(episode_id)
if episode:
episodes.append(episode)
if not episodes:
return "No episodes found."
# Build summary
lines = [f"Summary of {len(episodes)} episodes:", ""]
# Outcome breakdown
success = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
failure = sum(1 for e in episodes if e.outcome == Outcome.FAILURE)
partial = sum(1 for e in episodes if e.outcome == Outcome.PARTIAL)
lines.append(
f"Outcomes: {success} success, {failure} failure, {partial} partial"
)
# Task types
task_types = {e.task_type for e in episodes}
lines.append(f"Task types: {', '.join(sorted(task_types))}")
# Aggregate lessons
all_lessons: list[str] = []
for e in episodes:
all_lessons.extend(e.lessons_learned)
if all_lessons:
lines.append("")
lines.append("Key lessons learned:")
# Deduplicate lessons
unique_lessons = list(dict.fromkeys(all_lessons))
for lesson in unique_lessons[:10]: # Top 10
lines.append(f" - {lesson}")
# Duration and tokens
total_duration = sum(e.duration_seconds for e in episodes)
total_tokens = sum(e.tokens_used for e in episodes)
lines.append("")
lines.append(f"Total duration: {total_duration:.1f}s")
lines.append(f"Total tokens: {total_tokens:,}")
return "\n".join(lines)
# =========================================================================
# Statistics
# =========================================================================
async def get_stats(self, project_id: UUID) -> dict[str, Any]:
"""
Get episode statistics for a project.
Args:
project_id: Project to get stats for
Returns:
Dictionary with episode statistics
"""
return await self._recorder.get_stats(project_id)
async def count(
self,
project_id: UUID,
since: datetime | None = None,
) -> int:
"""
Count episodes for a project.
Args:
project_id: Project to count for
since: Optional time filter
Returns:
Number of episodes
"""
return await self._recorder.count_by_project(project_id, since)

View File

@@ -0,0 +1,357 @@
# app/services/memory/episodic/recorder.py
"""
Episode Recording.
Handles the creation and storage of episodic memories
during agent task execution.
"""
import logging
from datetime import UTC, datetime
from typing import Any
from uuid import UUID, uuid4
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory.enums import EpisodeOutcome
from app.models.memory.episode import Episode as EpisodeModel
from app.services.memory.config import get_memory_settings
from app.services.memory.types import Episode, EpisodeCreate, Outcome
logger = logging.getLogger(__name__)
def _outcome_to_db(outcome: Outcome) -> EpisodeOutcome:
"""Convert service Outcome to database EpisodeOutcome."""
return EpisodeOutcome(outcome.value)
def _db_to_outcome(db_outcome: EpisodeOutcome) -> Outcome:
"""Convert database EpisodeOutcome to service Outcome."""
return Outcome(db_outcome.value)
def _model_to_episode(model: EpisodeModel) -> Episode:
"""Convert SQLAlchemy model to Episode dataclass."""
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
# they return actual values. We use type: ignore to handle this mismatch.
return Episode(
id=model.id, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
agent_instance_id=model.agent_instance_id, # type: ignore[arg-type]
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
session_id=model.session_id, # type: ignore[arg-type]
task_type=model.task_type, # type: ignore[arg-type]
task_description=model.task_description, # type: ignore[arg-type]
actions=model.actions or [], # type: ignore[arg-type]
context_summary=model.context_summary, # type: ignore[arg-type]
outcome=_db_to_outcome(model.outcome), # type: ignore[arg-type]
outcome_details=model.outcome_details or "", # type: ignore[arg-type]
duration_seconds=model.duration_seconds, # type: ignore[arg-type]
tokens_used=model.tokens_used, # type: ignore[arg-type]
lessons_learned=model.lessons_learned or [], # type: ignore[arg-type]
importance_score=model.importance_score, # type: ignore[arg-type]
embedding=None, # Don't expose raw embedding
occurred_at=model.occurred_at, # type: ignore[arg-type]
created_at=model.created_at, # type: ignore[arg-type]
updated_at=model.updated_at, # type: ignore[arg-type]
)
class EpisodeRecorder:
"""
Records episodes to the database.
Handles episode creation, importance scoring,
and lesson extraction.
"""
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> None:
"""
Initialize recorder.
Args:
session: Database session
embedding_generator: Optional embedding generator for semantic indexing
"""
self._session = session
self._embedding_generator = embedding_generator
self._settings = get_memory_settings()
async def record(self, episode_data: EpisodeCreate) -> Episode:
"""
Record a new episode.
Args:
episode_data: Episode data to record
Returns:
The created episode
"""
now = datetime.now(UTC)
# Calculate importance score if not provided
importance = episode_data.importance_score
if importance == 0.5: # Default value, calculate
importance = self._calculate_importance(episode_data)
# Create the model
model = EpisodeModel(
id=uuid4(),
project_id=episode_data.project_id,
agent_instance_id=episode_data.agent_instance_id,
agent_type_id=episode_data.agent_type_id,
session_id=episode_data.session_id,
task_type=episode_data.task_type,
task_description=episode_data.task_description,
actions=episode_data.actions,
context_summary=episode_data.context_summary,
outcome=_outcome_to_db(episode_data.outcome),
outcome_details=episode_data.outcome_details,
duration_seconds=episode_data.duration_seconds,
tokens_used=episode_data.tokens_used,
lessons_learned=episode_data.lessons_learned,
importance_score=importance,
occurred_at=now,
created_at=now,
updated_at=now,
)
# Generate embedding if generator available
if self._embedding_generator is not None:
try:
text_for_embedding = self._create_embedding_text(episode_data)
embedding = await self._embedding_generator.generate(text_for_embedding)
model.embedding = embedding
except Exception as e:
logger.warning(f"Failed to generate embedding: {e}")
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
logger.debug(f"Recorded episode {model.id} for task {model.task_type}")
return _model_to_episode(model)
def _calculate_importance(self, episode_data: EpisodeCreate) -> float:
"""
Calculate importance score for an episode.
Factors:
- Outcome: Failures are more important to learn from
- Duration: Longer tasks may be more significant
- Token usage: Higher usage may indicate complexity
- Lessons learned: Episodes with lessons are more valuable
"""
score = 0.5 # Base score
# Outcome factor
if episode_data.outcome == Outcome.FAILURE:
score += 0.2 # Failures are important for learning
elif episode_data.outcome == Outcome.PARTIAL:
score += 0.1
# Success is default, no adjustment
# Lessons learned factor
if episode_data.lessons_learned:
score += min(0.15, len(episode_data.lessons_learned) * 0.05)
# Duration factor (longer tasks may be more significant)
if episode_data.duration_seconds > 60:
score += 0.05
if episode_data.duration_seconds > 300:
score += 0.05
# Token usage factor (complex tasks)
if episode_data.tokens_used > 1000:
score += 0.05
# Clamp to valid range
return min(1.0, max(0.0, score))
def _create_embedding_text(self, episode_data: EpisodeCreate) -> str:
"""Create text representation for embedding generation."""
parts = [
f"Task: {episode_data.task_type}",
f"Description: {episode_data.task_description}",
f"Context: {episode_data.context_summary}",
f"Outcome: {episode_data.outcome.value}",
]
if episode_data.outcome_details:
parts.append(f"Details: {episode_data.outcome_details}")
if episode_data.lessons_learned:
parts.append(f"Lessons: {', '.join(episode_data.lessons_learned)}")
return "\n".join(parts)
async def get_by_id(self, episode_id: UUID) -> Episode | None:
"""Get an episode by ID."""
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
return None
return _model_to_episode(model)
async def update_importance(
self,
episode_id: UUID,
importance_score: float,
) -> Episode | None:
"""
Update the importance score of an episode.
Args:
episode_id: Episode to update
importance_score: New importance score (0.0 to 1.0)
Returns:
Updated episode or None if not found
"""
# Validate score
importance_score = min(1.0, max(0.0, importance_score))
stmt = (
update(EpisodeModel)
.where(EpisodeModel.id == episode_id)
.values(
importance_score=importance_score,
updated_at=datetime.now(UTC),
)
.returning(EpisodeModel)
)
result = await self._session.execute(stmt)
model = result.scalar_one_or_none()
if model is None:
return None
await self._session.flush()
return _model_to_episode(model)
async def add_lessons(
self,
episode_id: UUID,
lessons: list[str],
) -> Episode | None:
"""
Add lessons learned to an episode.
Args:
episode_id: Episode to update
lessons: New lessons to add
Returns:
Updated episode or None if not found
"""
# Get current episode
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
return None
# Append lessons
current_lessons: list[str] = model.lessons_learned or [] # type: ignore[assignment]
updated_lessons = current_lessons + lessons
stmt = (
update(EpisodeModel)
.where(EpisodeModel.id == episode_id)
.values(
lessons_learned=updated_lessons,
updated_at=datetime.now(UTC),
)
.returning(EpisodeModel)
)
result = await self._session.execute(stmt)
model = result.scalar_one_or_none()
await self._session.flush()
return _model_to_episode(model) if model else None
async def delete(self, episode_id: UUID) -> bool:
"""
Delete an episode.
Args:
episode_id: Episode to delete
Returns:
True if deleted
"""
query = select(EpisodeModel).where(EpisodeModel.id == episode_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
return False
await self._session.delete(model)
await self._session.flush()
return True
async def count_by_project(
self,
project_id: UUID,
since: datetime | None = None,
) -> int:
"""Count episodes for a project."""
query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
if since is not None:
query = query.where(EpisodeModel.occurred_at >= since)
result = await self._session.execute(query)
return len(list(result.scalars().all()))
async def get_stats(self, project_id: UUID) -> dict[str, Any]:
"""
Get statistics for a project's episodes.
Returns:
Dictionary with episode statistics
"""
query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
result = await self._session.execute(query)
episodes = list(result.scalars().all())
if not episodes:
return {
"total_count": 0,
"success_count": 0,
"failure_count": 0,
"partial_count": 0,
"avg_importance": 0.0,
"avg_duration": 0.0,
"total_tokens": 0,
}
success_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.SUCCESS)
failure_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.FAILURE)
partial_count = sum(1 for e in episodes if e.outcome == EpisodeOutcome.PARTIAL)
avg_importance = sum(e.importance_score for e in episodes) / len(episodes)
avg_duration = sum(e.duration_seconds for e in episodes) / len(episodes)
total_tokens = sum(e.tokens_used for e in episodes)
return {
"total_count": len(episodes),
"success_count": success_count,
"failure_count": failure_count,
"partial_count": partial_count,
"avg_importance": avg_importance,
"avg_duration": avg_duration,
"total_tokens": total_tokens,
}

View File

@@ -0,0 +1,503 @@
# app/services/memory/episodic/retrieval.py
"""
Episode Retrieval Strategies.
Provides different retrieval strategies for finding relevant episodes:
- Semantic similarity (vector search)
- Recency-based
- Outcome-based filtering
- Importance-based ranking
"""
import logging
import time
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Any
from uuid import UUID
from sqlalchemy import and_, desc, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory.enums import EpisodeOutcome
from app.models.memory.episode import Episode as EpisodeModel
from app.services.memory.types import Episode, Outcome, RetrievalResult
logger = logging.getLogger(__name__)
class RetrievalStrategy(str, Enum):
"""Retrieval strategy types."""
SEMANTIC = "semantic"
RECENCY = "recency"
OUTCOME = "outcome"
IMPORTANCE = "importance"
HYBRID = "hybrid"
def _model_to_episode(model: EpisodeModel) -> Episode:
"""Convert SQLAlchemy model to Episode dataclass."""
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
# they return actual values. We use type: ignore to handle this mismatch.
return Episode(
id=model.id, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
agent_instance_id=model.agent_instance_id, # type: ignore[arg-type]
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
session_id=model.session_id, # type: ignore[arg-type]
task_type=model.task_type, # type: ignore[arg-type]
task_description=model.task_description, # type: ignore[arg-type]
actions=model.actions or [], # type: ignore[arg-type]
context_summary=model.context_summary, # type: ignore[arg-type]
outcome=Outcome(model.outcome.value),
outcome_details=model.outcome_details or "", # type: ignore[arg-type]
duration_seconds=model.duration_seconds, # type: ignore[arg-type]
tokens_used=model.tokens_used, # type: ignore[arg-type]
lessons_learned=model.lessons_learned or [], # type: ignore[arg-type]
importance_score=model.importance_score, # type: ignore[arg-type]
embedding=None, # Don't expose raw embedding
occurred_at=model.occurred_at, # type: ignore[arg-type]
created_at=model.created_at, # type: ignore[arg-type]
updated_at=model.updated_at, # type: ignore[arg-type]
)
class BaseRetriever(ABC):
"""Abstract base class for episode retrieval strategies."""
@abstractmethod
async def retrieve(
self,
session: AsyncSession,
project_id: UUID,
limit: int = 10,
**kwargs: Any,
) -> RetrievalResult[Episode]:
"""Retrieve episodes based on the strategy."""
...
class RecencyRetriever(BaseRetriever):
"""Retrieves episodes by recency (most recent first)."""
async def retrieve(
self,
session: AsyncSession,
project_id: UUID,
limit: int = 10,
*,
since: datetime | None = None,
agent_instance_id: UUID | None = None,
**kwargs: Any,
) -> RetrievalResult[Episode]:
"""Retrieve most recent episodes."""
start_time = time.perf_counter()
query = (
select(EpisodeModel)
.where(EpisodeModel.project_id == project_id)
.order_by(desc(EpisodeModel.occurred_at))
.limit(limit)
)
if since is not None:
query = query.where(EpisodeModel.occurred_at >= since)
if agent_instance_id is not None:
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
result = await session.execute(query)
models = list(result.scalars().all())
# Get total count
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
if since is not None:
count_query = count_query.where(EpisodeModel.occurred_at >= since)
count_result = await session.execute(count_query)
total_count = len(list(count_result.scalars().all()))
latency_ms = (time.perf_counter() - start_time) * 1000
return RetrievalResult(
items=[_model_to_episode(m) for m in models],
total_count=total_count,
query="recency",
retrieval_type=RetrievalStrategy.RECENCY.value,
latency_ms=latency_ms,
metadata={"since": since.isoformat() if since else None},
)
class OutcomeRetriever(BaseRetriever):
"""Retrieves episodes filtered by outcome."""
async def retrieve(
self,
session: AsyncSession,
project_id: UUID,
limit: int = 10,
*,
outcome: Outcome | None = None,
agent_instance_id: UUID | None = None,
**kwargs: Any,
) -> RetrievalResult[Episode]:
"""Retrieve episodes by outcome."""
start_time = time.perf_counter()
query = (
select(EpisodeModel)
.where(EpisodeModel.project_id == project_id)
.order_by(desc(EpisodeModel.occurred_at))
.limit(limit)
)
if outcome is not None:
db_outcome = EpisodeOutcome(outcome.value)
query = query.where(EpisodeModel.outcome == db_outcome)
if agent_instance_id is not None:
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
result = await session.execute(query)
models = list(result.scalars().all())
# Get total count
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
if outcome is not None:
count_query = count_query.where(
EpisodeModel.outcome == EpisodeOutcome(outcome.value)
)
count_result = await session.execute(count_query)
total_count = len(list(count_result.scalars().all()))
latency_ms = (time.perf_counter() - start_time) * 1000
return RetrievalResult(
items=[_model_to_episode(m) for m in models],
total_count=total_count,
query=f"outcome:{outcome.value if outcome else 'all'}",
retrieval_type=RetrievalStrategy.OUTCOME.value,
latency_ms=latency_ms,
metadata={"outcome": outcome.value if outcome else None},
)
class TaskTypeRetriever(BaseRetriever):
"""Retrieves episodes filtered by task type."""
async def retrieve(
self,
session: AsyncSession,
project_id: UUID,
limit: int = 10,
*,
task_type: str | None = None,
agent_instance_id: UUID | None = None,
**kwargs: Any,
) -> RetrievalResult[Episode]:
"""Retrieve episodes by task type."""
start_time = time.perf_counter()
query = (
select(EpisodeModel)
.where(EpisodeModel.project_id == project_id)
.order_by(desc(EpisodeModel.occurred_at))
.limit(limit)
)
if task_type is not None:
query = query.where(EpisodeModel.task_type == task_type)
if agent_instance_id is not None:
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
result = await session.execute(query)
models = list(result.scalars().all())
# Get total count
count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id)
if task_type is not None:
count_query = count_query.where(EpisodeModel.task_type == task_type)
count_result = await session.execute(count_query)
total_count = len(list(count_result.scalars().all()))
latency_ms = (time.perf_counter() - start_time) * 1000
return RetrievalResult(
items=[_model_to_episode(m) for m in models],
total_count=total_count,
query=f"task_type:{task_type or 'all'}",
retrieval_type="task_type",
latency_ms=latency_ms,
metadata={"task_type": task_type},
)
class ImportanceRetriever(BaseRetriever):
"""Retrieves episodes ranked by importance score."""
async def retrieve(
self,
session: AsyncSession,
project_id: UUID,
limit: int = 10,
*,
min_importance: float = 0.0,
agent_instance_id: UUID | None = None,
**kwargs: Any,
) -> RetrievalResult[Episode]:
"""Retrieve episodes by importance."""
start_time = time.perf_counter()
query = (
select(EpisodeModel)
.where(
and_(
EpisodeModel.project_id == project_id,
EpisodeModel.importance_score >= min_importance,
)
)
.order_by(desc(EpisodeModel.importance_score))
.limit(limit)
)
if agent_instance_id is not None:
query = query.where(EpisodeModel.agent_instance_id == agent_instance_id)
result = await session.execute(query)
models = list(result.scalars().all())
# Get total count
count_query = select(EpisodeModel).where(
and_(
EpisodeModel.project_id == project_id,
EpisodeModel.importance_score >= min_importance,
)
)
count_result = await session.execute(count_query)
total_count = len(list(count_result.scalars().all()))
latency_ms = (time.perf_counter() - start_time) * 1000
return RetrievalResult(
items=[_model_to_episode(m) for m in models],
total_count=total_count,
query=f"importance>={min_importance}",
retrieval_type=RetrievalStrategy.IMPORTANCE.value,
latency_ms=latency_ms,
metadata={"min_importance": min_importance},
)
class SemanticRetriever(BaseRetriever):
"""Retrieves episodes by semantic similarity using vector search."""
def __init__(self, embedding_generator: Any | None = None) -> None:
"""Initialize with optional embedding generator."""
self._embedding_generator = embedding_generator
async def retrieve(
self,
session: AsyncSession,
project_id: UUID,
limit: int = 10,
*,
query_text: str | None = None,
query_embedding: list[float] | None = None,
agent_instance_id: UUID | None = None,
**kwargs: Any,
) -> RetrievalResult[Episode]:
"""Retrieve episodes by semantic similarity."""
start_time = time.perf_counter()
# If no embedding provided, fall back to recency
if query_embedding is None and query_text is None:
logger.warning(
"No query provided for semantic search, falling back to recency"
)
recency = RecencyRetriever()
fallback_result = await recency.retrieve(
session, project_id, limit, agent_instance_id=agent_instance_id
)
latency_ms = (time.perf_counter() - start_time) * 1000
return RetrievalResult(
items=fallback_result.items,
total_count=fallback_result.total_count,
query="no_query",
retrieval_type=RetrievalStrategy.SEMANTIC.value,
latency_ms=latency_ms,
metadata={"fallback": "recency", "reason": "no_query"},
)
# Generate embedding if needed
embedding = query_embedding
if embedding is None and query_text is not None:
if self._embedding_generator is not None:
embedding = await self._embedding_generator.generate(query_text)
else:
logger.warning("No embedding generator, falling back to recency")
recency = RecencyRetriever()
fallback_result = await recency.retrieve(
session, project_id, limit, agent_instance_id=agent_instance_id
)
latency_ms = (time.perf_counter() - start_time) * 1000
return RetrievalResult(
items=fallback_result.items,
total_count=fallback_result.total_count,
query=query_text,
retrieval_type=RetrievalStrategy.SEMANTIC.value,
latency_ms=latency_ms,
metadata={
"fallback": "recency",
"reason": "no_embedding_generator",
},
)
# For now, use recency if vector search not available
# TODO: Implement proper pgvector similarity search when integrated
logger.debug("Vector search not yet implemented, using recency fallback")
recency = RecencyRetriever()
result = await recency.retrieve(
session, project_id, limit, agent_instance_id=agent_instance_id
)
latency_ms = (time.perf_counter() - start_time) * 1000
return RetrievalResult(
items=result.items,
total_count=result.total_count,
query=query_text or "embedding",
retrieval_type=RetrievalStrategy.SEMANTIC.value,
latency_ms=latency_ms,
metadata={"fallback": "recency"},
)
class EpisodeRetriever:
"""
Unified episode retrieval service.
Provides a single interface for all retrieval strategies.
"""
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> None:
"""Initialize retriever with database session."""
self._session = session
self._retrievers: dict[RetrievalStrategy, BaseRetriever] = {
RetrievalStrategy.RECENCY: RecencyRetriever(),
RetrievalStrategy.OUTCOME: OutcomeRetriever(),
RetrievalStrategy.IMPORTANCE: ImportanceRetriever(),
RetrievalStrategy.SEMANTIC: SemanticRetriever(embedding_generator),
}
async def retrieve(
self,
project_id: UUID,
strategy: RetrievalStrategy = RetrievalStrategy.RECENCY,
limit: int = 10,
**kwargs: Any,
) -> RetrievalResult[Episode]:
"""
Retrieve episodes using the specified strategy.
Args:
project_id: Project to search within
strategy: Retrieval strategy to use
limit: Maximum number of episodes to return
**kwargs: Strategy-specific parameters
Returns:
RetrievalResult containing matching episodes
"""
retriever = self._retrievers.get(strategy)
if retriever is None:
raise ValueError(f"Unknown retrieval strategy: {strategy}")
return await retriever.retrieve(self._session, project_id, limit, **kwargs)
async def get_recent(
self,
project_id: UUID,
limit: int = 10,
since: datetime | None = None,
agent_instance_id: UUID | None = None,
) -> RetrievalResult[Episode]:
"""Get recent episodes."""
return await self.retrieve(
project_id,
RetrievalStrategy.RECENCY,
limit,
since=since,
agent_instance_id=agent_instance_id,
)
async def get_by_outcome(
self,
project_id: UUID,
outcome: Outcome,
limit: int = 10,
agent_instance_id: UUID | None = None,
) -> RetrievalResult[Episode]:
"""Get episodes by outcome."""
return await self.retrieve(
project_id,
RetrievalStrategy.OUTCOME,
limit,
outcome=outcome,
agent_instance_id=agent_instance_id,
)
async def get_by_task_type(
self,
project_id: UUID,
task_type: str,
limit: int = 10,
agent_instance_id: UUID | None = None,
) -> RetrievalResult[Episode]:
"""Get episodes by task type."""
retriever = TaskTypeRetriever()
return await retriever.retrieve(
self._session,
project_id,
limit,
task_type=task_type,
agent_instance_id=agent_instance_id,
)
async def get_important(
self,
project_id: UUID,
limit: int = 10,
min_importance: float = 0.7,
agent_instance_id: UUID | None = None,
) -> RetrievalResult[Episode]:
"""Get high-importance episodes."""
return await self.retrieve(
project_id,
RetrievalStrategy.IMPORTANCE,
limit,
min_importance=min_importance,
agent_instance_id=agent_instance_id,
)
async def search_similar(
self,
project_id: UUID,
query: str,
limit: int = 10,
agent_instance_id: UUID | None = None,
) -> RetrievalResult[Episode]:
"""Search for semantically similar episodes."""
return await self.retrieve(
project_id,
RetrievalStrategy.SEMANTIC,
limit,
query_text=query,
agent_instance_id=agent_instance_id,
)

View File

@@ -0,0 +1,222 @@
"""
Memory System Exceptions
Custom exception classes for the Agent Memory System.
"""
from typing import Any
from uuid import UUID
class MemoryError(Exception):
"""Base exception for all memory-related errors."""
def __init__(
self,
message: str,
*,
memory_type: str | None = None,
scope_type: str | None = None,
scope_id: str | None = None,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message)
self.message = message
self.memory_type = memory_type
self.scope_type = scope_type
self.scope_id = scope_id
self.details = details or {}
class MemoryNotFoundError(MemoryError):
"""Raised when a memory item is not found."""
def __init__(
self,
message: str = "Memory not found",
*,
memory_id: UUID | str | None = None,
key: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.memory_id = memory_id
self.key = key
class MemoryCapacityError(MemoryError):
"""Raised when memory capacity limits are exceeded."""
def __init__(
self,
message: str = "Memory capacity exceeded",
*,
current_size: int = 0,
max_size: int = 0,
item_count: int = 0,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.current_size = current_size
self.max_size = max_size
self.item_count = item_count
class MemoryExpiredError(MemoryError):
"""Raised when attempting to access expired memory."""
def __init__(
self,
message: str = "Memory has expired",
*,
key: str | None = None,
expired_at: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.key = key
self.expired_at = expired_at
class MemoryStorageError(MemoryError):
"""Raised when memory storage operations fail."""
def __init__(
self,
message: str = "Memory storage operation failed",
*,
operation: str | None = None,
backend: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.operation = operation
self.backend = backend
class MemoryConnectionError(MemoryError):
"""Raised when memory storage connection fails."""
def __init__(
self,
message: str = "Memory connection failed",
*,
backend: str | None = None,
host: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.backend = backend
self.host = host
class MemorySerializationError(MemoryError):
"""Raised when memory serialization/deserialization fails."""
def __init__(
self,
message: str = "Memory serialization failed",
*,
content_type: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.content_type = content_type
class MemoryScopeError(MemoryError):
"""Raised when memory scope operations fail."""
def __init__(
self,
message: str = "Memory scope error",
*,
requested_scope: str | None = None,
allowed_scopes: list[str] | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.requested_scope = requested_scope
self.allowed_scopes = allowed_scopes or []
class MemoryConsolidationError(MemoryError):
"""Raised when memory consolidation fails."""
def __init__(
self,
message: str = "Memory consolidation failed",
*,
source_type: str | None = None,
target_type: str | None = None,
items_processed: int = 0,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.source_type = source_type
self.target_type = target_type
self.items_processed = items_processed
class MemoryRetrievalError(MemoryError):
"""Raised when memory retrieval fails."""
def __init__(
self,
message: str = "Memory retrieval failed",
*,
query: str | None = None,
retrieval_type: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.query = query
self.retrieval_type = retrieval_type
class EmbeddingError(MemoryError):
"""Raised when embedding generation fails."""
def __init__(
self,
message: str = "Embedding generation failed",
*,
content_length: int = 0,
model: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.content_length = content_length
self.model = model
class CheckpointError(MemoryError):
"""Raised when checkpoint operations fail."""
def __init__(
self,
message: str = "Checkpoint operation failed",
*,
checkpoint_id: str | None = None,
operation: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.checkpoint_id = checkpoint_id
self.operation = operation
class MemoryConflictError(MemoryError):
"""Raised when there's a conflict in memory (e.g., contradictory facts)."""
def __init__(
self,
message: str = "Memory conflict detected",
*,
conflicting_ids: list[str | UUID] | None = None,
conflict_type: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.conflicting_ids = conflicting_ids or []
self.conflict_type = conflict_type

View File

@@ -0,0 +1,56 @@
# app/services/memory/indexing/__init__.py
"""
Memory Indexing & Retrieval.
Provides vector embeddings and multiple index types for efficient memory search:
- Vector index for semantic similarity
- Temporal index for time-based queries
- Entity index for entity lookups
- Outcome index for success/failure filtering
"""
from .index import (
EntityIndex,
EntityIndexEntry,
IndexEntry,
MemoryIndex,
MemoryIndexer,
OutcomeIndex,
OutcomeIndexEntry,
TemporalIndex,
TemporalIndexEntry,
VectorIndex,
VectorIndexEntry,
get_memory_indexer,
)
from .retrieval import (
CacheEntry,
RelevanceScorer,
RetrievalCache,
RetrievalEngine,
RetrievalQuery,
ScoredResult,
get_retrieval_engine,
)
__all__ = [
"CacheEntry",
"EntityIndex",
"EntityIndexEntry",
"IndexEntry",
"MemoryIndex",
"MemoryIndexer",
"OutcomeIndex",
"OutcomeIndexEntry",
"RelevanceScorer",
"RetrievalCache",
"RetrievalEngine",
"RetrievalQuery",
"ScoredResult",
"TemporalIndex",
"TemporalIndexEntry",
"VectorIndex",
"VectorIndexEntry",
"get_memory_indexer",
"get_retrieval_engine",
]

View File

@@ -0,0 +1,858 @@
# app/services/memory/indexing/index.py
"""
Memory Indexing.
Provides multiple indexing strategies for efficient memory retrieval:
- Vector embeddings for semantic search
- Temporal index for time-based queries
- Entity index for entity-based lookups
- Outcome index for success/failure filtering
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import Any, TypeVar
from uuid import UUID
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
logger = logging.getLogger(__name__)
T = TypeVar("T", Episode, Fact, Procedure)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class IndexEntry:
"""A single entry in an index."""
memory_id: UUID
memory_type: MemoryType
indexed_at: datetime = field(default_factory=_utcnow)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class VectorIndexEntry(IndexEntry):
"""An entry with vector embedding."""
embedding: list[float] = field(default_factory=list)
dimension: int = 0
def __post_init__(self) -> None:
"""Set dimension from embedding."""
if self.embedding:
self.dimension = len(self.embedding)
@dataclass
class TemporalIndexEntry(IndexEntry):
"""An entry indexed by time."""
timestamp: datetime = field(default_factory=_utcnow)
@dataclass
class EntityIndexEntry(IndexEntry):
"""An entry indexed by entity."""
entity_type: str = ""
entity_value: str = ""
@dataclass
class OutcomeIndexEntry(IndexEntry):
"""An entry indexed by outcome."""
outcome: Outcome = Outcome.SUCCESS
class MemoryIndex[T](ABC):
"""Abstract base class for memory indices."""
@abstractmethod
async def add(self, item: T) -> IndexEntry:
"""Add an item to the index."""
...
@abstractmethod
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the index."""
...
@abstractmethod
async def search(
self,
query: Any,
limit: int = 10,
**kwargs: Any,
) -> list[IndexEntry]:
"""Search the index."""
...
@abstractmethod
async def clear(self) -> int:
"""Clear all entries from the index."""
...
@abstractmethod
async def count(self) -> int:
"""Get the number of entries in the index."""
...
class VectorIndex(MemoryIndex[T]):
"""
Vector-based index using embeddings for semantic similarity search.
Uses cosine similarity for matching.
"""
def __init__(self, dimension: int = 1536) -> None:
"""
Initialize the vector index.
Args:
dimension: Embedding dimension (default 1536 for OpenAI)
"""
self._dimension = dimension
self._entries: dict[UUID, VectorIndexEntry] = {}
logger.info(f"Initialized VectorIndex with dimension={dimension}")
async def add(self, item: T) -> VectorIndexEntry:
"""
Add an item to the vector index.
Args:
item: Memory item with embedding
Returns:
The created index entry
"""
embedding = getattr(item, "embedding", None) or []
entry = VectorIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
embedding=embedding,
dimension=len(embedding),
)
self._entries[item.id] = entry
logger.debug(f"Added {item.id} to vector index")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the vector index."""
if memory_id in self._entries:
del self._entries[memory_id]
logger.debug(f"Removed {memory_id} from vector index")
return True
return False
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
min_similarity: float = 0.0,
**kwargs: Any,
) -> list[VectorIndexEntry]:
"""
Search for similar items using vector similarity.
Args:
query: Query embedding vector
limit: Maximum results to return
min_similarity: Minimum similarity threshold (0-1)
**kwargs: Additional filter parameters
Returns:
List of matching entries sorted by similarity
"""
if not isinstance(query, list) or not query:
return []
results: list[tuple[float, VectorIndexEntry]] = []
for entry in self._entries.values():
if not entry.embedding:
continue
similarity = self._cosine_similarity(query, entry.embedding)
if similarity >= min_similarity:
results.append((similarity, entry))
# Sort by similarity descending
results.sort(key=lambda x: x[0], reverse=True)
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
if memory_type:
results = [(s, e) for s, e in results if e.memory_type == memory_type]
# Store similarity in metadata for the returned entries
# Use a copy of metadata to avoid mutating cached entries
output = []
for similarity, entry in results[:limit]:
# Create a shallow copy of the entry with updated metadata
entry_with_score = VectorIndexEntry(
memory_id=entry.memory_id,
memory_type=entry.memory_type,
embedding=entry.embedding,
metadata={**entry.metadata, "similarity": similarity},
)
output.append(entry_with_score)
logger.debug(f"Vector search returned {len(output)} results")
return output
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
logger.info(f"Cleared {count} entries from vector index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
if len(a) != len(b) or len(a) == 0:
return 0.0
dot_product = sum(x * y for x, y in zip(a, b, strict=True))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot_product / (norm_a * norm_b)
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
class TemporalIndex(MemoryIndex[T]):
"""
Time-based index for efficient temporal queries.
Supports:
- Range queries (between timestamps)
- Recent items (within last N seconds/hours/days)
- Oldest/newest sorting
"""
def __init__(self) -> None:
"""Initialize the temporal index."""
self._entries: dict[UUID, TemporalIndexEntry] = {}
# Sorted list for efficient range queries
self._sorted_entries: list[tuple[datetime, UUID]] = []
logger.info("Initialized TemporalIndex")
async def add(self, item: T) -> TemporalIndexEntry:
"""
Add an item to the temporal index.
Args:
item: Memory item with timestamp
Returns:
The created index entry
"""
# Get timestamp from various possible fields
timestamp = self._get_timestamp(item)
entry = TemporalIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
timestamp=timestamp,
)
self._entries[item.id] = entry
self._insert_sorted(timestamp, item.id)
logger.debug(f"Added {item.id} to temporal index at {timestamp}")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the temporal index."""
if memory_id not in self._entries:
return False
self._entries.pop(memory_id)
self._sorted_entries = [
(ts, mid) for ts, mid in self._sorted_entries if mid != memory_id
]
logger.debug(f"Removed {memory_id} from temporal index")
return True
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
start_time: datetime | None = None,
end_time: datetime | None = None,
recent_seconds: float | None = None,
order: str = "desc",
**kwargs: Any,
) -> list[TemporalIndexEntry]:
"""
Search for items by time.
Args:
query: Ignored for temporal search
limit: Maximum results to return
start_time: Start of time range
end_time: End of time range
recent_seconds: Get items from last N seconds
order: Sort order ("asc" or "desc")
**kwargs: Additional filter parameters
Returns:
List of matching entries sorted by time
"""
if recent_seconds is not None:
start_time = _utcnow() - timedelta(seconds=recent_seconds)
end_time = _utcnow()
# Filter by time range
results: list[TemporalIndexEntry] = []
for entry in self._entries.values():
if start_time and entry.timestamp < start_time:
continue
if end_time and entry.timestamp > end_time:
continue
results.append(entry)
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
if memory_type:
results = [e for e in results if e.memory_type == memory_type]
# Sort by timestamp
results.sort(key=lambda e: e.timestamp, reverse=(order == "desc"))
logger.debug(f"Temporal search returned {min(len(results), limit)} results")
return results[:limit]
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
self._sorted_entries.clear()
logger.info(f"Cleared {count} entries from temporal index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None:
"""Insert entry maintaining sorted order."""
# Binary search insert for efficiency
low, high = 0, len(self._sorted_entries)
while low < high:
mid = (low + high) // 2
if self._sorted_entries[mid][0] < timestamp:
low = mid + 1
else:
high = mid
self._sorted_entries.insert(low, (timestamp, memory_id))
def _get_timestamp(self, item: T) -> datetime:
"""Get the relevant timestamp for an item."""
if hasattr(item, "occurred_at"):
return item.occurred_at
if hasattr(item, "first_learned"):
return item.first_learned
if hasattr(item, "last_used") and item.last_used:
return item.last_used
if hasattr(item, "created_at"):
return item.created_at
return _utcnow()
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
class EntityIndex(MemoryIndex[T]):
"""
Entity-based index for lookups by entities mentioned in memories.
Supports:
- Single entity lookup
- Multi-entity intersection
- Entity type filtering
"""
def __init__(self) -> None:
"""Initialize the entity index."""
# Main storage
self._entries: dict[UUID, EntityIndexEntry] = {}
# Inverted index: entity -> set of memory IDs
self._entity_to_memories: dict[str, set[UUID]] = {}
# Memory to entities mapping
self._memory_to_entities: dict[UUID, set[str]] = {}
logger.info("Initialized EntityIndex")
async def add(self, item: T) -> EntityIndexEntry:
"""
Add an item to the entity index.
Args:
item: Memory item with entity information
Returns:
The created index entry
"""
entities = self._extract_entities(item)
# Create entry for the primary entity (or first one)
primary_entity = entities[0] if entities else ("unknown", "unknown")
entry = EntityIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
entity_type=primary_entity[0],
entity_value=primary_entity[1],
)
self._entries[item.id] = entry
# Update inverted indices
entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities}
self._memory_to_entities[item.id] = entity_keys
for entity_key in entity_keys:
if entity_key not in self._entity_to_memories:
self._entity_to_memories[entity_key] = set()
self._entity_to_memories[entity_key].add(item.id)
logger.debug(f"Added {item.id} to entity index with {len(entities)} entities")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the entity index."""
if memory_id not in self._entries:
return False
# Remove from inverted index
if memory_id in self._memory_to_entities:
for entity_key in self._memory_to_entities[memory_id]:
if entity_key in self._entity_to_memories:
self._entity_to_memories[entity_key].discard(memory_id)
if not self._entity_to_memories[entity_key]:
del self._entity_to_memories[entity_key]
del self._memory_to_entities[memory_id]
del self._entries[memory_id]
logger.debug(f"Removed {memory_id} from entity index")
return True
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
entity_type: str | None = None,
entity_value: str | None = None,
entities: list[tuple[str, str]] | None = None,
match_all: bool = False,
**kwargs: Any,
) -> list[EntityIndexEntry]:
"""
Search for items by entity.
Args:
query: Entity value to search (if entity_type not specified)
limit: Maximum results to return
entity_type: Type of entity to filter
entity_value: Specific entity value
entities: List of (type, value) tuples to match
match_all: If True, require all entities to match
**kwargs: Additional filter parameters
Returns:
List of matching entries
"""
matching_ids: set[UUID] | None = None
# Handle single entity query
if entity_type and entity_value:
entities = [(entity_type, entity_value)]
elif entity_value is None and isinstance(query, str):
# Search across all entity types
entity_value = query
if entities:
for etype, evalue in entities:
entity_key = f"{etype}:{evalue}"
if entity_key in self._entity_to_memories:
ids = self._entity_to_memories[entity_key]
if matching_ids is None:
matching_ids = ids.copy()
elif match_all:
matching_ids &= ids
else:
matching_ids |= ids
elif match_all:
# Required entity not found
matching_ids = set()
break
elif entity_value:
# Search for value across all types
matching_ids = set()
for entity_key, ids in self._entity_to_memories.items():
if entity_value.lower() in entity_key.lower():
matching_ids |= ids
if matching_ids is None:
matching_ids = set(self._entries.keys())
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
results = []
for mid in matching_ids:
if mid in self._entries:
entry = self._entries[mid]
if memory_type and entry.memory_type != memory_type:
continue
results.append(entry)
logger.debug(f"Entity search returned {min(len(results), limit)} results")
return results[:limit]
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
self._entity_to_memories.clear()
self._memory_to_entities.clear()
logger.info(f"Cleared {count} entries from entity index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]:
"""Get all entities for a memory item."""
if memory_id not in self._memory_to_entities:
return []
entities = []
for entity_key in self._memory_to_entities[memory_id]:
if ":" in entity_key:
etype, evalue = entity_key.split(":", 1)
entities.append((etype, evalue))
return entities
def _extract_entities(self, item: T) -> list[tuple[str, str]]:
"""Extract entities from a memory item."""
entities: list[tuple[str, str]] = []
if isinstance(item, Episode):
# Extract from task type and context
entities.append(("task_type", item.task_type))
if item.project_id:
entities.append(("project", str(item.project_id)))
if item.agent_instance_id:
entities.append(("agent_instance", str(item.agent_instance_id)))
if item.agent_type_id:
entities.append(("agent_type", str(item.agent_type_id)))
elif isinstance(item, Fact):
# Subject and object are entities
entities.append(("subject", item.subject))
entities.append(("object", item.object))
if item.project_id:
entities.append(("project", str(item.project_id)))
elif isinstance(item, Procedure):
entities.append(("procedure", item.name))
if item.project_id:
entities.append(("project", str(item.project_id)))
if item.agent_type_id:
entities.append(("agent_type", str(item.agent_type_id)))
return entities
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
class OutcomeIndex(MemoryIndex[T]):
"""
Outcome-based index for filtering by success/failure.
Primarily used for episodes and procedures.
"""
def __init__(self) -> None:
"""Initialize the outcome index."""
self._entries: dict[UUID, OutcomeIndexEntry] = {}
# Inverted index by outcome
self._outcome_to_memories: dict[Outcome, set[UUID]] = {
Outcome.SUCCESS: set(),
Outcome.FAILURE: set(),
Outcome.PARTIAL: set(),
}
logger.info("Initialized OutcomeIndex")
async def add(self, item: T) -> OutcomeIndexEntry:
"""
Add an item to the outcome index.
Args:
item: Memory item with outcome information
Returns:
The created index entry
"""
outcome = self._get_outcome(item)
entry = OutcomeIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
outcome=outcome,
)
self._entries[item.id] = entry
self._outcome_to_memories[outcome].add(item.id)
logger.debug(f"Added {item.id} to outcome index with {outcome.value}")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the outcome index."""
if memory_id not in self._entries:
return False
entry = self._entries.pop(memory_id)
self._outcome_to_memories[entry.outcome].discard(memory_id)
logger.debug(f"Removed {memory_id} from outcome index")
return True
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
outcome: Outcome | None = None,
outcomes: list[Outcome] | None = None,
**kwargs: Any,
) -> list[OutcomeIndexEntry]:
"""
Search for items by outcome.
Args:
query: Ignored for outcome search
limit: Maximum results to return
outcome: Single outcome to filter
outcomes: Multiple outcomes to filter (OR)
**kwargs: Additional filter parameters
Returns:
List of matching entries
"""
if outcome:
outcomes = [outcome]
if outcomes:
matching_ids: set[UUID] = set()
for o in outcomes:
matching_ids |= self._outcome_to_memories.get(o, set())
else:
matching_ids = set(self._entries.keys())
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
results = []
for mid in matching_ids:
if mid in self._entries:
entry = self._entries[mid]
if memory_type and entry.memory_type != memory_type:
continue
results.append(entry)
logger.debug(f"Outcome search returned {min(len(results), limit)} results")
return results[:limit]
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
for outcome in self._outcome_to_memories:
self._outcome_to_memories[outcome].clear()
logger.info(f"Cleared {count} entries from outcome index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
async def get_outcome_stats(self) -> dict[Outcome, int]:
"""Get statistics on outcomes."""
return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()}
def _get_outcome(self, item: T) -> Outcome:
"""Get the outcome for an item."""
if isinstance(item, Episode):
return item.outcome
elif isinstance(item, Procedure):
# Derive from success rate
if item.success_rate >= 0.8:
return Outcome.SUCCESS
elif item.success_rate <= 0.2:
return Outcome.FAILURE
return Outcome.PARTIAL
return Outcome.SUCCESS
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
@dataclass
class MemoryIndexer:
"""
Unified indexer that manages all index types.
Provides a single interface for indexing and searching across
multiple index types.
"""
vector_index: VectorIndex[Any] = field(default_factory=VectorIndex)
temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex)
entity_index: EntityIndex[Any] = field(default_factory=EntityIndex)
outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex)
async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]:
"""
Index an item across all applicable indices.
Args:
item: Memory item to index
Returns:
Dictionary of index type to entry
"""
results: dict[str, IndexEntry] = {}
# Vector index (if embedding present)
if getattr(item, "embedding", None):
results["vector"] = await self.vector_index.add(item)
# Temporal index
results["temporal"] = await self.temporal_index.add(item)
# Entity index
results["entity"] = await self.entity_index.add(item)
# Outcome index (for episodes and procedures)
if isinstance(item, (Episode, Procedure)):
results["outcome"] = await self.outcome_index.add(item)
logger.info(
f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}"
)
return results
async def remove(self, memory_id: UUID) -> dict[str, bool]:
"""
Remove an item from all indices.
Args:
memory_id: ID of the memory to remove
Returns:
Dictionary of index type to removal success
"""
results = {
"vector": await self.vector_index.remove(memory_id),
"temporal": await self.temporal_index.remove(memory_id),
"entity": await self.entity_index.remove(memory_id),
"outcome": await self.outcome_index.remove(memory_id),
}
removed_from = [k for k, v in results.items() if v]
if removed_from:
logger.info(f"Removed {memory_id} from indices: {removed_from}")
return results
async def clear_all(self) -> dict[str, int]:
"""
Clear all indices.
Returns:
Dictionary of index type to count cleared
"""
return {
"vector": await self.vector_index.clear(),
"temporal": await self.temporal_index.clear(),
"entity": await self.entity_index.clear(),
"outcome": await self.outcome_index.clear(),
}
async def get_stats(self) -> dict[str, int]:
"""
Get statistics for all indices.
Returns:
Dictionary of index type to entry count
"""
return {
"vector": await self.vector_index.count(),
"temporal": await self.temporal_index.count(),
"entity": await self.entity_index.count(),
"outcome": await self.outcome_index.count(),
}
# Singleton indexer instance
_indexer: MemoryIndexer | None = None
def get_memory_indexer() -> MemoryIndexer:
"""Get the singleton memory indexer instance."""
global _indexer
if _indexer is None:
_indexer = MemoryIndexer()
return _indexer

View File

@@ -0,0 +1,742 @@
# app/services/memory/indexing/retrieval.py
"""
Memory Retrieval Engine.
Provides hybrid retrieval capabilities combining:
- Vector similarity search
- Temporal filtering
- Entity filtering
- Outcome filtering
- Relevance scoring
- Result caching
"""
import hashlib
import logging
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any, TypeVar
from uuid import UUID
from app.services.memory.types import (
Episode,
Fact,
MemoryType,
Outcome,
Procedure,
RetrievalResult,
)
from .index import (
MemoryIndexer,
get_memory_indexer,
)
logger = logging.getLogger(__name__)
T = TypeVar("T", Episode, Fact, Procedure)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class RetrievalQuery:
"""Query parameters for memory retrieval."""
# Text/semantic query
query_text: str | None = None
query_embedding: list[float] | None = None
# Temporal filters
start_time: datetime | None = None
end_time: datetime | None = None
recent_seconds: float | None = None
# Entity filters
entities: list[tuple[str, str]] | None = None
entity_match_all: bool = False
# Outcome filters
outcomes: list[Outcome] | None = None
# Memory type filter
memory_types: list[MemoryType] | None = None
# Result options
limit: int = 10
min_relevance: float = 0.0
# Retrieval mode
use_vector: bool = True
use_temporal: bool = True
use_entity: bool = True
use_outcome: bool = True
def to_cache_key(self) -> str:
"""Generate a cache key for this query."""
key_parts = [
self.query_text or "",
str(self.start_time),
str(self.end_time),
str(self.recent_seconds),
str(self.entities),
str(self.outcomes),
str(self.memory_types),
str(self.limit),
str(self.min_relevance),
]
key_string = "|".join(key_parts)
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
@dataclass
class ScoredResult:
"""A retrieval result with relevance score."""
memory_id: UUID
memory_type: MemoryType
relevance_score: float
score_breakdown: dict[str, float] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class CacheEntry:
"""A cached retrieval result."""
results: list[ScoredResult]
created_at: datetime
ttl_seconds: float
query_key: str
def is_expired(self) -> bool:
"""Check if this cache entry has expired."""
age = (_utcnow() - self.created_at).total_seconds()
return age > self.ttl_seconds
class RelevanceScorer:
"""
Calculates relevance scores for retrieved memories.
Combines multiple signals:
- Vector similarity (if available)
- Temporal recency
- Entity match count
- Outcome preference
- Importance/confidence
"""
def __init__(
self,
vector_weight: float = 0.4,
recency_weight: float = 0.2,
entity_weight: float = 0.2,
outcome_weight: float = 0.1,
importance_weight: float = 0.1,
) -> None:
"""
Initialize the relevance scorer.
Args:
vector_weight: Weight for vector similarity (0-1)
recency_weight: Weight for temporal recency (0-1)
entity_weight: Weight for entity matches (0-1)
outcome_weight: Weight for outcome preference (0-1)
importance_weight: Weight for importance score (0-1)
"""
total = (
vector_weight
+ recency_weight
+ entity_weight
+ outcome_weight
+ importance_weight
)
# Normalize weights
self.vector_weight = vector_weight / total
self.recency_weight = recency_weight / total
self.entity_weight = entity_weight / total
self.outcome_weight = outcome_weight / total
self.importance_weight = importance_weight / total
def score(
self,
memory_id: UUID,
memory_type: MemoryType,
vector_similarity: float | None = None,
timestamp: datetime | None = None,
entity_match_count: int = 0,
entity_total: int = 1,
outcome: Outcome | None = None,
importance: float = 0.5,
preferred_outcomes: list[Outcome] | None = None,
) -> ScoredResult:
"""
Calculate a relevance score for a memory.
Args:
memory_id: ID of the memory
memory_type: Type of memory
vector_similarity: Similarity score from vector search (0-1)
timestamp: Timestamp of the memory
entity_match_count: Number of matching entities
entity_total: Total entities in query
outcome: Outcome of the memory
importance: Importance score of the memory (0-1)
preferred_outcomes: Outcomes to prefer
Returns:
Scored result with breakdown
"""
breakdown: dict[str, float] = {}
# Vector similarity score
if vector_similarity is not None:
breakdown["vector"] = vector_similarity
else:
breakdown["vector"] = 0.5 # Neutral if no vector
# Recency score (exponential decay)
if timestamp:
age_hours = (_utcnow() - timestamp).total_seconds() / 3600
# Decay with half-life of 24 hours
breakdown["recency"] = 2 ** (-age_hours / 24)
else:
breakdown["recency"] = 0.5
# Entity match score
if entity_total > 0:
breakdown["entity"] = entity_match_count / entity_total
else:
breakdown["entity"] = 1.0 # No entity filter = full score
# Outcome score
if preferred_outcomes and outcome:
breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0
else:
breakdown["outcome"] = 0.5 # Neutral if no preference
# Importance score
breakdown["importance"] = importance
# Calculate weighted sum
total_score = (
breakdown["vector"] * self.vector_weight
+ breakdown["recency"] * self.recency_weight
+ breakdown["entity"] * self.entity_weight
+ breakdown["outcome"] * self.outcome_weight
+ breakdown["importance"] * self.importance_weight
)
return ScoredResult(
memory_id=memory_id,
memory_type=memory_type,
relevance_score=total_score,
score_breakdown=breakdown,
)
class RetrievalCache:
"""
In-memory cache for retrieval results.
Supports TTL-based expiration and LRU eviction with O(1) operations.
Uses OrderedDict for efficient LRU tracking.
"""
def __init__(
self,
max_entries: int = 1000,
default_ttl_seconds: float = 300,
) -> None:
"""
Initialize the cache.
Args:
max_entries: Maximum cache entries
default_ttl_seconds: Default TTL for entries
"""
# OrderedDict maintains insertion order; we use move_to_end for O(1) LRU
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._max_entries = max_entries
self._default_ttl = default_ttl_seconds
logger.info(
f"Initialized RetrievalCache with max_entries={max_entries}, "
f"ttl={default_ttl_seconds}s"
)
def get(self, query_key: str) -> list[ScoredResult] | None:
"""
Get cached results for a query.
Args:
query_key: Cache key for the query
Returns:
Cached results or None if not found/expired
"""
if query_key not in self._cache:
return None
entry = self._cache[query_key]
if entry.is_expired():
del self._cache[query_key]
return None
# Update access order (LRU) - O(1) with OrderedDict
self._cache.move_to_end(query_key)
logger.debug(f"Cache hit for {query_key}")
return entry.results
def put(
self,
query_key: str,
results: list[ScoredResult],
ttl_seconds: float | None = None,
) -> None:
"""
Cache results for a query.
Args:
query_key: Cache key for the query
results: Results to cache
ttl_seconds: TTL for this entry (or default)
"""
# Evict oldest entries if at capacity - O(1) with popitem(last=False)
while len(self._cache) >= self._max_entries:
self._cache.popitem(last=False)
entry = CacheEntry(
results=results,
created_at=_utcnow(),
ttl_seconds=ttl_seconds or self._default_ttl,
query_key=query_key,
)
self._cache[query_key] = entry
logger.debug(f"Cached {len(results)} results for {query_key}")
def invalidate(self, query_key: str) -> bool:
"""
Invalidate a specific cache entry.
Args:
query_key: Cache key to invalidate
Returns:
True if entry was found and removed
"""
if query_key in self._cache:
del self._cache[query_key]
return True
return False
def invalidate_by_memory(self, memory_id: UUID) -> int:
"""
Invalidate all cache entries containing a specific memory.
Args:
memory_id: Memory ID to invalidate
Returns:
Number of entries invalidated
"""
keys_to_remove = []
for key, entry in self._cache.items():
if any(r.memory_id == memory_id for r in entry.results):
keys_to_remove.append(key)
for key in keys_to_remove:
self.invalidate(key)
if keys_to_remove:
logger.debug(
f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}"
)
return len(keys_to_remove)
def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries cleared
"""
count = len(self._cache)
self._cache.clear()
logger.info(f"Cleared {count} cache entries")
return count
def get_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
expired_count = sum(1 for e in self._cache.values() if e.is_expired())
return {
"total_entries": len(self._cache),
"expired_entries": expired_count,
"max_entries": self._max_entries,
"default_ttl_seconds": self._default_ttl,
}
class RetrievalEngine:
"""
Hybrid retrieval engine for memory search.
Combines multiple index types for comprehensive retrieval:
- Vector search for semantic similarity
- Temporal index for time-based filtering
- Entity index for entity-based lookups
- Outcome index for success/failure filtering
Results are scored and ranked using relevance scoring.
"""
def __init__(
self,
indexer: MemoryIndexer | None = None,
scorer: RelevanceScorer | None = None,
cache: RetrievalCache | None = None,
enable_cache: bool = True,
) -> None:
"""
Initialize the retrieval engine.
Args:
indexer: Memory indexer (defaults to singleton)
scorer: Relevance scorer (defaults to new instance)
cache: Retrieval cache (defaults to new instance)
enable_cache: Whether to enable result caching
"""
self._indexer = indexer or get_memory_indexer()
self._scorer = scorer or RelevanceScorer()
self._cache = cache or RetrievalCache() if enable_cache else None
self._enable_cache = enable_cache
logger.info(f"Initialized RetrievalEngine with cache={enable_cache}")
async def retrieve(
self,
query: RetrievalQuery,
use_cache: bool = True,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve relevant memories using hybrid search.
Args:
query: Retrieval query parameters
use_cache: Whether to use cached results
Returns:
Retrieval result with scored items
"""
start_time = _utcnow()
# Check cache
cache_key = query.to_cache_key()
if use_cache and self._cache:
cached = self._cache.get(cache_key)
if cached:
latency = (_utcnow() - start_time).total_seconds() * 1000
return RetrievalResult(
items=cached,
total_count=len(cached),
query=query.query_text or "",
retrieval_type="cached",
latency_ms=latency,
metadata={"cache_hit": True},
)
# Collect candidates from each index
candidates: dict[UUID, dict[str, Any]] = {}
# Vector search
if query.use_vector and query.query_embedding:
vector_results = await self._indexer.vector_index.search(
query=query.query_embedding,
limit=query.limit * 3, # Get more for filtering
min_similarity=query.min_relevance,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for entry in vector_results:
if entry.memory_id not in candidates:
candidates[entry.memory_id] = {
"memory_type": entry.memory_type,
"sources": [],
}
candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get(
"similarity", 0.5
)
candidates[entry.memory_id]["sources"].append("vector")
# Temporal search
if query.use_temporal and (
query.start_time or query.end_time or query.recent_seconds
):
temporal_results = await self._indexer.temporal_index.search(
query=None,
limit=query.limit * 3,
start_time=query.start_time,
end_time=query.end_time,
recent_seconds=query.recent_seconds,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for temporal_entry in temporal_results:
if temporal_entry.memory_id not in candidates:
candidates[temporal_entry.memory_id] = {
"memory_type": temporal_entry.memory_type,
"sources": [],
}
candidates[temporal_entry.memory_id]["timestamp"] = (
temporal_entry.timestamp
)
candidates[temporal_entry.memory_id]["sources"].append("temporal")
# Entity search
if query.use_entity and query.entities:
entity_results = await self._indexer.entity_index.search(
query=None,
limit=query.limit * 3,
entities=query.entities,
match_all=query.entity_match_all,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for entity_entry in entity_results:
if entity_entry.memory_id not in candidates:
candidates[entity_entry.memory_id] = {
"memory_type": entity_entry.memory_type,
"sources": [],
}
# Count entity matches
entity_count = candidates[entity_entry.memory_id].get(
"entity_match_count", 0
)
candidates[entity_entry.memory_id]["entity_match_count"] = (
entity_count + 1
)
candidates[entity_entry.memory_id]["sources"].append("entity")
# Outcome search
if query.use_outcome and query.outcomes:
outcome_results = await self._indexer.outcome_index.search(
query=None,
limit=query.limit * 3,
outcomes=query.outcomes,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for outcome_entry in outcome_results:
if outcome_entry.memory_id not in candidates:
candidates[outcome_entry.memory_id] = {
"memory_type": outcome_entry.memory_type,
"sources": [],
}
candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome
candidates[outcome_entry.memory_id]["sources"].append("outcome")
# Score and rank candidates
scored_results: list[ScoredResult] = []
entity_total = len(query.entities) if query.entities else 1
for memory_id, data in candidates.items():
scored = self._scorer.score(
memory_id=memory_id,
memory_type=data["memory_type"],
vector_similarity=data.get("vector_similarity"),
timestamp=data.get("timestamp"),
entity_match_count=data.get("entity_match_count", 0),
entity_total=entity_total,
outcome=data.get("outcome"),
preferred_outcomes=query.outcomes,
)
scored.metadata["sources"] = data.get("sources", [])
# Filter by minimum relevance
if scored.relevance_score >= query.min_relevance:
scored_results.append(scored)
# Sort by relevance score
scored_results.sort(key=lambda x: x.relevance_score, reverse=True)
# Apply limit
final_results = scored_results[: query.limit]
# Cache results
if use_cache and self._cache and final_results:
self._cache.put(cache_key, final_results)
latency = (_utcnow() - start_time).total_seconds() * 1000
logger.info(
f"Retrieved {len(final_results)} results from {len(candidates)} candidates "
f"in {latency:.2f}ms"
)
return RetrievalResult(
items=final_results,
total_count=len(candidates),
query=query.query_text or "",
retrieval_type="hybrid",
latency_ms=latency,
metadata={
"cache_hit": False,
"candidates_count": len(candidates),
"filtered_count": len(scored_results),
},
)
async def retrieve_similar(
self,
embedding: list[float],
limit: int = 10,
min_similarity: float = 0.5,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve memories similar to a given embedding.
Args:
embedding: Query embedding
limit: Maximum results
min_similarity: Minimum similarity threshold
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
query_embedding=embedding,
limit=limit,
min_relevance=min_similarity,
memory_types=memory_types,
use_temporal=False,
use_entity=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_recent(
self,
hours: float = 24,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve recent memories.
Args:
hours: Number of hours to look back
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
recent_seconds=hours * 3600,
limit=limit,
memory_types=memory_types,
use_vector=False,
use_entity=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_by_entity(
self,
entity_type: str,
entity_value: str,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve memories by entity.
Args:
entity_type: Type of entity
entity_value: Entity value
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
entities=[(entity_type, entity_value)],
limit=limit,
memory_types=memory_types,
use_vector=False,
use_temporal=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_successful(
self,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve successful memories.
Args:
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
outcomes=[Outcome.SUCCESS],
limit=limit,
memory_types=memory_types,
use_vector=False,
use_temporal=False,
use_entity=False,
)
return await self.retrieve(query)
def invalidate_cache(self) -> int:
"""
Invalidate all cached results.
Returns:
Number of entries invalidated
"""
if self._cache:
return self._cache.clear()
return 0
def invalidate_cache_for_memory(self, memory_id: UUID) -> int:
"""
Invalidate cache entries containing a specific memory.
Args:
memory_id: Memory ID to invalidate
Returns:
Number of entries invalidated
"""
if self._cache:
return self._cache.invalidate_by_memory(memory_id)
return 0
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
if self._cache:
return self._cache.get_stats()
return {"enabled": False}
# Singleton retrieval engine instance
_engine: RetrievalEngine | None = None
def get_retrieval_engine() -> RetrievalEngine:
"""Get the singleton retrieval engine instance."""
global _engine
if _engine is None:
_engine = RetrievalEngine()
return _engine

View File

@@ -0,0 +1,19 @@
# app/services/memory/integration/__init__.py
"""
Memory Integration Module.
Provides integration between the agent memory system and other Syndarix components:
- Context Engine: Memory as context source
- Agent Lifecycle: Spawn, pause, resume, terminate hooks
"""
from .context_source import MemoryContextSource, get_memory_context_source
from .lifecycle import AgentLifecycleManager, LifecycleHooks, get_lifecycle_manager
__all__ = [
"AgentLifecycleManager",
"LifecycleHooks",
"MemoryContextSource",
"get_lifecycle_manager",
"get_memory_context_source",
]

View File

@@ -0,0 +1,399 @@
# app/services/memory/integration/context_source.py
"""
Memory Context Source.
Provides agent memory as a context source for the Context Engine.
Retrieves relevant memories based on query and converts them to MemoryContext objects.
"""
import logging
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.context.types.memory import MemoryContext
from app.services.memory.episodic import EpisodicMemory
from app.services.memory.procedural import ProceduralMemory
from app.services.memory.semantic import SemanticMemory
from app.services.memory.working import WorkingMemory
logger = logging.getLogger(__name__)
@dataclass
class MemoryFetchConfig:
"""Configuration for memory fetching."""
# Limits per memory type
working_limit: int = 10
episodic_limit: int = 10
semantic_limit: int = 15
procedural_limit: int = 5
# Time ranges
episodic_days_back: int = 30
min_relevance: float = 0.3
# Which memory types to include
include_working: bool = True
include_episodic: bool = True
include_semantic: bool = True
include_procedural: bool = True
@dataclass
class MemoryFetchResult:
"""Result of memory fetch operation."""
contexts: list[MemoryContext]
by_type: dict[str, int]
fetch_time_ms: float
query: str
class MemoryContextSource:
"""
Source for memory context in the Context Engine.
This service retrieves relevant memories based on a query and
converts them to MemoryContext objects for context assembly.
It coordinates between all memory types (working, episodic,
semantic, procedural) to provide a comprehensive memory context.
"""
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> None:
"""
Initialize the memory context source.
Args:
session: Database session
embedding_generator: Optional embedding generator for semantic search
"""
self._session = session
self._embedding_generator = embedding_generator
# Lazy-initialized memory services
self._episodic: EpisodicMemory | None = None
self._semantic: SemanticMemory | None = None
self._procedural: ProceduralMemory | None = None
async def _get_episodic(self) -> EpisodicMemory:
"""Get or create episodic memory service."""
if self._episodic is None:
self._episodic = await EpisodicMemory.create(
self._session,
self._embedding_generator,
)
return self._episodic
async def _get_semantic(self) -> SemanticMemory:
"""Get or create semantic memory service."""
if self._semantic is None:
self._semantic = await SemanticMemory.create(
self._session,
self._embedding_generator,
)
return self._semantic
async def _get_procedural(self) -> ProceduralMemory:
"""Get or create procedural memory service."""
if self._procedural is None:
self._procedural = await ProceduralMemory.create(
self._session,
self._embedding_generator,
)
return self._procedural
async def fetch_context(
self,
query: str,
project_id: UUID,
agent_instance_id: UUID | None = None,
agent_type_id: UUID | None = None,
session_id: str | None = None,
config: MemoryFetchConfig | None = None,
) -> MemoryFetchResult:
"""
Fetch relevant memories as context.
This is the main entry point for the Context Engine integration.
It searches across all memory types and returns relevant memories
as MemoryContext objects.
Args:
query: Search query for finding relevant memories
project_id: Project scope
agent_instance_id: Optional agent instance scope
agent_type_id: Optional agent type scope (for procedural)
session_id: Optional session ID (for working memory)
config: Optional fetch configuration
Returns:
MemoryFetchResult with contexts and metadata
"""
config = config or MemoryFetchConfig()
start_time = datetime.now(UTC)
contexts: list[MemoryContext] = []
by_type: dict[str, int] = {
"working": 0,
"episodic": 0,
"semantic": 0,
"procedural": 0,
}
# Fetch from working memory (session-scoped)
if config.include_working and session_id:
try:
working_contexts = await self._fetch_working(
query=query,
session_id=session_id,
project_id=project_id,
agent_instance_id=agent_instance_id,
limit=config.working_limit,
)
contexts.extend(working_contexts)
by_type["working"] = len(working_contexts)
except Exception as e:
logger.warning(f"Failed to fetch working memory: {e}")
# Fetch from episodic memory
if config.include_episodic:
try:
episodic_contexts = await self._fetch_episodic(
query=query,
project_id=project_id,
agent_instance_id=agent_instance_id,
limit=config.episodic_limit,
days_back=config.episodic_days_back,
)
contexts.extend(episodic_contexts)
by_type["episodic"] = len(episodic_contexts)
except Exception as e:
logger.warning(f"Failed to fetch episodic memory: {e}")
# Fetch from semantic memory
if config.include_semantic:
try:
semantic_contexts = await self._fetch_semantic(
query=query,
project_id=project_id,
limit=config.semantic_limit,
min_relevance=config.min_relevance,
)
contexts.extend(semantic_contexts)
by_type["semantic"] = len(semantic_contexts)
except Exception as e:
logger.warning(f"Failed to fetch semantic memory: {e}")
# Fetch from procedural memory
if config.include_procedural:
try:
procedural_contexts = await self._fetch_procedural(
query=query,
project_id=project_id,
agent_type_id=agent_type_id,
limit=config.procedural_limit,
)
contexts.extend(procedural_contexts)
by_type["procedural"] = len(procedural_contexts)
except Exception as e:
logger.warning(f"Failed to fetch procedural memory: {e}")
# Sort by relevance
contexts.sort(key=lambda c: c.relevance_score, reverse=True)
fetch_time = (datetime.now(UTC) - start_time).total_seconds() * 1000
logger.debug(
f"Fetched {len(contexts)} memory contexts for query '{query[:50]}...' "
f"in {fetch_time:.1f}ms"
)
return MemoryFetchResult(
contexts=contexts,
by_type=by_type,
fetch_time_ms=fetch_time,
query=query,
)
async def _fetch_working(
self,
query: str,
session_id: str,
project_id: UUID,
agent_instance_id: UUID | None,
limit: int,
) -> list[MemoryContext]:
"""Fetch from working memory."""
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
)
contexts: list[MemoryContext] = []
all_keys = await working.list_keys()
# Filter keys by query (simple substring match)
query_lower = query.lower()
matched_keys = [k for k in all_keys if query_lower in k.lower()]
# If no query match, include all keys (working memory is always relevant)
if not matched_keys and query:
matched_keys = all_keys
for key in matched_keys[:limit]:
value = await working.get(key)
if value is not None:
contexts.append(
MemoryContext.from_working_memory(
key=key,
value=value,
source=f"working:{session_id}",
query=query,
)
)
return contexts
async def _fetch_episodic(
self,
query: str,
project_id: UUID,
agent_instance_id: UUID | None,
limit: int,
days_back: int,
) -> list[MemoryContext]:
"""Fetch from episodic memory."""
episodic = await self._get_episodic()
# Search for similar episodes
episodes = await episodic.search_similar(
project_id=project_id,
query=query,
limit=limit,
agent_instance_id=agent_instance_id,
)
# Also get recent episodes if we didn't find enough
if len(episodes) < limit // 2:
since = datetime.now(UTC) - timedelta(days=days_back)
recent = await episodic.get_recent(
project_id=project_id,
limit=limit,
since=since,
)
# Deduplicate by ID
existing_ids = {e.id for e in episodes}
for ep in recent:
if ep.id not in existing_ids:
episodes.append(ep)
if len(episodes) >= limit:
break
return [
MemoryContext.from_episodic_memory(ep, query=query)
for ep in episodes[:limit]
]
async def _fetch_semantic(
self,
query: str,
project_id: UUID,
limit: int,
min_relevance: float,
) -> list[MemoryContext]:
"""Fetch from semantic memory."""
semantic = await self._get_semantic()
facts = await semantic.search_facts(
query=query,
project_id=project_id,
limit=limit,
min_confidence=min_relevance,
)
return [MemoryContext.from_semantic_memory(fact, query=query) for fact in facts]
async def _fetch_procedural(
self,
query: str,
project_id: UUID,
agent_type_id: UUID | None,
limit: int,
) -> list[MemoryContext]:
"""Fetch from procedural memory."""
procedural = await self._get_procedural()
procedures = await procedural.find_matching(
context=query,
project_id=project_id,
agent_type_id=agent_type_id,
limit=limit,
)
return [
MemoryContext.from_procedural_memory(proc, query=query)
for proc in procedures
]
async def fetch_all_working(
self,
session_id: str,
project_id: UUID,
agent_instance_id: UUID | None = None,
) -> list[MemoryContext]:
"""
Fetch all working memory for a session.
Useful for including entire session state in context.
Args:
session_id: Session ID
project_id: Project scope
agent_instance_id: Optional agent instance scope
Returns:
List of MemoryContext for all working memory items
"""
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
)
contexts: list[MemoryContext] = []
all_keys = await working.list_keys()
for key in all_keys:
value = await working.get(key)
if value is not None:
contexts.append(
MemoryContext.from_working_memory(
key=key,
value=value,
source=f"working:{session_id}",
)
)
return contexts
# Factory function
async def get_memory_context_source(
session: AsyncSession,
embedding_generator: Any | None = None,
) -> MemoryContextSource:
"""Create a memory context source instance."""
return MemoryContextSource(
session=session,
embedding_generator=embedding_generator,
)

View File

@@ -0,0 +1,635 @@
# app/services/memory/integration/lifecycle.py
"""
Agent Lifecycle Hooks for Memory System.
Provides memory management hooks for agent lifecycle events:
- spawn: Initialize working memory for new agent instance
- pause: Checkpoint working memory state
- resume: Restore working memory from checkpoint
- terminate: Consolidate session to episodic memory
"""
import logging
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.memory.episodic import EpisodicMemory
from app.services.memory.types import EpisodeCreate, Outcome
from app.services.memory.working import WorkingMemory
logger = logging.getLogger(__name__)
@dataclass
class LifecycleEvent:
"""Event data for lifecycle hooks."""
event_type: str # spawn, pause, resume, terminate
project_id: UUID
agent_instance_id: UUID
agent_type_id: UUID | None = None
session_id: str | None = None
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class LifecycleResult:
"""Result of a lifecycle operation."""
success: bool
event_type: str
message: str | None = None
data: dict[str, Any] = field(default_factory=dict)
duration_ms: float = 0.0
# Type alias for lifecycle hooks
LifecycleHook = Callable[[LifecycleEvent], Coroutine[Any, Any, None]]
class LifecycleHooks:
"""
Collection of lifecycle hooks.
Allows registration of custom hooks for lifecycle events.
Hooks are called after the core memory operations.
"""
def __init__(self) -> None:
"""Initialize lifecycle hooks."""
self._spawn_hooks: list[LifecycleHook] = []
self._pause_hooks: list[LifecycleHook] = []
self._resume_hooks: list[LifecycleHook] = []
self._terminate_hooks: list[LifecycleHook] = []
def on_spawn(self, hook: LifecycleHook) -> LifecycleHook:
"""Register a spawn hook."""
self._spawn_hooks.append(hook)
return hook
def on_pause(self, hook: LifecycleHook) -> LifecycleHook:
"""Register a pause hook."""
self._pause_hooks.append(hook)
return hook
def on_resume(self, hook: LifecycleHook) -> LifecycleHook:
"""Register a resume hook."""
self._resume_hooks.append(hook)
return hook
def on_terminate(self, hook: LifecycleHook) -> LifecycleHook:
"""Register a terminate hook."""
self._terminate_hooks.append(hook)
return hook
async def run_spawn_hooks(self, event: LifecycleEvent) -> None:
"""Run all spawn hooks."""
for hook in self._spawn_hooks:
try:
await hook(event)
except Exception as e:
logger.warning(f"Spawn hook failed: {e}")
async def run_pause_hooks(self, event: LifecycleEvent) -> None:
"""Run all pause hooks."""
for hook in self._pause_hooks:
try:
await hook(event)
except Exception as e:
logger.warning(f"Pause hook failed: {e}")
async def run_resume_hooks(self, event: LifecycleEvent) -> None:
"""Run all resume hooks."""
for hook in self._resume_hooks:
try:
await hook(event)
except Exception as e:
logger.warning(f"Resume hook failed: {e}")
async def run_terminate_hooks(self, event: LifecycleEvent) -> None:
"""Run all terminate hooks."""
for hook in self._terminate_hooks:
try:
await hook(event)
except Exception as e:
logger.warning(f"Terminate hook failed: {e}")
class AgentLifecycleManager:
"""
Manager for agent lifecycle and memory integration.
Handles memory operations during agent lifecycle events:
- spawn: Creates new working memory for the session
- pause: Saves working memory state to checkpoint
- resume: Restores working memory from checkpoint
- terminate: Consolidates working memory to episodic memory
"""
# Key prefix for checkpoint storage
CHECKPOINT_PREFIX = "__checkpoint__"
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
hooks: LifecycleHooks | None = None,
) -> None:
"""
Initialize the lifecycle manager.
Args:
session: Database session
embedding_generator: Optional embedding generator
hooks: Optional lifecycle hooks
"""
self._session = session
self._embedding_generator = embedding_generator
self._hooks = hooks or LifecycleHooks()
# Lazy-initialized services
self._episodic: EpisodicMemory | None = None
async def _get_episodic(self) -> EpisodicMemory:
"""Get or create episodic memory service."""
if self._episodic is None:
self._episodic = await EpisodicMemory.create(
self._session,
self._embedding_generator,
)
return self._episodic
@property
def hooks(self) -> LifecycleHooks:
"""Get the lifecycle hooks."""
return self._hooks
async def spawn(
self,
project_id: UUID,
agent_instance_id: UUID,
session_id: str,
agent_type_id: UUID | None = None,
initial_state: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> LifecycleResult:
"""
Handle agent spawn - initialize working memory.
Creates a new working memory instance for the agent session
and optionally populates it with initial state.
Args:
project_id: Project scope
agent_instance_id: Agent instance ID
session_id: Session ID for working memory
agent_type_id: Optional agent type ID
initial_state: Optional initial state to populate
metadata: Optional metadata for the event
Returns:
LifecycleResult with spawn outcome
"""
start_time = datetime.now(UTC)
try:
# Create working memory for the session
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id),
)
# Populate initial state if provided
items_set = 0
if initial_state:
for key, value in initial_state.items():
await working.set(key, value)
items_set += 1
# Create and run event hooks
event = LifecycleEvent(
event_type="spawn",
project_id=project_id,
agent_instance_id=agent_instance_id,
agent_type_id=agent_type_id,
session_id=session_id,
metadata=metadata or {},
)
await self._hooks.run_spawn_hooks(event)
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
logger.info(
f"Agent {agent_instance_id} spawned with session {session_id}, "
f"initial state: {items_set} items"
)
return LifecycleResult(
success=True,
event_type="spawn",
message="Agent spawned successfully",
data={
"session_id": session_id,
"initial_items": items_set,
},
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Spawn failed for agent {agent_instance_id}: {e}")
return LifecycleResult(
success=False,
event_type="spawn",
message=f"Spawn failed: {e}",
)
async def pause(
self,
project_id: UUID,
agent_instance_id: UUID,
session_id: str,
checkpoint_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> LifecycleResult:
"""
Handle agent pause - checkpoint working memory.
Saves the current working memory state to a checkpoint
that can be restored later with resume().
Args:
project_id: Project scope
agent_instance_id: Agent instance ID
session_id: Session ID
checkpoint_id: Optional checkpoint identifier
metadata: Optional metadata for the event
Returns:
LifecycleResult with checkpoint data
"""
start_time = datetime.now(UTC)
checkpoint_id = checkpoint_id or f"checkpoint_{int(start_time.timestamp())}"
try:
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id),
)
# Get all current state
all_keys = await working.list_keys()
# Filter out checkpoint keys
state_keys = [
k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)
]
state: dict[str, Any] = {}
for key in state_keys:
value = await working.get(key)
if value is not None:
state[key] = value
# Store checkpoint
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
await working.set(
checkpoint_key,
{
"state": state,
"timestamp": start_time.isoformat(),
"keys_count": len(state),
},
ttl_seconds=86400 * 7, # Keep checkpoint for 7 days
)
# Run hooks
event = LifecycleEvent(
event_type="pause",
project_id=project_id,
agent_instance_id=agent_instance_id,
session_id=session_id,
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
)
await self._hooks.run_pause_hooks(event)
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
logger.info(
f"Agent {agent_instance_id} paused, checkpoint {checkpoint_id} "
f"saved with {len(state)} items"
)
return LifecycleResult(
success=True,
event_type="pause",
message="Agent paused successfully",
data={
"checkpoint_id": checkpoint_id,
"items_saved": len(state),
"timestamp": start_time.isoformat(),
},
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Pause failed for agent {agent_instance_id}: {e}")
return LifecycleResult(
success=False,
event_type="pause",
message=f"Pause failed: {e}",
)
async def resume(
self,
project_id: UUID,
agent_instance_id: UUID,
session_id: str,
checkpoint_id: str,
clear_current: bool = True,
metadata: dict[str, Any] | None = None,
) -> LifecycleResult:
"""
Handle agent resume - restore from checkpoint.
Restores working memory state from a previously saved checkpoint.
Args:
project_id: Project scope
agent_instance_id: Agent instance ID
session_id: Session ID
checkpoint_id: Checkpoint to restore from
clear_current: Whether to clear current state before restoring
metadata: Optional metadata for the event
Returns:
LifecycleResult with restore outcome
"""
start_time = datetime.now(UTC)
try:
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id),
)
# Get checkpoint
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
checkpoint = await working.get(checkpoint_key)
if checkpoint is None:
return LifecycleResult(
success=False,
event_type="resume",
message=f"Checkpoint '{checkpoint_id}' not found",
)
# Clear current state if requested
if clear_current:
all_keys = await working.list_keys()
for key in all_keys:
if not key.startswith(self.CHECKPOINT_PREFIX):
await working.delete(key)
# Restore state from checkpoint
state = checkpoint.get("state", {})
items_restored = 0
for key, value in state.items():
await working.set(key, value)
items_restored += 1
# Run hooks
event = LifecycleEvent(
event_type="resume",
project_id=project_id,
agent_instance_id=agent_instance_id,
session_id=session_id,
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
)
await self._hooks.run_resume_hooks(event)
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
logger.info(
f"Agent {agent_instance_id} resumed from checkpoint {checkpoint_id}, "
f"restored {items_restored} items"
)
return LifecycleResult(
success=True,
event_type="resume",
message="Agent resumed successfully",
data={
"checkpoint_id": checkpoint_id,
"items_restored": items_restored,
"checkpoint_timestamp": checkpoint.get("timestamp"),
},
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Resume failed for agent {agent_instance_id}: {e}")
return LifecycleResult(
success=False,
event_type="resume",
message=f"Resume failed: {e}",
)
async def terminate(
self,
project_id: UUID,
agent_instance_id: UUID,
session_id: str,
task_description: str | None = None,
outcome: Outcome = Outcome.SUCCESS,
lessons_learned: list[str] | None = None,
consolidate_to_episodic: bool = True,
cleanup_working: bool = True,
metadata: dict[str, Any] | None = None,
) -> LifecycleResult:
"""
Handle agent termination - consolidate to episodic memory.
Consolidates the session's working memory into an episodic memory
entry, then optionally cleans up the working memory.
Args:
project_id: Project scope
agent_instance_id: Agent instance ID
session_id: Session ID
task_description: Description of what was accomplished
outcome: Task outcome (SUCCESS, FAILURE, PARTIAL)
lessons_learned: Optional list of lessons learned
consolidate_to_episodic: Whether to create episodic entry
cleanup_working: Whether to clear working memory
metadata: Optional metadata for the event
Returns:
LifecycleResult with termination outcome
"""
start_time = datetime.now(UTC)
try:
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id),
)
# Gather session state for consolidation
all_keys = await working.list_keys()
state_keys = [
k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)
]
session_state: dict[str, Any] = {}
for key in state_keys:
value = await working.get(key)
if value is not None:
session_state[key] = value
episode_id: str | None = None
# Consolidate to episodic memory
if consolidate_to_episodic:
episodic = await self._get_episodic()
description = task_description or f"Session {session_id} completed"
episode_data = EpisodeCreate(
project_id=project_id,
agent_instance_id=agent_instance_id,
session_id=session_id,
task_type="session_completion",
task_description=description[:500],
outcome=outcome,
outcome_details=f"Session terminated with {len(session_state)} state items",
actions=[
{
"type": "session_terminate",
"state_keys": list(session_state.keys()),
"outcome": outcome.value,
}
],
context_summary=str(session_state)[:1000] if session_state else "",
lessons_learned=lessons_learned or [],
duration_seconds=0.0, # Unknown at this point
tokens_used=0,
importance_score=0.6, # Moderate importance for session ends
)
episode = await episodic.record_episode(episode_data)
episode_id = str(episode.id)
# Clean up working memory
items_cleared = 0
if cleanup_working:
for key in all_keys:
await working.delete(key)
items_cleared += 1
# Run hooks
event = LifecycleEvent(
event_type="terminate",
project_id=project_id,
agent_instance_id=agent_instance_id,
session_id=session_id,
metadata={**(metadata or {}), "episode_id": episode_id},
)
await self._hooks.run_terminate_hooks(event)
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
logger.info(
f"Agent {agent_instance_id} terminated, session {session_id} "
f"consolidated to episode {episode_id}"
)
return LifecycleResult(
success=True,
event_type="terminate",
message="Agent terminated successfully",
data={
"episode_id": episode_id,
"state_items_consolidated": len(session_state),
"items_cleared": items_cleared,
"outcome": outcome.value,
},
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Terminate failed for agent {agent_instance_id}: {e}")
return LifecycleResult(
success=False,
event_type="terminate",
message=f"Terminate failed: {e}",
)
async def list_checkpoints(
self,
project_id: UUID,
agent_instance_id: UUID,
session_id: str,
) -> list[dict[str, Any]]:
"""
List available checkpoints for a session.
Args:
project_id: Project scope
agent_instance_id: Agent instance ID
session_id: Session ID
Returns:
List of checkpoint metadata dicts
"""
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id),
)
all_keys = await working.list_keys()
checkpoints: list[dict[str, Any]] = []
for key in all_keys:
if key.startswith(self.CHECKPOINT_PREFIX):
checkpoint_id = key[len(self.CHECKPOINT_PREFIX) :]
checkpoint = await working.get(key)
if checkpoint:
checkpoints.append(
{
"checkpoint_id": checkpoint_id,
"timestamp": checkpoint.get("timestamp"),
"keys_count": checkpoint.get("keys_count", 0),
}
)
# Sort by timestamp (newest first)
checkpoints.sort(
key=lambda c: c.get("timestamp", ""),
reverse=True,
)
return checkpoints
# Factory function
async def get_lifecycle_manager(
session: AsyncSession,
embedding_generator: Any | None = None,
hooks: LifecycleHooks | None = None,
) -> AgentLifecycleManager:
"""Create a lifecycle manager instance."""
return AgentLifecycleManager(
session=session,
embedding_generator=embedding_generator,
hooks=hooks,
)

View File

@@ -0,0 +1,606 @@
"""
Memory Manager
Facade for the Agent Memory System providing unified access
to all memory types and operations.
"""
import logging
from typing import Any
from uuid import UUID
from .config import MemorySettings, get_memory_settings
from .types import (
Episode,
EpisodeCreate,
Fact,
FactCreate,
MemoryStats,
MemoryType,
Outcome,
Procedure,
ProcedureCreate,
RetrievalResult,
ScopeContext,
ScopeLevel,
TaskState,
)
logger = logging.getLogger(__name__)
class MemoryManager:
"""
Unified facade for the Agent Memory System.
Provides a single entry point for all memory operations across
working, episodic, semantic, and procedural memory types.
Usage:
manager = MemoryManager.create()
# Working memory
await manager.set_working("key", {"data": "value"})
value = await manager.get_working("key")
# Episodic memory
episode = await manager.record_episode(episode_data)
similar = await manager.search_episodes("query")
# Semantic memory
fact = await manager.store_fact(fact_data)
facts = await manager.search_facts("query")
# Procedural memory
procedure = await manager.record_procedure(procedure_data)
procedures = await manager.find_procedures("context")
"""
def __init__(
self,
settings: MemorySettings,
scope: ScopeContext,
) -> None:
"""
Initialize the MemoryManager.
Args:
settings: Memory configuration settings
scope: The scope context for this manager instance
"""
self._settings = settings
self._scope = scope
self._initialized = False
# These will be initialized when the respective sub-modules are implemented
self._working_memory: Any | None = None
self._episodic_memory: Any | None = None
self._semantic_memory: Any | None = None
self._procedural_memory: Any | None = None
logger.debug(
"MemoryManager created for scope %s:%s",
scope.scope_type.value,
scope.scope_id,
)
@classmethod
def create(
cls,
scope_type: ScopeLevel = ScopeLevel.SESSION,
scope_id: str = "default",
parent_scope: ScopeContext | None = None,
settings: MemorySettings | None = None,
) -> "MemoryManager":
"""
Create a new MemoryManager instance.
Args:
scope_type: The scope level for this manager
scope_id: The scope identifier
parent_scope: Optional parent scope for inheritance
settings: Optional custom settings (uses global if not provided)
Returns:
A new MemoryManager instance
"""
if settings is None:
settings = get_memory_settings()
scope = ScopeContext(
scope_type=scope_type,
scope_id=scope_id,
parent=parent_scope,
)
return cls(settings=settings, scope=scope)
@classmethod
def for_session(
cls,
session_id: str,
agent_instance_id: UUID | None = None,
project_id: UUID | None = None,
) -> "MemoryManager":
"""
Create a MemoryManager for a specific session.
Builds the appropriate scope hierarchy based on provided IDs.
Args:
session_id: The session identifier
agent_instance_id: Optional agent instance ID
project_id: Optional project ID
Returns:
A MemoryManager configured for the session scope
"""
settings = get_memory_settings()
# Build scope hierarchy
parent: ScopeContext | None = None
if project_id:
parent = ScopeContext(
scope_type=ScopeLevel.PROJECT,
scope_id=str(project_id),
parent=ScopeContext(
scope_type=ScopeLevel.GLOBAL,
scope_id="global",
),
)
if agent_instance_id:
parent = ScopeContext(
scope_type=ScopeLevel.AGENT_INSTANCE,
scope_id=str(agent_instance_id),
parent=parent,
)
scope = ScopeContext(
scope_type=ScopeLevel.SESSION,
scope_id=session_id,
parent=parent,
)
return cls(settings=settings, scope=scope)
@property
def scope(self) -> ScopeContext:
"""Get the current scope context."""
return self._scope
@property
def settings(self) -> MemorySettings:
"""Get the memory settings."""
return self._settings
# =========================================================================
# Working Memory Operations
# =========================================================================
async def set_working(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""
Set a value in working memory.
Args:
key: The key to store the value under
value: The value to store (must be JSON serializable)
ttl_seconds: Optional TTL (uses default if not provided)
"""
# Placeholder - will be implemented in #89
logger.debug("set_working called for key=%s (not yet implemented)", key)
raise NotImplementedError("Working memory not yet implemented")
async def get_working(
self,
key: str,
default: Any = None,
) -> Any:
"""
Get a value from working memory.
Args:
key: The key to retrieve
default: Default value if key not found
Returns:
The stored value or default
"""
# Placeholder - will be implemented in #89
logger.debug("get_working called for key=%s (not yet implemented)", key)
raise NotImplementedError("Working memory not yet implemented")
async def delete_working(self, key: str) -> bool:
"""
Delete a value from working memory.
Args:
key: The key to delete
Returns:
True if the key was deleted, False if not found
"""
# Placeholder - will be implemented in #89
logger.debug("delete_working called for key=%s (not yet implemented)", key)
raise NotImplementedError("Working memory not yet implemented")
async def set_task_state(self, state: TaskState) -> None:
"""
Set the current task state in working memory.
Args:
state: The task state to store
"""
# Placeholder - will be implemented in #89
logger.debug(
"set_task_state called for task=%s (not yet implemented)",
state.task_id,
)
raise NotImplementedError("Working memory not yet implemented")
async def get_task_state(self) -> TaskState | None:
"""
Get the current task state from working memory.
Returns:
The current task state or None
"""
# Placeholder - will be implemented in #89
logger.debug("get_task_state called (not yet implemented)")
raise NotImplementedError("Working memory not yet implemented")
async def create_checkpoint(self) -> str:
"""
Create a checkpoint of the current working memory state.
Returns:
The checkpoint ID
"""
# Placeholder - will be implemented in #89
logger.debug("create_checkpoint called (not yet implemented)")
raise NotImplementedError("Working memory not yet implemented")
async def restore_checkpoint(self, checkpoint_id: str) -> None:
"""
Restore working memory from a checkpoint.
Args:
checkpoint_id: The checkpoint to restore from
"""
# Placeholder - will be implemented in #89
logger.debug(
"restore_checkpoint called for id=%s (not yet implemented)",
checkpoint_id,
)
raise NotImplementedError("Working memory not yet implemented")
# =========================================================================
# Episodic Memory Operations
# =========================================================================
async def record_episode(self, episode: EpisodeCreate) -> Episode:
"""
Record a new episode in episodic memory.
Args:
episode: The episode data to record
Returns:
The created episode with ID
"""
# Placeholder - will be implemented in #90
logger.debug(
"record_episode called for task=%s (not yet implemented)",
episode.task_type,
)
raise NotImplementedError("Episodic memory not yet implemented")
async def search_episodes(
self,
query: str,
limit: int | None = None,
) -> RetrievalResult[Episode]:
"""
Search for similar episodes.
Args:
query: The search query
limit: Maximum results to return
Returns:
Retrieval result with matching episodes
"""
# Placeholder - will be implemented in #90
logger.debug(
"search_episodes called for query=%s (not yet implemented)",
query[:50],
)
raise NotImplementedError("Episodic memory not yet implemented")
async def get_recent_episodes(
self,
limit: int = 10,
) -> list[Episode]:
"""
Get the most recent episodes.
Args:
limit: Maximum episodes to return
Returns:
List of recent episodes
"""
# Placeholder - will be implemented in #90
logger.debug("get_recent_episodes called (not yet implemented)")
raise NotImplementedError("Episodic memory not yet implemented")
async def get_episodes_by_outcome(
self,
outcome: Outcome,
limit: int = 10,
) -> list[Episode]:
"""
Get episodes by outcome.
Args:
outcome: The outcome to filter by
limit: Maximum episodes to return
Returns:
List of episodes with the specified outcome
"""
# Placeholder - will be implemented in #90
logger.debug(
"get_episodes_by_outcome called for outcome=%s (not yet implemented)",
outcome.value,
)
raise NotImplementedError("Episodic memory not yet implemented")
# =========================================================================
# Semantic Memory Operations
# =========================================================================
async def store_fact(self, fact: FactCreate) -> Fact:
"""
Store a new fact in semantic memory.
Args:
fact: The fact data to store
Returns:
The created fact with ID
"""
# Placeholder - will be implemented in #91
logger.debug(
"store_fact called for %s %s %s (not yet implemented)",
fact.subject,
fact.predicate,
fact.object,
)
raise NotImplementedError("Semantic memory not yet implemented")
async def search_facts(
self,
query: str,
limit: int | None = None,
) -> RetrievalResult[Fact]:
"""
Search for facts matching a query.
Args:
query: The search query
limit: Maximum results to return
Returns:
Retrieval result with matching facts
"""
# Placeholder - will be implemented in #91
logger.debug(
"search_facts called for query=%s (not yet implemented)",
query[:50],
)
raise NotImplementedError("Semantic memory not yet implemented")
async def get_facts_by_entity(
self,
entity: str,
limit: int = 20,
) -> list[Fact]:
"""
Get facts related to an entity.
Args:
entity: The entity to search for
limit: Maximum facts to return
Returns:
List of facts mentioning the entity
"""
# Placeholder - will be implemented in #91
logger.debug(
"get_facts_by_entity called for entity=%s (not yet implemented)",
entity,
)
raise NotImplementedError("Semantic memory not yet implemented")
async def reinforce_fact(self, fact_id: UUID) -> Fact:
"""
Reinforce a fact (increase confidence from repeated learning).
Args:
fact_id: The fact to reinforce
Returns:
The updated fact
"""
# Placeholder - will be implemented in #91
logger.debug(
"reinforce_fact called for id=%s (not yet implemented)",
fact_id,
)
raise NotImplementedError("Semantic memory not yet implemented")
# =========================================================================
# Procedural Memory Operations
# =========================================================================
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure:
"""
Record a new procedure.
Args:
procedure: The procedure data to record
Returns:
The created procedure with ID
"""
# Placeholder - will be implemented in #92
logger.debug(
"record_procedure called for name=%s (not yet implemented)",
procedure.name,
)
raise NotImplementedError("Procedural memory not yet implemented")
async def find_procedures(
self,
context: str,
limit: int = 5,
) -> list[Procedure]:
"""
Find procedures matching the current context.
Args:
context: The context to match against
limit: Maximum procedures to return
Returns:
List of matching procedures sorted by success rate
"""
# Placeholder - will be implemented in #92
logger.debug(
"find_procedures called for context=%s (not yet implemented)",
context[:50],
)
raise NotImplementedError("Procedural memory not yet implemented")
async def record_procedure_outcome(
self,
procedure_id: UUID,
success: bool,
) -> None:
"""
Record the outcome of using a procedure.
Args:
procedure_id: The procedure that was used
success: Whether the procedure succeeded
"""
# Placeholder - will be implemented in #92
logger.debug(
"record_procedure_outcome called for id=%s success=%s (not yet implemented)",
procedure_id,
success,
)
raise NotImplementedError("Procedural memory not yet implemented")
# =========================================================================
# Cross-Memory Operations
# =========================================================================
async def recall(
self,
query: str,
memory_types: list[MemoryType] | None = None,
limit: int = 10,
) -> dict[MemoryType, list[Any]]:
"""
Recall memories across multiple memory types.
Args:
query: The search query
memory_types: Memory types to search (all if not specified)
limit: Maximum results per type
Returns:
Dictionary mapping memory types to results
"""
# Placeholder - will be implemented in #97 (Component Integration)
logger.debug("recall called for query=%s (not yet implemented)", query[:50])
raise NotImplementedError("Cross-memory recall not yet implemented")
async def get_stats(
self,
memory_type: MemoryType | None = None,
) -> list[MemoryStats]:
"""
Get memory statistics.
Args:
memory_type: Specific type or all if not specified
Returns:
List of statistics for requested memory types
"""
# Placeholder - will be implemented in #100 (Metrics & Observability)
logger.debug("get_stats called (not yet implemented)")
raise NotImplementedError("Memory stats not yet implemented")
# =========================================================================
# Lifecycle Operations
# =========================================================================
async def initialize(self) -> None:
"""
Initialize the memory manager and its backends.
Should be called before using the manager.
"""
if self._initialized:
logger.debug("MemoryManager already initialized")
return
logger.info(
"Initializing MemoryManager for scope %s:%s",
self._scope.scope_type.value,
self._scope.scope_id,
)
# TODO: Initialize backends when implemented
self._initialized = True
logger.info("MemoryManager initialized successfully")
async def close(self) -> None:
"""
Close the memory manager and release resources.
Should be called when done using the manager.
"""
if not self._initialized:
return
logger.info(
"Closing MemoryManager for scope %s:%s",
self._scope.scope_type.value,
self._scope.scope_id,
)
# TODO: Close backends when implemented
self._initialized = False
logger.info("MemoryManager closed successfully")
async def __aenter__(self) -> "MemoryManager":
"""Async context manager entry."""
await self.initialize()
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Async context manager exit."""
await self.close()

View File

@@ -0,0 +1,40 @@
# app/services/memory/mcp/__init__.py
"""
MCP Tools for Agent Memory System.
Exposes memory operations as MCP-compatible tools that agents can invoke:
- remember: Store data in memory
- recall: Retrieve from memory
- forget: Remove from memory
- reflect: Analyze patterns
- get_memory_stats: Usage statistics
- search_procedures: Find relevant procedures
- record_outcome: Record task success/failure
"""
from .service import MemoryToolService, get_memory_tool_service
from .tools import (
MEMORY_TOOL_DEFINITIONS,
ForgetArgs,
GetMemoryStatsArgs,
MemoryToolDefinition,
RecallArgs,
RecordOutcomeArgs,
ReflectArgs,
RememberArgs,
SearchProceduresArgs,
)
__all__ = [
"MEMORY_TOOL_DEFINITIONS",
"ForgetArgs",
"GetMemoryStatsArgs",
"MemoryToolDefinition",
"MemoryToolService",
"RecallArgs",
"RecordOutcomeArgs",
"ReflectArgs",
"RememberArgs",
"SearchProceduresArgs",
"get_memory_tool_service",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,485 @@
# app/services/memory/mcp/tools.py
"""
MCP Tool Definitions for Agent Memory System.
Defines the schema and metadata for memory-related MCP tools.
These tools are invoked by AI agents to interact with the memory system.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
# OutcomeType alias - uses core Outcome enum from types module for consistency
from app.services.memory.types import Outcome as OutcomeType
class MemoryType(str, Enum):
"""Types of memory for storage operations."""
WORKING = "working"
EPISODIC = "episodic"
SEMANTIC = "semantic"
PROCEDURAL = "procedural"
class AnalysisType(str, Enum):
"""Types of pattern analysis for the reflect tool."""
RECENT_PATTERNS = "recent_patterns"
SUCCESS_FACTORS = "success_factors"
FAILURE_PATTERNS = "failure_patterns"
COMMON_PROCEDURES = "common_procedures"
LEARNING_PROGRESS = "learning_progress"
# ============================================================================
# Tool Argument Schemas (Pydantic models for validation)
# ============================================================================
class RememberArgs(BaseModel):
"""Arguments for the 'remember' tool."""
memory_type: MemoryType = Field(
...,
description="Type of memory to store in: working, episodic, semantic, or procedural",
)
content: str = Field(
...,
description="The content to remember. Can be text, facts, or procedure steps.",
min_length=1,
max_length=10000,
)
key: str | None = Field(
None,
description="Optional key for working memory entries. Required for working memory type.",
max_length=256,
)
importance: float = Field(
0.5,
description="Importance score from 0.0 (low) to 1.0 (critical)",
ge=0.0,
le=1.0,
)
ttl_seconds: int | None = Field(
None,
description="Time-to-live in seconds for working memory. None for permanent storage.",
ge=1,
le=86400 * 30, # Max 30 days
)
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Additional metadata to store with the memory",
)
# For semantic memory (facts)
subject: str | None = Field(
None,
description="Subject of the fact (for semantic memory)",
max_length=256,
)
predicate: str | None = Field(
None,
description="Predicate/relationship (for semantic memory)",
max_length=256,
)
object_value: str | None = Field(
None,
description="Object of the fact (for semantic memory)",
max_length=1000,
)
# For procedural memory
trigger: str | None = Field(
None,
description="Trigger condition for the procedure (for procedural memory)",
max_length=500,
)
steps: list[dict[str, Any]] | None = Field(
None,
description="Procedure steps as a list of action dictionaries",
)
class RecallArgs(BaseModel):
"""Arguments for the 'recall' tool."""
query: str = Field(
...,
description="Search query to find relevant memories",
min_length=1,
max_length=1000,
)
memory_types: list[MemoryType] = Field(
default_factory=lambda: [MemoryType.EPISODIC, MemoryType.SEMANTIC],
description="Types of memory to search in",
)
limit: int = Field(
10,
description="Maximum number of results to return",
ge=1,
le=100,
)
min_relevance: float = Field(
0.0,
description="Minimum relevance score (0.0-1.0) for results",
ge=0.0,
le=1.0,
)
filters: dict[str, Any] = Field(
default_factory=dict,
description="Additional filters (e.g., outcome, task_type, date range)",
)
include_context: bool = Field(
True,
description="Whether to include surrounding context in results",
)
class ForgetArgs(BaseModel):
"""Arguments for the 'forget' tool."""
memory_type: MemoryType = Field(
...,
description="Type of memory to remove from",
)
key: str | None = Field(
None,
description="Key to remove (for working memory)",
max_length=256,
)
memory_id: str | None = Field(
None,
description="Specific memory ID to remove (for episodic/semantic/procedural)",
)
pattern: str | None = Field(
None,
description="Pattern to match for bulk removal (use with caution)",
max_length=500,
)
confirm_bulk: bool = Field(
False,
description="Must be True to confirm bulk deletion when using pattern",
)
class ReflectArgs(BaseModel):
"""Arguments for the 'reflect' tool."""
analysis_type: AnalysisType = Field(
...,
description="Type of pattern analysis to perform",
)
scope: str | None = Field(
None,
description="Optional scope to limit analysis (e.g., task_type, time range)",
max_length=500,
)
depth: int = Field(
3,
description="Depth of analysis (1=surface, 5=deep)",
ge=1,
le=5,
)
include_examples: bool = Field(
True,
description="Whether to include example memories in the analysis",
)
max_items: int = Field(
10,
description="Maximum number of patterns/examples to analyze",
ge=1,
le=50,
)
class GetMemoryStatsArgs(BaseModel):
"""Arguments for the 'get_memory_stats' tool."""
include_breakdown: bool = Field(
True,
description="Include breakdown by memory type",
)
include_recent_activity: bool = Field(
True,
description="Include recent memory activity summary",
)
time_range_days: int = Field(
7,
description="Time range for activity analysis in days",
ge=1,
le=90,
)
class SearchProceduresArgs(BaseModel):
"""Arguments for the 'search_procedures' tool."""
trigger: str = Field(
...,
description="Trigger or situation to find procedures for",
min_length=1,
max_length=500,
)
task_type: str | None = Field(
None,
description="Optional task type to filter procedures",
max_length=100,
)
min_success_rate: float = Field(
0.5,
description="Minimum success rate (0.0-1.0) for returned procedures",
ge=0.0,
le=1.0,
)
limit: int = Field(
5,
description="Maximum number of procedures to return",
ge=1,
le=20,
)
include_steps: bool = Field(
True,
description="Whether to include detailed steps in the response",
)
class RecordOutcomeArgs(BaseModel):
"""Arguments for the 'record_outcome' tool."""
task_type: str = Field(
...,
description="Type of task that was executed",
min_length=1,
max_length=100,
)
outcome: OutcomeType = Field(
...,
description="Outcome of the task execution",
)
procedure_id: str | None = Field(
None,
description="ID of the procedure that was followed (if any)",
)
context: dict[str, Any] = Field(
default_factory=dict,
description="Context in which the task was executed",
)
lessons_learned: str | None = Field(
None,
description="What was learned from this execution",
max_length=2000,
)
duration_seconds: float | None = Field(
None,
description="How long the task took to execute",
ge=0.0,
)
error_details: str | None = Field(
None,
description="Details about any errors encountered (for failures)",
max_length=2000,
)
# ============================================================================
# Tool Definition Structure
# ============================================================================
@dataclass
class MemoryToolDefinition:
"""Definition of an MCP tool for the memory system."""
name: str
description: str
args_schema: type[BaseModel]
input_schema: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Generate input schema from Pydantic model."""
if not self.input_schema:
self.input_schema = self.args_schema.model_json_schema()
def to_mcp_format(self) -> dict[str, Any]:
"""Convert to MCP tool format."""
return {
"name": self.name,
"description": self.description,
"inputSchema": self.input_schema,
}
def validate_args(self, args: dict[str, Any]) -> BaseModel:
"""Validate and parse arguments."""
return self.args_schema.model_validate(args)
# ============================================================================
# Tool Definitions
# ============================================================================
REMEMBER_TOOL = MemoryToolDefinition(
name="remember",
description="""Store information in the agent's memory system.
Use this tool to:
- Store temporary data in working memory (key-value with optional TTL)
- Record important events in episodic memory (automatically done on session end)
- Store facts/knowledge in semantic memory (subject-predicate-object triples)
- Save procedures in procedural memory (trigger conditions and steps)
Examples:
- Working memory: {"memory_type": "working", "key": "current_task", "content": "Implementing auth", "ttl_seconds": 3600}
- Semantic fact: {"memory_type": "semantic", "subject": "User", "predicate": "prefers", "object_value": "dark mode", "content": "User preference noted"}
- Procedure: {"memory_type": "procedural", "trigger": "When creating a new file", "steps": [{"action": "check_exists"}, {"action": "create"}], "content": "File creation procedure"}
""",
args_schema=RememberArgs,
)
RECALL_TOOL = MemoryToolDefinition(
name="recall",
description="""Retrieve information from the agent's memory system.
Use this tool to:
- Search for relevant past experiences (episodic)
- Look up known facts and knowledge (semantic)
- Find applicable procedures for current task (procedural)
- Get current session state (working)
The query supports semantic search - describe what you're looking for in natural language.
Examples:
- {"query": "How did I handle authentication errors before?", "memory_types": ["episodic"]}
- {"query": "What are the user's preferences?", "memory_types": ["semantic"], "limit": 5}
- {"query": "database connection", "memory_types": ["episodic", "semantic", "procedural"], "filters": {"outcome": "success"}}
""",
args_schema=RecallArgs,
)
FORGET_TOOL = MemoryToolDefinition(
name="forget",
description="""Remove information from the agent's memory system.
Use this tool to:
- Clear temporary working memory entries
- Remove specific memories by ID
- Bulk remove memories matching a pattern (requires confirmation)
WARNING: Deletion is permanent. Use with caution.
Examples:
- Working memory: {"memory_type": "working", "key": "temp_calculation"}
- Specific memory: {"memory_type": "episodic", "memory_id": "ep-123"}
- Bulk (requires confirm): {"memory_type": "working", "pattern": "cache_*", "confirm_bulk": true}
""",
args_schema=ForgetArgs,
)
REFLECT_TOOL = MemoryToolDefinition(
name="reflect",
description="""Analyze patterns in the agent's memory to gain insights.
Use this tool to:
- Identify patterns in recent work
- Understand what leads to success/failure
- Learn from past experiences
- Track learning progress over time
Analysis types:
- recent_patterns: What patterns appear in recent work
- success_factors: What conditions lead to success
- failure_patterns: What causes failures and how to avoid them
- common_procedures: Most frequently used procedures
- learning_progress: How knowledge has grown over time
Examples:
- {"analysis_type": "success_factors", "scope": "code_review", "depth": 3}
- {"analysis_type": "failure_patterns", "include_examples": true, "max_items": 5}
""",
args_schema=ReflectArgs,
)
GET_MEMORY_STATS_TOOL = MemoryToolDefinition(
name="get_memory_stats",
description="""Get statistics about the agent's memory usage.
Returns information about:
- Total memories stored by type
- Storage utilization
- Recent activity summary
- Memory health indicators
Use this to understand memory capacity and usage patterns.
Examples:
- {"include_breakdown": true, "include_recent_activity": true}
- {"time_range_days": 30, "include_breakdown": true}
""",
args_schema=GetMemoryStatsArgs,
)
SEARCH_PROCEDURES_TOOL = MemoryToolDefinition(
name="search_procedures",
description="""Find relevant procedures for a given situation.
Use this tool when you need to:
- Find the best way to handle a situation
- Look up proven approaches to problems
- Get step-by-step guidance for tasks
Returns procedures ranked by relevance and success rate.
Examples:
- {"trigger": "Deploying to production", "min_success_rate": 0.8}
- {"trigger": "Handling merge conflicts", "task_type": "git_operations", "limit": 3}
""",
args_schema=SearchProceduresArgs,
)
RECORD_OUTCOME_TOOL = MemoryToolDefinition(
name="record_outcome",
description="""Record the outcome of a task execution.
Use this tool after completing a task to:
- Update procedure success/failure rates
- Store lessons learned for future reference
- Improve procedure recommendations
This helps the memory system learn from experience.
Examples:
- {"task_type": "code_review", "outcome": "success", "lessons_learned": "Breaking changes caught early"}
- {"task_type": "deployment", "outcome": "failure", "error_details": "Database migration timeout", "lessons_learned": "Need to test migrations locally first"}
""",
args_schema=RecordOutcomeArgs,
)
# All tool definitions in a dictionary for easy lookup
MEMORY_TOOL_DEFINITIONS: dict[str, MemoryToolDefinition] = {
"remember": REMEMBER_TOOL,
"recall": RECALL_TOOL,
"forget": FORGET_TOOL,
"reflect": REFLECT_TOOL,
"get_memory_stats": GET_MEMORY_STATS_TOOL,
"search_procedures": SEARCH_PROCEDURES_TOOL,
"record_outcome": RECORD_OUTCOME_TOOL,
}
def get_all_tool_schemas() -> list[dict[str, Any]]:
"""Get MCP-formatted schemas for all memory tools."""
return [tool.to_mcp_format() for tool in MEMORY_TOOL_DEFINITIONS.values()]
def get_tool_definition(name: str) -> MemoryToolDefinition | None:
"""Get a specific tool definition by name."""
return MEMORY_TOOL_DEFINITIONS.get(name)

View File

@@ -0,0 +1,18 @@
# app/services/memory/metrics/__init__.py
"""Memory Metrics module."""
from .collector import (
MemoryMetrics,
get_memory_metrics,
record_memory_operation,
record_retrieval,
reset_memory_metrics,
)
__all__ = [
"MemoryMetrics",
"get_memory_metrics",
"record_memory_operation",
"record_retrieval",
"reset_memory_metrics",
]

View File

@@ -0,0 +1,542 @@
# app/services/memory/metrics/collector.py
"""
Memory Metrics Collector
Collects and exposes metrics for the memory system.
"""
import asyncio
import logging
from collections import Counter, defaultdict, deque
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
logger = logging.getLogger(__name__)
class MetricType(str, Enum):
"""Types of metrics."""
COUNTER = "counter"
GAUGE = "gauge"
HISTOGRAM = "histogram"
@dataclass
class MetricValue:
"""A single metric value."""
name: str
metric_type: MetricType
value: float
labels: dict[str, str] = field(default_factory=dict)
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
@dataclass
class HistogramBucket:
"""Histogram bucket for distribution metrics."""
le: float # Less than or equal
count: int = 0
class MemoryMetrics:
"""
Collects memory system metrics.
Metrics tracked:
- Memory operations (get/set/delete by type and scope)
- Retrieval operations and latencies
- Memory item counts by type
- Consolidation operations and durations
- Cache hit/miss rates
- Procedure success rates
- Embedding operations
"""
# Maximum samples to keep in histogram (circular buffer)
MAX_HISTOGRAM_SAMPLES = 10000
def __init__(self) -> None:
"""Initialize MemoryMetrics."""
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
# Use deque with maxlen for bounded memory (circular buffer)
self._histograms: dict[str, deque[float]] = defaultdict(
lambda: deque(maxlen=self.MAX_HISTOGRAM_SAMPLES)
)
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
self._lock = asyncio.Lock()
# Initialize histogram buckets
self._init_histogram_buckets()
def _init_histogram_buckets(self) -> None:
"""Initialize histogram buckets for latency metrics."""
# Fast operations (working memory)
fast_buckets = [0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, float("inf")]
# Normal operations (retrieval)
normal_buckets = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, float("inf")]
# Slow operations (consolidation)
slow_buckets = [0.1, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0, 60.0, float("inf")]
self._histogram_buckets["memory_working_latency_seconds"] = [
HistogramBucket(le=b) for b in fast_buckets
]
self._histogram_buckets["memory_retrieval_latency_seconds"] = [
HistogramBucket(le=b) for b in normal_buckets
]
self._histogram_buckets["memory_consolidation_duration_seconds"] = [
HistogramBucket(le=b) for b in slow_buckets
]
self._histogram_buckets["memory_embedding_latency_seconds"] = [
HistogramBucket(le=b) for b in normal_buckets
]
# Counter methods - Operations
async def inc_operations(
self,
operation: str,
memory_type: str,
scope: str | None = None,
success: bool = True,
) -> None:
"""Increment memory operation counter."""
async with self._lock:
labels = f"operation={operation},memory_type={memory_type}"
if scope:
labels += f",scope={scope}"
labels += f",success={str(success).lower()}"
self._counters["memory_operations_total"][labels] += 1
async def inc_retrieval(
self,
memory_type: str,
strategy: str,
results_count: int,
) -> None:
"""Increment retrieval counter."""
async with self._lock:
labels = f"memory_type={memory_type},strategy={strategy}"
self._counters["memory_retrievals_total"][labels] += 1
# Track result counts as a separate metric
self._counters["memory_retrieval_results_total"][labels] += results_count
async def inc_cache_hit(self, cache_type: str) -> None:
"""Increment cache hit counter."""
async with self._lock:
labels = f"cache_type={cache_type}"
self._counters["memory_cache_hits_total"][labels] += 1
async def inc_cache_miss(self, cache_type: str) -> None:
"""Increment cache miss counter."""
async with self._lock:
labels = f"cache_type={cache_type}"
self._counters["memory_cache_misses_total"][labels] += 1
async def inc_consolidation(
self,
consolidation_type: str,
success: bool = True,
) -> None:
"""Increment consolidation counter."""
async with self._lock:
labels = f"type={consolidation_type},success={str(success).lower()}"
self._counters["memory_consolidations_total"][labels] += 1
async def inc_procedure_execution(
self,
procedure_id: str | None = None,
success: bool = True,
) -> None:
"""Increment procedure execution counter."""
async with self._lock:
labels = f"success={str(success).lower()}"
self._counters["memory_procedure_executions_total"][labels] += 1
async def inc_embeddings_generated(self, memory_type: str) -> None:
"""Increment embeddings generated counter."""
async with self._lock:
labels = f"memory_type={memory_type}"
self._counters["memory_embeddings_generated_total"][labels] += 1
async def inc_fact_reinforcements(self) -> None:
"""Increment fact reinforcement counter."""
async with self._lock:
self._counters["memory_fact_reinforcements_total"][""] += 1
async def inc_episodes_recorded(self, outcome: str) -> None:
"""Increment episodes recorded counter."""
async with self._lock:
labels = f"outcome={outcome}"
self._counters["memory_episodes_recorded_total"][labels] += 1
async def inc_anomalies_detected(self, anomaly_type: str) -> None:
"""Increment anomaly detection counter."""
async with self._lock:
labels = f"anomaly_type={anomaly_type}"
self._counters["memory_anomalies_detected_total"][labels] += 1
async def inc_patterns_detected(self, pattern_type: str) -> None:
"""Increment pattern detection counter."""
async with self._lock:
labels = f"pattern_type={pattern_type}"
self._counters["memory_patterns_detected_total"][labels] += 1
async def inc_insights_generated(self, insight_type: str) -> None:
"""Increment insight generation counter."""
async with self._lock:
labels = f"insight_type={insight_type}"
self._counters["memory_insights_generated_total"][labels] += 1
# Gauge methods
async def set_memory_items_count(
self,
memory_type: str,
scope: str,
count: int,
) -> None:
"""Set memory item count gauge."""
async with self._lock:
labels = f"memory_type={memory_type},scope={scope}"
self._gauges["memory_items_count"][labels] = float(count)
async def set_memory_size_bytes(
self,
memory_type: str,
scope: str,
size_bytes: int,
) -> None:
"""Set memory size gauge in bytes."""
async with self._lock:
labels = f"memory_type={memory_type},scope={scope}"
self._gauges["memory_size_bytes"][labels] = float(size_bytes)
async def set_cache_size(self, cache_type: str, size: int) -> None:
"""Set cache size gauge."""
async with self._lock:
labels = f"cache_type={cache_type}"
self._gauges["memory_cache_size"][labels] = float(size)
async def set_procedure_success_rate(
self,
procedure_name: str,
rate: float,
) -> None:
"""Set procedure success rate gauge (0-1)."""
async with self._lock:
labels = f"procedure_name={procedure_name}"
self._gauges["memory_procedure_success_rate"][labels] = rate
async def set_active_sessions(self, count: int) -> None:
"""Set active working memory sessions gauge."""
async with self._lock:
self._gauges["memory_active_sessions"][""] = float(count)
async def set_pending_consolidations(self, count: int) -> None:
"""Set pending consolidations gauge."""
async with self._lock:
self._gauges["memory_pending_consolidations"][""] = float(count)
# Histogram methods
async def observe_working_latency(self, latency_seconds: float) -> None:
"""Observe working memory operation latency."""
async with self._lock:
self._observe_histogram("memory_working_latency_seconds", latency_seconds)
async def observe_retrieval_latency(self, latency_seconds: float) -> None:
"""Observe retrieval latency."""
async with self._lock:
self._observe_histogram("memory_retrieval_latency_seconds", latency_seconds)
async def observe_consolidation_duration(self, duration_seconds: float) -> None:
"""Observe consolidation duration."""
async with self._lock:
self._observe_histogram(
"memory_consolidation_duration_seconds", duration_seconds
)
async def observe_embedding_latency(self, latency_seconds: float) -> None:
"""Observe embedding generation latency."""
async with self._lock:
self._observe_histogram("memory_embedding_latency_seconds", latency_seconds)
def _observe_histogram(self, name: str, value: float) -> None:
"""Record a value in a histogram."""
self._histograms[name].append(value)
# Update buckets
if name in self._histogram_buckets:
for bucket in self._histogram_buckets[name]:
if value <= bucket.le:
bucket.count += 1
# Export methods
async def get_all_metrics(self) -> list[MetricValue]:
"""Get all metrics as MetricValue objects."""
metrics: list[MetricValue] = []
async with self._lock:
# Export counters
for name, counter in self._counters.items():
for labels_str, value in counter.items():
labels = self._parse_labels(labels_str)
metrics.append(
MetricValue(
name=name,
metric_type=MetricType.COUNTER,
value=float(value),
labels=labels,
)
)
# Export gauges
for name, gauge_dict in self._gauges.items():
for labels_str, gauge_value in gauge_dict.items():
gauge_labels = self._parse_labels(labels_str)
metrics.append(
MetricValue(
name=name,
metric_type=MetricType.GAUGE,
value=gauge_value,
labels=gauge_labels,
)
)
# Export histogram summaries
for name, values in self._histograms.items():
if values:
metrics.append(
MetricValue(
name=f"{name}_count",
metric_type=MetricType.COUNTER,
value=float(len(values)),
)
)
metrics.append(
MetricValue(
name=f"{name}_sum",
metric_type=MetricType.COUNTER,
value=sum(values),
)
)
return metrics
async def get_prometheus_format(self) -> str:
"""Export metrics in Prometheus text format."""
lines: list[str] = []
async with self._lock:
# Export counters
for name, counter in self._counters.items():
lines.append(f"# TYPE {name} counter")
for labels_str, value in counter.items():
if labels_str:
lines.append(f"{name}{{{labels_str}}} {value}")
else:
lines.append(f"{name} {value}")
# Export gauges
for name, gauge_dict in self._gauges.items():
lines.append(f"# TYPE {name} gauge")
for labels_str, gauge_value in gauge_dict.items():
if labels_str:
lines.append(f"{name}{{{labels_str}}} {gauge_value}")
else:
lines.append(f"{name} {gauge_value}")
# Export histograms
for name, buckets in self._histogram_buckets.items():
lines.append(f"# TYPE {name} histogram")
for bucket in buckets:
le_str = "+Inf" if bucket.le == float("inf") else str(bucket.le)
lines.append(f'{name}_bucket{{le="{le_str}"}} {bucket.count}')
if name in self._histograms:
values = self._histograms[name]
lines.append(f"{name}_count {len(values)}")
lines.append(f"{name}_sum {sum(values)}")
return "\n".join(lines)
async def get_summary(self) -> dict[str, Any]:
"""Get a summary of key metrics."""
async with self._lock:
total_operations = sum(self._counters["memory_operations_total"].values())
successful_operations = sum(
v
for k, v in self._counters["memory_operations_total"].items()
if "success=true" in k
)
total_retrievals = sum(self._counters["memory_retrievals_total"].values())
total_cache_hits = sum(self._counters["memory_cache_hits_total"].values())
total_cache_misses = sum(
self._counters["memory_cache_misses_total"].values()
)
cache_hit_rate = (
total_cache_hits / (total_cache_hits + total_cache_misses)
if (total_cache_hits + total_cache_misses) > 0
else 0.0
)
total_consolidations = sum(
self._counters["memory_consolidations_total"].values()
)
total_episodes = sum(
self._counters["memory_episodes_recorded_total"].values()
)
# Calculate average latencies
retrieval_latencies = list(
self._histograms.get("memory_retrieval_latency_seconds", deque())
)
avg_retrieval_latency = (
sum(retrieval_latencies) / len(retrieval_latencies)
if retrieval_latencies
else 0.0
)
return {
"total_operations": total_operations,
"successful_operations": successful_operations,
"operation_success_rate": (
successful_operations / total_operations
if total_operations > 0
else 1.0
),
"total_retrievals": total_retrievals,
"cache_hit_rate": cache_hit_rate,
"total_consolidations": total_consolidations,
"total_episodes_recorded": total_episodes,
"avg_retrieval_latency_ms": avg_retrieval_latency * 1000,
"patterns_detected": sum(
self._counters["memory_patterns_detected_total"].values()
),
"insights_generated": sum(
self._counters["memory_insights_generated_total"].values()
),
"anomalies_detected": sum(
self._counters["memory_anomalies_detected_total"].values()
),
"active_sessions": self._gauges.get("memory_active_sessions", {}).get(
"", 0
),
"pending_consolidations": self._gauges.get(
"memory_pending_consolidations", {}
).get("", 0),
}
async def get_cache_stats(self) -> dict[str, Any]:
"""Get detailed cache statistics."""
async with self._lock:
stats: dict[str, Any] = {}
# Get hits/misses by cache type
for labels_str, hits in self._counters["memory_cache_hits_total"].items():
cache_type = self._parse_labels(labels_str).get("cache_type", "unknown")
if cache_type not in stats:
stats[cache_type] = {"hits": 0, "misses": 0}
stats[cache_type]["hits"] = hits
for labels_str, misses in self._counters[
"memory_cache_misses_total"
].items():
cache_type = self._parse_labels(labels_str).get("cache_type", "unknown")
if cache_type not in stats:
stats[cache_type] = {"hits": 0, "misses": 0}
stats[cache_type]["misses"] = misses
# Calculate hit rates
for data in stats.values():
total = data["hits"] + data["misses"]
data["hit_rate"] = data["hits"] / total if total > 0 else 0.0
data["total"] = total
return stats
async def reset(self) -> None:
"""Reset all metrics."""
async with self._lock:
self._counters.clear()
self._gauges.clear()
self._histograms.clear()
self._init_histogram_buckets()
def _parse_labels(self, labels_str: str) -> dict[str, str]:
"""Parse labels string into dictionary."""
if not labels_str:
return {}
labels = {}
for pair in labels_str.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
labels[key.strip()] = value.strip()
return labels
# Singleton instance
_metrics: MemoryMetrics | None = None
_lock = asyncio.Lock()
async def get_memory_metrics() -> MemoryMetrics:
"""Get the singleton MemoryMetrics instance."""
global _metrics
async with _lock:
if _metrics is None:
_metrics = MemoryMetrics()
return _metrics
async def reset_memory_metrics() -> None:
"""Reset the singleton instance (for testing)."""
global _metrics
async with _lock:
_metrics = None
# Convenience functions
async def record_memory_operation(
operation: str,
memory_type: str,
scope: str | None = None,
success: bool = True,
latency_ms: float | None = None,
) -> None:
"""Record a memory operation."""
metrics = await get_memory_metrics()
await metrics.inc_operations(operation, memory_type, scope, success)
if latency_ms is not None and memory_type == "working":
await metrics.observe_working_latency(latency_ms / 1000)
async def record_retrieval(
memory_type: str,
strategy: str,
results_count: int,
latency_ms: float,
) -> None:
"""Record a retrieval operation."""
metrics = await get_memory_metrics()
await metrics.inc_retrieval(memory_type, strategy, results_count)
await metrics.observe_retrieval_latency(latency_ms / 1000)

View File

@@ -0,0 +1,22 @@
# app/services/memory/procedural/__init__.py
"""
Procedural Memory
Learned skills and procedures from successful task patterns.
"""
from .matching import (
MatchContext,
MatchResult,
ProcedureMatcher,
get_procedure_matcher,
)
from .memory import ProceduralMemory
__all__ = [
"MatchContext",
"MatchResult",
"ProceduralMemory",
"ProcedureMatcher",
"get_procedure_matcher",
]

View File

@@ -0,0 +1,291 @@
# app/services/memory/procedural/matching.py
"""
Procedure Matching.
Provides utilities for matching procedures to contexts,
ranking procedures by relevance, and suggesting procedures.
"""
import logging
import re
from dataclasses import dataclass, field
from typing import Any, ClassVar
from app.services.memory.types import Procedure
logger = logging.getLogger(__name__)
@dataclass
class MatchResult:
"""Result of a procedure match."""
procedure: Procedure
score: float
matched_terms: list[str] = field(default_factory=list)
match_type: str = "keyword" # keyword, semantic, pattern
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"procedure_id": str(self.procedure.id),
"procedure_name": self.procedure.name,
"score": self.score,
"matched_terms": self.matched_terms,
"match_type": self.match_type,
"success_rate": self.procedure.success_rate,
}
@dataclass
class MatchContext:
"""Context for procedure matching."""
query: str
task_type: str | None = None
project_id: Any | None = None
agent_type_id: Any | None = None
max_results: int = 5
min_score: float = 0.3
require_success_rate: float | None = None
class ProcedureMatcher:
"""
Matches procedures to contexts using multiple strategies.
Matching strategies:
- Keyword matching on trigger pattern and name
- Pattern-based matching using regex
- Success rate weighting
In production, this would be augmented with vector similarity search.
"""
# Common task-related keywords for boosting
TASK_KEYWORDS: ClassVar[set[str]] = {
"create",
"update",
"delete",
"fix",
"implement",
"add",
"remove",
"refactor",
"test",
"deploy",
"configure",
"setup",
"build",
"debug",
"optimize",
}
def __init__(self) -> None:
"""Initialize the matcher."""
self._compiled_patterns: dict[str, re.Pattern[str]] = {}
def match(
self,
procedures: list[Procedure],
context: MatchContext,
) -> list[MatchResult]:
"""
Match procedures against a context.
Args:
procedures: List of procedures to match
context: Matching context
Returns:
List of match results, sorted by score (highest first)
"""
results: list[MatchResult] = []
query_terms = self._extract_terms(context.query)
query_lower = context.query.lower()
for procedure in procedures:
score, matched = self._calculate_match_score(
procedure=procedure,
query_terms=query_terms,
query_lower=query_lower,
context=context,
)
if score >= context.min_score:
# Apply success rate boost
if context.require_success_rate is not None:
if procedure.success_rate < context.require_success_rate:
continue
# Boost score based on success rate
success_boost = procedure.success_rate * 0.2
final_score = min(1.0, score + success_boost)
results.append(
MatchResult(
procedure=procedure,
score=final_score,
matched_terms=matched,
match_type="keyword",
)
)
# Sort by score descending
results.sort(key=lambda r: r.score, reverse=True)
return results[: context.max_results]
def _extract_terms(self, text: str) -> list[str]:
"""Extract searchable terms from text."""
# Remove special characters and split
clean = re.sub(r"[^\w\s-]", " ", text.lower())
terms = clean.split()
# Filter out very short terms
return [t for t in terms if len(t) >= 2]
def _calculate_match_score(
self,
procedure: Procedure,
query_terms: list[str],
query_lower: str,
context: MatchContext,
) -> tuple[float, list[str]]:
"""
Calculate match score between procedure and query.
Returns:
Tuple of (score, matched_terms)
"""
score = 0.0
matched: list[str] = []
trigger_lower = procedure.trigger_pattern.lower()
name_lower = procedure.name.lower()
# Exact name match - high score
if name_lower in query_lower or query_lower in name_lower:
score += 0.5
matched.append(f"name:{procedure.name}")
# Trigger pattern match
if trigger_lower in query_lower or query_lower in trigger_lower:
score += 0.4
matched.append(f"trigger:{procedure.trigger_pattern[:30]}")
# Term-by-term matching
for term in query_terms:
if term in trigger_lower:
score += 0.1
matched.append(term)
elif term in name_lower:
score += 0.08
matched.append(term)
# Boost for task keywords
if term in self.TASK_KEYWORDS:
if term in trigger_lower or term in name_lower:
score += 0.05
# Task type match if provided
if context.task_type:
task_type_lower = context.task_type.lower()
if task_type_lower in trigger_lower or task_type_lower in name_lower:
score += 0.3
matched.append(f"task_type:{context.task_type}")
# Regex pattern matching on trigger
try:
pattern = self._get_or_compile_pattern(trigger_lower)
if pattern and pattern.search(query_lower):
score += 0.25
matched.append("pattern_match")
except re.error:
pass # Invalid regex, skip pattern matching
return min(1.0, score), matched
def _get_or_compile_pattern(self, pattern: str) -> re.Pattern[str] | None:
"""Get or compile a regex pattern with caching."""
if pattern in self._compiled_patterns:
return self._compiled_patterns[pattern]
# Only compile if it looks like a regex pattern
if not any(c in pattern for c in r"\.*+?[]{}|()^$"):
return None
try:
compiled = re.compile(pattern, re.IGNORECASE)
self._compiled_patterns[pattern] = compiled
return compiled
except re.error:
return None
def rank_by_relevance(
self,
procedures: list[Procedure],
task_type: str,
) -> list[Procedure]:
"""
Rank procedures by relevance to a task type.
Args:
procedures: Procedures to rank
task_type: Task type for relevance
Returns:
Procedures sorted by relevance
"""
context = MatchContext(
query=task_type,
task_type=task_type,
min_score=0.0,
max_results=len(procedures),
)
results = self.match(procedures, context)
return [r.procedure for r in results]
def suggest_procedures(
self,
procedures: list[Procedure],
query: str,
min_success_rate: float = 0.5,
max_suggestions: int = 3,
) -> list[MatchResult]:
"""
Suggest the best procedures for a query.
Only suggests procedures with sufficient success rate.
Args:
procedures: Available procedures
query: Query/context
min_success_rate: Minimum success rate to suggest
max_suggestions: Maximum suggestions
Returns:
List of procedure suggestions
"""
context = MatchContext(
query=query,
max_results=max_suggestions,
min_score=0.2,
require_success_rate=min_success_rate,
)
return self.match(procedures, context)
# Singleton matcher instance
_matcher: ProcedureMatcher | None = None
def get_procedure_matcher() -> ProcedureMatcher:
"""Get the singleton procedure matcher instance."""
global _matcher
if _matcher is None:
_matcher = ProcedureMatcher()
return _matcher

View File

@@ -0,0 +1,749 @@
# app/services/memory/procedural/memory.py
"""
Procedural Memory Implementation.
Provides storage and retrieval for learned procedures (skills)
derived from successful task execution patterns.
"""
import logging
import time
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import and_, desc, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory.procedure import Procedure as ProcedureModel
from app.services.memory.config import get_memory_settings
from app.services.memory.types import Procedure, ProcedureCreate, RetrievalResult, Step
logger = logging.getLogger(__name__)
def _escape_like_pattern(pattern: str) -> str:
"""
Escape SQL LIKE/ILIKE special characters to prevent pattern injection.
Characters escaped:
- % (matches zero or more characters)
- _ (matches exactly one character)
- \\ (escape character itself)
Args:
pattern: Raw search pattern from user input
Returns:
Escaped pattern safe for use in LIKE/ILIKE queries
"""
# Escape backslash first, then the wildcards
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
def _model_to_procedure(model: ProcedureModel) -> Procedure:
"""Convert SQLAlchemy model to Procedure dataclass."""
return Procedure(
id=model.id, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
name=model.name, # type: ignore[arg-type]
trigger_pattern=model.trigger_pattern, # type: ignore[arg-type]
steps=model.steps or [], # type: ignore[arg-type]
success_count=model.success_count, # type: ignore[arg-type]
failure_count=model.failure_count, # type: ignore[arg-type]
last_used=model.last_used, # type: ignore[arg-type]
embedding=None, # Don't expose raw embedding
created_at=model.created_at, # type: ignore[arg-type]
updated_at=model.updated_at, # type: ignore[arg-type]
)
class ProceduralMemory:
"""
Procedural Memory Service.
Provides procedure storage and retrieval:
- Record procedures from successful task patterns
- Find matching procedures by trigger pattern
- Track success/failure rates
- Get best procedure for a task type
- Update procedure steps
Performance target: <50ms P95 for matching
"""
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> None:
"""
Initialize procedural memory.
Args:
session: Database session
embedding_generator: Optional embedding generator for semantic matching
"""
self._session = session
self._embedding_generator = embedding_generator
self._settings = get_memory_settings()
@classmethod
async def create(
cls,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> "ProceduralMemory":
"""
Factory method to create ProceduralMemory.
Args:
session: Database session
embedding_generator: Optional embedding generator
Returns:
Configured ProceduralMemory instance
"""
return cls(session=session, embedding_generator=embedding_generator)
# =========================================================================
# Procedure Recording
# =========================================================================
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure:
"""
Record a new procedure or update an existing one.
If a procedure with the same name exists in the same scope,
its steps will be updated and success count incremented.
Args:
procedure: Procedure data to record
Returns:
The created or updated procedure
"""
# Check for existing procedure with same name
existing = await self._find_existing_procedure(
project_id=procedure.project_id,
agent_type_id=procedure.agent_type_id,
name=procedure.name,
)
if existing is not None:
# Update existing procedure
return await self._update_existing_procedure(
existing=existing,
new_steps=procedure.steps,
new_trigger=procedure.trigger_pattern,
)
# Create new procedure
now = datetime.now(UTC)
# Generate embedding if possible
embedding = None
if self._embedding_generator is not None:
embedding_text = self._create_embedding_text(procedure)
embedding = await self._embedding_generator.generate(embedding_text)
model = ProcedureModel(
project_id=procedure.project_id,
agent_type_id=procedure.agent_type_id,
name=procedure.name,
trigger_pattern=procedure.trigger_pattern,
steps=procedure.steps,
success_count=1, # New procedures start with 1 success (they worked)
failure_count=0,
last_used=now,
embedding=embedding,
)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
logger.info(
f"Recorded new procedure: {procedure.name} with {len(procedure.steps)} steps"
)
return _model_to_procedure(model)
async def _find_existing_procedure(
self,
project_id: UUID | None,
agent_type_id: UUID | None,
name: str,
) -> ProcedureModel | None:
"""Find an existing procedure with the same name in the same scope."""
query = select(ProcedureModel).where(ProcedureModel.name == name)
if project_id is not None:
query = query.where(ProcedureModel.project_id == project_id)
else:
query = query.where(ProcedureModel.project_id.is_(None))
if agent_type_id is not None:
query = query.where(ProcedureModel.agent_type_id == agent_type_id)
else:
query = query.where(ProcedureModel.agent_type_id.is_(None))
result = await self._session.execute(query)
return result.scalar_one_or_none()
async def _update_existing_procedure(
self,
existing: ProcedureModel,
new_steps: list[dict[str, Any]],
new_trigger: str,
) -> Procedure:
"""Update an existing procedure with new steps."""
now = datetime.now(UTC)
# Merge steps intelligently - keep existing order, add new steps
merged_steps = self._merge_steps(
existing.steps or [], # type: ignore[arg-type]
new_steps,
)
stmt = (
update(ProcedureModel)
.where(ProcedureModel.id == existing.id)
.values(
steps=merged_steps,
trigger_pattern=new_trigger,
success_count=ProcedureModel.success_count + 1,
last_used=now,
updated_at=now,
)
.returning(ProcedureModel)
)
result = await self._session.execute(stmt)
updated_model = result.scalar_one()
await self._session.flush()
logger.info(f"Updated existing procedure: {existing.name}")
return _model_to_procedure(updated_model)
def _merge_steps(
self,
existing_steps: list[dict[str, Any]],
new_steps: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Merge steps from a new execution with existing steps."""
if not existing_steps:
return new_steps
if not new_steps:
return existing_steps
# For now, use the new steps if they differ significantly
# In production, this could use more sophisticated merging
if len(new_steps) != len(existing_steps):
# If structure changed, prefer newer steps
return new_steps
# Merge step-by-step, preferring new data where available
merged = []
for i, new_step in enumerate(new_steps):
if i < len(existing_steps):
# Merge with existing step
step = {**existing_steps[i], **new_step}
else:
step = new_step
merged.append(step)
return merged
def _create_embedding_text(self, procedure: ProcedureCreate) -> str:
"""Create text for embedding from procedure data."""
steps_text = " ".join(step.get("action", "") for step in procedure.steps)
return f"{procedure.name} {procedure.trigger_pattern} {steps_text}"
# =========================================================================
# Procedure Retrieval
# =========================================================================
async def find_matching(
self,
context: str,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
limit: int = 5,
) -> list[Procedure]:
"""
Find procedures matching the given context.
Args:
context: Context/trigger to match against
project_id: Optional project to search within
agent_type_id: Optional agent type filter
limit: Maximum results
Returns:
List of matching procedures
"""
result = await self._find_matching_with_metadata(
context=context,
project_id=project_id,
agent_type_id=agent_type_id,
limit=limit,
)
return result.items
async def _find_matching_with_metadata(
self,
context: str,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
limit: int = 5,
) -> RetrievalResult[Procedure]:
"""Find matching procedures with full result metadata."""
start_time = time.perf_counter()
# Build base query - prioritize by success rate
stmt = (
select(ProcedureModel)
.order_by(
desc(
ProcedureModel.success_count
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
),
desc(ProcedureModel.last_used),
)
.limit(limit)
)
# Apply scope filters
if project_id is not None:
stmt = stmt.where(
or_(
ProcedureModel.project_id == project_id,
ProcedureModel.project_id.is_(None),
)
)
if agent_type_id is not None:
stmt = stmt.where(
or_(
ProcedureModel.agent_type_id == agent_type_id,
ProcedureModel.agent_type_id.is_(None),
)
)
# Text-based matching on trigger pattern and name
# TODO: Implement proper vector similarity search when pgvector is integrated
search_terms = context.lower().split()[:5] # Limit to 5 terms
if search_terms:
conditions = []
for term in search_terms:
# Escape SQL wildcards to prevent pattern injection
escaped_term = _escape_like_pattern(term)
term_pattern = f"%{escaped_term}%"
conditions.append(
or_(
ProcedureModel.trigger_pattern.ilike(term_pattern),
ProcedureModel.name.ilike(term_pattern),
)
)
if conditions:
stmt = stmt.where(or_(*conditions))
result = await self._session.execute(stmt)
models = list(result.scalars().all())
latency_ms = (time.perf_counter() - start_time) * 1000
return RetrievalResult(
items=[_model_to_procedure(m) for m in models],
total_count=len(models),
query=context,
retrieval_type="procedural",
latency_ms=latency_ms,
metadata={"project_id": str(project_id) if project_id else None},
)
async def get_best_procedure(
self,
task_type: str,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
min_success_rate: float = 0.5,
min_uses: int = 1,
) -> Procedure | None:
"""
Get the best procedure for a given task type.
Returns the procedure with the highest success rate that
meets the minimum thresholds.
Args:
task_type: Task type to find procedure for
project_id: Optional project scope
agent_type_id: Optional agent type scope
min_success_rate: Minimum required success rate
min_uses: Minimum number of uses required
Returns:
Best matching procedure or None
"""
# Escape SQL wildcards to prevent pattern injection
escaped_task_type = _escape_like_pattern(task_type)
task_type_pattern = f"%{escaped_task_type}%"
# Build query for procedures matching task type
stmt = (
select(ProcedureModel)
.where(
and_(
(ProcedureModel.success_count + ProcedureModel.failure_count)
>= min_uses,
or_(
ProcedureModel.trigger_pattern.ilike(task_type_pattern),
ProcedureModel.name.ilike(task_type_pattern),
),
)
)
.order_by(
desc(
ProcedureModel.success_count
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
),
desc(ProcedureModel.last_used),
)
.limit(10)
)
# Apply scope filters
if project_id is not None:
stmt = stmt.where(
or_(
ProcedureModel.project_id == project_id,
ProcedureModel.project_id.is_(None),
)
)
if agent_type_id is not None:
stmt = stmt.where(
or_(
ProcedureModel.agent_type_id == agent_type_id,
ProcedureModel.agent_type_id.is_(None),
)
)
result = await self._session.execute(stmt)
models = list(result.scalars().all())
# Filter by success rate in Python (SQLAlchemy division in WHERE is complex)
for model in models:
success = float(model.success_count)
failure = float(model.failure_count)
total = success + failure
if total > 0 and (success / total) >= min_success_rate:
logger.debug(
f"Found best procedure for '{task_type}': {model.name} "
f"(success_rate={success / total:.2%})"
)
return _model_to_procedure(model)
return None
async def get_by_id(self, procedure_id: UUID) -> Procedure | None:
"""Get a procedure by ID."""
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
return _model_to_procedure(model) if model else None
# =========================================================================
# Outcome Recording
# =========================================================================
async def record_outcome(
self,
procedure_id: UUID,
success: bool,
) -> Procedure:
"""
Record the outcome of using a procedure.
Updates the success or failure count and last_used timestamp.
Args:
procedure_id: Procedure that was used
success: Whether the procedure succeeded
Returns:
Updated procedure
Raises:
ValueError: If procedure not found
"""
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
raise ValueError(f"Procedure not found: {procedure_id}")
now = datetime.now(UTC)
if success:
stmt = (
update(ProcedureModel)
.where(ProcedureModel.id == procedure_id)
.values(
success_count=ProcedureModel.success_count + 1,
last_used=now,
updated_at=now,
)
.returning(ProcedureModel)
)
else:
stmt = (
update(ProcedureModel)
.where(ProcedureModel.id == procedure_id)
.values(
failure_count=ProcedureModel.failure_count + 1,
last_used=now,
updated_at=now,
)
.returning(ProcedureModel)
)
result = await self._session.execute(stmt)
updated_model = result.scalar_one()
await self._session.flush()
outcome = "success" if success else "failure"
logger.info(
f"Recorded {outcome} for procedure {procedure_id}: "
f"success_rate={updated_model.success_rate:.2%}"
)
return _model_to_procedure(updated_model)
# =========================================================================
# Step Management
# =========================================================================
async def update_steps(
self,
procedure_id: UUID,
steps: list[Step],
) -> Procedure:
"""
Update the steps of a procedure.
Args:
procedure_id: Procedure to update
steps: New steps
Returns:
Updated procedure
Raises:
ValueError: If procedure not found
"""
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
raise ValueError(f"Procedure not found: {procedure_id}")
# Convert Step objects to dictionaries
steps_dict = [
{
"order": step.order,
"action": step.action,
"parameters": step.parameters,
"expected_outcome": step.expected_outcome,
"fallback_action": step.fallback_action,
}
for step in steps
]
now = datetime.now(UTC)
stmt = (
update(ProcedureModel)
.where(ProcedureModel.id == procedure_id)
.values(
steps=steps_dict,
updated_at=now,
)
.returning(ProcedureModel)
)
result = await self._session.execute(stmt)
updated_model = result.scalar_one()
await self._session.flush()
logger.info(f"Updated steps for procedure {procedure_id}: {len(steps)} steps")
return _model_to_procedure(updated_model)
# =========================================================================
# Statistics & Management
# =========================================================================
async def get_stats(
self,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
) -> dict[str, Any]:
"""
Get statistics about procedural memory.
Args:
project_id: Optional project to get stats for
agent_type_id: Optional agent type filter
Returns:
Dictionary with statistics
"""
query = select(ProcedureModel)
if project_id is not None:
query = query.where(
or_(
ProcedureModel.project_id == project_id,
ProcedureModel.project_id.is_(None),
)
)
if agent_type_id is not None:
query = query.where(
or_(
ProcedureModel.agent_type_id == agent_type_id,
ProcedureModel.agent_type_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
if not models:
return {
"total_procedures": 0,
"avg_success_rate": 0.0,
"avg_steps_count": 0.0,
"total_uses": 0,
"high_success_count": 0,
"low_success_count": 0,
}
success_rates = [m.success_rate for m in models]
step_counts = [len(m.steps or []) for m in models]
total_uses = sum(m.total_uses for m in models)
return {
"total_procedures": len(models),
"avg_success_rate": sum(success_rates) / len(success_rates),
"avg_steps_count": sum(step_counts) / len(step_counts),
"total_uses": total_uses,
"high_success_count": sum(1 for r in success_rates if r >= 0.8),
"low_success_count": sum(1 for r in success_rates if r < 0.5),
}
async def count(
self,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
) -> int:
"""
Count procedures in scope.
Args:
project_id: Optional project to count for
agent_type_id: Optional agent type filter
Returns:
Number of procedures
"""
query = select(ProcedureModel)
if project_id is not None:
query = query.where(
or_(
ProcedureModel.project_id == project_id,
ProcedureModel.project_id.is_(None),
)
)
if agent_type_id is not None:
query = query.where(
or_(
ProcedureModel.agent_type_id == agent_type_id,
ProcedureModel.agent_type_id.is_(None),
)
)
result = await self._session.execute(query)
return len(list(result.scalars().all()))
async def delete(self, procedure_id: UUID) -> bool:
"""
Delete a procedure.
Args:
procedure_id: Procedure to delete
Returns:
True if deleted, False if not found
"""
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
return False
await self._session.delete(model)
await self._session.flush()
logger.info(f"Deleted procedure {procedure_id}")
return True
async def get_procedures_by_success_rate(
self,
min_rate: float = 0.0,
max_rate: float = 1.0,
project_id: UUID | None = None,
limit: int = 20,
) -> list[Procedure]:
"""
Get procedures within a success rate range.
Args:
min_rate: Minimum success rate
max_rate: Maximum success rate
project_id: Optional project scope
limit: Maximum results
Returns:
List of procedures
"""
query = (
select(ProcedureModel)
.order_by(desc(ProcedureModel.last_used))
.limit(limit * 2) # Fetch more since we filter in Python
)
if project_id is not None:
query = query.where(
or_(
ProcedureModel.project_id == project_id,
ProcedureModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
# Filter by success rate in Python
filtered = [m for m in models if min_rate <= m.success_rate <= max_rate][:limit]
return [_model_to_procedure(m) for m in filtered]

View File

@@ -0,0 +1,38 @@
# app/services/memory/reflection/__init__.py
"""
Memory Reflection Layer.
Analyzes patterns in agent experiences to generate actionable insights.
"""
from .service import (
MemoryReflection,
ReflectionConfig,
get_memory_reflection,
)
from .types import (
Anomaly,
AnomalyType,
Factor,
FactorType,
Insight,
InsightType,
Pattern,
PatternType,
TimeRange,
)
__all__ = [
"Anomaly",
"AnomalyType",
"Factor",
"FactorType",
"Insight",
"InsightType",
"MemoryReflection",
"Pattern",
"PatternType",
"ReflectionConfig",
"TimeRange",
"get_memory_reflection",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,304 @@
# app/services/memory/reflection/types.py
"""
Memory Reflection Types.
Type definitions for pattern detection, anomaly detection, and insights.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from uuid import UUID
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
class PatternType(str, Enum):
"""Types of patterns detected in episodic memory."""
RECURRING_SUCCESS = "recurring_success"
RECURRING_FAILURE = "recurring_failure"
ACTION_SEQUENCE = "action_sequence"
CONTEXT_CORRELATION = "context_correlation"
TEMPORAL = "temporal"
EFFICIENCY = "efficiency"
class FactorType(str, Enum):
"""Types of factors contributing to outcomes."""
ACTION = "action"
CONTEXT = "context"
TIMING = "timing"
RESOURCE = "resource"
PRECEDING_STATE = "preceding_state"
class AnomalyType(str, Enum):
"""Types of anomalies detected."""
UNUSUAL_DURATION = "unusual_duration"
UNEXPECTED_OUTCOME = "unexpected_outcome"
UNUSUAL_TOKEN_USAGE = "unusual_token_usage"
UNUSUAL_FAILURE_RATE = "unusual_failure_rate"
UNUSUAL_ACTION_PATTERN = "unusual_action_pattern"
class InsightType(str, Enum):
"""Types of insights generated."""
OPTIMIZATION = "optimization"
WARNING = "warning"
LEARNING = "learning"
RECOMMENDATION = "recommendation"
TREND = "trend"
@dataclass
class TimeRange:
"""Time range for reflection analysis."""
start: datetime
end: datetime
@classmethod
def last_hours(cls, hours: int = 24) -> "TimeRange":
"""Create time range for last N hours."""
end = _utcnow()
start = datetime(
end.year, end.month, end.day, end.hour, end.minute, end.second, tzinfo=UTC
) - __import__("datetime").timedelta(hours=hours)
return cls(start=start, end=end)
@classmethod
def last_days(cls, days: int = 7) -> "TimeRange":
"""Create time range for last N days."""
from datetime import timedelta
end = _utcnow()
start = end - timedelta(days=days)
return cls(start=start, end=end)
@property
def duration_hours(self) -> float:
"""Get duration in hours."""
return (self.end - self.start).total_seconds() / 3600
@property
def duration_days(self) -> float:
"""Get duration in days."""
return (self.end - self.start).total_seconds() / 86400
@dataclass
class Pattern:
"""A detected pattern in episodic memory."""
id: UUID
pattern_type: PatternType
name: str
description: str
confidence: float
occurrence_count: int
episode_ids: list[UUID]
first_seen: datetime
last_seen: datetime
metadata: dict[str, Any] = field(default_factory=dict)
@property
def frequency(self) -> float:
"""Calculate pattern frequency per day."""
duration_days = (self.last_seen - self.first_seen).total_seconds() / 86400
if duration_days < 1:
duration_days = 1
return self.occurrence_count / duration_days
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"id": str(self.id),
"pattern_type": self.pattern_type.value,
"name": self.name,
"description": self.description,
"confidence": self.confidence,
"occurrence_count": self.occurrence_count,
"episode_ids": [str(eid) for eid in self.episode_ids],
"first_seen": self.first_seen.isoformat(),
"last_seen": self.last_seen.isoformat(),
"frequency": self.frequency,
"metadata": self.metadata,
}
@dataclass
class Factor:
"""A factor contributing to success or failure."""
id: UUID
factor_type: FactorType
name: str
description: str
impact_score: float
correlation: float
sample_size: int
positive_examples: list[UUID]
negative_examples: list[UUID]
metadata: dict[str, Any] = field(default_factory=dict)
@property
def net_impact(self) -> float:
"""Calculate net impact considering sample size."""
# Weight impact by sample confidence
confidence_weight = min(1.0, self.sample_size / 20)
return self.impact_score * self.correlation * confidence_weight
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"id": str(self.id),
"factor_type": self.factor_type.value,
"name": self.name,
"description": self.description,
"impact_score": self.impact_score,
"correlation": self.correlation,
"sample_size": self.sample_size,
"positive_examples": [str(eid) for eid in self.positive_examples],
"negative_examples": [str(eid) for eid in self.negative_examples],
"net_impact": self.net_impact,
"metadata": self.metadata,
}
@dataclass
class Anomaly:
"""An anomaly detected in memory patterns."""
id: UUID
anomaly_type: AnomalyType
description: str
severity: float
episode_ids: list[UUID]
detected_at: datetime
baseline_value: float
observed_value: float
deviation_factor: float
metadata: dict[str, Any] = field(default_factory=dict)
@property
def is_critical(self) -> bool:
"""Check if anomaly is critical (severity > 0.8)."""
return self.severity > 0.8
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"id": str(self.id),
"anomaly_type": self.anomaly_type.value,
"description": self.description,
"severity": self.severity,
"episode_ids": [str(eid) for eid in self.episode_ids],
"detected_at": self.detected_at.isoformat(),
"baseline_value": self.baseline_value,
"observed_value": self.observed_value,
"deviation_factor": self.deviation_factor,
"is_critical": self.is_critical,
"metadata": self.metadata,
}
@dataclass
class Insight:
"""An actionable insight generated from reflection."""
id: UUID
insight_type: InsightType
title: str
description: str
priority: float
confidence: float
source_patterns: list[UUID]
source_factors: list[UUID]
source_anomalies: list[UUID]
recommended_actions: list[str]
generated_at: datetime
metadata: dict[str, Any] = field(default_factory=dict)
@property
def actionable_score(self) -> float:
"""Calculate how actionable this insight is."""
action_weight = min(1.0, len(self.recommended_actions) / 3)
return self.priority * self.confidence * action_weight
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"id": str(self.id),
"insight_type": self.insight_type.value,
"title": self.title,
"description": self.description,
"priority": self.priority,
"confidence": self.confidence,
"source_patterns": [str(pid) for pid in self.source_patterns],
"source_factors": [str(fid) for fid in self.source_factors],
"source_anomalies": [str(aid) for aid in self.source_anomalies],
"recommended_actions": self.recommended_actions,
"generated_at": self.generated_at.isoformat(),
"actionable_score": self.actionable_score,
"metadata": self.metadata,
}
@dataclass
class ReflectionResult:
"""Result of a reflection operation."""
patterns: list[Pattern]
factors: list[Factor]
anomalies: list[Anomaly]
insights: list[Insight]
time_range: TimeRange
episodes_analyzed: int
analysis_duration_seconds: float
generated_at: datetime = field(default_factory=_utcnow)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"patterns": [p.to_dict() for p in self.patterns],
"factors": [f.to_dict() for f in self.factors],
"anomalies": [a.to_dict() for a in self.anomalies],
"insights": [i.to_dict() for i in self.insights],
"time_range": {
"start": self.time_range.start.isoformat(),
"end": self.time_range.end.isoformat(),
"duration_hours": self.time_range.duration_hours,
},
"episodes_analyzed": self.episodes_analyzed,
"analysis_duration_seconds": self.analysis_duration_seconds,
"generated_at": self.generated_at.isoformat(),
}
@property
def summary(self) -> str:
"""Generate a summary of the reflection results."""
lines = [
f"Reflection Analysis ({self.time_range.duration_days:.1f} days)",
f"Episodes analyzed: {self.episodes_analyzed}",
"",
f"Patterns detected: {len(self.patterns)}",
f"Success/failure factors: {len(self.factors)}",
f"Anomalies found: {len(self.anomalies)}",
f"Insights generated: {len(self.insights)}",
]
if self.insights:
lines.append("")
lines.append("Top insights:")
for insight in sorted(self.insights, key=lambda i: -i.priority)[:3]:
lines.append(f" - [{insight.insight_type.value}] {insight.title}")
return "\n".join(lines)

View File

@@ -0,0 +1,33 @@
# app/services/memory/scoping/__init__.py
"""
Memory Scoping
Hierarchical scoping for memory with inheritance:
Global -> Project -> Agent Type -> Agent Instance -> Session
"""
from .resolver import (
ResolutionOptions,
ResolutionResult,
ScopeFilter,
ScopeResolver,
get_scope_resolver,
)
from .scope import (
ScopeInfo,
ScopeManager,
ScopePolicy,
get_scope_manager,
)
__all__ = [
"ResolutionOptions",
"ResolutionResult",
"ScopeFilter",
"ScopeInfo",
"ScopeManager",
"ScopePolicy",
"ScopeResolver",
"get_scope_manager",
"get_scope_resolver",
]

View File

@@ -0,0 +1,390 @@
# app/services/memory/scoping/resolver.py
"""
Scope Resolution.
Provides utilities for resolving memory queries across scope hierarchies,
implementing inheritance and aggregation of memories from parent scopes.
"""
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, TypeVar
from app.services.memory.types import ScopeContext, ScopeLevel
from .scope import ScopeManager, get_scope_manager
logger = logging.getLogger(__name__)
T = TypeVar("T")
@dataclass
class ResolutionResult[T]:
"""Result of a scope resolution."""
items: list[T]
sources: list[ScopeContext]
total_from_each: dict[str, int] = field(default_factory=dict)
inherited_count: int = 0
own_count: int = 0
@property
def total_count(self) -> int:
"""Get total items from all sources."""
return len(self.items)
@dataclass
class ResolutionOptions:
"""Options for scope resolution."""
include_inherited: bool = True
max_inheritance_depth: int = 5
limit_per_scope: int = 100
total_limit: int = 500
deduplicate: bool = True
deduplicate_key: str | None = None # Field to use for deduplication
class ScopeResolver:
"""
Resolves memory queries across scope hierarchies.
Features:
- Traverse scope hierarchy for inherited memories
- Aggregate results from multiple scope levels
- Apply access control policies
- Support deduplication across scopes
"""
def __init__(
self,
manager: ScopeManager | None = None,
) -> None:
"""
Initialize the resolver.
Args:
manager: Scope manager to use (defaults to singleton)
"""
self._manager = manager or get_scope_manager()
def resolve(
self,
scope: ScopeContext,
fetcher: Callable[[ScopeContext, int], list[T]],
options: ResolutionOptions | None = None,
) -> ResolutionResult[T]:
"""
Resolve memories for a scope, including inherited memories.
Args:
scope: Starting scope
fetcher: Function to fetch items for a scope (scope, limit) -> items
options: Resolution options
Returns:
Resolution result with items from all scopes
"""
opts = options or ResolutionOptions()
all_items: list[T] = []
sources: list[ScopeContext] = []
counts: dict[str, int] = {}
seen_keys: set[Any] = set()
# Collect scopes to query (starting from current, going up to ancestors)
scopes_to_query = self._collect_queryable_scopes(
scope=scope,
max_depth=opts.max_inheritance_depth if opts.include_inherited else 0,
)
own_count = 0
inherited_count = 0
remaining_limit = opts.total_limit
for i, query_scope in enumerate(scopes_to_query):
if remaining_limit <= 0:
break
# Check access policy
policy = self._manager.get_policy(query_scope)
if not policy.allows_read():
continue
if i > 0 and not policy.allows_inherit():
continue
# Fetch items for this scope
scope_limit = min(opts.limit_per_scope, remaining_limit)
items = fetcher(query_scope, scope_limit)
# Apply deduplication
if opts.deduplicate and opts.deduplicate_key:
items = self._deduplicate(items, opts.deduplicate_key, seen_keys)
if items:
all_items.extend(items)
sources.append(query_scope)
key = f"{query_scope.scope_type.value}:{query_scope.scope_id}"
counts[key] = len(items)
if i == 0:
own_count = len(items)
else:
inherited_count += len(items)
remaining_limit -= len(items)
logger.debug(
f"Resolved {len(all_items)} items from {len(sources)} scopes "
f"(own={own_count}, inherited={inherited_count})"
)
return ResolutionResult(
items=all_items[: opts.total_limit],
sources=sources,
total_from_each=counts,
own_count=own_count,
inherited_count=inherited_count,
)
def _collect_queryable_scopes(
self,
scope: ScopeContext,
max_depth: int,
) -> list[ScopeContext]:
"""Collect scopes to query, from current to ancestors."""
scopes: list[ScopeContext] = [scope]
if max_depth <= 0:
return scopes
current = scope.parent
depth = 0
while current is not None and depth < max_depth:
scopes.append(current)
current = current.parent
depth += 1
return scopes
def _deduplicate(
self,
items: list[T],
key_field: str,
seen_keys: set[Any],
) -> list[T]:
"""Remove duplicate items based on a key field."""
unique: list[T] = []
for item in items:
key = getattr(item, key_field, None)
if key is None:
# If no key, include the item
unique.append(item)
elif key not in seen_keys:
seen_keys.add(key)
unique.append(item)
return unique
def get_visible_scopes(
self,
scope: ScopeContext,
) -> list[ScopeContext]:
"""
Get all scopes visible from a given scope.
A scope can see itself and all its ancestors (if inheritance allowed).
Args:
scope: Starting scope
Returns:
List of visible scopes (from most specific to most general)
"""
visible = [scope]
current = scope.parent
while current is not None:
policy = self._manager.get_policy(current)
if policy.allows_inherit():
visible.append(current)
else:
break # Stop at first non-inheritable scope
current = current.parent
return visible
def find_write_scope(
self,
target_level: ScopeLevel,
scope: ScopeContext,
) -> ScopeContext | None:
"""
Find the appropriate scope for writing at a target level.
Walks up the hierarchy to find a scope at the target level
that allows writing.
Args:
target_level: Desired scope level
scope: Starting scope
Returns:
Scope to write to, or None if not found/not allowed
"""
# First check if current scope is at target level
if scope.scope_type == target_level:
policy = self._manager.get_policy(scope)
return scope if policy.allows_write() else None
# Check ancestors
current = scope.parent
while current is not None:
if current.scope_type == target_level:
policy = self._manager.get_policy(current)
return current if policy.allows_write() else None
current = current.parent
return None
def resolve_scope_from_memory(
self,
memory_type: str,
project_id: str | None = None,
agent_type_id: str | None = None,
agent_instance_id: str | None = None,
session_id: str | None = None,
) -> tuple[ScopeContext, ScopeLevel]:
"""
Resolve the appropriate scope for a memory operation.
Different memory types have different scope requirements:
- working: Session or Agent Instance
- episodic: Agent Instance or Project
- semantic: Project or Global
- procedural: Agent Type or Project
Args:
memory_type: Type of memory
project_id: Project ID
agent_type_id: Agent type ID
agent_instance_id: Agent instance ID
session_id: Session ID
Returns:
Tuple of (scope context, recommended level)
"""
# Build full scope chain
scope = self._manager.create_scope_from_ids(
project_id=project_id if project_id else None, # type: ignore[arg-type]
agent_type_id=agent_type_id if agent_type_id else None, # type: ignore[arg-type]
agent_instance_id=agent_instance_id if agent_instance_id else None, # type: ignore[arg-type]
session_id=session_id,
)
# Determine recommended level based on memory type
recommended = self._get_recommended_level(memory_type)
return scope, recommended
def _get_recommended_level(self, memory_type: str) -> ScopeLevel:
"""Get recommended scope level for a memory type."""
recommendations = {
"working": ScopeLevel.SESSION,
"episodic": ScopeLevel.AGENT_INSTANCE,
"semantic": ScopeLevel.PROJECT,
"procedural": ScopeLevel.AGENT_TYPE,
}
return recommendations.get(memory_type, ScopeLevel.PROJECT)
def validate_write_access(
self,
scope: ScopeContext,
memory_type: str,
) -> bool:
"""
Validate that writing is allowed for the given scope and memory type.
Args:
scope: Scope to validate
memory_type: Type of memory to write
Returns:
True if write is allowed
"""
policy = self._manager.get_policy(scope)
if not policy.allows_write():
return False
if not policy.allows_memory_type(memory_type):
return False
return True
def get_scope_chain(
self,
scope: ScopeContext,
) -> list[tuple[ScopeLevel, str]]:
"""
Get the scope chain as a list of (level, id) tuples.
Args:
scope: Scope to get chain for
Returns:
List of (level, id) tuples from root to leaf
"""
chain: list[tuple[ScopeLevel, str]] = []
# Get full hierarchy
hierarchy = scope.get_hierarchy()
for ctx in hierarchy:
chain.append((ctx.scope_type, ctx.scope_id))
return chain
@dataclass
class ScopeFilter:
"""Filter for querying across scopes."""
scope_types: list[ScopeLevel] | None = None
project_ids: list[str] | None = None
agent_type_ids: list[str] | None = None
include_global: bool = True
def matches(self, scope: ScopeContext) -> bool:
"""Check if a scope matches this filter."""
if self.scope_types and scope.scope_type not in self.scope_types:
return False
if scope.scope_type == ScopeLevel.GLOBAL:
return self.include_global
if scope.scope_type == ScopeLevel.PROJECT:
if self.project_ids and scope.scope_id not in self.project_ids:
return False
if scope.scope_type == ScopeLevel.AGENT_TYPE:
if self.agent_type_ids and scope.scope_id not in self.agent_type_ids:
return False
return True
# Singleton resolver instance
_resolver: ScopeResolver | None = None
def get_scope_resolver() -> ScopeResolver:
"""Get the singleton scope resolver instance."""
global _resolver
if _resolver is None:
_resolver = ScopeResolver()
return _resolver

View File

@@ -0,0 +1,472 @@
# app/services/memory/scoping/scope.py
"""
Scope Management.
Provides utilities for managing memory scopes with hierarchical inheritance:
Global -> Project -> Agent Type -> Agent Instance -> Session
"""
import logging
import threading
from dataclasses import dataclass, field
from typing import Any, ClassVar
from uuid import UUID
from app.services.memory.types import ScopeContext, ScopeLevel
logger = logging.getLogger(__name__)
@dataclass
class ScopePolicy:
"""Access control policy for a scope."""
scope_type: ScopeLevel
scope_id: str
can_read: bool = True
can_write: bool = True
can_inherit: bool = True
allowed_memory_types: list[str] = field(default_factory=lambda: ["all"])
metadata: dict[str, Any] = field(default_factory=dict)
def allows_read(self) -> bool:
"""Check if reading is allowed."""
return self.can_read
def allows_write(self) -> bool:
"""Check if writing is allowed."""
return self.can_write
def allows_inherit(self) -> bool:
"""Check if inheritance from parent is allowed."""
return self.can_inherit
def allows_memory_type(self, memory_type: str) -> bool:
"""Check if a specific memory type is allowed."""
return (
"all" in self.allowed_memory_types
or memory_type in self.allowed_memory_types
)
@dataclass
class ScopeInfo:
"""Information about a scope including its hierarchy."""
context: ScopeContext
policy: ScopePolicy
parent_info: "ScopeInfo | None" = None
child_count: int = 0
memory_count: int = 0
@property
def depth(self) -> int:
"""Get the depth of this scope in the hierarchy."""
count = 0
current = self.parent_info
while current is not None:
count += 1
current = current.parent_info
return count
class ScopeManager:
"""
Manages memory scopes and their hierarchies.
Provides:
- Scope creation and validation
- Hierarchy management
- Access control policy management
- Scope inheritance rules
"""
# Order of scope levels from root to leaf
SCOPE_ORDER: ClassVar[list[ScopeLevel]] = [
ScopeLevel.GLOBAL,
ScopeLevel.PROJECT,
ScopeLevel.AGENT_TYPE,
ScopeLevel.AGENT_INSTANCE,
ScopeLevel.SESSION,
]
def __init__(self) -> None:
"""Initialize the scope manager."""
# In-memory policy cache (would be backed by database in production)
self._policies: dict[str, ScopePolicy] = {}
self._default_policies = self._create_default_policies()
def _create_default_policies(self) -> dict[ScopeLevel, ScopePolicy]:
"""Create default policies for each scope level."""
return {
ScopeLevel.GLOBAL: ScopePolicy(
scope_type=ScopeLevel.GLOBAL,
scope_id="global",
can_read=True,
can_write=False, # Global writes require special permission
can_inherit=True,
),
ScopeLevel.PROJECT: ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="default",
can_read=True,
can_write=True,
can_inherit=True,
),
ScopeLevel.AGENT_TYPE: ScopePolicy(
scope_type=ScopeLevel.AGENT_TYPE,
scope_id="default",
can_read=True,
can_write=True,
can_inherit=True,
),
ScopeLevel.AGENT_INSTANCE: ScopePolicy(
scope_type=ScopeLevel.AGENT_INSTANCE,
scope_id="default",
can_read=True,
can_write=True,
can_inherit=True,
),
ScopeLevel.SESSION: ScopePolicy(
scope_type=ScopeLevel.SESSION,
scope_id="default",
can_read=True,
can_write=True,
can_inherit=True,
allowed_memory_types=["working"], # Sessions only allow working memory
),
}
def create_scope(
self,
scope_type: ScopeLevel,
scope_id: str,
parent: ScopeContext | None = None,
) -> ScopeContext:
"""
Create a new scope context.
Args:
scope_type: Level of the scope
scope_id: Unique identifier within the level
parent: Optional parent scope
Returns:
Created scope context
Raises:
ValueError: If scope hierarchy is invalid
"""
# Validate hierarchy
if parent is not None:
self._validate_parent_child(parent.scope_type, scope_type)
# For non-global scopes without parent, auto-create parent chain
if parent is None and scope_type != ScopeLevel.GLOBAL:
parent = self._create_parent_chain(scope_type, scope_id)
context = ScopeContext(
scope_type=scope_type,
scope_id=scope_id,
parent=parent,
)
logger.debug(f"Created scope: {scope_type.value}:{scope_id}")
return context
def _validate_parent_child(
self,
parent_type: ScopeLevel,
child_type: ScopeLevel,
) -> None:
"""Validate that parent-child relationship is valid."""
parent_idx = self.SCOPE_ORDER.index(parent_type)
child_idx = self.SCOPE_ORDER.index(child_type)
if child_idx <= parent_idx:
raise ValueError(
f"Invalid scope hierarchy: {child_type.value} cannot be child of {parent_type.value}"
)
# Allow skipping levels (e.g., PROJECT -> SESSION is valid)
# This enables flexible scope structures
def _create_parent_chain(
self,
target_type: ScopeLevel,
scope_id: str,
) -> ScopeContext:
"""Create parent scope chain up to target type."""
target_idx = self.SCOPE_ORDER.index(target_type)
# Start from global and build chain
current: ScopeContext | None = None
for i in range(target_idx):
level = self.SCOPE_ORDER[i]
if level == ScopeLevel.GLOBAL:
level_id = "global"
else:
# Use a default ID for intermediate levels
level_id = f"default_{level.value}"
current = ScopeContext(
scope_type=level,
scope_id=level_id,
parent=current,
)
return current # type: ignore[return-value]
def create_scope_from_ids(
self,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
agent_instance_id: UUID | None = None,
session_id: str | None = None,
) -> ScopeContext:
"""
Create a scope context from individual IDs.
Automatically determines the most specific scope level
based on provided IDs.
Args:
project_id: Project UUID
agent_type_id: Agent type UUID
agent_instance_id: Agent instance UUID
session_id: Session identifier
Returns:
Scope context for the most specific level
"""
# Build scope chain from most general to most specific
current: ScopeContext = ScopeContext(
scope_type=ScopeLevel.GLOBAL,
scope_id="global",
parent=None,
)
if project_id is not None:
current = ScopeContext(
scope_type=ScopeLevel.PROJECT,
scope_id=str(project_id),
parent=current,
)
if agent_type_id is not None:
current = ScopeContext(
scope_type=ScopeLevel.AGENT_TYPE,
scope_id=str(agent_type_id),
parent=current,
)
if agent_instance_id is not None:
current = ScopeContext(
scope_type=ScopeLevel.AGENT_INSTANCE,
scope_id=str(agent_instance_id),
parent=current,
)
if session_id is not None:
current = ScopeContext(
scope_type=ScopeLevel.SESSION,
scope_id=session_id,
parent=current,
)
return current
def get_policy(
self,
scope: ScopeContext,
) -> ScopePolicy:
"""
Get the access policy for a scope.
Args:
scope: Scope to get policy for
Returns:
Policy for the scope
"""
key = self._scope_key(scope)
if key in self._policies:
return self._policies[key]
# Return default policy for the scope level
return self._default_policies.get(
scope.scope_type,
ScopePolicy(
scope_type=scope.scope_type,
scope_id=scope.scope_id,
),
)
def set_policy(
self,
scope: ScopeContext,
policy: ScopePolicy,
) -> None:
"""
Set the access policy for a scope.
Args:
scope: Scope to set policy for
policy: Policy to apply
"""
key = self._scope_key(scope)
self._policies[key] = policy
logger.info(f"Set policy for scope {key}")
def _scope_key(self, scope: ScopeContext) -> str:
"""Generate a unique key for a scope."""
return f"{scope.scope_type.value}:{scope.scope_id}"
def get_scope_depth(self, scope_type: ScopeLevel) -> int:
"""Get the depth of a scope level in the hierarchy."""
return self.SCOPE_ORDER.index(scope_type)
def get_parent_level(self, scope_type: ScopeLevel) -> ScopeLevel | None:
"""Get the parent scope level for a given level."""
idx = self.SCOPE_ORDER.index(scope_type)
if idx == 0:
return None
return self.SCOPE_ORDER[idx - 1]
def get_child_level(self, scope_type: ScopeLevel) -> ScopeLevel | None:
"""Get the child scope level for a given level."""
idx = self.SCOPE_ORDER.index(scope_type)
if idx >= len(self.SCOPE_ORDER) - 1:
return None
return self.SCOPE_ORDER[idx + 1]
def is_ancestor(
self,
potential_ancestor: ScopeContext,
descendant: ScopeContext,
) -> bool:
"""
Check if one scope is an ancestor of another.
Args:
potential_ancestor: Scope to check as ancestor
descendant: Scope to check as descendant
Returns:
True if ancestor relationship exists
"""
current = descendant.parent
while current is not None:
if (
current.scope_type == potential_ancestor.scope_type
and current.scope_id == potential_ancestor.scope_id
):
return True
current = current.parent
return False
def get_common_ancestor(
self,
scope_a: ScopeContext,
scope_b: ScopeContext,
) -> ScopeContext | None:
"""
Find the nearest common ancestor of two scopes.
Args:
scope_a: First scope
scope_b: Second scope
Returns:
Common ancestor or None if none exists
"""
# Get ancestors of scope_a
ancestors_a: set[str] = set()
current: ScopeContext | None = scope_a
while current is not None:
ancestors_a.add(self._scope_key(current))
current = current.parent
# Find first ancestor of scope_b that's in ancestors_a
current = scope_b
while current is not None:
if self._scope_key(current) in ancestors_a:
return current
current = current.parent
return None
def can_access(
self,
accessor_scope: ScopeContext,
target_scope: ScopeContext,
operation: str = "read",
) -> bool:
"""
Check if accessor scope can access target scope.
Access rules:
- A scope can always access itself
- A scope can access ancestors (if inheritance allowed)
- A scope CANNOT access descendants (privacy)
- Sibling scopes cannot access each other
Args:
accessor_scope: Scope attempting access
target_scope: Scope being accessed
operation: Type of operation (read/write)
Returns:
True if access is allowed
"""
# Same scope - always allowed
if (
accessor_scope.scope_type == target_scope.scope_type
and accessor_scope.scope_id == target_scope.scope_id
):
policy = self.get_policy(target_scope)
if operation == "write":
return policy.allows_write()
return policy.allows_read()
# Check if target is ancestor (inheritance)
if self.is_ancestor(target_scope, accessor_scope):
policy = self.get_policy(target_scope)
if not policy.allows_inherit():
return False
if operation == "write":
return policy.allows_write()
return policy.allows_read()
# Check if accessor is ancestor of target (downward access)
# This is NOT allowed - parents cannot access children's memories
if self.is_ancestor(accessor_scope, target_scope):
return False
# Sibling scopes cannot access each other
return False
# Singleton manager instance with thread-safe initialization
_manager: ScopeManager | None = None
_manager_lock = threading.Lock()
def get_scope_manager() -> ScopeManager:
"""Get the singleton scope manager instance (thread-safe)."""
global _manager
if _manager is None:
with _manager_lock:
# Double-check locking pattern
if _manager is None:
_manager = ScopeManager()
return _manager
def reset_scope_manager() -> None:
"""Reset the scope manager singleton (for testing)."""
global _manager
with _manager_lock:
_manager = None

View File

@@ -0,0 +1,27 @@
# app/services/memory/semantic/__init__.py
"""
Semantic Memory
Fact storage with triple format (subject, predicate, object)
and semantic search capabilities.
"""
from .extraction import (
ExtractedFact,
ExtractionContext,
FactExtractor,
get_fact_extractor,
)
from .memory import SemanticMemory
from .verification import FactConflict, FactVerifier, VerificationResult
__all__ = [
"ExtractedFact",
"ExtractionContext",
"FactConflict",
"FactExtractor",
"FactVerifier",
"SemanticMemory",
"VerificationResult",
"get_fact_extractor",
]

View File

@@ -0,0 +1,313 @@
# app/services/memory/semantic/extraction.py
"""
Fact Extraction from Episodes.
Provides utilities for extracting semantic facts (subject-predicate-object triples)
from episodic memories and other text sources.
"""
import logging
import re
from dataclasses import dataclass, field
from typing import Any, ClassVar
from app.services.memory.types import Episode, FactCreate, Outcome
logger = logging.getLogger(__name__)
@dataclass
class ExtractionContext:
"""Context for fact extraction."""
project_id: Any | None = None
source_episode_id: Any | None = None
min_confidence: float = 0.5
max_facts_per_source: int = 10
@dataclass
class ExtractedFact:
"""A fact extracted from text before storage."""
subject: str
predicate: str
object: str
confidence: float
source_text: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
def to_fact_create(
self,
project_id: Any | None = None,
source_episode_ids: list[Any] | None = None,
) -> FactCreate:
"""Convert to FactCreate for storage."""
return FactCreate(
subject=self.subject,
predicate=self.predicate,
object=self.object,
confidence=self.confidence,
project_id=project_id,
source_episode_ids=source_episode_ids or [],
)
class FactExtractor:
"""
Extracts facts from episodes and text.
This is a rule-based extractor. In production, this would be
replaced or augmented with LLM-based extraction for better accuracy.
"""
# Common predicates we can detect
PREDICATE_PATTERNS: ClassVar[dict[str, str]] = {
"uses": r"(?:uses?|using|utilizes?)",
"requires": r"(?:requires?|needs?|depends?\s+on)",
"is_a": r"(?:is\s+a|is\s+an|are\s+a|are)",
"has": r"(?:has|have|contains?)",
"part_of": r"(?:part\s+of|belongs?\s+to|member\s+of)",
"causes": r"(?:causes?|leads?\s+to|results?\s+in)",
"prevents": r"(?:prevents?|avoids?|stops?)",
"solves": r"(?:solves?|fixes?|resolves?)",
}
def __init__(self) -> None:
"""Initialize extractor."""
self._compiled_patterns = {
pred: re.compile(pattern, re.IGNORECASE)
for pred, pattern in self.PREDICATE_PATTERNS.items()
}
def extract_from_episode(
self,
episode: Episode,
context: ExtractionContext | None = None,
) -> list[ExtractedFact]:
"""
Extract facts from an episode.
Args:
episode: Episode to extract from
context: Optional extraction context
Returns:
List of extracted facts
"""
ctx = context or ExtractionContext()
facts: list[ExtractedFact] = []
# Extract from task description
task_facts = self._extract_from_text(
episode.task_description,
source_prefix=episode.task_type,
)
facts.extend(task_facts)
# Extract from lessons learned
for lesson in episode.lessons_learned:
lesson_facts = self._extract_from_lesson(lesson, episode)
facts.extend(lesson_facts)
# Extract outcome-based facts
outcome_facts = self._extract_outcome_facts(episode)
facts.extend(outcome_facts)
# Limit and filter
facts = [f for f in facts if f.confidence >= ctx.min_confidence]
facts = facts[: ctx.max_facts_per_source]
logger.debug(f"Extracted {len(facts)} facts from episode {episode.id}")
return facts
def _extract_from_text(
self,
text: str,
source_prefix: str = "",
) -> list[ExtractedFact]:
"""Extract facts from free-form text using pattern matching."""
facts: list[ExtractedFact] = []
if not text or len(text) < 10:
return facts
# Split into sentences
sentences = re.split(r"[.!?]+", text)
for sentence in sentences:
sentence = sentence.strip()
if len(sentence) < 10:
continue
# Try to match predicate patterns
for predicate, pattern in self._compiled_patterns.items():
match = pattern.search(sentence)
if match:
# Extract subject (text before predicate)
subject = sentence[: match.start()].strip()
# Extract object (text after predicate)
obj = sentence[match.end() :].strip()
if len(subject) > 2 and len(obj) > 2:
facts.append(
ExtractedFact(
subject=subject[:200], # Limit length
predicate=predicate,
object=obj[:500],
confidence=0.6, # Medium confidence for pattern matching
source_text=sentence,
)
)
break # One fact per sentence
return facts
def _extract_from_lesson(
self,
lesson: str,
episode: Episode,
) -> list[ExtractedFact]:
"""Extract facts from a lesson learned."""
facts: list[ExtractedFact] = []
if not lesson or len(lesson) < 10:
return facts
# Lessons are typically in the form "Always do X" or "Never do Y"
# or "When X, do Y"
# Direct lesson fact
facts.append(
ExtractedFact(
subject=episode.task_type,
predicate="lesson_learned",
object=lesson,
confidence=0.8, # High confidence for explicit lessons
source_text=lesson,
metadata={"outcome": episode.outcome.value},
)
)
# Extract conditional patterns
conditional_match = re.match(
r"(?:when|if)\s+(.+?),\s*(.+)",
lesson,
re.IGNORECASE,
)
if conditional_match:
condition, action = conditional_match.groups()
facts.append(
ExtractedFact(
subject=condition.strip(),
predicate="requires_action",
object=action.strip(),
confidence=0.7,
source_text=lesson,
)
)
# Extract "always/never" patterns
always_match = re.match(
r"(?:always)\s+(.+)",
lesson,
re.IGNORECASE,
)
if always_match:
facts.append(
ExtractedFact(
subject=episode.task_type,
predicate="best_practice",
object=always_match.group(1).strip(),
confidence=0.85,
source_text=lesson,
)
)
never_match = re.match(
r"(?:never|avoid)\s+(.+)",
lesson,
re.IGNORECASE,
)
if never_match:
facts.append(
ExtractedFact(
subject=episode.task_type,
predicate="anti_pattern",
object=never_match.group(1).strip(),
confidence=0.85,
source_text=lesson,
)
)
return facts
def _extract_outcome_facts(
self,
episode: Episode,
) -> list[ExtractedFact]:
"""Extract facts based on episode outcome."""
facts: list[ExtractedFact] = []
# Create fact based on outcome
if episode.outcome == Outcome.SUCCESS:
if episode.outcome_details:
facts.append(
ExtractedFact(
subject=episode.task_type,
predicate="successful_approach",
object=episode.outcome_details[:500],
confidence=0.75,
source_text=episode.outcome_details,
)
)
elif episode.outcome == Outcome.FAILURE:
if episode.outcome_details:
facts.append(
ExtractedFact(
subject=episode.task_type,
predicate="known_failure_mode",
object=episode.outcome_details[:500],
confidence=0.8, # High confidence for failures
source_text=episode.outcome_details,
)
)
return facts
def extract_from_text(
self,
text: str,
context: ExtractionContext | None = None,
) -> list[ExtractedFact]:
"""
Extract facts from arbitrary text.
Args:
text: Text to extract from
context: Optional extraction context
Returns:
List of extracted facts
"""
ctx = context or ExtractionContext()
facts = self._extract_from_text(text)
# Filter by confidence
facts = [f for f in facts if f.confidence >= ctx.min_confidence]
return facts[: ctx.max_facts_per_source]
# Singleton extractor instance
_extractor: FactExtractor | None = None
def get_fact_extractor() -> FactExtractor:
"""Get the singleton fact extractor instance."""
global _extractor
if _extractor is None:
_extractor = FactExtractor()
return _extractor

View File

@@ -0,0 +1,767 @@
# app/services/memory/semantic/memory.py
"""
Semantic Memory Implementation.
Provides fact storage and retrieval using subject-predicate-object triples.
Supports semantic search, confidence scoring, and fact reinforcement.
"""
import logging
import time
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import and_, desc, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory.fact import Fact as FactModel
from app.services.memory.config import get_memory_settings
from app.services.memory.types import Episode, Fact, FactCreate, RetrievalResult
logger = logging.getLogger(__name__)
def _escape_like_pattern(pattern: str) -> str:
"""
Escape SQL LIKE/ILIKE special characters to prevent pattern injection.
Characters escaped:
- % (matches zero or more characters)
- _ (matches exactly one character)
- \\ (escape character itself)
Args:
pattern: Raw search pattern from user input
Returns:
Escaped pattern safe for use in LIKE/ILIKE queries
"""
# Escape backslash first, then the wildcards
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
def _model_to_fact(model: FactModel) -> Fact:
"""Convert SQLAlchemy model to Fact dataclass."""
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
# they return actual values. We use type: ignore to handle this mismatch.
return Fact(
id=model.id, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
subject=model.subject, # type: ignore[arg-type]
predicate=model.predicate, # type: ignore[arg-type]
object=model.object, # type: ignore[arg-type]
confidence=model.confidence, # type: ignore[arg-type]
source_episode_ids=model.source_episode_ids or [], # type: ignore[arg-type]
first_learned=model.first_learned, # type: ignore[arg-type]
last_reinforced=model.last_reinforced, # type: ignore[arg-type]
reinforcement_count=model.reinforcement_count, # type: ignore[arg-type]
embedding=None, # Don't expose raw embedding
created_at=model.created_at, # type: ignore[arg-type]
updated_at=model.updated_at, # type: ignore[arg-type]
)
class SemanticMemory:
"""
Semantic Memory Service.
Provides fact storage and retrieval:
- Store facts as subject-predicate-object triples
- Semantic search over facts
- Entity-based retrieval
- Confidence scoring and decay
- Fact reinforcement on repeated learning
- Conflict resolution
Performance target: <100ms P95 for retrieval
"""
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> None:
"""
Initialize semantic memory.
Args:
session: Database session
embedding_generator: Optional embedding generator for semantic search
"""
self._session = session
self._embedding_generator = embedding_generator
self._settings = get_memory_settings()
@classmethod
async def create(
cls,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> "SemanticMemory":
"""
Factory method to create SemanticMemory.
Args:
session: Database session
embedding_generator: Optional embedding generator
Returns:
Configured SemanticMemory instance
"""
return cls(session=session, embedding_generator=embedding_generator)
# =========================================================================
# Fact Storage
# =========================================================================
async def store_fact(self, fact: FactCreate) -> Fact:
"""
Store a new fact or reinforce an existing one.
If a fact with the same triple (subject, predicate, object) exists
in the same scope, it will be reinforced instead of duplicated.
Args:
fact: Fact data to store
Returns:
The created or reinforced fact
"""
# Check for existing fact with same triple
existing = await self._find_existing_fact(
project_id=fact.project_id,
subject=fact.subject,
predicate=fact.predicate,
object=fact.object,
)
if existing is not None:
# Reinforce existing fact
return await self.reinforce_fact(
existing.id, # type: ignore[arg-type]
source_episode_ids=fact.source_episode_ids,
)
# Create new fact
now = datetime.now(UTC)
# Generate embedding if possible
embedding = None
if self._embedding_generator is not None:
embedding_text = self._create_embedding_text(fact)
embedding = await self._embedding_generator.generate(embedding_text)
model = FactModel(
project_id=fact.project_id,
subject=fact.subject,
predicate=fact.predicate,
object=fact.object,
confidence=fact.confidence,
source_episode_ids=fact.source_episode_ids,
first_learned=now,
last_reinforced=now,
reinforcement_count=1,
embedding=embedding,
)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
logger.info(
f"Stored new fact: {fact.subject} - {fact.predicate} - {fact.object[:50]}..."
)
return _model_to_fact(model)
async def _find_existing_fact(
self,
project_id: UUID | None,
subject: str,
predicate: str,
object: str,
) -> FactModel | None:
"""Find an existing fact with the same triple in the same scope."""
query = select(FactModel).where(
and_(
FactModel.subject == subject,
FactModel.predicate == predicate,
FactModel.object == object,
)
)
if project_id is not None:
query = query.where(FactModel.project_id == project_id)
else:
query = query.where(FactModel.project_id.is_(None))
result = await self._session.execute(query)
return result.scalar_one_or_none()
def _create_embedding_text(self, fact: FactCreate) -> str:
"""Create text for embedding from fact data."""
return f"{fact.subject} {fact.predicate} {fact.object}"
# =========================================================================
# Fact Retrieval
# =========================================================================
async def search_facts(
self,
query: str,
project_id: UUID | None = None,
limit: int = 10,
min_confidence: float | None = None,
) -> list[Fact]:
"""
Search for facts semantically similar to the query.
Args:
query: Search query
project_id: Optional project to search within
limit: Maximum results
min_confidence: Optional minimum confidence filter
Returns:
List of matching facts
"""
result = await self._search_facts_with_metadata(
query=query,
project_id=project_id,
limit=limit,
min_confidence=min_confidence,
)
return result.items
async def _search_facts_with_metadata(
self,
query: str,
project_id: UUID | None = None,
limit: int = 10,
min_confidence: float | None = None,
) -> RetrievalResult[Fact]:
"""Search facts with full result metadata."""
start_time = time.perf_counter()
min_conf = min_confidence or self._settings.semantic_min_confidence
# Build base query
stmt = (
select(FactModel)
.where(FactModel.confidence >= min_conf)
.order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced))
.limit(limit)
)
# Apply project filter
if project_id is not None:
# Include both project-specific and global facts
stmt = stmt.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
# TODO: Implement proper vector similarity search when pgvector is integrated
# For now, do text-based search on subject/predicate/object
search_terms = query.lower().split()
if search_terms:
conditions = []
for term in search_terms[:5]: # Limit to 5 terms
# Escape SQL wildcards to prevent pattern injection
escaped_term = _escape_like_pattern(term)
term_pattern = f"%{escaped_term}%"
conditions.append(
or_(
FactModel.subject.ilike(term_pattern),
FactModel.predicate.ilike(term_pattern),
FactModel.object.ilike(term_pattern),
)
)
if conditions:
stmt = stmt.where(or_(*conditions))
result = await self._session.execute(stmt)
models = list(result.scalars().all())
latency_ms = (time.perf_counter() - start_time) * 1000
return RetrievalResult(
items=[_model_to_fact(m) for m in models],
total_count=len(models),
query=query,
retrieval_type="semantic",
latency_ms=latency_ms,
metadata={"min_confidence": min_conf},
)
async def get_by_entity(
self,
entity: str,
project_id: UUID | None = None,
limit: int = 20,
) -> list[Fact]:
"""
Get facts related to an entity (as subject or object).
Args:
entity: Entity to search for
project_id: Optional project to search within
limit: Maximum results
Returns:
List of facts mentioning the entity
"""
start_time = time.perf_counter()
# Escape SQL wildcards to prevent pattern injection
escaped_entity = _escape_like_pattern(entity)
entity_pattern = f"%{escaped_entity}%"
stmt = (
select(FactModel)
.where(
or_(
FactModel.subject.ilike(entity_pattern),
FactModel.object.ilike(entity_pattern),
)
)
.order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced))
.limit(limit)
)
if project_id is not None:
stmt = stmt.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(stmt)
models = list(result.scalars().all())
latency_ms = (time.perf_counter() - start_time) * 1000
logger.debug(
f"get_by_entity({entity}) returned {len(models)} facts in {latency_ms:.1f}ms"
)
return [_model_to_fact(m) for m in models]
async def get_by_subject(
self,
subject: str,
project_id: UUID | None = None,
predicate: str | None = None,
limit: int = 20,
) -> list[Fact]:
"""
Get facts with a specific subject.
Args:
subject: Subject to search for
project_id: Optional project to search within
predicate: Optional predicate filter
limit: Maximum results
Returns:
List of facts with matching subject
"""
stmt = (
select(FactModel)
.where(FactModel.subject == subject)
.order_by(desc(FactModel.confidence))
.limit(limit)
)
if predicate is not None:
stmt = stmt.where(FactModel.predicate == predicate)
if project_id is not None:
stmt = stmt.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(stmt)
models = list(result.scalars().all())
return [_model_to_fact(m) for m in models]
async def get_by_id(self, fact_id: UUID) -> Fact | None:
"""Get a fact by ID."""
query = select(FactModel).where(FactModel.id == fact_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
return _model_to_fact(model) if model else None
# =========================================================================
# Fact Reinforcement
# =========================================================================
async def reinforce_fact(
self,
fact_id: UUID,
confidence_boost: float = 0.1,
source_episode_ids: list[UUID] | None = None,
) -> Fact:
"""
Reinforce a fact, increasing its confidence.
Args:
fact_id: Fact to reinforce
confidence_boost: Amount to increase confidence (default 0.1)
source_episode_ids: Additional source episodes
Returns:
Updated fact
Raises:
ValueError: If fact not found
"""
query = select(FactModel).where(FactModel.id == fact_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
raise ValueError(f"Fact not found: {fact_id}")
# Calculate new confidence (max 1.0)
current_confidence: float = model.confidence # type: ignore[assignment]
new_confidence = min(1.0, current_confidence + confidence_boost)
# Merge source episode IDs
current_sources: list[UUID] = model.source_episode_ids or [] # type: ignore[assignment]
if source_episode_ids:
# Add new sources, avoiding duplicates
new_sources = list(set(current_sources + source_episode_ids))
else:
new_sources = current_sources
now = datetime.now(UTC)
stmt = (
update(FactModel)
.where(FactModel.id == fact_id)
.values(
confidence=new_confidence,
source_episode_ids=new_sources,
last_reinforced=now,
reinforcement_count=FactModel.reinforcement_count + 1,
updated_at=now,
)
.returning(FactModel)
)
result = await self._session.execute(stmt)
updated_model = result.scalar_one()
await self._session.flush()
logger.info(
f"Reinforced fact {fact_id}: confidence {current_confidence:.2f} -> {new_confidence:.2f}"
)
return _model_to_fact(updated_model)
async def deprecate_fact(
self,
fact_id: UUID,
reason: str,
new_confidence: float = 0.0,
) -> Fact | None:
"""
Deprecate a fact by lowering its confidence.
Args:
fact_id: Fact to deprecate
reason: Reason for deprecation
new_confidence: New confidence level (default 0.0)
Returns:
Updated fact or None if not found
"""
query = select(FactModel).where(FactModel.id == fact_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
return None
now = datetime.now(UTC)
stmt = (
update(FactModel)
.where(FactModel.id == fact_id)
.values(
confidence=max(0.0, new_confidence),
updated_at=now,
)
.returning(FactModel)
)
result = await self._session.execute(stmt)
updated_model = result.scalar_one_or_none()
await self._session.flush()
logger.info(f"Deprecated fact {fact_id}: {reason}")
return _model_to_fact(updated_model) if updated_model else None
# =========================================================================
# Fact Extraction from Episodes
# =========================================================================
async def extract_facts_from_episode(
self,
episode: Episode,
) -> list[Fact]:
"""
Extract facts from an episode.
This is a placeholder for LLM-based fact extraction.
In production, this would call an LLM to analyze the episode
and extract subject-predicate-object triples.
Args:
episode: Episode to extract facts from
Returns:
List of extracted facts
"""
# For now, extract basic facts from lessons learned
extracted_facts: list[Fact] = []
for lesson in episode.lessons_learned:
if len(lesson) > 10: # Skip very short lessons
fact_create = FactCreate(
subject=episode.task_type,
predicate="lesson_learned",
object=lesson,
confidence=0.7, # Lessons start with moderate confidence
project_id=episode.project_id,
source_episode_ids=[episode.id],
)
fact = await self.store_fact(fact_create)
extracted_facts.append(fact)
logger.debug(
f"Extracted {len(extracted_facts)} facts from episode {episode.id}"
)
return extracted_facts
# =========================================================================
# Conflict Resolution
# =========================================================================
async def resolve_conflict(
self,
fact_ids: list[UUID],
keep_fact_id: UUID | None = None,
) -> Fact | None:
"""
Resolve a conflict between multiple facts.
If keep_fact_id is specified, that fact is kept and others are deprecated.
Otherwise, the fact with highest confidence is kept.
Args:
fact_ids: IDs of conflicting facts
keep_fact_id: Optional ID of fact to keep
Returns:
The winning fact, or None if no facts found
"""
if not fact_ids:
return None
# Load all facts
query = select(FactModel).where(FactModel.id.in_(fact_ids))
result = await self._session.execute(query)
models = list(result.scalars().all())
if not models:
return None
# Determine winner
if keep_fact_id is not None:
winner = next((m for m in models if m.id == keep_fact_id), None)
if winner is None:
# Fallback to highest confidence
winner = max(models, key=lambda m: m.confidence)
else:
# Keep the fact with highest confidence
winner = max(models, key=lambda m: m.confidence)
# Deprecate losers
for model in models:
if model.id != winner.id:
await self.deprecate_fact(
model.id, # type: ignore[arg-type]
reason=f"Conflict resolution: superseded by {winner.id}",
)
logger.info(
f"Resolved conflict between {len(fact_ids)} facts, keeping {winner.id}"
)
return _model_to_fact(winner)
# =========================================================================
# Confidence Decay
# =========================================================================
async def apply_confidence_decay(
self,
project_id: UUID | None = None,
decay_factor: float = 0.01,
) -> int:
"""
Apply confidence decay to facts that haven't been reinforced recently.
Args:
project_id: Optional project to apply decay to
decay_factor: Decay factor per day (default 0.01)
Returns:
Number of facts affected
"""
now = datetime.now(UTC)
decay_days = self._settings.semantic_confidence_decay_days
min_conf = self._settings.semantic_min_confidence
# Calculate cutoff date
from datetime import timedelta
cutoff = now - timedelta(days=decay_days)
# Find facts needing decay
query = select(FactModel).where(
and_(
FactModel.last_reinforced < cutoff,
FactModel.confidence > min_conf,
)
)
if project_id is not None:
query = query.where(FactModel.project_id == project_id)
result = await self._session.execute(query)
models = list(result.scalars().all())
# Apply decay
updated_count = 0
for model in models:
# Calculate days since last reinforcement
days_since: float = (now - model.last_reinforced).days
# Calculate decay: exponential decay based on days
decay = decay_factor * (days_since - decay_days)
new_confidence = max(min_conf, model.confidence - decay)
if new_confidence != model.confidence:
await self._session.execute(
update(FactModel)
.where(FactModel.id == model.id)
.values(confidence=new_confidence, updated_at=now)
)
updated_count += 1
await self._session.flush()
logger.info(f"Applied confidence decay to {updated_count} facts")
return updated_count
# =========================================================================
# Statistics
# =========================================================================
async def get_stats(self, project_id: UUID | None = None) -> dict[str, Any]:
"""
Get statistics about semantic memory.
Args:
project_id: Optional project to get stats for
Returns:
Dictionary with statistics
"""
# Get all facts for this scope
query = select(FactModel)
if project_id is not None:
query = query.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
if not models:
return {
"total_facts": 0,
"avg_confidence": 0.0,
"avg_reinforcement_count": 0.0,
"high_confidence_count": 0,
"low_confidence_count": 0,
}
confidences = [m.confidence for m in models]
reinforcements = [m.reinforcement_count for m in models]
return {
"total_facts": len(models),
"avg_confidence": sum(confidences) / len(confidences),
"avg_reinforcement_count": sum(reinforcements) / len(reinforcements),
"high_confidence_count": sum(1 for c in confidences if c >= 0.8),
"low_confidence_count": sum(1 for c in confidences if c < 0.5),
}
async def count(self, project_id: UUID | None = None) -> int:
"""
Count facts in scope.
Args:
project_id: Optional project to count for
Returns:
Number of facts
"""
query = select(FactModel)
if project_id is not None:
query = query.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
return len(list(result.scalars().all()))
async def delete(self, fact_id: UUID) -> bool:
"""
Delete a fact.
Args:
fact_id: Fact to delete
Returns:
True if deleted, False if not found
"""
query = select(FactModel).where(FactModel.id == fact_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
return False
await self._session.delete(model)
await self._session.flush()
logger.info(f"Deleted fact {fact_id}")
return True

View File

@@ -0,0 +1,363 @@
# app/services/memory/semantic/verification.py
"""
Fact Verification.
Provides utilities for verifying facts, detecting conflicts,
and managing fact consistency.
"""
import logging
from dataclasses import dataclass, field
from typing import Any, ClassVar
from uuid import UUID
from sqlalchemy import and_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory.fact import Fact as FactModel
from app.services.memory.types import Fact
logger = logging.getLogger(__name__)
@dataclass
class VerificationResult:
"""Result of fact verification."""
is_valid: bool
confidence_adjustment: float = 0.0
conflicts: list["FactConflict"] = field(default_factory=list)
supporting_facts: list[Fact] = field(default_factory=list)
messages: list[str] = field(default_factory=list)
@dataclass
class FactConflict:
"""Represents a conflict between two facts."""
fact_a_id: UUID
fact_b_id: UUID
conflict_type: str # "contradiction", "superseded", "partial_overlap"
description: str
suggested_resolution: str | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"fact_a_id": str(self.fact_a_id),
"fact_b_id": str(self.fact_b_id),
"conflict_type": self.conflict_type,
"description": self.description,
"suggested_resolution": self.suggested_resolution,
}
class FactVerifier:
"""
Verifies facts and detects conflicts.
Provides methods to:
- Check if a fact conflicts with existing facts
- Find supporting evidence for a fact
- Detect contradictions in the fact base
"""
# Predicates that are opposites/contradictions
CONTRADICTORY_PREDICATES: ClassVar[set[tuple[str, str]]] = {
("uses", "does_not_use"),
("requires", "does_not_require"),
("is_a", "is_not_a"),
("causes", "prevents"),
("allows", "prevents"),
("supports", "does_not_support"),
("best_practice", "anti_pattern"),
}
def __init__(self, session: AsyncSession) -> None:
"""Initialize verifier with database session."""
self._session = session
async def verify_fact(
self,
subject: str,
predicate: str,
obj: str,
project_id: UUID | None = None,
) -> VerificationResult:
"""
Verify a fact against existing facts.
Args:
subject: Fact subject
predicate: Fact predicate
obj: Fact object
project_id: Optional project scope
Returns:
VerificationResult with verification details
"""
result = VerificationResult(is_valid=True)
# Check for direct contradictions
conflicts = await self._find_contradictions(
subject=subject,
predicate=predicate,
obj=obj,
project_id=project_id,
)
result.conflicts = conflicts
if conflicts:
result.is_valid = False
result.messages.append(f"Found {len(conflicts)} conflicting fact(s)")
# Reduce confidence based on conflicts
result.confidence_adjustment = -0.1 * len(conflicts)
# Find supporting facts
supporting = await self._find_supporting_facts(
subject=subject,
predicate=predicate,
project_id=project_id,
)
result.supporting_facts = supporting
if supporting:
result.messages.append(f"Found {len(supporting)} supporting fact(s)")
# Boost confidence based on support
result.confidence_adjustment += 0.05 * min(len(supporting), 3)
return result
async def _find_contradictions(
self,
subject: str,
predicate: str,
obj: str,
project_id: UUID | None = None,
) -> list[FactConflict]:
"""Find facts that contradict the given fact."""
conflicts: list[FactConflict] = []
# Find opposite predicates
opposite_predicates = self._get_opposite_predicates(predicate)
if not opposite_predicates:
return conflicts
# Search for contradicting facts
query = select(FactModel).where(
and_(
FactModel.subject == subject,
FactModel.predicate.in_(opposite_predicates),
)
)
if project_id is not None:
query = query.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
for model in models:
conflicts.append(
FactConflict(
fact_a_id=model.id, # type: ignore[arg-type]
fact_b_id=UUID(
"00000000-0000-0000-0000-000000000000"
), # Placeholder for new fact
conflict_type="contradiction",
description=(
f"'{subject} {predicate} {obj}' contradicts "
f"'{model.subject} {model.predicate} {model.object}'"
),
suggested_resolution="Keep fact with higher confidence",
)
)
return conflicts
def _get_opposite_predicates(self, predicate: str) -> list[str]:
"""Get predicates that are opposite to the given predicate."""
opposites: list[str] = []
for pair in self.CONTRADICTORY_PREDICATES:
if predicate in pair:
opposites.extend(p for p in pair if p != predicate)
return opposites
async def _find_supporting_facts(
self,
subject: str,
predicate: str,
project_id: UUID | None = None,
) -> list[Fact]:
"""Find facts that support the given fact."""
# Find facts with same subject and predicate
query = (
select(FactModel)
.where(
and_(
FactModel.subject == subject,
FactModel.predicate == predicate,
FactModel.confidence >= 0.5,
)
)
.limit(10)
)
if project_id is not None:
query = query.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
return [self._model_to_fact(m) for m in models]
async def find_all_conflicts(
self,
project_id: UUID | None = None,
) -> list[FactConflict]:
"""
Find all conflicts in the fact base.
Args:
project_id: Optional project scope
Returns:
List of all detected conflicts
"""
conflicts: list[FactConflict] = []
# Get all facts
query = select(FactModel)
if project_id is not None:
query = query.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
# Check each pair for conflicts
for i, fact_a in enumerate(models):
for fact_b in models[i + 1 :]:
conflict = self._check_pair_conflict(fact_a, fact_b)
if conflict:
conflicts.append(conflict)
logger.info(f"Found {len(conflicts)} conflicts in fact base")
return conflicts
def _check_pair_conflict(
self,
fact_a: FactModel,
fact_b: FactModel,
) -> FactConflict | None:
"""Check if two facts conflict."""
# Same subject?
if fact_a.subject != fact_b.subject:
return None
# Contradictory predicates?
opposite = self._get_opposite_predicates(fact_a.predicate) # type: ignore[arg-type]
if fact_b.predicate not in opposite:
return None
return FactConflict(
fact_a_id=fact_a.id, # type: ignore[arg-type]
fact_b_id=fact_b.id, # type: ignore[arg-type]
conflict_type="contradiction",
description=(
f"'{fact_a.subject} {fact_a.predicate} {fact_a.object}' "
f"contradicts '{fact_b.subject} {fact_b.predicate} {fact_b.object}'"
),
suggested_resolution="Deprecate fact with lower confidence",
)
async def get_fact_reliability_score(
self,
fact_id: UUID,
) -> float:
"""
Calculate a reliability score for a fact.
Based on:
- Confidence score
- Number of reinforcements
- Number of supporting facts
- Absence of conflicts
Args:
fact_id: Fact to score
Returns:
Reliability score (0.0 to 1.0)
"""
query = select(FactModel).where(FactModel.id == fact_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
return 0.0
# Base score from confidence - explicitly typed to avoid Column type issues
score: float = float(model.confidence)
# Boost for reinforcements (diminishing returns)
reinforcement_boost = min(0.2, float(model.reinforcement_count) * 0.02)
score += reinforcement_boost
# Find supporting facts
supporting = await self._find_supporting_facts(
subject=model.subject, # type: ignore[arg-type]
predicate=model.predicate, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
)
support_boost = min(0.1, len(supporting) * 0.02)
score += support_boost
# Check for conflicts
conflicts = await self._find_contradictions(
subject=model.subject, # type: ignore[arg-type]
predicate=model.predicate, # type: ignore[arg-type]
obj=model.object, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
)
conflict_penalty = min(0.3, len(conflicts) * 0.1)
score -= conflict_penalty
# Clamp to valid range
return max(0.0, min(1.0, score))
def _model_to_fact(self, model: FactModel) -> Fact:
"""Convert SQLAlchemy model to Fact dataclass."""
return Fact(
id=model.id, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
subject=model.subject, # type: ignore[arg-type]
predicate=model.predicate, # type: ignore[arg-type]
object=model.object, # type: ignore[arg-type]
confidence=model.confidence, # type: ignore[arg-type]
source_episode_ids=model.source_episode_ids or [], # type: ignore[arg-type]
first_learned=model.first_learned, # type: ignore[arg-type]
last_reinforced=model.last_reinforced, # type: ignore[arg-type]
reinforcement_count=model.reinforcement_count, # type: ignore[arg-type]
embedding=None,
created_at=model.created_at, # type: ignore[arg-type]
updated_at=model.updated_at, # type: ignore[arg-type]
)

View File

@@ -0,0 +1,328 @@
"""
Memory System Types
Core type definitions and interfaces for the Agent Memory System.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from uuid import UUID
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
class MemoryType(str, Enum):
"""Types of memory in the agent memory system."""
WORKING = "working"
EPISODIC = "episodic"
SEMANTIC = "semantic"
PROCEDURAL = "procedural"
class ScopeLevel(str, Enum):
"""Hierarchical scoping levels for memory."""
GLOBAL = "global"
PROJECT = "project"
AGENT_TYPE = "agent_type"
AGENT_INSTANCE = "agent_instance"
SESSION = "session"
class Outcome(str, Enum):
"""Outcome of a task or episode."""
SUCCESS = "success"
FAILURE = "failure"
PARTIAL = "partial"
ABANDONED = "abandoned"
class ConsolidationStatus(str, Enum):
"""Status of a memory consolidation job."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class ConsolidationType(str, Enum):
"""Types of memory consolidation."""
WORKING_TO_EPISODIC = "working_to_episodic"
EPISODIC_TO_SEMANTIC = "episodic_to_semantic"
EPISODIC_TO_PROCEDURAL = "episodic_to_procedural"
PRUNING = "pruning"
@dataclass
class ScopeContext:
"""Represents a memory scope with its hierarchy."""
scope_type: ScopeLevel
scope_id: str
parent: "ScopeContext | None" = None
def get_hierarchy(self) -> list["ScopeContext"]:
"""Get the full scope hierarchy from root to this scope."""
hierarchy: list[ScopeContext] = []
current: ScopeContext | None = self
while current is not None:
hierarchy.insert(0, current)
current = current.parent
return hierarchy
def to_key_prefix(self) -> str:
"""Convert scope to a key prefix for storage."""
return f"{self.scope_type.value}:{self.scope_id}"
@dataclass
class MemoryItem:
"""Base class for all memory items."""
id: UUID
memory_type: MemoryType
scope_type: ScopeLevel
scope_id: str
created_at: datetime
updated_at: datetime
metadata: dict[str, Any] = field(default_factory=dict)
def get_age_seconds(self) -> float:
"""Get the age of this memory item in seconds."""
return (_utcnow() - self.created_at).total_seconds()
@dataclass
class WorkingMemoryItem:
"""A key-value item in working memory."""
id: UUID
scope_type: ScopeLevel
scope_id: str
key: str
value: Any
expires_at: datetime | None = None
created_at: datetime = field(default_factory=_utcnow)
updated_at: datetime = field(default_factory=_utcnow)
def is_expired(self) -> bool:
"""Check if this item has expired."""
if self.expires_at is None:
return False
return _utcnow() > self.expires_at
@dataclass
class TaskState:
"""Current state of a task in working memory."""
task_id: str
task_type: str
description: str
status: str = "in_progress"
current_step: int = 0
total_steps: int = 0
progress_percent: float = 0.0
context: dict[str, Any] = field(default_factory=dict)
started_at: datetime = field(default_factory=_utcnow)
updated_at: datetime = field(default_factory=_utcnow)
@dataclass
class Episode:
"""An episodic memory - a recorded experience."""
id: UUID
project_id: UUID
agent_instance_id: UUID | None
agent_type_id: UUID | None
session_id: str
task_type: str
task_description: str
actions: list[dict[str, Any]]
context_summary: str
outcome: Outcome
outcome_details: str
duration_seconds: float
tokens_used: int
lessons_learned: list[str]
importance_score: float
embedding: list[float] | None
occurred_at: datetime
created_at: datetime
updated_at: datetime
@dataclass
class EpisodeCreate:
"""Data required to create a new episode."""
project_id: UUID
session_id: str
task_type: str
task_description: str
actions: list[dict[str, Any]]
context_summary: str
outcome: Outcome
outcome_details: str
duration_seconds: float
tokens_used: int
lessons_learned: list[str] = field(default_factory=list)
importance_score: float = 0.5
agent_instance_id: UUID | None = None
agent_type_id: UUID | None = None
@dataclass
class Fact:
"""A semantic memory fact - a piece of knowledge."""
id: UUID
project_id: UUID | None # None for global facts
subject: str
predicate: str
object: str
confidence: float
source_episode_ids: list[UUID]
first_learned: datetime
last_reinforced: datetime
reinforcement_count: int
embedding: list[float] | None
created_at: datetime
updated_at: datetime
@dataclass
class FactCreate:
"""Data required to create a new fact."""
subject: str
predicate: str
object: str
confidence: float = 0.8
project_id: UUID | None = None
source_episode_ids: list[UUID] = field(default_factory=list)
@dataclass
class Procedure:
"""A procedural memory - a learned skill or procedure."""
id: UUID
project_id: UUID | None
agent_type_id: UUID | None
name: str
trigger_pattern: str
steps: list[dict[str, Any]]
success_count: int
failure_count: int
last_used: datetime | None
embedding: list[float] | None
created_at: datetime
updated_at: datetime
@property
def success_rate(self) -> float:
"""Calculate the success rate of this procedure."""
total = self.success_count + self.failure_count
if total == 0:
return 0.0
return self.success_count / total
@dataclass
class ProcedureCreate:
"""Data required to create a new procedure."""
name: str
trigger_pattern: str
steps: list[dict[str, Any]]
project_id: UUID | None = None
agent_type_id: UUID | None = None
@dataclass
class Step:
"""A single step in a procedure."""
order: int
action: str
parameters: dict[str, Any] = field(default_factory=dict)
expected_outcome: str = ""
fallback_action: str | None = None
class MemoryStore[T: MemoryItem](ABC):
"""Abstract base class for memory storage backends."""
@abstractmethod
async def store(self, item: T) -> T:
"""Store a memory item."""
...
@abstractmethod
async def get(self, item_id: UUID) -> T | None:
"""Get a memory item by ID."""
...
@abstractmethod
async def delete(self, item_id: UUID) -> bool:
"""Delete a memory item."""
...
@abstractmethod
async def list(
self,
scope_type: ScopeLevel | None = None,
scope_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[T]:
"""List memory items with optional scope filtering."""
...
@abstractmethod
async def count(
self,
scope_type: ScopeLevel | None = None,
scope_id: str | None = None,
) -> int:
"""Count memory items with optional scope filtering."""
...
@dataclass
class RetrievalResult[T]:
"""Result of a memory retrieval operation."""
items: list[T]
total_count: int
query: str
retrieval_type: str
latency_ms: float
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class MemoryStats:
"""Statistics about memory usage."""
memory_type: MemoryType
scope_type: ScopeLevel | None
scope_id: str | None
item_count: int
total_size_bytes: int
oldest_item_age_seconds: float
newest_item_age_seconds: float
avg_item_size_bytes: float
metadata: dict[str, Any] = field(default_factory=dict)

View File

@@ -0,0 +1,16 @@
# app/services/memory/working/__init__.py
"""
Working Memory Implementation.
Provides short-term memory storage with Redis primary and in-memory fallback.
"""
from .memory import WorkingMemory
from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage
__all__ = [
"InMemoryStorage",
"RedisStorage",
"WorkingMemory",
"WorkingMemoryStorage",
]

View File

@@ -0,0 +1,544 @@
# app/services/memory/working/memory.py
"""
Working Memory Implementation.
Provides session-scoped ephemeral memory with:
- Key-value storage with TTL
- Task state tracking
- Scratchpad for reasoning steps
- Checkpoint/snapshot support
"""
import logging
import uuid
from dataclasses import asdict
from datetime import UTC, datetime
from typing import Any
from app.services.memory.config import get_memory_settings
from app.services.memory.exceptions import (
MemoryConnectionError,
MemoryNotFoundError,
)
from app.services.memory.types import ScopeContext, ScopeLevel, TaskState
from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage
logger = logging.getLogger(__name__)
# Reserved key prefixes for internal use
_TASK_STATE_KEY = "_task_state"
_SCRATCHPAD_KEY = "_scratchpad"
_CHECKPOINT_PREFIX = "_checkpoint:"
_METADATA_KEY = "_metadata"
class WorkingMemory:
"""
Session-scoped working memory.
Provides ephemeral storage for agent's current task context:
- Variables and intermediate data
- Task state (current step, status, progress)
- Scratchpad for reasoning steps
- Checkpoints for recovery
Uses Redis as primary storage with in-memory fallback.
"""
def __init__(
self,
scope: ScopeContext,
storage: WorkingMemoryStorage,
default_ttl_seconds: int | None = None,
) -> None:
"""
Initialize working memory for a scope.
Args:
scope: The scope context (session, agent instance, etc.)
storage: Storage backend (use create() factory for auto-configuration)
default_ttl_seconds: Default TTL for keys (None = no expiration)
"""
self._scope = scope
self._storage: WorkingMemoryStorage = storage
self._default_ttl = default_ttl_seconds
self._using_fallback = False
self._initialized = False
@classmethod
async def create(
cls,
scope: ScopeContext,
default_ttl_seconds: int | None = None,
) -> "WorkingMemory":
"""
Factory method to create WorkingMemory with auto-configured storage.
Attempts Redis first, falls back to in-memory if unavailable.
"""
settings = get_memory_settings()
key_prefix = f"wm:{scope.to_key_prefix()}:"
storage: WorkingMemoryStorage
# Try Redis first
if settings.working_memory_backend == "redis":
redis_storage = RedisStorage(key_prefix=key_prefix)
try:
if await redis_storage.is_healthy():
logger.debug(f"Using Redis storage for scope {scope.scope_id}")
instance = cls(
scope=scope,
storage=redis_storage,
default_ttl_seconds=default_ttl_seconds
or settings.working_memory_default_ttl_seconds,
)
await instance._initialize()
return instance
except MemoryConnectionError:
logger.warning("Redis unavailable, falling back to in-memory storage")
await redis_storage.close()
# Fall back to in-memory
storage = InMemoryStorage(
max_keys=settings.working_memory_max_items_per_session
)
instance = cls(
scope=scope,
storage=storage,
default_ttl_seconds=default_ttl_seconds
or settings.working_memory_default_ttl_seconds,
)
instance._using_fallback = True
await instance._initialize()
return instance
@classmethod
async def for_session(
cls,
session_id: str,
project_id: str | None = None,
agent_instance_id: str | None = None,
) -> "WorkingMemory":
"""
Convenience factory for session-scoped working memory.
Args:
session_id: Unique session identifier
project_id: Optional project context
agent_instance_id: Optional agent instance context
"""
# Build scope hierarchy
parent = None
if project_id:
parent = ScopeContext(
scope_type=ScopeLevel.PROJECT,
scope_id=project_id,
)
if agent_instance_id:
parent = ScopeContext(
scope_type=ScopeLevel.AGENT_INSTANCE,
scope_id=agent_instance_id,
parent=parent,
)
scope = ScopeContext(
scope_type=ScopeLevel.SESSION,
scope_id=session_id,
parent=parent,
)
return await cls.create(scope=scope)
async def _initialize(self) -> None:
"""Initialize working memory metadata."""
if self._initialized:
return
metadata = {
"scope_type": self._scope.scope_type.value,
"scope_id": self._scope.scope_id,
"created_at": datetime.now(UTC).isoformat(),
"using_fallback": self._using_fallback,
}
await self._storage.set(_METADATA_KEY, metadata)
self._initialized = True
@property
def scope(self) -> ScopeContext:
"""Get the scope context."""
return self._scope
@property
def is_using_fallback(self) -> bool:
"""Check if using fallback in-memory storage."""
return self._using_fallback
# =========================================================================
# Basic Key-Value Operations
# =========================================================================
async def set(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""
Store a value.
Args:
key: The key to store under
value: The value to store (must be JSON-serializable)
ttl_seconds: Optional TTL (uses default if not specified)
"""
if key.startswith("_"):
raise ValueError("Keys starting with '_' are reserved for internal use")
ttl = ttl_seconds if ttl_seconds is not None else self._default_ttl
await self._storage.set(key, value, ttl)
async def get(self, key: str, default: Any = None) -> Any:
"""
Get a value.
Args:
key: The key to retrieve
default: Default value if key not found
Returns:
The stored value or default
"""
result = await self._storage.get(key)
return result if result is not None else default
async def delete(self, key: str) -> bool:
"""
Delete a key.
Args:
key: The key to delete
Returns:
True if the key existed
"""
if key.startswith("_"):
raise ValueError("Cannot delete internal keys directly")
return await self._storage.delete(key)
async def exists(self, key: str) -> bool:
"""
Check if a key exists.
Args:
key: The key to check
Returns:
True if the key exists
"""
return await self._storage.exists(key)
async def list_keys(self, pattern: str = "*") -> list[str]:
"""
List keys matching a pattern.
Args:
pattern: Glob-style pattern (default "*" for all)
Returns:
List of matching keys (excludes internal keys)
"""
all_keys = await self._storage.list_keys(pattern)
return [k for k in all_keys if not k.startswith("_")]
async def get_all(self) -> dict[str, Any]:
"""
Get all user key-value pairs.
Returns:
Dictionary of all key-value pairs (excludes internal keys)
"""
all_data = await self._storage.get_all()
return {k: v for k, v in all_data.items() if not k.startswith("_")}
async def clear(self) -> int:
"""
Clear all user keys (preserves internal state).
Returns:
Number of keys deleted
"""
# Save internal state
task_state = await self._storage.get(_TASK_STATE_KEY)
scratchpad = await self._storage.get(_SCRATCHPAD_KEY)
metadata = await self._storage.get(_METADATA_KEY)
count = await self._storage.clear()
# Restore internal state
if metadata is not None:
await self._storage.set(_METADATA_KEY, metadata)
if task_state is not None:
await self._storage.set(_TASK_STATE_KEY, task_state)
if scratchpad is not None:
await self._storage.set(_SCRATCHPAD_KEY, scratchpad)
# Adjust count for preserved keys
preserved = sum(1 for x in [task_state, scratchpad, metadata] if x is not None)
return max(0, count - preserved)
# =========================================================================
# Task State Operations
# =========================================================================
async def set_task_state(self, state: TaskState) -> None:
"""
Set the current task state.
Args:
state: The task state to store
"""
state.updated_at = datetime.now(UTC)
await self._storage.set(_TASK_STATE_KEY, asdict(state))
async def get_task_state(self) -> TaskState | None:
"""
Get the current task state.
Returns:
The current TaskState or None if not set
"""
data = await self._storage.get(_TASK_STATE_KEY)
if data is None:
return None
# Convert datetime strings back to datetime objects
if isinstance(data.get("started_at"), str):
data["started_at"] = datetime.fromisoformat(data["started_at"])
if isinstance(data.get("updated_at"), str):
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
return TaskState(**data)
async def update_task_progress(
self,
current_step: int | None = None,
progress_percent: float | None = None,
status: str | None = None,
) -> TaskState | None:
"""
Update task progress fields.
Args:
current_step: New current step number
progress_percent: New progress percentage (0.0 to 100.0)
status: New status string
Returns:
Updated TaskState or None if no task state exists
"""
state = await self.get_task_state()
if state is None:
return None
if current_step is not None:
state.current_step = current_step
if progress_percent is not None:
state.progress_percent = min(100.0, max(0.0, progress_percent))
if status is not None:
state.status = status
await self.set_task_state(state)
return state
# =========================================================================
# Scratchpad Operations
# =========================================================================
async def append_scratchpad(self, content: str) -> None:
"""
Append content to the scratchpad.
Args:
content: Text to append
"""
settings = get_memory_settings()
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
# Check capacity
if len(entries) >= settings.working_memory_max_items_per_session:
# Remove oldest entries
entries = entries[-(settings.working_memory_max_items_per_session - 1) :]
entry = {
"content": content,
"timestamp": datetime.now(UTC).isoformat(),
}
entries.append(entry)
await self._storage.set(_SCRATCHPAD_KEY, entries)
async def get_scratchpad(self) -> list[str]:
"""
Get all scratchpad entries.
Returns:
List of scratchpad content strings (ordered by time)
"""
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
return [e["content"] for e in entries]
async def get_scratchpad_with_timestamps(self) -> list[dict[str, Any]]:
"""
Get all scratchpad entries with timestamps.
Returns:
List of dicts with 'content' and 'timestamp' keys
"""
return await self._storage.get(_SCRATCHPAD_KEY) or []
async def clear_scratchpad(self) -> int:
"""
Clear the scratchpad.
Returns:
Number of entries cleared
"""
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
count = len(entries)
await self._storage.set(_SCRATCHPAD_KEY, [])
return count
# =========================================================================
# Checkpoint Operations
# =========================================================================
async def create_checkpoint(self, description: str = "") -> str:
"""
Create a checkpoint of current state.
Args:
description: Optional description of the checkpoint
Returns:
Checkpoint ID for later restoration
"""
# Use full UUID to avoid collision risk (8 chars has ~50k collision at birthday paradox)
checkpoint_id = str(uuid.uuid4())
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
# Capture all current state
all_data = await self._storage.get_all()
checkpoint = {
"id": checkpoint_id,
"description": description,
"created_at": datetime.now(UTC).isoformat(),
"data": all_data,
}
await self._storage.set(checkpoint_key, checkpoint)
logger.debug(f"Created checkpoint {checkpoint_id}")
return checkpoint_id
async def restore_checkpoint(self, checkpoint_id: str) -> None:
"""
Restore state from a checkpoint.
Args:
checkpoint_id: ID of the checkpoint to restore
Raises:
MemoryNotFoundError: If checkpoint not found
"""
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
checkpoint = await self._storage.get(checkpoint_key)
if checkpoint is None:
raise MemoryNotFoundError(f"Checkpoint {checkpoint_id} not found")
# Clear current state
await self._storage.clear()
# Restore all data from checkpoint
for key, value in checkpoint["data"].items():
await self._storage.set(key, value)
# Keep the checkpoint itself
await self._storage.set(checkpoint_key, checkpoint)
logger.debug(f"Restored checkpoint {checkpoint_id}")
async def list_checkpoints(self) -> list[dict[str, Any]]:
"""
List all available checkpoints.
Returns:
List of checkpoint metadata (id, description, created_at)
"""
checkpoint_keys = await self._storage.list_keys(f"{_CHECKPOINT_PREFIX}*")
checkpoints = []
for key in checkpoint_keys:
data = await self._storage.get(key)
if data:
checkpoints.append(
{
"id": data["id"],
"description": data["description"],
"created_at": data["created_at"],
}
)
# Sort by creation time
checkpoints.sort(key=lambda x: x["created_at"])
return checkpoints
async def delete_checkpoint(self, checkpoint_id: str) -> bool:
"""
Delete a checkpoint.
Args:
checkpoint_id: ID of the checkpoint to delete
Returns:
True if checkpoint existed
"""
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
return await self._storage.delete(checkpoint_key)
# =========================================================================
# Health and Lifecycle
# =========================================================================
async def is_healthy(self) -> bool:
"""Check if the working memory storage is healthy."""
return await self._storage.is_healthy()
async def close(self) -> None:
"""Close the working memory storage."""
if self._storage:
await self._storage.close()
async def get_stats(self) -> dict[str, Any]:
"""
Get working memory statistics.
Returns:
Dictionary with stats about current state
"""
all_keys = await self._storage.list_keys("*")
user_keys = [k for k in all_keys if not k.startswith("_")]
checkpoint_keys = [k for k in all_keys if k.startswith(_CHECKPOINT_PREFIX)]
scratchpad = await self._storage.get(_SCRATCHPAD_KEY) or []
return {
"scope_type": self._scope.scope_type.value,
"scope_id": self._scope.scope_id,
"using_fallback": self._using_fallback,
"total_keys": len(all_keys),
"user_keys": len(user_keys),
"checkpoint_count": len(checkpoint_keys),
"scratchpad_entries": len(scratchpad),
"has_task_state": await self._storage.exists(_TASK_STATE_KEY),
}

View File

@@ -0,0 +1,406 @@
# app/services/memory/working/storage.py
"""
Working Memory Storage Backends.
Provides abstract storage interface and implementations:
- RedisStorage: Primary storage using Redis with connection pooling
- InMemoryStorage: Fallback storage when Redis is unavailable
"""
import asyncio
import fnmatch
import json
import logging
from abc import ABC, abstractmethod
from datetime import UTC, datetime, timedelta
from typing import Any
from app.services.memory.config import get_memory_settings
from app.services.memory.exceptions import (
MemoryConnectionError,
MemoryStorageError,
)
logger = logging.getLogger(__name__)
class WorkingMemoryStorage(ABC):
"""Abstract base class for working memory storage backends."""
@abstractmethod
async def set(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""Store a value with optional TTL."""
...
@abstractmethod
async def get(self, key: str) -> Any | None:
"""Get a value by key, returns None if not found or expired."""
...
@abstractmethod
async def delete(self, key: str) -> bool:
"""Delete a key, returns True if existed."""
...
@abstractmethod
async def exists(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
...
@abstractmethod
async def list_keys(self, pattern: str = "*") -> list[str]:
"""List all keys matching a pattern."""
...
@abstractmethod
async def get_all(self) -> dict[str, Any]:
"""Get all key-value pairs."""
...
@abstractmethod
async def clear(self) -> int:
"""Clear all keys, returns count of deleted keys."""
...
@abstractmethod
async def is_healthy(self) -> bool:
"""Check if the storage backend is healthy."""
...
@abstractmethod
async def close(self) -> None:
"""Close the storage connection."""
...
class InMemoryStorage(WorkingMemoryStorage):
"""
In-memory storage backend for working memory.
Used as fallback when Redis is unavailable. Data is not persisted
across restarts and is not shared between processes.
"""
def __init__(self, max_keys: int = 10000) -> None:
"""Initialize in-memory storage."""
self._data: dict[str, Any] = {}
self._expirations: dict[str, datetime] = {}
self._max_keys = max_keys
self._lock = asyncio.Lock()
def _is_expired(self, key: str) -> bool:
"""Check if a key has expired."""
if key not in self._expirations:
return False
return datetime.now(UTC) > self._expirations[key]
def _cleanup_expired(self) -> None:
"""Remove all expired keys."""
now = datetime.now(UTC)
expired_keys = [
key for key, exp_time in self._expirations.items() if now > exp_time
]
for key in expired_keys:
self._data.pop(key, None)
self._expirations.pop(key, None)
async def set(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""Store a value with optional TTL."""
async with self._lock:
# Cleanup expired keys periodically
if len(self._data) % 100 == 0:
self._cleanup_expired()
# Check capacity
if key not in self._data and len(self._data) >= self._max_keys:
# Evict expired keys first
self._cleanup_expired()
if len(self._data) >= self._max_keys:
raise MemoryStorageError(
f"Working memory capacity exceeded: {self._max_keys} keys"
)
self._data[key] = value
if ttl_seconds is not None:
self._expirations[key] = datetime.now(UTC) + timedelta(
seconds=ttl_seconds
)
elif key in self._expirations:
# Remove existing expiration if no TTL specified
del self._expirations[key]
async def get(self, key: str) -> Any | None:
"""Get a value by key."""
async with self._lock:
if key not in self._data:
return None
if self._is_expired(key):
del self._data[key]
del self._expirations[key]
return None
return self._data[key]
async def delete(self, key: str) -> bool:
"""Delete a key."""
async with self._lock:
existed = key in self._data
self._data.pop(key, None)
self._expirations.pop(key, None)
return existed
async def exists(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
async with self._lock:
if key not in self._data:
return False
if self._is_expired(key):
del self._data[key]
del self._expirations[key]
return False
return True
async def list_keys(self, pattern: str = "*") -> list[str]:
"""List all keys matching a pattern."""
async with self._lock:
self._cleanup_expired()
if pattern == "*":
return list(self._data.keys())
return [key for key in self._data.keys() if fnmatch.fnmatch(key, pattern)]
async def get_all(self) -> dict[str, Any]:
"""Get all key-value pairs."""
async with self._lock:
self._cleanup_expired()
return dict(self._data)
async def clear(self) -> int:
"""Clear all keys."""
async with self._lock:
count = len(self._data)
self._data.clear()
self._expirations.clear()
return count
async def is_healthy(self) -> bool:
"""In-memory storage is always healthy."""
return True
async def close(self) -> None:
"""No cleanup needed for in-memory storage."""
class RedisStorage(WorkingMemoryStorage):
"""
Redis storage backend for working memory.
Primary storage with connection pooling, automatic reconnection,
and proper serialization of Python objects.
"""
def __init__(
self,
key_prefix: str = "",
connection_timeout: float = 5.0,
socket_timeout: float = 5.0,
) -> None:
"""
Initialize Redis storage.
Args:
key_prefix: Prefix for all keys (e.g., "session:abc123:")
connection_timeout: Timeout for establishing connections
socket_timeout: Timeout for socket operations
"""
self._key_prefix = key_prefix
self._connection_timeout = connection_timeout
self._socket_timeout = socket_timeout
self._redis: Any = None
self._lock = asyncio.Lock()
def _make_key(self, key: str) -> str:
"""Add prefix to key."""
return f"{self._key_prefix}{key}"
def _strip_prefix(self, key: str) -> str:
"""Remove prefix from key."""
if key.startswith(self._key_prefix):
return key[len(self._key_prefix) :]
return key
def _serialize(self, value: Any) -> str:
"""Serialize a Python value to JSON string."""
return json.dumps(value, default=str)
def _deserialize(self, data: str | bytes | None) -> Any | None:
"""Deserialize a JSON string to Python value."""
if data is None:
return None
if isinstance(data, bytes):
data = data.decode("utf-8")
return json.loads(data)
async def _get_client(self) -> Any:
"""Get or create Redis client."""
if self._redis is not None:
return self._redis
async with self._lock:
if self._redis is not None:
return self._redis
try:
import redis.asyncio as aioredis
except ImportError as e:
raise MemoryConnectionError(
"redis package not installed. Install with: pip install redis"
) from e
settings = get_memory_settings()
redis_url = settings.redis_url
try:
self._redis = await aioredis.from_url(
redis_url,
encoding="utf-8",
decode_responses=True,
socket_connect_timeout=self._connection_timeout,
socket_timeout=self._socket_timeout,
)
# Test connection
await self._redis.ping()
logger.info("Connected to Redis for working memory")
return self._redis
except Exception as e:
self._redis = None
raise MemoryConnectionError(f"Failed to connect to Redis: {e}") from e
async def set(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""Store a value with optional TTL."""
try:
client = await self._get_client()
full_key = self._make_key(key)
serialized = self._serialize(value)
if ttl_seconds is not None:
await client.setex(full_key, ttl_seconds, serialized)
else:
await client.set(full_key, serialized)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to set key {key}: {e}") from e
async def get(self, key: str) -> Any | None:
"""Get a value by key."""
try:
client = await self._get_client()
full_key = self._make_key(key)
data = await client.get(full_key)
return self._deserialize(data)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to get key {key}: {e}") from e
async def delete(self, key: str) -> bool:
"""Delete a key."""
try:
client = await self._get_client()
full_key = self._make_key(key)
result = await client.delete(full_key)
return bool(result)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to delete key {key}: {e}") from e
async def exists(self, key: str) -> bool:
"""Check if a key exists."""
try:
client = await self._get_client()
full_key = self._make_key(key)
result = await client.exists(full_key)
return bool(result)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to check key {key}: {e}") from e
async def list_keys(self, pattern: str = "*") -> list[str]:
"""List all keys matching a pattern."""
try:
client = await self._get_client()
full_pattern = self._make_key(pattern)
keys = await client.keys(full_pattern)
return [self._strip_prefix(key) for key in keys]
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to list keys: {e}") from e
async def get_all(self) -> dict[str, Any]:
"""Get all key-value pairs."""
try:
client = await self._get_client()
full_pattern = self._make_key("*")
keys = await client.keys(full_pattern)
if not keys:
return {}
values = await client.mget(*keys)
result = {}
for key, value in zip(keys, values, strict=False):
stripped_key = self._strip_prefix(key)
result[stripped_key] = self._deserialize(value)
return result
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to get all keys: {e}") from e
async def clear(self) -> int:
"""Clear all keys with this prefix."""
try:
client = await self._get_client()
full_pattern = self._make_key("*")
keys = await client.keys(full_pattern)
if not keys:
return 0
return await client.delete(*keys)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to clear keys: {e}") from e
async def is_healthy(self) -> bool:
"""Check if Redis connection is healthy."""
try:
client = await self._get_client()
await client.ping()
return True
except Exception:
return False
async def close(self) -> None:
"""Close the Redis connection."""
if self._redis is not None:
await self._redis.close()
self._redis = None

View File

@@ -10,14 +10,16 @@ Modules:
sync: Issue synchronization tasks (incremental/full sync, webhooks)
workflow: Workflow state management tasks
cost: Cost tracking and budget monitoring tasks
memory_consolidation: Memory consolidation tasks
"""
from app.tasks import agent, cost, git, sync, workflow
from app.tasks import agent, cost, git, memory_consolidation, sync, workflow
__all__ = [
"agent",
"cost",
"git",
"memory_consolidation",
"sync",
"workflow",
]

View File

@@ -0,0 +1,234 @@
# app/tasks/memory_consolidation.py
"""
Memory consolidation Celery tasks.
Handles scheduled and on-demand memory consolidation:
- Session consolidation (on session end)
- Nightly consolidation (scheduled)
- On-demand project consolidation
"""
import logging
from typing import Any
from app.celery_app import celery_app
logger = logging.getLogger(__name__)
@celery_app.task(
bind=True,
name="app.tasks.memory_consolidation.consolidate_session",
autoretry_for=(Exception,),
retry_backoff=True,
retry_kwargs={"max_retries": 3},
)
def consolidate_session(
self,
project_id: str,
session_id: str,
task_type: str = "session_task",
agent_instance_id: str | None = None,
agent_type_id: str | None = None,
) -> dict[str, Any]:
"""
Consolidate a session's working memory to episodic memory.
This task is triggered when an agent session ends to transfer
relevant session data into persistent episodic memory.
Args:
project_id: UUID of the project
session_id: Session identifier
task_type: Type of task performed
agent_instance_id: Optional agent instance UUID
agent_type_id: Optional agent type UUID
Returns:
dict with consolidation results
"""
logger.info(f"Consolidating session {session_id} for project {project_id}")
# TODO: Implement actual consolidation
# This will involve:
# 1. Getting database session from async context
# 2. Loading working memory for session
# 3. Calling consolidation service
# 4. Returning results
# Placeholder implementation
return {
"status": "pending",
"project_id": project_id,
"session_id": session_id,
"episode_created": False,
}
@celery_app.task(
bind=True,
name="app.tasks.memory_consolidation.run_nightly_consolidation",
autoretry_for=(Exception,),
retry_backoff=True,
retry_kwargs={"max_retries": 3},
)
def run_nightly_consolidation(
self,
project_id: str,
agent_type_id: str | None = None,
) -> dict[str, Any]:
"""
Run nightly memory consolidation for a project.
This task performs the full consolidation workflow:
1. Extract facts from recent episodes to semantic memory
2. Learn procedures from successful episode patterns
3. Prune old, low-value memories
Args:
project_id: UUID of the project to consolidate
agent_type_id: Optional agent type to filter by
Returns:
dict with consolidation results
"""
logger.info(f"Running nightly consolidation for project {project_id}")
# TODO: Implement actual consolidation
# This will involve:
# 1. Getting database session from async context
# 2. Creating consolidation service instance
# 3. Running run_nightly_consolidation
# 4. Returning results
# Placeholder implementation
return {
"status": "pending",
"project_id": project_id,
"total_facts_created": 0,
"total_procedures_created": 0,
"total_pruned": 0,
}
@celery_app.task(
bind=True,
name="app.tasks.memory_consolidation.consolidate_episodes_to_facts",
autoretry_for=(Exception,),
retry_backoff=True,
retry_kwargs={"max_retries": 3},
)
def consolidate_episodes_to_facts(
self,
project_id: str,
since_hours: int = 24,
limit: int | None = None,
) -> dict[str, Any]:
"""
Extract facts from episodic memories.
Args:
project_id: UUID of the project
since_hours: Process episodes from last N hours
limit: Maximum episodes to process
Returns:
dict with extraction results
"""
logger.info(f"Consolidating episodes to facts for project {project_id}")
# TODO: Implement actual consolidation
# Placeholder implementation
return {
"status": "pending",
"project_id": project_id,
"items_processed": 0,
"items_created": 0,
}
@celery_app.task(
bind=True,
name="app.tasks.memory_consolidation.consolidate_episodes_to_procedures",
autoretry_for=(Exception,),
retry_backoff=True,
retry_kwargs={"max_retries": 3},
)
def consolidate_episodes_to_procedures(
self,
project_id: str,
agent_type_id: str | None = None,
since_days: int = 7,
) -> dict[str, Any]:
"""
Learn procedures from episodic patterns.
Args:
project_id: UUID of the project
agent_type_id: Optional agent type filter
since_days: Process episodes from last N days
Returns:
dict with procedure learning results
"""
logger.info(f"Consolidating episodes to procedures for project {project_id}")
# TODO: Implement actual consolidation
# Placeholder implementation
return {
"status": "pending",
"project_id": project_id,
"items_processed": 0,
"items_created": 0,
}
@celery_app.task(
bind=True,
name="app.tasks.memory_consolidation.prune_old_memories",
autoretry_for=(Exception,),
retry_backoff=True,
retry_kwargs={"max_retries": 3},
)
def prune_old_memories(
self,
project_id: str,
max_age_days: int = 90,
min_importance: float = 0.2,
) -> dict[str, Any]:
"""
Prune old, low-value memories.
Args:
project_id: UUID of the project
max_age_days: Maximum age in days
min_importance: Minimum importance to keep
Returns:
dict with pruning results
"""
logger.info(f"Pruning old memories for project {project_id}")
# TODO: Implement actual pruning
# Placeholder implementation
return {
"status": "pending",
"project_id": project_id,
"items_pruned": 0,
}
# =========================================================================
# Celery Beat Schedule Configuration
# =========================================================================
# This would typically be configured in celery_app.py or a separate config file
# Example schedule for nightly consolidation:
#
# app.conf.beat_schedule = {
# 'nightly-memory-consolidation': {
# 'task': 'app.tasks.memory_consolidation.run_nightly_consolidation',
# 'schedule': crontab(hour=2, minute=0), # 2 AM daily
# 'args': (None,), # Will process all projects
# },
# }

File diff suppressed because it is too large Load Diff

879
backend/data/demo_data.json Normal file
View File

@@ -0,0 +1,879 @@
{
"organizations": [
{
"name": "Acme Corp",
"slug": "acme-corp",
"description": "A leading provider of coyote-catching equipment."
},
{
"name": "Globex Corporation",
"slug": "globex",
"description": "We own the East Coast."
},
{
"name": "Soylent Corp",
"slug": "soylent",
"description": "Making food for the future."
},
{
"name": "Initech",
"slug": "initech",
"description": "Software for the soul."
},
{
"name": "Umbrella Corporation",
"slug": "umbrella",
"description": "Our business is life itself."
},
{
"name": "Massive Dynamic",
"slug": "massive-dynamic",
"description": "What don't we do?"
}
],
"users": [
{
"email": "demo@example.com",
"password": "DemoPass1234!",
"first_name": "Demo",
"last_name": "User",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": true
},
{
"email": "alice@acme.com",
"password": "Demo123!",
"first_name": "Alice",
"last_name": "Smith",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "admin",
"is_active": true
},
{
"email": "bob@acme.com",
"password": "Demo123!",
"first_name": "Bob",
"last_name": "Jones",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": true
},
{
"email": "charlie@acme.com",
"password": "Demo123!",
"first_name": "Charlie",
"last_name": "Brown",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": false
},
{
"email": "diana@acme.com",
"password": "Demo123!",
"first_name": "Diana",
"last_name": "Prince",
"is_superuser": false,
"organization_slug": "acme-corp",
"role": "member",
"is_active": true
},
{
"email": "carol@globex.com",
"password": "Demo123!",
"first_name": "Carol",
"last_name": "Williams",
"is_superuser": false,
"organization_slug": "globex",
"role": "owner",
"is_active": true
},
{
"email": "dan@globex.com",
"password": "Demo123!",
"first_name": "Dan",
"last_name": "Miller",
"is_superuser": false,
"organization_slug": "globex",
"role": "member",
"is_active": true
},
{
"email": "ellen@globex.com",
"password": "Demo123!",
"first_name": "Ellen",
"last_name": "Ripley",
"is_superuser": false,
"organization_slug": "globex",
"role": "member",
"is_active": true
},
{
"email": "fred@globex.com",
"password": "Demo123!",
"first_name": "Fred",
"last_name": "Flintstone",
"is_superuser": false,
"organization_slug": "globex",
"role": "member",
"is_active": true
},
{
"email": "dave@soylent.com",
"password": "Demo123!",
"first_name": "Dave",
"last_name": "Brown",
"is_superuser": false,
"organization_slug": "soylent",
"role": "member",
"is_active": true
},
{
"email": "gina@soylent.com",
"password": "Demo123!",
"first_name": "Gina",
"last_name": "Torres",
"is_superuser": false,
"organization_slug": "soylent",
"role": "member",
"is_active": true
},
{
"email": "harry@soylent.com",
"password": "Demo123!",
"first_name": "Harry",
"last_name": "Potter",
"is_superuser": false,
"organization_slug": "soylent",
"role": "admin",
"is_active": true
},
{
"email": "eve@initech.com",
"password": "Demo123!",
"first_name": "Eve",
"last_name": "Davis",
"is_superuser": false,
"organization_slug": "initech",
"role": "admin",
"is_active": true
},
{
"email": "iris@initech.com",
"password": "Demo123!",
"first_name": "Iris",
"last_name": "West",
"is_superuser": false,
"organization_slug": "initech",
"role": "member",
"is_active": true
},
{
"email": "jack@initech.com",
"password": "Demo123!",
"first_name": "Jack",
"last_name": "Sparrow",
"is_superuser": false,
"organization_slug": "initech",
"role": "member",
"is_active": false
},
{
"email": "frank@umbrella.com",
"password": "Demo123!",
"first_name": "Frank",
"last_name": "Miller",
"is_superuser": false,
"organization_slug": "umbrella",
"role": "member",
"is_active": true
},
{
"email": "george@umbrella.com",
"password": "Demo123!",
"first_name": "George",
"last_name": "Costanza",
"is_superuser": false,
"organization_slug": "umbrella",
"role": "member",
"is_active": false
},
{
"email": "kate@umbrella.com",
"password": "Demo123!",
"first_name": "Kate",
"last_name": "Bishop",
"is_superuser": false,
"organization_slug": "umbrella",
"role": "member",
"is_active": true
},
{
"email": "leo@massive.com",
"password": "Demo123!",
"first_name": "Leo",
"last_name": "Messi",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "owner",
"is_active": true
},
{
"email": "mary@massive.com",
"password": "Demo123!",
"first_name": "Mary",
"last_name": "Jane",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "nathan@massive.com",
"password": "Demo123!",
"first_name": "Nathan",
"last_name": "Drake",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "olivia@massive.com",
"password": "Demo123!",
"first_name": "Olivia",
"last_name": "Dunham",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "admin",
"is_active": true
},
{
"email": "peter@massive.com",
"password": "Demo123!",
"first_name": "Peter",
"last_name": "Parker",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "quinn@massive.com",
"password": "Demo123!",
"first_name": "Quinn",
"last_name": "Mallory",
"is_superuser": false,
"organization_slug": "massive-dynamic",
"role": "member",
"is_active": true
},
{
"email": "grace@example.com",
"password": "Demo123!",
"first_name": "Grace",
"last_name": "Hopper",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "heidi@example.com",
"password": "Demo123!",
"first_name": "Heidi",
"last_name": "Klum",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "ivan@example.com",
"password": "Demo123!",
"first_name": "Ivan",
"last_name": "Drago",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": false
},
{
"email": "rachel@example.com",
"password": "Demo123!",
"first_name": "Rachel",
"last_name": "Green",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "sam@example.com",
"password": "Demo123!",
"first_name": "Sam",
"last_name": "Wilson",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "tony@example.com",
"password": "Demo123!",
"first_name": "Tony",
"last_name": "Stark",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "una@example.com",
"password": "Demo123!",
"first_name": "Una",
"last_name": "Chin-Riley",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": false
},
{
"email": "victor@example.com",
"password": "Demo123!",
"first_name": "Victor",
"last_name": "Von Doom",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
},
{
"email": "wanda@example.com",
"password": "Demo123!",
"first_name": "Wanda",
"last_name": "Maximoff",
"is_superuser": false,
"organization_slug": null,
"role": null,
"is_active": true
}
],
"projects": [
{
"name": "E-Commerce Platform Redesign",
"slug": "ecommerce-redesign",
"description": "Complete redesign of the e-commerce platform with modern UX, improved checkout flow, and mobile-first approach.",
"owner_email": "__admin__",
"autonomy_level": "milestone",
"status": "active",
"complexity": "complex",
"client_mode": "technical",
"settings": {
"mcp_servers": ["gitea", "knowledge-base"]
}
},
{
"name": "Mobile Banking App",
"slug": "mobile-banking",
"description": "Secure mobile banking application with biometric authentication, transaction history, and real-time notifications.",
"owner_email": "__admin__",
"autonomy_level": "full_control",
"status": "active",
"complexity": "complex",
"client_mode": "technical",
"settings": {
"mcp_servers": ["gitea", "knowledge-base"],
"security_level": "high"
}
},
{
"name": "Internal HR Portal",
"slug": "hr-portal",
"description": "Employee self-service portal for leave requests, performance reviews, and document management.",
"owner_email": "__admin__",
"autonomy_level": "autonomous",
"status": "active",
"complexity": "medium",
"client_mode": "auto",
"settings": {
"mcp_servers": ["gitea", "knowledge-base"]
}
},
{
"name": "API Gateway Modernization",
"slug": "api-gateway",
"description": "Migrate legacy REST API gateway to modern GraphQL-based architecture with improved caching and rate limiting.",
"owner_email": "__admin__",
"autonomy_level": "milestone",
"status": "active",
"complexity": "complex",
"client_mode": "technical",
"settings": {
"mcp_servers": ["gitea", "knowledge-base"]
}
},
{
"name": "Customer Analytics Dashboard",
"slug": "analytics-dashboard",
"description": "Real-time analytics dashboard for customer behavior insights, cohort analysis, and predictive modeling.",
"owner_email": "__admin__",
"autonomy_level": "autonomous",
"status": "completed",
"complexity": "medium",
"client_mode": "auto",
"settings": {
"mcp_servers": ["gitea", "knowledge-base"]
}
},
{
"name": "DevOps Pipeline Automation",
"slug": "devops-automation",
"description": "Automate CI/CD pipelines with AI-assisted deployments, rollback detection, and infrastructure as code.",
"owner_email": "__admin__",
"autonomy_level": "full_control",
"status": "active",
"complexity": "complex",
"client_mode": "technical",
"settings": {
"mcp_servers": ["gitea", "knowledge-base"]
}
}
],
"sprints": [
{
"project_slug": "ecommerce-redesign",
"name": "Sprint 1: Foundation",
"number": 1,
"goal": "Set up project infrastructure, design system, and core navigation components.",
"start_date": "2026-01-06",
"end_date": "2026-01-20",
"status": "active",
"planned_points": 21
},
{
"project_slug": "ecommerce-redesign",
"name": "Sprint 2: Product Catalog",
"number": 2,
"goal": "Implement product listing, filtering, search, and detail pages.",
"start_date": "2026-01-20",
"end_date": "2026-02-03",
"status": "planned",
"planned_points": 34
},
{
"project_slug": "mobile-banking",
"name": "Sprint 1: Authentication",
"number": 1,
"goal": "Implement secure login, biometric authentication, and session management.",
"start_date": "2026-01-06",
"end_date": "2026-01-20",
"status": "active",
"planned_points": 26
},
{
"project_slug": "hr-portal",
"name": "Sprint 1: Core Features",
"number": 1,
"goal": "Build employee dashboard, leave request system, and basic document management.",
"start_date": "2026-01-06",
"end_date": "2026-01-20",
"status": "active",
"planned_points": 18
},
{
"project_slug": "api-gateway",
"name": "Sprint 1: GraphQL Schema",
"number": 1,
"goal": "Define GraphQL schema and implement core resolvers for existing REST endpoints.",
"start_date": "2025-12-23",
"end_date": "2026-01-06",
"status": "completed",
"planned_points": 21
},
{
"project_slug": "api-gateway",
"name": "Sprint 2: Caching Layer",
"number": 2,
"goal": "Implement Redis-based caching layer and query batching.",
"start_date": "2026-01-06",
"end_date": "2026-01-20",
"status": "active",
"planned_points": 26
},
{
"project_slug": "analytics-dashboard",
"name": "Sprint 1: Data Pipeline",
"number": 1,
"goal": "Set up data ingestion pipeline and real-time event processing.",
"start_date": "2025-11-15",
"end_date": "2025-11-29",
"status": "completed",
"planned_points": 18
},
{
"project_slug": "analytics-dashboard",
"name": "Sprint 2: Dashboard UI",
"number": 2,
"goal": "Build interactive dashboard with charts and filtering capabilities.",
"start_date": "2025-11-29",
"end_date": "2025-12-13",
"status": "completed",
"planned_points": 21
},
{
"project_slug": "devops-automation",
"name": "Sprint 1: Pipeline Templates",
"number": 1,
"goal": "Create reusable CI/CD pipeline templates for common deployment patterns.",
"start_date": "2026-01-06",
"end_date": "2026-01-20",
"status": "active",
"planned_points": 24
}
],
"agent_instances": [
{
"project_slug": "ecommerce-redesign",
"agent_type_slug": "product-owner",
"name": "Aria",
"status": "idle"
},
{
"project_slug": "ecommerce-redesign",
"agent_type_slug": "solutions-architect",
"name": "Marcus",
"status": "idle"
},
{
"project_slug": "ecommerce-redesign",
"agent_type_slug": "senior-engineer",
"name": "Zara",
"status": "working",
"current_task": "Implementing responsive navigation component"
},
{
"project_slug": "mobile-banking",
"agent_type_slug": "product-owner",
"name": "Felix",
"status": "waiting",
"current_task": "Awaiting security requirements clarification"
},
{
"project_slug": "mobile-banking",
"agent_type_slug": "senior-engineer",
"name": "Luna",
"status": "working",
"current_task": "Implementing biometric authentication flow"
},
{
"project_slug": "mobile-banking",
"agent_type_slug": "qa-engineer",
"name": "Rex",
"status": "idle"
},
{
"project_slug": "hr-portal",
"agent_type_slug": "business-analyst",
"name": "Nova",
"status": "working",
"current_task": "Documenting leave request workflow"
},
{
"project_slug": "hr-portal",
"agent_type_slug": "senior-engineer",
"name": "Atlas",
"status": "working",
"current_task": "Building employee dashboard API"
},
{
"project_slug": "api-gateway",
"agent_type_slug": "solutions-architect",
"name": "Orion",
"status": "working",
"current_task": "Designing caching strategy for GraphQL queries"
},
{
"project_slug": "api-gateway",
"agent_type_slug": "senior-engineer",
"name": "Cleo",
"status": "working",
"current_task": "Implementing Redis cache invalidation"
},
{
"project_slug": "devops-automation",
"agent_type_slug": "devops-engineer",
"name": "Volt",
"status": "working",
"current_task": "Creating Terraform modules for AWS ECS"
},
{
"project_slug": "devops-automation",
"agent_type_slug": "senior-engineer",
"name": "Sage",
"status": "idle"
},
{
"project_slug": "devops-automation",
"agent_type_slug": "qa-engineer",
"name": "Echo",
"status": "waiting",
"current_task": "Waiting for pipeline templates to test"
}
],
"issues": [
{
"project_slug": "ecommerce-redesign",
"sprint_number": 1,
"type": "story",
"title": "Design responsive navigation component",
"body": "As a user, I want a navigation menu that works seamlessly on both desktop and mobile devices.\n\n## Acceptance Criteria\n- Hamburger menu on mobile viewports\n- Sticky header on scroll\n- Keyboard accessible\n- Screen reader compatible",
"status": "in_progress",
"priority": "high",
"labels": ["frontend", "design-system"],
"story_points": 5,
"assigned_agent_name": "Zara"
},
{
"project_slug": "ecommerce-redesign",
"sprint_number": 1,
"type": "task",
"title": "Set up Tailwind CSS configuration",
"body": "Configure Tailwind CSS with custom design tokens for the e-commerce platform.\n\n- Define color palette\n- Set up typography scale\n- Configure breakpoints\n- Add custom utilities",
"status": "closed",
"priority": "high",
"labels": ["frontend", "infrastructure"],
"story_points": 3
},
{
"project_slug": "ecommerce-redesign",
"sprint_number": 1,
"type": "task",
"title": "Create base component library structure",
"body": "Set up the foundational component library with:\n- Button variants\n- Form inputs\n- Card component\n- Modal system",
"status": "open",
"priority": "medium",
"labels": ["frontend", "design-system"],
"story_points": 8
},
{
"project_slug": "ecommerce-redesign",
"sprint_number": 1,
"type": "story",
"title": "Implement user authentication flow",
"body": "As a user, I want to sign up, log in, and manage my account.\n\n## Features\n- Email/password registration\n- Social login (Google, GitHub)\n- Password reset flow\n- Email verification",
"status": "open",
"priority": "critical",
"labels": ["auth", "backend", "frontend"],
"story_points": 13
},
{
"project_slug": "ecommerce-redesign",
"sprint_number": 2,
"type": "epic",
"title": "Product Catalog System",
"body": "Complete product catalog implementation including:\n- Product listing with pagination\n- Advanced filtering and search\n- Product detail pages\n- Category navigation",
"status": "open",
"priority": "high",
"labels": ["catalog", "backend", "frontend"],
"story_points": null
},
{
"project_slug": "mobile-banking",
"sprint_number": 1,
"type": "story",
"title": "Implement biometric authentication",
"body": "As a user, I want to log in using Face ID or Touch ID for quick and secure access.\n\n## Requirements\n- Support Face ID on iOS\n- Support fingerprint on Android\n- Fallback to PIN/password\n- Secure keychain storage",
"status": "in_progress",
"priority": "critical",
"labels": ["auth", "security", "mobile"],
"story_points": 8,
"assigned_agent_name": "Luna"
},
{
"project_slug": "mobile-banking",
"sprint_number": 1,
"type": "task",
"title": "Set up secure session management",
"body": "Implement secure session handling with:\n- JWT tokens with short expiry\n- Refresh token rotation\n- Session timeout handling\n- Multi-device session management",
"status": "open",
"priority": "critical",
"labels": ["auth", "security", "backend"],
"story_points": 5
},
{
"project_slug": "mobile-banking",
"sprint_number": 1,
"type": "bug",
"title": "Fix token refresh race condition",
"body": "When multiple API calls happen simultaneously after token expiry, multiple refresh requests are made causing 401 errors.\n\n## Steps to Reproduce\n1. Wait for token to expire\n2. Trigger multiple API calls at once\n3. Observe multiple 401 errors",
"status": "open",
"priority": "high",
"labels": ["bug", "auth", "backend"],
"story_points": 3
},
{
"project_slug": "mobile-banking",
"sprint_number": 1,
"type": "task",
"title": "Implement PIN entry screen",
"body": "Create secure PIN entry component with:\n- Masked input display\n- Haptic feedback\n- Brute force protection (lockout after 5 attempts)\n- Secure PIN storage",
"status": "open",
"priority": "high",
"labels": ["auth", "mobile", "frontend"],
"story_points": 5
},
{
"project_slug": "hr-portal",
"sprint_number": 1,
"type": "story",
"title": "Build employee dashboard",
"body": "As an employee, I want a dashboard showing my key information at a glance.\n\n## Dashboard Widgets\n- Leave balance\n- Pending approvals\n- Upcoming holidays\n- Recent announcements",
"status": "in_progress",
"priority": "high",
"labels": ["frontend", "dashboard"],
"story_points": 5,
"assigned_agent_name": "Atlas"
},
{
"project_slug": "hr-portal",
"sprint_number": 1,
"type": "story",
"title": "Implement leave request system",
"body": "As an employee, I want to submit and track leave requests.\n\n## Features\n- Submit leave request with date range\n- View leave balance by type\n- Track request status\n- Manager approval workflow",
"status": "in_progress",
"priority": "high",
"labels": ["backend", "frontend", "workflow"],
"story_points": 8,
"assigned_agent_name": "Nova"
},
{
"project_slug": "hr-portal",
"sprint_number": 1,
"type": "task",
"title": "Set up document storage integration",
"body": "Integrate with S3-compatible storage for employee documents:\n- Secure upload/download\n- File type validation\n- Size limits\n- Virus scanning",
"status": "open",
"priority": "medium",
"labels": ["backend", "infrastructure", "storage"],
"story_points": 5
},
{
"project_slug": "api-gateway",
"sprint_number": 2,
"type": "story",
"title": "Implement Redis caching layer",
"body": "As an API consumer, I want responses to be cached for improved performance.\n\n## Requirements\n- Cache GraphQL query results\n- Configurable TTL per query type\n- Cache invalidation on mutations\n- Cache hit/miss metrics",
"status": "in_progress",
"priority": "critical",
"labels": ["backend", "performance", "redis"],
"story_points": 8,
"assigned_agent_name": "Cleo"
},
{
"project_slug": "api-gateway",
"sprint_number": 2,
"type": "task",
"title": "Set up query batching and deduplication",
"body": "Implement DataLoader pattern for:\n- Batching multiple queries into single database calls\n- Deduplicating identical queries within request scope\n- N+1 query prevention",
"status": "open",
"priority": "high",
"labels": ["backend", "performance", "graphql"],
"story_points": 5
},
{
"project_slug": "api-gateway",
"sprint_number": 2,
"type": "task",
"title": "Implement rate limiting middleware",
"body": "Add rate limiting to prevent API abuse:\n- Per-user rate limits\n- Per-IP fallback for anonymous requests\n- Sliding window algorithm\n- Custom limits per operation type",
"status": "open",
"priority": "high",
"labels": ["backend", "security", "middleware"],
"story_points": 5,
"assigned_agent_name": "Orion"
},
{
"project_slug": "api-gateway",
"sprint_number": 2,
"type": "bug",
"title": "Fix N+1 query in user resolver",
"body": "The user resolver is making separate database calls for each user's organization.\n\n## Steps to Reproduce\n1. Query users with organization field\n2. Check database logs\n3. Observe N+1 queries",
"status": "open",
"priority": "high",
"labels": ["bug", "performance", "graphql"],
"story_points": 3
},
{
"project_slug": "analytics-dashboard",
"sprint_number": 2,
"type": "story",
"title": "Build cohort analysis charts",
"body": "As a product manager, I want to analyze user cohorts over time.\n\n## Features\n- Weekly/monthly cohort grouping\n- Retention curve visualization\n- Cohort comparison view",
"status": "closed",
"priority": "high",
"labels": ["frontend", "charts", "analytics"],
"story_points": 8
},
{
"project_slug": "analytics-dashboard",
"sprint_number": 2,
"type": "task",
"title": "Implement real-time event streaming",
"body": "Set up WebSocket connection for live event updates:\n- Event type filtering\n- Buffering for high-volume periods\n- Reconnection handling",
"status": "closed",
"priority": "high",
"labels": ["backend", "websocket", "realtime"],
"story_points": 5
},
{
"project_slug": "devops-automation",
"sprint_number": 1,
"type": "epic",
"title": "CI/CD Pipeline Templates",
"body": "Create reusable pipeline templates for common deployment patterns.\n\n## Templates Needed\n- Node.js applications\n- Python applications\n- Docker-based deployments\n- Kubernetes deployments",
"status": "in_progress",
"priority": "critical",
"labels": ["infrastructure", "cicd", "templates"],
"story_points": null
},
{
"project_slug": "devops-automation",
"sprint_number": 1,
"type": "story",
"title": "Create Terraform modules for AWS ECS",
"body": "As a DevOps engineer, I want Terraform modules for ECS deployments.\n\n## Modules\n- ECS cluster configuration\n- Service and task definitions\n- Load balancer integration\n- Auto-scaling policies",
"status": "in_progress",
"priority": "high",
"labels": ["terraform", "aws", "ecs"],
"story_points": 8,
"assigned_agent_name": "Volt"
},
{
"project_slug": "devops-automation",
"sprint_number": 1,
"type": "task",
"title": "Set up Gitea Actions runners",
"body": "Configure self-hosted Gitea Actions runners:\n- Docker-in-Docker support\n- Caching for npm/pip\n- Secrets management\n- Resource limits",
"status": "open",
"priority": "high",
"labels": ["infrastructure", "gitea", "cicd"],
"story_points": 5
},
{
"project_slug": "devops-automation",
"sprint_number": 1,
"type": "task",
"title": "Implement rollback detection system",
"body": "AI-assisted rollback detection:\n- Monitor deployment health metrics\n- Automatic rollback triggers\n- Notification system\n- Post-rollback analysis",
"status": "open",
"priority": "medium",
"labels": ["ai", "monitoring", "automation"],
"story_points": 8
}
]
}

View File

@@ -0,0 +1,507 @@
# Agent Memory System
Comprehensive multi-tier cognitive memory for AI agents, enabling state persistence, experiential learning, and context continuity across sessions.
## Overview
The Agent Memory System implements a cognitive architecture inspired by human memory:
```
+------------------------------------------------------------------+
| Agent Memory System |
+------------------------------------------------------------------+
| |
| +------------------+ +------------------+ |
| | Working Memory |----consolidate---->| Episodic Memory | |
| | (Redis/In-Mem) | | (PostgreSQL) | |
| | | | | |
| | - Current task | | - Past sessions | |
| | - Variables | | - Experiences | |
| | - Scratchpad | | - Outcomes | |
| +------------------+ +--------+---------+ |
| | |
| extract | |
| v |
| +------------------+ +------------------+ |
| |Procedural Memory |<-----learn from----| Semantic Memory | |
| | (PostgreSQL) | | (PostgreSQL + | |
| | | | pgvector) | |
| | - Procedures | | | |
| | - Skills | | - Facts | |
| | - Patterns | | - Entities | |
| +------------------+ | - Relationships | |
| +------------------+ |
+------------------------------------------------------------------+
```
## Memory Types
### Working Memory
Short-term, session-scoped memory for current task state.
**Features:**
- Key-value storage with TTL
- Task state tracking
- Scratchpad for reasoning
- Checkpoint/restore support
- Redis primary with in-memory fallback
**Usage:**
```python
from app.services.memory.working import WorkingMemory
memory = WorkingMemory(scope_context)
await memory.set("key", {"data": "value"}, ttl_seconds=3600)
value = await memory.get("key")
# Task state
await memory.set_task_state(TaskState(task_id="t1", status="running"))
state = await memory.get_task_state()
# Checkpoints
checkpoint_id = await memory.create_checkpoint()
await memory.restore_checkpoint(checkpoint_id)
```
### Episodic Memory
Experiential records of past agent actions and outcomes.
**Features:**
- Records task completions and failures
- Semantic similarity search (pgvector)
- Temporal and outcome-based retrieval
- Importance scoring
- Episode summarization
**Usage:**
```python
from app.services.memory.episodic import EpisodicMemory
memory = EpisodicMemory(session, embedder)
# Record an episode
episode = await memory.record_episode(
project_id=project_id,
episode=EpisodeCreate(
task_type="code_review",
task_description="Review PR #42",
outcome=Outcome.SUCCESS,
actions=[{"type": "analyze", "target": "src/"}],
)
)
# Search similar experiences
similar = await memory.search_similar(
project_id=project_id,
query="debugging memory leak",
limit=5
)
# Get recent episodes
recent = await memory.get_recent(project_id, limit=10)
```
### Semantic Memory
Learned facts and knowledge with confidence scoring.
**Features:**
- Triple format (subject, predicate, object)
- Confidence scoring with decay
- Fact extraction from episodes
- Conflict resolution
- Entity-based retrieval
**Usage:**
```python
from app.services.memory.semantic import SemanticMemory
memory = SemanticMemory(session, embedder)
# Store a fact
fact = await memory.store_fact(
project_id=project_id,
fact=FactCreate(
subject="UserService",
predicate="handles",
object="authentication",
confidence=0.9,
)
)
# Search facts
facts = await memory.search_facts(project_id, "authentication flow")
# Reinforce on repeated learning
await memory.reinforce_fact(fact.id)
```
### Procedural Memory
Learned skills and procedures from successful patterns.
**Features:**
- Procedure recording from task patterns
- Trigger-based matching
- Success rate tracking
- Procedure suggestions
- Step-by-step storage
**Usage:**
```python
from app.services.memory.procedural import ProceduralMemory
memory = ProceduralMemory(session, embedder)
# Record a procedure
procedure = await memory.record_procedure(
project_id=project_id,
procedure=ProcedureCreate(
name="PR Review Process",
trigger_pattern="code review requested",
steps=[
Step(action="fetch_diff"),
Step(action="analyze_changes"),
Step(action="check_tests"),
]
)
)
# Find matching procedures
matches = await memory.find_matching(project_id, "need to review code")
# Record outcomes
await memory.record_outcome(procedure.id, success=True)
```
## Memory Scoping
Memory is organized in a hierarchical scope structure:
```
Global Memory (shared by all)
└── Project Memory (per project)
└── Agent Type Memory (per agent type)
└── Agent Instance Memory (per instance)
└── Session Memory (ephemeral)
```
**Usage:**
```python
from app.services.memory.scoping import ScopeManager, ScopeLevel
manager = ScopeManager(session)
# Get scoped memories with inheritance
memories = await manager.get_scoped_memories(
context=ScopeContext(
project_id=project_id,
agent_type_id=agent_type_id,
agent_instance_id=agent_instance_id,
session_id=session_id,
),
include_inherited=True, # Include parent scopes
)
```
## Memory Consolidation
Automatic background processes transfer and extract knowledge:
```
Working Memory ──> Episodic Memory ──> Semantic Memory
└──> Procedural Memory
```
**Consolidation Types:**
- `working_to_episodic`: Transfer session state to episodes (on session end)
- `episodic_to_semantic`: Extract facts from experiences
- `episodic_to_procedural`: Learn procedures from patterns
- `prune`: Remove low-value memories
**Celery Tasks:**
```python
from app.tasks.memory_consolidation import (
consolidate_session,
run_nightly_consolidation,
prune_old_memories,
)
# Manual consolidation
consolidate_session.delay(session_id)
# Scheduled nightly (3 AM by default)
run_nightly_consolidation.delay()
```
## Memory Retrieval
### Hybrid Retrieval
Combine multiple retrieval strategies:
```python
from app.services.memory.indexing import RetrievalEngine
engine = RetrievalEngine(session, embedder)
# Hybrid search across memory types
results = await engine.retrieve_hybrid(
project_id=project_id,
query="authentication error handling",
memory_types=["episodic", "semantic", "procedural"],
filters={"outcome": "success"},
limit=10,
)
```
### Index Types
- **Vector Index**: Semantic similarity (HNSW/pgvector)
- **Temporal Index**: Time-based retrieval
- **Entity Index**: Entity mention lookup
- **Outcome Index**: Success/failure filtering
## MCP Tools
The memory system exposes MCP tools for agent use:
### `remember`
Store information in memory.
```json
{
"memory_type": "working",
"content": {"key": "value"},
"importance": 0.8,
"ttl_seconds": 3600
}
```
### `recall`
Retrieve from memory.
```json
{
"query": "authentication patterns",
"memory_types": ["episodic", "semantic"],
"limit": 10,
"filters": {"outcome": "success"}
}
```
### `forget`
Remove from memory.
```json
{
"memory_type": "working",
"key": "temp_data"
}
```
### `reflect`
Analyze memory patterns.
```json
{
"analysis_type": "success_factors",
"task_type": "code_review",
"time_range_days": 30
}
```
### `get_memory_stats`
Get memory usage statistics.
### `record_outcome`
Record task success/failure for learning.
## Memory Reflection
Analyze patterns and generate insights from memory:
```python
from app.services.memory.reflection import MemoryReflection, TimeRange
reflection = MemoryReflection(session)
# Detect patterns
patterns = await reflection.analyze_patterns(
project_id=project_id,
time_range=TimeRange.last_days(30),
)
# Identify success factors
factors = await reflection.identify_success_factors(
project_id=project_id,
task_type="code_review",
)
# Detect anomalies
anomalies = await reflection.detect_anomalies(
project_id=project_id,
baseline_days=30,
)
# Generate insights
insights = await reflection.generate_insights(project_id)
# Comprehensive reflection
result = await reflection.reflect(project_id)
print(result.summary)
```
## Configuration
All settings use the `MEM_` environment variable prefix:
| Variable | Default | Description |
|----------|---------|-------------|
| `MEM_WORKING_MEMORY_BACKEND` | `redis` | Backend: `redis` or `memory` |
| `MEM_WORKING_MEMORY_DEFAULT_TTL_SECONDS` | `3600` | Default TTL (1 hour) |
| `MEM_REDIS_URL` | `redis://localhost:6379/0` | Redis connection URL |
| `MEM_EPISODIC_MAX_EPISODES_PER_PROJECT` | `10000` | Max episodes per project |
| `MEM_EPISODIC_RETENTION_DAYS` | `365` | Episode retention period |
| `MEM_SEMANTIC_MAX_FACTS_PER_PROJECT` | `50000` | Max facts per project |
| `MEM_SEMANTIC_CONFIDENCE_DECAY_DAYS` | `90` | Confidence half-life |
| `MEM_EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model |
| `MEM_EMBEDDING_DIMENSIONS` | `1536` | Vector dimensions |
| `MEM_RETRIEVAL_MIN_SIMILARITY` | `0.5` | Minimum similarity score |
| `MEM_CONSOLIDATION_ENABLED` | `true` | Enable auto-consolidation |
| `MEM_CONSOLIDATION_SCHEDULE_CRON` | `0 3 * * *` | Nightly schedule |
| `MEM_CACHE_ENABLED` | `true` | Enable retrieval caching |
| `MEM_CACHE_TTL_SECONDS` | `300` | Cache TTL (5 minutes) |
See `app/services/memory/config.py` for complete configuration options.
## Integration with Context Engine
Memory integrates with the Context Engine as a context source:
```python
from app.services.memory.integration import MemoryContextSource
# Register as context source
source = MemoryContextSource(memory_manager)
context_engine.register_source(source)
# Memory is automatically included in context assembly
context = await context_engine.assemble_context(
project_id=project_id,
session_id=session_id,
current_task="Review authentication code",
)
```
## Caching
Multi-layer caching for performance:
- **Hot Cache**: Frequently accessed memories (LRU)
- **Retrieval Cache**: Query result caching
- **Embedding Cache**: Pre-computed embeddings
```python
from app.services.memory.cache import CacheManager
cache = CacheManager(settings)
await cache.warm_hot_cache(project_id) # Pre-warm common memories
```
## Metrics
Prometheus-compatible metrics:
| Metric | Type | Labels |
|--------|------|--------|
| `memory_operations_total` | Counter | operation, memory_type, scope, success |
| `memory_retrievals_total` | Counter | memory_type, strategy |
| `memory_cache_hits_total` | Counter | cache_type |
| `memory_retrieval_latency_seconds` | Histogram | - |
| `memory_consolidation_duration_seconds` | Histogram | - |
| `memory_items_count` | Gauge | memory_type, scope |
```python
from app.services.memory.metrics import get_memory_metrics
metrics = await get_memory_metrics()
summary = await metrics.get_summary()
prometheus_output = await metrics.get_prometheus_format()
```
## Performance Targets
| Operation | Target P95 |
|-----------|------------|
| Working memory get/set | < 5ms |
| Episodic memory retrieval | < 100ms |
| Semantic memory search | < 100ms |
| Procedural memory matching | < 50ms |
| Consolidation batch (1000 items) | < 30s |
## Troubleshooting
### Redis Connection Issues
```bash
# Check Redis connectivity
redis-cli ping
# Verify memory settings
MEM_REDIS_URL=redis://localhost:6379/0
```
### Slow Retrieval
1. Check if caching is enabled: `MEM_CACHE_ENABLED=true`
2. Verify HNSW indexes exist on vector columns
3. Monitor `memory_retrieval_latency_seconds` metric
### High Memory Usage
1. Review `MEM_EPISODIC_MAX_EPISODES_PER_PROJECT` limit
2. Ensure pruning is enabled: `MEM_PRUNING_ENABLED=true`
3. Check consolidation is running (cron schedule)
### Embedding Errors
1. Verify LLM Gateway is accessible
2. Check embedding model is valid
3. Review batch size if hitting rate limits
## Directory Structure
```
app/services/memory/
├── __init__.py # Public exports
├── config.py # MemorySettings
├── exceptions.py # Memory-specific errors
├── manager.py # MemoryManager facade
├── types.py # Core types
├── working/ # Working memory
│ ├── memory.py
│ └── storage.py
├── episodic/ # Episodic memory
│ ├── memory.py
│ ├── recorder.py
│ └── retrieval.py
├── semantic/ # Semantic memory
│ ├── memory.py
│ ├── extraction.py
│ └── verification.py
├── procedural/ # Procedural memory
│ ├── memory.py
│ └── matching.py
├── scoping/ # Memory scoping
│ ├── scope.py
│ └── resolver.py
├── indexing/ # Indexing & retrieval
│ ├── index.py
│ └── retrieval.py
├── consolidation/ # Memory consolidation
│ └── service.py
├── reflection/ # Memory reflection
│ ├── service.py
│ └── types.py
├── integration/ # External integrations
│ ├── context_source.py
│ └── lifecycle.py
├── cache/ # Caching layer
│ ├── cache_manager.py
│ ├── hot_cache.py
│ └── embedding_cache.py
├── mcp/ # MCP tools
│ ├── service.py
│ └── tools.py
└── metrics/ # Observability
└── collector.py
```

View File

@@ -26,6 +26,7 @@ Usage:
# Inside Docker (without --local flag):
python migrate.py auto "Add new field"
"""
import argparse
import os
import subprocess
@@ -44,13 +45,14 @@ def setup_database_url(use_local: bool) -> str:
# Override DATABASE_URL to use localhost instead of Docker hostname
local_url = os.environ.get(
"LOCAL_DATABASE_URL",
"postgresql://postgres:postgres@localhost:5432/app"
"postgresql://postgres:postgres@localhost:5432/syndarix",
)
os.environ["DATABASE_URL"] = local_url
return local_url
# Use the configured DATABASE_URL from environment/.env
from app.core.config import settings
return settings.database_url
@@ -61,6 +63,7 @@ def check_models():
try:
# Import all models through the models package
from app.models import __all__ as all_models
print(f"Found {len(all_models)} model(s):")
for model in all_models:
print(f" - {model}")
@@ -110,7 +113,9 @@ def generate_migration(message, rev_id=None, auto_rev_id=True, offline=False):
# Look for the revision ID, which is typically 12 hex characters
parts = line.split()
for part in parts:
if len(part) >= 12 and all(c in "0123456789abcdef" for c in part[:12]):
if len(part) >= 12 and all(
c in "0123456789abcdef" for c in part[:12]
):
revision = part[:12]
break
except Exception as e:
@@ -185,6 +190,7 @@ def check_database_connection():
db_url = os.environ.get("DATABASE_URL")
if not db_url:
from app.core.config import settings
db_url = settings.database_url
engine = create_engine(db_url)
@@ -270,8 +276,8 @@ def generate_offline_migration(message, rev_id):
content = f'''"""{message}
Revision ID: {rev_id}
Revises: {down_revision or ''}
Create Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}
Revises: {down_revision or ""}
Create Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")}
"""
@@ -320,6 +326,7 @@ def reset_alembic_version():
db_url = os.environ.get("DATABASE_URL")
if not db_url:
from app.core.config import settings
db_url = settings.database_url
try:
@@ -338,82 +345,80 @@ def reset_alembic_version():
def main():
"""Main function"""
parser = argparse.ArgumentParser(
description='Database migration helper for Generative Models Arena'
description="Database migration helper for Generative Models Arena"
)
# Global options
parser.add_argument(
'--local', '-l',
action='store_true',
help='Use localhost instead of Docker hostname (for local development)'
"--local",
"-l",
action="store_true",
help="Use localhost instead of Docker hostname (for local development)",
)
subparsers = parser.add_subparsers(dest='command', help='Command to run')
subparsers = parser.add_subparsers(dest="command", help="Command to run")
# Generate command
generate_parser = subparsers.add_parser('generate', help='Generate a migration')
generate_parser.add_argument('message', help='Migration message')
generate_parser = subparsers.add_parser("generate", help="Generate a migration")
generate_parser.add_argument("message", help="Migration message")
generate_parser.add_argument(
'--rev-id',
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
"--rev-id", help="Custom revision ID (e.g., 0001, 0002 for sequential naming)"
)
generate_parser.add_argument(
'--offline',
action='store_true',
help='Generate empty migration template without database connection'
"--offline",
action="store_true",
help="Generate empty migration template without database connection",
)
# Apply command
apply_parser = subparsers.add_parser('apply', help='Apply migrations')
apply_parser.add_argument('--revision', help='Specific revision to apply to')
apply_parser = subparsers.add_parser("apply", help="Apply migrations")
apply_parser.add_argument("--revision", help="Specific revision to apply to")
# List command
subparsers.add_parser('list', help='List migrations')
subparsers.add_parser("list", help="List migrations")
# Current command
subparsers.add_parser('current', help='Show current revision')
subparsers.add_parser("current", help="Show current revision")
# Check command
subparsers.add_parser('check', help='Check database connection and models')
subparsers.add_parser("check", help="Check database connection and models")
# Next command (show next revision ID)
subparsers.add_parser('next', help='Show the next sequential revision ID')
subparsers.add_parser("next", help="Show the next sequential revision ID")
# Reset command (clear alembic_version table)
subparsers.add_parser(
'reset',
help='Reset alembic_version table (use after deleting all migrations)'
"reset", help="Reset alembic_version table (use after deleting all migrations)"
)
# Auto command (generate and apply)
auto_parser = subparsers.add_parser('auto', help='Generate and apply migration')
auto_parser.add_argument('message', help='Migration message')
auto_parser = subparsers.add_parser("auto", help="Generate and apply migration")
auto_parser.add_argument("message", help="Migration message")
auto_parser.add_argument(
'--rev-id',
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
"--rev-id", help="Custom revision ID (e.g., 0001, 0002 for sequential naming)"
)
auto_parser.add_argument(
'--offline',
action='store_true',
help='Generate empty migration template without database connection'
"--offline",
action="store_true",
help="Generate empty migration template without database connection",
)
args = parser.parse_args()
# Commands that don't need database connection
if args.command == 'next':
if args.command == "next":
show_next_rev_id()
return
# Check if offline mode is requested
offline = getattr(args, 'offline', False)
offline = getattr(args, "offline", False)
# Offline generate doesn't need database or model check
if args.command == 'generate' and offline:
if args.command == "generate" and offline:
generate_migration(args.message, rev_id=args.rev_id, offline=True)
return
if args.command == 'auto' and offline:
if args.command == "auto" and offline:
generate_migration(args.message, rev_id=args.rev_id, offline=True)
print("\nOffline migration generated. Apply it later with:")
print(" python migrate.py --local apply")
@@ -423,27 +428,27 @@ def main():
db_url = setup_database_url(args.local)
print(f"Using database URL: {db_url}")
if args.command == 'generate':
if args.command == "generate":
check_models()
generate_migration(args.message, rev_id=args.rev_id)
elif args.command == 'apply':
elif args.command == "apply":
apply_migration(args.revision)
elif args.command == 'list':
elif args.command == "list":
list_migrations()
elif args.command == 'current':
elif args.command == "current":
show_current()
elif args.command == 'check':
elif args.command == "check":
check_database_connection()
check_models()
elif args.command == 'reset':
elif args.command == "reset":
reset_alembic_version()
elif args.command == 'auto':
elif args.command == "auto":
check_models()
revision = generate_migration(args.message, rev_id=args.rev_id)
if revision:

View File

@@ -745,3 +745,230 @@ class TestAgentTypeInstanceCount:
for agent_type in data["data"]:
assert "instance_count" in agent_type
assert isinstance(agent_type["instance_count"], int)
@pytest.mark.asyncio
class TestAgentTypeCategoryFields:
"""Tests for agent type category and display fields."""
async def test_create_agent_type_with_category_fields(
self, client, superuser_token
):
"""Test creating agent type with all category and display fields."""
unique_slug = f"category-type-{uuid.uuid4().hex[:8]}"
response = await client.post(
"/api/v1/agent-types",
json={
"name": "Categorized Agent Type",
"slug": unique_slug,
"description": "An agent type with category fields",
"expertise": ["python"],
"personality_prompt": "You are a helpful assistant.",
"primary_model": "claude-opus-4-5-20251101",
# Category and display fields
"category": "development",
"icon": "code",
"color": "#3B82F6",
"sort_order": 10,
"typical_tasks": ["Write code", "Review PRs"],
"collaboration_hints": ["backend-engineer", "qa-engineer"],
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["category"] == "development"
assert data["icon"] == "code"
assert data["color"] == "#3B82F6"
assert data["sort_order"] == 10
assert data["typical_tasks"] == ["Write code", "Review PRs"]
assert data["collaboration_hints"] == ["backend-engineer", "qa-engineer"]
async def test_create_agent_type_with_nullable_category(
self, client, superuser_token
):
"""Test creating agent type with null category."""
unique_slug = f"null-category-{uuid.uuid4().hex[:8]}"
response = await client.post(
"/api/v1/agent-types",
json={
"name": "Uncategorized Agent",
"slug": unique_slug,
"expertise": ["general"],
"personality_prompt": "You are a helpful assistant.",
"primary_model": "claude-opus-4-5-20251101",
"category": None,
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["category"] is None
async def test_create_agent_type_invalid_color_format(
self, client, superuser_token
):
"""Test that invalid color format is rejected."""
unique_slug = f"invalid-color-{uuid.uuid4().hex[:8]}"
response = await client.post(
"/api/v1/agent-types",
json={
"name": "Invalid Color Agent",
"slug": unique_slug,
"expertise": ["python"],
"personality_prompt": "You are a helpful assistant.",
"primary_model": "claude-opus-4-5-20251101",
"color": "not-a-hex-color",
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
async def test_create_agent_type_invalid_category(self, client, superuser_token):
"""Test that invalid category value is rejected."""
unique_slug = f"invalid-category-{uuid.uuid4().hex[:8]}"
response = await client.post(
"/api/v1/agent-types",
json={
"name": "Invalid Category Agent",
"slug": unique_slug,
"expertise": ["python"],
"personality_prompt": "You are a helpful assistant.",
"primary_model": "claude-opus-4-5-20251101",
"category": "not_a_valid_category",
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
async def test_update_agent_type_category_fields(
self, client, superuser_token, test_agent_type
):
"""Test updating category and display fields."""
agent_type_id = test_agent_type["id"]
response = await client.patch(
f"/api/v1/agent-types/{agent_type_id}",
json={
"category": "ai_ml",
"icon": "brain",
"color": "#8B5CF6",
"sort_order": 50,
"typical_tasks": ["Train models", "Analyze data"],
"collaboration_hints": ["data-scientist"],
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["category"] == "ai_ml"
assert data["icon"] == "brain"
assert data["color"] == "#8B5CF6"
assert data["sort_order"] == 50
assert data["typical_tasks"] == ["Train models", "Analyze data"]
assert data["collaboration_hints"] == ["data-scientist"]
@pytest.mark.asyncio
class TestAgentTypeCategoryFilter:
"""Tests for agent type category filtering."""
async def test_list_agent_types_filter_by_category(
self, client, superuser_token, user_token
):
"""Test filtering agent types by category."""
# Create agent types in different categories
for cat in ["development", "design"]:
unique_slug = f"filter-test-{cat}-{uuid.uuid4().hex[:8]}"
await client.post(
"/api/v1/agent-types",
json={
"name": f"Filter Test {cat.capitalize()}",
"slug": unique_slug,
"expertise": ["python"],
"personality_prompt": "Test prompt",
"primary_model": "claude-opus-4-5-20251101",
"category": cat,
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Filter by development category
response = await client.get(
"/api/v1/agent-types",
params={"category": "development"},
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
# All returned types should have development category
for agent_type in data["data"]:
assert agent_type["category"] == "development"
@pytest.mark.asyncio
class TestAgentTypeGroupedEndpoint:
"""Tests for the grouped by category endpoint."""
async def test_list_agent_types_grouped(self, client, superuser_token, user_token):
"""Test getting agent types grouped by category."""
# Create agent types in different categories
categories = ["development", "design", "quality"]
for cat in categories:
unique_slug = f"grouped-test-{cat}-{uuid.uuid4().hex[:8]}"
await client.post(
"/api/v1/agent-types",
json={
"name": f"Grouped Test {cat.capitalize()}",
"slug": unique_slug,
"expertise": ["python"],
"personality_prompt": "Test prompt",
"primary_model": "claude-opus-4-5-20251101",
"category": cat,
"sort_order": 10,
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Get grouped agent types
response = await client.get(
"/api/v1/agent-types/grouped",
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
# Should be a dict with category keys
assert isinstance(data, dict)
# Check that at least one of our created categories exists
assert any(cat in data for cat in categories)
async def test_list_agent_types_grouped_filter_inactive(
self, client, superuser_token, user_token
):
"""Test grouped endpoint with is_active filter."""
response = await client.get(
"/api/v1/agent-types/grouped",
params={"is_active": False},
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, dict)
async def test_list_agent_types_grouped_unauthenticated(self, client):
"""Test that unauthenticated users cannot access grouped endpoint."""
response = await client.get("/api/v1/agent-types/grouped")
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -188,13 +188,14 @@ class TestPasswordResetConfirm:
@pytest.mark.asyncio
async def test_password_reset_confirm_expired_token(self, client, async_test_user):
"""Test password reset confirmation with expired token."""
import time as time_module
import asyncio
# Create token that expires immediately
token = create_password_reset_token(async_test_user.email, expires_in=1)
# Create token that expires at current second (expires_in=0)
# Token expires when exp < current_time, so we need to cross a second boundary
token = create_password_reset_token(async_test_user.email, expires_in=0)
# Wait for token to expire
time_module.sleep(2)
# Wait for token to expire (need to cross second boundary)
await asyncio.sleep(1.1)
response = await client.post(
"/api/v1/auth/password-reset/confirm",

View File

@@ -368,3 +368,9 @@ async def e2e_org_with_members(e2e_client, e2e_superuser):
"user_id": member_id,
},
}
# NOTE: Class-scoped fixtures for E2E tests were attempted but have fundamental
# issues with pytest-asyncio + SQLAlchemy/asyncpg event loop management.
# The function-scoped fixtures above provide proper test isolation.
# Performance optimization would require significant infrastructure changes.

View File

@@ -0,0 +1,2 @@
# tests/unit/models/memory/__init__.py
"""Unit tests for memory database models."""

View File

@@ -0,0 +1,71 @@
# tests/unit/models/memory/test_enums.py
"""Unit tests for memory model enums."""
from app.models.memory.enums import (
ConsolidationStatus,
ConsolidationType,
EpisodeOutcome,
ScopeType,
)
class TestScopeType:
"""Tests for ScopeType enum."""
def test_all_values_exist(self) -> None:
"""Test all expected scope types exist."""
assert ScopeType.GLOBAL.value == "global"
assert ScopeType.PROJECT.value == "project"
assert ScopeType.AGENT_TYPE.value == "agent_type"
assert ScopeType.AGENT_INSTANCE.value == "agent_instance"
assert ScopeType.SESSION.value == "session"
def test_scope_count(self) -> None:
"""Test we have exactly 5 scope types."""
assert len(ScopeType) == 5
class TestEpisodeOutcome:
"""Tests for EpisodeOutcome enum."""
def test_all_values_exist(self) -> None:
"""Test all expected outcome values exist."""
assert EpisodeOutcome.SUCCESS.value == "success"
assert EpisodeOutcome.FAILURE.value == "failure"
assert EpisodeOutcome.PARTIAL.value == "partial"
def test_outcome_count(self) -> None:
"""Test we have exactly 3 outcome types."""
assert len(EpisodeOutcome) == 3
class TestConsolidationType:
"""Tests for ConsolidationType enum."""
def test_all_values_exist(self) -> None:
"""Test all expected consolidation types exist."""
assert ConsolidationType.WORKING_TO_EPISODIC.value == "working_to_episodic"
assert ConsolidationType.EPISODIC_TO_SEMANTIC.value == "episodic_to_semantic"
assert (
ConsolidationType.EPISODIC_TO_PROCEDURAL.value == "episodic_to_procedural"
)
assert ConsolidationType.PRUNING.value == "pruning"
def test_consolidation_count(self) -> None:
"""Test we have exactly 4 consolidation types."""
assert len(ConsolidationType) == 4
class TestConsolidationStatus:
"""Tests for ConsolidationStatus enum."""
def test_all_values_exist(self) -> None:
"""Test all expected status values exist."""
assert ConsolidationStatus.PENDING.value == "pending"
assert ConsolidationStatus.RUNNING.value == "running"
assert ConsolidationStatus.COMPLETED.value == "completed"
assert ConsolidationStatus.FAILED.value == "failed"
def test_status_count(self) -> None:
"""Test we have exactly 4 status types."""
assert len(ConsolidationStatus) == 4

View File

@@ -0,0 +1,249 @@
# tests/unit/models/memory/test_models.py
"""Unit tests for memory database models."""
from datetime import UTC, datetime, timedelta
import pytest
from app.models.memory import (
ConsolidationStatus,
ConsolidationType,
Episode,
EpisodeOutcome,
Fact,
MemoryConsolidationLog,
Procedure,
ScopeType,
WorkingMemory,
)
class TestWorkingMemoryModel:
"""Tests for WorkingMemory model."""
def test_tablename(self) -> None:
"""Test table name is correct."""
assert WorkingMemory.__tablename__ == "working_memory"
def test_has_required_columns(self) -> None:
"""Test all required columns exist."""
columns = WorkingMemory.__table__.columns
assert "id" in columns
assert "scope_type" in columns
assert "scope_id" in columns
assert "key" in columns
assert "value" in columns
assert "expires_at" in columns
assert "created_at" in columns
assert "updated_at" in columns
def test_has_unique_constraint(self) -> None:
"""Test unique constraint on scope+key."""
indexes = {idx.name: idx for idx in WorkingMemory.__table__.indexes}
assert "ix_working_memory_scope_key" in indexes
assert indexes["ix_working_memory_scope_key"].unique
class TestEpisodeModel:
"""Tests for Episode model."""
def test_tablename(self) -> None:
"""Test table name is correct."""
assert Episode.__tablename__ == "episodes"
def test_has_required_columns(self) -> None:
"""Test all required columns exist."""
columns = Episode.__table__.columns
required = [
"id",
"project_id",
"agent_instance_id",
"agent_type_id",
"session_id",
"task_type",
"task_description",
"actions",
"context_summary",
"outcome",
"outcome_details",
"duration_seconds",
"tokens_used",
"lessons_learned",
"importance_score",
"embedding",
"occurred_at",
"created_at",
"updated_at",
]
for col in required:
assert col in columns, f"Missing column: {col}"
def test_has_foreign_keys(self) -> None:
"""Test foreign key relationships exist."""
columns = Episode.__table__.columns
assert columns["project_id"].foreign_keys
assert columns["agent_instance_id"].foreign_keys
assert columns["agent_type_id"].foreign_keys
def test_has_relationships(self) -> None:
"""Test ORM relationships exist."""
mapper = Episode.__mapper__
assert "project" in mapper.relationships
assert "agent_instance" in mapper.relationships
assert "agent_type" in mapper.relationships
class TestFactModel:
"""Tests for Fact model."""
def test_tablename(self) -> None:
"""Test table name is correct."""
assert Fact.__tablename__ == "facts"
def test_has_required_columns(self) -> None:
"""Test all required columns exist."""
columns = Fact.__table__.columns
required = [
"id",
"project_id",
"subject",
"predicate",
"object",
"confidence",
"source_episode_ids",
"first_learned",
"last_reinforced",
"reinforcement_count",
"embedding",
"created_at",
"updated_at",
]
for col in required:
assert col in columns, f"Missing column: {col}"
def test_project_id_nullable(self) -> None:
"""Test project_id is nullable for global facts."""
columns = Fact.__table__.columns
assert columns["project_id"].nullable
class TestProcedureModel:
"""Tests for Procedure model."""
def test_tablename(self) -> None:
"""Test table name is correct."""
assert Procedure.__tablename__ == "procedures"
def test_has_required_columns(self) -> None:
"""Test all required columns exist."""
columns = Procedure.__table__.columns
required = [
"id",
"project_id",
"agent_type_id",
"name",
"trigger_pattern",
"steps",
"success_count",
"failure_count",
"last_used",
"embedding",
"created_at",
"updated_at",
]
for col in required:
assert col in columns, f"Missing column: {col}"
def test_success_rate_property(self) -> None:
"""Test success_rate calculated property."""
proc = Procedure()
proc.success_count = 8
proc.failure_count = 2
assert proc.success_rate == 0.8
def test_success_rate_zero_total(self) -> None:
"""Test success_rate with zero total uses."""
proc = Procedure()
proc.success_count = 0
proc.failure_count = 0
assert proc.success_rate == 0.0
def test_total_uses_property(self) -> None:
"""Test total_uses calculated property."""
proc = Procedure()
proc.success_count = 5
proc.failure_count = 3
assert proc.total_uses == 8
class TestMemoryConsolidationLogModel:
"""Tests for MemoryConsolidationLog model."""
def test_tablename(self) -> None:
"""Test table name is correct."""
assert MemoryConsolidationLog.__tablename__ == "memory_consolidation_log"
def test_has_required_columns(self) -> None:
"""Test all required columns exist."""
columns = MemoryConsolidationLog.__table__.columns
required = [
"id",
"consolidation_type",
"source_count",
"result_count",
"started_at",
"completed_at",
"status",
"error",
"created_at",
"updated_at",
]
for col in required:
assert col in columns, f"Missing column: {col}"
def test_duration_seconds_property_completed(self) -> None:
"""Test duration_seconds with completed job."""
log = MemoryConsolidationLog()
log.started_at = datetime.now(UTC)
log.completed_at = log.started_at + timedelta(seconds=10)
assert log.duration_seconds == pytest.approx(10.0)
def test_duration_seconds_property_incomplete(self) -> None:
"""Test duration_seconds with incomplete job."""
log = MemoryConsolidationLog()
log.started_at = datetime.now(UTC)
log.completed_at = None
assert log.duration_seconds is None
def test_default_status(self) -> None:
"""Test default status is PENDING."""
columns = MemoryConsolidationLog.__table__.columns
assert columns["status"].default.arg == ConsolidationStatus.PENDING
class TestModelExports:
"""Tests for model package exports."""
def test_all_models_exported(self) -> None:
"""Test all models are exported from package."""
from app.models.memory import (
Episode,
Fact,
MemoryConsolidationLog,
Procedure,
WorkingMemory,
)
# Verify these are the actual classes
assert Episode.__tablename__ == "episodes"
assert Fact.__tablename__ == "facts"
assert Procedure.__tablename__ == "procedures"
assert WorkingMemory.__tablename__ == "working_memory"
assert MemoryConsolidationLog.__tablename__ == "memory_consolidation_log"
def test_enums_exported(self) -> None:
"""Test all enums are exported."""
assert ScopeType.GLOBAL.value == "global"
assert EpisodeOutcome.SUCCESS.value == "success"
assert ConsolidationType.WORKING_TO_EPISODIC.value == "working_to_episodic"
assert ConsolidationStatus.PENDING.value == "pending"

View File

@@ -316,3 +316,325 @@ class TestAgentTypeJsonFields:
)
assert agent_type.fallback_models == models
class TestAgentTypeCategoryFieldsValidation:
"""Tests for AgentType category and display field validation."""
def test_valid_category_values(self):
"""Test that all valid category values are accepted."""
valid_categories = [
"development",
"design",
"quality",
"operations",
"ai_ml",
"data",
"leadership",
"domain_expert",
]
for category in valid_categories:
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
category=category,
)
assert agent_type.category.value == category
def test_category_null_allowed(self):
"""Test that null category is allowed."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
category=None,
)
assert agent_type.category is None
def test_invalid_category_rejected(self):
"""Test that invalid category values are rejected."""
with pytest.raises(ValidationError):
AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
category="invalid_category",
)
def test_valid_hex_color(self):
"""Test that valid hex colors are accepted."""
valid_colors = ["#3B82F6", "#EC4899", "#10B981", "#ffffff", "#000000"]
for color in valid_colors:
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
color=color,
)
assert agent_type.color == color
def test_invalid_hex_color_rejected(self):
"""Test that invalid hex colors are rejected."""
invalid_colors = [
"not-a-color",
"3B82F6", # Missing #
"#3B82F", # Too short
"#3B82F6A", # Too long
"#GGGGGG", # Invalid hex chars
"rgb(59, 130, 246)", # RGB format not supported
]
for color in invalid_colors:
with pytest.raises(ValidationError):
AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
color=color,
)
def test_color_null_allowed(self):
"""Test that null color is allowed."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
color=None,
)
assert agent_type.color is None
def test_sort_order_valid_range(self):
"""Test that valid sort_order values are accepted."""
for sort_order in [0, 1, 500, 1000]:
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
sort_order=sort_order,
)
assert agent_type.sort_order == sort_order
def test_sort_order_default_zero(self):
"""Test that sort_order defaults to 0."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
)
assert agent_type.sort_order == 0
def test_sort_order_negative_rejected(self):
"""Test that negative sort_order is rejected."""
with pytest.raises(ValidationError):
AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
sort_order=-1,
)
def test_sort_order_exceeds_max_rejected(self):
"""Test that sort_order > 1000 is rejected."""
with pytest.raises(ValidationError):
AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
sort_order=1001,
)
def test_icon_max_length(self):
"""Test that icon field respects max length."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
icon="x" * 50,
)
assert len(agent_type.icon) == 50
def test_icon_exceeds_max_length_rejected(self):
"""Test that icon exceeding max length is rejected."""
with pytest.raises(ValidationError):
AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
icon="x" * 51,
)
class TestAgentTypeTypicalTasksValidation:
"""Tests for typical_tasks field validation."""
def test_typical_tasks_list(self):
"""Test typical_tasks as a list."""
tasks = ["Write code", "Review PRs", "Debug issues"]
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
typical_tasks=tasks,
)
assert agent_type.typical_tasks == tasks
def test_typical_tasks_default_empty(self):
"""Test typical_tasks defaults to empty list."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
)
assert agent_type.typical_tasks == []
def test_typical_tasks_strips_whitespace(self):
"""Test that typical_tasks items are stripped."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
typical_tasks=[" Write code ", " Debug "],
)
assert agent_type.typical_tasks == ["Write code", "Debug"]
def test_typical_tasks_removes_empty_strings(self):
"""Test that empty strings are removed from typical_tasks."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
typical_tasks=["Write code", "", " ", "Debug"],
)
assert agent_type.typical_tasks == ["Write code", "Debug"]
class TestAgentTypeCollaborationHintsValidation:
"""Tests for collaboration_hints field validation."""
def test_collaboration_hints_list(self):
"""Test collaboration_hints as a list."""
hints = ["backend-engineer", "qa-engineer"]
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
collaboration_hints=hints,
)
assert agent_type.collaboration_hints == hints
def test_collaboration_hints_default_empty(self):
"""Test collaboration_hints defaults to empty list."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
)
assert agent_type.collaboration_hints == []
def test_collaboration_hints_normalized_lowercase(self):
"""Test that collaboration_hints are normalized to lowercase."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
collaboration_hints=["Backend-Engineer", "QA-ENGINEER"],
)
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
def test_collaboration_hints_strips_whitespace(self):
"""Test that collaboration_hints are stripped."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
collaboration_hints=[" backend-engineer ", " qa-engineer "],
)
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
def test_collaboration_hints_removes_empty_strings(self):
"""Test that empty strings are removed from collaboration_hints."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
collaboration_hints=["backend-engineer", "", " ", "qa-engineer"],
)
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
class TestAgentTypeUpdateCategoryFields:
"""Tests for AgentTypeUpdate category and display fields."""
def test_update_category_field(self):
"""Test updating category field."""
update = AgentTypeUpdate(category="ai_ml")
assert update.category.value == "ai_ml"
def test_update_icon_field(self):
"""Test updating icon field."""
update = AgentTypeUpdate(icon="brain")
assert update.icon == "brain"
def test_update_color_field(self):
"""Test updating color field."""
update = AgentTypeUpdate(color="#8B5CF6")
assert update.color == "#8B5CF6"
def test_update_sort_order_field(self):
"""Test updating sort_order field."""
update = AgentTypeUpdate(sort_order=50)
assert update.sort_order == 50
def test_update_typical_tasks_field(self):
"""Test updating typical_tasks field."""
update = AgentTypeUpdate(typical_tasks=["New task"])
assert update.typical_tasks == ["New task"]
def test_update_typical_tasks_strips_whitespace(self):
"""Test that typical_tasks are stripped on update."""
update = AgentTypeUpdate(typical_tasks=[" New task "])
assert update.typical_tasks == ["New task"]
def test_update_collaboration_hints_field(self):
"""Test updating collaboration_hints field."""
update = AgentTypeUpdate(collaboration_hints=["new-collaborator"])
assert update.collaboration_hints == ["new-collaborator"]
def test_update_collaboration_hints_normalized(self):
"""Test that collaboration_hints are normalized on update."""
update = AgentTypeUpdate(collaboration_hints=[" New-Collaborator "])
assert update.collaboration_hints == ["new-collaborator"]
def test_update_invalid_color_rejected(self):
"""Test that invalid color is rejected on update."""
with pytest.raises(ValidationError):
AgentTypeUpdate(color="invalid")
def test_update_invalid_sort_order_rejected(self):
"""Test that invalid sort_order is rejected on update."""
with pytest.raises(ValidationError):
AgentTypeUpdate(sort_order=-1)

View File

@@ -304,10 +304,18 @@ class TestTaskModuleExports:
assert hasattr(tasks, "sync")
assert hasattr(tasks, "workflow")
assert hasattr(tasks, "cost")
assert hasattr(tasks, "memory_consolidation")
def test_tasks_all_attribute_is_correct(self):
"""Test that __all__ contains all expected module names."""
from app import tasks
expected_modules = ["agent", "git", "sync", "workflow", "cost"]
expected_modules = [
"agent",
"git",
"sync",
"workflow",
"cost",
"memory_consolidation",
]
assert set(tasks.__all__) == set(expected_modules)

View File

@@ -42,6 +42,9 @@ class TestInitDb:
assert user.last_name == "User"
@pytest.mark.asyncio
@pytest.mark.skip(
reason="SQLite doesn't support UUID type binding - requires PostgreSQL"
)
async def test_init_db_returns_existing_superuser(
self, async_test_db, async_test_user
):

View File

@@ -0,0 +1,2 @@
# tests/unit/models/__init__.py
"""Unit tests for database models."""

View File

@@ -0,0 +1,260 @@
# tests/unit/services/context/types/test_memory.py
"""Tests for MemoryContext type."""
from datetime import UTC, datetime
from unittest.mock import MagicMock
from uuid import uuid4
from app.services.context.types import ContextType
from app.services.context.types.memory import MemoryContext, MemorySubtype
class TestMemorySubtype:
"""Tests for MemorySubtype enum."""
def test_all_types_defined(self) -> None:
"""All memory subtypes should be defined."""
assert MemorySubtype.WORKING == "working"
assert MemorySubtype.EPISODIC == "episodic"
assert MemorySubtype.SEMANTIC == "semantic"
assert MemorySubtype.PROCEDURAL == "procedural"
def test_enum_values(self) -> None:
"""Enum values should match strings."""
assert MemorySubtype.WORKING.value == "working"
assert MemorySubtype("episodic") == MemorySubtype.EPISODIC
class TestMemoryContext:
"""Tests for MemoryContext class."""
def test_get_type_returns_memory(self) -> None:
"""get_type should return MEMORY."""
ctx = MemoryContext(content="test", source="test_source")
assert ctx.get_type() == ContextType.MEMORY
def test_default_values(self) -> None:
"""Default values should be set correctly."""
ctx = MemoryContext(content="test", source="test_source")
assert ctx.memory_subtype == MemorySubtype.EPISODIC
assert ctx.memory_id is None
assert ctx.relevance_score == 0.0
assert ctx.importance == 0.5
def test_to_dict_includes_memory_fields(self) -> None:
"""to_dict should include memory-specific fields."""
ctx = MemoryContext(
content="test content",
source="test_source",
memory_subtype=MemorySubtype.SEMANTIC,
memory_id="mem-123",
relevance_score=0.8,
subject="User",
predicate="prefers",
object_value="dark mode",
)
data = ctx.to_dict()
assert data["memory_subtype"] == "semantic"
assert data["memory_id"] == "mem-123"
assert data["relevance_score"] == 0.8
assert data["subject"] == "User"
assert data["predicate"] == "prefers"
assert data["object_value"] == "dark mode"
def test_from_dict(self) -> None:
"""from_dict should create correct MemoryContext."""
data = {
"content": "test content",
"source": "test_source",
"timestamp": "2024-01-01T00:00:00+00:00",
"memory_subtype": "semantic",
"memory_id": "mem-123",
"relevance_score": 0.8,
"subject": "Test",
}
ctx = MemoryContext.from_dict(data)
assert ctx.content == "test content"
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
assert ctx.memory_id == "mem-123"
assert ctx.subject == "Test"
class TestMemoryContextFromWorkingMemory:
"""Tests for MemoryContext.from_working_memory."""
def test_creates_working_memory_context(self) -> None:
"""Should create working memory context from key/value."""
ctx = MemoryContext.from_working_memory(
key="user_preferences",
value={"theme": "dark"},
source="working:sess-123",
query="preferences",
)
assert ctx.memory_subtype == MemorySubtype.WORKING
assert ctx.key == "user_preferences"
assert "{'theme': 'dark'}" in ctx.content
assert ctx.relevance_score == 1.0 # Working memory is always relevant
assert ctx.importance == 0.8 # Higher importance
def test_string_value(self) -> None:
"""Should handle string values."""
ctx = MemoryContext.from_working_memory(
key="current_task",
value="Build authentication",
)
assert ctx.content == "Build authentication"
class TestMemoryContextFromEpisodicMemory:
"""Tests for MemoryContext.from_episodic_memory."""
def test_creates_episodic_memory_context(self) -> None:
"""Should create episodic memory context from episode."""
episode = MagicMock()
episode.id = uuid4()
episode.task_description = "Implemented login feature"
episode.task_type = "feature_implementation"
episode.outcome = MagicMock(value="success")
episode.importance_score = 0.9
episode.session_id = "sess-123"
episode.occurred_at = datetime.now(UTC)
episode.lessons_learned = ["Use proper validation"]
ctx = MemoryContext.from_episodic_memory(episode, query="login")
assert ctx.memory_subtype == MemorySubtype.EPISODIC
assert ctx.memory_id == str(episode.id)
assert ctx.content == "Implemented login feature"
assert ctx.task_type == "feature_implementation"
assert ctx.outcome == "success"
assert ctx.importance == 0.9
def test_handles_missing_outcome(self) -> None:
"""Should handle episodes with no outcome."""
episode = MagicMock()
episode.id = uuid4()
episode.task_description = "WIP task"
episode.outcome = None
episode.importance_score = 0.5
episode.occurred_at = None
ctx = MemoryContext.from_episodic_memory(episode)
assert ctx.outcome is None
class TestMemoryContextFromSemanticMemory:
"""Tests for MemoryContext.from_semantic_memory."""
def test_creates_semantic_memory_context(self) -> None:
"""Should create semantic memory context from fact."""
fact = MagicMock()
fact.id = uuid4()
fact.subject = "User"
fact.predicate = "prefers"
fact.object = "dark mode"
fact.confidence = 0.95
ctx = MemoryContext.from_semantic_memory(fact, query="user preferences")
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
assert ctx.memory_id == str(fact.id)
assert ctx.content == "User prefers dark mode"
assert ctx.subject == "User"
assert ctx.predicate == "prefers"
assert ctx.object_value == "dark mode"
assert ctx.relevance_score == 0.95
class TestMemoryContextFromProceduralMemory:
"""Tests for MemoryContext.from_procedural_memory."""
def test_creates_procedural_memory_context(self) -> None:
"""Should create procedural memory context from procedure."""
procedure = MagicMock()
procedure.id = uuid4()
procedure.name = "Deploy to Production"
procedure.trigger_pattern = "When deploying to production"
procedure.steps = [
{"action": "run_tests"},
{"action": "build_docker"},
{"action": "deploy"},
]
procedure.success_rate = 0.85
procedure.success_count = 10
procedure.failure_count = 2
ctx = MemoryContext.from_procedural_memory(procedure, query="deploy")
assert ctx.memory_subtype == MemorySubtype.PROCEDURAL
assert ctx.memory_id == str(procedure.id)
assert "Deploy to Production" in ctx.content
assert "When deploying to production" in ctx.content
assert ctx.trigger == "When deploying to production"
assert ctx.success_rate == 0.85
assert ctx.metadata["steps_count"] == 3
assert ctx.metadata["execution_count"] == 12
class TestMemoryContextHelpers:
"""Tests for MemoryContext helper methods."""
def test_is_working_memory(self) -> None:
"""is_working_memory should return True for working memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.WORKING,
)
assert ctx.is_working_memory() is True
assert ctx.is_episodic_memory() is False
def test_is_episodic_memory(self) -> None:
"""is_episodic_memory should return True for episodic memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.EPISODIC,
)
assert ctx.is_episodic_memory() is True
assert ctx.is_semantic_memory() is False
def test_is_semantic_memory(self) -> None:
"""is_semantic_memory should return True for semantic memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.SEMANTIC,
)
assert ctx.is_semantic_memory() is True
assert ctx.is_procedural_memory() is False
def test_is_procedural_memory(self) -> None:
"""is_procedural_memory should return True for procedural memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.PROCEDURAL,
)
assert ctx.is_procedural_memory() is True
assert ctx.is_working_memory() is False
def test_get_formatted_source(self) -> None:
"""get_formatted_source should return formatted string."""
ctx = MemoryContext(
content="test",
source="episodic:12345678-1234-1234-1234-123456789012",
memory_subtype=MemorySubtype.EPISODIC,
memory_id="12345678-1234-1234-1234-123456789012",
)
formatted = ctx.get_formatted_source()
assert "[episodic]" in formatted
assert "12345678..." in formatted

View File

@@ -0,0 +1 @@
"""Tests for the Agent Memory System."""

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/cache/__init__.py
"""Tests for memory caching layer."""

View File

@@ -0,0 +1,331 @@
# tests/unit/services/memory/cache/test_cache_manager.py
"""Tests for CacheManager."""
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from app.services.memory.cache.cache_manager import (
CacheManager,
CacheStats,
get_cache_manager,
reset_cache_manager,
)
from app.services.memory.cache.embedding_cache import EmbeddingCache
from app.services.memory.cache.hot_cache import HotMemoryCache
pytestmark = pytest.mark.asyncio(loop_scope="function")
@pytest.fixture(autouse=True)
def reset_singleton() -> None:
"""Reset singleton before each test."""
reset_cache_manager()
class TestCacheStats:
"""Tests for CacheStats."""
def test_to_dict(self) -> None:
"""Should convert to dictionary."""
from datetime import UTC, datetime
stats = CacheStats(
hot_cache={"hits": 10},
embedding_cache={"hits": 20},
overall_hit_rate=0.75,
last_cleanup=datetime.now(UTC),
cleanup_count=5,
)
result = stats.to_dict()
assert result["hot_cache"] == {"hits": 10}
assert result["overall_hit_rate"] == 0.75
assert result["cleanup_count"] == 5
assert result["last_cleanup"] is not None
class TestCacheManager:
"""Tests for CacheManager."""
@pytest.fixture
def manager(self) -> CacheManager:
"""Create a cache manager."""
return CacheManager()
def test_is_enabled(self, manager: CacheManager) -> None:
"""Should check if caching is enabled."""
# Default is enabled from settings
assert manager.is_enabled is True
def test_has_hot_cache(self, manager: CacheManager) -> None:
"""Should have hot memory cache."""
assert manager.hot_cache is not None
assert isinstance(manager.hot_cache, HotMemoryCache)
def test_has_embedding_cache(self, manager: CacheManager) -> None:
"""Should have embedding cache."""
assert manager.embedding_cache is not None
assert isinstance(manager.embedding_cache, EmbeddingCache)
def test_cache_memory(self, manager: CacheManager) -> None:
"""Should cache memory in hot cache."""
memory_id = uuid4()
memory = {"task": "test", "data": "value"}
manager.cache_memory("episodic", memory_id, memory)
result = manager.get_memory("episodic", memory_id)
assert result == memory
def test_cache_memory_with_scope(self, manager: CacheManager) -> None:
"""Should cache memory with scope."""
memory_id = uuid4()
memory = {"task": "test"}
manager.cache_memory("semantic", memory_id, memory, scope="proj-123")
result = manager.get_memory("semantic", memory_id, scope="proj-123")
assert result == memory
async def test_cache_embedding(self, manager: CacheManager) -> None:
"""Should cache embedding."""
content = "test content"
embedding = [0.1, 0.2, 0.3]
content_hash = await manager.cache_embedding(content, embedding)
result = await manager.get_embedding(content)
assert result == embedding
assert len(content_hash) == 32
async def test_invalidate_memory(self, manager: CacheManager) -> None:
"""Should invalidate memory from hot cache."""
memory_id = uuid4()
manager.cache_memory("episodic", memory_id, {"data": "test"})
count = await manager.invalidate_memory("episodic", memory_id)
assert count >= 1
assert manager.get_memory("episodic", memory_id) is None
async def test_invalidate_by_type(self, manager: CacheManager) -> None:
"""Should invalidate all entries of a type."""
manager.cache_memory("episodic", uuid4(), {"data": "1"})
manager.cache_memory("episodic", uuid4(), {"data": "2"})
manager.cache_memory("semantic", uuid4(), {"data": "3"})
count = await manager.invalidate_by_type("episodic")
assert count >= 2
async def test_invalidate_by_scope(self, manager: CacheManager) -> None:
"""Should invalidate all entries in a scope."""
manager.cache_memory("episodic", uuid4(), {"data": "1"}, scope="proj-1")
manager.cache_memory("semantic", uuid4(), {"data": "2"}, scope="proj-1")
manager.cache_memory("episodic", uuid4(), {"data": "3"}, scope="proj-2")
count = await manager.invalidate_by_scope("proj-1")
assert count >= 2
async def test_invalidate_embedding(self, manager: CacheManager) -> None:
"""Should invalidate cached embedding."""
content = "test content"
await manager.cache_embedding(content, [0.1, 0.2])
result = await manager.invalidate_embedding(content)
assert result is True
assert await manager.get_embedding(content) is None
async def test_clear_all(self, manager: CacheManager) -> None:
"""Should clear all caches."""
manager.cache_memory("episodic", uuid4(), {"data": "test"})
await manager.cache_embedding("content", [0.1])
count = await manager.clear_all()
assert count >= 2
async def test_cleanup_expired(self, manager: CacheManager) -> None:
"""Should clean up expired entries."""
count = await manager.cleanup_expired()
# May be 0 if no expired entries
assert count >= 0
assert manager._cleanup_count == 1
assert manager._last_cleanup is not None
def test_get_stats(self, manager: CacheManager) -> None:
"""Should return aggregated statistics."""
manager.cache_memory("episodic", uuid4(), {"data": "test"})
stats = manager.get_stats()
assert "hot_cache" in stats.to_dict()
assert "embedding_cache" in stats.to_dict()
assert "overall_hit_rate" in stats.to_dict()
def test_get_hot_memories(self, manager: CacheManager) -> None:
"""Should return most accessed memories."""
id1 = uuid4()
id2 = uuid4()
manager.cache_memory("episodic", id1, {"data": "1"})
manager.cache_memory("episodic", id2, {"data": "2"})
# Access first multiple times
for _ in range(5):
manager.get_memory("episodic", id1)
hot = manager.get_hot_memories(limit=2)
assert len(hot) == 2
def test_reset_stats(self, manager: CacheManager) -> None:
"""Should reset all statistics."""
manager.cache_memory("episodic", uuid4(), {"data": "test"})
manager.get_memory("episodic", uuid4()) # Miss
manager.reset_stats()
stats = manager.get_stats()
assert stats.hot_cache.get("hits", 0) == 0
async def test_warmup(self, manager: CacheManager) -> None:
"""Should warm up cache with memories."""
memories = [
("episodic", uuid4(), {"data": "1"}),
("episodic", uuid4(), {"data": "2"}),
("semantic", uuid4(), {"data": "3"}),
]
count = await manager.warmup(memories)
assert count == 3
class TestCacheManagerWithRetrieval:
"""Tests for CacheManager with retrieval cache."""
@pytest.fixture
def mock_retrieval_cache(self) -> MagicMock:
"""Create mock retrieval cache."""
cache = MagicMock()
cache.invalidate_by_memory = MagicMock(return_value=1)
cache.clear = MagicMock(return_value=5)
cache.get_stats = MagicMock(return_value={"entries": 10})
return cache
@pytest.fixture
def manager_with_retrieval(
self,
mock_retrieval_cache: MagicMock,
) -> CacheManager:
"""Create manager with retrieval cache."""
manager = CacheManager()
manager.set_retrieval_cache(mock_retrieval_cache)
return manager
async def test_invalidate_clears_retrieval(
self,
manager_with_retrieval: CacheManager,
mock_retrieval_cache: MagicMock,
) -> None:
"""Should invalidate retrieval cache entries."""
memory_id = uuid4()
await manager_with_retrieval.invalidate_memory("episodic", memory_id)
mock_retrieval_cache.invalidate_by_memory.assert_called_once_with(memory_id)
def test_stats_includes_retrieval(
self,
manager_with_retrieval: CacheManager,
) -> None:
"""Should include retrieval cache stats."""
stats = manager_with_retrieval.get_stats()
assert "retrieval_cache" in stats.to_dict()
class TestCacheManagerDisabled:
"""Tests for CacheManager when disabled."""
@pytest.fixture
def disabled_manager(self) -> CacheManager:
"""Create a disabled cache manager."""
with patch(
"app.services.memory.cache.cache_manager.get_memory_settings"
) as mock_settings:
settings = MagicMock()
settings.cache_enabled = False
settings.cache_max_items = 1000
settings.cache_ttl_seconds = 300
mock_settings.return_value = settings
return CacheManager()
def test_get_memory_returns_none(self, disabled_manager: CacheManager) -> None:
"""Should return None when disabled."""
disabled_manager.cache_memory("episodic", uuid4(), {"data": "test"})
result = disabled_manager.get_memory("episodic", uuid4())
assert result is None
async def test_get_embedding_returns_none(
self,
disabled_manager: CacheManager,
) -> None:
"""Should return None for embeddings when disabled."""
result = await disabled_manager.get_embedding("content")
assert result is None
async def test_warmup_returns_zero(self, disabled_manager: CacheManager) -> None:
"""Should return 0 from warmup when disabled."""
count = await disabled_manager.warmup([("episodic", uuid4(), {})])
assert count == 0
class TestGetCacheManager:
"""Tests for get_cache_manager factory."""
def test_returns_singleton(self) -> None:
"""Should return same instance."""
manager1 = get_cache_manager()
manager2 = get_cache_manager()
assert manager1 is manager2
def test_reset_creates_new(self) -> None:
"""Should create new instance after reset."""
manager1 = get_cache_manager()
reset_cache_manager()
manager2 = get_cache_manager()
assert manager1 is not manager2
def test_reset_parameter(self) -> None:
"""Should create new instance with reset=True."""
manager1 = get_cache_manager()
manager2 = get_cache_manager(reset=True)
assert manager1 is not manager2
class TestResetCacheManager:
"""Tests for reset_cache_manager."""
def test_resets_singleton(self) -> None:
"""Should reset the singleton."""
get_cache_manager()
reset_cache_manager()
# Next call should create new instance
manager = get_cache_manager()
assert manager is not None

View File

@@ -0,0 +1,391 @@
# tests/unit/services/memory/cache/test_embedding_cache.py
"""Tests for EmbeddingCache."""
import time
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.memory.cache.embedding_cache import (
CachedEmbeddingGenerator,
EmbeddingCache,
EmbeddingCacheStats,
EmbeddingEntry,
create_embedding_cache,
)
pytestmark = pytest.mark.asyncio(loop_scope="function")
class TestEmbeddingEntry:
"""Tests for EmbeddingEntry."""
def test_creates_entry(self) -> None:
"""Should create entry with embedding."""
from datetime import UTC, datetime
entry = EmbeddingEntry(
embedding=[0.1, 0.2, 0.3],
content_hash="abc123",
model="text-embedding-3-small",
created_at=datetime.now(UTC),
)
assert entry.embedding == [0.1, 0.2, 0.3]
assert entry.content_hash == "abc123"
assert entry.ttl_seconds == 3600.0
def test_is_expired(self) -> None:
"""Should detect expired entries."""
from datetime import UTC, datetime, timedelta
old_time = datetime.now(UTC) - timedelta(seconds=4000)
entry = EmbeddingEntry(
embedding=[0.1],
content_hash="abc",
model="default",
created_at=old_time,
ttl_seconds=3600.0,
)
assert entry.is_expired() is True
def test_not_expired(self) -> None:
"""Should detect non-expired entries."""
from datetime import UTC, datetime
entry = EmbeddingEntry(
embedding=[0.1],
content_hash="abc",
model="default",
created_at=datetime.now(UTC),
)
assert entry.is_expired() is False
class TestEmbeddingCacheStats:
"""Tests for EmbeddingCacheStats."""
def test_hit_rate_calculation(self) -> None:
"""Should calculate hit rate correctly."""
stats = EmbeddingCacheStats(hits=90, misses=10)
assert stats.hit_rate == 0.9
def test_hit_rate_zero_requests(self) -> None:
"""Should return 0 for no requests."""
stats = EmbeddingCacheStats()
assert stats.hit_rate == 0.0
def test_to_dict(self) -> None:
"""Should convert to dictionary."""
stats = EmbeddingCacheStats(hits=10, misses=5, bytes_saved=1000)
result = stats.to_dict()
assert result["hits"] == 10
assert result["bytes_saved"] == 1000
class TestEmbeddingCache:
"""Tests for EmbeddingCache."""
@pytest.fixture
def cache(self) -> EmbeddingCache:
"""Create an embedding cache."""
return EmbeddingCache(max_size=100, default_ttl_seconds=300.0)
async def test_put_and_get(self, cache: EmbeddingCache) -> None:
"""Should store and retrieve embeddings."""
content = "Hello world"
embedding = [0.1, 0.2, 0.3, 0.4]
content_hash = await cache.put(content, embedding)
result = await cache.get(content)
assert result == embedding
assert len(content_hash) == 32
async def test_get_missing(self, cache: EmbeddingCache) -> None:
"""Should return None for missing content."""
result = await cache.get("nonexistent content")
assert result is None
async def test_get_by_hash(self, cache: EmbeddingCache) -> None:
"""Should get by content hash."""
content = "Test content"
embedding = [0.1, 0.2]
content_hash = await cache.put(content, embedding)
result = await cache.get_by_hash(content_hash)
assert result == embedding
async def test_model_separation(self, cache: EmbeddingCache) -> None:
"""Should separate embeddings by model."""
content = "Same content"
emb1 = [0.1, 0.2]
emb2 = [0.3, 0.4]
await cache.put(content, emb1, model="model-a")
await cache.put(content, emb2, model="model-b")
result1 = await cache.get(content, model="model-a")
result2 = await cache.get(content, model="model-b")
assert result1 == emb1
assert result2 == emb2
async def test_lru_eviction(self) -> None:
"""Should evict LRU entries when at capacity."""
cache = EmbeddingCache(max_size=3)
await cache.put("content1", [0.1])
await cache.put("content2", [0.2])
await cache.put("content3", [0.3])
# Access first to make it recent
await cache.get("content1")
# Add fourth, should evict second (LRU)
await cache.put("content4", [0.4])
assert await cache.get("content1") is not None
assert await cache.get("content2") is None # Evicted
assert await cache.get("content3") is not None
assert await cache.get("content4") is not None
async def test_ttl_expiration(self) -> None:
"""Should expire entries after TTL."""
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.05)
await cache.put("content", [0.1, 0.2])
time.sleep(0.06)
result = await cache.get("content")
assert result is None
async def test_put_batch(self, cache: EmbeddingCache) -> None:
"""Should cache multiple embeddings."""
items = [
("content1", [0.1]),
("content2", [0.2]),
("content3", [0.3]),
]
hashes = await cache.put_batch(items)
assert len(hashes) == 3
assert await cache.get("content1") == [0.1]
assert await cache.get("content2") == [0.2]
async def test_invalidate(self, cache: EmbeddingCache) -> None:
"""Should invalidate cached embedding."""
await cache.put("content", [0.1, 0.2])
result = await cache.invalidate("content")
assert result is True
assert await cache.get("content") is None
async def test_invalidate_by_hash(self, cache: EmbeddingCache) -> None:
"""Should invalidate by hash."""
content_hash = await cache.put("content", [0.1, 0.2])
result = await cache.invalidate_by_hash(content_hash)
assert result is True
assert await cache.get("content") is None
async def test_invalidate_by_model(self, cache: EmbeddingCache) -> None:
"""Should invalidate all embeddings for a model."""
await cache.put("content1", [0.1], model="model-a")
await cache.put("content2", [0.2], model="model-a")
await cache.put("content3", [0.3], model="model-b")
count = await cache.invalidate_by_model("model-a")
assert count == 2
assert await cache.get("content1", model="model-a") is None
assert await cache.get("content3", model="model-b") is not None
async def test_clear(self, cache: EmbeddingCache) -> None:
"""Should clear all entries."""
await cache.put("content1", [0.1])
await cache.put("content2", [0.2])
count = await cache.clear()
assert count == 2
assert cache.size == 0
def test_cleanup_expired(self) -> None:
"""Should remove expired entries."""
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.05)
# Use synchronous put for setup
cache._put_memory("hash1", "default", [0.1])
cache._put_memory("hash2", "default", [0.2], ttl_seconds=10)
time.sleep(0.06)
count = cache.cleanup_expired()
assert count == 1
def test_get_stats(self, cache: EmbeddingCache) -> None:
"""Should return accurate statistics."""
# Put synchronously for setup
cache._put_memory("hash1", "default", [0.1])
stats = cache.get_stats()
assert stats.current_size == 1
def test_hash_content(self) -> None:
"""Should produce consistent hashes."""
hash1 = EmbeddingCache.hash_content("test content")
hash2 = EmbeddingCache.hash_content("test content")
hash3 = EmbeddingCache.hash_content("different content")
assert hash1 == hash2
assert hash1 != hash3
assert len(hash1) == 32
class TestEmbeddingCacheWithRedis:
"""Tests for EmbeddingCache with Redis."""
@pytest.fixture
def mock_redis(self) -> MagicMock:
"""Create mock Redis."""
redis = MagicMock()
redis.get = AsyncMock(return_value=None)
redis.setex = AsyncMock()
redis.delete = AsyncMock()
redis.scan_iter = MagicMock(return_value=iter([]))
return redis
@pytest.fixture
def cache_with_redis(self, mock_redis: MagicMock) -> EmbeddingCache:
"""Create cache with mock Redis."""
return EmbeddingCache(
max_size=100,
default_ttl_seconds=300.0,
redis=mock_redis,
)
async def test_put_stores_in_redis(
self,
cache_with_redis: EmbeddingCache,
mock_redis: MagicMock,
) -> None:
"""Should store in Redis when available."""
await cache_with_redis.put("content", [0.1, 0.2])
mock_redis.setex.assert_called_once()
async def test_get_checks_redis_on_miss(
self,
cache_with_redis: EmbeddingCache,
mock_redis: MagicMock,
) -> None:
"""Should check Redis when memory cache misses."""
import json
mock_redis.get.return_value = json.dumps([0.1, 0.2])
result = await cache_with_redis.get("content")
assert result == [0.1, 0.2]
mock_redis.get.assert_called_once()
class TestCachedEmbeddingGenerator:
"""Tests for CachedEmbeddingGenerator."""
@pytest.fixture
def mock_generator(self) -> MagicMock:
"""Create mock embedding generator."""
gen = MagicMock()
gen.generate = AsyncMock(return_value=[0.1, 0.2, 0.3])
gen.generate_batch = AsyncMock(return_value=[[0.1], [0.2], [0.3]])
return gen
@pytest.fixture
def cache(self) -> EmbeddingCache:
"""Create embedding cache."""
return EmbeddingCache(max_size=100)
@pytest.fixture
def cached_gen(
self,
mock_generator: MagicMock,
cache: EmbeddingCache,
) -> CachedEmbeddingGenerator:
"""Create cached generator."""
return CachedEmbeddingGenerator(mock_generator, cache)
async def test_generate_caches_result(
self,
cached_gen: CachedEmbeddingGenerator,
mock_generator: MagicMock,
) -> None:
"""Should cache generated embedding."""
result1 = await cached_gen.generate("test text")
result2 = await cached_gen.generate("test text")
assert result1 == [0.1, 0.2, 0.3]
assert result2 == [0.1, 0.2, 0.3]
mock_generator.generate.assert_called_once() # Only called once
async def test_generate_batch_uses_cache(
self,
cached_gen: CachedEmbeddingGenerator,
mock_generator: MagicMock,
cache: EmbeddingCache,
) -> None:
"""Should use cache for batch generation."""
# Pre-cache one embedding
await cache.put("text1", [0.5])
# Mock returns 2 embeddings for the 2 uncached texts
mock_generator.generate_batch = AsyncMock(return_value=[[0.2], [0.3]])
results = await cached_gen.generate_batch(["text1", "text2", "text3"])
assert len(results) == 3
assert results[0] == [0.5] # From cache
assert results[1] == [0.2] # Generated
assert results[2] == [0.3] # Generated
async def test_get_stats(self, cached_gen: CachedEmbeddingGenerator) -> None:
"""Should return generator statistics."""
await cached_gen.generate("text1")
await cached_gen.generate("text1") # Cache hit
stats = cached_gen.get_stats()
assert stats["call_count"] == 2
assert stats["cache_hit_count"] == 1
class TestCreateEmbeddingCache:
"""Tests for factory function."""
def test_creates_cache(self) -> None:
"""Should create cache with defaults."""
cache = create_embedding_cache()
assert cache.max_size == 50000
def test_creates_cache_with_options(self) -> None:
"""Should create cache with custom options."""
cache = create_embedding_cache(max_size=1000, default_ttl_seconds=600.0)
assert cache.max_size == 1000

View File

@@ -0,0 +1,355 @@
# tests/unit/services/memory/cache/test_hot_cache.py
"""Tests for HotMemoryCache."""
import time
from uuid import uuid4
import pytest
from app.services.memory.cache.hot_cache import (
CacheEntry,
CacheKey,
HotCacheStats,
HotMemoryCache,
create_hot_cache,
)
class TestCacheKey:
"""Tests for CacheKey."""
def test_creates_key(self) -> None:
"""Should create key with required fields."""
key = CacheKey(memory_type="episodic", memory_id="123")
assert key.memory_type == "episodic"
assert key.memory_id == "123"
assert key.scope is None
def test_creates_key_with_scope(self) -> None:
"""Should create key with scope."""
key = CacheKey(memory_type="semantic", memory_id="456", scope="proj-123")
assert key.scope == "proj-123"
def test_hash_and_equality(self) -> None:
"""Keys with same values should be equal and have same hash."""
key1 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
key2 = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
assert key1 == key2
assert hash(key1) == hash(key2)
def test_str_representation(self) -> None:
"""Should produce readable string."""
key = CacheKey(memory_type="episodic", memory_id="123", scope="proj-1")
assert str(key) == "episodic:proj-1:123"
def test_str_without_scope(self) -> None:
"""Should produce string without scope."""
key = CacheKey(memory_type="episodic", memory_id="123")
assert str(key) == "episodic:123"
class TestCacheEntry:
"""Tests for CacheEntry."""
def test_creates_entry(self) -> None:
"""Should create entry with value."""
entry = CacheEntry(
value={"data": "test"},
created_at=pytest.importorskip("datetime").datetime.now(
pytest.importorskip("datetime").UTC
),
last_accessed_at=pytest.importorskip("datetime").datetime.now(
pytest.importorskip("datetime").UTC
),
)
assert entry.value == {"data": "test"}
assert entry.access_count == 1
assert entry.ttl_seconds == 300.0
def test_is_expired(self) -> None:
"""Should detect expired entries."""
from datetime import UTC, datetime, timedelta
old_time = datetime.now(UTC) - timedelta(seconds=400)
entry = CacheEntry(
value="test",
created_at=old_time,
last_accessed_at=old_time,
ttl_seconds=300.0,
)
assert entry.is_expired() is True
def test_not_expired(self) -> None:
"""Should detect non-expired entries."""
from datetime import UTC, datetime
entry = CacheEntry(
value="test",
created_at=datetime.now(UTC),
last_accessed_at=datetime.now(UTC),
ttl_seconds=300.0,
)
assert entry.is_expired() is False
def test_touch_updates_access(self) -> None:
"""Touch should update access time and count."""
from datetime import UTC, datetime, timedelta
old_time = datetime.now(UTC) - timedelta(seconds=10)
entry = CacheEntry(
value="test",
created_at=old_time,
last_accessed_at=old_time,
access_count=5,
)
entry.touch()
assert entry.access_count == 6
assert entry.last_accessed_at > old_time
class TestHotCacheStats:
"""Tests for HotCacheStats."""
def test_hit_rate_calculation(self) -> None:
"""Should calculate hit rate correctly."""
stats = HotCacheStats(hits=80, misses=20)
assert stats.hit_rate == 0.8
def test_hit_rate_zero_requests(self) -> None:
"""Should return 0 for no requests."""
stats = HotCacheStats()
assert stats.hit_rate == 0.0
def test_to_dict(self) -> None:
"""Should convert to dictionary."""
stats = HotCacheStats(hits=10, misses=5, evictions=2)
result = stats.to_dict()
assert result["hits"] == 10
assert result["misses"] == 5
assert result["evictions"] == 2
assert "hit_rate" in result
class TestHotMemoryCache:
"""Tests for HotMemoryCache."""
@pytest.fixture
def cache(self) -> HotMemoryCache[dict]:
"""Create a hot memory cache."""
return HotMemoryCache[dict](max_size=100, default_ttl_seconds=300.0)
def test_put_and_get(self, cache: HotMemoryCache[dict]) -> None:
"""Should store and retrieve values."""
key = CacheKey(memory_type="episodic", memory_id="123")
value = {"data": "test"}
cache.put(key, value)
result = cache.get(key)
assert result == value
def test_get_missing_key(self, cache: HotMemoryCache[dict]) -> None:
"""Should return None for missing keys."""
key = CacheKey(memory_type="episodic", memory_id="nonexistent")
result = cache.get(key)
assert result is None
def test_put_by_id(self, cache: HotMemoryCache[dict]) -> None:
"""Should store by type and ID."""
memory_id = uuid4()
value = {"data": "test"}
cache.put_by_id("episodic", memory_id, value)
result = cache.get_by_id("episodic", memory_id)
assert result == value
def test_put_by_id_with_scope(self, cache: HotMemoryCache[dict]) -> None:
"""Should store with scope."""
memory_id = uuid4()
value = {"data": "test"}
cache.put_by_id("semantic", memory_id, value, scope="proj-123")
result = cache.get_by_id("semantic", memory_id, scope="proj-123")
assert result == value
def test_lru_eviction(self) -> None:
"""Should evict LRU entries when at capacity."""
cache = HotMemoryCache[str](max_size=3)
# Fill cache
cache.put_by_id("test", "1", "first")
cache.put_by_id("test", "2", "second")
cache.put_by_id("test", "3", "third")
# Access first to make it recent
cache.get_by_id("test", "1")
# Add fourth, should evict second (LRU)
cache.put_by_id("test", "4", "fourth")
assert cache.get_by_id("test", "1") is not None # Accessed, kept
assert cache.get_by_id("test", "2") is None # Evicted (LRU)
assert cache.get_by_id("test", "3") is not None
assert cache.get_by_id("test", "4") is not None
def test_ttl_expiration(self) -> None:
"""Should expire entries after TTL."""
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.05)
cache.put_by_id("test", "1", "value")
# Wait for expiration
time.sleep(0.06)
result = cache.get_by_id("test", "1")
assert result is None
def test_invalidate(self, cache: HotMemoryCache[dict]) -> None:
"""Should invalidate specific entry."""
key = CacheKey(memory_type="episodic", memory_id="123")
cache.put(key, {"data": "test"})
result = cache.invalidate(key)
assert result is True
assert cache.get(key) is None
def test_invalidate_by_id(self, cache: HotMemoryCache[dict]) -> None:
"""Should invalidate by ID."""
memory_id = uuid4()
cache.put_by_id("episodic", memory_id, {"data": "test"})
result = cache.invalidate_by_id("episodic", memory_id)
assert result is True
assert cache.get_by_id("episodic", memory_id) is None
def test_invalidate_by_type(self, cache: HotMemoryCache[dict]) -> None:
"""Should invalidate all entries of a type."""
cache.put_by_id("episodic", "1", {"data": "1"})
cache.put_by_id("episodic", "2", {"data": "2"})
cache.put_by_id("semantic", "3", {"data": "3"})
count = cache.invalidate_by_type("episodic")
assert count == 2
assert cache.get_by_id("episodic", "1") is None
assert cache.get_by_id("episodic", "2") is None
assert cache.get_by_id("semantic", "3") is not None
def test_invalidate_by_scope(self, cache: HotMemoryCache[dict]) -> None:
"""Should invalidate all entries in a scope."""
cache.put_by_id("episodic", "1", {"data": "1"}, scope="proj-1")
cache.put_by_id("semantic", "2", {"data": "2"}, scope="proj-1")
cache.put_by_id("episodic", "3", {"data": "3"}, scope="proj-2")
count = cache.invalidate_by_scope("proj-1")
assert count == 2
assert cache.get_by_id("episodic", "3", scope="proj-2") is not None
def test_invalidate_pattern(self, cache: HotMemoryCache[dict]) -> None:
"""Should invalidate entries matching pattern."""
cache.put_by_id("episodic", "123", {"data": "1"})
cache.put_by_id("episodic", "124", {"data": "2"})
cache.put_by_id("semantic", "125", {"data": "3"})
count = cache.invalidate_pattern("episodic:*")
assert count == 2
def test_clear(self, cache: HotMemoryCache[dict]) -> None:
"""Should clear all entries."""
cache.put_by_id("episodic", "1", {"data": "1"})
cache.put_by_id("semantic", "2", {"data": "2"})
count = cache.clear()
assert count == 2
assert cache.size == 0
def test_cleanup_expired(self) -> None:
"""Should remove expired entries."""
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.05)
cache.put_by_id("test", "1", "value1")
cache.put_by_id("test", "2", "value2", ttl_seconds=10)
time.sleep(0.06)
count = cache.cleanup_expired()
assert count == 1 # Only the first one expired
assert cache.size == 1
def test_get_hot_memories(self, cache: HotMemoryCache[dict]) -> None:
"""Should return most accessed memories."""
cache.put_by_id("episodic", "1", {"data": "1"})
cache.put_by_id("episodic", "2", {"data": "2"})
# Access first one multiple times
for _ in range(5):
cache.get_by_id("episodic", "1")
hot = cache.get_hot_memories(limit=2)
assert len(hot) == 2
assert hot[0][1] >= hot[1][1] # Sorted by access count
def test_get_stats(self, cache: HotMemoryCache[dict]) -> None:
"""Should return accurate statistics."""
cache.put_by_id("episodic", "1", {"data": "1"})
cache.get_by_id("episodic", "1") # Hit
cache.get_by_id("episodic", "2") # Miss
stats = cache.get_stats()
assert stats.hits == 1
assert stats.misses == 1
assert stats.current_size == 1
def test_reset_stats(self, cache: HotMemoryCache[dict]) -> None:
"""Should reset statistics."""
cache.put_by_id("episodic", "1", {"data": "1"})
cache.get_by_id("episodic", "1")
cache.reset_stats()
stats = cache.get_stats()
assert stats.hits == 0
assert stats.misses == 0
class TestCreateHotCache:
"""Tests for factory function."""
def test_creates_cache(self) -> None:
"""Should create cache with defaults."""
cache = create_hot_cache()
assert cache.max_size == 10000
def test_creates_cache_with_options(self) -> None:
"""Should create cache with custom options."""
cache = create_hot_cache(max_size=500, default_ttl_seconds=60.0)
assert cache.max_size == 500

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/consolidation/__init__.py
"""Tests for memory consolidation."""

View File

@@ -0,0 +1,736 @@
# tests/unit/services/memory/consolidation/test_service.py
"""Unit tests for memory consolidation service."""
from datetime import UTC, datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from app.services.memory.consolidation.service import (
ConsolidationConfig,
ConsolidationResult,
MemoryConsolidationService,
NightlyConsolidationResult,
SessionConsolidationResult,
)
from app.services.memory.types import Episode, Outcome, TaskState
def _utcnow() -> datetime:
"""Get current UTC time."""
return datetime.now(UTC)
def make_episode(
outcome: Outcome = Outcome.SUCCESS,
occurred_at: datetime | None = None,
task_type: str = "test_task",
lessons_learned: list[str] | None = None,
importance_score: float = 0.5,
actions: list[dict] | None = None,
) -> Episode:
"""Create a test episode."""
return Episode(
id=uuid4(),
project_id=uuid4(),
agent_instance_id=uuid4(),
agent_type_id=uuid4(),
session_id="test-session",
task_type=task_type,
task_description="Test task description",
actions=actions or [{"action": "test"}],
context_summary="Test context",
outcome=outcome,
outcome_details="Test outcome",
duration_seconds=10.0,
tokens_used=100,
lessons_learned=lessons_learned or [],
importance_score=importance_score,
embedding=None,
occurred_at=occurred_at or _utcnow(),
created_at=_utcnow(),
updated_at=_utcnow(),
)
def make_task_state(
current_step: int = 5,
total_steps: int = 10,
progress_percent: float = 50.0,
status: str = "in_progress",
description: str = "Test Task",
) -> TaskState:
"""Create a test task state."""
now = _utcnow()
return TaskState(
task_id="test-task-id",
task_type="test_task",
description=description,
current_step=current_step,
total_steps=total_steps,
status=status,
progress_percent=progress_percent,
started_at=now - timedelta(hours=1),
updated_at=now,
)
class TestConsolidationConfig:
"""Tests for ConsolidationConfig."""
def test_default_values(self) -> None:
"""Test default configuration values."""
config = ConsolidationConfig()
assert config.min_steps_for_episode == 2
assert config.min_duration_seconds == 5.0
assert config.min_confidence_for_fact == 0.6
assert config.max_facts_per_episode == 10
assert config.min_episodes_for_procedure == 3
assert config.max_episode_age_days == 90
assert config.batch_size == 100
def test_custom_values(self) -> None:
"""Test custom configuration values."""
config = ConsolidationConfig(
min_steps_for_episode=5,
batch_size=50,
)
assert config.min_steps_for_episode == 5
assert config.batch_size == 50
class TestConsolidationResult:
"""Tests for ConsolidationResult."""
def test_creation(self) -> None:
"""Test creating a consolidation result."""
result = ConsolidationResult(
source_type="episodic",
target_type="semantic",
items_processed=10,
items_created=5,
)
assert result.source_type == "episodic"
assert result.target_type == "semantic"
assert result.items_processed == 10
assert result.items_created == 5
assert result.items_skipped == 0
assert result.errors == []
def test_to_dict(self) -> None:
"""Test converting to dictionary."""
result = ConsolidationResult(
source_type="episodic",
target_type="semantic",
items_processed=10,
items_created=5,
errors=["test error"],
)
d = result.to_dict()
assert d["source_type"] == "episodic"
assert d["target_type"] == "semantic"
assert d["items_processed"] == 10
assert d["items_created"] == 5
assert "test error" in d["errors"]
class TestSessionConsolidationResult:
"""Tests for SessionConsolidationResult."""
def test_creation(self) -> None:
"""Test creating a session consolidation result."""
result = SessionConsolidationResult(
session_id="test-session",
episode_created=True,
episode_id=uuid4(),
scratchpad_entries=5,
)
assert result.session_id == "test-session"
assert result.episode_created is True
assert result.episode_id is not None
class TestNightlyConsolidationResult:
"""Tests for NightlyConsolidationResult."""
def test_creation(self) -> None:
"""Test creating a nightly consolidation result."""
result = NightlyConsolidationResult(
started_at=_utcnow(),
)
assert result.started_at is not None
assert result.completed_at is None
assert result.total_episodes_processed == 0
def test_to_dict(self) -> None:
"""Test converting to dictionary."""
result = NightlyConsolidationResult(
started_at=_utcnow(),
completed_at=_utcnow(),
total_facts_created=5,
total_procedures_created=2,
)
d = result.to_dict()
assert "started_at" in d
assert "completed_at" in d
assert d["total_facts_created"] == 5
assert d["total_procedures_created"] == 2
class TestMemoryConsolidationService:
"""Tests for MemoryConsolidationService."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
return AsyncMock()
@pytest.fixture
def service(self, mock_session: AsyncMock) -> MemoryConsolidationService:
"""Create a consolidation service with mocked dependencies."""
return MemoryConsolidationService(
session=mock_session,
config=ConsolidationConfig(),
)
# =========================================================================
# Session Consolidation Tests
# =========================================================================
@pytest.mark.asyncio
async def test_consolidate_session_insufficient_steps(
self, service: MemoryConsolidationService
) -> None:
"""Test session not consolidated when insufficient steps."""
mock_working_memory = AsyncMock()
task_state = make_task_state(current_step=1) # Less than min_steps_for_episode
mock_working_memory.get_task_state.return_value = task_state
result = await service.consolidate_session(
working_memory=mock_working_memory,
project_id=uuid4(),
session_id="test-session",
)
assert result.episode_created is False
assert result.episode_id is None
@pytest.mark.asyncio
async def test_consolidate_session_no_task_state(
self, service: MemoryConsolidationService
) -> None:
"""Test session not consolidated when no task state."""
mock_working_memory = AsyncMock()
mock_working_memory.get_task_state.return_value = None
result = await service.consolidate_session(
working_memory=mock_working_memory,
project_id=uuid4(),
session_id="test-session",
)
assert result.episode_created is False
@pytest.mark.asyncio
async def test_consolidate_session_success(
self, service: MemoryConsolidationService, mock_session: AsyncMock
) -> None:
"""Test successful session consolidation."""
mock_working_memory = AsyncMock()
task_state = make_task_state(
current_step=5,
progress_percent=100.0,
status="complete",
)
mock_working_memory.get_task_state.return_value = task_state
mock_working_memory.get_scratchpad.return_value = ["step1", "step2"]
mock_working_memory.get_all.return_value = {"key1": "value1"}
# Mock episodic memory
mock_episode = make_episode()
with patch.object(
service, "_get_episodic", new_callable=AsyncMock
) as mock_get_episodic:
mock_episodic = AsyncMock()
mock_episodic.record_episode.return_value = mock_episode
mock_get_episodic.return_value = mock_episodic
result = await service.consolidate_session(
working_memory=mock_working_memory,
project_id=uuid4(),
session_id="test-session",
)
assert result.episode_created is True
assert result.episode_id == mock_episode.id
assert result.scratchpad_entries == 2
# =========================================================================
# Outcome Determination Tests
# =========================================================================
def test_determine_session_outcome_success(
self, service: MemoryConsolidationService
) -> None:
"""Test outcome determination for successful session."""
task_state = make_task_state(status="complete", progress_percent=100.0)
outcome = service._determine_session_outcome(task_state)
assert outcome == Outcome.SUCCESS
def test_determine_session_outcome_failure(
self, service: MemoryConsolidationService
) -> None:
"""Test outcome determination for failed session."""
task_state = make_task_state(status="error", progress_percent=25.0)
outcome = service._determine_session_outcome(task_state)
assert outcome == Outcome.FAILURE
def test_determine_session_outcome_partial(
self, service: MemoryConsolidationService
) -> None:
"""Test outcome determination for partial session."""
task_state = make_task_state(status="stopped", progress_percent=60.0)
outcome = service._determine_session_outcome(task_state)
assert outcome == Outcome.PARTIAL
def test_determine_session_outcome_none(
self, service: MemoryConsolidationService
) -> None:
"""Test outcome determination with no task state."""
outcome = service._determine_session_outcome(None)
assert outcome == Outcome.PARTIAL
# =========================================================================
# Action Building Tests
# =========================================================================
def test_build_actions_from_session(
self, service: MemoryConsolidationService
) -> None:
"""Test building actions from session data."""
scratchpad = ["thought 1", "thought 2"]
variables = {"var1": "value1"}
task_state = make_task_state()
actions = service._build_actions_from_session(scratchpad, variables, task_state)
assert len(actions) == 3 # 2 scratchpad + 1 final state
assert actions[0]["type"] == "reasoning"
assert actions[2]["type"] == "final_state"
def test_build_context_summary(self, service: MemoryConsolidationService) -> None:
"""Test building context summary."""
task_state = make_task_state(
description="Test Task",
progress_percent=75.0,
)
variables = {"key": "value"}
summary = service._build_context_summary(task_state, variables)
assert "Test Task" in summary
assert "75.0%" in summary
# =========================================================================
# Importance Calculation Tests
# =========================================================================
def test_calculate_session_importance_base(
self, service: MemoryConsolidationService
) -> None:
"""Test base importance calculation."""
task_state = make_task_state(total_steps=3) # Below threshold
importance = service._calculate_session_importance(
task_state, Outcome.SUCCESS, []
)
assert importance == 0.5 # Base score
def test_calculate_session_importance_failure(
self, service: MemoryConsolidationService
) -> None:
"""Test importance boost for failures."""
task_state = make_task_state(total_steps=3) # Below threshold
importance = service._calculate_session_importance(
task_state, Outcome.FAILURE, []
)
assert importance == 0.8 # Base (0.5) + failure boost (0.3)
def test_calculate_session_importance_complex(
self, service: MemoryConsolidationService
) -> None:
"""Test importance for complex session."""
task_state = make_task_state(total_steps=10)
actions = [{"step": i} for i in range(6)]
importance = service._calculate_session_importance(
task_state, Outcome.SUCCESS, actions
)
# Base (0.5) + many steps (0.1) + many actions (0.1)
assert importance == 0.7
# =========================================================================
# Episode to Fact Consolidation Tests
# =========================================================================
@pytest.mark.asyncio
async def test_consolidate_episodes_to_facts_empty(
self, service: MemoryConsolidationService
) -> None:
"""Test consolidation with no episodes."""
with patch.object(
service, "_get_episodic", new_callable=AsyncMock
) as mock_get_episodic:
mock_episodic = AsyncMock()
mock_episodic.get_recent.return_value = []
mock_get_episodic.return_value = mock_episodic
result = await service.consolidate_episodes_to_facts(
project_id=uuid4(),
)
assert result.items_processed == 0
assert result.items_created == 0
@pytest.mark.asyncio
async def test_consolidate_episodes_to_facts_success(
self, service: MemoryConsolidationService
) -> None:
"""Test successful fact extraction."""
episode = make_episode(
lessons_learned=["Always check return values"],
)
mock_fact = MagicMock()
mock_fact.reinforcement_count = 1 # New fact
with (
patch.object(
service, "_get_episodic", new_callable=AsyncMock
) as mock_get_episodic,
patch.object(
service, "_get_semantic", new_callable=AsyncMock
) as mock_get_semantic,
):
mock_episodic = AsyncMock()
mock_episodic.get_recent.return_value = [episode]
mock_get_episodic.return_value = mock_episodic
mock_semantic = AsyncMock()
mock_semantic.store_fact.return_value = mock_fact
mock_get_semantic.return_value = mock_semantic
result = await service.consolidate_episodes_to_facts(
project_id=uuid4(),
)
assert result.items_processed == 1
# At least one fact should be created from lesson
assert result.items_created >= 0
# =========================================================================
# Episode to Procedure Consolidation Tests
# =========================================================================
@pytest.mark.asyncio
async def test_consolidate_episodes_to_procedures_insufficient(
self, service: MemoryConsolidationService
) -> None:
"""Test consolidation with insufficient episodes."""
# Only 1 episode - less than min_episodes_for_procedure (3)
episode = make_episode()
with patch.object(
service, "_get_episodic", new_callable=AsyncMock
) as mock_get_episodic:
mock_episodic = AsyncMock()
mock_episodic.get_by_outcome.return_value = [episode]
mock_get_episodic.return_value = mock_episodic
result = await service.consolidate_episodes_to_procedures(
project_id=uuid4(),
)
assert result.items_processed == 1
assert result.items_created == 0
assert result.items_skipped == 1
@pytest.mark.asyncio
async def test_consolidate_episodes_to_procedures_success(
self, service: MemoryConsolidationService
) -> None:
"""Test successful procedure creation."""
# Create enough episodes for a procedure
episodes = [
make_episode(
task_type="deploy",
actions=[{"type": "step1"}, {"type": "step2"}, {"type": "step3"}],
)
for _ in range(5)
]
mock_procedure = MagicMock()
with (
patch.object(
service, "_get_episodic", new_callable=AsyncMock
) as mock_get_episodic,
patch.object(
service, "_get_procedural", new_callable=AsyncMock
) as mock_get_procedural,
):
mock_episodic = AsyncMock()
mock_episodic.get_by_outcome.return_value = episodes
mock_get_episodic.return_value = mock_episodic
mock_procedural = AsyncMock()
mock_procedural.find_matching.return_value = [] # No existing procedure
mock_procedural.record_procedure.return_value = mock_procedure
mock_get_procedural.return_value = mock_procedural
result = await service.consolidate_episodes_to_procedures(
project_id=uuid4(),
)
assert result.items_processed == 5
assert result.items_created == 1
# =========================================================================
# Common Steps Extraction Tests
# =========================================================================
def test_extract_common_steps(self, service: MemoryConsolidationService) -> None:
"""Test extracting steps from episodes."""
episodes = [
make_episode(
outcome=Outcome.SUCCESS,
importance_score=0.8,
actions=[
{"type": "step1", "content": "First step"},
{"type": "step2", "content": "Second step"},
],
),
make_episode(
outcome=Outcome.SUCCESS,
importance_score=0.5,
actions=[{"type": "simple"}],
),
]
steps = service._extract_common_steps(episodes)
assert len(steps) == 2
assert steps[0]["order"] == 1
assert steps[0]["action"] == "step1"
# =========================================================================
# Pruning Tests
# =========================================================================
def test_should_prune_episode_old_low_importance(
self, service: MemoryConsolidationService
) -> None:
"""Test pruning old, low-importance episode."""
old_date = _utcnow() - timedelta(days=100)
episode = make_episode(
occurred_at=old_date,
importance_score=0.1,
outcome=Outcome.SUCCESS,
)
cutoff = _utcnow() - timedelta(days=90)
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
assert should_prune is True
def test_should_prune_episode_recent(
self, service: MemoryConsolidationService
) -> None:
"""Test not pruning recent episode."""
recent_date = _utcnow() - timedelta(days=30)
episode = make_episode(
occurred_at=recent_date,
importance_score=0.1,
)
cutoff = _utcnow() - timedelta(days=90)
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
assert should_prune is False
def test_should_prune_episode_failure_protected(
self, service: MemoryConsolidationService
) -> None:
"""Test not pruning failure (with keep_all_failures=True)."""
old_date = _utcnow() - timedelta(days=100)
episode = make_episode(
occurred_at=old_date,
importance_score=0.1,
outcome=Outcome.FAILURE,
)
cutoff = _utcnow() - timedelta(days=90)
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
# Config has keep_all_failures=True by default
assert should_prune is False
def test_should_prune_episode_with_lessons_protected(
self, service: MemoryConsolidationService
) -> None:
"""Test not pruning episode with lessons."""
old_date = _utcnow() - timedelta(days=100)
episode = make_episode(
occurred_at=old_date,
importance_score=0.1,
lessons_learned=["Important lesson"],
)
cutoff = _utcnow() - timedelta(days=90)
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
# Config has keep_all_with_lessons=True by default
assert should_prune is False
def test_should_prune_episode_high_importance_protected(
self, service: MemoryConsolidationService
) -> None:
"""Test not pruning high importance episode."""
old_date = _utcnow() - timedelta(days=100)
episode = make_episode(
occurred_at=old_date,
importance_score=0.8,
)
cutoff = _utcnow() - timedelta(days=90)
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
assert should_prune is False
@pytest.mark.asyncio
async def test_prune_old_episodes(
self, service: MemoryConsolidationService
) -> None:
"""Test episode pruning."""
old_episode = make_episode(
occurred_at=_utcnow() - timedelta(days=100),
importance_score=0.1,
outcome=Outcome.SUCCESS,
lessons_learned=[],
)
with patch.object(
service, "_get_episodic", new_callable=AsyncMock
) as mock_get_episodic:
mock_episodic = AsyncMock()
mock_episodic.get_recent.return_value = [old_episode]
mock_episodic.delete.return_value = True
mock_get_episodic.return_value = mock_episodic
result = await service.prune_old_episodes(project_id=uuid4())
assert result.items_processed == 1
assert result.items_pruned == 1
# =========================================================================
# Nightly Consolidation Tests
# =========================================================================
@pytest.mark.asyncio
async def test_run_nightly_consolidation(
self, service: MemoryConsolidationService
) -> None:
"""Test nightly consolidation workflow."""
with (
patch.object(
service,
"consolidate_episodes_to_facts",
new_callable=AsyncMock,
) as mock_facts,
patch.object(
service,
"consolidate_episodes_to_procedures",
new_callable=AsyncMock,
) as mock_procedures,
patch.object(
service,
"prune_old_episodes",
new_callable=AsyncMock,
) as mock_prune,
):
mock_facts.return_value = ConsolidationResult(
source_type="episodic",
target_type="semantic",
items_processed=10,
items_created=5,
)
mock_procedures.return_value = ConsolidationResult(
source_type="episodic",
target_type="procedural",
items_processed=10,
items_created=2,
)
mock_prune.return_value = ConsolidationResult(
source_type="episodic",
target_type="pruned",
items_pruned=3,
)
result = await service.run_nightly_consolidation(project_id=uuid4())
assert result.completed_at is not None
assert result.total_facts_created == 5
assert result.total_procedures_created == 2
assert result.total_pruned == 3
assert result.total_episodes_processed == 20
@pytest.mark.asyncio
async def test_run_nightly_consolidation_with_errors(
self, service: MemoryConsolidationService
) -> None:
"""Test nightly consolidation handles errors."""
with (
patch.object(
service,
"consolidate_episodes_to_facts",
new_callable=AsyncMock,
) as mock_facts,
patch.object(
service,
"consolidate_episodes_to_procedures",
new_callable=AsyncMock,
) as mock_procedures,
patch.object(
service,
"prune_old_episodes",
new_callable=AsyncMock,
) as mock_prune,
):
mock_facts.return_value = ConsolidationResult(
source_type="episodic",
target_type="semantic",
errors=["fact error"],
)
mock_procedures.return_value = ConsolidationResult(
source_type="episodic",
target_type="procedural",
)
mock_prune.return_value = ConsolidationResult(
source_type="episodic",
target_type="pruned",
)
result = await service.run_nightly_consolidation(project_id=uuid4())
assert "fact error" in result.errors

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/episodic/__init__.py
"""Unit tests for episodic memory service."""

View File

@@ -0,0 +1,359 @@
# tests/unit/services/memory/episodic/test_memory.py
"""Unit tests for EpisodicMemory class."""
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from app.services.memory.episodic.memory import EpisodicMemory
from app.services.memory.episodic.retrieval import RetrievalStrategy
from app.services.memory.types import EpisodeCreate, Outcome, RetrievalResult
class TestEpisodicMemoryInit:
"""Tests for EpisodicMemory initialization."""
def test_init_creates_recorder_and_retriever(self) -> None:
"""Test that init creates recorder and retriever."""
mock_session = AsyncMock()
memory = EpisodicMemory(session=mock_session)
assert memory._recorder is not None
assert memory._retriever is not None
assert memory._session is mock_session
def test_init_with_embedding_generator(self) -> None:
"""Test init with embedding generator."""
mock_session = AsyncMock()
mock_embedding_gen = AsyncMock()
memory = EpisodicMemory(
session=mock_session, embedding_generator=mock_embedding_gen
)
assert memory._embedding_generator is mock_embedding_gen
@pytest.mark.asyncio
async def test_create_factory_method(self) -> None:
"""Test create factory method."""
mock_session = AsyncMock()
memory = await EpisodicMemory.create(session=mock_session)
assert memory is not None
assert memory._session is mock_session
class TestEpisodicMemoryRecording:
"""Tests for episode recording methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
session.add = MagicMock()
session.flush = AsyncMock()
session.refresh = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
"""Create an EpisodicMemory instance."""
return EpisodicMemory(session=mock_session)
@pytest.mark.asyncio
async def test_record_episode(
self,
memory: EpisodicMemory,
) -> None:
"""Test recording an episode."""
episode_data = EpisodeCreate(
project_id=uuid4(),
session_id="test-session",
task_type="test_task",
task_description="Test description",
actions=[{"action": "test"}],
context_summary="Test context",
outcome=Outcome.SUCCESS,
outcome_details="Success",
duration_seconds=30.0,
tokens_used=100,
)
result = await memory.record_episode(episode_data)
assert result.project_id == episode_data.project_id
assert result.task_type == "test_task"
assert result.outcome == Outcome.SUCCESS
@pytest.mark.asyncio
async def test_record_success(
self,
memory: EpisodicMemory,
) -> None:
"""Test convenience method for recording success."""
project_id = uuid4()
result = await memory.record_success(
project_id=project_id,
session_id="test-session",
task_type="deployment",
task_description="Deploy to production",
actions=[{"step": "deploy"}],
context_summary="Deploying v1.0",
outcome_details="Deployed successfully",
duration_seconds=60.0,
tokens_used=200,
)
assert result.outcome == Outcome.SUCCESS
assert result.task_type == "deployment"
@pytest.mark.asyncio
async def test_record_failure(
self,
memory: EpisodicMemory,
) -> None:
"""Test convenience method for recording failure."""
project_id = uuid4()
result = await memory.record_failure(
project_id=project_id,
session_id="test-session",
task_type="deployment",
task_description="Deploy to production",
actions=[{"step": "deploy"}],
context_summary="Deploying v1.0",
error_details="Connection timeout",
duration_seconds=30.0,
tokens_used=100,
)
assert result.outcome == Outcome.FAILURE
assert result.outcome_details == "Connection timeout"
class TestEpisodicMemoryRetrieval:
"""Tests for episode retrieval methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
session.execute.return_value = mock_result
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
"""Create an EpisodicMemory instance."""
return EpisodicMemory(session=mock_session)
@pytest.mark.asyncio
async def test_search_similar(
self,
memory: EpisodicMemory,
) -> None:
"""Test semantic search."""
project_id = uuid4()
results = await memory.search_similar(project_id, "authentication bug")
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_recent(
self,
memory: EpisodicMemory,
) -> None:
"""Test getting recent episodes."""
project_id = uuid4()
results = await memory.get_recent(project_id, limit=5)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_by_outcome(
self,
memory: EpisodicMemory,
) -> None:
"""Test getting episodes by outcome."""
project_id = uuid4()
results = await memory.get_by_outcome(project_id, Outcome.FAILURE, limit=5)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_by_task_type(
self,
memory: EpisodicMemory,
) -> None:
"""Test getting episodes by task type."""
project_id = uuid4()
results = await memory.get_by_task_type(project_id, "code_review", limit=5)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_important(
self,
memory: EpisodicMemory,
) -> None:
"""Test getting important episodes."""
project_id = uuid4()
results = await memory.get_important(project_id, limit=5, min_importance=0.8)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_retrieve_with_full_result(
self,
memory: EpisodicMemory,
) -> None:
"""Test retrieve with full result metadata."""
project_id = uuid4()
result = await memory.retrieve(project_id, RetrievalStrategy.RECENCY, limit=10)
assert isinstance(result, RetrievalResult)
assert result.retrieval_type == "recency"
class TestEpisodicMemorySummarization:
"""Tests for episode summarization."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
"""Create an EpisodicMemory instance."""
return EpisodicMemory(session=mock_session)
@pytest.mark.asyncio
async def test_summarize_empty_list(
self,
memory: EpisodicMemory,
) -> None:
"""Test summarizing empty episode list."""
summary = await memory.summarize_episodes([])
assert "No episodes to summarize" in summary
@pytest.mark.asyncio
async def test_summarize_not_found(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test summarizing when episodes not found."""
# Mock get_by_id to return None
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
summary = await memory.summarize_episodes([uuid4(), uuid4()])
assert "No episodes found" in summary
class TestEpisodicMemoryStats:
"""Tests for episode statistics."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
"""Create an EpisodicMemory instance."""
return EpisodicMemory(session=mock_session)
@pytest.mark.asyncio
async def test_get_stats(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test getting episode statistics."""
# Mock empty result
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
stats = await memory.get_stats(uuid4())
assert "total_count" in stats
assert "success_count" in stats
assert "failure_count" in stats
@pytest.mark.asyncio
async def test_count(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test counting episodes."""
# Mock result with 3 episodes
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [1, 2, 3]
mock_session.execute.return_value = mock_result
count = await memory.count(uuid4())
assert count == 3
class TestEpisodicMemoryModification:
"""Tests for episode modification methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> EpisodicMemory:
"""Create an EpisodicMemory instance."""
return EpisodicMemory(session=mock_session)
@pytest.mark.asyncio
async def test_get_by_id_not_found(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test get_by_id returns None when not found."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.get_by_id(uuid4())
assert result is None
@pytest.mark.asyncio
async def test_update_importance_not_found(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test update_importance returns None when not found."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.update_importance(uuid4(), 0.9)
assert result is None
@pytest.mark.asyncio
async def test_delete_not_found(
self,
memory: EpisodicMemory,
mock_session: AsyncMock,
) -> None:
"""Test delete returns False when not found."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.delete(uuid4())
assert result is False

Some files were not shown because too many files have changed in this diff Show More