--- /dev/null
+import os
+import re
+from pathlib import Path
+from typing import Optional
+
+import typer
+from alembic.autogenerate import produce_migrations, render_python_code
+from alembic.config import Config
+from alembic.runtime.migration import MigrationContext
+from sqlalchemy import create_engine, pool
+
+migrations_app = typer.Typer()
+
+
+def get_migrations_dir(migrations_path: Optional[str] = None) -> Path:
+ """Get the migrations directory path."""
+ if migrations_path:
+ return Path(migrations_path)
+ return Path.cwd() / "migrations"
+
+
+def get_alembic_config(migrations_dir: Path) -> Config:
+ """Create an Alembic config object programmatically without ini file."""
+ config = Config()
+ config.set_main_option("script_location", str(migrations_dir))
+ config.set_main_option("sqlalchemy.url", "") # Will be set by env.py
+ return config
+
+
+def get_next_migration_number(migrations_dir: Path) -> str:
+ """Get the next sequential migration number."""
+ if not migrations_dir.exists():
+ return "0001"
+
+ migration_files = list(migrations_dir.glob("*.py"))
+ if not migration_files:
+ return "0001"
+
+ numbers = []
+ for f in migration_files:
+ match = re.match(r"^(\d{4})_", f.name)
+ if match:
+ numbers.append(int(match.group(1)))
+
+ if not numbers:
+ return "0001"
+
+ return f"{max(numbers) + 1:04d}"
+
+
+def get_metadata(models_path: str):
+ """Import and return SQLModel metadata."""
+ import sys
+ from importlib import import_module
+
+ # Add current directory to Python path
+ sys.path.insert(0, str(Path.cwd()))
+
+ try:
+ # Import the module containing the models
+ models_module = import_module(models_path)
+
+ # Get SQLModel from the module or import it
+ if hasattr(models_module, "SQLModel"):
+ return models_module.SQLModel.metadata
+ else:
+ # Try importing SQLModel from sqlmodel
+ from sqlmodel import SQLModel
+
+ return SQLModel.metadata
+ except ImportError as e:
+ raise ValueError(
+ f"Failed to import models from '{models_path}': {e}\n"
+ f"Make sure the module exists and is importable from the current directory."
+ )
+
+
+def get_current_revision(db_url: str) -> Optional[str]:
+ """Get the current revision from the database."""
+ import sqlalchemy as sa
+
+ engine = create_engine(db_url, poolclass=pool.NullPool)
+
+ # Create alembic_version table if it doesn't exist
+ with engine.begin() as connection:
+ connection.execute(
+ sa.text("""
+ CREATE TABLE IF NOT EXISTS alembic_version (
+ version_num VARCHAR(32) NOT NULL,
+ CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
+ )
+ """)
+ )
+
+ # Get current revision
+ with engine.connect() as connection:
+ result = connection.execute(sa.text("SELECT version_num FROM alembic_version"))
+ row = result.first()
+ return row[0] if row else None
+
+
+def generate_migration_ops(db_url: str, metadata):
+ """Generate migration operations by comparing metadata to database."""
+
+ engine = create_engine(db_url, poolclass=pool.NullPool)
+
+ with engine.connect() as connection:
+ migration_context = MigrationContext.configure(connection)
+ # Use produce_migrations which returns actual operation objects
+ migration_script = produce_migrations(migration_context, metadata)
+
+ return migration_script.upgrade_ops, migration_script.downgrade_ops
+
+
+@migrations_app.command()
+def create(
+ message: str = typer.Option(..., "--message", "-m", help="Migration message"),
+ models: str = typer.Option(
+ ...,
+ "--models",
+ help="Python import path to models module (e.g., 'models' or 'app.models')",
+ ),
+ migrations_path: Optional[str] = typer.Option(
+ None, "--path", "-p", help="Path to migrations directory"
+ ),
+) -> None:
+ """Create a new migration with autogenerate."""
+ migrations_dir = get_migrations_dir(migrations_path)
+
+ # Create migrations directory if it doesn't exist
+ migrations_dir.mkdir(parents=True, exist_ok=True)
+
+ # Get next migration number
+ migration_number = get_next_migration_number(migrations_dir)
+
+ # Create slug from message
+ # TODO: truncate, handle special characters
+ slug = message.lower().replace(" ", "_")
+ slug = re.sub(r"[^a-z0-9_]", "", slug)
+
+ filename = f"{migration_number}_{slug}.py"
+ filepath = migrations_dir / filename
+
+ typer.echo(f"Creating migration: {filename}")
+
+ try:
+ # Get database URL
+ db_url = os.getenv("DATABASE_URL")
+ if not db_url:
+ raise ValueError(
+ "DATABASE_URL environment variable is not set. "
+ "Please set it to your database connection string."
+ )
+
+ # Get metadata
+ metadata = get_metadata(models)
+
+ # Check if there are pending migrations that need to be applied
+ current_revision = get_current_revision(db_url)
+ existing_migrations = sorted(
+ [f for f in migrations_dir.glob("*.py") if f.name != "__init__.py"]
+ )
+
+ if existing_migrations:
+ # Get the latest migration file
+ latest_migration_file = existing_migrations[-1]
+ content = latest_migration_file.read_text()
+ # Extract revision = "..." from the file
+ match = re.search(
+ r'^revision = ["\']([^"\']+)["\']', content, re.MULTILINE
+ )
+ if match:
+ latest_file_revision = match.group(1)
+ # Check if database is up to date
+ if current_revision != latest_file_revision:
+ typer.echo(
+ f"Error: Database is not up to date. Current revision: {current_revision or 'None'}, "
+ f"Latest migration: {latest_file_revision}",
+ err=True,
+ )
+ typer.echo(
+ "Please run 'sqlmodel migrations migrate' to apply pending migrations before creating a new one.",
+ err=True,
+ )
+ raise typer.Exit(1)
+
+ # Generate migration operations
+ upgrade_ops_obj, downgrade_ops_obj = generate_migration_ops(db_url, metadata)
+
+ # Render upgrade
+ if upgrade_ops_obj:
+ upgrade_code = render_python_code(upgrade_ops_obj).strip()
+ else:
+ upgrade_code = "pass"
+
+ # Render downgrade
+ if downgrade_ops_obj:
+ downgrade_code = render_python_code(downgrade_ops_obj).strip()
+ else:
+ # TODO: space :)
+ downgrade_code = "pass"
+
+ # Remove Alembic comments to check if migrations are actually empty
+ def extract_code_without_comments(code: str) -> str:
+ """Extract actual code, removing Alembic auto-generated comments."""
+ lines = code.split("\n")
+ actual_lines = []
+ for line in lines:
+ # Skip Alembic comment lines
+ if not line.strip().startswith("# ###"):
+ actual_lines.append(line)
+ return "\n".join(actual_lines).strip()
+
+ upgrade_code_clean = extract_code_without_comments(upgrade_code)
+ downgrade_code_clean = extract_code_without_comments(downgrade_code)
+
+ # Only reject empty migrations if there are already existing migrations
+ # (i.e., this is not the first migration)
+ if (
+ upgrade_code_clean == "pass"
+ and downgrade_code_clean == "pass"
+ and len(existing_migrations) > 0
+ ):
+ # TODO: better message
+ typer.echo(
+ "Empty migrations are not allowed"
+ ) # TODO: unless you pass `--empty`
+ raise typer.Exit(1)
+
+ # Generate revision ID from filename (without .py extension)
+ revision_id = f"{migration_number}_{slug}"
+
+ # Get previous revision by reading the last migration file's revision ID
+ down_revision = None
+ if existing_migrations:
+ # Read the last migration file to get its revision ID
+ last_migration = existing_migrations[-1]
+ content = last_migration.read_text()
+ # Extract revision = "..." from the file
+ import re as regex_module
+
+ match = regex_module.search(
+ r'^revision = ["\']([^"\']+)["\']', content, regex_module.MULTILINE
+ )
+ if match:
+ down_revision = match.group(1)
+
+ # Check if we need to import sqlmodel
+ needs_sqlmodel = "sqlmodel" in upgrade_code or "sqlmodel" in downgrade_code
+
+ # Generate migration file - build without f-strings to avoid % issues
+ lines: list[str] = []
+ lines.append(f'"""{message}"""')
+ lines.append("")
+ lines.append("import sqlalchemy as sa")
+ if needs_sqlmodel:
+ lines.append("import sqlmodel")
+ lines.append("from alembic import op")
+ lines.append("")
+ lines.append('revision = "' + revision_id + '"')
+ lines.append("down_revision = " + repr(down_revision))
+ lines.append("depends_on = None")
+ lines.append("")
+ lines.append("")
+ lines.append("def upgrade() -> None:")
+
+ # Add upgrade code with proper indentation
+ for line in upgrade_code.split("\n"):
+ lines.append(line)
+
+ lines.append("")
+ lines.append("")
+ lines.append("def downgrade() -> None:")
+
+ # Add downgrade code with proper indentation
+ for line in downgrade_code.split("\n"):
+ lines.append(line)
+
+ migration_content = "\n".join(lines)
+
+ filepath.write_text(migration_content)
+
+ typer.echo(f"✓ Created migration: {filename}")
+
+ except ValueError as e:
+ typer.echo(f"Error creating migration: {e}", err=True)
+ raise typer.Exit(1)
+ except Exception as e:
+ import traceback
+
+ typer.echo(f"Error creating migration: {e}", err=True)
+ typer.echo("\nFull traceback:", err=True)
+ traceback.print_exc()
+ raise typer.Exit(1)
+
+
+def get_pending_migrations(
+ migrations_dir: Path, current_revision: Optional[str]
+) -> list[Path]:
+ """Get list of pending migration files."""
+ all_migrations = sorted(
+ [
+ f
+ for f in migrations_dir.glob("*.py")
+ if f.name != "__init__.py" and f.name != "env.py"
+ ]
+ )
+
+ if not current_revision:
+ return all_migrations
+
+ # Find migrations after the current revision
+ pending = []
+ found_current = False
+ for migration_file in all_migrations:
+ if found_current:
+ pending.append(migration_file)
+ elif migration_file.stem == current_revision:
+ found_current = True
+
+ return pending
+
+
+def apply_migrations_programmatically(
+ migrations_dir: Path, db_url: str, models: str
+) -> None:
+ """Apply migrations programmatically without env.py."""
+ import importlib.util
+
+ import sqlalchemy as sa
+
+ # Get metadata
+ metadata = get_metadata(models)
+
+ # Create engine
+ engine = create_engine(db_url, poolclass=pool.NullPool)
+
+ # Create alembic_version table if it doesn't exist (outside transaction)
+ with engine.begin() as connection:
+ connection.execute(
+ sa.text("""
+ CREATE TABLE IF NOT EXISTS alembic_version (
+ version_num VARCHAR(32) NOT NULL,
+ CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num)
+ )
+ """)
+ )
+
+ # Get current revision
+ with engine.connect() as connection:
+ result = connection.execute(sa.text("SELECT version_num FROM alembic_version"))
+ row = result.first()
+ current_revision = row[0] if row else None
+
+ # Get pending migrations
+ pending_migrations = get_pending_migrations(migrations_dir, current_revision)
+
+ if not pending_migrations:
+ typer.echo(" No pending migrations")
+ return
+
+ # Run each migration
+ for migration_file in pending_migrations:
+ revision_id = migration_file.stem
+ typer.echo(f" Applying: {revision_id}")
+
+ # Execute each migration in its own transaction
+ with engine.begin() as connection:
+ # Create migration context
+ migration_context = MigrationContext.configure(
+ connection, opts={"target_metadata": metadata}
+ )
+
+ # Load the migration module
+ spec = importlib.util.spec_from_file_location(revision_id, migration_file)
+ if not spec or not spec.loader:
+ raise ValueError(f"Could not load migration: {migration_file}")
+
+ module = importlib.util.module_from_spec(spec)
+
+ # Also make sqlalchemy and sqlmodel available
+ import sqlmodel
+
+ module.sa = sa # type: ignore
+ module.sqlmodel = sqlmodel # type: ignore
+
+ # Execute the module to define the functions
+ spec.loader.exec_module(module)
+
+ # Create operations context and run upgrade within ops.invoke_for_target
+ from alembic.operations import ops
+
+ with ops.Operations.context(migration_context):
+ # Now op proxy is available via alembic.op
+ from alembic import op as alembic_op
+
+ module.op = alembic_op # type: ignore
+
+ # Execute upgrade
+ module.upgrade()
+
+ # Update alembic_version table
+ if current_revision:
+ connection.execute(
+ sa.text("UPDATE alembic_version SET version_num = :version"),
+ {"version": revision_id},
+ )
+ else:
+ connection.execute(
+ sa.text(
+ "INSERT INTO alembic_version (version_num) VALUES (:version)"
+ ),
+ {"version": revision_id},
+ )
+
+ current_revision = revision_id
+
+
+@migrations_app.command()
+def migrate(
+ models: str = typer.Option(
+ ...,
+ "--models",
+ help="Python import path to models module (e.g., 'models' or 'app.models')",
+ ),
+ migrations_path: Optional[str] = typer.Option(
+ None, "--path", "-p", help="Path to migrations directory"
+ ),
+) -> None:
+ """Apply all pending migrations to the database."""
+ migrations_dir = get_migrations_dir(migrations_path)
+
+ if not migrations_dir.exists():
+ typer.echo(
+ f"Error: {migrations_dir} not found. Run 'sqlmodel migrations init' first.",
+ err=True,
+ )
+ raise typer.Exit(1)
+
+ # Get database URL
+ db_url = os.getenv("DATABASE_URL")
+ if not db_url:
+ typer.echo("Error: DATABASE_URL environment variable is not set.", err=True)
+ raise typer.Exit(1)
+
+ typer.echo("Applying migrations...")
+
+ try:
+ apply_migrations_programmatically(migrations_dir, db_url, models)
+ typer.echo("✓ Migrations applied successfully")
+ except Exception as e:
+ typer.echo(f"Error applying migrations: {e}", err=True)
+ raise typer.Exit(1)
--- /dev/null
+import shutil
+from pathlib import Path
+
+import pytest
+from inline_snapshot import external, register_format_alias
+from sqlmodel.cli import app
+from typer.testing import CliRunner
+
+register_format_alias(".py", ".txt")
+
+runner = CliRunner()
+
+
+HERE = Path(__file__).parent
+
+register_format_alias(".html", ".txt")
+
+
+def test_create_first_migration(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
+ """Test creating the first migration with an empty database."""
+ db_path = tmp_path / "test.db"
+ db_url = f"sqlite:///{db_path}"
+ migrations_dir = tmp_path / "migrations"
+
+ model_source = HERE / "./fixtures/models_initial.py"
+
+ models_dir = tmp_path / "test_models"
+ models_dir.mkdir()
+
+ (models_dir / "__init__.py").write_text("")
+ models_file = models_dir / "models.py"
+
+ shutil.copy(model_source, models_file)
+
+ monkeypatch.setenv("DATABASE_URL", db_url)
+ monkeypatch.chdir(tmp_path)
+
+ # Run the create command
+ result = runner.invoke(
+ app,
+ [
+ "migrations",
+ "create",
+ "-m",
+ "Initial migration",
+ "--models",
+ "test_models.models",
+ "--path",
+ str(migrations_dir),
+ ],
+ )
+
+ assert result.exit_code == 0, f"Command failed: {result.stdout}"
+ assert "✓ Created migration:" in result.stdout
+
+ migration_files = sorted(
+ [str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")]
+ )
+
+ assert migration_files == [
+ "migrations/0001_initial_migration.py",
+ ]
+
+ migration_file = migrations_dir / "0001_initial_migration.py"
+
+ assert migration_file.read_text() == external(
+ "uuid:f1182584-912e-4f31-9d79-2233e5a8a986.py"
+ )
+
+
+# TODO: to force migration you need to pass `--empty`s
+def test_running_migration_twice_only_generates_migration_once(
+ tmp_path: Path, monkeypatch: pytest.MonkeyPatch
+):
+ db_path = tmp_path / "test.db"
+ db_url = f"sqlite:///{db_path}"
+ migrations_dir = tmp_path / "migrations"
+
+ model_source = HERE / "./fixtures/models_initial.py"
+
+ models_dir = tmp_path / "test_models"
+ models_dir.mkdir()
+
+ (models_dir / "__init__.py").write_text("")
+ models_file = models_dir / "models.py"
+
+ shutil.copy(model_source, models_file)
+
+ monkeypatch.setenv("DATABASE_URL", db_url)
+ monkeypatch.chdir(tmp_path)
+
+ # Run the create command
+ result = runner.invoke(
+ app,
+ [
+ "migrations",
+ "create",
+ "-m",
+ "Initial migration",
+ "--models",
+ "test_models.models",
+ "--path",
+ str(migrations_dir),
+ ],
+ )
+
+ assert result.exit_code == 0, f"Command failed: {result.stdout}"
+ assert "✓ Created migration:" in result.stdout
+
+ # Apply the first migration to the database
+ result = runner.invoke(
+ app,
+ [
+ "migrations",
+ "migrate",
+ "--models",
+ "test_models.models",
+ "--path",
+ str(migrations_dir),
+ ],
+ )
+
+ assert result.exit_code == 0, f"Migration failed: {result.stdout}"
+
+ # Run the create command again (should fail with empty migration)
+ result = runner.invoke(
+ app,
+ [
+ "migrations",
+ "create",
+ "-m",
+ "Initial migration",
+ "--models",
+ "test_models.models",
+ "--path",
+ str(migrations_dir),
+ ],
+ )
+
+ assert result.exit_code == 1
+ assert "Empty migrations are not allowed" in result.stdout
+
+ migration_files = sorted(
+ [str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")]
+ )
+
+ assert migration_files == [
+ "migrations/0001_initial_migration.py",
+ ]
+
+
+def test_cannot_create_migration_with_pending_migrations(
+ tmp_path: Path, monkeypatch: pytest.MonkeyPatch
+):
+ """Test that creating a migration fails if there are unapplied migrations."""
+ db_path = tmp_path / "test.db"
+ db_url = f"sqlite:///{db_path}"
+ migrations_dir = tmp_path / "migrations"
+
+ model_source = HERE / "./fixtures/models_initial.py"
+
+ models_dir = tmp_path / "test_models"
+ models_dir.mkdir()
+
+ (models_dir / "__init__.py").write_text("")
+ models_file = models_dir / "models.py"
+
+ shutil.copy(model_source, models_file)
+
+ monkeypatch.setenv("DATABASE_URL", db_url)
+ monkeypatch.chdir(tmp_path)
+
+ # Run the create command to create first migration
+ result = runner.invoke(
+ app,
+ [
+ "migrations",
+ "create",
+ "-m",
+ "Initial migration",
+ "--models",
+ "test_models.models",
+ "--path",
+ str(migrations_dir),
+ ],
+ )
+
+ assert result.exit_code == 0, f"Command failed: {result.stdout}"
+ assert "✓ Created migration:" in result.stdout
+
+ # Try to create another migration WITHOUT applying the first one
+ result = runner.invoke(
+ app,
+ [
+ "migrations",
+ "create",
+ "-m",
+ "Second migration",
+ "--models",
+ "test_models.models",
+ "--path",
+ str(migrations_dir),
+ ],
+ )
+
+ # Should fail because database is not up to date
+ assert result.exit_code == 1
+ # Error messages are printed to stderr, which Typer's CliRunner combines into output
+ assert (
+ "Database is not up to date" in result.stdout
+ or "Database is not up to date" in str(result.output)
+ )
+ assert (
+ "Please run 'sqlmodel migrations migrate'" in result.stdout
+ or "Please run 'sqlmodel migrations migrate'" in str(result.output)
+ )
+
+ # Verify only one migration file exists
+ migration_files = sorted(
+ [str(f.relative_to(tmp_path)) for f in migrations_dir.glob("*.py")]
+ )
+
+ assert migration_files == [
+ "migrations/0001_initial_migration.py",
+ ]