Files
ai-training-monitor/backend/app/services/sample_manager.py
2025-01-23 08:49:16 +01:00

190 lines
7.1 KiB
Python

import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import List, Dict, Optional
import paramiko
from fastapi import HTTPException
from app.core.config import settings
from app.models.sample import Sample
logger = logging.getLogger(__name__)
class SampleManager:
def __init__(self):
self.sftp_client = None
self.memory_cache: Dict[str, memoryview] = {}
self.file_index: Dict[str, datetime] = {}
self.last_sync = None
self.executor = ThreadPoolExecutor(max_workers=4)
self._sync_task = None
self._running = False
async def startup(self):
"""Initialize the manager and start periodic sync"""
logger.info("Starting SampleManager initialization...")
self._running = True
try:
# Start both initial sync and periodic sync as background tasks
self._sync_task = asyncio.create_task(self._periodic_sync())
logger.info("SampleManager started, initial sync running in background")
except Exception as e:
logger.error(f"Startup failed with error: {str(e)}")
raise
async def shutdown(self):
"""Cleanup resources"""
self._running = False
if self._sync_task:
self._sync_task.cancel()
try:
await self._sync_task
except asyncio.CancelledError:
pass
self.executor.shutdown(wait=True)
self._disconnect_sftp()
# Clear memory cache
self.memory_cache.clear()
self.file_index.clear()
logger.info("SampleManager shutdown completed")
async def _connect_sftp(self):
"""Create SFTP connection using SSH key"""
try:
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
logger.info(f"Attempting connection to {settings.SFTP_HOST} as {settings.SFTP_USER} with key {key_path}")
ssh.connect(
hostname=settings.SFTP_HOST,
username=settings.SFTP_USER,
port=settings.SFTP_PORT,
key_filename=key_path,
)
self.sftp_client = ssh.open_sftp()
except Exception as e:
logger.error(f"SFTP connection failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
def _disconnect_sftp(self):
"""Close SFTP connection"""
if self.sftp_client:
self.sftp_client.close()
self.sftp_client = None
def _download_to_memory(self, remote_path: str) -> memoryview:
"""Download file directly to memory"""
try:
with self.sftp_client.file(remote_path, 'rb') as remote_file:
data = remote_file.read()
return memoryview(data)
except Exception as e:
logger.error(f"File download failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
async def _sync_files(self):
"""Sync remote files to memory cache"""
if not self.sftp_client:
await self._connect_sftp()
try:
remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH)
logger.info(f"Found {len(remote_files)} files in remote directory at {settings.SFTP_PATH}")
# if there are files, log some sample names
if remote_files:
logger.info(f"Sample filenames: {[attr.filename for attr in remote_files[:3]]}")
# Track new and updated files
updates = 0
for attr in remote_files:
remote_path = f"{settings.SFTP_PATH}/{attr.filename}"
file_time = datetime.fromtimestamp(attr.st_mtime)
if (attr.filename not in self.file_index or
file_time > self.file_index[attr.filename]):
loop = asyncio.get_event_loop()
self.memory_cache[attr.filename] = await loop.run_in_executor(
self.executor,
self._download_to_memory,
remote_path
)
self.file_index[attr.filename] = file_time
updates += 1
self.last_sync = datetime.now()
if updates > 0:
logger.info(f"Sync completed: {updates} files updated")
except Exception as e:
logger.error(f"Sync failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}")
finally:
self._disconnect_sftp()
async def _periodic_sync(self, interval_seconds: int = 30):
"""Periodically sync files"""
while self._running:
try:
await self._sync_files()
await asyncio.sleep(interval_seconds)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Periodic sync error: {str(e)}")
# Wait a bit before retrying on error
await asyncio.sleep(5)
async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]:
"""List sample images with pagination"""
logger.info(f"Current file index has {len(self.file_index)} files")
logger.info(f"Memory cache has {len(self.memory_cache)} files")
# Debug: print some keys
logger.info(f"File index keys: {list(self.file_index.keys())[:3]}")
logger.info(f"Memory cache keys: {list(self.memory_cache.keys())[:3]}")
files = sorted(
[(f, self.file_index[f]) for f in self.file_index],
key=lambda x: x[1],
reverse=True
)
logger.info(f"Sorted files list length: {len(files)}")
# Debug: print first few sorted items
if files:
logger.info(f"First few sorted items: {files[:3]}")
files = files[offset:offset + limit]
return [ # This return statement was missing
Sample(
filename=filename,
url=f"{settings.API_VER_STR}/samples/image/{filename}",
created_at=created_at
)
for filename, created_at in files
]
async def get_latest_samples(self, count: int = 5) -> List[Sample]:
"""Get most recent samples"""
return await self.list_samples(limit=count, offset=0)
async def get_sample_data(self, filename: str) -> Optional[memoryview]:
"""Get image data from memory cache"""
if filename not in self.memory_cache:
raise HTTPException(status_code=404, detail="Sample not found")
return self.memory_cache[filename]
def get_stats(self):
"""Get cache statistics"""
return {
"cached_files": len(self.memory_cache),
"cache_size_mb": sum(len(mv) for mv in self.memory_cache.values()) / (1024 * 1024),
"last_sync": self.last_sync.isoformat() if self.last_sync else None,
}