Add major caching to sampler manager
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
@@ -1,16 +1,15 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Query
|
from fastapi import APIRouter, HTTPException, Query, Request
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
|
|
||||||
from app.models.sample import Sample
|
from app.models.sample import Sample
|
||||||
from app.services.sample_manager import SampleManager
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
sample_manager = SampleManager()
|
|
||||||
|
|
||||||
@router.get("/list", response_model=List[Sample])
|
@router.get("/list", response_model=List[Sample])
|
||||||
async def list_samples(
|
async def list_samples(
|
||||||
|
request: Request,
|
||||||
limit: int = Query(20, ge=1, le=100),
|
limit: int = Query(20, ge=1, le=100),
|
||||||
offset: int = Query(0, ge=0)
|
offset: int = Query(0, ge=0)
|
||||||
):
|
):
|
||||||
@@ -18,31 +17,36 @@ async def list_samples(
|
|||||||
List sample images with pagination
|
List sample images with pagination
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
sample_manager = request.app.state.sample_manager
|
||||||
return await sample_manager.list_samples(limit, offset)
|
return await sample_manager.list_samples(limit, offset)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@router.get("/latest", response_model=List[Sample])
|
@router.get("/latest", response_model=List[Sample])
|
||||||
async def get_latest_samples(
|
async def get_latest_samples(
|
||||||
|
request: Request,
|
||||||
count: int = Query(5, ge=1, le=20)
|
count: int = Query(5, ge=1, le=20)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Get the most recent sample images
|
Get the most recent sample images
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
sample_manager = request.app.state.sample_manager
|
||||||
return await sample_manager.get_latest_samples(count)
|
return await sample_manager.get_latest_samples(count)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/image/{filename}")
|
@router.get("/image/{filename}")
|
||||||
async def get_sample_image(filename: str):
|
async def get_sample_image(
|
||||||
|
request: Request,
|
||||||
|
filename: str):
|
||||||
"""
|
"""
|
||||||
Get a specific sample image
|
Get a specific sample image
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
sample_manager = request.app.state.sample_manager
|
||||||
image_data = await sample_manager.get_sample_data(filename)
|
image_data = await sample_manager.get_sample_data(filename)
|
||||||
|
|
||||||
# Try to determine content type from filename
|
# Try to determine content type from filename
|
||||||
content_type = "image/jpeg" # default
|
content_type = "image/jpeg" # default
|
||||||
if filename.lower().endswith('.png'):
|
if filename.lower().endswith('.png'):
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import platform
|
import platform
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
@@ -7,7 +8,14 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
|
|
||||||
from app.api.routes import training, samples
|
from app.api.routes import training, samples
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.services.sample_manager import SampleManager
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Training Monitor API",
|
title="Training Monitor API",
|
||||||
description="API for monitoring ML training progress and samples",
|
description="API for monitoring ML training progress and samples",
|
||||||
@@ -23,6 +31,24 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create and store SampleManager instance
|
||||||
|
sample_manager = SampleManager()
|
||||||
|
app.state.sample_manager = sample_manager
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""Initialize services on startup"""
|
||||||
|
logger.info("Starting up Training Monitor API")
|
||||||
|
await sample_manager.startup()
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
"""Cleanup on shutdown"""
|
||||||
|
logger.info("Shutting down Training Monitor API")
|
||||||
|
await sample_manager.shutdown()
|
||||||
|
|
||||||
# Include routers with versioning
|
# Include routers with versioning
|
||||||
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
|
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
|
||||||
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])
|
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -10,32 +11,55 @@ from fastapi import HTTPException
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.models.sample import Sample
|
from app.models.sample import Sample
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SampleManager:
|
class SampleManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.sftp_client = None
|
self.sftp_client = None
|
||||||
self.cache_dir = "cache/samples"
|
|
||||||
self.last_sync = None
|
|
||||||
self.file_index: Dict[str, datetime] = {}
|
|
||||||
self.memory_cache: Dict[str, memoryview] = {}
|
self.memory_cache: Dict[str, memoryview] = {}
|
||||||
|
self.file_index: Dict[str, datetime] = {}
|
||||||
|
self.last_sync = None
|
||||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||||
self._ensure_cache_dir()
|
self._sync_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
def _ensure_cache_dir(self):
|
async def startup(self):
|
||||||
"""Ensure cache directory exists"""
|
"""Initialize the manager and start periodic sync"""
|
||||||
os.makedirs(self.cache_dir, exist_ok=True)
|
logger.info("Starting SampleManager initialization...")
|
||||||
|
self._running = True
|
||||||
|
try:
|
||||||
|
# Start both initial sync and periodic sync as background tasks
|
||||||
|
self._sync_task = asyncio.create_task(self._periodic_sync())
|
||||||
|
logger.info("SampleManager started, initial sync running in background")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Startup failed with error: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Cleanup resources"""
|
||||||
|
self._running = False
|
||||||
|
if self._sync_task:
|
||||||
|
self._sync_task.cancel()
|
||||||
|
try:
|
||||||
|
await self._sync_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.executor.shutdown(wait=True)
|
||||||
|
self._disconnect_sftp()
|
||||||
|
# Clear memory cache
|
||||||
|
self.memory_cache.clear()
|
||||||
|
self.file_index.clear()
|
||||||
|
logger.info("SampleManager shutdown completed")
|
||||||
|
|
||||||
async def _connect_sftp(self):
|
async def _connect_sftp(self):
|
||||||
"""Create SFTP connection using SSH key"""
|
"""Create SFTP connection using SSH key"""
|
||||||
try:
|
try:
|
||||||
# Expand the key path (handles ~/)
|
|
||||||
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
|
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
|
||||||
|
|
||||||
# Create a new SSH client
|
|
||||||
ssh = paramiko.SSHClient()
|
ssh = paramiko.SSHClient()
|
||||||
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
||||||
|
logger.info(f"Attempting connection to {settings.SFTP_HOST} as {settings.SFTP_USER} with key {key_path}")
|
||||||
# Connect using the SSH key
|
|
||||||
ssh.connect(
|
ssh.connect(
|
||||||
hostname=settings.SFTP_HOST,
|
hostname=settings.SFTP_HOST,
|
||||||
username=settings.SFTP_USER,
|
username=settings.SFTP_USER,
|
||||||
@@ -43,9 +67,9 @@ class SampleManager:
|
|||||||
key_filename=key_path,
|
key_filename=key_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create SFTP client from the SSH client
|
|
||||||
self.sftp_client = ssh.open_sftp()
|
self.sftp_client = ssh.open_sftp()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"SFTP connection failed: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
|
||||||
|
|
||||||
def _disconnect_sftp(self):
|
def _disconnect_sftp(self):
|
||||||
@@ -58,10 +82,10 @@ class SampleManager:
|
|||||||
"""Download file directly to memory"""
|
"""Download file directly to memory"""
|
||||||
try:
|
try:
|
||||||
with self.sftp_client.file(remote_path, 'rb') as remote_file:
|
with self.sftp_client.file(remote_path, 'rb') as remote_file:
|
||||||
# Read the entire file into memory
|
|
||||||
data = remote_file.read()
|
data = remote_file.read()
|
||||||
return memoryview(data)
|
return memoryview(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"File download failed: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
|
||||||
|
|
||||||
async def _sync_files(self):
|
async def _sync_files(self):
|
||||||
@@ -70,90 +94,96 @@ class SampleManager:
|
|||||||
await self._connect_sftp()
|
await self._connect_sftp()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get remote files list - using listdir_attr directly on sftp_client
|
|
||||||
remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH)
|
remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH)
|
||||||
|
logger.info(f"Found {len(remote_files)} files in remote directory at {settings.SFTP_PATH}")
|
||||||
# Update file index and download new files
|
# if there are files, log some sample names
|
||||||
|
if remote_files:
|
||||||
|
logger.info(f"Sample filenames: {[attr.filename for attr in remote_files[:3]]}")
|
||||||
|
# Track new and updated files
|
||||||
|
updates = 0
|
||||||
for attr in remote_files:
|
for attr in remote_files:
|
||||||
remote_path = f"{settings.SFTP_PATH}/{attr.filename}"
|
remote_path = f"{settings.SFTP_PATH}/{attr.filename}"
|
||||||
|
file_time = datetime.fromtimestamp(attr.st_mtime)
|
||||||
|
|
||||||
# Check if file is new or updated
|
|
||||||
if (attr.filename not in self.file_index or
|
if (attr.filename not in self.file_index or
|
||||||
datetime.fromtimestamp(attr.st_mtime) > self.file_index[attr.filename]):
|
file_time > self.file_index[attr.filename]):
|
||||||
# Download file to memory
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
self.memory_cache[attr.filename] = await loop.run_in_executor(
|
self.memory_cache[attr.filename] = await loop.run_in_executor(
|
||||||
self.executor,
|
self.executor,
|
||||||
self._download_to_memory,
|
self._download_to_memory,
|
||||||
remote_path
|
remote_path
|
||||||
)
|
)
|
||||||
self.file_index[attr.filename] = datetime.fromtimestamp(attr.st_mtime)
|
self.file_index[attr.filename] = file_time
|
||||||
|
updates += 1
|
||||||
|
|
||||||
self.last_sync = datetime.now()
|
self.last_sync = datetime.now()
|
||||||
|
if updates > 0:
|
||||||
|
logger.info(f"Sync completed: {updates} files updated")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"Sync failed: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}")
|
||||||
finally:
|
finally:
|
||||||
self._disconnect_sftp()
|
self._disconnect_sftp()
|
||||||
|
|
||||||
async def ensure_synced(self, max_age_seconds: int = 30):
|
async def _periodic_sync(self, interval_seconds: int = 30):
|
||||||
"""Ensure memory cache is synced if too old"""
|
"""Periodically sync files"""
|
||||||
if (not self.last_sync or
|
while self._running:
|
||||||
(datetime.now() - self.last_sync).total_seconds() > max_age_seconds):
|
try:
|
||||||
await self._sync_files()
|
await self._sync_files()
|
||||||
|
await asyncio.sleep(interval_seconds)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Periodic sync error: {str(e)}")
|
||||||
|
# Wait a bit before retrying on error
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]:
|
async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]:
|
||||||
"""List sample images with pagination"""
|
"""List sample images with pagination"""
|
||||||
await self.ensure_synced()
|
logger.info(f"Current file index has {len(self.file_index)} files")
|
||||||
|
logger.info(f"Memory cache has {len(self.memory_cache)} files")
|
||||||
|
|
||||||
|
# Debug: print some keys
|
||||||
|
logger.info(f"File index keys: {list(self.file_index.keys())[:3]}")
|
||||||
|
logger.info(f"Memory cache keys: {list(self.memory_cache.keys())[:3]}")
|
||||||
|
|
||||||
# Get sorted list of files
|
|
||||||
files = sorted(
|
files = sorted(
|
||||||
[(f, self.file_index[f]) for f in self.file_index],
|
[(f, self.file_index[f]) for f in self.file_index],
|
||||||
key=lambda x: x[1],
|
key=lambda x: x[1],
|
||||||
reverse=True
|
reverse=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply pagination
|
logger.info(f"Sorted files list length: {len(files)}")
|
||||||
|
# Debug: print first few sorted items
|
||||||
|
if files:
|
||||||
|
logger.info(f"First few sorted items: {files[:3]}")
|
||||||
|
|
||||||
files = files[offset:offset + limit]
|
files = files[offset:offset + limit]
|
||||||
|
|
||||||
# Create Sample objects
|
return [ # This return statement was missing
|
||||||
return [
|
|
||||||
Sample(
|
Sample(
|
||||||
filename=filename,
|
filename=filename,
|
||||||
url=f"/api/v1/samples/image/{filename}",
|
url=f"{settings.API_VER_STR}/samples/image/{filename}",
|
||||||
created_at=created_at
|
created_at=created_at
|
||||||
)
|
)
|
||||||
for filename, created_at in files
|
for filename, created_at in files
|
||||||
]
|
]
|
||||||
|
|
||||||
async def get_latest_samples(self, count: int = 5) -> List[Sample]:
|
async def get_latest_samples(self, count: int = 5) -> List[Sample]:
|
||||||
"""Get most recent samples"""
|
"""Get most recent samples"""
|
||||||
return await self.list_samples(limit=count, offset=0)
|
return await self.list_samples(limit=count, offset=0)
|
||||||
|
|
||||||
async def get_sample_data(self, filename: str) -> Optional[memoryview]:
|
async def get_sample_data(self, filename: str) -> Optional[memoryview]:
|
||||||
"""Get image data from memory cache"""
|
"""Get image data from memory cache"""
|
||||||
await self.ensure_synced()
|
|
||||||
|
|
||||||
if filename not in self.memory_cache:
|
if filename not in self.memory_cache:
|
||||||
raise HTTPException(status_code=404, detail="Sample not found")
|
raise HTTPException(status_code=404, detail="Sample not found")
|
||||||
|
|
||||||
return self.memory_cache[filename]
|
return self.memory_cache[filename]
|
||||||
|
|
||||||
def cleanup_old_files(self, max_files: int = 1000):
|
def get_stats(self):
|
||||||
"""Cleanup old files from memory cache"""
|
"""Get cache statistics"""
|
||||||
if len(self.memory_cache) > max_files:
|
return {
|
||||||
# Sort files by date and keep only the newest
|
"cached_files": len(self.memory_cache),
|
||||||
files = sorted(
|
"cache_size_mb": sum(len(mv) for mv in self.memory_cache.values()) / (1024 * 1024),
|
||||||
[(f, self.file_index[f]) for f in self.file_index],
|
"last_sync": self.last_sync.isoformat() if self.last_sync else None,
|
||||||
key=lambda x: x[1],
|
}
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Keep only max_files
|
|
||||||
files_to_keep = {f[0] for f in files[:max_files]}
|
|
||||||
|
|
||||||
# Remove old files from cache
|
|
||||||
for filename in list(self.memory_cache.keys()):
|
|
||||||
if filename not in files_to_keep:
|
|
||||||
del self.memory_cache[filename]
|
|
||||||
del self.file_index[filename]
|
|
||||||
|
|||||||
Reference in New Issue
Block a user