chore(context): refactor for consistency, optimize formatting, and simplify logic

- Cleaned up unnecessary comments in `__all__` definitions for better readability.
- Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping).
- Simplified conditional expressions and inline comments for context scoring and ranking.
- Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`).
- Removed unused imports and ensured consistent usage across test files.
- Updated `test_score_not_cached_on_context` to clarify caching behavior.
- Improved truncation strategy logic and marker handling.
This commit is contained in:
2026-01-04 15:23:14 +01:00
parent 9e54f16e56
commit 2bea057fb1
26 changed files with 226 additions and 273 deletions

View File

@@ -1,11 +1,8 @@
"""Tests for model adapters."""
import pytest
from app.services.context.adapters import (
ClaudeAdapter,
DefaultAdapter,
ModelAdapter,
OpenAIAdapter,
get_adapter,
)

View File

@@ -5,10 +5,9 @@ from datetime import UTC, datetime
import pytest
from app.services.context.assembly import ContextPipeline, PipelineMetrics
from app.services.context.budget import BudgetAllocator, TokenBudget
from app.services.context.budget import TokenBudget
from app.services.context.types import (
AssembledContext,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
@@ -354,7 +353,10 @@ class TestContextPipelineFormatting:
if result.context_count > 0:
assert "<conversation_history>" in result.content
assert '<message role="user">' in result.content or 'role="user"' in result.content
assert (
'<message role="user">' in result.content
or 'role="user"' in result.content
)
@pytest.mark.asyncio
async def test_format_tool_results(self) -> None:
@@ -474,6 +476,10 @@ class TestContextPipelineIntegration:
assert system_pos < task_pos
if task_pos >= 0 and knowledge_pos >= 0:
assert task_pos < knowledge_pos
if knowledge_pos >= 0 and conversation_pos >= 0:
assert knowledge_pos < conversation_pos
if conversation_pos >= 0 and tool_pos >= 0:
assert conversation_pos < tool_pos
@pytest.mark.asyncio
async def test_excluded_contexts_tracked(self) -> None:

View File

@@ -2,16 +2,15 @@
import pytest
from app.services.context.budget import BudgetAllocator
from app.services.context.compression import (
ContextCompressor,
TruncationResult,
TruncationStrategy,
)
from app.services.context.budget import BudgetAllocator, TokenBudget
from app.services.context.types import (
ContextType,
KnowledgeContext,
SystemContext,
TaskContext,
)
@@ -113,7 +112,7 @@ class TestTruncationStrategy:
assert result.truncated is True
assert len(result.content) < len(content)
assert strategy.TRUNCATION_MARKER in result.content
assert strategy.truncation_marker in result.content
@pytest.mark.asyncio
async def test_truncate_middle_strategy(self) -> None:
@@ -126,7 +125,7 @@ class TestTruncationStrategy:
)
assert result.truncated is True
assert strategy.TRUNCATION_MARKER in result.content
assert strategy.truncation_marker in result.content
@pytest.mark.asyncio
async def test_truncate_sentence_strategy(self) -> None:
@@ -140,7 +139,9 @@ class TestTruncationStrategy:
assert result.truncated is True
# Should cut at sentence boundary
assert result.content.endswith(".") or strategy.TRUNCATION_MARKER in result.content
assert (
result.content.endswith(".") or strategy.truncation_marker in result.content
)
class TestContextCompressor:
@@ -235,10 +236,12 @@ class TestTruncationEdgeCases:
content = "Some content to truncate"
# max_tokens less than marker tokens should return just marker
result = await strategy.truncate_to_tokens(content, max_tokens=1, strategy="end")
result = await strategy.truncate_to_tokens(
content, max_tokens=1, strategy="end"
)
# Should handle gracefully without crashing
assert strategy.TRUNCATION_MARKER in result.content or result.content == content
assert strategy.truncation_marker in result.content or result.content == content
@pytest.mark.asyncio
async def test_truncate_with_content_that_has_zero_tokens(self) -> None:
@@ -249,7 +252,7 @@ class TestTruncationEdgeCases:
result = await strategy.truncate_to_tokens("a", max_tokens=100)
# Should not raise ZeroDivisionError
assert result.content in ("a", "a" + strategy.TRUNCATION_MARKER)
assert result.content in ("a", "a" + strategy.truncation_marker)
@pytest.mark.asyncio
async def test_get_content_for_tokens_zero_target(self) -> None:

View File

@@ -11,8 +11,6 @@ from app.services.context.types import (
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskContext,
ToolContext,
)

View File

@@ -1,7 +1,5 @@
"""Tests for context management exceptions."""
import pytest
from app.services.context.exceptions import (
AssemblyTimeoutError,
BudgetExceededError,

View File

@@ -1,7 +1,5 @@
"""Tests for context ranking module."""
from datetime import UTC, datetime
import pytest
from app.services.context.budget import BudgetAllocator, TokenBudget
@@ -230,9 +228,7 @@ class TestContextRanker:
),
]
result = await ranker.rank(
contexts, "query", budget, ensure_required=False
)
result = await ranker.rank(contexts, "query", budget, ensure_required=False)
# Without ensure_required, CRITICAL contexts can be excluded
# if budget doesn't allow
@@ -246,12 +242,8 @@ class TestContextRanker:
budget = allocator.create_budget(10000)
contexts = [
KnowledgeContext(
content="Knowledge 1", source="docs", relevance_score=0.8
),
KnowledgeContext(
content="Knowledge 2", source="docs", relevance_score=0.6
),
KnowledgeContext(content="Knowledge 1", source="docs", relevance_score=0.8),
KnowledgeContext(content="Knowledge 2", source="docs", relevance_score=0.6),
TaskContext(content="Task", source="task"),
]

View File

@@ -6,7 +6,6 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from app.services.context.scoring import (
BaseScorer,
CompositeScorer,
PriorityScorer,
RecencyScorer,
@@ -149,15 +148,9 @@ class TestRelevanceScorer:
scorer = RelevanceScorer()
contexts = [
KnowledgeContext(
content="Python", source="1", relevance_score=0.8
),
KnowledgeContext(
content="Java", source="2", relevance_score=0.6
),
KnowledgeContext(
content="Go", source="3", relevance_score=0.9
),
KnowledgeContext(content="Python", source="1", relevance_score=0.8),
KnowledgeContext(content="Java", source="2", relevance_score=0.6),
KnowledgeContext(content="Go", source="3", relevance_score=0.9),
]
scores = await scorer.score_batch(contexts, "test")
@@ -263,7 +256,9 @@ class TestRecencyScorer:
)
conv_score = await scorer.score(conv_context, "query", reference_time=now)
knowledge_score = await scorer.score(knowledge_context, "query", reference_time=now)
knowledge_score = await scorer.score(
knowledge_context, "query", reference_time=now
)
# Conversation should decay much faster
assert conv_score < knowledge_score
@@ -301,12 +296,8 @@ class TestRecencyScorer:
contexts = [
TaskContext(content="1", source="t", timestamp=now),
TaskContext(
content="2", source="t", timestamp=now - timedelta(hours=24)
),
TaskContext(
content="3", source="t", timestamp=now - timedelta(hours=48)
),
TaskContext(content="2", source="t", timestamp=now - timedelta(hours=24)),
TaskContext(content="3", source="t", timestamp=now - timedelta(hours=48)),
]
scores = await scorer.score_batch(contexts, "query", reference_time=now)
@@ -508,8 +499,12 @@ class TestCompositeScorer:
assert scored.priority_score > 0.5 # HIGH priority
@pytest.mark.asyncio
async def test_score_cached_on_context(self) -> None:
"""Test that score is cached on the context."""
async def test_score_not_cached_on_context(self) -> None:
"""Test that scores are NOT cached on the context.
Scores should not be cached on the context because they are query-dependent.
Different queries would get incorrect cached scores if we cached on the context.
"""
scorer = CompositeScorer()
context = KnowledgeContext(
@@ -518,14 +513,18 @@ class TestCompositeScorer:
relevance_score=0.5,
)
# First scoring
# After scoring, context._score should remain None
# (we don't cache on context because scores are query-dependent)
await scorer.score(context, "query")
assert context._score is not None
# The scorer should compute fresh scores each time
# rather than caching on the context object
# Second scoring should use cached value
context._score = 0.999 # Set to a known value
score2 = await scorer.score(context, "query")
assert score2 == 0.999
# Score again with different query - should compute fresh score
score1 = await scorer.score(context, "query 1")
score2 = await scorer.score(context, "query 2")
# Both should be valid scores (not necessarily equal since queries differ)
assert 0.0 <= score1 <= 1.0
assert 0.0 <= score2 <= 1.0
@pytest.mark.asyncio
async def test_score_batch(self) -> None:
@@ -555,15 +554,9 @@ class TestCompositeScorer:
scorer = CompositeScorer()
contexts = [
KnowledgeContext(
content="Low", source="docs", relevance_score=0.2
),
KnowledgeContext(
content="High", source="docs", relevance_score=0.9
),
KnowledgeContext(
content="Medium", source="docs", relevance_score=0.5
),
KnowledgeContext(content="Low", source="docs", relevance_score=0.2),
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
KnowledgeContext(content="Medium", source="docs", relevance_score=0.5),
]
ranked = await scorer.rank(contexts, "query")
@@ -580,9 +573,7 @@ class TestCompositeScorer:
scorer = CompositeScorer()
contexts = [
KnowledgeContext(
content=str(i), source="docs", relevance_score=i / 10
)
KnowledgeContext(content=str(i), source="docs", relevance_score=i / 10)
for i in range(10)
]
@@ -595,12 +586,8 @@ class TestCompositeScorer:
scorer = CompositeScorer()
contexts = [
KnowledgeContext(
content="Low", source="docs", relevance_score=0.1
),
KnowledgeContext(
content="High", source="docs", relevance_score=0.9
),
KnowledgeContext(content="Low", source="docs", relevance_score=0.1),
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
]
ranked = await scorer.rank(contexts, "query", min_score=0.5)
@@ -625,7 +612,13 @@ class TestCompositeScorer:
"""
import asyncio
scorer = CompositeScorer()
# Use scorer with recency_weight=0 to eliminate time-dependent variation
# (recency scores change as time passes between calls)
scorer = CompositeScorer(
relevance_weight=0.5,
recency_weight=0.0, # Disable recency to get deterministic results
priority_weight=0.5,
)
# Create a single context that will be scored multiple times concurrently
context = KnowledgeContext(
@@ -639,11 +632,9 @@ class TestCompositeScorer:
tasks = [scorer.score(context, "test query") for _ in range(num_concurrent)]
scores = await asyncio.gather(*tasks)
# All scores should be identical (the same context scored the same way)
# All scores should be identical (deterministic scoring without recency)
assert all(s == scores[0] for s in scores)
# The context should have its _score cached
assert context._score is not None
# Note: We don't cache _score on context because scores are query-dependent
@pytest.mark.asyncio
async def test_concurrent_scoring_different_contexts(self) -> None:
@@ -671,10 +662,7 @@ class TestCompositeScorer:
# Each context should have a different score based on its relevance
assert len(set(scores)) > 1 # Not all the same
# All contexts should have cached scores
for ctx in contexts:
assert ctx._score is not None
# Note: We don't cache _score on context because scores are query-dependent
class TestScoredContext:

View File

@@ -1,20 +1,17 @@
"""Tests for context types."""
import json
from datetime import UTC, datetime, timedelta
import pytest
from app.services.context.types import (
AssembledContext,
BaseContext,
ContextPriority,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskComplexity,
TaskContext,
TaskStatus,
ToolContext,
@@ -181,24 +178,16 @@ class TestKnowledgeContext:
def test_is_code(self) -> None:
"""Test is_code method."""
code_ctx = KnowledgeContext(
content="code", source="test", file_type="python"
)
doc_ctx = KnowledgeContext(
content="docs", source="test", file_type="markdown"
)
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
assert code_ctx.is_code() is True
assert doc_ctx.is_code() is False
def test_is_documentation(self) -> None:
"""Test is_documentation method."""
doc_ctx = KnowledgeContext(
content="docs", source="test", file_type="markdown"
)
code_ctx = KnowledgeContext(
content="code", source="test", file_type="python"
)
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
assert doc_ctx.is_documentation() is True
assert code_ctx.is_documentation() is False
@@ -333,15 +322,11 @@ class TestTaskContext:
def test_status_checks(self) -> None:
"""Test status check methods."""
pending = TaskContext(
content="test", source="test", status=TaskStatus.PENDING
)
pending = TaskContext(content="test", source="test", status=TaskStatus.PENDING)
completed = TaskContext(
content="test", source="test", status=TaskStatus.COMPLETED
)
blocked = TaskContext(
content="test", source="test", status=TaskStatus.BLOCKED
)
blocked = TaskContext(content="test", source="test", status=TaskStatus.BLOCKED)
assert pending.is_active() is True
assert completed.is_complete() is True
@@ -395,12 +380,8 @@ class TestToolContext:
def test_is_successful(self) -> None:
"""Test is_successful method."""
success = ToolContext.from_tool_result(
"test", "ok", ToolResultStatus.SUCCESS
)
error = ToolContext.from_tool_result(
"test", "error", ToolResultStatus.ERROR
)
success = ToolContext.from_tool_result("test", "ok", ToolResultStatus.SUCCESS)
error = ToolContext.from_tool_result("test", "error", ToolResultStatus.ERROR)
assert success.is_successful() is True
assert error.is_successful() is False
@@ -510,9 +491,7 @@ class TestBaseContextMethods:
def test_get_age_seconds(self) -> None:
"""Test get_age_seconds method."""
old_time = datetime.now(UTC) - timedelta(hours=2)
ctx = SystemContext(
content="test", source="test", timestamp=old_time
)
ctx = SystemContext(content="test", source="test", timestamp=old_time)
age = ctx.get_age_seconds()
# Should be approximately 2 hours in seconds
@@ -521,9 +500,7 @@ class TestBaseContextMethods:
def test_get_age_hours(self) -> None:
"""Test get_age_hours method."""
old_time = datetime.now(UTC) - timedelta(hours=5)
ctx = SystemContext(
content="test", source="test", timestamp=old_time
)
ctx = SystemContext(content="test", source="test", timestamp=old_time)
age = ctx.get_age_hours()
assert 4.9 < age < 5.1
@@ -533,12 +510,8 @@ class TestBaseContextMethods:
old_time = datetime.now(UTC) - timedelta(days=10)
new_time = datetime.now(UTC) - timedelta(hours=1)
old_ctx = SystemContext(
content="test", source="test", timestamp=old_time
)
new_ctx = SystemContext(
content="test", source="test", timestamp=new_time
)
old_ctx = SystemContext(content="test", source="test", timestamp=old_time)
new_ctx = SystemContext(content="test", source="test", timestamp=new_time)
# Default max_age is 168 hours (7 days)
assert old_ctx.is_stale() is True