Files
ai-training-monitor/backend/app/services/sample_manager.py
2025-01-22 21:00:05 +01:00

160 lines
5.6 KiB
Python

import asyncio
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import List, Dict, Optional
import paramiko
from fastapi import HTTPException
from app.core.config import settings
from app.models.sample import Sample
class SampleManager:
def __init__(self):
self.sftp_client = None
self.cache_dir = "cache/samples"
self.last_sync = None
self.file_index: Dict[str, datetime] = {}
self.memory_cache: Dict[str, memoryview] = {}
self.executor = ThreadPoolExecutor(max_workers=4)
self._ensure_cache_dir()
def _ensure_cache_dir(self):
"""Ensure cache directory exists"""
os.makedirs(self.cache_dir, exist_ok=True)
async def _connect_sftp(self):
"""Create SFTP connection using SSH key"""
try:
# Expand the key path (handles ~/)
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
# Create a new SSH client
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# Connect using the SSH key
ssh.connect(
hostname=settings.SFTP_HOST,
username=settings.SFTP_USER,
port=settings.SFTP_PORT,
key_filename=key_path,
)
# Create SFTP client from the SSH client
self.sftp_client = ssh.open_sftp()
except Exception as e:
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
def _disconnect_sftp(self):
"""Close SFTP connection"""
if self.sftp_client:
self.sftp_client.close()
self.sftp_client = None
def _download_to_memory(self, remote_path: str) -> memoryview:
"""Download file directly to memory"""
try:
with self.sftp_client.file(remote_path, 'rb') as remote_file:
# Read the entire file into memory
data = remote_file.read()
return memoryview(data)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
async def _sync_files(self):
"""Sync remote files to memory cache"""
if not self.sftp_client:
await self._connect_sftp()
try:
# Get remote files list - using listdir_attr directly on sftp_client
remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH)
# Update file index and download new files
for attr in remote_files:
remote_path = f"{settings.SFTP_PATH}/{attr.filename}"
# Check if file is new or updated
if (attr.filename not in self.file_index or
datetime.fromtimestamp(attr.st_mtime) > self.file_index[attr.filename]):
# Download file to memory
loop = asyncio.get_event_loop()
self.memory_cache[attr.filename] = await loop.run_in_executor(
self.executor,
self._download_to_memory,
remote_path
)
self.file_index[attr.filename] = datetime.fromtimestamp(attr.st_mtime)
self.last_sync = datetime.now()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}")
finally:
self._disconnect_sftp()
async def ensure_synced(self, max_age_seconds: int = 30):
"""Ensure memory cache is synced if too old"""
if (not self.last_sync or
(datetime.now() - self.last_sync).total_seconds() > max_age_seconds):
await self._sync_files()
async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]:
"""List sample images with pagination"""
await self.ensure_synced()
# Get sorted list of files
files = sorted(
[(f, self.file_index[f]) for f in self.file_index],
key=lambda x: x[1],
reverse=True
)
# Apply pagination
files = files[offset:offset + limit]
# Create Sample objects
return [
Sample(
filename=filename,
url=f"/api/v1/samples/image/{filename}",
created_at=created_at
)
for filename, created_at in files
]
async def get_latest_samples(self, count: int = 5) -> List[Sample]:
"""Get most recent samples"""
return await self.list_samples(limit=count, offset=0)
async def get_sample_data(self, filename: str) -> Optional[memoryview]:
"""Get image data from memory cache"""
await self.ensure_synced()
if filename not in self.memory_cache:
raise HTTPException(status_code=404, detail="Sample not found")
return self.memory_cache[filename]
def cleanup_old_files(self, max_files: int = 1000):
"""Cleanup old files from memory cache"""
if len(self.memory_cache) > max_files:
# Sort files by date and keep only the newest
files = sorted(
[(f, self.file_index[f]) for f in self.file_index],
key=lambda x: x[1],
reverse=True
)
# Keep only max_files
files_to_keep = {f[0] for f in files[:max_files]}
# Remove old files from cache
for filename in list(self.memory_cache.keys()):
if filename not in files_to_keep:
del self.memory_cache[filename]
del self.file_index[filename]