From: Federico Caselli Date: Thu, 10 Aug 2023 21:54:43 +0000 (+0200) Subject: adapt identity logic to support dialect kwags X-Git-Tag: rel_1_12_1~12^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ccae936643641c83233138f3713263a955cf51da;p=thirdparty%2Fsqlalchemy%2Falembic.git adapt identity logic to support dialect kwags Alembic now accommodates for Sequence and Identity that support dialect kwargs. This is a change that will be added to SQLAlchemy v2.1. Fixes: #1304 Change-Id: I68d46426296931dee68eeb909cbe17d1c48a5899 --- diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 3729a486..fa24c397 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections import OrderedDict from io import StringIO import re from typing import Any @@ -760,11 +759,9 @@ def _render_computed( def _render_identity( identity: Identity, autogen_context: AutogenContext ) -> str: - # always=None means something different than always=False - kwargs = OrderedDict(always=identity.always) - if identity.on_null is not None: - kwargs["on_null"] = identity.on_null - kwargs.update(_get_identity_options(identity)) + kwargs = sqla_compat._get_identity_options_dict( + identity, dialect_kwargs=True + ) return "%(prefix)sIdentity(%(kwargs)s)" % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), @@ -772,15 +769,6 @@ def _render_identity( } -def _get_identity_options(identity_options: Identity) -> OrderedDict: - kwargs = OrderedDict() - for attr in sqla_compat._identity_options_attrs: - value = getattr(identity_options, attr, None) - if value is not None: - kwargs[attr] = value - return kwargs - - def _repr_type( type_: TypeEngine, autogen_context: AutogenContext, diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 38827092..5ae5f2f9 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -5,7 +5,9 @@ import re from typing import Any from typing import Callable from typing import Dict +from typing import Iterable from typing import List +from typing import Mapping from typing import Optional from typing import Sequence from typing import Set @@ -86,8 +88,11 @@ class DefaultImpl(metaclass=ImplMeta): command_terminator = ";" type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},) type_arg_extract: Sequence[str] = () - # on_null is known to be supported only by oracle - identity_attrs_ignore: Tuple[str, ...] = ("on_null",) + # These attributes are deprecated in SQLAlchemy via #10247. They need to + # be ignored to support older version that did not use dialect kwargs. + # They only apply to Oracle and are replaced by oracle_order, + # oracle_on_null + identity_attrs_ignore: Tuple[str, ...] = ("order", "on_null") def __init__( self, @@ -638,10 +643,10 @@ class DefaultImpl(metaclass=ImplMeta): # ignored contains the attributes that were not considered # because assumed to their default values in the db. diff, ignored = _compare_identity_options( - sqla_compat._identity_attrs, metadata_identity, inspector_identity, sqla_compat.Identity(), + skip={"always"}, ) meta_always = getattr(metadata_identity, "always", None) @@ -696,20 +701,50 @@ class DefaultImpl(metaclass=ImplMeta): def _compare_identity_options( - attributes, metadata_io, inspector_io, default_io + metadata_io: Union[schema.Identity, schema.Sequence, None], + inspector_io: Union[schema.Identity, schema.Sequence, None], + default_io: Union[schema.Identity, schema.Sequence], + skip: Set[str], ): # this can be used for identity or sequence compare. # default_io is an instance of IdentityOption with all attributes to the # default value. + meta_d = sqla_compat._get_identity_options_dict(metadata_io) + insp_d = sqla_compat._get_identity_options_dict(inspector_io) + diff = set() ignored_attr = set() - for attr in attributes: - meta_value = getattr(metadata_io, attr, None) - default_value = getattr(default_io, attr, None) - conn_value = getattr(inspector_io, attr, None) - if conn_value != meta_value: - if meta_value == default_value: - ignored_attr.add(attr) - else: - diff.add(attr) + + def check_dicts( + meta_dict: Mapping[str, Any], + insp_dict: Mapping[str, Any], + default_dict: Mapping[str, Any], + attrs: Iterable[str], + ): + for attr in set(attrs).difference(skip): + meta_value = meta_dict.get(attr) + insp_value = insp_dict.get(attr) + if insp_value != meta_value: + default_value = default_dict.get(attr) + if meta_value == default_value: + ignored_attr.add(attr) + else: + diff.add(attr) + + check_dicts( + meta_d, + insp_d, + sqla_compat._get_identity_options_dict(default_io), + set(meta_d).union(insp_d), + ) + if sqla_compat.identity_has_dialect_kwargs: + # use only the dialect kwargs in inspector_io since metadata_io + # can have options for many backends + check_dicts( + getattr(metadata_io, "dialect_kwargs", {}), + getattr(inspector_io, "dialect_kwargs", {}), + default_io.dialect_kwargs, # type: ignore[union-attr] + getattr(inspector_io, "dialect_kwargs", {}), + ) + return diff, ignored_attr diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index dbd8de6c..9b0fff88 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -51,16 +51,13 @@ class MSSQLImpl(DefaultImpl): batch_separator = "GO" type_synonyms = DefaultImpl.type_synonyms + ({"VARCHAR", "NVARCHAR"},) - identity_attrs_ignore = ( + identity_attrs_ignore = DefaultImpl.identity_attrs_ignore + ( "minvalue", "maxvalue", "nominvalue", "nomaxvalue", "cycle", "cache", - "order", - "on_null", - "order", ) def __init__(self, *arg, **kw) -> None: diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index b63938ac..6c2ab64c 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -79,7 +79,6 @@ class PostgresqlImpl(DefaultImpl): type_synonyms = DefaultImpl.type_synonyms + ( {"FLOAT", "DOUBLE PRECISION"}, ) - identity_attrs_ignore = ("on_null", "order") def create_index(self, index: Index, **kw: Any) -> None: # this likely defaults to None if not present, so get() diff --git a/alembic/testing/requirements.py b/alembic/testing/requirements.py index 40de4cd2..2107da46 100644 --- a/alembic/testing/requirements.py +++ b/alembic/testing/requirements.py @@ -196,7 +196,3 @@ class SuiteRequirements(Requirements): return exclusions.only_if( exclusions.BooleanPredicate(sqla_compat.has_identity) ) - - @property - def supports_identity_on_null(self): - return exclusions.closed() diff --git a/alembic/testing/suite/test_autogen_identity.py b/alembic/testing/suite/test_autogen_identity.py index 9aedf9e9..3dee9fc9 100644 --- a/alembic/testing/suite/test_autogen_identity.py +++ b/alembic/testing/suite/test_autogen_identity.py @@ -4,6 +4,7 @@ from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import Table +from alembic.util import sqla_compat from ._autogen_fixtures import AutogenFixtureTest from ... import testing from ...testing import config @@ -78,16 +79,33 @@ class AutogenerateIdentityTest(AutogenFixtureTest, TestBase): m2 = MetaData() for m in (m1, m2): - Table( - "user", - m, - Column("id", Integer, sa.Identity(start=2)), - ) + id_ = sa.Identity(start=2) + Table("user", m, Column("id", Integer, id_)) diffs = self._fixture(m1, m2) eq_(diffs, []) + def test_dialect_kwargs_changes(self): + m1 = MetaData() + m2 = MetaData() + + if sqla_compat.identity_has_dialect_kwargs: + args = {"oracle_on_null": True, "oracle_order": True} + else: + args = {"on_null": True, "order": True} + + Table("user", m1, Column("id", Integer, sa.Identity(start=2))) + id_ = sa.Identity(start=2, **args) + Table("user", m2, Column("id", Integer, id_)) + + diffs = self._fixture(m1, m2) + if config.db.name == "oracle": + is_true(len(diffs), 1) + eq_(diffs[0][0][0], "modify_default") + else: + eq_(diffs, []) + @testing.combinations( (None, dict(start=2)), (dict(start=2), None), @@ -206,36 +224,3 @@ class AutogenerateIdentityTest(AutogenFixtureTest, TestBase): removed = diffs[5] is_true(isinstance(removed, sa.Identity)) - - def test_identity_on_null(self): - m1 = MetaData() - m2 = MetaData() - - Table( - "user", - m1, - Column("id", Integer, sa.Identity(start=2, on_null=True)), - Column("other", sa.Text), - ) - - Table( - "user", - m2, - Column("id", Integer, sa.Identity(start=2, on_null=False)), - Column("other", sa.Text), - ) - - diffs = self._fixture(m1, m2) - if not config.requirements.supports_identity_on_null.enabled: - eq_(diffs, []) - else: - eq_(len(diffs[0]), 1) - diffs = diffs[0][0] - eq_(diffs[0], "modify_default") - eq_(diffs[2], "user") - eq_(diffs[3], "id") - old = diffs[5] - new = diffs[6] - - is_true(isinstance(old, sa.Identity)) - is_true(isinstance(new, sa.Identity)) diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index d356abcd..3f175cf5 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -3,6 +3,7 @@ from __future__ import annotations import contextlib import re from typing import Any +from typing import Dict from typing import Iterable from typing import Iterator from typing import Mapping @@ -22,6 +23,7 @@ from sqlalchemy.schema import CheckConstraint from sqlalchemy.schema import Column from sqlalchemy.schema import ForeignKeyConstraint from sqlalchemy.sql import visitors +from sqlalchemy.sql.base import DialectKWArgs from sqlalchemy.sql.elements import BindParameter from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.elements import quoted_name @@ -80,9 +82,10 @@ class _Unsupported: try: from sqlalchemy import Computed except ImportError: + if not TYPE_CHECKING: - class Computed(_Unsupported): # type: ignore - pass + class Computed(_Unsupported): + pass has_computed = False has_computed_reflection = False @@ -93,26 +96,54 @@ else: try: from sqlalchemy import Identity except ImportError: + if not TYPE_CHECKING: - class Identity(_Unsupported): # type: ignore - pass + class Identity(_Unsupported): + pass has_identity = False else: - # attributes common to Identity and Sequence - _identity_options_attrs = ( - "start", - "increment", - "minvalue", - "maxvalue", - "nominvalue", - "nomaxvalue", - "cycle", - "cache", - "order", - ) - # attributes of Identity - _identity_attrs = _identity_options_attrs + ("on_null",) + identity_has_dialect_kwargs = issubclass(Identity, DialectKWArgs) + + def _get_identity_options_dict( + identity: Union[Identity, schema.Sequence, None], + dialect_kwargs: bool = False, + ) -> Dict[str, Any]: + if identity is None: + return {} + elif identity_has_dialect_kwargs: + as_dict = identity._as_dict() # type: ignore + if dialect_kwargs: + assert isinstance(identity, DialectKWArgs) + as_dict.update(identity.dialect_kwargs) + else: + as_dict = {} + if isinstance(identity, Identity): + # always=None means something different than always=False + as_dict["always"] = identity.always + if identity.on_null is not None: + as_dict["on_null"] = identity.on_null + # attributes common to Identity and Sequence + attrs = ( + "start", + "increment", + "minvalue", + "maxvalue", + "nominvalue", + "nomaxvalue", + "cycle", + "cache", + "order", + ) + as_dict.update( + { + key: getattr(identity, key, None) + for key in attrs + if getattr(identity, key, None) is not None + } + ) + return as_dict + has_identity = True if sqla_2: diff --git a/docs/build/unreleased/1304.rst b/docs/build/unreleased/1304.rst new file mode 100644 index 00000000..089adbb2 --- /dev/null +++ b/docs/build/unreleased/1304.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: usecase + :tickets: 1304 + + Alembic now accommodates for Sequence and Identity that support dialect kwargs. + This is a change that will be added to SQLAlchemy v2.1. diff --git a/tests/requirements.py b/tests/requirements.py index 347d119a..2f259aa1 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -384,10 +384,6 @@ class DefaultRequirements(SuiteRequirements): ["postgresql >= 10", "oracle >= 12"] ) - @property - def supports_identity_on_null(self): - return self.identity_columns + exclusions.only_on(["oracle"]) - @property def legacy_engine(self): return exclusions.only_if( diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py index cd0b8b4a..b00fdc1f 100644 --- a/tests/test_autogen_render.py +++ b/tests/test_autogen_render.py @@ -49,6 +49,7 @@ from alembic.testing import eq_ignore_whitespace from alembic.testing import mock from alembic.testing import TestBase from alembic.testing.fixtures import op_fixture +from alembic.util import sqla_compat class AutogenRenderTest(TestBase): @@ -2150,15 +2151,13 @@ class AutogenRenderTest(TestBase): % persisted, ) - @config.requirements.identity_columns_api - @testing.combinations( + identity_comb = testing.combinations( ({}, "sa.Identity(always=False)"), (dict(always=None), "sa.Identity(always=None)"), (dict(always=True), "sa.Identity(always=True)"), ( dict( always=False, - on_null=True, start=2, increment=4, minvalue=-3, @@ -2167,13 +2166,30 @@ class AutogenRenderTest(TestBase): nomaxvalue=True, cycle=True, cache=42, - order=True, ), - "sa.Identity(always=False, on_null=True, start=2, increment=4, " + "sa.Identity(always=False, start=2, increment=4, " "minvalue=-3, maxvalue=99, nominvalue=True, nomaxvalue=True, " - "cycle=True, cache=42, order=True)", + "cycle=True, cache=42)", + ), + ( + dict(start=42, oracle_on_null=True, oracle_order=False), + "sa.Identity(always=False, start=42, oracle_on_null=True, " + "oracle_order=False)", + testing.exclusions.only_if( + lambda: sqla_compat.identity_has_dialect_kwargs + ), + ), + ( + dict(start=42, on_null=True, order=False), + "sa.Identity(always=False, on_null=True, start=42, order=False)", + testing.exclusions.only_if( + lambda: not sqla_compat.identity_has_dialect_kwargs + ), ), ) + + @config.requirements.identity_columns_api + @identity_comb def test_render_add_column_identity(self, kw, text): col = Column("x", Integer, sa.Identity(**kw)) op_obj = ops.AddColumnOp("foo", col) @@ -2184,29 +2200,7 @@ class AutogenRenderTest(TestBase): ) @config.requirements.identity_columns_api - @testing.combinations( - ({}, "sa.Identity(always=False)"), - (dict(always=None), "sa.Identity(always=None)"), - (dict(always=True), "sa.Identity(always=True)"), - ( - dict( - always=False, - on_null=True, - start=2, - increment=4, - minvalue=-3, - maxvalue=99, - nominvalue=True, - nomaxvalue=True, - cycle=True, - cache=42, - order=True, - ), - "sa.Identity(always=False, on_null=True, start=2, increment=4, " - "minvalue=-3, maxvalue=99, nominvalue=True, nomaxvalue=True, " - "cycle=True, cache=42, order=True)", - ), - ) + @identity_comb def test_render_alter_column_add_identity(self, kw, text): op_obj = ops.AlterColumnOp( "foo", diff --git a/tests/test_oracle.py b/tests/test_oracle.py index 63ac1c4b..3b6ddcac 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -1,12 +1,17 @@ +import sqlalchemy as sa from sqlalchemy import Column from sqlalchemy import exc from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import Table from alembic import command from alembic import op from alembic.testing import assert_raises_message from alembic.testing import combinations from alembic.testing import config +from alembic.testing import eq_ +from alembic.testing import is_true from alembic.testing.env import _no_sql_testing_config from alembic.testing.env import clear_staging_env from alembic.testing.env import staging_env @@ -14,6 +19,7 @@ from alembic.testing.env import three_rev_fixture from alembic.testing.fixtures import capture_context_buffer from alembic.testing.fixtures import op_fixture from alembic.testing.fixtures import TestBase +from alembic.testing.suite._autogen_fixtures import AutogenFixtureTest from alembic.util import sqla_compat @@ -253,31 +259,48 @@ class OpTest(TestBase): # 'ALTER TABLE y.t RENAME COLUMN c TO c2' # ) + +class IdentityTest(AutogenFixtureTest, TestBase): + __requires__ = ("identity_columns",) + __backend__ = True + __only_on__ = "oracle" + def _identity_qualification(self, kw): always = kw.get("always", False) if always is None: return "" qualification = "ALWAYS" if always else "BY DEFAULT" - if kw.get("on_null", False): + if kw.get("oracle_on_null", False): qualification += " ON NULL" return qualification - @config.requirements.identity_columns + def _adapt_identity_kw(self, data): + res = data.copy() + if not sqla_compat.identity_has_dialect_kwargs: + for k in data: + if k.startswith("oracle_"): + res[k[7:]] = res.pop(k) + return res + @combinations( ({}, None), (dict(always=True), None), - (dict(always=None, order=True), "ORDER"), + (dict(always=None, oracle_order=True), "ORDER"), ( dict(start=3, increment=33, maxvalue=99, cycle=True), "INCREMENT BY 33 START WITH 3 MAXVALUE 99 CYCLE", ), - (dict(on_null=True, start=42), "START WITH 42"), + (dict(oracle_on_null=True, start=42), "START WITH 42"), ) def test_add_column_identity(self, kw, text): context = op_fixture("oracle") op.add_column( "t1", - Column("some_column", Integer, sqla_compat.Identity(**kw)), + Column( + "some_column", + Integer, + sqla_compat.Identity(**self._adapt_identity_kw(kw)), + ), ) qualification = self._identity_qualification(kw) options = " (%s)" % text if text else "" @@ -286,7 +309,6 @@ class OpTest(TestBase): "INTEGER GENERATED %s AS IDENTITY%s" % (qualification, options) ) - @config.requirements.identity_columns @combinations( ({}, None), (dict(always=True), None), @@ -295,14 +317,14 @@ class OpTest(TestBase): dict(start=3, increment=33, maxvalue=99, cycle=True), "INCREMENT BY 33 START WITH 3 MAXVALUE 99 CYCLE", ), - (dict(on_null=True, start=42), "START WITH 42"), + (dict(oracle_on_null=True, start=42), "START WITH 42"), ) def test_add_identity_to_column(self, kw, text): context = op_fixture("oracle") op.alter_column( "t1", "some_column", - server_default=sqla_compat.Identity(**kw), + server_default=sqla_compat.Identity(**self._adapt_identity_kw(kw)), existing_server_default=None, ) qualification = self._identity_qualification(kw) @@ -312,7 +334,6 @@ class OpTest(TestBase): "GENERATED %s AS IDENTITY%s" % (qualification, options) ) - @config.requirements.identity_columns def test_remove_identity_from_column(self): context = op_fixture("oracle") op.alter_column( @@ -323,7 +344,6 @@ class OpTest(TestBase): ) context.assert_("ALTER TABLE t1 MODIFY some_column DROP IDENTITY") - @config.requirements.identity_columns @combinations( ({}, dict(always=True), None), ( @@ -350,7 +370,13 @@ class OpTest(TestBase): maxvalue=9999, minvalue=0, ), - dict(always=False, start=3, order=True, on_null=False, cache=2), + dict( + always=False, + start=3, + oracle_order=True, + oracle_on_null=False, + cache=2, + ), "START WITH 3 CACHE 2 ORDER", ), ( @@ -364,8 +390,12 @@ class OpTest(TestBase): op.alter_column( "t1", "some_column", - server_default=sqla_compat.Identity(**updated), - existing_server_default=sqla_compat.Identity(**existing), + server_default=sqla_compat.Identity( + **self._adapt_identity_kw(updated) + ), + existing_server_default=sqla_compat.Identity( + **self._adapt_identity_kw(existing) + ), ) qualification = self._identity_qualification(updated) @@ -374,3 +404,49 @@ class OpTest(TestBase): "ALTER TABLE t1 MODIFY some_column " "GENERATED %s AS IDENTITY%s" % (qualification, options) ) + + def test_identity_on_null(self): + m1 = MetaData() + m2 = MetaData() + + Table( + "user", + m1, + Column( + "id", + Integer, + sqla_compat.Identity( + **self._adapt_identity_kw( + dict(start=2, oracle_on_null=True) + ) + ), + ), + Column("other", sa.Text), + ) + + Table( + "user", + m2, + Column( + "id", + Integer, + sa.Identity( + **self._adapt_identity_kw( + dict(start=2, oracle_on_null=False) + ) + ), + ), + Column("other", sa.Text), + ) + + diffs = self._fixture(m1, m2) + eq_(len(diffs[0]), 1) + diffs = diffs[0][0] + eq_(diffs[0], "modify_default") + eq_(diffs[2], "user") + eq_(diffs[3], "id") + old = diffs[5] + new = diffs[6] + + is_true(isinstance(old, sa.Identity)) + is_true(isinstance(new, sa.Identity)) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 99f6e9fc..4b328e55 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -444,7 +444,7 @@ class PostgresqlOpTest(TestBase): maxvalue=9999, minvalue=0, ), - dict(always=False, start=3, order=True, on_null=False, cache=2), + dict(always=False, start=3, cache=2), "SET CACHE 2", ), (