Add support for local files and optimized caching
Signed-off-by: Felipe Cardoso <felipe.cardoso@hotmail.it>
This commit is contained in:
@@ -10,7 +10,7 @@ router = APIRouter()
|
||||
@router.get("/list", response_model=List[Sample])
|
||||
async def list_samples(
|
||||
request: Request,
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
limit: int = Query(20, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0)
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# SFTP Settings
|
||||
SFTP_HOST: str
|
||||
SFTP_USER: str
|
||||
SFTP_KEY_PATH: str = "~/.ssh/id_rsa" # Default SSH key path
|
||||
SFTP_PATH: str
|
||||
# SFTP Settings (Optional)
|
||||
SFTP_HOST: Optional[str] = None
|
||||
SFTP_USER: Optional[str] = None
|
||||
SFTP_KEY_PATH: Optional[str] = "~/.ssh/id_rsa"
|
||||
SFTP_PATH: Optional[str] = None
|
||||
SFTP_PORT: int = 22
|
||||
|
||||
# Local Settings (Optional)
|
||||
LOCAL_PATH: Optional[str] = None
|
||||
|
||||
# API Settings
|
||||
API_VER_STR: str = "/api/v1"
|
||||
PROJECT_NAME: str = "Training Monitor"
|
||||
API_VER_STR: str = "/api/v1"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -8,4 +7,6 @@ class Sample(BaseModel):
|
||||
filename: str
|
||||
url: str
|
||||
created_at: datetime
|
||||
step: Optional[int] = None
|
||||
source: str # 'local' or 'remote'
|
||||
source_path: str # full path to the file
|
||||
size: int
|
||||
|
||||
@@ -18,11 +18,15 @@ class SampleManager:
|
||||
def __init__(self):
|
||||
self.sftp_client = None
|
||||
self.memory_cache: Dict[str, memoryview] = {}
|
||||
self.samples: Dict[str, Sample] = {} # Store Sample instances directly
|
||||
self.file_index: Dict[str, datetime] = {}
|
||||
self.last_sync = None
|
||||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||||
self._sync_task = None
|
||||
self._running = False
|
||||
self.remote_path = settings.SFTP_PATH if hasattr(settings, 'SFTP_PATH') else None
|
||||
self.local_path = settings.LOCAL_PATH if hasattr(settings, 'LOCAL_PATH') else None
|
||||
|
||||
|
||||
async def startup(self):
|
||||
"""Initialize the manager and start periodic sync"""
|
||||
@@ -89,43 +93,91 @@ class SampleManager:
|
||||
raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")
|
||||
|
||||
async def _sync_files(self):
|
||||
"""Sync remote files to memory cache"""
|
||||
"""Sync files from all configured sources"""
|
||||
if self.local_path:
|
||||
await self._sync_local_files()
|
||||
|
||||
if self.remote_path:
|
||||
await self._sync_remote_files() # Rename the existing _sync_files to _sync_remote_files
|
||||
|
||||
async def _sync_local_files(self):
|
||||
"""Sync files from local directory"""
|
||||
if not self.local_path:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"Syncing local files from {self.local_path}")
|
||||
new_files_count = 0
|
||||
for filename in os.listdir(self.local_path):
|
||||
full_path = os.path.join(self.local_path, filename)
|
||||
if not os.path.isfile(full_path):
|
||||
continue
|
||||
|
||||
file_time = datetime.fromtimestamp(os.path.getmtime(full_path))
|
||||
|
||||
# Only update if file is new or modified
|
||||
if (filename not in self.samples or
|
||||
file_time > self.samples[filename].created_at):
|
||||
with open(full_path, 'rb') as f:
|
||||
data = f.read()
|
||||
self.memory_cache[filename] = memoryview(data)
|
||||
self.samples[filename] = Sample(
|
||||
filename=filename,
|
||||
url=f"{settings.API_VER_STR}/samples/image/{filename}",
|
||||
created_at=file_time,
|
||||
source='local',
|
||||
source_path=full_path,
|
||||
size=len(data)
|
||||
)
|
||||
new_files_count += 1
|
||||
|
||||
logger.info(f"Local sync completed for {self.local_path}")
|
||||
logger.info(f"Found {new_files_count} files in local directory at {self.local_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Local sync failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Local sync failed: {str(e)}")
|
||||
|
||||
async def _sync_remote_files(self):
|
||||
"""Sync remote files via SFTP"""
|
||||
if not self.sftp_client:
|
||||
await self._connect_sftp()
|
||||
|
||||
try:
|
||||
remote_files = self.sftp_client.listdir_attr(settings.SFTP_PATH)
|
||||
logger.info(f"Found {len(remote_files)} files in remote directory at {settings.SFTP_PATH}")
|
||||
# if there are files, log some sample names
|
||||
if remote_files:
|
||||
logger.info(f"Sample filenames: {[attr.filename for attr in remote_files[:3]]}")
|
||||
# Track new and updated files
|
||||
updates = 0
|
||||
|
||||
for attr in remote_files:
|
||||
remote_path = f"{settings.SFTP_PATH}/{attr.filename}"
|
||||
file_time = datetime.fromtimestamp(attr.st_mtime)
|
||||
|
||||
if (attr.filename not in self.file_index or
|
||||
file_time > self.file_index[attr.filename]):
|
||||
if (attr.filename not in self.samples or
|
||||
file_time > self.samples[attr.filename].created_at):
|
||||
loop = asyncio.get_event_loop()
|
||||
self.memory_cache[attr.filename] = await loop.run_in_executor(
|
||||
data = await loop.run_in_executor(
|
||||
self.executor,
|
||||
self._download_to_memory,
|
||||
remote_path
|
||||
)
|
||||
self.file_index[attr.filename] = file_time
|
||||
updates += 1
|
||||
self.memory_cache[attr.filename] = data
|
||||
self.samples[attr.filename] = Sample(
|
||||
filename=attr.filename,
|
||||
url=f"{settings.API_VER_STR}/samples/image/{attr.filename}",
|
||||
created_at=file_time,
|
||||
source='remote',
|
||||
source_path=remote_path,
|
||||
size=len(data)
|
||||
)
|
||||
|
||||
self.last_sync = datetime.now()
|
||||
if updates > 0:
|
||||
logger.info(f"Sync completed: {updates} files updated")
|
||||
logger.info(f"Remote sync completed with {len(remote_files)} files")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Sync failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Sync failed: {str(e)}")
|
||||
logger.error(f"Remote sync failed: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Remote sync failed: {str(e)}")
|
||||
finally:
|
||||
self._disconnect_sftp()
|
||||
|
||||
|
||||
async def _periodic_sync(self, interval_seconds: int = 30):
|
||||
"""Periodically sync files"""
|
||||
while self._running:
|
||||
@@ -141,34 +193,17 @@ class SampleManager:
|
||||
|
||||
async def list_samples(self, limit: int = 20, offset: int = 0) -> List[Sample]:
|
||||
"""List sample images with pagination"""
|
||||
logger.info(f"Current file index has {len(self.file_index)} files")
|
||||
logger.info(f"Memory cache has {len(self.memory_cache)} files")
|
||||
logger.info(f"Total samples: {len(self.samples)}")
|
||||
|
||||
# Debug: print some keys
|
||||
logger.info(f"File index keys: {list(self.file_index.keys())[:3]}")
|
||||
logger.info(f"Memory cache keys: {list(self.memory_cache.keys())[:3]}")
|
||||
|
||||
files = sorted(
|
||||
[(f, self.file_index[f]) for f in self.file_index],
|
||||
key=lambda x: x[1],
|
||||
# Sort samples by created_at
|
||||
sorted_samples = sorted(
|
||||
self.samples.values(),
|
||||
key=lambda x: x.created_at,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
logger.info(f"Sorted files list length: {len(files)}")
|
||||
# Debug: print first few sorted items
|
||||
if files:
|
||||
logger.info(f"First few sorted items: {files[:3]}")
|
||||
return sorted_samples[offset:offset + limit]
|
||||
|
||||
files = files[offset:offset + limit]
|
||||
|
||||
return [ # This return statement was missing
|
||||
Sample(
|
||||
filename=filename,
|
||||
url=f"{settings.API_VER_STR}/samples/image/{filename}",
|
||||
created_at=created_at
|
||||
)
|
||||
for filename, created_at in files
|
||||
]
|
||||
async def get_latest_samples(self, count: int = 5) -> List[Sample]:
|
||||
"""Get most recent samples"""
|
||||
return await self.list_samples(limit=count, offset=0)
|
||||
|
||||
Reference in New Issue
Block a user