diff --git a/backend/app/api/routes/samples.py b/backend/app/api/routes/samples.py index fa8ed92..ce6d748 100644 --- a/backend/app/api/routes/samples.py +++ b/backend/app/api/routes/samples.py @@ -1,6 +1,7 @@ from typing import List from fastapi import APIRouter, HTTPException, Query +from fastapi.responses import Response from app.models.sample import Sample from app.services.sample_manager import SampleManager @@ -8,7 +9,6 @@ from app.services.sample_manager import SampleManager router = APIRouter() sample_manager = SampleManager() - @router.get("/list", response_model=List[Sample]) async def list_samples( limit: int = Query(20, ge=1, le=100), @@ -22,9 +22,10 @@ async def list_samples( except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - @router.get("/latest", response_model=List[Sample]) -async def get_latest_samples(count: int = Query(5, ge=1, le=20)): +async def get_latest_samples( + count: int = Query(5, ge=1, le=20) +): """ Get the most recent sample images """ @@ -32,3 +33,28 @@ async def get_latest_samples(count: int = Query(5, ge=1, le=20)): return await sample_manager.get_latest_samples(count) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/image/{filename}") +async def get_sample_image(filename: str): + """ + Get a specific sample image + """ + try: + image_data = await sample_manager.get_sample_data(filename) + + # Try to determine content type from filename + content_type = "image/jpeg" # default + if filename.lower().endswith('.png'): + content_type = "image/png" + elif filename.lower().endswith('.gif'): + content_type = "image/gif" + + return Response( + content=bytes(image_data), + media_type=content_type + ) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index e990256..f6c273c 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -5,7 +5,7 @@ class Settings(BaseSettings): # SFTP Settings SFTP_HOST: str SFTP_USER: str - SFTP_PASSWORD: str + SFTP_KEY_PATH: str = "~/.ssh/id_rsa" # Default SSH key path SFTP_PATH: str SFTP_PORT: int = 22 diff --git a/backend/app/services/sample_manager.py b/backend/app/services/sample_manager.py index 8cfb468..e24ed8b 100644 --- a/backend/app/services/sample_manager.py +++ b/backend/app/services/sample_manager.py @@ -1,15 +1,159 @@ -from typing import List +import asyncio +import os +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import List, Dict, Optional +import paramiko +from fastapi import HTTPException + +from app.core.config import settings from app.models.sample import Sample class SampleManager: - async def list_samples(self, limit: int, offset: int) -> List[Sample]: - # Implementation for listing samples from SFTP - # This is a placeholder - actual implementation needed - pass + def __init__(self): + self.sftp_client = None + self.cache_dir = "cache/samples" + self.last_sync = None + self.file_index: Dict[str, datetime] = {} + self.memory_cache: Dict[str, memoryview] = {} + self.executor = ThreadPoolExecutor(max_workers=4) + self._ensure_cache_dir() - async def get_latest_samples(self, count: int) -> List[Sample]: - # Implementation for getting latest samples - # This is a placeholder - actual implementation needed - pass + def _ensure_cache_dir(self): + """Ensure cache directory exists""" + os.makedirs(self.cache_dir, exist_ok=True) + + async def _connect_sftp(self): + """Create SFTP connection using SSH key""" + try: + # Expand the key path (handles ~/) + key_path = os.path.expanduser(settings.SFTP_KEY_PATH) + + # Create a new SSH client + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + # Connect using the SSH key + ssh.connect( + hostname=settings.SFTP_HOST, + username=settings.SFTP_USER, + port=settings.SFTP_PORT, + key_filename=key_path, + ) + + # Create SFTP client from the SSH client + self.sftp_client = ssh.open_sftp() + except Exception as e: + raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}") + + def _disconnect_sftp(self): + """Close SFTP connection""" + if self.sftp_client: + self.sftp_client.close() + self.sftp_client = None + + def _download_to_memory(self, remote_path: str) -> memoryview: + """Download file directly to memory""" + try: + with self.sftp_client.file(remote_path, 'rb') as remote_file: + # Read the entire file into memory + data = remote_file.read() + return memoryview(data) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}") + + async def _sync_files(self): + """Sync remote files to memory cache""" + if not self.sftp_client: + await self._connect_sftp() + + try: + # Get remote files list - using listdir_attr directly on sftp_client + remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH) + + # Update file index and download new files + for attr in remote_files: + remote_path = f"{settings.SFTP_PATH}/{attr.filename}" + + # Check if file is new or updated + if (attr.filename not in self.file_index or + datetime.fromtimestamp(attr.st_mtime) > self.file_index[attr.filename]): + # Download file to memory + loop = asyncio.get_event_loop() + self.memory_cache[attr.filename] = await loop.run_in_executor( + self.executor, + self._download_to_memory, + remote_path + ) + self.file_index[attr.filename] = datetime.fromtimestamp(attr.st_mtime) + + self.last_sync = datetime.now() + + except Exception as e: + raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}") + finally: + self._disconnect_sftp() + + async def ensure_synced(self, max_age_seconds: int = 30): + """Ensure memory cache is synced if too old""" + if (not self.last_sync or + (datetime.now() - self.last_sync).total_seconds() > max_age_seconds): + await self._sync_files() + + async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]: + """List sample images with pagination""" + await self.ensure_synced() + + # Get sorted list of files + files = sorted( + [(f, self.file_index[f]) for f in self.file_index], + key=lambda x: x[1], + reverse=True + ) + + # Apply pagination + files = files[offset:offset + limit] + + # Create Sample objects + return [ + Sample( + filename=filename, + url=f"/api/v1/samples/image/{filename}", + created_at=created_at + ) + for filename, created_at in files + ] + + async def get_latest_samples(self, count: int = 5) -> List[Sample]: + """Get most recent samples""" + return await self.list_samples(limit=count, offset=0) + + async def get_sample_data(self, filename: str) -> Optional[memoryview]: + """Get image data from memory cache""" + await self.ensure_synced() + + if filename not in self.memory_cache: + raise HTTPException(status_code=404, detail="Sample not found") + + return self.memory_cache[filename] + + def cleanup_old_files(self, max_files: int = 1000): + """Cleanup old files from memory cache""" + if len(self.memory_cache) > max_files: + # Sort files by date and keep only the newest + files = sorted( + [(f, self.file_index[f]) for f in self.file_index], + key=lambda x: x[1], + reverse=True + ) + + # Keep only max_files + files_to_keep = {f[0] for f in files[:max_files]} + + # Remove old files from cache + for filename in list(self.memory_cache.keys()): + if filename not in files_to_keep: + del self.memory_cache[filename] + del self.file_index[filename]