Compare commits

...

3 Commits

Author SHA1 Message Date
Felipe Cardoso
df5b42b9c9 Add major caching to sampler manager
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
2025-01-23 08:49:16 +01:00
Felipe Cardoso
7fc8fa17d6 Working Sampler and samples routes working
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
2025-01-22 21:00:05 +01:00
Felipe Cardoso
2ece0b2d8f Add base backend api root page
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
2025-01-22 18:04:50 +01:00
5 changed files with 278 additions and 17 deletions

View File

@@ -1,16 +1,15 @@
from typing import List
from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response
from app.models.sample import Sample
from app.services.sample_manager import SampleManager
router = APIRouter()
sample_manager = SampleManager()
@router.get("/list", response_model=List[Sample])
async def list_samples(
request: Request,
limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0)
):
@@ -18,17 +17,48 @@ async def list_samples(
List sample images with pagination
"""
try:
sample_manager = request.app.state.sample_manager
return await sample_manager.list_samples(limit, offset)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/latest", response_model=List[Sample])
async def get_latest_samples(count: int = Query(5, ge=1, le=20)):
async def get_latest_samples(
request: Request,
count: int = Query(5, ge=1, le=20)
):
"""
Get the most recent sample images
"""
try:
sample_manager = request.app.state.sample_manager
return await sample_manager.get_latest_samples(count)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/image/{filename}")
async def get_sample_image(
request: Request,
filename: str):
"""
Get a specific sample image
"""
try:
sample_manager = request.app.state.sample_manager
image_data = await sample_manager.get_sample_data(filename)
# Try to determine content type from filename
content_type = "image/jpeg" # default
if filename.lower().endswith('.png'):
content_type = "image/png"
elif filename.lower().endswith('.gif'):
content_type = "image/gif"
return Response(
content=bytes(image_data),
media_type=content_type
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -5,7 +5,7 @@ class Settings(BaseSettings):
# SFTP Settings
SFTP_HOST: str
SFTP_USER: str
SFTP_PASSWORD: str
SFTP_KEY_PATH: str = "~/.ssh/id_rsa" # Default SSH key path
SFTP_PATH: str
SFTP_PORT: int = 22

View File

@@ -1,9 +1,21 @@
import logging
import platform
from datetime import datetime, timezone
import psutil
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.routes import training, samples
from app.core.config import settings
from app.services.sample_manager import SampleManager
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Training Monitor API",
description="API for monitoring ML training progress and samples",
@@ -19,11 +31,55 @@ app.add_middleware(
allow_headers=["*"],
)
# Create and store SampleManager instance
sample_manager = SampleManager()
app.state.sample_manager = sample_manager
@app.on_event("startup")
async def startup_event():
"""Initialize services on startup"""
logger.info("Starting up Training Monitor API")
await sample_manager.startup()
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
logger.info("Shutting down Training Monitor API")
await sample_manager.shutdown()
# Include routers with versioning
app.include_router(training.router, prefix=f"{settings.API_VER_STR}/training", tags=["training"])
app.include_router(samples.router, prefix=f"{settings.API_VER_STR}/samples", tags=["samples"])
@app.get("/")
async def root():
"""
Root endpoint providing API status and system information
"""
return {
"name": "Training Monitor API",
"version": "1.0.0",
"status": "operational",
"timestamp": datetime.now(timezone.utc).isoformat(),
"system_info": {
"cpu_usage": f"{psutil.cpu_percent()}%",
"memory_usage": f"{psutil.virtual_memory().percent}%",
"platform": platform.platform(),
"python": platform.python_version(),
},
"endpoints": {
"docs": "/docs",
"health": "/health",
"training_status": "/api/v1/training/status",
"training_log": "/api/v1/training/log",
"samples_list": "/api/v1/samples/list",
"samples_latest": "/api/v1/samples/latest"
}
}
@app.get("/health")
async def health_check():
return {"status": "healthy"}

View File

@@ -1,15 +1,189 @@
from typing import List
import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import List, Dict, Optional
import paramiko
from fastapi import HTTPException
from app.core.config import settings
from app.models.sample import Sample
logger = logging.getLogger(__name__)
class SampleManager:
async def list_samples(self, limit: int, offset: int) -> List[Sample]:
# Implementation for listing samples from SFTP
# This is a placeholder - actual implementation needed
pass
def __init__(self):
self.sftp_client = None
self.memory_cache: Dict[str, memoryview] = {}
self.file_index: Dict[str, datetime] = {}
self.last_sync = None
self.executor = ThreadPoolExecutor(max_workers=4)
self._sync_task = None
self._running = False
async def get_latest_samples(self, count: int) -> List[Sample]:
# Implementation for getting latest samples
# This is a placeholder - actual implementation needed
pass
async def startup(self):
"""Initialize the manager and start periodic sync"""
logger.info("Starting SampleManager initialization...")
self._running = True
try:
# Start both initial sync and periodic sync as background tasks
self._sync_task = asyncio.create_task(self._periodic_sync())
logger.info("SampleManager started, initial sync running in background")
except Exception as e:
logger.error(f"Startup failed with error: {str(e)}")
raise
async def shutdown(self):
"""Cleanup resources"""
self._running = False
if self._sync_task:
self._sync_task.cancel()
try:
await self._sync_task
except asyncio.CancelledError:
pass
self.executor.shutdown(wait=True)
self._disconnect_sftp()
# Clear memory cache
self.memory_cache.clear()
self.file_index.clear()
logger.info("SampleManager shutdown completed")
async def _connect_sftp(self):
"""Create SFTP connection using SSH key"""
try:
key_path = os.path.expanduser(settings.SFTP_KEY_PATH)
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
logger.info(f"Attempting connection to {settings.SFTP_HOST} as {settings.SFTP_USER} with key {key_path}")
ssh.connect(
hostname=settings.SFTP_HOST,
username=settings.SFTP_USER,
port=settings.SFTP_PORT,
key_filename=key_path,
)
self.sftp_client = ssh.open_sftp()
except Exception as e:
logger.error(f"SFTP connection failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"SFTP Connection failed: {str(e)}")
def _disconnect_sftp(self):
"""Close SFTP connection"""
if self.sftp_client:
self.sftp_client.close()
self.sftp_client = None
def _download_to_memory(self, remote_path: str) -> memoryview:
"""Download file directly to memory"""
try:
with self.sftp_client.file(remote_path, 'rb') as remote_file:
data = remote_file.read()
return memoryview(data)
except Exception as e:
logger.error(f"File download failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
async def _sync_files(self):
"""Sync remote files to memory cache"""
if not self.sftp_client:
await self._connect_sftp()
try:
remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH)
logger.info(f"Found {len(remote_files)} files in remote directory at {settings.SFTP_PATH}")
# if there are files, log some sample names
if remote_files:
logger.info(f"Sample filenames: {[attr.filename for attr in remote_files[:3]]}")
# Track new and updated files
updates = 0
for attr in remote_files:
remote_path = f"{settings.SFTP_PATH}/{attr.filename}"
file_time = datetime.fromtimestamp(attr.st_mtime)
if (attr.filename not in self.file_index or
file_time > self.file_index[attr.filename]):
loop = asyncio.get_event_loop()
self.memory_cache[attr.filename] = await loop.run_in_executor(
self.executor,
self._download_to_memory,
remote_path
)
self.file_index[attr.filename] = file_time
updates += 1
self.last_sync = datetime.now()
if updates > 0:
logger.info(f"Sync completed: {updates} files updated")
except Exception as e:
logger.error(f"Sync failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}")
finally:
self._disconnect_sftp()
async def _periodic_sync(self, interval_seconds: int = 30):
"""Periodically sync files"""
while self._running:
try:
await self._sync_files()
await asyncio.sleep(interval_seconds)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Periodic sync error: {str(e)}")
# Wait a bit before retrying on error
await asyncio.sleep(5)
async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]:
"""List sample images with pagination"""
logger.info(f"Current file index has {len(self.file_index)} files")
logger.info(f"Memory cache has {len(self.memory_cache)} files")
# Debug: print some keys
logger.info(f"File index keys: {list(self.file_index.keys())[:3]}")
logger.info(f"Memory cache keys: {list(self.memory_cache.keys())[:3]}")
files = sorted(
[(f, self.file_index[f]) for f in self.file_index],
key=lambda x: x[1],
reverse=True
)
logger.info(f"Sorted files list length: {len(files)}")
# Debug: print first few sorted items
if files:
logger.info(f"First few sorted items: {files[:3]}")
files = files[offset:offset + limit]
return [ # This return statement was missing
Sample(
filename=filename,
url=f"{settings.API_VER_STR}/samples/image/{filename}",
created_at=created_at
)
for filename, created_at in files
]
async def get_latest_samples(self, count: int = 5) -> List[Sample]:
"""Get most recent samples"""
return await self.list_samples(limit=count, offset=0)
async def get_sample_data(self, filename: str) -> Optional[memoryview]:
"""Get image data from memory cache"""
if filename not in self.memory_cache:
raise HTTPException(status_code=404, detail="Sample not found")
return self.memory_cache[filename]
def get_stats(self):
"""Get cache statistics"""
return {
"cached_files": len(self.memory_cache),
"cache_size_mb": sum(len(mv) for mv in self.memory_cache.values()) / (1024 * 1024),
"last_sync": self.last_sync.isoformat() if self.last_sync else None,
}

View File

@@ -10,4 +10,5 @@ python-dotenv>=1.0.0
aiofiles>=23.2.1
pytest>=7.4.3
httpx>=0.25.1
pytest-asyncio>=0.21.1
pytest-asyncio>=0.21.1
psutil>=5.9.8