Add support for local files and optimized caching

Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
2025-01-23 09:27:40 +01:00
parent df5b42b9c9
commit 36ce6ac5ef
4 changed files with 88 additions and 47 deletions

View File

@@ -10,7 +10,7 @@ router = APIRouter()
@router.get("/list", response_model=List[Sample])
async def list_samples(
request: Request,
limit: int = Query(20, ge=1, le=100),
limit: int = Query(20, ge=1, le=1000),
offset: int = Query(0, ge=0)
):
"""

View File

@@ -1,17 +1,22 @@
from typing import Optional
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
# SFTP Settings
SFTP_HOST: str
SFTP_USER: str
SFTP_KEY_PATH: str = "~/.ssh/id_rsa" # Default SSH key path
SFTP_PATH: str
# SFTP Settings (Optional)
SFTP_HOST: Optional[str] = None
SFTP_USER: Optional[str] = None
SFTP_KEY_PATH: Optional[str] = "~/.ssh/id_rsa"
SFTP_PATH: Optional[str] = None
SFTP_PORT: int = 22
# Local Settings (Optional)
LOCAL_PATH: Optional[str] = None
# API Settings
API_VER_STR: str = "/api/v1"
PROJECT_NAME: str = "Training Monitor"
API_VER_STR: str = "/api/v1"
class Config:
env_file = ".env"

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel
@@ -8,4 +7,6 @@ class Sample(BaseModel):
filename: str
url: str
created_at: datetime
step: Optional[int] = None
source: str # 'local' or 'remote'
source_path: str # full path to the file
size: int

View File

@@ -18,11 +18,15 @@ class SampleManager:
def __init__(self):
self.sftp_client = None
self.memory_cache: Dict[str, memoryview] = {}
self.samples: Dict[str, Sample] = {} # Store Sample instances directly
self.file_index: Dict[str, datetime] = {}
self.last_sync = None
self.executor = ThreadPoolExecutor(max_workers=4)
self._sync_task = None
self._running = False
self.remote_path = settings.SFTP_PATH if hasattr(settings, 'SFTP_PATH') else None
self.local_path = settings.LOCAL_PATH if hasattr(settings, 'LOCAL_PATH') else None
async def startup(self):
"""Initialize the manager and start periodic sync"""
@@ -89,43 +93,91 @@ class SampleManager:
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
async def _sync_files(self):
"""Sync remote files to memory cache"""
"""Sync files from all configured sources"""
if self.local_path:
await self._sync_local_files()
if self.remote_path:
await self._sync_remote_files() # Rename the existing _sync_files to _sync_remote_files
async def _sync_local_files(self):
"""Sync files from local directory"""
if not self.local_path:
return
try:
logger.info(f"Syncing local files from {self.local_path}")
new_files_count = 0
for filename in os.listdir(self.local_path):
full_path = os.path.join(self.local_path, filename)
if not os.path.isfile(full_path):
continue
file_time = datetime.fromtimestamp(os.path.getmtime(full_path))
# Only update if file is new or modified
if (filename not in self.samples or
file_time > self.samples[filename].created_at):
with open(full_path, 'rb') as f:
data = f.read()
self.memory_cache[filename] = memoryview(data)
self.samples[filename] = Sample(
filename=filename,
url=f"{settings.API_VER_STR}/samples/image/{filename}",
created_at=file_time,
source='local',
source_path=full_path,
size=len(data)
)
new_files_count += 1
logger.info(f"Local sync completed for {self.local_path}")
logger.info(f"Found {new_files_count} files in local directory at {self.local_path}")
except Exception as e:
logger.error(f"Local sync failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Local sync failed: {str(e)}")
async def _sync_remote_files(self):
"""Sync remote files via SFTP"""
if not self.sftp_client:
await self._connect_sftp()
try:
remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH)
logger.info(f"Found {len(remote_files)} files in remote directory at {settings.SFTP_PATH}")
# if there are files, log some sample names
if remote_files:
logger.info(f"Sample filenames: {[attr.filename for attr in remote_files[:3]]}")
# Track new and updated files
updates = 0
for attr in remote_files:
remote_path = f"{settings.SFTP_PATH}/{attr.filename}"
file_time = datetime.fromtimestamp(attr.st_mtime)
if (attr.filename not in self.file_index or
file_time > self.file_index[attr.filename]):
if (attr.filename not in self.samples or
file_time > self.samples[attr.filename].created_at):
loop = asyncio.get_event_loop()
self.memory_cache[attr.filename] = await loop.run_in_executor(
data = await loop.run_in_executor(
self.executor,
self._download_to_memory,
remote_path
)
self.file_index[attr.filename] = file_time
updates += 1
self.memory_cache[attr.filename] = data
self.samples[attr.filename] = Sample(
filename=attr.filename,
url=f"{settings.API_VER_STR}/samples/image/{attr.filename}",
created_at=file_time,
source='remote',
source_path=remote_path,
size=len(data)
)
self.last_sync = datetime.now()
if updates > 0:
logger.info(f"Sync completed: {updates} files updated")
logger.info(f"Remote sync completed with {len(remote_files)} files")
except Exception as e:
logger.error(f"Sync failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}")
logger.error(f"Remote sync failed: {str(e)}")
raise HTTPException(status_code=500, detail=f"Remote sync failed: {str(e)}")
finally:
self._disconnect_sftp()
async def _periodic_sync(self, interval_seconds: int = 30):
"""Periodically sync files"""
while self._running:
@@ -141,34 +193,17 @@ class SampleManager:
async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]:
"""List sample images with pagination"""
logger.info(f"Current file index has {len(self.file_index)} files")
logger.info(f"Memory cache has {len(self.memory_cache)} files")
logger.info(f"Total samples: {len(self.samples)}")
# Debug: print some keys
logger.info(f"File index keys: {list(self.file_index.keys())[:3]}")
logger.info(f"Memory cache keys: {list(self.memory_cache.keys())[:3]}")
files = sorted(
[(f, self.file_index[f]) for f in self.file_index],
key=lambda x: x[1],
# Sort samples by created_at
sorted_samples = sorted(
self.samples.values(),
key=lambda x: x.created_at,
reverse=True
)
logger.info(f"Sorted files list length: {len(files)}")
# Debug: print first few sorted items
if files:
logger.info(f"First few sorted items: {files[:3]}")
return sorted_samples[offset:offset + limit]
files = files[offset:offset + limit]
return [ # This return statement was missing
Sample(
filename=filename,
url=f"{settings.API_VER_STR}/samples/image/{filename}",
created_at=created_at
)
for filename, created_at in files
]
async def get_latest_samples(self, count: int = 5) -> List[Sample]:
"""Get most recent samples"""
return await self.list_samples(limit=count, offset=0)