From: Federico Caselli Date: Sat, 19 Sep 2020 20:29:38 +0000 (+0200) Subject: Add reflection for Identity columns X-Git-Tag: rel_1_4_0b1~67 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=7362d454f46107cae4076ce54e9fa430c3370734;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add reflection for Identity columns Added support for reflecting "identity" columns, which are now returned as part of the structure returned by :meth:`_reflection.Inspector.get_columns`. When reflecting full :class:`_schema.Table` objects, identity columns will be represented using the :class:`_schema.Identity` construct. Fixed compilation error on oracle for sequence and identity column ``nominvalue`` and ``nomaxvalue`` options that require no space in them. Improved test compatibility with oracle 18. As part of the support for reflecting :class:`_schema.Identity` objects, the method :meth:`_reflection.Inspector.get_columns` no longer returns ``mssql_identity_start`` and ``mssql_identity_increment`` as part of the ``dialect_options``. Use the information in the ``identity`` key instead. The mssql dialect will assume that at least MSSQL 2005 is used. There is no hard exception raised if a previous version is detected, but operations may fail for older versions. Fixes: #5527 Fixes: #5324 Change-Id: If039fe637c46b424499e6bac54a2cbc0dc54cb57 --- diff --git a/doc/build/changelog/unreleased_14/5506.rst b/doc/build/changelog/unreleased_14/5506.rst index 71b57322d3..1614efecab 100644 --- a/doc/build/changelog/unreleased_14/5506.rst +++ b/doc/build/changelog/unreleased_14/5506.rst @@ -1,5 +1,5 @@ .. change:: - :tags: usecase, mssql + :tags: usecase, mssql, reflection :tickets: 5506 Added support for reflection of temporary tables with the SQL Server dialect. diff --git a/doc/build/changelog/unreleased_14/5527.rst b/doc/build/changelog/unreleased_14/5527.rst new file mode 100644 index 0000000000..e6721df121 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5527.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: sql, reflection + :tickets: 5527, 5324 + + Added support for reflecting "identity" columns, which are now returned + as part of the structure returned by :meth:`_reflection.Inspector.get_columns`. + When reflecting full :class:`_schema.Table` objects, identity columns will + be represented using the :class:`_schema.Identity` construct. + Currently the supported backends are + PostgreSQL >= 10, Oracle >= 12 and MSSQL (with different syntax + and a subset of functionalities). diff --git a/doc/build/changelog/unreleased_14/mssql_identity_dialect_options.rst b/doc/build/changelog/unreleased_14/mssql_identity_dialect_options.rst new file mode 100644 index 0000000000..8676c7d252 --- /dev/null +++ b/doc/build/changelog/unreleased_14/mssql_identity_dialect_options.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: mssql, reflection + :tickets: 5527 + + As part of the support for reflecting :class:`_schema.Identity` objects, + the method :meth:`_reflection.Inspector.get_columns` no longer returns + ``mssql_identity_start`` and ``mssql_identity_increment`` as part of the + ``dialect_options``. Use the information in the ``identity`` key instead. diff --git a/doc/build/changelog/unreleased_14/mssql_version.rst b/doc/build/changelog/unreleased_14/mssql_version.rst new file mode 100644 index 0000000000..9454cc45a4 --- /dev/null +++ b/doc/build/changelog/unreleased_14/mssql_version.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: mssql + + The mssql dialect will assume that at least MSSQL 2005 is used. + There is no hard exception raised if a previous version is detected, + but operations may fail for older versions. diff --git a/doc/build/changelog/unreleased_14/oracle_sequence_options.rst b/doc/build/changelog/unreleased_14/oracle_sequence_options.rst new file mode 100644 index 0000000000..738958d0c5 --- /dev/null +++ b/doc/build/changelog/unreleased_14/oracle_sequence_options.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: oracle, bug + + Correctly render :class:`_schema.Sequence` and :class:`_schema.Identity` + column options ``nominvalue`` and ``nomaxvalue`` as ``NOMAXVALUE` and + ``NOMINVALUE`` on oracle database. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 87ccc8427c..a2ed2b47da 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -76,6 +76,9 @@ is set to ``False`` on any integer primary key column:: to :class:`_schema.Column` are deprecated and should we replaced by an :class:`_schema.Identity` object. Specifying both ways of configuring an IDENTITY will result in a compile error. + These options are also no longer returned as part of the + ``dialect_options`` key in :meth:`_reflection.Inspector.get_columns`. + Use the information in the ``identity`` key instead. .. deprecated:: 1.3 @@ -770,6 +773,7 @@ from ...types import NVARCHAR from ...types import SMALLINT from ...types import TEXT from ...types import VARCHAR +from ...util import compat from ...util import update_wrapper from ...util.langhelpers import public_factory @@ -2524,6 +2528,7 @@ def _schema_elements(schema): class MSDialect(default.DefaultDialect): + # will assume it's at least mssql2005 name = "mssql" supports_default_values = True supports_empty_insert = False @@ -2649,11 +2654,6 @@ class MSDialect(default.DefaultDialect): connection.commit() def get_isolation_level(self, connection): - if self.server_version_info < MS_2005_VERSION: - raise NotImplementedError( - "Can't fetch isolation level prior to SQL Server 2005" - ) - last_error = None views = ("sys.dm_exec_sessions", "sys.dm_pdw_nodes_exec_sessions") @@ -2719,9 +2719,6 @@ class MSDialect(default.DefaultDialect): % ".".join(str(x) for x in self.server_version_info) ) - if self.server_version_info < MS_2005_VERSION: - self.implicit_returning = self.full_returning = False - if self.server_version_info >= MS_2008_VERSION: self.supports_multivalues_insert = True if self.deprecate_large_types is None: @@ -2744,17 +2741,14 @@ class MSDialect(default.DefaultDialect): self._supports_nvarchar_max = True def _get_default_schema_name(self, connection): - if self.server_version_info < MS_2005_VERSION: - return self.schema_name + query = sql.text("SELECT schema_name()") + default_schema_name = connection.scalar(query) + if default_schema_name is not None: + # guard against the case where the default_schema_name is being + # fed back into a table reflection function. + return quoted_name(default_schema_name, quote=True) else: - query = sql.text("SELECT schema_name()") - default_schema_name = connection.scalar(query) - if default_schema_name is not None: - # guard against the case where the default_schema_name is being - # fed back into a table reflection function. - return quoted_name(default_schema_name, quote=True) - else: - return self.schema_name + return self.schema_name @_db_plus_owner def has_table(self, connection, tablename, dbname, owner, schema): @@ -2860,11 +2854,6 @@ class MSDialect(default.DefaultDialect): @reflection.cache @_db_plus_owner def get_indexes(self, connection, tablename, dbname, owner, schema, **kw): - # using system catalogs, don't support index reflection - # below MS 2005 - if self.server_version_info < MS_2005_VERSION: - return [] - rp = connection.execution_options(future_result=True).execute( sql.text( "select ind.index_id, ind.is_unique, ind.name, " @@ -3002,25 +2991,32 @@ class MSDialect(default.DefaultDialect): columns = ischema.columns computed_cols = ischema.computed_columns + identity_cols = ischema.identity_columns if owner: whereclause = sql.and_( columns.c.table_name == tablename, columns.c.table_schema == owner, ) - table_fullname = "%s.%s" % (owner, tablename) full_name = columns.c.table_schema + "." + columns.c.table_name - join_on = computed_cols.c.object_id == func.object_id(full_name) else: whereclause = columns.c.table_name == tablename - table_fullname = tablename - join_on = computed_cols.c.object_id == func.object_id( - columns.c.table_name - ) + full_name = columns.c.table_name - join_on = sql.and_( - join_on, columns.c.column_name == computed_cols.c.name + join = columns.join( + computed_cols, + onclause=sql.and_( + computed_cols.c.object_id == func.object_id(full_name), + computed_cols.c.name == columns.c.column_name, + ), + isouter=True, + ).join( + identity_cols, + onclause=sql.and_( + identity_cols.c.object_id == func.object_id(full_name), + identity_cols.c.name == columns.c.column_name, + ), + isouter=True, ) - join = columns.join(computed_cols, onclause=join_on, isouter=True) if self._supports_nvarchar_max: computed_definition = computed_cols.c.definition @@ -3032,7 +3028,12 @@ class MSDialect(default.DefaultDialect): s = ( sql.select( - columns, computed_definition, computed_cols.c.is_persisted + columns, + computed_definition, + computed_cols.c.is_persisted, + identity_cols.c.is_identity, + identity_cols.c.seed_value, + identity_cols.c.increment_value, ) .where(whereclause) .select_from(join) @@ -3053,6 +3054,9 @@ class MSDialect(default.DefaultDialect): collation = row[columns.c.collation_name] definition = row[computed_definition] is_persisted = row[computed_cols.c.is_persisted] + is_identity = row[identity_cols.c.is_identity] + identity_start = row[identity_cols.c.seed_value] + identity_increment = row[identity_cols.c.increment_value] coltype = self.ischema_names.get(type_, None) @@ -3093,7 +3097,7 @@ class MSDialect(default.DefaultDialect): "type": coltype, "nullable": nullable, "default": default, - "autoincrement": False, + "autoincrement": is_identity is not None, } if definition is not None and is_persisted is not None: @@ -3102,50 +3106,28 @@ class MSDialect(default.DefaultDialect): "persisted": is_persisted, } - cols.append(cdict) - # autoincrement and identity - colmap = {} - for col in cols: - colmap[col["name"]] = col - # We also run an sp_columns to check for identity columns: - cursor = connection.execute( - sql.text( - "sp_columns @table_name = :table_name, " - "@table_owner = :table_owner", - ), - {"table_name": tablename, "table_owner": owner}, - ) - ic = None - while True: - row = cursor.fetchone() - if row is None: - break - (col_name, type_name) = row[3], row[5] - if type_name.endswith("identity") and col_name in colmap: - ic = col_name - colmap[col_name]["autoincrement"] = True - colmap[col_name]["dialect_options"] = { - "mssql_identity_start": 1, - "mssql_identity_increment": 1, - } - break - cursor.close() + if is_identity is not None: + # identity_start and identity_increment are Decimal or None + if identity_start is None or identity_increment is None: + cdict["identity"] = {} + else: + if isinstance(coltype, sqltypes.BigInteger): + start = compat.long_type(identity_start) + increment = compat.long_type(identity_increment) + elif isinstance(coltype, sqltypes.Integer): + start = int(identity_start) + increment = int(identity_increment) + else: + start = identity_start + increment = identity_increment + + cdict["identity"] = { + "start": start, + "increment": increment, + } - if ic is not None and self.server_version_info >= MS_2005_VERSION: - table_fullname = "%s.%s" % (owner, tablename) - cursor = connection.exec_driver_sql( - "select ident_seed('%s'), ident_incr('%s')" - % (table_fullname, table_fullname) - ) + cols.append(cdict) - row = cursor.first() - if row is not None and row[0] is not None: - colmap[ic]["dialect_options"].update( - { - "mssql_identity_start": int(row[0]), - "mssql_identity_increment": int(row[1]), - } - ) return cols @reflection.cache diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index f80110b7d4..974a55963e 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -14,6 +14,7 @@ from ...ext.compiler import compiles from ...sql import expression from ...types import Boolean from ...types import Integer +from ...types import Numeric from ...types import String from ...types import TypeDecorator from ...types import Unicode @@ -198,3 +199,32 @@ sequences = Table( Column("SEQUENCE_NAME", CoerceUnicode, key="sequence_name"), schema="INFORMATION_SCHEMA", ) + + +class IdentitySqlVariant(TypeDecorator): + r"""This type casts sql_variant columns in the identity_columns view + to numeric. This is required because: + + * pyodbc does not support sql_variant + * pymssql under python 2 return the byte representation of the number, + int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the + correct value as string. + """ + impl = Unicode + + def column_expression(self, colexpr): + return cast(colexpr, Numeric) + + +identity_columns = Table( + "identity_columns", + ischema, + Column("object_id", Integer), + Column("name", CoerceUnicode), + Column("is_identity", Boolean), + Column("seed_value", IdentitySqlVariant), + Column("increment_value", IdentitySqlVariant), + Column("last_value", IdentitySqlVariant), + Column("is_not_for_replication", Boolean), + schema="sys", +) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 3afd6fc3a7..fad1b0bbe0 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -536,7 +536,7 @@ from ...types import NCHAR from ...types import NVARCHAR from ...types import TIMESTAMP from ...types import VARCHAR - +from ...util import compat RESERVED_WORDS = set( "SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN " @@ -979,7 +979,11 @@ class OracleCompiler(compiler.SQLCompiler): for i, column in enumerate( expression._select_iterables(returning_cols) ): - if self.isupdate and isinstance(column.server_default, Computed): + if ( + self.isupdate + and isinstance(column.server_default, Computed) + and not self.dialect._supports_update_returning_computed_cols + ): util.warn( "Computed columns don't work with Oracle UPDATE " "statements that use RETURNING; the value of the column " @@ -1317,6 +1321,14 @@ class OracleDDLCompiler(compiler.DDLCompiler): return "".join(table_opts) + def get_identity_options(self, identity_options): + text = super(OracleDDLCompiler, self).get_identity_options( + identity_options + ) + return text.replace("NO MINVALUE", "NOMINVALUE").replace( + "NO MAXVALUE", "NOMAXVALUE" + ) + def visit_computed_column(self, generated): text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( generated.sqltext, include_table=False, literal_binds=True @@ -1493,6 +1505,12 @@ class OracleDialect(default.DefaultDialect): def _supports_char_length(self): return not self._is_oracle_8 + @property + def _supports_update_returning_computed_cols(self): + # on version 18 this error is no longet present while it happens on 11 + # it may work also on versions before the 18 + return self.server_version_info and self.server_version_info >= (18,) + def do_release_savepoint(self, connection, name): # Oracle does not support RELEASE SAVEPOINT pass @@ -1825,11 +1843,32 @@ class OracleDialect(default.DefaultDialect): else: char_length_col = "data_length" + if self.server_version_info >= (12,): + identity_cols = """\ + col.default_on_null, + ( + SELECT id.generation_type || ',' || id.IDENTITY_OPTIONS + FROM ALL_TAB_IDENTITY_COLS id + WHERE col.table_name = id.table_name + AND col.column_name = id.column_name + AND col.owner = id.owner + ) AS identity_options""" + else: + identity_cols = "NULL as default_on_null, NULL as identity_options" + params = {"table_name": table_name} text = """ - SELECT col.column_name, col.data_type, col.%(char_length_col)s, - col.data_precision, col.data_scale, col.nullable, - col.data_default, com.comments, col.virtual_column\ + SELECT + col.column_name, + col.data_type, + col.%(char_length_col)s, + col.data_precision, + col.data_scale, + col.nullable, + col.data_default, + com.comments, + col.virtual_column, + %(identity_cols)s FROM all_tab_cols%(dblink)s col LEFT JOIN all_col_comments%(dblink)s com ON col.table_name = com.table_name @@ -1842,7 +1881,11 @@ class OracleDialect(default.DefaultDialect): params["owner"] = schema text += " AND col.owner = :owner " text += " ORDER BY col.column_id" - text = text % {"dblink": dblink, "char_length_col": char_length_col} + text = text % { + "dblink": dblink, + "char_length_col": char_length_col, + "identity_cols": identity_cols, + } c = connection.execute(sql.text(text), params) @@ -1857,6 +1900,8 @@ class OracleDialect(default.DefaultDialect): default = row[6] comment = row[7] generated = row[8] + default_on_nul = row[9] + identity_options = row[10] if coltype == "NUMBER": if precision is None and scale == 0: @@ -1887,6 +1932,14 @@ class OracleDialect(default.DefaultDialect): else: computed = None + if identity_options is not None: + identity = self._parse_identity_options( + identity_options, default_on_nul + ) + default = None + else: + identity = None + cdict = { "name": colname, "type": coltype, @@ -1899,10 +1952,44 @@ class OracleDialect(default.DefaultDialect): cdict["quote"] = True if computed is not None: cdict["computed"] = computed + if identity is not None: + cdict["identity"] = identity columns.append(cdict) return columns + def _parse_identity_options(self, identity_options, default_on_nul): + # identity_options is a string that starts with 'ALWAYS,' or + # 'BY DEFAULT,' and contues with + # START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1, + # CYCLE_FLAG: N, CACHE_SIZE: 1, ORDER_FLAG: N, SCALE_FLAG: N, + # EXTEND_FLAG: N, SESSION_FLAG: N, KEEP_VALUE: N + parts = [p.strip() for p in identity_options.split(",")] + identity = { + "always": parts[0] == "ALWAYS", + "on_null": default_on_nul == "YES", + } + + for part in parts[1:]: + option, value = part.split(":") + value = value.strip() + + if "START WITH" in option: + identity["start"] = compat.long_type(value) + elif "INCREMENT BY" in option: + identity["increment"] = compat.long_type(value) + elif "MAX_VALUE" in option: + identity["maxvalue"] = compat.long_type(value) + elif "MIN_VALUE" in option: + identity["minvalue"] = compat.long_type(value) + elif "CYCLE_FLAG" in option: + identity["cycle"] = value == "Y" + elif "CACHE_SIZE" in option: + identity["cache"] = compat.long_type(value) + elif "ORDER_FLAG" in option: + identity["order"] = value == "Y" + return identity + @reflection.cache def get_table_comment( self, diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 5ed56db560..82f0126a0f 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -3234,17 +3234,44 @@ class PGDialect(default.DefaultDialect): if self.server_version_info >= (12,) else "NULL as generated" ) - SQL_COLS = ( - """ + if self.server_version_info >= (10,): + # a.attidentity != '' is required or it will reflect also + # serial columns as identity. + identity = """\ + (SELECT json_build_object( + 'always', a.attidentity = 'a', + 'start', s.seqstart, + 'increment', s.seqincrement, + 'minvalue', s.seqmin, + 'maxvalue', s.seqmax, + 'cache', s.seqcache, + 'cycle', s.seqcycle) + FROM pg_catalog.pg_sequence s + JOIN pg_catalog.pg_class c on s.seqrelid = c."oid" + JOIN pg_catalog.pg_namespace n on n.oid = c.relnamespace + WHERE c.relkind = 'S' + AND a.attidentity != '' + AND n.nspname || '.' || c.relname = + pg_catalog.pg_get_serial_sequence( + a.attrelid::regclass::text, a.attname) + ) as identity_options\ + """ + else: + identity = "NULL as identity" + + SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), - (SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid) + ( + SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid) FROM pg_catalog.pg_attrdef d - WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum - AND a.atthasdef) - AS DEFAULT, - a.attnotnull, a.attnum, a.attrelid as table_oid, + WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum + AND a.atthasdef + ) AS DEFAULT, + a.attnotnull, + a.attrelid as table_oid, pgd.description as comment, + %s, %s FROM pg_catalog.pg_attribute a LEFT JOIN pg_catalog.pg_description pgd ON ( @@ -3252,8 +3279,9 @@ class PGDialect(default.DefaultDialect): WHERE a.attrelid = :table_oid AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum - """ - % generated + """ % ( + generated, + identity, ) s = ( sql.text(SQL_COLS) @@ -3284,10 +3312,10 @@ class PGDialect(default.DefaultDialect): format_type, default_, notnull, - attnum, table_oid, comment, generated, + identity, ) in rows: column_info = self._get_column_info( name, @@ -3299,6 +3327,7 @@ class PGDialect(default.DefaultDialect): schema, comment, generated, + identity, ) columns.append(column_info) return columns @@ -3314,6 +3343,7 @@ class PGDialect(default.DefaultDialect): schema, comment, generated, + identity, ): def _handle_array_type(attype): return ( @@ -3428,7 +3458,6 @@ class PGDialect(default.DefaultDialect): # If a zero byte or blank string depending on driver (is also absent # for older PG versions), then not a generated column. Otherwise, s = # stored. (Other values might be added in the future.) - # if generated not in (None, "", b"\x00"): computed = dict( sqltext=default, persisted=generated in ("s", b"s") @@ -3463,11 +3492,13 @@ class PGDialect(default.DefaultDialect): type=coltype, nullable=nullable, default=default, - autoincrement=autoincrement, + autoincrement=autoincrement or identity is not None, comment=comment, ) if computed is not None: column_info["computed"] = computed + if identity is not None: + column_info["identity"] = identity return column_info @reflection.cache diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 812f7ceeca..92a85df2e6 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -472,8 +472,13 @@ class Inspector(object): .. versionadded:: 1.3.16 - added support for computed reflection. - * ``dialect_options`` - (optional) a dict with dialect specific options + * ``identity`` - (optional) when present it indicates that this column + is a generated always column. Only some dialects return this key. + For a list of keywords on this dict see :class:`_schema.Identity`. + + .. versionadded:: 1.4 - added support for identity column reflection. + * ``dialect_options`` - (optional) a dict with dialect specific options :param table_name: string name of the table. For special quoting, use :class:`.quoted_name`. @@ -880,6 +885,10 @@ class Inspector(object): computed = sa_schema.Computed(**col_d["computed"]) colargs.append(computed) + if "identity" in col_d: + computed = sa_schema.Identity(**col_d["identity"]) + colargs.append(computed) + if "sequence" in col_d: self._reflect_col_sequence(col_d, colargs) diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index f728310d7c..f8f93b5635 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -15,6 +15,7 @@ from ..schema import Column from ..schema import Table from ... import event from ... import ForeignKey +from ... import Identity from ... import inspect from ... import Integer from ... import MetaData @@ -1442,6 +1443,144 @@ class ComputedReflectionTest(fixtures.ComputedReflectionFixtureTest): ) +class IdentityReflectionTest(fixtures.TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + __requires__ = ("identity_columns", "table_reflection") + + @classmethod + def define_tables(cls, metadata): + Table( + "t1", + metadata, + Column("normal", Integer), + Column("id1", Integer, Identity()), + ) + Table( + "t2", + metadata, + Column( + "id2", + Integer, + Identity( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + ), + ) + if testing.requires.schemas.enabled: + Table( + "t1", + metadata, + Column("normal", Integer), + Column("id1", Integer, Identity(always=True, start=20)), + schema=config.test_schema, + ) + + def check(self, value, exp, approx): + if testing.requires.identity_columns_standard.enabled: + common_keys = ( + "always", + "start", + "increment", + "minvalue", + "maxvalue", + "cycle", + "cache", + ) + for k in list(value): + if k not in common_keys: + value.pop(k) + if approx: + eq_(len(value), len(exp)) + for k in value: + if k == "minvalue": + is_true(value[k] <= exp[k]) + elif k in {"maxvalue", "cache"}: + is_true(value[k] >= exp[k]) + else: + eq_(value[k], exp[k], k) + else: + eq_(value, exp) + else: + eq_(value["start"], exp["start"]) + eq_(value["increment"], exp["increment"]) + + def test_reflect_identity(self): + insp = inspect(config.db) + + cols = insp.get_columns("t1") + insp.get_columns("t2") + for col in cols: + if col["name"] == "normal": + is_false("identity" in col) + elif col["name"] == "id1": + is_true(col["autoincrement"] in (True, "auto")) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=False, + start=1, + increment=1, + minvalue=1, + maxvalue=2147483647, + cycle=False, + cache=1, + ), + approx=True, + ) + elif col["name"] == "id2": + is_true(col["autoincrement"] in (True, "auto")) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + approx=False, + ) + + @testing.requires.schemas + def test_reflect_identity_schema(self): + insp = inspect(config.db) + + cols = insp.get_columns("t1", schema=config.test_schema) + for col in cols: + if col["name"] == "normal": + is_false("identity" in col) + elif col["name"] == "id1": + is_true(col["autoincrement"] in (True, "auto")) + eq_(col["default"], None) + is_true("identity" in col) + self.check( + col["identity"], + dict( + always=True, + start=20, + increment=1, + minvalue=1, + maxvalue=2147483647, + cycle=False, + cache=1, + ), + approx=True, + ) + + __all__ = ( "ComponentReflectionTest", "QuotedNameArgumentTest", @@ -1449,4 +1588,5 @@ __all__ = ( "HasIndexTest", "NormalizedNameTest", "ComputedReflectionTest", + "IdentityReflectionTest", ) diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index b0fb60c5f6..224826c257 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -1047,7 +1047,9 @@ class IdentityColumnTest(fixtures.TablesTest): Column( "id", Integer, - Identity(always=True, start=42), + Identity( + always=True, start=42, nominvalue=True, nomaxvalue=True + ), primary_key=True, ), Column("desc", String(100)), @@ -1058,12 +1060,7 @@ class IdentityColumnTest(fixtures.TablesTest): Column( "id", Integer, - Identity( - increment=-5, - start=0, - minvalue=-1000, - maxvalue=0, - ), + Identity(increment=-5, start=0, minvalue=-1000, maxvalue=0), primary_key=True, ), Column("desc", String(100)), diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index de970da53c..d8c35ed0b6 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -105,7 +105,9 @@ class HasSequenceTest(fixtures.TablesTest): @classmethod def define_tables(cls, metadata): Sequence("user_id_seq", metadata=metadata) - Sequence("other_seq", metadata=metadata) + Sequence( + "other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True + ) if testing.requires.schemas.enabled: Sequence( "user_id_seq", schema=config.test_schema, metadata=metadata diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 285f6c0216..c71deffbc2 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -122,6 +122,7 @@ if py3k: text_type = str int_types = (int,) iterbytes = iter + long_type = int itertools_filterfalse = itertools.filterfalse itertools_filter = filter @@ -226,6 +227,7 @@ else: binary_type = str text_type = unicode # noqa int_types = int, long # noqa + long_type = long # noqa callable = callable # noqa cmp = cmp # noqa diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index d33838b6a3..d791e9ec14 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -1,11 +1,13 @@ # -*- encoding: utf-8 import datetime +import decimal from sqlalchemy import Column from sqlalchemy import DDL from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey +from sqlalchemy import Identity from sqlalchemy import Index from sqlalchemy import inspect from sqlalchemy import Integer @@ -29,6 +31,7 @@ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock @@ -146,8 +149,13 @@ class ReflectionTest(fixtures.TestBase, ComparesTables, AssertsCompiledSQL): meta2 = MetaData(testing.db) table2 = Table("identity_test", meta2, autoload=True) - eq_(table2.c["col1"].dialect_options["mssql"]["identity_start"], 2) - eq_(table2.c["col1"].dialect_options["mssql"]["identity_increment"], 3) + eq_(table2.c["col1"].dialect_options["mssql"]["identity_start"], None) + eq_( + table2.c["col1"].dialect_options["mssql"]["identity_increment"], + None, + ) + eq_(table2.c["col1"].identity.start, 2) + eq_(table2.c["col1"].identity.increment, 3) @testing.provide_metadata def test_skip_types(self, connection): @@ -673,3 +681,76 @@ class OwnerPlusDBTest(fixtures.TestBase): [mock.call.scalar()], ) eq_(mock_lambda.mock_calls, [mock.call("x", y="bar")]) + + +class IdentityReflectionTest(fixtures.TablesTest): + __only_on__ = "mssql" + __backend__ = True + __requires__ = ("identity_columns",) + + @classmethod + def define_tables(cls, metadata): + + for i, col in enumerate( + [ + Column( + "id1", + Integer, + Identity( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + ), + Column("id2", Integer, Identity()), + Column("id3", sqltypes.BigInteger, Identity()), + Column("id4", sqltypes.SmallInteger, Identity()), + Column("id5", sqltypes.Numeric, Identity()), + ] + ): + Table("t%s" % i, metadata, col) + + def test_reflect_identity(self): + insp = inspect(testing.db) + cols = [] + for t in self.metadata.tables.keys(): + cols.extend(insp.get_columns(t)) + for col in cols: + is_true("dialect_options" not in col) + is_true("identity" in col) + if col["name"] == "id1": + eq_(col["identity"], {"start": 2, "increment": 3}) + elif col["name"] == "id2": + eq_(col["identity"], {"start": 1, "increment": 1}) + eq_(type(col["identity"]["start"]), int) + eq_(type(col["identity"]["increment"]), int) + elif col["name"] == "id3": + eq_(col["identity"], {"start": 1, "increment": 1}) + eq_(type(col["identity"]["start"]), util.compat.long_type) + eq_(type(col["identity"]["increment"]), util.compat.long_type) + elif col["name"] == "id4": + eq_(col["identity"], {"start": 1, "increment": 1}) + eq_(type(col["identity"]["start"]), int) + eq_(type(col["identity"]["increment"]), int) + elif col["name"] == "id5": + eq_(col["identity"], {"start": 1, "increment": 1}) + eq_(type(col["identity"]["start"]), decimal.Decimal) + eq_(type(col["identity"]["increment"]), decimal.Decimal) + + @testing.requires.views + def test_reflect_views(self, connection): + try: + with testing.db.connect() as conn: + conn.exec_driver_sql("CREATE VIEW view1 AS SELECT * FROM t1") + insp = inspect(testing.db) + for col in insp.get_columns("view1"): + is_true("dialect_options" not in col) + is_true("identity" in col) + eq_(col["identity"], {}) + finally: + with testing.db.connect() as conn: + conn.exec_driver_sql("DROP VIEW view1") diff --git a/test/dialect/oracle/test_compiler.py b/test/dialect/oracle/test_compiler.py index cea2b51419..869cffe440 100644 --- a/test/dialect/oracle/test_compiler.py +++ b/test/dialect/oracle/test_compiler.py @@ -28,6 +28,7 @@ from sqlalchemy.dialects.oracle import base as oracle from sqlalchemy.dialects.oracle import cx_oracle from sqlalchemy.engine import default from sqlalchemy.sql import column +from sqlalchemy.sql import ddl from sqlalchemy.sql import quoted_name from sqlalchemy.sql import table from sqlalchemy.testing import assert_raises_message @@ -1258,12 +1259,22 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): t = Table( "t", m, - Column("y", Integer, Identity(always=True, start=4, increment=7)), + Column( + "y", + Integer, + Identity( + always=True, + start=4, + increment=7, + nominvalue=True, + nomaxvalue=True, + ), + ), ) self.assert_compile( schema.CreateTable(t), "CREATE TABLE t (y INTEGER GENERATED ALWAYS AS IDENTITY " - "(INCREMENT BY 7 START WITH 4))", + "(INCREMENT BY 7 START WITH 4 NOMINVALUE NOMAXVALUE))", ) def test_column_identity_no_generated(self): @@ -1310,6 +1321,15 @@ class SequenceTest(fixtures.TestBase, AssertsCompiledSQL): == '"Some_Schema"."My_Seq"' ) + def test_compile(self): + self.assert_compile( + ddl.CreateSequence( + Sequence("my_seq", nomaxvalue=True, nominvalue=True) + ), + "CREATE SEQUENCE my_seq START WITH 1 NOMINVALUE NOMAXVALUE", + dialect=oracle.OracleDialect(), + ) + class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "oracle" diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index cd0e11e588..128ecc573a 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -287,15 +287,21 @@ class ComputedReturningTest(fixtures.TablesTest): with testing.db.connect() as conn: conn.execute(test.insert(), {"id": 1, "foo": 5}) - with testing.expect_warnings( - "Computed columns don't work with Oracle UPDATE" - ): + if testing.db.dialect._supports_update_returning_computed_cols: result = conn.execute( test.update().values(foo=10).return_defaults() ) - - # returns the *old* value - eq_(result.returned_defaults, (47,)) + eq_(result.returned_defaults, (52,)) + else: + with testing.expect_warnings( + "Computed columns don't work with Oracle UPDATE" + ): + result = conn.execute( + test.update().values(foo=10).return_defaults() + ) + + # returns the *old* value + eq_(result.returned_defaults, (47,)) eq_(conn.scalar(select(test.c.bar)), 52) diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index d2780fa29c..bd881ac1c3 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -6,6 +6,7 @@ from sqlalchemy import FLOAT from sqlalchemy import ForeignKey from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func +from sqlalchemy import Identity from sqlalchemy import Index from sqlalchemy import inspect from sqlalchemy import INTEGER @@ -27,6 +28,7 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_true from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -448,7 +450,8 @@ def all_tables_compression_missing(): if ( "Enterprise Edition" not in exec_sql(testing.db, "select * from v$version").scalar() - ): + # this works in Oracle Database 18c Express Edition Release + ) and testing.db.dialect.server_version_info < (18,): return True return False except Exception: @@ -819,3 +822,39 @@ class TypeReflectionTest(fixtures.TestBase): # (FLOAT(5), oracle.FLOAT(binary_precision=126),), ] self._run_test(specs, ["precision"]) + + +class IdentityReflectionTest(fixtures.TablesTest): + __only_on__ = "oracle" + __backend__ = True + __requires__ = ("identity_columns",) + + @classmethod + def define_tables(cls, metadata): + Table("t1", metadata, Column("id1", Integer, Identity(on_null=True))) + Table("t2", metadata, Column("id2", Integer, Identity(order=True))) + + def test_reflect_identity(self): + insp = inspect(testing.db) + common = { + "always": False, + "start": 1, + "increment": 1, + "on_null": False, + "maxvalue": 10 ** 28 - 1, + "minvalue": 1, + "cycle": False, + "cache": 20, + "order": False, + } + for col in insp.get_columns("t1") + insp.get_columns("t2"): + if col["name"] == "id1": + is_true("identity" in col) + exp = common.copy() + exp["on_null"] = True + eq_(col["identity"], exp) + if col["name"] == "id2": + is_true("identity" in col) + exp = common.copy() + exp["order"] = True + eq_(col["identity"], exp) diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 2c67957197..d85bdff77f 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -5,9 +5,11 @@ from operator import itemgetter import re import sqlalchemy as sa +from sqlalchemy import BigInteger from sqlalchemy import Column from sqlalchemy import exc from sqlalchemy import ForeignKey +from sqlalchemy import Identity from sqlalchemy import Index from sqlalchemy import inspect from sqlalchemy import Integer @@ -15,6 +17,7 @@ from sqlalchemy import join from sqlalchemy import MetaData from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import Sequence +from sqlalchemy import SmallInteger from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing @@ -33,6 +36,7 @@ from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import AssertsExecutionResults from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.assertions import is_true class ForeignTableReflectionTest(fixtures.TablesTest, AssertsExecutionResults): @@ -1770,7 +1774,7 @@ class CustomTypeReflectionTest(fixtures.TestBase): ("my_custom_type(ARG1, ARG2)", ("ARG1", "ARG2")), ]: column_info = dialect._get_column_info( - "colname", sch, None, False, {}, {}, "public", None, "" + "colname", sch, None, False, {}, {}, "public", None, "", None ) assert isinstance(column_info["type"], self.CustomType) eq_(column_info["type"].arg1, args[0]) @@ -1845,3 +1849,74 @@ class IntervalReflectionTest(fixtures.TestBase): assert isinstance(columns["data1"]["type"], INTERVAL) eq_(columns["data1"]["type"].fields, None) eq_(columns["data1"]["type"].precision, 6) + + +class IdentityReflectionTest(fixtures.TablesTest): + __only_on__ = "postgresql" + __backend__ = True + __requires__ = ("identity_columns",) + + @classmethod + def define_tables(cls, metadata): + Table( + "t1", + metadata, + Column( + "id1", + Integer, + Identity( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + ), + Column("id2", Integer, Identity()), + Column("id3", BigInteger, Identity()), + Column("id4", SmallInteger, Identity()), + ) + + def test_reflect_identity(self): + insp = inspect(testing.db) + default = dict( + always=False, + start=1, + increment=1, + minvalue=1, + cycle=False, + cache=1, + ) + cols = insp.get_columns("t1") + for col in cols: + if col["name"] == "id1": + is_true("identity" in col) + eq_( + col["identity"], + dict( + always=True, + start=2, + increment=3, + minvalue=-2, + maxvalue=42, + cycle=True, + cache=4, + ), + ) + elif col["name"] == "id2": + is_true("identity" in col) + exp = default.copy() + exp.update(maxvalue=2 ** 31 - 1) + eq_(col["identity"], exp) + elif col["name"] == "id3": + is_true("identity" in col) + exp = default.copy() + exp.update(maxvalue=2 ** 63 - 1) + eq_(col["identity"], exp) + elif col["name"] == "id4": + is_true("identity" in col) + exp = default.copy() + exp.update(maxvalue=2 ** 15 - 1) + eq_(col["identity"], exp) diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 194de9a7d2..26fb1b07fd 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -5,6 +5,7 @@ from sqlalchemy import Computed from sqlalchemy import DefaultClause from sqlalchemy import FetchedValue from sqlalchemy import ForeignKey +from sqlalchemy import Identity from sqlalchemy import Index from sqlalchemy import inspect from sqlalchemy import Integer @@ -2135,7 +2136,7 @@ class ColumnEventsTest(fixtures.RemovesEvents, fixtures.TestBase): cls.to_reflect = Table( "to_reflect", cls.metadata, - Column("x", sa.Integer, primary_key=True), + Column("x", sa.Integer, primary_key=True, autoincrement=False), Column("y", sa.Integer), test_needs_fk=True, ) @@ -2308,3 +2309,28 @@ class ComputedColumnTest(fixtures.ComputedReflectionFixtureTest): "normal-42", True, ) + + +class IdentityColumnTest(fixtures.TablesTest): + run_inserts = run_deletes = None + + __backend__ = True + __requires__ = ("identity_columns", "table_reflection") + + @classmethod + def define_tables(cls, metadata): + Table( + "t1", + metadata, + Column("normal", Integer), + Column("id1", Integer, Identity(start=2, increment=3)), + ) + + def test_table_reflection(self): + meta = MetaData() + table = Table("t1", meta, autoload_with=config.db) + + eq_(table.c.normal.identity, None) + is_true(table.c.id1.identity is not None) + eq_(table.c.id1.identity.start, 2) + eq_(table.c.id1.identity.increment, 3) diff --git a/test/sql/test_identity_column.py b/test/sql/test_identity_column.py index 2564022c2f..0b1bffd000 100644 --- a/test/sql/test_identity_column.py +++ b/test/sql/test_identity_column.py @@ -57,13 +57,17 @@ class _IdentityDDLFixture(testing.AssertsCompiledSQL): dict(always=False, cache=1000, order=True), "BY DEFAULT AS IDENTITY (CACHE 1000 ORDER)", ), - ( - dict(order=True), - "BY DEFAULT AS IDENTITY (ORDER)", - ), + (dict(order=True), "BY DEFAULT AS IDENTITY (ORDER)"), ) def test_create_ddl(self, identity_args, text): + if getattr(self, "__dialect__", None) != "default" and testing.against( + "oracle" + ): + text = text.replace("NO MINVALUE", "NOMINVALUE").replace( + "NO MAXVALUE", "NOMAXVALUE" + ) + t = Table( "foo_table", MetaData(),