import io
import re
+from sqlalchemy import Column
from sqlalchemy import create_engine
+from sqlalchemy import inspect
from sqlalchemy import MetaData
+from sqlalchemy import String
+from sqlalchemy import Table
from sqlalchemy import text
import alembic
+from . import config
from . import mock
from .assertions import _get_dialect
from .assertions import eq_
alembic.op._proxy = Operations(context)
return context
+
+
+class AlterColRoundTripFixture(object):
+
+ # since these tests are about syntax, use more recent SQLAlchemy as some of
+ # the type / server default compare logic might not work on older
+ # SQLAlchemy versions as seems to be the case for SQLAlchemy 1.1 on Oracle
+
+ __requires__ = ("alter_column", "sqlachemy_12")
+
+ def setUp(self):
+ self.conn = config.db.connect()
+ self.ctx = MigrationContext.configure(self.conn)
+ self.op = Operations(self.ctx)
+ self.metadata = MetaData()
+
+ def _compare_type(self, t1, t2):
+ c1 = Column("q", t1)
+ c2 = Column("q", t2)
+ assert not self.ctx.impl.compare_type(
+ c1, c2
+ ), "Type objects %r and %r didn't compare as equivalent" % (t1, t2)
+
+ def _compare_server_default(self, t1, s1, t2, s2):
+ c1 = Column("q", t1, server_default=s1)
+ c2 = Column("q", t2, server_default=s2)
+ assert not self.ctx.impl.compare_server_default(
+ c1, c2, s2, s1
+ ), "server defaults %r and %r didn't compare as equivalent" % (s1, s2)
+
+ def tearDown(self):
+ self.metadata.drop_all(self.conn)
+ self.conn.close()
+
+ def _run_alter_col(self, from_, to_):
+ column = Column(
+ from_.get("name", "colname"),
+ from_.get("type", String(10)),
+ nullable=from_.get("nullable", True),
+ server_default=from_.get("server_default", None),
+ # comment=from_.get("comment", None)
+ )
+ t = Table("x", self.metadata, column)
+
+ t.create(self.conn)
+ insp = inspect(self.conn)
+ old_col = insp.get_columns("x")[0]
+
+ # TODO: conditional comment support
+ self.op.alter_column(
+ "x",
+ column.name,
+ existing_type=column.type,
+ existing_server_default=column.server_default
+ if column.server_default is not None
+ else False,
+ existing_nullable=True if column.nullable else False,
+ # existing_comment=column.comment,
+ nullable=to_.get("nullable", None),
+ # modify_comment=False,
+ server_default=to_.get("server_default", False),
+ new_column_name=to_.get("name", None),
+ type_=to_.get("type", None),
+ )
+
+ insp = inspect(self.conn)
+ new_col = insp.get_columns("x")[0]
+
+ eq_(new_col["name"], to_["name"] if "name" in to_ else column.name)
+ self._compare_type(new_col["type"], to_.get("type", old_col["type"]))
+ eq_(new_col["nullable"], to_.get("nullable", column.nullable))
+ self._compare_server_default(
+ new_col["type"],
+ new_col.get("default", None),
+ to_.get("type", old_col["type"]),
+ to_["server_default"].text
+ if "server_default" in to_
+ else column.server_default.arg.text
+ if column.server_default is not None
+ else None,
+ )
def reflects_fk_options(self):
return exclusions.closed()
+ @property
+ def sqlachemy_12(self):
+ return exclusions.skip_if(
+ lambda config: not util.sqla_1216,
+ "SQLAlchemy 1.2.16 or greater required",
+ )
+
@property
def fail_before_sqla_100(self):
return exclusions.fails_if(
def check(config):
vers = sqla_compat._vers
- if vers == (1, 3, 0, 'b1'):
+ if vers == (1, 3, 0, "b1"):
return True
elif vers >= (1, 2, 16):
return False
return True
return exclusions.skip_if(
- check,
- "SQLAlchemy 1.2.16, 1.3.0b2 or greater required",
+ check, "SQLAlchemy 1.2.16, 1.3.0b2 or greater required"
)
@property
@property
def comments_api(self):
return exclusions.only_if(lambda config: util.sqla_120)
+
+ @property
+ def alter_column(self):
+ return exclusions.open()
from alembic.testing import eq_
from alembic.testing import is_
from alembic.testing import mock
+from alembic.testing.fixtures import AlterColRoundTripFixture
from alembic.testing.fixtures import op_fixture
from alembic.testing.fixtures import TestBase
assert_raises_message(
ValueError, "constraint cannot be produced", op.to_constraint
)
+
+
+# MARKMARK
+class BackendAlterColumnTest(AlterColRoundTripFixture, TestBase):
+ __backend__ = True
+
+ def test_rename_column(self):
+ self._run_alter_col({}, {"name": "newname"})
+
+ def test_modify_type_int_str(self):
+ self._run_alter_col({"type": Integer()}, {"type": String(50)})
+
+ def test_add_server_default_int(self):
+ self._run_alter_col({"type": Integer}, {"server_default": text("5")})
+
+ def test_modify_server_default_int(self):
+ self._run_alter_col(
+ {"type": Integer, "server_default": text("2")},
+ {"server_default": text("5")},
+ )
+
+ def test_modify_nullable_to_non(self):
+ self._run_alter_col({}, {"nullable": False})
+
+ def test_modify_non_nullable_to_nullable(self):
+ self._run_alter_col({"nullable": False}, {"nullable": True})