diff --git a/backend/app/api/routes/samples.py b/backend/app/api/routes/samples.py index ce6d748..e54c97e 100644 --- a/backend/app/api/routes/samples.py +++ b/backend/app/api/routes/samples.py @@ -1,16 +1,15 @@ from typing import List -from fastapi import APIRouter, HTTPException, Query +from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import Response from app.models.sample import Sample -from app.services.sample_manager import SampleManager router = APIRouter() -sample_manager = SampleManager() @router.get("/list", response_model=List[Sample]) async def list_samples( + request: Request, limit: int = Query(20, ge=1, le=100), offset: int = Query(0, ge=0) ): @@ -18,31 +17,36 @@ async def list_samples( List sample images with pagination """ try: + sample_manager = request.app.state.sample_manager return await sample_manager.list_samples(limit, offset) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/latest", response_model=List[Sample]) async def get_latest_samples( + request: Request, count: int = Query(5, ge=1, le=20) ): """ Get the most recent sample images """ try: + sample_manager = request.app.state.sample_manager return await sample_manager.get_latest_samples(count) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @router.get("/image/{filename}") -async def get_sample_image(filename: str): +async def get_sample_image( + request: Request, + filename: str): """ Get a specific sample image """ try: + sample_manager = request.app.state.sample_manager image_data = await sample_manager.get_sample_data(filename) - # Try to determine content type from filename content_type = "image/jpeg" # default if filename.lower().endswith('.png'): diff --git a/backend/app/main.py b/backend/app/main.py index 7348f87..e2831bd 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,3 +1,4 @@ +import logging import platform from datetime import datetime, timezone @@ -7,7 +8,14 @@ from fastapi.middleware.cors import CORSMiddleware from app.api.routes import training, samples from app.core.config import settings +from app.services.sample_manager import SampleManager +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) app = FastAPI( title="Training Monitor API", description="API for monitoring ML training progress and samples", @@ -23,6 +31,24 @@ app.add_middleware( allow_headers=["*"], ) +# Create and store SampleManager instance +sample_manager = SampleManager() +app.state.sample_manager = sample_manager + + +@app.on_event("startup") +async def startup_event(): + """Initialize services on startup""" + logger.info("Starting up Training Monitor API") + await sample_manager.startup() + + +@app.on_event("shutdown") +async def shutdown_event(): + """Cleanup on shutdown""" + logger.info("Shutting down Training Monitor API") + await sample_manager.shutdown() + # Include routers with versioning app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"]) app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"]) diff --git a/backend/app/services/sample_manager.py b/backend/app/services/sample_manager.py index e24ed8b..44c8464 100644 --- a/backend/app/services/sample_manager.py +++ b/backend/app/services/sample_manager.py @@ -1,4 +1,5 @@ import asyncio +import logging import os from concurrent.futures import ThreadPoolExecutor from datetime import datetime @@ -10,32 +11,55 @@ from fastapi import HTTPException from app.core.config import settings from app.models.sample import Sample +logger = logging.getLogger(__name__) + class SampleManager: def __init__(self): self.sftp_client = None - self.cache_dir = "cache/samples" - self.last_sync = None - self.file_index: Dict[str, datetime] = {} self.memory_cache: Dict[str, memoryview] = {} + self.file_index: Dict[str, datetime] = {} + self.last_sync = None self.executor = ThreadPoolExecutor(max_workers=4) - self._ensure_cache_dir() + self._sync_task = None + self._running = False - def _ensure_cache_dir(self): - """Ensure cache directory exists""" - os.makedirs(self.cache_dir, exist_ok=True) + async def startup(self): + """Initialize the manager and start periodic sync""" + logger.info("Starting SampleManager initialization...") + self._running = True + try: + # Start both initial sync and periodic sync as background tasks + self._sync_task = asyncio.create_task(self._periodic_sync()) + logger.info("SampleManager started, initial sync running in background") + except Exception as e: + logger.error(f"Startup failed with error: {str(e)}") + raise + + async def shutdown(self): + """Cleanup resources""" + self._running = False + if self._sync_task: + self._sync_task.cancel() + try: + await self._sync_task + except asyncio.CancelledError: + pass + + self.executor.shutdown(wait=True) + self._disconnect_sftp() + # Clear memory cache + self.memory_cache.clear() + self.file_index.clear() + logger.info("SampleManager shutdown completed") async def _connect_sftp(self): """Create SFTP connection using SSH key""" try: - # Expand the key path (handles ~/) key_path = os.path.expanduser(settings.SFTP_KEY_PATH) - - # Create a new SSH client ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - - # Connect using the SSH key + logger.info(f"Attempting connection to {settings.SFTP_HOST} as {settings.SFTP_USER} with key {key_path}") ssh.connect( hostname=settings.SFTP_HOST, username=settings.SFTP_USER, @@ -43,9 +67,9 @@ class SampleManager: key_filename=key_path, ) - # Create SFTP client from the SSH client self.sftp_client = ssh.open_sftp() except Exception as e: + logger.error(f"SFTP connection failed: {str(e)}") raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}") def _disconnect_sftp(self): @@ -58,10 +82,10 @@ class SampleManager: """Download file directly to memory""" try: with self.sftp_client.file(remote_path, 'rb') as remote_file: - # Read the entire file into memory data = remote_file.read() return memoryview(data) except Exception as e: + logger.error(f"File download failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") async def _sync_files(self): @@ -70,90 +94,96 @@ class SampleManager: await self._connect_sftp() try: - # Get remote files list - using listdir_attr directly on sftp_client remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH) - - # Update file index and download new files + logger.info(f"Found {len(remote_files)} files in remote directory at {settings.SFTP_PATH}") + # if there are files, log some sample names + if remote_files: + logger.info(f"Sample filenames: {[attr.filename for attr in remote_files[:3]]}") + # Track new and updated files + updates = 0 for attr in remote_files: remote_path = f"{settings.SFTP_PATH}/{attr.filename}" + file_time = datetime.fromtimestamp(attr.st_mtime) - # Check if file is new or updated if (attr.filename not in self.file_index or - datetime.fromtimestamp(attr.st_mtime) > self.file_index[attr.filename]): - # Download file to memory + file_time > self.file_index[attr.filename]): loop = asyncio.get_event_loop() self.memory_cache[attr.filename] = await loop.run_in_executor( self.executor, self._download_to_memory, remote_path ) - self.file_index[attr.filename] = datetime.fromtimestamp(attr.st_mtime) + self.file_index[attr.filename] = file_time + updates += 1 self.last_sync = datetime.now() + if updates > 0: + logger.info(f"Sync completed: {updates} files updated") except Exception as e: + logger.error(f"Sync failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}") finally: self._disconnect_sftp() - async def ensure_synced(self, max_age_seconds: int = 30): - """Ensure memory cache is synced if too old""" - if (not self.last_sync or - (datetime.now() - self.last_sync).total_seconds() > max_age_seconds): - await self._sync_files() + async def _periodic_sync(self, interval_seconds: int = 30): + """Periodically sync files""" + while self._running: + try: + await self._sync_files() + await asyncio.sleep(interval_seconds) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Periodic sync error: {str(e)}") + # Wait a bit before retrying on error + await asyncio.sleep(5) async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]: """List sample images with pagination""" - await self.ensure_synced() + logger.info(f"Current file index has {len(self.file_index)} files") + logger.info(f"Memory cache has {len(self.memory_cache)} files") + + # Debug: print some keys + logger.info(f"File index keys: {list(self.file_index.keys())[:3]}") + logger.info(f"Memory cache keys: {list(self.memory_cache.keys())[:3]}") - # Get sorted list of files files = sorted( [(f, self.file_index[f]) for f in self.file_index], key=lambda x: x[1], reverse=True ) - # Apply pagination + logger.info(f"Sorted files list length: {len(files)}") + # Debug: print first few sorted items + if files: + logger.info(f"First few sorted items: {files[:3]}") + files = files[offset:offset + limit] - # Create Sample objects - return [ + return [ # This return statement was missing Sample( filename=filename, - url=f"/api/v1/samples/image/{filename}", + url=f"{settings.API_VER_STR}/samples/image/{filename}", created_at=created_at ) for filename, created_at in files ] - async def get_latest_samples(self, count: int = 5) -> List[Sample]: """Get most recent samples""" return await self.list_samples(limit=count, offset=0) async def get_sample_data(self, filename: str) -> Optional[memoryview]: """Get image data from memory cache""" - await self.ensure_synced() - if filename not in self.memory_cache: raise HTTPException(status_code=404, detail="Sample not found") return self.memory_cache[filename] - def cleanup_old_files(self, max_files: int = 1000): - """Cleanup old files from memory cache""" - if len(self.memory_cache) > max_files: - # Sort files by date and keep only the newest - files = sorted( - [(f, self.file_index[f]) for f in self.file_index], - key=lambda x: x[1], - reverse=True - ) - - # Keep only max_files - files_to_keep = {f[0] for f in files[:max_files]} - - # Remove old files from cache - for filename in list(self.memory_cache.keys()): - if filename not in files_to_keep: - del self.memory_cache[filename] - del self.file_index[filename] + def get_stats(self): + """Get cache statistics""" + return { + "cached_files": len(self.memory_cache), + "cache_size_mb": sum(len(mv) for mv in self.memory_cache.values()) / (1024 * 1024), + "last_sync": self.last_sync.isoformat() if self.last_sync else None, + }