forked from cardosofelipe/fast-next-template
chore(context): refactor for consistency, optimize formatting, and simplify logic
- Cleaned up unnecessary comments in `__all__` definitions for better readability. - Adjusted indentation and formatting across modules for improved clarity (e.g., long lines, logical grouping). - Simplified conditional expressions and inline comments for context scoring and ranking. - Replaced some hard-coded values with type-safe annotations (e.g., `ClassVar`). - Removed unused imports and ensured consistent usage across test files. - Updated `test_score_not_cached_on_context` to clarify caching behavior. - Improved truncation strategy logic and marker handling.
This commit is contained in:
@@ -1,11 +1,8 @@
|
||||
"""Tests for model adapters."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.adapters import (
|
||||
ClaudeAdapter,
|
||||
DefaultAdapter,
|
||||
ModelAdapter,
|
||||
OpenAIAdapter,
|
||||
get_adapter,
|
||||
)
|
||||
|
||||
@@ -5,10 +5,9 @@ from datetime import UTC, datetime
|
||||
import pytest
|
||||
|
||||
from app.services.context.assembly import ContextPipeline, PipelineMetrics
|
||||
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||
from app.services.context.budget import TokenBudget
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
@@ -354,7 +353,10 @@ class TestContextPipelineFormatting:
|
||||
|
||||
if result.context_count > 0:
|
||||
assert "<conversation_history>" in result.content
|
||||
assert '<message role="user">' in result.content or 'role="user"' in result.content
|
||||
assert (
|
||||
'<message role="user">' in result.content
|
||||
or 'role="user"' in result.content
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_format_tool_results(self) -> None:
|
||||
@@ -474,6 +476,10 @@ class TestContextPipelineIntegration:
|
||||
assert system_pos < task_pos
|
||||
if task_pos >= 0 and knowledge_pos >= 0:
|
||||
assert task_pos < knowledge_pos
|
||||
if knowledge_pos >= 0 and conversation_pos >= 0:
|
||||
assert knowledge_pos < conversation_pos
|
||||
if conversation_pos >= 0 and tool_pos >= 0:
|
||||
assert conversation_pos < tool_pos
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excluded_contexts_tracked(self) -> None:
|
||||
|
||||
@@ -2,16 +2,15 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.budget import BudgetAllocator
|
||||
from app.services.context.compression import (
|
||||
ContextCompressor,
|
||||
TruncationResult,
|
||||
TruncationStrategy,
|
||||
)
|
||||
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||
from app.services.context.types import (
|
||||
ContextType,
|
||||
KnowledgeContext,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
)
|
||||
|
||||
@@ -113,7 +112,7 @@ class TestTruncationStrategy:
|
||||
|
||||
assert result.truncated is True
|
||||
assert len(result.content) < len(content)
|
||||
assert strategy.TRUNCATION_MARKER in result.content
|
||||
assert strategy.truncation_marker in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_middle_strategy(self) -> None:
|
||||
@@ -126,7 +125,7 @@ class TestTruncationStrategy:
|
||||
)
|
||||
|
||||
assert result.truncated is True
|
||||
assert strategy.TRUNCATION_MARKER in result.content
|
||||
assert strategy.truncation_marker in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_sentence_strategy(self) -> None:
|
||||
@@ -140,7 +139,9 @@ class TestTruncationStrategy:
|
||||
|
||||
assert result.truncated is True
|
||||
# Should cut at sentence boundary
|
||||
assert result.content.endswith(".") or strategy.TRUNCATION_MARKER in result.content
|
||||
assert (
|
||||
result.content.endswith(".") or strategy.truncation_marker in result.content
|
||||
)
|
||||
|
||||
|
||||
class TestContextCompressor:
|
||||
@@ -235,10 +236,12 @@ class TestTruncationEdgeCases:
|
||||
content = "Some content to truncate"
|
||||
|
||||
# max_tokens less than marker tokens should return just marker
|
||||
result = await strategy.truncate_to_tokens(content, max_tokens=1, strategy="end")
|
||||
result = await strategy.truncate_to_tokens(
|
||||
content, max_tokens=1, strategy="end"
|
||||
)
|
||||
|
||||
# Should handle gracefully without crashing
|
||||
assert strategy.TRUNCATION_MARKER in result.content or result.content == content
|
||||
assert strategy.truncation_marker in result.content or result.content == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncate_with_content_that_has_zero_tokens(self) -> None:
|
||||
@@ -249,7 +252,7 @@ class TestTruncationEdgeCases:
|
||||
result = await strategy.truncate_to_tokens("a", max_tokens=100)
|
||||
|
||||
# Should not raise ZeroDivisionError
|
||||
assert result.content in ("a", "a" + strategy.TRUNCATION_MARKER)
|
||||
assert result.content in ("a", "a" + strategy.truncation_marker)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_content_for_tokens_zero_target(self) -> None:
|
||||
|
||||
@@ -11,8 +11,6 @@ from app.services.context.types import (
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Tests for context management exceptions."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.exceptions import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Tests for context ranking module."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||
@@ -230,9 +228,7 @@ class TestContextRanker:
|
||||
),
|
||||
]
|
||||
|
||||
result = await ranker.rank(
|
||||
contexts, "query", budget, ensure_required=False
|
||||
)
|
||||
result = await ranker.rank(contexts, "query", budget, ensure_required=False)
|
||||
|
||||
# Without ensure_required, CRITICAL contexts can be excluded
|
||||
# if budget doesn't allow
|
||||
@@ -246,12 +242,8 @@ class TestContextRanker:
|
||||
budget = allocator.create_budget(10000)
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Knowledge 1", source="docs", relevance_score=0.8
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Knowledge 2", source="docs", relevance_score=0.6
|
||||
),
|
||||
KnowledgeContext(content="Knowledge 1", source="docs", relevance_score=0.8),
|
||||
KnowledgeContext(content="Knowledge 2", source="docs", relevance_score=0.6),
|
||||
TaskContext(content="Task", source="task"),
|
||||
]
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
import pytest
|
||||
|
||||
from app.services.context.scoring import (
|
||||
BaseScorer,
|
||||
CompositeScorer,
|
||||
PriorityScorer,
|
||||
RecencyScorer,
|
||||
@@ -149,15 +148,9 @@ class TestRelevanceScorer:
|
||||
scorer = RelevanceScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Python", source="1", relevance_score=0.8
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Java", source="2", relevance_score=0.6
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Go", source="3", relevance_score=0.9
|
||||
),
|
||||
KnowledgeContext(content="Python", source="1", relevance_score=0.8),
|
||||
KnowledgeContext(content="Java", source="2", relevance_score=0.6),
|
||||
KnowledgeContext(content="Go", source="3", relevance_score=0.9),
|
||||
]
|
||||
|
||||
scores = await scorer.score_batch(contexts, "test")
|
||||
@@ -263,7 +256,9 @@ class TestRecencyScorer:
|
||||
)
|
||||
|
||||
conv_score = await scorer.score(conv_context, "query", reference_time=now)
|
||||
knowledge_score = await scorer.score(knowledge_context, "query", reference_time=now)
|
||||
knowledge_score = await scorer.score(
|
||||
knowledge_context, "query", reference_time=now
|
||||
)
|
||||
|
||||
# Conversation should decay much faster
|
||||
assert conv_score < knowledge_score
|
||||
@@ -301,12 +296,8 @@ class TestRecencyScorer:
|
||||
|
||||
contexts = [
|
||||
TaskContext(content="1", source="t", timestamp=now),
|
||||
TaskContext(
|
||||
content="2", source="t", timestamp=now - timedelta(hours=24)
|
||||
),
|
||||
TaskContext(
|
||||
content="3", source="t", timestamp=now - timedelta(hours=48)
|
||||
),
|
||||
TaskContext(content="2", source="t", timestamp=now - timedelta(hours=24)),
|
||||
TaskContext(content="3", source="t", timestamp=now - timedelta(hours=48)),
|
||||
]
|
||||
|
||||
scores = await scorer.score_batch(contexts, "query", reference_time=now)
|
||||
@@ -508,8 +499,12 @@ class TestCompositeScorer:
|
||||
assert scored.priority_score > 0.5 # HIGH priority
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_cached_on_context(self) -> None:
|
||||
"""Test that score is cached on the context."""
|
||||
async def test_score_not_cached_on_context(self) -> None:
|
||||
"""Test that scores are NOT cached on the context.
|
||||
|
||||
Scores should not be cached on the context because they are query-dependent.
|
||||
Different queries would get incorrect cached scores if we cached on the context.
|
||||
"""
|
||||
scorer = CompositeScorer()
|
||||
|
||||
context = KnowledgeContext(
|
||||
@@ -518,14 +513,18 @@ class TestCompositeScorer:
|
||||
relevance_score=0.5,
|
||||
)
|
||||
|
||||
# First scoring
|
||||
# After scoring, context._score should remain None
|
||||
# (we don't cache on context because scores are query-dependent)
|
||||
await scorer.score(context, "query")
|
||||
assert context._score is not None
|
||||
# The scorer should compute fresh scores each time
|
||||
# rather than caching on the context object
|
||||
|
||||
# Second scoring should use cached value
|
||||
context._score = 0.999 # Set to a known value
|
||||
score2 = await scorer.score(context, "query")
|
||||
assert score2 == 0.999
|
||||
# Score again with different query - should compute fresh score
|
||||
score1 = await scorer.score(context, "query 1")
|
||||
score2 = await scorer.score(context, "query 2")
|
||||
# Both should be valid scores (not necessarily equal since queries differ)
|
||||
assert 0.0 <= score1 <= 1.0
|
||||
assert 0.0 <= score2 <= 1.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_batch(self) -> None:
|
||||
@@ -555,15 +554,9 @@ class TestCompositeScorer:
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low", source="docs", relevance_score=0.2
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High", source="docs", relevance_score=0.9
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="Medium", source="docs", relevance_score=0.5
|
||||
),
|
||||
KnowledgeContext(content="Low", source="docs", relevance_score=0.2),
|
||||
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
|
||||
KnowledgeContext(content="Medium", source="docs", relevance_score=0.5),
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query")
|
||||
@@ -580,9 +573,7 @@ class TestCompositeScorer:
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content=str(i), source="docs", relevance_score=i / 10
|
||||
)
|
||||
KnowledgeContext(content=str(i), source="docs", relevance_score=i / 10)
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
@@ -595,12 +586,8 @@ class TestCompositeScorer:
|
||||
scorer = CompositeScorer()
|
||||
|
||||
contexts = [
|
||||
KnowledgeContext(
|
||||
content="Low", source="docs", relevance_score=0.1
|
||||
),
|
||||
KnowledgeContext(
|
||||
content="High", source="docs", relevance_score=0.9
|
||||
),
|
||||
KnowledgeContext(content="Low", source="docs", relevance_score=0.1),
|
||||
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
|
||||
]
|
||||
|
||||
ranked = await scorer.rank(contexts, "query", min_score=0.5)
|
||||
@@ -625,7 +612,13 @@ class TestCompositeScorer:
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
scorer = CompositeScorer()
|
||||
# Use scorer with recency_weight=0 to eliminate time-dependent variation
|
||||
# (recency scores change as time passes between calls)
|
||||
scorer = CompositeScorer(
|
||||
relevance_weight=0.5,
|
||||
recency_weight=0.0, # Disable recency to get deterministic results
|
||||
priority_weight=0.5,
|
||||
)
|
||||
|
||||
# Create a single context that will be scored multiple times concurrently
|
||||
context = KnowledgeContext(
|
||||
@@ -639,11 +632,9 @@ class TestCompositeScorer:
|
||||
tasks = [scorer.score(context, "test query") for _ in range(num_concurrent)]
|
||||
scores = await asyncio.gather(*tasks)
|
||||
|
||||
# All scores should be identical (the same context scored the same way)
|
||||
# All scores should be identical (deterministic scoring without recency)
|
||||
assert all(s == scores[0] for s in scores)
|
||||
|
||||
# The context should have its _score cached
|
||||
assert context._score is not None
|
||||
# Note: We don't cache _score on context because scores are query-dependent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_scoring_different_contexts(self) -> None:
|
||||
@@ -671,10 +662,7 @@ class TestCompositeScorer:
|
||||
|
||||
# Each context should have a different score based on its relevance
|
||||
assert len(set(scores)) > 1 # Not all the same
|
||||
|
||||
# All contexts should have cached scores
|
||||
for ctx in contexts:
|
||||
assert ctx._score is not None
|
||||
# Note: We don't cache _score on context because scores are query-dependent
|
||||
|
||||
|
||||
class TestScoredContext:
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
"""Tests for context types."""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
ToolContext,
|
||||
@@ -181,24 +178,16 @@ class TestKnowledgeContext:
|
||||
|
||||
def test_is_code(self) -> None:
|
||||
"""Test is_code method."""
|
||||
code_ctx = KnowledgeContext(
|
||||
content="code", source="test", file_type="python"
|
||||
)
|
||||
doc_ctx = KnowledgeContext(
|
||||
content="docs", source="test", file_type="markdown"
|
||||
)
|
||||
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
|
||||
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
|
||||
|
||||
assert code_ctx.is_code() is True
|
||||
assert doc_ctx.is_code() is False
|
||||
|
||||
def test_is_documentation(self) -> None:
|
||||
"""Test is_documentation method."""
|
||||
doc_ctx = KnowledgeContext(
|
||||
content="docs", source="test", file_type="markdown"
|
||||
)
|
||||
code_ctx = KnowledgeContext(
|
||||
content="code", source="test", file_type="python"
|
||||
)
|
||||
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
|
||||
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
|
||||
|
||||
assert doc_ctx.is_documentation() is True
|
||||
assert code_ctx.is_documentation() is False
|
||||
@@ -333,15 +322,11 @@ class TestTaskContext:
|
||||
|
||||
def test_status_checks(self) -> None:
|
||||
"""Test status check methods."""
|
||||
pending = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.PENDING
|
||||
)
|
||||
pending = TaskContext(content="test", source="test", status=TaskStatus.PENDING)
|
||||
completed = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.COMPLETED
|
||||
)
|
||||
blocked = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.BLOCKED
|
||||
)
|
||||
blocked = TaskContext(content="test", source="test", status=TaskStatus.BLOCKED)
|
||||
|
||||
assert pending.is_active() is True
|
||||
assert completed.is_complete() is True
|
||||
@@ -395,12 +380,8 @@ class TestToolContext:
|
||||
|
||||
def test_is_successful(self) -> None:
|
||||
"""Test is_successful method."""
|
||||
success = ToolContext.from_tool_result(
|
||||
"test", "ok", ToolResultStatus.SUCCESS
|
||||
)
|
||||
error = ToolContext.from_tool_result(
|
||||
"test", "error", ToolResultStatus.ERROR
|
||||
)
|
||||
success = ToolContext.from_tool_result("test", "ok", ToolResultStatus.SUCCESS)
|
||||
error = ToolContext.from_tool_result("test", "error", ToolResultStatus.ERROR)
|
||||
|
||||
assert success.is_successful() is True
|
||||
assert error.is_successful() is False
|
||||
@@ -510,9 +491,7 @@ class TestBaseContextMethods:
|
||||
def test_get_age_seconds(self) -> None:
|
||||
"""Test get_age_seconds method."""
|
||||
old_time = datetime.now(UTC) - timedelta(hours=2)
|
||||
ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||
|
||||
age = ctx.get_age_seconds()
|
||||
# Should be approximately 2 hours in seconds
|
||||
@@ -521,9 +500,7 @@ class TestBaseContextMethods:
|
||||
def test_get_age_hours(self) -> None:
|
||||
"""Test get_age_hours method."""
|
||||
old_time = datetime.now(UTC) - timedelta(hours=5)
|
||||
ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||
|
||||
age = ctx.get_age_hours()
|
||||
assert 4.9 < age < 5.1
|
||||
@@ -533,12 +510,8 @@ class TestBaseContextMethods:
|
||||
old_time = datetime.now(UTC) - timedelta(days=10)
|
||||
new_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
|
||||
old_ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
new_ctx = SystemContext(
|
||||
content="test", source="test", timestamp=new_time
|
||||
)
|
||||
old_ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||
new_ctx = SystemContext(content="test", source="test", timestamp=new_time)
|
||||
|
||||
# Default max_age is 168 hours (7 days)
|
||||
assert old_ctx.is_stale() is True
|
||||
|
||||
Reference in New Issue
Block a user