diff --git a/backend/app/api/routes/samples.py b/backend/app/api/routes/samples.py index e54c97e..9857791 100644 --- a/backend/app/api/routes/samples.py +++ b/backend/app/api/routes/samples.py @@ -10,7 +10,7 @@ router = APIRouter() @router.get("/list", response_model=List[Sample]) async def list_samples( request: Request, - limit: int = Query(20, ge=1, le=100), + limit: int = Query(20, ge=1, le=1000), offset: int = Query(0, ge=0) ): """ diff --git a/backend/app/core/config.py b/backend/app/core/config.py index f6c273c..f5360de 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -1,17 +1,22 @@ +from typing import Optional + from pydantic_settings import BaseSettings class Settings(BaseSettings): - # SFTP Settings - SFTP_HOST: str - SFTP_USER: str - SFTP_KEY_PATH: str = "~/.ssh/id_rsa" # Default SSH key path - SFTP_PATH: str + # SFTP Settings (Optional) + SFTP_HOST: Optional[str] = None + SFTP_USER: Optional[str] = None + SFTP_KEY_PATH: Optional[str] = "~/.ssh/id_rsa" + SFTP_PATH: Optional[str] = None SFTP_PORT: int = 22 + # Local Settings (Optional) + LOCAL_PATH: Optional[str] = None + # API Settings - API_VER_STR: str = "/api/v1" PROJECT_NAME: str = "Training Monitor" + API_VER_STR: str = "/api/v1" class Config: env_file = ".env" diff --git a/backend/app/models/sample.py b/backend/app/models/sample.py index 164dd00..0a4389c 100644 --- a/backend/app/models/sample.py +++ b/backend/app/models/sample.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Optional from pydantic import BaseModel @@ -8,4 +7,6 @@ class Sample(BaseModel): filename: str url: str created_at: datetime - step: Optional[int] = None + source: str # 'local' or 'remote' + source_path: str # full path to the file + size: int diff --git a/backend/app/services/sample_manager.py b/backend/app/services/sample_manager.py index 44c8464..28c9fa7 100644 --- a/backend/app/services/sample_manager.py +++ b/backend/app/services/sample_manager.py @@ -18,11 +18,15 @@ class SampleManager: def __init__(self): self.sftp_client = None self.memory_cache: Dict[str, memoryview] = {} + self.samples: Dict[str, Sample] = {} # Store Sample instances directly self.file_index: Dict[str, datetime] = {} self.last_sync = None self.executor = ThreadPoolExecutor(max_workers=4) self._sync_task = None self._running = False + self.remote_path = settings.SFTP_PATH if hasattr(settings, 'SFTP_PATH') else None + self.local_path = settings.LOCAL_PATH if hasattr(settings, 'LOCAL_PATH') else None + async def startup(self): """Initialize the manager and start periodic sync""" @@ -89,43 +93,91 @@ class SampleManager: raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") async def _sync_files(self): - """Sync remote files to memory cache""" + """Sync files from all configured sources""" + if self.local_path: + await self._sync_local_files() + + if self.remote_path: + await self._sync_remote_files() # Rename the existing _sync_files to _sync_remote_files + + async def _sync_local_files(self): + """Sync files from local directory""" + if not self.local_path: + return + + try: + logger.info(f"Syncing local files from {self.local_path}") + new_files_count = 0 + for filename in os.listdir(self.local_path): + full_path = os.path.join(self.local_path, filename) + if not os.path.isfile(full_path): + continue + + file_time = datetime.fromtimestamp(os.path.getmtime(full_path)) + + # Only update if file is new or modified + if (filename not in self.samples or + file_time > self.samples[filename].created_at): + with open(full_path, 'rb') as f: + data = f.read() + self.memory_cache[filename] = memoryview(data) + self.samples[filename] = Sample( + filename=filename, + url=f"{settings.API_VER_STR}/samples/image/{filename}", + created_at=file_time, + source='local', + source_path=full_path, + size=len(data) + ) + new_files_count += 1 + + logger.info(f"Local sync completed for {self.local_path}") + logger.info(f"Found {new_files_count} files in local directory at {self.local_path}") + except Exception as e: + logger.error(f"Local sync failed: {str(e)}") + raise HTTPException(status_code=500, detail=f"Local sync failed: {str(e)}") + + async def _sync_remote_files(self): + """Sync remote files via SFTP""" if not self.sftp_client: await self._connect_sftp() try: remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH) logger.info(f"Found {len(remote_files)} files in remote directory at {settings.SFTP_PATH}") - # if there are files, log some sample names - if remote_files: - logger.info(f"Sample filenames: {[attr.filename for attr in remote_files[:3]]}") - # Track new and updated files - updates = 0 + for attr in remote_files: remote_path = f"{settings.SFTP_PATH}/{attr.filename}" file_time = datetime.fromtimestamp(attr.st_mtime) - if (attr.filename not in self.file_index or - file_time > self.file_index[attr.filename]): + if (attr.filename not in self.samples or + file_time > self.samples[attr.filename].created_at): loop = asyncio.get_event_loop() - self.memory_cache[attr.filename] = await loop.run_in_executor( + data = await loop.run_in_executor( self.executor, self._download_to_memory, remote_path ) - self.file_index[attr.filename] = file_time - updates += 1 + self.memory_cache[attr.filename] = data + self.samples[attr.filename] = Sample( + filename=attr.filename, + url=f"{settings.API_VER_STR}/samples/image/{attr.filename}", + created_at=file_time, + source='remote', + source_path=remote_path, + size=len(data) + ) self.last_sync = datetime.now() - if updates > 0: - logger.info(f"Sync completed: {updates} files updated") + logger.info(f"Remote sync completed with {len(remote_files)} files") except Exception as e: - logger.error(f"Sync failed: {str(e)}") - raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}") + logger.error(f"Remote sync failed: {str(e)}") + raise HTTPException(status_code=500, detail=f"Remote sync failed: {str(e)}") finally: self._disconnect_sftp() + async def _periodic_sync(self, interval_seconds: int = 30): """Periodically sync files""" while self._running: @@ -141,34 +193,17 @@ class SampleManager: async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]: """List sample images with pagination""" - logger.info(f"Current file index has {len(self.file_index)} files") - logger.info(f"Memory cache has {len(self.memory_cache)} files") + logger.info(f"Total samples: {len(self.samples)}") - # Debug: print some keys - logger.info(f"File index keys: {list(self.file_index.keys())[:3]}") - logger.info(f"Memory cache keys: {list(self.memory_cache.keys())[:3]}") - - files = sorted( - [(f, self.file_index[f]) for f in self.file_index], - key=lambda x: x[1], + # Sort samples by created_at + sorted_samples = sorted( + self.samples.values(), + key=lambda x: x.created_at, reverse=True ) - logger.info(f"Sorted files list length: {len(files)}") - # Debug: print first few sorted items - if files: - logger.info(f"First few sorted items: {files[:3]}") + return sorted_samples[offset:offset + limit] - files = files[offset:offset + limit] - - return [ # This return statement was missing - Sample( - filename=filename, - url=f"{settings.API_VER_STR}/samples/image/{filename}", - created_at=created_at - ) - for filename, created_at in files - ] async def get_latest_samples(self, count: int = 5) -> List[Sample]: """Get most recent samples""" return await self.list_samples(limit=count, offset=0)