Files
ai-training-monitor/backend/app/services/sample_manager.py
2025-01-23 13:10:58 +01:00

225 lines
8.6 KiB
Python

import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import List, Dict, Optional
import paramiko
from fastapi import HTTPException
from app.core.config import settings
from app.models.sample import Sample
logger = logging.getLogger(__name__)
class SampleManager:
def __init__(self):
self.sftp_client = None
self.memory_cache: Dict[str, memoryview] = {}
self.samples: Dict[str, Sample] = {} # Store Sample instances directly
self.file_index: Dict[str, datetime] = {}
self.last_sync = None
self.executor = ThreadPoolExecutor(max_workers=4)
self._sync_task = None
self._running = False
self.remote_path = settings.SFTP_PATH if hasattr(settings, 'SFTP_PATH') else None
self.local_path = settings.LOCAL_PATH if hasattr(settings, 'LOCAL_PATH') else None
async def startup(self):
"""Initialize the manager and start periodic sync"""
logger.info("Starting SampleManager initialization...")
self._running = True
try:
# Start both initial sync and periodic sync as background tasks
self._sync_task = asyncio.create_task(self._periodic_sync())
logger.info("SampleManager started, initial sync running in background")
except Exception as e:
logger.error(f"Startup failed with error: {str(e)}")
raise
async def shutdown(self):
"""Cleanup resources"""
self._running = False
if self._sync_task:
self._sync_task.cancel()
try:
await self._sync_task
except asyncio.CancelledError:
pass
self.executor.shutdown(wait=True)
self._disconnect_sftp()
# Clear memory cache
self.memory_cache.clear()
self.file_index.clear()
logger.info("SampleManager shutdown completed")
async def _connect_sftp(self):
"""Create SFTP connection using SSH key"""
try:
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
logger.info(f"Attempting connection to {settings.SFTP_HOST} as {settings.SFTP_USER} with key {key_path}")
ssh.connect(
hostname=settings.SFTP_HOST,
username=settings.SFTP_USER,
port=settings.SFTP_PORT,
key_filename=key_path,
)
self.sftp_client = ssh.open_sftp()
except Exception as e:
logger.error(f"SFTP connection failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
def _disconnect_sftp(self):
"""Close SFTP connection"""
if self.sftp_client:
self.sftp_client.close()
self.sftp_client = None
def _download_to_memory(self, remote_path: str) -> memoryview:
"""Download file directly to memory"""
try:
with self.sftp_client.file(remote_path, 'rb') as remote_file:
data = remote_file.read()
return memoryview(data)
except Exception as e:
logger.error(f"File download failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
async def _sync_files(self):
"""Sync files from all configured sources"""
if self.local_path:
await self._sync_local_files()
if self.remote_path:
await self._sync_remote_files() # Rename the existing _sync_files to _sync_remote_files
async def _sync_local_files(self):
"""Sync files from local directory"""
if not self.local_path:
return
try:
logger.info(f"Syncing local files from {self.local_path}")
new_files_count = 0
for filename in os.listdir(self.local_path):
full_path = os.path.join(self.local_path, filename)
if not os.path.isfile(full_path):
continue
file_time = datetime.fromtimestamp(os.path.getmtime(full_path))
# Only update if file is new or modified
if (filename not in self.samples or
file_time > self.samples[filename].created_at):
with open(full_path, 'rb') as f:
data = f.read()
self.memory_cache[filename] = memoryview(data)
self.samples[filename] = Sample(
filename=filename,
url=f"{settings.API_VER_STR}/samples/image/{filename}",
created_at=file_time,
source='local',
source_path=full_path,
size=len(data)
)
new_files_count += 1
logger.info(f"Local sync completed for {self.local_path}")
logger.info(f"Found {new_files_count} files in local directory at {self.local_path}")
except Exception as e:
logger.error(f"Local sync failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Local sync failed: {str(e)}")
async def _sync_remote_files(self):
"""Sync remote files via SFTP"""
if not self.sftp_client:
await self._connect_sftp()
try:
remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH)
logger.info(f"Found {len(remote_files)} files in remote directory at {settings.SFTP_PATH}")
for attr in remote_files:
remote_path = f"{settings.SFTP_PATH}/{attr.filename}"
file_time = datetime.fromtimestamp(attr.st_mtime)
if (attr.filename not in self.samples or
file_time > self.samples[attr.filename].created_at):
loop = asyncio.get_event_loop()
data = await loop.run_in_executor(
self.executor,
self._download_to_memory,
remote_path
)
self.memory_cache[attr.filename] = data
self.samples[attr.filename] = Sample(
filename=attr.filename,
url=f"{settings.API_VER_STR}/samples/image/{attr.filename}",
created_at=file_time,
source='remote',
source_path=remote_path,
size=len(data)
)
self.last_sync = datetime.now()
logger.info(f"Remote sync completed with {len(remote_files)} files")
except Exception as e:
logger.error(f"Remote sync failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Remote sync failed: {str(e)}")
finally:
self._disconnect_sftp()
async def _periodic_sync(self, interval_seconds: int = 30):
"""Periodically sync files"""
while self._running:
try:
await self._sync_files()
await asyncio.sleep(interval_seconds)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Periodic sync error: {str(e)}")
# Wait a bit before retrying on error
await asyncio.sleep(5)
async def list_samples(self, limit: int = 200, offset: int = 0) -> List[Sample]:
"""List sample images with pagination"""
logger.info(f"Total samples: {len(self.samples)}")
# Sort samples by created_at
sorted_samples = sorted(
self.samples.values(),
key=lambda x: x.created_at,
reverse=True
)
return sorted_samples[offset:offset + limit]
async def get_latest_samples(self, count: int = 20) -> List[Sample]:
"""Get most recent samples"""
return await self.list_samples(limit=count, offset=0)
async def get_sample_data(self, filename: str) -> Optional[memoryview]:
"""Get image data from memory cache"""
if filename not in self.memory_cache:
raise HTTPException(status_code=404, detail="Sample not found")
return self.memory_cache[filename]
def get_stats(self):
"""Get cache statistics"""
return {
"cached_files": len(self.memory_cache),
"cache_size_mb": sum(len(mv) for mv in self.memory_cache.values()) / (1024 * 1024),
"last_sync": self.last_sync.isoformat() if self.last_sync else None,
}