225 lines
8.6 KiB
Python
225 lines
8.6 KiB
Python
import asyncio
|
|
import logging
|
|
import os
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import datetime
|
|
from typing import List, Dict, Optional
|
|
|
|
import paramiko
|
|
from fastapi import HTTPException
|
|
|
|
from app.core.config import settings
|
|
from app.models.sample import Sample
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SampleManager:
|
|
def __init__(self):
|
|
self.sftp_client = None
|
|
self.memory_cache: Dict[str, memoryview] = {}
|
|
self.samples: Dict[str, Sample] = {} # Store Sample instances directly
|
|
self.file_index: Dict[str, datetime] = {}
|
|
self.last_sync = None
|
|
self.executor = ThreadPoolExecutor(max_workers=4)
|
|
self._sync_task = None
|
|
self._running = False
|
|
self.remote_path = settings.SFTP_PATH if hasattr(settings, 'SFTP_PATH') else None
|
|
self.local_path = settings.LOCAL_PATH if hasattr(settings, 'LOCAL_PATH') else None
|
|
|
|
|
|
async def startup(self):
|
|
"""Initialize the manager and start periodic sync"""
|
|
logger.info("Starting SampleManager initialization...")
|
|
self._running = True
|
|
try:
|
|
# Start both initial sync and periodic sync as background tasks
|
|
self._sync_task = asyncio.create_task(self._periodic_sync())
|
|
logger.info("SampleManager started, initial sync running in background")
|
|
except Exception as e:
|
|
logger.error(f"Startup failed with error: {str(e)}")
|
|
raise
|
|
|
|
async def shutdown(self):
|
|
"""Cleanup resources"""
|
|
self._running = False
|
|
if self._sync_task:
|
|
self._sync_task.cancel()
|
|
try:
|
|
await self._sync_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
self.executor.shutdown(wait=True)
|
|
self._disconnect_sftp()
|
|
# Clear memory cache
|
|
self.memory_cache.clear()
|
|
self.file_index.clear()
|
|
logger.info("SampleManager shutdown completed")
|
|
|
|
async def _connect_sftp(self):
|
|
"""Create SFTP connection using SSH key"""
|
|
try:
|
|
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
|
|
ssh = paramiko.SSHClient()
|
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
logger.info(f"Attempting connection to {settings.SFTP_HOST} as {settings.SFTP_USER} with key {key_path}")
|
|
ssh.connect(
|
|
hostname=settings.SFTP_HOST,
|
|
username=settings.SFTP_USER,
|
|
port=settings.SFTP_PORT,
|
|
key_filename=key_path,
|
|
)
|
|
|
|
self.sftp_client = ssh.open_sftp()
|
|
except Exception as e:
|
|
logger.error(f"SFTP connection failed: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
|
|
|
|
def _disconnect_sftp(self):
|
|
"""Close SFTP connection"""
|
|
if self.sftp_client:
|
|
self.sftp_client.close()
|
|
self.sftp_client = None
|
|
|
|
def _download_to_memory(self, remote_path: str) -> memoryview:
|
|
"""Download file directly to memory"""
|
|
try:
|
|
with self.sftp_client.file(remote_path, 'rb') as remote_file:
|
|
data = remote_file.read()
|
|
return memoryview(data)
|
|
except Exception as e:
|
|
logger.error(f"File download failed: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
|
|
|
|
async def _sync_files(self):
|
|
"""Sync files from all configured sources"""
|
|
if self.local_path:
|
|
await self._sync_local_files()
|
|
|
|
if self.remote_path:
|
|
await self._sync_remote_files() # Rename the existing _sync_files to _sync_remote_files
|
|
|
|
async def _sync_local_files(self):
|
|
"""Sync files from local directory"""
|
|
if not self.local_path:
|
|
return
|
|
|
|
try:
|
|
logger.info(f"Syncing local files from {self.local_path}")
|
|
new_files_count = 0
|
|
for filename in os.listdir(self.local_path):
|
|
full_path = os.path.join(self.local_path, filename)
|
|
if not os.path.isfile(full_path):
|
|
continue
|
|
|
|
file_time = datetime.fromtimestamp(os.path.getmtime(full_path))
|
|
|
|
# Only update if file is new or modified
|
|
if (filename not in self.samples or
|
|
file_time > self.samples[filename].created_at):
|
|
with open(full_path, 'rb') as f:
|
|
data = f.read()
|
|
self.memory_cache[filename] = memoryview(data)
|
|
self.samples[filename] = Sample(
|
|
filename=filename,
|
|
url=f"{settings.API_VER_STR}/samples/image/{filename}",
|
|
created_at=file_time,
|
|
source='local',
|
|
source_path=full_path,
|
|
size=len(data)
|
|
)
|
|
new_files_count += 1
|
|
|
|
logger.info(f"Local sync completed for {self.local_path}")
|
|
logger.info(f"Found {new_files_count} files in local directory at {self.local_path}")
|
|
except Exception as e:
|
|
logger.error(f"Local sync failed: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Local sync failed: {str(e)}")
|
|
|
|
async def _sync_remote_files(self):
|
|
"""Sync remote files via SFTP"""
|
|
if not self.sftp_client:
|
|
await self._connect_sftp()
|
|
|
|
try:
|
|
remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH)
|
|
logger.info(f"Found {len(remote_files)} files in remote directory at {settings.SFTP_PATH}")
|
|
|
|
for attr in remote_files:
|
|
remote_path = f"{settings.SFTP_PATH}/{attr.filename}"
|
|
file_time = datetime.fromtimestamp(attr.st_mtime)
|
|
|
|
if (attr.filename not in self.samples or
|
|
file_time > self.samples[attr.filename].created_at):
|
|
loop = asyncio.get_event_loop()
|
|
data = await loop.run_in_executor(
|
|
self.executor,
|
|
self._download_to_memory,
|
|
remote_path
|
|
)
|
|
self.memory_cache[attr.filename] = data
|
|
self.samples[attr.filename] = Sample(
|
|
filename=attr.filename,
|
|
url=f"{settings.API_VER_STR}/samples/image/{attr.filename}",
|
|
created_at=file_time,
|
|
source='remote',
|
|
source_path=remote_path,
|
|
size=len(data)
|
|
)
|
|
|
|
self.last_sync = datetime.now()
|
|
logger.info(f"Remote sync completed with {len(remote_files)} files")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Remote sync failed: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Remote sync failed: {str(e)}")
|
|
finally:
|
|
self._disconnect_sftp()
|
|
|
|
|
|
async def _periodic_sync(self, interval_seconds: int = 30):
|
|
"""Periodically sync files"""
|
|
while self._running:
|
|
try:
|
|
await self._sync_files()
|
|
await asyncio.sleep(interval_seconds)
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Periodic sync error: {str(e)}")
|
|
# Wait a bit before retrying on error
|
|
await asyncio.sleep(5)
|
|
|
|
async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]:
|
|
"""List sample images with pagination"""
|
|
logger.info(f"Total samples: {len(self.samples)}")
|
|
|
|
# Sort samples by created_at
|
|
sorted_samples = sorted(
|
|
self.samples.values(),
|
|
key=lambda x: x.created_at,
|
|
reverse=True
|
|
)
|
|
|
|
return sorted_samples[offset:offset + limit]
|
|
|
|
async def get_latest_samples(self, count: int = 5) -> List[Sample]:
|
|
"""Get most recent samples"""
|
|
return await self.list_samples(limit=count, offset=0)
|
|
|
|
async def get_sample_data(self, filename: str) -> Optional[memoryview]:
|
|
"""Get image data from memory cache"""
|
|
if filename not in self.memory_cache:
|
|
raise HTTPException(status_code=404, detail="Sample not found")
|
|
|
|
return self.memory_cache[filename]
|
|
|
|
def get_stats(self):
|
|
"""Get cache statistics"""
|
|
return {
|
|
"cached_files": len(self.memory_cache),
|
|
"cache_size_mb": sum(len(mv) for mv in self.memory_cache.values()) / (1024 * 1024),
|
|
"last_sync": self.last_sync.isoformat() if self.last_sync else None,
|
|
}
|