- Removed outdated and redundant Alembic migration files to streamline the migration directory. This improves maintainability and eliminates duplicate or unused scripts.
459 lines
14 KiB
Python
Executable File
459 lines
14 KiB
Python
Executable File
#!/usr/bin/env python
|
|
"""
|
|
Database migration helper script.
|
|
Provides convenient commands for generating and applying Alembic migrations.
|
|
|
|
Usage:
|
|
# Generate migration (auto-increments revision ID: 0001, 0002, etc.)
|
|
python migrate.py --local generate "Add new field"
|
|
python migrate.py --local auto "Add new field"
|
|
|
|
# Apply migrations
|
|
python migrate.py --local apply
|
|
|
|
# Show next revision ID
|
|
python migrate.py next
|
|
|
|
# Reset after deleting migrations (clears alembic_version table)
|
|
python migrate.py --local reset
|
|
|
|
# Override auto-increment with custom revision ID
|
|
python migrate.py --local generate "initial_models" --rev-id custom_id
|
|
|
|
# Generate empty migration template without database (no autogenerate)
|
|
python migrate.py generate "Add performance indexes" --offline
|
|
|
|
# Inside Docker (without --local flag):
|
|
python migrate.py auto "Add new field"
|
|
"""
|
|
import argparse
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Ensure the project root is in the Python path
|
|
project_root = Path(__file__).resolve().parent
|
|
if str(project_root) not in sys.path:
|
|
sys.path.append(str(project_root))
|
|
|
|
|
|
def setup_database_url(use_local: bool) -> str:
|
|
"""Setup database URL, optionally using localhost for local development."""
|
|
if use_local:
|
|
# Override DATABASE_URL to use localhost instead of Docker hostname
|
|
local_url = os.environ.get(
|
|
"LOCAL_DATABASE_URL",
|
|
"postgresql://postgres:postgres@localhost:5432/app"
|
|
)
|
|
os.environ["DATABASE_URL"] = local_url
|
|
return local_url
|
|
|
|
# Use the configured DATABASE_URL from environment/.env
|
|
from app.core.config import settings
|
|
return settings.database_url
|
|
|
|
|
|
def check_models():
|
|
"""Check if all models are properly imported"""
|
|
print("Checking model imports...")
|
|
|
|
try:
|
|
# Import all models through the models package
|
|
from app.models import __all__ as all_models
|
|
print(f"Found {len(all_models)} model(s):")
|
|
for model in all_models:
|
|
print(f" - {model}")
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error checking models: {e}")
|
|
return False
|
|
|
|
|
|
def generate_migration(message, rev_id=None, auto_rev_id=True, offline=False):
|
|
"""Generate an Alembic migration with the given message.
|
|
|
|
Args:
|
|
message: Migration message
|
|
rev_id: Custom revision ID (overrides auto_rev_id)
|
|
auto_rev_id: If True and rev_id is None, auto-generate sequential ID
|
|
offline: If True, generate empty migration without database (no autogenerate)
|
|
"""
|
|
# Auto-generate sequential revision ID if not provided
|
|
if rev_id is None and auto_rev_id:
|
|
rev_id = get_next_rev_id()
|
|
|
|
print(f"Generating migration: {message}")
|
|
if rev_id:
|
|
print(f"Using revision ID: {rev_id}")
|
|
|
|
if offline:
|
|
# Generate migration file directly without database connection
|
|
return generate_offline_migration(message, rev_id)
|
|
|
|
cmd = ["alembic", "revision", "--autogenerate", "-m", message]
|
|
if rev_id:
|
|
cmd.extend(["--rev-id", rev_id])
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
print(result.stdout)
|
|
if result.returncode != 0:
|
|
print("Error generating migration:", file=sys.stderr)
|
|
print(result.stderr, file=sys.stderr)
|
|
return False
|
|
|
|
# Extract revision ID if possible
|
|
revision = None
|
|
for line in result.stdout.split("\n"):
|
|
if "Generating" in line and "..." in line:
|
|
try:
|
|
# Look for the revision ID, which is typically 12 hex characters
|
|
parts = line.split()
|
|
for part in parts:
|
|
if len(part) >= 12 and all(c in "0123456789abcdef" for c in part[:12]):
|
|
revision = part[:12]
|
|
break
|
|
except Exception as e:
|
|
# If parsing fails, we can still proceed without a detected revision
|
|
print(f"Warning: could not parse revision from line '{line}': {e}")
|
|
|
|
if revision:
|
|
print(f"Generated revision: {revision}")
|
|
else:
|
|
print("Generated migration (revision ID not identified)")
|
|
|
|
return revision or True
|
|
|
|
|
|
def apply_migration(revision=None):
|
|
"""Apply migrations up to the specified revision or head"""
|
|
target = revision or "head"
|
|
print(f"Applying migration(s) to: {target}")
|
|
|
|
cmd = ["alembic", "upgrade", target]
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
print(result.stdout)
|
|
if result.returncode != 0:
|
|
print("Error applying migration:", file=sys.stderr)
|
|
print(result.stderr, file=sys.stderr)
|
|
return False
|
|
|
|
print("Migration(s) applied successfully")
|
|
return True
|
|
|
|
|
|
def show_current():
|
|
"""Show current revision"""
|
|
print("Current database revision:")
|
|
|
|
cmd = ["alembic", "current"]
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
print(result.stdout)
|
|
if result.returncode != 0:
|
|
print("Error getting current revision:", file=sys.stderr)
|
|
print(result.stderr, file=sys.stderr)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def list_migrations():
|
|
"""List all migrations and their status"""
|
|
print("Listing migrations:")
|
|
|
|
cmd = ["alembic", "history", "--verbose"]
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
print(result.stdout)
|
|
if result.returncode != 0:
|
|
print("Error listing migrations:", file=sys.stderr)
|
|
print(result.stderr, file=sys.stderr)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def check_database_connection():
|
|
"""Check if database is accessible"""
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
try:
|
|
# Use DATABASE_URL from environment (set by setup_database_url)
|
|
db_url = os.environ.get("DATABASE_URL")
|
|
if not db_url:
|
|
from app.core.config import settings
|
|
db_url = settings.database_url
|
|
|
|
engine = create_engine(db_url)
|
|
with engine.connect():
|
|
print("✓ Database connection successful!")
|
|
return True
|
|
except SQLAlchemyError as e:
|
|
print(f"✗ Error connecting to database: {e}")
|
|
return False
|
|
|
|
|
|
def get_next_rev_id():
|
|
"""Get the next sequential revision ID based on existing migrations."""
|
|
import re
|
|
|
|
versions_dir = project_root / "app" / "alembic" / "versions"
|
|
if not versions_dir.exists():
|
|
return "0001"
|
|
|
|
# Find all migration files with numeric prefixes
|
|
max_num = 0
|
|
pattern = re.compile(r"^(\d{4})_.*\.py$")
|
|
|
|
for f in versions_dir.iterdir():
|
|
if f.is_file() and f.suffix == ".py":
|
|
match = pattern.match(f.name)
|
|
if match:
|
|
num = int(match.group(1))
|
|
max_num = max(max_num, num)
|
|
|
|
next_num = max_num + 1
|
|
return f"{next_num:04d}"
|
|
|
|
|
|
def get_current_rev_id():
|
|
"""Get the current (latest) revision ID from existing migrations."""
|
|
import re
|
|
|
|
versions_dir = project_root / "app" / "alembic" / "versions"
|
|
if not versions_dir.exists():
|
|
return None
|
|
|
|
# Find all migration files with numeric prefixes and get the highest
|
|
max_num = 0
|
|
max_rev_id = None
|
|
pattern = re.compile(r"^(\d{4})_.*\.py$")
|
|
|
|
for f in versions_dir.iterdir():
|
|
if f.is_file() and f.suffix == ".py":
|
|
match = pattern.match(f.name)
|
|
if match:
|
|
num = int(match.group(1))
|
|
if num > max_num:
|
|
max_num = num
|
|
max_rev_id = match.group(1)
|
|
|
|
return max_rev_id
|
|
|
|
|
|
def generate_offline_migration(message, rev_id):
|
|
"""Generate a migration file without database connection.
|
|
|
|
Creates an empty migration template that can be filled in manually.
|
|
Useful for performance indexes or when database is not available.
|
|
"""
|
|
from datetime import datetime
|
|
|
|
versions_dir = project_root / "app" / "alembic" / "versions"
|
|
versions_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Slugify the message for filename
|
|
slug = message.lower().replace(" ", "_").replace("-", "_")
|
|
slug = "".join(c for c in slug if c.isalnum() or c == "_")
|
|
|
|
filename = f"{rev_id}_{slug}.py"
|
|
filepath = versions_dir / filename
|
|
|
|
# Get the previous revision ID
|
|
down_revision = get_current_rev_id()
|
|
down_rev_str = f'"{down_revision}"' if down_revision else "None"
|
|
|
|
# Generate the migration file content
|
|
content = f'''"""{message}
|
|
|
|
Revision ID: {rev_id}
|
|
Revises: {down_revision or ''}
|
|
Create Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}
|
|
|
|
"""
|
|
|
|
from collections.abc import Sequence
|
|
|
|
import sqlalchemy as sa
|
|
from alembic import op
|
|
|
|
# revision identifiers, used by Alembic.
|
|
revision: str = "{rev_id}"
|
|
down_revision: str | None = {down_rev_str}
|
|
branch_labels: str | Sequence[str] | None = None
|
|
depends_on: str | Sequence[str] | None = None
|
|
|
|
|
|
def upgrade() -> None:
|
|
# TODO: Add your upgrade operations here
|
|
pass
|
|
|
|
|
|
def downgrade() -> None:
|
|
# TODO: Add your downgrade operations here
|
|
pass
|
|
'''
|
|
|
|
filepath.write_text(content)
|
|
print(f"Generated offline migration: {filepath}")
|
|
return rev_id
|
|
|
|
|
|
def show_next_rev_id():
|
|
"""Show the next sequential revision ID."""
|
|
next_id = get_next_rev_id()
|
|
print(f"Next revision ID: {next_id}")
|
|
print(f"\nUsage:")
|
|
print(f" python migrate.py --local generate 'your_message' --rev-id {next_id}")
|
|
print(f" python migrate.py --local auto 'your_message' --rev-id {next_id}")
|
|
return next_id
|
|
|
|
|
|
def reset_alembic_version():
|
|
"""Reset the alembic_version table (for fresh start after deleting migrations)."""
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
db_url = os.environ.get("DATABASE_URL")
|
|
if not db_url:
|
|
from app.core.config import settings
|
|
db_url = settings.database_url
|
|
|
|
try:
|
|
engine = create_engine(db_url)
|
|
with engine.connect() as conn:
|
|
conn.execute(text("DROP TABLE IF EXISTS alembic_version"))
|
|
conn.commit()
|
|
print("✓ Alembic version table reset successfully")
|
|
print(" You can now run migrations from scratch")
|
|
return True
|
|
except SQLAlchemyError as e:
|
|
print(f"✗ Error resetting alembic version: {e}")
|
|
return False
|
|
|
|
|
|
def main():
|
|
"""Main function"""
|
|
parser = argparse.ArgumentParser(
|
|
description='Database migration helper for Generative Models Arena'
|
|
)
|
|
|
|
# Global options
|
|
parser.add_argument(
|
|
'--local', '-l',
|
|
action='store_true',
|
|
help='Use localhost instead of Docker hostname (for local development)'
|
|
)
|
|
|
|
subparsers = parser.add_subparsers(dest='command', help='Command to run')
|
|
|
|
# Generate command
|
|
generate_parser = subparsers.add_parser('generate', help='Generate a migration')
|
|
generate_parser.add_argument('message', help='Migration message')
|
|
generate_parser.add_argument(
|
|
'--rev-id',
|
|
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
|
|
)
|
|
generate_parser.add_argument(
|
|
'--offline',
|
|
action='store_true',
|
|
help='Generate empty migration template without database connection'
|
|
)
|
|
|
|
# Apply command
|
|
apply_parser = subparsers.add_parser('apply', help='Apply migrations')
|
|
apply_parser.add_argument('--revision', help='Specific revision to apply to')
|
|
|
|
# List command
|
|
subparsers.add_parser('list', help='List migrations')
|
|
|
|
# Current command
|
|
subparsers.add_parser('current', help='Show current revision')
|
|
|
|
# Check command
|
|
subparsers.add_parser('check', help='Check database connection and models')
|
|
|
|
# Next command (show next revision ID)
|
|
subparsers.add_parser('next', help='Show the next sequential revision ID')
|
|
|
|
# Reset command (clear alembic_version table)
|
|
subparsers.add_parser(
|
|
'reset',
|
|
help='Reset alembic_version table (use after deleting all migrations)'
|
|
)
|
|
|
|
# Auto command (generate and apply)
|
|
auto_parser = subparsers.add_parser('auto', help='Generate and apply migration')
|
|
auto_parser.add_argument('message', help='Migration message')
|
|
auto_parser.add_argument(
|
|
'--rev-id',
|
|
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
|
|
)
|
|
auto_parser.add_argument(
|
|
'--offline',
|
|
action='store_true',
|
|
help='Generate empty migration template without database connection'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Commands that don't need database connection
|
|
if args.command == 'next':
|
|
show_next_rev_id()
|
|
return
|
|
|
|
# Check if offline mode is requested
|
|
offline = getattr(args, 'offline', False)
|
|
|
|
# Offline generate doesn't need database or model check
|
|
if args.command == 'generate' and offline:
|
|
generate_migration(args.message, rev_id=args.rev_id, offline=True)
|
|
return
|
|
|
|
if args.command == 'auto' and offline:
|
|
generate_migration(args.message, rev_id=args.rev_id, offline=True)
|
|
print("\nOffline migration generated. Apply it later with:")
|
|
print(f" python migrate.py --local apply")
|
|
return
|
|
|
|
# Setup database URL (must be done before importing settings elsewhere)
|
|
db_url = setup_database_url(args.local)
|
|
print(f"Using database URL: {db_url}")
|
|
|
|
if args.command == 'generate':
|
|
check_models()
|
|
generate_migration(args.message, rev_id=args.rev_id)
|
|
|
|
elif args.command == 'apply':
|
|
apply_migration(args.revision)
|
|
|
|
elif args.command == 'list':
|
|
list_migrations()
|
|
|
|
elif args.command == 'current':
|
|
show_current()
|
|
|
|
elif args.command == 'check':
|
|
check_database_connection()
|
|
check_models()
|
|
|
|
elif args.command == 'reset':
|
|
reset_alembic_version()
|
|
|
|
elif args.command == 'auto':
|
|
check_models()
|
|
revision = generate_migration(args.message, rev_id=args.rev_id)
|
|
if revision:
|
|
input("\nPress Enter to apply migration or Ctrl+C to abort... ")
|
|
apply_migration()
|
|
|
|
else:
|
|
parser.print_help()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|