Add major caching to sampler manager

Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
2025-01-23 08:49:16 +01:00
parent 7fc8fa17d6
commit df5b42b9c9
3 changed files with 118 additions and 58 deletions

View File

@@ -1,16 +1,15 @@
from typing import List from typing import List
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response from fastapi.responses import Response
from app.models.sample import Sample from app.models.sample import Sample
from app.services.sample_manager import SampleManager
router = APIRouter() router = APIRouter()
sample_manager = SampleManager()
@router.get("/list", response_model=List[Sample]) @router.get("/list", response_model=List[Sample])
async def list_samples( async def list_samples(
request: Request,
limit: int = Query(20, ge=1, le=100), limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0) offset: int = Query(0, ge=0)
): ):
@@ -18,31 +17,36 @@ async def list_samples(
List sample images with pagination List sample images with pagination
""" """
try: try:
sample_manager = request.app.state.sample_manager
return await sample_manager.list_samples(limit, offset) return await sample_manager.list_samples(limit, offset)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("/latest", response_model=List[Sample]) @router.get("/latest", response_model=List[Sample])
async def get_latest_samples( async def get_latest_samples(
request: Request,
count: int = Query(5, ge=1, le=20) count: int = Query(5, ge=1, le=20)
): ):
""" """
Get the most recent sample images Get the most recent sample images
""" """
try: try:
sample_manager = request.app.state.sample_manager
return await sample_manager.get_latest_samples(count) return await sample_manager.get_latest_samples(count)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("/image/{filename}") @router.get("/image/{filename}")
async def get_sample_image(filename: str): async def get_sample_image(
request: Request,
filename: str):
""" """
Get a specific sample image Get a specific sample image
""" """
try: try:
sample_manager = request.app.state.sample_manager
image_data = await sample_manager.get_sample_data(filename) image_data = await sample_manager.get_sample_data(filename)
# Try to determine content type from filename # Try to determine content type from filename
content_type = "image/jpeg" # default content_type = "image/jpeg" # default
if filename.lower().endswith('.png'): if filename.lower().endswith('.png'):

View File

@@ -1,3 +1,4 @@
import logging
import platform import platform
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -7,7 +8,14 @@ from fastapi.middleware.cors import CORSMiddleware
from app.api.routes import training, samples from app.api.routes import training, samples
from app.core.config import settings from app.core.config import settings
from app.services.sample_manager import SampleManager
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = FastAPI( app = FastAPI(
title="Training Monitor API", title="Training Monitor API",
description="API for monitoring ML training progress and samples", description="API for monitoring ML training progress and samples",
@@ -23,6 +31,24 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# Create and store SampleManager instance
sample_manager = SampleManager()
app.state.sample_manager = sample_manager
@app.on_event("startup")
async def startup_event():
"""Initialize services on startup"""
logger.info("Starting up Training Monitor API")
await sample_manager.startup()
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
logger.info("Shutting down Training Monitor API")
await sample_manager.shutdown()
# Include routers with versioning # Include routers with versioning
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"]) app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"]) app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import logging
import os import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime from datetime import datetime
@@ -10,32 +11,55 @@ from fastapi import HTTPException
from app.core.config import settings from app.core.config import settings
from app.models.sample import Sample from app.models.sample import Sample
logger = logging.getLogger(__name__)
class SampleManager: class SampleManager:
def __init__(self): def __init__(self):
self.sftp_client = None self.sftp_client = None
self.cache_dir = "cache/samples"
self.last_sync = None
self.file_index: Dict[str, datetime] = {}
self.memory_cache: Dict[str, memoryview] = {} self.memory_cache: Dict[str, memoryview] = {}
self.file_index: Dict[str, datetime] = {}
self.last_sync = None
self.executor = ThreadPoolExecutor(max_workers=4) self.executor = ThreadPoolExecutor(max_workers=4)
self._ensure_cache_dir() self._sync_task = None
self._running = False
def _ensure_cache_dir(self): async def startup(self):
"""Ensure cache directory exists""" """Initialize the manager and start periodic sync"""
os.makedirs(self.cache_dir, exist_ok=True) logger.info("Starting SampleManager initialization...")
self._running = True
try:
# Start both initial sync and periodic sync as background tasks
self._sync_task = asyncio.create_task(self._periodic_sync())
logger.info("SampleManager started, initial sync running in background")
except Exception as e:
logger.error(f"Startup failed with error: {str(e)}")
raise
async def shutdown(self):
"""Cleanup resources"""
self._running = False
if self._sync_task:
self._sync_task.cancel()
try:
await self._sync_task
except asyncio.CancelledError:
pass
self.executor.shutdown(wait=True)
self._disconnect_sftp()
# Clear memory cache
self.memory_cache.clear()
self.file_index.clear()
logger.info("SampleManager shutdown completed")
async def _connect_sftp(self): async def _connect_sftp(self):
"""Create SFTP connection using SSH key""" """Create SFTP connection using SSH key"""
try: try:
# Expand the key path (handles ~/)
key_path = os.path.expanduser(settings.SFTP_KEY_PATH) key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
# Create a new SSH client
ssh = paramiko.SSHClient() ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
logger.info(f"Attempting connection to {settings.SFTP_HOST} as {settings.SFTP_USER} with key {key_path}")
# Connect using the SSH key
ssh.connect( ssh.connect(
hostname=settings.SFTP_HOST, hostname=settings.SFTP_HOST,
username=settings.SFTP_USER, username=settings.SFTP_USER,
@@ -43,9 +67,9 @@ class SampleManager:
key_filename=key_path, key_filename=key_path,
) )
# Create SFTP client from the SSH client
self.sftp_client = ssh.open_sftp() self.sftp_client = ssh.open_sftp()
except Exception as e: except Exception as e:
logger.error(f"SFTP connection failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}") raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
def _disconnect_sftp(self): def _disconnect_sftp(self):
@@ -58,10 +82,10 @@ class SampleManager:
"""Download file directly to memory""" """Download file directly to memory"""
try: try:
with self.sftp_client.file(remote_path, 'rb') as remote_file: with self.sftp_client.file(remote_path, 'rb') as remote_file:
# Read the entire file into memory
data = remote_file.read() data = remote_file.read()
return memoryview(data) return memoryview(data)
except Exception as e: except Exception as e:
logger.error(f"File download failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
async def _sync_files(self): async def _sync_files(self):
@@ -70,90 +94,96 @@ class SampleManager:
await self._connect_sftp() await self._connect_sftp()
try: try:
# Get remote files list - using listdir_attr directly on sftp_client
remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH) remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH)
logger.info(f"Found {len(remote_files)} files in remote directory at {settings.SFTP_PATH}")
# Update file index and download new files # if there are files, log some sample names
if remote_files:
logger.info(f"Sample filenames: {[attr.filename for attr in remote_files[:3]]}")
# Track new and updated files
updates = 0
for attr in remote_files: for attr in remote_files:
remote_path = f"{settings.SFTP_PATH}/{attr.filename}" remote_path = f"{settings.SFTP_PATH}/{attr.filename}"
file_time = datetime.fromtimestamp(attr.st_mtime)
# Check if file is new or updated
if (attr.filename not in self.file_index or if (attr.filename not in self.file_index or
datetime.fromtimestamp(attr.st_mtime) > self.file_index[attr.filename]): file_time > self.file_index[attr.filename]):
# Download file to memory
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self.memory_cache[attr.filename] = await loop.run_in_executor( self.memory_cache[attr.filename] = await loop.run_in_executor(
self.executor, self.executor,
self._download_to_memory, self._download_to_memory,
remote_path remote_path
) )
self.file_index[attr.filename] = datetime.fromtimestamp(attr.st_mtime) self.file_index[attr.filename] = file_time
updates += 1
self.last_sync = datetime.now() self.last_sync = datetime.now()
if updates > 0:
logger.info(f"Sync completed: {updates} files updated")
except Exception as e: except Exception as e:
logger.error(f"Sync failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}")
finally: finally:
self._disconnect_sftp() self._disconnect_sftp()
async def ensure_synced(self, max_age_seconds: int = 30): async def _periodic_sync(self, interval_seconds: int = 30):
"""Ensure memory cache is synced if too old""" """Periodically sync files"""
if (not self.last_sync or while self._running:
(datetime.now() - self.last_sync).total_seconds() > max_age_seconds): try:
await self._sync_files() await self._sync_files()
await asyncio.sleep(interval_seconds)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Periodic sync error: {str(e)}")
# Wait a bit before retrying on error
await asyncio.sleep(5)
async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]: async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]:
"""List sample images with pagination""" """List sample images with pagination"""
await self.ensure_synced() logger.info(f"Current file index has {len(self.file_index)} files")
logger.info(f"Memory cache has {len(self.memory_cache)} files")
# Debug: print some keys
logger.info(f"File index keys: {list(self.file_index.keys())[:3]}")
logger.info(f"Memory cache keys: {list(self.memory_cache.keys())[:3]}")
# Get sorted list of files
files = sorted( files = sorted(
[(f, self.file_index[f]) for f in self.file_index], [(f, self.file_index[f]) for f in self.file_index],
key=lambda x: x[1], key=lambda x: x[1],
reverse=True reverse=True
) )
# Apply pagination logger.info(f"Sorted files list length: {len(files)}")
# Debug: print first few sorted items
if files:
logger.info(f"First few sorted items: {files[:3]}")
files = files[offset:offset + limit] files = files[offset:offset + limit]
# Create Sample objects return [ # This return statement was missing
return [
Sample( Sample(
filename=filename, filename=filename,
url=f"/api/v1/samples/image/{filename}", url=f"{settings.API_VER_STR}/samples/image/{filename}",
created_at=created_at created_at=created_at
) )
for filename, created_at in files for filename, created_at in files
] ]
async def get_latest_samples(self, count: int = 5) -> List[Sample]: async def get_latest_samples(self, count: int = 5) -> List[Sample]:
"""Get most recent samples""" """Get most recent samples"""
return await self.list_samples(limit=count, offset=0) return await self.list_samples(limit=count, offset=0)
async def get_sample_data(self, filename: str) -> Optional[memoryview]: async def get_sample_data(self, filename: str) -> Optional[memoryview]:
"""Get image data from memory cache""" """Get image data from memory cache"""
await self.ensure_synced()
if filename not in self.memory_cache: if filename not in self.memory_cache:
raise HTTPException(status_code=404, detail="Sample not found") raise HTTPException(status_code=404, detail="Sample not found")
return self.memory_cache[filename] return self.memory_cache[filename]
def cleanup_old_files(self, max_files: int = 1000): def get_stats(self):
"""Cleanup old files from memory cache""" """Get cache statistics"""
if len(self.memory_cache) > max_files: return {
# Sort files by date and keep only the newest "cached_files": len(self.memory_cache),
files = sorted( "cache_size_mb": sum(len(mv) for mv in self.memory_cache.values()) / (1024 * 1024),
[(f, self.file_index[f]) for f in self.file_index], "last_sync": self.last_sync.isoformat() if self.last_sync else None,
key=lambda x: x[1], }
reverse=True
)
# Keep only max_files
files_to_keep = {f[0] for f in files[:max_files]}
# Remove old files from cache
for filename in list(self.memory_cache.keys()):
if filename not in files_to_keep:
del self.memory_cache[filename]
del self.file_index[filename]