]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
Add support identity columns.
authorCaselIT <cfederico87@gmail.com>
Tue, 17 Nov 2020 21:32:36 +0000 (22:32 +0100)
committerCaselIT <cfederico87@gmail.com>
Tue, 15 Dec 2020 22:27:06 +0000 (23:27 +0100)
Added support for rendering of "identity" elements on
:class:`.Column` objects, supported in SQLAlchemy via
the :class:`.Identity` element introduced in version 1.4.

Adding columns with identity is supported on PostgreSQL,
MSSQL and Oracle. Changing the identity options or removing
it is supported only on PostgreSQL and Oracle.

Fixes: #730
Change-Id: I184d8bd30ef5c6286f3263b4033ca4eb359c02e2

20 files changed:
alembic/autogenerate/api.py
alembic/autogenerate/compare.py
alembic/autogenerate/render.py
alembic/ddl/base.py
alembic/ddl/impl.py
alembic/ddl/mssql.py
alembic/ddl/mysql.py
alembic/ddl/oracle.py
alembic/ddl/postgresql.py
alembic/testing/requirements.py
alembic/util/sqla_compat.py
docs/build/unreleased/730.rst [new file with mode: 0644]
tests/requirements.py
tests/test_autogen_identity.py [new file with mode: 0644]
tests/test_autogen_render.py
tests/test_mssql.py
tests/test_mysql.py
tests/test_op.py
tests/test_oracle.py
tests/test_postgresql.py

index 2902f80a9fdf3e0e2283ada3818b87bc5ee34a54..5d1e84816f799c7f4e3e9fae4fbceaface6cf279 100644 (file)
@@ -220,7 +220,7 @@ class AutogenContext(object):
     connected to the database backend being compared.
 
     This is obtained from the :attr:`.MigrationContext.bind` and is
-    utimately set up in the ``env.py`` script.
+    ultimately set up in the ``env.py`` script.
 
     """
 
index 08d3b3bb71e3e31b4decfd8ad8b04e593c69ec77..b82225d9476a322c118f130213fe9238c0e2e9d0 100644 (file)
@@ -161,7 +161,7 @@ def _compare_tables(
                 (inspector),
                 # fmt: on
             )
-            inspector.reflecttable(t, None)
+            sqla_compat._reflect_table(inspector, t, None)
         if autogen_context.run_filters(t, tname, "table", True, None):
 
             modify_table_ops = ops.ModifyTableOps(tname, [], schema=s)
@@ -192,7 +192,7 @@ def _compare_tables(
                 _compat_autogen_column_reflect(inspector),
                 # fmt: on
             )
-            inspector.reflecttable(t, None)
+            sqla_compat._reflect_table(inspector, t, None)
         conn_column_info[(s, tname)] = t
 
     for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])):
@@ -810,13 +810,22 @@ def _compare_nullable(
     alter_column_op.existing_nullable = conn_col_nullable
 
     if conn_col_nullable is not metadata_col_nullable:
-        alter_column_op.modify_nullable = metadata_col_nullable
-        log.info(
-            "Detected %s on column '%s.%s'",
-            "NULL" if metadata_col_nullable else "NOT NULL",
-            tname,
-            cname,
-        )
+        if sqla_compat._server_default_is_identity(
+            metadata_col.server_default, conn_col.server_default
+        ):
+            log.info(
+                "Ignoring nullable change on identity column '%s.%s'",
+                tname,
+                cname,
+            )
+        else:
+            alter_column_op.modify_nullable = metadata_col_nullable
+            log.info(
+                "Detected %s on column '%s.%s'",
+                "NULL" if metadata_col_nullable else "NOT NULL",
+                tname,
+                cname,
+            )
 
 
 @comparators.dispatch_for("column")
@@ -969,6 +978,23 @@ def _warn_computed_not_supported(tname, cname):
     util.warn("Computed default on %s.%s cannot be modified" % (tname, cname))
 
 
+def _compare_identity_default(
+    autogen_context,
+    alter_column_op,
+    schema,
+    tname,
+    cname,
+    conn_col,
+    metadata_col,
+):
+    impl = autogen_context.migration_context.impl
+    diff, ignored_attr = impl._compare_identity_default(
+        metadata_col.server_default, conn_col.server_default
+    )
+
+    return diff
+
+
 @comparators.dispatch_for("column")
 def _compare_server_default(
     autogen_context,
@@ -985,9 +1011,7 @@ def _compare_server_default(
     if conn_col_default is None and metadata_default is None:
         return False
 
-    if sqla_compat.has_computed and isinstance(
-        metadata_default, sa_schema.Computed
-    ):
+    if sqla_compat._server_default_is_computed(metadata_default):
         # return False in case of a computed column as the server
         # default. Note that DDL for adding or removing "GENERATED AS" from
         # an existing column is not currently known for any backend.
@@ -1007,33 +1031,53 @@ def _compare_server_default(
                 conn_col,
                 metadata_col,
             )
-    rendered_metadata_default = _render_server_default_for_compare(
-        metadata_default, metadata_col, autogen_context
-    )
-
-    if sqla_compat.has_computed_reflection and isinstance(
-        conn_col.server_default, sa_schema.Computed
-    ):
+    if sqla_compat._server_default_is_computed(conn_col_default):
         _warn_computed_not_supported(tname, cname)
         return False
+
+    if sqla_compat._server_default_is_identity(
+        metadata_default, conn_col_default
+    ):
+        alter_column_op.existing_server_default = conn_col_default
+        is_diff = _compare_identity_default(
+            autogen_context,
+            alter_column_op,
+            schema,
+            tname,
+            cname,
+            conn_col,
+            metadata_col,
+        )
+        if is_diff or (bool(conn_col_default) != bool(metadata_default)):
+            alter_column_op.modify_server_default = metadata_default
+            if is_diff:
+                log.info(
+                    "Detected server default on column '%s.%s': "
+                    "identity options attributes %s",
+                    tname,
+                    cname,
+                    sorted(is_diff),
+                )
     else:
+        rendered_metadata_default = _render_server_default_for_compare(
+            metadata_default, metadata_col, autogen_context
+        )
+
         rendered_conn_default = (
-            conn_col.server_default.arg.text
-            if conn_col.server_default
-            else None
+            conn_col_default.arg.text if conn_col_default else None
         )
 
-    alter_column_op.existing_server_default = conn_col_default
+        alter_column_op.existing_server_default = conn_col_default
 
-    isdiff = autogen_context.migration_context._compare_server_default(
-        conn_col,
-        metadata_col,
-        rendered_metadata_default,
-        rendered_conn_default,
-    )
-    if isdiff:
-        alter_column_op.modify_server_default = metadata_default
-        log.info("Detected server default on column '%s.%s'", tname, cname)
+        is_diff = autogen_context.migration_context._compare_server_default(
+            conn_col,
+            metadata_col,
+            rendered_metadata_default,
+            rendered_conn_default,
+        )
+        if is_diff:
+            alter_column_op.modify_server_default = metadata_default
+            log.info("Detected server default on column '%s.%s'", tname, cname)
 
 
 @comparators.dispatch_for("column")
index a5e32639ca0cc2689dd3f8cc48566989311d204b..23890fb22f7785329a90339f41da4f3e0035daeb 100644 (file)
@@ -1,3 +1,4 @@
+from collections import OrderedDict
 import re
 
 from mako.pygen import PythonPrinter
@@ -602,15 +603,16 @@ def _render_column(column, autogen_context):
     opts = []
 
     if column.server_default:
-        if sqla_compat._server_default_is_computed(column):
-            rendered = _render_computed(column.computed, autogen_context)
-            if rendered:
+
+        rendered = _render_server_default(
+            column.server_default, autogen_context
+        )
+        if rendered:
+            if _should_render_server_default_positionally(
+                column.server_default
+            ):
                 args.append(rendered)
-        else:
-            rendered = _render_server_default(
-                column.server_default, autogen_context
-            )
-            if rendered:
+            else:
                 opts.append(("server_default", rendered))
 
     if (
@@ -648,13 +650,21 @@ def _render_column(column, autogen_context):
     }
 
 
+def _should_render_server_default_positionally(server_default):
+    return sqla_compat._server_default_is_computed(
+        server_default
+    ) or sqla_compat._server_default_is_identity(server_default)
+
+
 def _render_server_default(default, autogen_context, repr_=True):
     rendered = _user_defined_render("server_default", default, autogen_context)
     if rendered is not False:
         return rendered
 
-    if sqla_compat.has_computed and isinstance(default, sa_schema.Computed):
+    if sqla_compat._server_default_is_computed(default):
         return _render_computed(default, autogen_context)
+    elif sqla_compat._server_default_is_identity(default):
+        return _render_identity(default, autogen_context)
     elif isinstance(default, sa_schema.DefaultClause):
         if isinstance(default.arg, compat.string_types):
             default = default.arg
@@ -684,6 +694,28 @@ def _render_computed(computed, autogen_context):
     }
 
 
+def _render_identity(identity, autogen_context):
+    # always=None means something different than always=False
+    kwargs = OrderedDict(always=identity.always)
+    if identity.on_null is not None:
+        kwargs["on_null"] = identity.on_null
+    kwargs.update(_get_identity_options(identity))
+
+    return "%(prefix)sIdentity(%(kwargs)s)" % {
+        "prefix": _sqlalchemy_autogenerate_prefix(autogen_context),
+        "kwargs": (", ".join("%s=%s" % pair for pair in kwargs.items())),
+    }
+
+
+def _get_identity_options(identity_options):
+    kwargs = OrderedDict()
+    for attr in sqla_compat._identity_options_attrs:
+        value = getattr(identity_options, attr, None)
+        if value is not None:
+            kwargs[attr] = value
+    return kwargs
+
+
 def _repr_type(type_, autogen_context):
     rendered = _user_defined_render("type", type_, autogen_context)
     if rendered is not False:
index b8d9dce41cf4c4881e9e64a45c7c1498f5ddb477..da81c72206cbab03fbe4061df636b100a3c454ab 100644 (file)
@@ -8,7 +8,6 @@ from sqlalchemy.schema import Column
 from sqlalchemy.schema import DDLElement
 from sqlalchemy.sql.elements import quoted_name
 
-from ..util import sqla_compat
 from ..util.sqla_compat import _columns_for_constraint  # noqa
 from ..util.sqla_compat import _find_columns  # noqa
 from ..util.sqla_compat import _fk_spec  # noqa
@@ -83,6 +82,19 @@ class ColumnDefault(AlterColumn):
         self.default = default
 
 
+class ComputedColumnDefault(AlterColumn):
+    def __init__(self, name, column_name, default, **kw):
+        super(ComputedColumnDefault, self).__init__(name, column_name, **kw)
+        self.default = default
+
+
+class IdentityColumnDefault(AlterColumn):
+    def __init__(self, name, column_name, default, impl, **kw):
+        super(IdentityColumnDefault, self).__init__(name, column_name, **kw)
+        self.default = default
+        self.impl = impl
+
+
 class AddColumn(AlterTable):
     def __init__(self, name, column, schema=None):
         super(AddColumn, self).__init__(name, schema=schema)
@@ -154,15 +166,6 @@ def visit_column_name(element, compiler, **kw):
 
 @compiles(ColumnDefault)
 def visit_column_default(element, compiler, **kw):
-    if sqla_compat.has_computed and (
-        isinstance(element.default, sqla_compat.Computed)
-        or isinstance(element.existing_server_default, sqla_compat.Computed)
-    ):
-        raise exc.CompileError(
-            'Adding or removing a "computed" construct, e.g. GENERATED '
-            "ALWAYS AS, to or from an existing column is not supported."
-        )
-
     return "%s %s %s" % (
         alter_table(compiler, element.table_name, element.schema),
         alter_column(compiler, element.column_name),
@@ -172,6 +175,23 @@ def visit_column_default(element, compiler, **kw):
     )
 
 
+@compiles(ComputedColumnDefault)
+def visit_computed_column(element, compiler, **kw):
+    raise exc.CompileError(
+        'Adding or removing a "computed" construct, e.g. GENERATED '
+        "ALWAYS AS, to or from an existing column is not supported."
+    )
+
+
+@compiles(IdentityColumnDefault)
+def visit_identity_column(element, compiler, **kw):
+    raise exc.CompileError(
+        'Adding, removing or modifying an "identity" construct, '
+        "e.g. GENERATED AS IDENTITY, to or from an existing "
+        "column is not supported in this dialect."
+    )
+
+
 def quote_dotted(name, quote):
     """quote the elements of a dotted name"""
 
index d0e64585520b0315ff68f26f7c24eaeaa4548ab4..08d980b7f00eecc362a218ec1e941dcc7ff0438b 100644 (file)
@@ -46,6 +46,8 @@ class DefaultImpl(with_metaclass(ImplMeta)):
     command_terminator = ";"
     type_synonyms = ({"NUMERIC", "DECIMAL"},)
     type_arg_extract = ()
+    # on_null is known to be supported only by oracle
+    identity_attrs_ignore = ("on_null",)
 
     def __init__(
         self,
@@ -180,8 +182,20 @@ class DefaultImpl(with_metaclass(ImplMeta)):
                 )
             )
         if server_default is not False:
+            kw = {}
+            if sqla_compat._server_default_is_computed(
+                server_default, existing_server_default
+            ):
+                cls_ = base.ComputedColumnDefault
+            elif sqla_compat._server_default_is_identity(
+                server_default, existing_server_default
+            ):
+                cls_ = base.IdentityColumnDefault
+                kw["impl"] = self
+            else:
+                cls_ = base.ColumnDefault
             self._exec(
-                base.ColumnDefault(
+                cls_(
                     table_name,
                     column_name,
                     server_default,
@@ -190,6 +204,7 @@ class DefaultImpl(with_metaclass(ImplMeta)):
                     existing_server_default=existing_server_default,
                     existing_nullable=existing_nullable,
                     existing_comment=existing_comment,
+                    **kw
                 )
             )
         if type_ is not None:
@@ -513,3 +528,44 @@ class DefaultImpl(with_metaclass(ImplMeta)):
 
     def render_type(self, type_obj, autogen_context):
         return False
+
+    def _compare_identity_default(self, metadata_identity, inspector_identity):
+
+        # ignored contains the attributes that were not considered
+        # because assumed to their default values in the db.
+        diff, ignored = _compare_identity_options(
+            sqla_compat._identity_attrs,
+            metadata_identity,
+            inspector_identity,
+            sqla_compat.Identity(),
+        )
+
+        meta_always = getattr(metadata_identity, "always", None)
+        inspector_always = getattr(inspector_identity, "always", None)
+        # None and False are the same in this comparison
+        if bool(meta_always) != bool(inspector_always):
+            diff.add("always")
+
+        diff.difference_update(self.identity_attrs_ignore)
+
+        return diff, ignored
+
+
+def _compare_identity_options(
+    attributes, metadata_io, inspector_io, default_io
+):
+    # this can be used for identity or sequence compare.
+    # default_io is an instance of IdentityOption with all attributes to the
+    # default value.
+    diff = set()
+    ignored_attr = set()
+    for attr in attributes:
+        meta_value = getattr(metadata_io, attr, None)
+        default_value = getattr(default_io, attr, None)
+        conn_value = getattr(inspector_io, attr, None)
+        if conn_value != meta_value:
+            if meta_value == default_value:
+                ignored_attr.add(attr)
+            else:
+                diff.add(attr)
+    return diff, ignored_attr
index e20e069ddd8fb3f6281bc32ac05f45f70a4d5783..916011aa8faa0f297dbe0086693f9f5937cbdef6 100644 (file)
@@ -19,6 +19,7 @@ from .base import format_type
 from .base import RenameTable
 from .impl import DefaultImpl
 from .. import util
+from ..util import sqla_compat
 
 
 class MSSQLImpl(DefaultImpl):
@@ -27,6 +28,17 @@ class MSSQLImpl(DefaultImpl):
     batch_separator = "GO"
 
     type_synonyms = DefaultImpl.type_synonyms + ({"VARCHAR", "NVARCHAR"},)
+    identity_attrs_ignore = (
+        "minvalue",
+        "maxvalue",
+        "nominvalue",
+        "nomaxvalue",
+        "cycle",
+        "cache",
+        "order",
+        "on_null",
+        "order",
+    )
 
     def __init__(self, *arg, **kw):
         super(MSSQLImpl, self).__init__(*arg, **kw)
@@ -76,6 +88,16 @@ class MSSQLImpl(DefaultImpl):
                     "existing_type or a new type_ be passed."
                 )
 
+        used_default = False
+        if sqla_compat._server_default_is_identity(
+            server_default, existing_server_default
+        ) or sqla_compat._server_default_is_computed(
+            server_default, existing_server_default
+        ):
+            used_default = True
+            kw["server_default"] = server_default
+            kw["existing_server_default"] = existing_server_default
+
         super(MSSQLImpl, self).alter_column(
             table_name,
             column_name,
@@ -87,7 +109,7 @@ class MSSQLImpl(DefaultImpl):
             **kw
         )
 
-        if server_default is not False:
+        if server_default is not False and used_default is False:
             if existing_server_default is not False or server_default is None:
                 self._exec(
                     _ExecDropConstraint(
index 446eeb146ee290fe4b9188baa1cbfa47396f7b56..4761f75edd9e4b5c592fdc0378765d311efb0e7b 100644 (file)
@@ -15,6 +15,7 @@ from .base import format_server_default
 from .impl import DefaultImpl
 from .. import util
 from ..autogenerate import compare
+from ..util import sqla_compat
 from ..util.sqla_compat import _is_mariadb
 from ..util.sqla_compat import _is_type_bound
 
@@ -44,6 +45,25 @@ class MySQLImpl(DefaultImpl):
         existing_comment=None,
         **kw
     ):
+        if sqla_compat._server_default_is_identity(
+            server_default, existing_server_default
+        ) or sqla_compat._server_default_is_computed(
+            server_default, existing_server_default
+        ):
+            # modifying computed or identity columns is not supported
+            # the default will raise
+            super(MySQLImpl, self).alter_column(
+                table_name,
+                column_name,
+                nullable=nullable,
+                type_=type_,
+                schema=schema,
+                existing_type=existing_type,
+                existing_nullable=existing_nullable,
+                server_default=server_default,
+                existing_server_default=existing_server_default,
+                **kw
+            )
         if name is not None or self._is_mysql_allowed_functional_default(
             type_ if type_ is not None else existing_type, server_default
         ):
index aaf759adcb2d8298fd56e0ad515c013352e88ef3..90f93d27cfbbc4de5a1a530b974ed5b06397542e 100644 (file)
@@ -12,6 +12,7 @@ from .base import format_column_name
 from .base import format_server_default
 from .base import format_table_name
 from .base import format_type
+from .base import IdentityColumnDefault
 from .base import RenameTable
 from .impl import DefaultImpl
 
@@ -25,6 +26,7 @@ class OracleImpl(DefaultImpl):
         {"VARCHAR", "VARCHAR2"},
         {"BIGINT", "INTEGER", "SMALLINT", "DECIMAL", "NUMERIC", "NUMBER"},
     )
+    identity_attrs_ignore = ()
 
     def __init__(self, *arg, **kw):
         super(OracleImpl, self).__init__(*arg, **kw)
@@ -121,3 +123,18 @@ def alter_column(compiler, name):
 
 def add_column(compiler, column, **kw):
     return "ADD %s" % compiler.get_column_specification(column, **kw)
+
+
+@compiles(IdentityColumnDefault, "oracle")
+def visit_identity_column(element, compiler, **kw):
+    text = "%s %s " % (
+        alter_table(compiler, element.table_name, element.schema),
+        alter_column(compiler, element.column_name),
+    )
+    if element.default is None:
+        # drop identity
+        text += "DROP IDENTITY"
+        return text
+    else:
+        text += compiler.visit_identity_column(element.default)
+        return text
index 4ddc0ed9701f3b8dccc45826b0473f443b0ae53e..5fd6c91981d2a70efe6166f35a2201a251e3c93e 100644 (file)
@@ -20,6 +20,7 @@ from .base import compiles
 from .base import format_column_name
 from .base import format_table_name
 from .base import format_type
+from .base import IdentityColumnDefault
 from .base import RenameTable
 from .impl import DefaultImpl
 from .. import util
@@ -41,6 +42,7 @@ class PostgresqlImpl(DefaultImpl):
     type_synonyms = DefaultImpl.type_synonyms + (
         {"FLOAT", "DOUBLE PRECISION"},
     )
+    identity_attrs_ignore = ("on_null", "order")
 
     def prep_table_for_batch(self, table):
         for constraint in table.constraints:
@@ -296,6 +298,39 @@ def visit_column_comment(element, compiler, **kw):
     )
 
 
+@compiles(IdentityColumnDefault, "postgresql")
+def visit_identity_column(element, compiler, **kw):
+    text = "%s %s " % (
+        alter_table(compiler, element.table_name, element.schema),
+        alter_column(compiler, element.column_name),
+    )
+    if element.default is None:
+        # drop identity
+        text += "DROP IDENTITY"
+        return text
+    elif element.existing_server_default is None:
+        # add identity options
+        text += "ADD "
+        text += compiler.visit_identity_column(element.default)
+        return text
+    else:
+        # alter identity
+        diff, _ = element.impl._compare_identity_default(
+            element.default, element.existing_server_default
+        )
+        identity = element.default
+        for attr in sorted(diff):
+            if attr == "always":
+                text += "SET GENERATED %s " % (
+                    "ALWAYS" if identity.always else "BY DEFAULT"
+                )
+            else:
+                text += "SET %s " % compiler.get_identity_options(
+                    sqla_compat.Identity(**{attr: getattr(identity, attr)})
+                )
+        return text
+
+
 @Operations.register_operation("create_exclude_constraint")
 @BatchOperations.register_operation(
     "create_exclude_constraint", "batch_create_exclude_constraint"
index a9a1059d1a94e362c348b97d3b3013d9271c5ba3..456002d825c01332577e1ea1720bde95c6c47eb3 100644 (file)
@@ -106,3 +106,21 @@ class SuiteRequirements(Requirements):
         return exclusions.only_if(
             exclusions.BooleanPredicate(sqla_compat.has_computed)
         )
+
+    @property
+    def identity_columns(self):
+        return exclusions.closed()
+
+    @property
+    def identity_columns_alter(self):
+        return exclusions.closed()
+
+    @property
+    def identity_columns_api(self):
+        return exclusions.only_if(
+            exclusions.BooleanPredicate(sqla_compat.has_identity)
+        )
+
+    @property
+    def supports_identity_on_null(self):
+        return exclusions.closed()
index 5ac8d5d660409d80d3e1f7f9b16f7287a45a8325..159d0f09b10b56aeff9309441fb99284626187a2 100644 (file)
@@ -30,15 +30,36 @@ _vers = tuple(
 )
 sqla_13 = _vers >= (1, 3)
 sqla_14 = _vers >= (1, 4)
+
 try:
     from sqlalchemy import Computed  # noqa
-
-    has_computed = True
-
-    has_computed_reflection = _vers >= (1, 3, 16)
 except ImportError:
     has_computed = False
     has_computed_reflection = False
+else:
+    has_computed = True
+    has_computed_reflection = _vers >= (1, 3, 16)
+
+try:
+    from sqlalchemy import Identity  # noqa
+except ImportError:
+    has_identity = False
+else:
+    # attributes common to Indentity and Sequence
+    _identity_options_attrs = (
+        "start",
+        "increment",
+        "minvalue",
+        "maxvalue",
+        "nominvalue",
+        "nomaxvalue",
+        "cycle",
+        "cache",
+        "order",
+    )
+    # attributes of Indentity
+    _identity_attrs = _identity_options_attrs + ("on_null",)
+    has_identity = True
 
 AUTOINCREMENT_DEFAULT = "auto"
 
@@ -67,11 +88,18 @@ def _exec_on_inspector(inspector, statement, **params):
         return inspector.bind.execute(statement, params)
 
 
-def _server_default_is_computed(column):
+def _server_default_is_computed(*server_default):
     if not has_computed:
         return False
     else:
-        return isinstance(column.computed, Computed)
+        return any(isinstance(sd, Computed) for sd in server_default)
+
+
+def _server_default_is_identity(*server_default):
+    if not sqla_14:
+        return False
+    else:
+        return any(isinstance(sd, Identity) for sd in server_default)
 
 
 def _table_for_constraint(constraint):
@@ -90,6 +118,13 @@ def _columns_for_constraint(constraint):
         return list(constraint.columns)
 
 
+def _reflect_table(inspector, table, include_cols):
+    if sqla_14:
+        return inspector.reflect_table(table, None)
+    else:
+        return inspector.reflecttable(table, None)
+
+
 def _fk_spec(constraint):
     source_columns = [
         constraint.columns[key].name for key in constraint.column_keys
diff --git a/docs/build/unreleased/730.rst b/docs/build/unreleased/730.rst
new file mode 100644 (file)
index 0000000..e9b967f
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: usecase, operations
+    :tickets: 730
+
+    Added support for rendering of "identity" elements on
+    :class:`.Column` objects, supported in SQLAlchemy via
+    the :class:`.Identity` element introduced in version 1.4.
+
+    Adding columns with identity is supported on PostgreSQL,
+    MSSQL and Oracle. Changing the identity options or removing
+    it is supported only on PostgreSQL and Oracle.
index eb6066db6c70a5188b3a92bfd28303b07242f2ec..6f171f61cea733201d77cf9e6f9fa8244ad49a93 100644 (file)
@@ -251,3 +251,23 @@ class DefaultRequirements(SuiteRequirements):
                 return norm_version_info >= (8, 0, 16)
         else:
             return False
+
+    @property
+    def identity_columns(self):
+        # TODO: in theory if these could come from SQLAlchemy dialects
+        # that would be helpful
+        return self.identity_columns_api + exclusions.only_on(
+            ["postgresql >= 10", "oracle >= 12", "mssql"]
+        )
+
+    @property
+    def identity_columns_alter(self):
+        # TODO: in theory if these could come from SQLAlchemy dialects
+        # that would be helpful
+        return self.identity_columns_api + exclusions.only_on(
+            ["postgresql >= 10", "oracle >= 12"]
+        )
+
+    @property
+    def supports_identity_on_null(self):
+        return self.identity_columns + exclusions.only_on(["oracle"])
diff --git a/tests/test_autogen_identity.py b/tests/test_autogen_identity.py
new file mode 100644 (file)
index 0000000..9d8253a
--- /dev/null
@@ -0,0 +1,241 @@
+import sqlalchemy as sa
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import Table
+
+from alembic import testing
+from alembic.testing import config
+from alembic.testing import eq_
+from alembic.testing import is_true
+from alembic.testing import TestBase
+from ._autogen_fixtures import AutogenFixtureTest
+
+
+class AutogenerateIdentityTest(AutogenFixtureTest, TestBase):
+    __requires__ = ("identity_columns",)
+    __backend__ = True
+
+    def test_add_identity_column(self):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        Table("user", m1, Column("other", sa.Text))
+
+        Table(
+            "user",
+            m2,
+            Column("other", sa.Text),
+            Column(
+                "id",
+                Integer,
+                sa.Identity(start=5, increment=7),
+                primary_key=True,
+            ),
+        )
+
+        diffs = self._fixture(m1, m2)
+
+        eq_(diffs[0][0], "add_column")
+        eq_(diffs[0][2], "user")
+        eq_(diffs[0][3].name, "id")
+        i = diffs[0][3].identity
+
+        is_true(isinstance(i, sa.Identity))
+        eq_(i.start, 5)
+        eq_(i.increment, 7)
+
+    def test_remove_identity_column(self):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        Table(
+            "user",
+            m1,
+            Column(
+                "id",
+                Integer,
+                sa.Identity(start=2, increment=3),
+                primary_key=True,
+            ),
+        )
+
+        Table("user", m2)
+
+        diffs = self._fixture(m1, m2)
+
+        eq_(diffs[0][0], "remove_column")
+        eq_(diffs[0][2], "user")
+        c = diffs[0][3]
+        eq_(c.name, "id")
+
+        is_true(isinstance(c.identity, sa.Identity))
+        eq_(c.identity.start, 2)
+        eq_(c.identity.increment, 3)
+
+    def test_no_change_identity_column(self):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        for m in (m1, m2):
+            Table(
+                "user",
+                m,
+                Column("id", Integer, sa.Identity(start=2)),
+            )
+
+        diffs = self._fixture(m1, m2)
+
+        eq_(diffs, [])
+
+    @testing.combinations(
+        (None, dict(start=2)),
+        (dict(start=2), None),
+        (dict(start=2), dict(start=2, increment=7)),
+        (dict(always=False), dict(always=True)),
+        (
+            dict(start=1, minvalue=0, maxvalue=100, cycle=True),
+            dict(start=1, minvalue=0, maxvalue=100, cycle=False),
+        ),
+        (
+            dict(start=10, increment=3, maxvalue=9999),
+            dict(start=10, increment=1, maxvalue=3333),
+        ),
+    )
+    @config.requirements.identity_columns_alter
+    def test_change_identity(self, before, after):
+        arg_before = (sa.Identity(**before),) if before else ()
+        arg_after = (sa.Identity(**after),) if after else ()
+
+        m1 = MetaData()
+        m2 = MetaData()
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, *arg_before),
+            Column("other", sa.Text),
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, *arg_after),
+            Column("other", sa.Text),
+        )
+
+        diffs = self._fixture(m1, m2)
+
+        eq_(len(diffs[0]), 1)
+        diffs = diffs[0][0]
+        eq_(diffs[0], "modify_default")
+        eq_(diffs[2], "user")
+        eq_(diffs[3], "id")
+        old = diffs[5]
+        new = diffs[6]
+
+        def check(kw, idt):
+            if kw:
+                is_true(isinstance(idt, sa.Identity))
+                for k, v in kw.items():
+                    eq_(getattr(idt, k), v)
+            else:
+                is_true(idt in (None, False))
+
+        check(before, old)
+        check(after, new)
+
+    def test_add_identity_to_column(self):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer),
+            Column("other", sa.Text),
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, sa.Identity(start=2, maxvalue=1000)),
+            Column("other", sa.Text),
+        )
+
+        diffs = self._fixture(m1, m2)
+
+        eq_(len(diffs[0]), 1)
+        diffs = diffs[0][0]
+        eq_(diffs[0], "modify_default")
+        eq_(diffs[2], "user")
+        eq_(diffs[3], "id")
+        eq_(diffs[5], None)
+        added = diffs[6]
+
+        is_true(isinstance(added, sa.Identity))
+        eq_(added.start, 2)
+        eq_(added.maxvalue, 1000)
+
+    def test_remove_identity_from_column(self):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, sa.Identity(start=2, maxvalue=1000)),
+            Column("other", sa.Text),
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer),
+            Column("other", sa.Text),
+        )
+
+        diffs = self._fixture(m1, m2)
+
+        eq_(len(diffs[0]), 1)
+        diffs = diffs[0][0]
+        eq_(diffs[0], "modify_default")
+        eq_(diffs[2], "user")
+        eq_(diffs[3], "id")
+        eq_(diffs[6], None)
+        removed = diffs[5]
+
+        is_true(isinstance(removed, sa.Identity))
+
+    def test_identity_on_null(self):
+        m1 = MetaData()
+        m2 = MetaData()
+
+        Table(
+            "user",
+            m1,
+            Column("id", Integer, sa.Identity(start=2, on_null=True)),
+            Column("other", sa.Text),
+        )
+
+        Table(
+            "user",
+            m2,
+            Column("id", Integer, sa.Identity(start=2, on_null=False)),
+            Column("other", sa.Text),
+        )
+
+        diffs = self._fixture(m1, m2)
+        if not config.requirements.supports_identity_on_null.enabled:
+            eq_(diffs, [])
+        else:
+            eq_(len(diffs[0]), 1)
+            diffs = diffs[0][0]
+            eq_(diffs[0], "modify_default")
+            eq_(diffs[2], "user")
+            eq_(diffs[3], "id")
+            old = diffs[5]
+            new = diffs[6]
+
+            is_true(isinstance(old, sa.Identity))
+            is_true(isinstance(new, sa.Identity))
index e1f32697ad9da5e4aa849d328b7e0af7a00be08f..d2dcc34995e0c19d8a153675f2aff1c449aebcfd 100644 (file)
@@ -2123,6 +2123,93 @@ class AutogenRenderTest(TestBase):
             % persisted,
         )
 
+    @config.requirements.identity_columns
+    @testing.combinations(
+        ({}, "sa.Identity(always=False)"),
+        (dict(always=None), "sa.Identity(always=None)"),
+        (dict(always=True), "sa.Identity(always=True)"),
+        (
+            dict(
+                always=False,
+                on_null=True,
+                start=2,
+                increment=4,
+                minvalue=-3,
+                maxvalue=99,
+                nominvalue=True,
+                nomaxvalue=True,
+                cycle=True,
+                cache=42,
+                order=True,
+            ),
+            "sa.Identity(always=False, on_null=True, start=2, increment=4, "
+            "minvalue=-3, maxvalue=99, nominvalue=True, nomaxvalue=True, "
+            "cycle=True, cache=42, order=True)",
+        ),
+    )
+    def test_render_add_column_identity(self, kw, text):
+        op_obj = ops.AddColumnOp(
+            "foo", Column("x", Integer, sa.Identity(**kw))
+        )
+        eq_ignore_whitespace(
+            autogenerate.render_op_text(self.autogen_context, op_obj),
+            "op.add_column('foo', sa.Column('x', sa.Integer(), "
+            "%s, nullable=True))" % text,
+        )
+
+    @config.requirements.identity_columns
+    @testing.combinations(
+        ({}, "sa.Identity(always=False)"),
+        (dict(always=None), "sa.Identity(always=None)"),
+        (dict(always=True), "sa.Identity(always=True)"),
+        (
+            dict(
+                always=False,
+                on_null=True,
+                start=2,
+                increment=4,
+                minvalue=-3,
+                maxvalue=99,
+                nominvalue=True,
+                nomaxvalue=True,
+                cycle=True,
+                cache=42,
+                order=True,
+            ),
+            "sa.Identity(always=False, on_null=True, start=2, increment=4, "
+            "minvalue=-3, maxvalue=99, nominvalue=True, nomaxvalue=True, "
+            "cycle=True, cache=42, order=True)",
+        ),
+    )
+    def test_render_alter_column_add_identity(self, kw, text):
+        op_obj = ops.AlterColumnOp(
+            "foo",
+            "x",
+            existing_type=Integer(),
+            existing_server_default=None,
+            modify_server_default=sa.Identity(**kw),
+        )
+        eq_ignore_whitespace(
+            autogenerate.render_op_text(self.autogen_context, op_obj),
+            "op.alter_column('foo', 'x', existing_type=sa.Integer(), "
+            "server_default=%s)" % text,
+        )
+
+    @config.requirements.identity_columns
+    def test_render_alter_column_drop_identity(self):
+        op_obj = ops.AlterColumnOp(
+            "foo",
+            "x",
+            existing_type=Integer(),
+            existing_server_default=sa.Identity(),
+            modify_server_default=None,
+        )
+        eq_ignore_whitespace(
+            autogenerate.render_op_text(self.autogen_context, op_obj),
+            "op.alter_column('foo', 'x', existing_type=sa.Integer(), "
+            "server_default=None)",
+        )
+
 
 class RenderNamingConventionTest(TestBase):
     def setUp(self):
index 36b90cb1c8710c1e13ebee73d98d97697f3bba87..0c8b3ae63a88d3494c8585286a83482a046660bb 100644 (file)
@@ -1,12 +1,14 @@
 """Test op functions against MSSQL."""
 
 from sqlalchemy import Column
+from sqlalchemy import exc
 from sqlalchemy import Integer
 
 from alembic import command
 from alembic import op
 from alembic import util
 from alembic.testing import assert_raises_message
+from alembic.testing import combinations
 from alembic.testing import config
 from alembic.testing import eq_
 from alembic.testing.env import _no_sql_testing_config
@@ -354,3 +356,74 @@ class OpTest(TestBase):
         context.assert_contains(
             "CREATE INDEX ix_mytable_a_b ON mytable " "(col_a, col_b)"
         )
+
+    @combinations(
+        (lambda: sqla_compat.Computed("foo * 5"), lambda: None),
+        (lambda: None, lambda: sqla_compat.Computed("foo * 5")),
+        (
+            lambda: sqla_compat.Computed("foo * 42"),
+            lambda: sqla_compat.Computed("foo * 5"),
+        ),
+    )
+    @config.requirements.computed_columns
+    def test_alter_column_computed_not_supported(self, sd, esd):
+        op_fixture("mssql")
+        assert_raises_message(
+            exc.CompileError,
+            'Adding or removing a "computed" construct, e.g. '
+            "GENERATED ALWAYS AS, to or from an existing column is not "
+            "supported.",
+            op.alter_column,
+            "t1",
+            "c1",
+            server_default=sd(),
+            existing_server_default=esd(),
+        )
+
+    @config.requirements.identity_columns
+    @combinations(
+        ({},),
+        (dict(always=True),),
+        (dict(start=3),),
+        (dict(start=3, increment=3),),
+    )
+    def test_add_column_identity(self, kw):
+        context = op_fixture("mssql")
+        op.add_column(
+            "t1",
+            Column("some_column", Integer, sqla_compat.Identity(**kw)),
+        )
+        if "start" in kw or "increment" in kw:
+            options = "(%s,%s)" % (
+                kw.get("start", 1),
+                kw.get("increment", 1),
+            )
+        else:
+            options = ""
+        context.assert_(
+            "ALTER TABLE t1 ADD some_column INTEGER NOT NULL IDENTITY%s"
+            % options
+        )
+
+    @combinations(
+        (lambda: sqla_compat.Identity(), lambda: None),
+        (lambda: None, lambda: sqla_compat.Identity()),
+        (
+            lambda: sqla_compat.Identity(),
+            lambda: sqla_compat.Identity(),
+        ),
+    )
+    @config.requirements.identity_columns
+    def test_alter_column_identity_add_not_supported(self, sd, esd):
+        op_fixture("mssql")
+        assert_raises_message(
+            exc.CompileError,
+            'Adding, removing or modifying an "identity" construct, '
+            "e.g. GENERATED AS IDENTITY, to or from an existing "
+            "column is not supported in this dialect.",
+            op.alter_column,
+            "t1",
+            "c1",
+            server_default=sd(),
+            existing_server_default=esd(),
+        )
index 6fcc35a92805d5a48e616e291d0abdc274b0ec9f..caef197f611677b481123b460d39d8fe3a767cab 100644 (file)
@@ -1,6 +1,7 @@
 from sqlalchemy import Boolean
 from sqlalchemy import Column
 from sqlalchemy import DATETIME
+from sqlalchemy import exc
 from sqlalchemy import Float
 from sqlalchemy import func
 from sqlalchemy import inspect
@@ -17,6 +18,7 @@ from alembic.autogenerate import compare
 from alembic.migration import MigrationContext
 from alembic.operations import ops
 from alembic.testing import assert_raises_message
+from alembic.testing import combinations
 from alembic.testing import config
 from alembic.testing.env import clear_staging_env
 from alembic.testing.env import staging_env
@@ -460,6 +462,52 @@ class MySQLOpTest(TestBase):
             "t1",
         )
 
+    @combinations(
+        (lambda: sqla_compat.Computed("foo * 5"), lambda: None),
+        (lambda: None, lambda: sqla_compat.Computed("foo * 5")),
+        (
+            lambda: sqla_compat.Computed("foo * 42"),
+            lambda: sqla_compat.Computed("foo * 5"),
+        ),
+    )
+    @config.requirements.computed_columns_api
+    def test_alter_column_computed_not_supported(self, sd, esd):
+        op_fixture("mssql")
+        assert_raises_message(
+            exc.CompileError,
+            'Adding or removing a "computed" construct, e.g. '
+            "GENERATED ALWAYS AS, to or from an existing column is not "
+            "supported.",
+            op.alter_column,
+            "t1",
+            "c1",
+            server_default=sd(),
+            existing_server_default=esd(),
+        )
+
+    @combinations(
+        (lambda: sqla_compat.Identity(), lambda: None),
+        (lambda: None, lambda: sqla_compat.Identity()),
+        (
+            lambda: sqla_compat.Identity(),
+            lambda: sqla_compat.Identity(),
+        ),
+    )
+    @config.requirements.identity_columns_api
+    def test_alter_column_identity_not_supported(self, sd, esd):
+        op_fixture()
+        assert_raises_message(
+            exc.CompileError,
+            'Adding, removing or modifying an "identity" construct, '
+            "e.g. GENERATED AS IDENTITY, to or from an existing "
+            "column is not supported in this dialect.",
+            op.alter_column,
+            "t1",
+            "c1",
+            server_default=sd(),
+            existing_server_default=esd(),
+        )
+
 
 class MySQLBackendOpTest(AlterColRoundTripFixture, TestBase):
     __only_on__ = "mysql", "mariadb"
@@ -578,7 +626,7 @@ class MySQLDefaultCompareTest(TestBase):
         insp = inspect(self.bind)
         cols = insp.get_columns(t1.name)
         refl = Table(t1.name, MetaData())
-        insp.reflecttable(refl, None)
+        sqla_compat._reflect_table(insp, refl, None)
         ctx = self.autogen_context["context"]
         return ctx.impl.compare_server_default(
             refl.c[cols[0]["name"]], col, rendered, cols[0]["default"]
index 7b8dc02766e90992aa65fc04ebe991581fb99a59..58a6a860a728b645face22e2dbe907bc676fe9b9 100644 (file)
@@ -19,6 +19,7 @@ from alembic import op
 from alembic.operations import ops
 from alembic.operations import schemaobj
 from alembic.testing import assert_raises_message
+from alembic.testing import combinations
 from alembic.testing import config
 from alembic.testing import eq_
 from alembic.testing import is_
@@ -374,8 +375,16 @@ class OpTest(TestBase):
         op.alter_column("t", "c", server_default=None, schema="foo")
         context.assert_("ALTER TABLE foo.t ALTER COLUMN c DROP DEFAULT")
 
+    @combinations(
+        (lambda: sqla_compat.Computed("foo * 5"), lambda: None),
+        (lambda: None, lambda: sqla_compat.Computed("foo * 5")),
+        (
+            lambda: sqla_compat.Computed("foo * 42"),
+            lambda: sqla_compat.Computed("foo * 5"),
+        ),
+    )
     @config.requirements.computed_columns_api
-    def test_alter_column_computed_add_not_supported(self):
+    def test_alter_column_computed_not_supported(self, sd, esd):
         op_fixture()
         assert_raises_message(
             exc.CompileError,
@@ -385,22 +394,31 @@ class OpTest(TestBase):
             op.alter_column,
             "t1",
             "c1",
-            server_default=sqla_compat.Computed("foo * 5"),
-        )
-
-    @config.requirements.computed_columns_api
-    def test_alter_column_computed_remove_not_supported(self):
+            server_default=sd(),
+            existing_server_default=esd(),
+        )
+
+    @combinations(
+        (lambda: sqla_compat.Identity(), lambda: None),
+        (lambda: None, lambda: sqla_compat.Identity()),
+        (
+            lambda: sqla_compat.Identity(),
+            lambda: sqla_compat.Identity(),
+        ),
+    )
+    @config.requirements.identity_columns_api
+    def test_alter_column_identity_not_supported(self, sd, esd):
         op_fixture()
         assert_raises_message(
             exc.CompileError,
-            'Adding or removing a "computed" construct, e.g. '
-            "GENERATED ALWAYS AS, to or from an existing column is not "
-            "supported.",
+            'Adding, removing or modifying an "identity" construct, '
+            "e.g. GENERATED AS IDENTITY, to or from an existing "
+            "column is not supported in this dialect.",
             op.alter_column,
             "t1",
             "c1",
-            server_default=None,
-            existing_server_default=sqla_compat.Computed("foo * 5"),
+            server_default=sd(),
+            existing_server_default=esd(),
         )
 
     def test_alter_column_schema_type_unnamed(self):
index f7061903a264680f9b39d3a390c04ca8597fdfc7..f84c0e12f83f0a4b055f9410cbdfbda900b2e800 100644 (file)
@@ -1,8 +1,11 @@
 from sqlalchemy import Column
+from sqlalchemy import exc
 from sqlalchemy import Integer
 
 from alembic import command
 from alembic import op
+from alembic.testing import assert_raises_message
+from alembic.testing import combinations
 from alembic.testing import config
 from alembic.testing.env import _no_sql_testing_config
 from alembic.testing.env import clear_staging_env
@@ -68,7 +71,7 @@ class OpTest(TestBase):
             "COMMENT ON COLUMN t1.c1 IS 'c1 comment'",
         )
 
-    @config.requirements.computed_columns_api
+    @config.requirements.computed_columns
     def test_add_column_computed(self):
         context = op_fixture("oracle")
         op.add_column(
@@ -80,6 +83,29 @@ class OpTest(TestBase):
             "INTEGER GENERATED ALWAYS AS (foo * 5)"
         )
 
+    @combinations(
+        (lambda: sqla_compat.Computed("foo * 5"), lambda: None),
+        (lambda: None, lambda: sqla_compat.Computed("foo * 5")),
+        (
+            lambda: sqla_compat.Computed("foo * 42"),
+            lambda: sqla_compat.Computed("foo * 5"),
+        ),
+    )
+    @config.requirements.computed_columns
+    def test_alter_column_computed_not_supported(self, sd, esd):
+        op_fixture("oracle")
+        assert_raises_message(
+            exc.CompileError,
+            'Adding or removing a "computed" construct, e.g. '
+            "GENERATED ALWAYS AS, to or from an existing column is not "
+            "supported.",
+            op.alter_column,
+            "t1",
+            "c1",
+            server_default=sd(),
+            existing_server_default=esd(),
+        )
+
     def test_alter_table_rename_oracle(self):
         context = op_fixture("oracle")
         op.rename_table("s", "t")
@@ -226,3 +252,125 @@ class OpTest(TestBase):
     #    context.assert_(
     #        'ALTER TABLE y.t RENAME COLUMN c TO c2'
     #    )
+
+    def _identity_qualification(self, kw):
+        always = kw.get("always", False)
+        if always is None:
+            return ""
+        qualification = "ALWAYS" if always else "BY DEFAULT"
+        if kw.get("on_null", False):
+            qualification += " ON NULL"
+        return qualification
+
+    @config.requirements.identity_columns
+    @combinations(
+        ({}, None),
+        (dict(always=True), None),
+        (dict(always=None, order=True), "ORDER"),
+        (
+            dict(start=3, increment=33, maxvalue=99, cycle=True),
+            "INCREMENT BY 33 START WITH 3 MAXVALUE 99 CYCLE",
+        ),
+        (dict(on_null=True, start=42), "START WITH 42"),
+    )
+    def test_add_column_identity(self, kw, text):
+        context = op_fixture("oracle")
+        op.add_column(
+            "t1",
+            Column("some_column", Integer, sqla_compat.Identity(**kw)),
+        )
+        qualification = self._identity_qualification(kw)
+        options = " (%s)" % text if text else ""
+        context.assert_(
+            "ALTER TABLE t1 ADD some_column "
+            "INTEGER GENERATED %s AS IDENTITY%s" % (qualification, options)
+        )
+
+    @config.requirements.identity_columns
+    @combinations(
+        ({}, None),
+        (dict(always=True), None),
+        (dict(always=None, cycle=True), "CYCLE"),
+        (
+            dict(start=3, increment=33, maxvalue=99, cycle=True),
+            "INCREMENT BY 33 START WITH 3 MAXVALUE 99 CYCLE",
+        ),
+        (dict(on_null=True, start=42), "START WITH 42"),
+    )
+    def test_add_identity_to_column(self, kw, text):
+        context = op_fixture("oracle")
+        op.alter_column(
+            "t1",
+            "some_column",
+            server_default=sqla_compat.Identity(**kw),
+            existing_server_default=None,
+        )
+        qualification = self._identity_qualification(kw)
+        options = " (%s)" % text if text else ""
+        context.assert_(
+            "ALTER TABLE t1 MODIFY some_column "
+            "GENERATED %s AS IDENTITY%s" % (qualification, options)
+        )
+
+    @config.requirements.identity_columns
+    def test_remove_identity_from_column(self):
+        context = op_fixture("oracle")
+        op.alter_column(
+            "t1",
+            "some_column",
+            server_default=None,
+            existing_server_default=sqla_compat.Identity(),
+        )
+        context.assert_("ALTER TABLE t1 MODIFY some_column DROP IDENTITY")
+
+    @config.requirements.identity_columns
+    @combinations(
+        ({}, dict(always=True), None),
+        (
+            dict(always=True),
+            dict(always=False, start=3),
+            "START WITH 3",
+        ),
+        (
+            dict(always=True, start=3, increment=2, minvalue=-3, maxvalue=99),
+            dict(
+                always=True,
+                start=3,
+                increment=1,
+                minvalue=-3,
+                maxvalue=99,
+                cycle=True,
+            ),
+            "INCREMENT BY 1 START WITH 3 MINVALUE -3 MAXVALUE 99 CYCLE",
+        ),
+        (
+            dict(
+                always=False,
+                start=3,
+                maxvalue=9999,
+                minvalue=0,
+            ),
+            dict(always=False, start=3, order=True, on_null=False, cache=2),
+            "START WITH 3 CACHE 2 ORDER",
+        ),
+        (
+            dict(always=False),
+            dict(always=None, minvalue=0),
+            "MINVALUE 0",
+        ),
+    )
+    def test_change_identity_in_column(self, existing, updated, text):
+        context = op_fixture("oracle")
+        op.alter_column(
+            "t1",
+            "some_column",
+            server_default=sqla_compat.Identity(**updated),
+            existing_server_default=sqla_compat.Identity(**existing),
+        )
+
+        qualification = self._identity_qualification(updated)
+        options = " (%s)" % text if text else ""
+        context.assert_(
+            "ALTER TABLE t1 MODIFY some_column "
+            "GENERATED %s AS IDENTITY%s" % (qualification, options)
+        )
index f928868fd6d1fc5897c520e34ea527e535a5c3be..08f70d84310ba9b853a3d4c2391720b8c1af2d9e 100644 (file)
@@ -2,6 +2,7 @@ from sqlalchemy import BigInteger
 from sqlalchemy import Boolean
 from sqlalchemy import Column
 from sqlalchemy import DateTime
+from sqlalchemy import exc
 from sqlalchemy import Float
 from sqlalchemy import func
 from sqlalchemy import Index
@@ -37,6 +38,8 @@ from alembic.migration import MigrationContext
 from alembic.operations import Operations
 from alembic.operations import ops
 from alembic.script import ScriptDirectory
+from alembic.testing import assert_raises_message
+from alembic.testing import combinations
 from alembic.testing import config
 from alembic.testing import eq_
 from alembic.testing import eq_ignore_whitespace
@@ -276,7 +279,7 @@ class PostgresqlOpTest(TestBase):
         op.drop_table_comment("t2", existing_comment="t2 table", schema="foo")
         context.assert_("COMMENT ON TABLE foo.t2 IS NULL")
 
-    @config.requirements.computed_columns_api
+    @config.requirements.computed_columns
     def test_add_column_computed(self):
         context = op_fixture("postgresql")
         op.add_column(
@@ -288,6 +291,134 @@ class PostgresqlOpTest(TestBase):
             "INTEGER GENERATED ALWAYS AS (foo * 5) STORED"
         )
 
+    @combinations(
+        (lambda: sqla_compat.Computed("foo * 5"), lambda: None),
+        (lambda: None, lambda: sqla_compat.Computed("foo * 5")),
+        (
+            lambda: sqla_compat.Computed("foo * 42"),
+            lambda: sqla_compat.Computed("foo * 5"),
+        ),
+    )
+    @config.requirements.computed_columns
+    def test_alter_column_computed_not_supported(self, sd, esd):
+        op_fixture("postgresql")
+        assert_raises_message(
+            exc.CompileError,
+            'Adding or removing a "computed" construct, e.g. '
+            "GENERATED ALWAYS AS, to or from an existing column is not "
+            "supported.",
+            op.alter_column,
+            "t1",
+            "c1",
+            server_default=sd(),
+            existing_server_default=esd(),
+        )
+
+    @config.requirements.identity_columns
+    @combinations(
+        ({}, None),
+        (dict(always=True), None),
+        (
+            dict(start=3, increment=33, maxvalue=99, cycle=True),
+            "INCREMENT BY 33 START WITH 3 MAXVALUE 99 CYCLE",
+        ),
+    )
+    def test_add_column_identity(self, kw, text):
+        context = op_fixture("postgresql")
+        op.add_column(
+            "t1",
+            Column("some_column", Integer, sqla_compat.Identity(**kw)),
+        )
+        qualification = "ALWAYS" if kw.get("always", False) else "BY DEFAULT"
+        options = " (%s)" % text if text else ""
+        context.assert_(
+            "ALTER TABLE t1 ADD COLUMN some_column "
+            "INTEGER GENERATED %s AS IDENTITY%s" % (qualification, options)
+        )
+
+    @config.requirements.identity_columns
+    @combinations(
+        ({}, None),
+        (dict(always=True), None),
+        (
+            dict(start=3, increment=33, maxvalue=99, cycle=True),
+            "INCREMENT BY 33 START WITH 3 MAXVALUE 99 CYCLE",
+        ),
+    )
+    def test_add_identity_to_column(self, kw, text):
+        context = op_fixture("postgresql")
+        op.alter_column(
+            "t1",
+            "some_column",
+            server_default=sqla_compat.Identity(**kw),
+            existing_server_default=None,
+        )
+        qualification = "ALWAYS" if kw.get("always", False) else "BY DEFAULT"
+        options = " (%s)" % text if text else ""
+        context.assert_(
+            "ALTER TABLE t1 ALTER COLUMN some_column ADD "
+            "GENERATED %s AS IDENTITY%s" % (qualification, options)
+        )
+
+    @config.requirements.identity_columns
+    def test_remove_identity_from_column(self):
+        context = op_fixture("postgresql")
+        op.alter_column(
+            "t1",
+            "some_column",
+            server_default=None,
+            existing_server_default=sqla_compat.Identity(),
+        )
+        context.assert_(
+            "ALTER TABLE t1 ALTER COLUMN some_column DROP IDENTITY"
+        )
+
+    @config.requirements.identity_columns
+    @combinations(
+        ({}, dict(always=True), "SET GENERATED ALWAYS"),
+        (
+            dict(always=True),
+            dict(always=False, start=3),
+            "SET GENERATED BY DEFAULT SET START WITH 3",
+        ),
+        (
+            dict(always=True, start=3, increment=2, minvalue=-3, maxvalue=99),
+            dict(
+                always=True,
+                start=3,
+                increment=1,
+                minvalue=-3,
+                maxvalue=99,
+                cycle=True,
+            ),
+            "SET CYCLE SET INCREMENT BY 1",
+        ),
+        (
+            dict(
+                always=False,
+                start=3,
+                maxvalue=9999,
+                minvalue=0,
+            ),
+            dict(always=False, start=3, order=True, on_null=False, cache=2),
+            "SET CACHE 2",
+        ),
+        (
+            dict(always=False),
+            dict(always=None, minvalue=0),
+            "SET MINVALUE 0",
+        ),
+    )
+    def test_change_identity_in_column(self, existing, updated, text):
+        context = op_fixture("postgresql")
+        op.alter_column(
+            "t1",
+            "some_column",
+            server_default=sqla_compat.Identity(**updated),
+            existing_server_default=sqla_compat.Identity(**existing),
+        )
+        context.assert_("ALTER TABLE t1 ALTER COLUMN some_column %s" % text)
+
 
 class PGAutocommitBlockTest(TestBase):
     __only_on__ = "postgresql"