import asyncio import logging import os from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import List, Dict, Optional import paramiko from fastapi import HTTPException from app.core.config import settings from app.models.sample import Sample logger = logging.getLogger(__name__) class SampleManager: def __init__(self): self.sftp_client = None self.memory_cache: Dict[str, memoryview] = {} self.samples: Dict[str, Sample] = {} # Store Sample instances directly self.file_index: Dict[str, datetime] = {} self.last_sync = None self.executor = ThreadPoolExecutor(max_workers=4) self._sync_task = None self._running = False self.remote_path = settings.SFTP_PATH if hasattr(settings, 'SFTP_PATH') else None self.local_path = settings.LOCAL_PATH if hasattr(settings, 'LOCAL_PATH') else None async def startup(self): """Initialize the manager and start periodic sync""" logger.info("Starting SampleManager initialization...") self._running = True try: # Start both initial sync and periodic sync as background tasks self._sync_task = asyncio.create_task(self._periodic_sync()) logger.info("SampleManager started, initial sync running in background") except Exception as e: logger.error(f"Startup failed with error: {str(e)}") raise async def shutdown(self): """Cleanup resources""" self._running = False if self._sync_task: self._sync_task.cancel() try: await self._sync_task except asyncio.CancelledError: pass self.executor.shutdown(wait=True) self._disconnect_sftp() # Clear memory cache self.memory_cache.clear() self.file_index.clear() logger.info("SampleManager shutdown completed") async def _connect_sftp(self): """Create SFTP connection using SSH key""" try: key_path = os.path.expanduser(settings.SFTP_KEY_PATH) ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) logger.info(f"Attempting connection to {settings.SFTP_HOST} as {settings.SFTP_USER} with key {key_path}") ssh.connect( hostname=settings.SFTP_HOST, username=settings.SFTP_USER, port=settings.SFTP_PORT, key_filename=key_path, ) self.sftp_client = ssh.open_sftp() except Exception as e: logger.error(f"SFTP connection failed: {str(e)}") raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}") def _disconnect_sftp(self): """Close SFTP connection""" if self.sftp_client: self.sftp_client.close() self.sftp_client = None def _download_to_memory(self, remote_path: str) -> memoryview: """Download file directly to memory""" try: with self.sftp_client.file(remote_path, 'rb') as remote_file: data = remote_file.read() return memoryview(data) except Exception as e: logger.error(f"File download failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") async def _sync_files(self): """Sync files from all configured sources""" if self.local_path: await self._sync_local_files() if self.remote_path: await self._sync_remote_files() # Rename the existing _sync_files to _sync_remote_files async def _sync_local_files(self): """Sync files from local directory""" if not self.local_path: return try: logger.info(f"Syncing local files from {self.local_path}") new_files_count = 0 for filename in os.listdir(self.local_path): full_path = os.path.join(self.local_path, filename) if not os.path.isfile(full_path): continue file_time = datetime.fromtimestamp(os.path.getmtime(full_path)) # Only update if file is new or modified if (filename not in self.samples or file_time > self.samples[filename].created_at): with open(full_path, 'rb') as f: data = f.read() self.memory_cache[filename] = memoryview(data) self.samples[filename] = Sample( filename=filename, url=f"{settings.API_VER_STR}/samples/image/{filename}", created_at=file_time, source='local', source_path=full_path, size=len(data) ) new_files_count += 1 logger.info(f"Local sync completed for {self.local_path}") logger.info(f"Found {new_files_count} files in local directory at {self.local_path}") except Exception as e: logger.error(f"Local sync failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Local sync failed: {str(e)}") async def _sync_remote_files(self): """Sync remote files via SFTP""" if not self.sftp_client: await self._connect_sftp() try: remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH) logger.info(f"Found {len(remote_files)} files in remote directory at {settings.SFTP_PATH}") for attr in remote_files: remote_path = f"{settings.SFTP_PATH}/{attr.filename}" file_time = datetime.fromtimestamp(attr.st_mtime) if (attr.filename not in self.samples or file_time > self.samples[attr.filename].created_at): loop = asyncio.get_event_loop() data = await loop.run_in_executor( self.executor, self._download_to_memory, remote_path ) self.memory_cache[attr.filename] = data self.samples[attr.filename] = Sample( filename=attr.filename, url=f"{settings.API_VER_STR}/samples/image/{attr.filename}", created_at=file_time, source='remote', source_path=remote_path, size=len(data) ) self.last_sync = datetime.now() logger.info(f"Remote sync completed with {len(remote_files)} files") except Exception as e: logger.error(f"Remote sync failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Remote sync failed: {str(e)}") finally: self._disconnect_sftp() async def _periodic_sync(self, interval_seconds: int = 30): """Periodically sync files""" while self._running: try: await self._sync_files() await asyncio.sleep(interval_seconds) except asyncio.CancelledError: break except Exception as e: logger.error(f"Periodic sync error: {str(e)}") # Wait a bit before retrying on error await asyncio.sleep(5) async def list_samples(self, limit: int = 200, offset: int = 0) -> List[Sample]: """List sample images with pagination""" logger.info(f"Total samples: {len(self.samples)}") # Sort samples by created_at sorted_samples = sorted( self.samples.values(), key=lambda x: x.created_at, reverse=True ) return sorted_samples[offset:offset + limit] async def get_latest_samples(self, count: int = 20) -> List[Sample]: """Get most recent samples""" return await self.list_samples(limit=count, offset=0) async def get_sample_data(self, filename: str) -> Optional[memoryview]: """Get image data from memory cache""" if filename not in self.memory_cache: raise HTTPException(status_code=404, detail="Sample not found") return self.memory_cache[filename] def get_stats(self): """Get cache statistics""" return { "cached_files": len(self.memory_cache), "cache_size_mb": sum(len(mv) for mv in self.memory_cache.values()) / (1024 * 1024), "last_sync": self.last_sync.isoformat() if self.last_sync else None, }