]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Propertly ignore ``Identity`` in MySQL and MariaDb.
authorFederico Caselli <cfederico87@gmail.com>
Wed, 21 Apr 2021 20:49:09 +0000 (22:49 +0200)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Apr 2021 00:03:27 +0000 (20:03 -0400)
Ensure that the MySQL and MariaDB dialect ignore the
:class:`_sql.Identity` construct while rendering the
``AUTO_INCREMENT`` keyword in a create table.

The Oracle and PostgreSQL compiler was updated to not render
:class:`_sql.Identity` if the database version does not support it
(Oracle < 12 and PostgreSQL < 10). Previously it was rendered regardless
of the database version.

Fixes: #6338
Change-Id: I2ca0902fdd7b4be4fc1a563cf5585504cbea9360

12 files changed:
doc/build/changelog/unreleased_14/6338.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/testing/suite/test_select.py
test/dialect/oracle/test_compiler.py
test/dialect/postgresql/test_compiler.py
test/requirements.py
test/sql/test_identity_column.py

diff --git a/doc/build/changelog/unreleased_14/6338.rst b/doc/build/changelog/unreleased_14/6338.rst
new file mode 100644 (file)
index 0000000..2ad45db
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, schema, mysql, mariadb, oracle, postgresql
+    :tickets: 6338
+
+    Ensure that the MySQL and MariaDB dialect ignore the
+    :class:`_sql.Identity` construct while rendering the ``AUTO_INCREMENT``
+    keyword in a create table.
+
+    The Oracle and PostgreSQL compiler was updated to not render
+    :class:`_sql.Identity` if the database version does not support it
+    (Oracle < 12 and PostgreSQL < 10). Previously it was rendered regardless
+    of the database version.
index d4c70a78e6b013ad6b7b300abcc464928e92c59f..88c03a0113ae37b96f73e5455b15a5cbace351e5 100644 (file)
@@ -1933,7 +1933,10 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
         if (
             column.table is not None
             and column is column.table._autoincrement_column
-            and column.server_default is None
+            and (
+                column.server_default is None
+                or isinstance(column.server_default, sa_schema.Identity)
+            )
             and not (
                 self.dialect.supports_sequences
                 and isinstance(column.default, sa_schema.Sequence)
index 9571eb46f0604c1e57a4bede02f131baeb4d5217..57f55c1824bc69eb017aac742c5e6d64340bc02c 100644 (file)
@@ -1457,6 +1457,7 @@ class OracleDialect(default.DefaultDialect):
     supports_default_values = False
     supports_default_metavalue = True
     supports_empty_insert = False
+    supports_identity_columns = True
 
     statement_compiler = OracleCompiler
     ddl_compiler = OracleDDLCompiler
@@ -1513,6 +1514,8 @@ class OracleDialect(default.DefaultDialect):
             self.colspecs.pop(sqltypes.Interval)
             self.use_ansi = False
 
+        self.supports_identity_columns = self.server_version_info >= (12,)
+
     def _get_effective_compat_server_version_info(self, connection):
         # dialect does not need compat levels below 12.2, so don't query
         # in those cases
index 47a933479f781d7e6fc61205d9b50c786ec19654..26b025b1a4a4d979e9e4b34b23d282e81f3af64d 100644 (file)
@@ -2464,6 +2464,11 @@ class PGDDLCompiler(compiler.DDLCompiler):
         if isinstance(impl_type, sqltypes.TypeDecorator):
             impl_type = impl_type.impl
 
+        has_identity = (
+            column.identity is not None
+            and self.dialect.supports_identity_columns
+        )
+
         if (
             column.primary_key
             and column is column.table._autoincrement_column
@@ -2471,7 +2476,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
                 self.dialect.supports_smallserial
                 or not isinstance(impl_type, sqltypes.SmallInteger)
             )
-            and column.identity is None
+            and not has_identity
             and (
                 column.default is None
                 or (
@@ -2498,12 +2503,12 @@ class PGDDLCompiler(compiler.DDLCompiler):
 
         if column.computed is not None:
             colspec += " " + self.process(column.computed)
-        if column.identity is not None:
+        if has_identity:
             colspec += " " + self.process(column.identity)
 
-        if not column.nullable and not column.identity:
+        if not column.nullable and not has_identity:
             colspec += " NOT NULL"
-        elif column.nullable and column.identity:
+        elif column.nullable and has_identity:
             colspec += " NULL"
         return colspec
 
@@ -3086,6 +3091,8 @@ class PGDialect(default.DefaultDialect):
 
     supports_empty_insert = False
     supports_multivalues_insert = True
+    supports_identity_columns = True
+
     default_paramstyle = "pyformat"
     ischema_names = ischema_names
     colspecs = colspecs
@@ -3193,6 +3200,7 @@ class PGDialect(default.DefaultDialect):
             9,
             2,
         )
+        self.supports_identity_columns = self.server_version_info >= (10,)
 
     def on_connect(self):
         if self.isolation_level is not None:
index a917228adbf601a4ad08f660e09f026fd4150a5d..aa7e4e5e9de8fa414b936162beec07a4f7dc754f 100644 (file)
@@ -74,6 +74,7 @@ class DefaultDialect(interfaces.Dialect):
     supports_sequences = False
     sequences_optional = False
     preexecute_autoincrement_sequences = False
+    supports_identity_columns = False
     postfetch_lastrowid = True
     implicit_returning = False
     full_returning = False
@@ -808,6 +809,8 @@ class StrCompileDialect(DefaultDialect):
 
     supports_statement_cache = True
 
+    supports_identity_columns = True
+
     supports_sequences = True
     sequences_optional = True
     preexecute_autoincrement_sequences = False
index 6168248ff7918fc3f1575d05a0cdaf483991ce26..bd93f5199eaa6f0bf70255b35fcd4baaa782f85b 100644 (file)
@@ -4257,10 +4257,15 @@ class DDLCompiler(Compiled):
         if column.computed is not None:
             colspec += " " + self.process(column.computed)
 
-        if column.identity is not None:
+        if (
+            column.identity is not None
+            and self.dialect.supports_identity_columns
+        ):
             colspec += " " + self.process(column.identity)
 
-        if not column.nullable and not column.identity:
+        if not column.nullable and (
+            not column.identity or not self.dialect.supports_identity_columns
+        ):
             colspec += " NOT NULL"
         return colspec
 
index 8a70cc69247196faf84f8a828e2de6100567885e..673fa15cddb94cdb4a55743fc541e123c4e04473 100644 (file)
@@ -1417,3 +1417,10 @@ class SuiteRequirements(Requirements):
         or ties. basically this is "not mssql"
         """
         return exclusions.closed()
+
+    @property
+    def autoincrement_without_sequence(self):
+        """If autoincrement=True on a column does not require an explicit
+        sequence. This should be false only for oracle.
+        """
+        return exclusions.open()
index 7b35dc3fa360a9b82df7415af8a0e05f7a8db7c9..8f34129299966992751b90cbe159f69a7dc14dcd 100644 (file)
@@ -1442,6 +1442,31 @@ class IdentityColumnTest(fixtures.TablesTest):
         assert_raises((DatabaseError, ProgrammingError), fn)
 
 
+class IdentityAutoincrementTest(fixtures.TablesTest):
+    __backend__ = True
+    __requires__ = ("autoincrement_without_sequence",)
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "tbl",
+            metadata,
+            Column(
+                "id",
+                Integer,
+                Identity(),
+                primary_key=True,
+                autoincrement=True,
+            ),
+            Column("desc", String(100)),
+        )
+
+    def test_autoincrement_with_identity(self, connection):
+        res = connection.execute(self.tables.tbl.insert(), {"desc": "row"})
+        res = connection.execute(self.tables.tbl.select()).first()
+        eq_(res, (1, "row"))
+
+
 class ExistsTest(fixtures.TablesTest):
     __backend__ = True
 
index 5e9f46e1a451c459ceb21393911588ca14527061..e198fa48a92c54d20e63bca7d12609cb05a30003 100644 (file)
@@ -1333,6 +1333,17 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "CREATE TABLE t (y INTEGER GENERATED %s AS IDENTITY)" % text,
         )
 
+    def test_column_identity_not_supported(self):
+        m = MetaData()
+        t = Table("t", m, Column("y", Integer, Identity(always=None)))
+        dd = oracle.OracleDialect()
+        dd.supports_identity_columns = False
+        self.assert_compile(
+            schema.CreateTable(t),
+            "CREATE TABLE t (y INTEGER NOT NULL)",
+            dialect=dd,
+        )
+
 
 class SequenceTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_basic(self):
index 4b2004a5ffac761c89d88b62ced6ac556d07a24f..a517ad1ac057ea9c375010d74556424ec9c970f9 100644 (file)
@@ -1857,18 +1857,50 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=postgresql.dialect(),
         )
 
-    def test_column_identity(self):
+    @testing.combinations(True, False)
+    def test_column_identity(self, pk):
         # all other tests are in test_identity_column.py
         m = MetaData()
         t = Table(
             "t",
             m,
-            Column("y", Integer, Identity(always=True, start=4, increment=7)),
+            Column(
+                "y",
+                Integer,
+                Identity(always=True, start=4, increment=7),
+                primary_key=pk,
+            ),
         )
         self.assert_compile(
             schema.CreateTable(t),
             "CREATE TABLE t (y INTEGER GENERATED ALWAYS AS IDENTITY "
-            "(INCREMENT BY 7 START WITH 4))",
+            "(INCREMENT BY 7 START WITH 4)%s)"
+            % (", PRIMARY KEY (y)" if pk else ""),
+        )
+
+    @testing.combinations(True, False)
+    def test_column_identity_no_support(self, pk):
+        m = MetaData()
+        t = Table(
+            "t",
+            m,
+            Column(
+                "y",
+                Integer,
+                Identity(always=True, start=4, increment=7),
+                primary_key=pk,
+            ),
+        )
+        dd = PGDialect()
+        dd.supports_identity_columns = False
+        self.assert_compile(
+            schema.CreateTable(t),
+            "CREATE TABLE t (y %s%s)"
+            % (
+                "SERIAL NOT NULL" if pk else "INTEGER NOT NULL",
+                ", PRIMARY KEY (y)" if pk else "",
+            ),
+            dialect=dd,
         )
 
     def test_column_identity_null(self):
index 98dca6124986cc8ef044a04e8db95446a8cefe12..6628d9ef3895432d32a8f82db497e1018d72c3e8 100644 (file)
@@ -1776,3 +1776,7 @@ class DefaultRequirements(SuiteRequirements):
     @property
     def fetch_offset_with_options(self):
         return skip_if("mssql")
+
+    @property
+    def autoincrement_without_sequence(self):
+        return skip_if("oracle")
index 1ce15f38c9fcd8a0662f1b849666e428193e8801..00404dae791b59c2c8384188aebab96caca76f6e 100644 (file)
@@ -1,3 +1,5 @@
+import re
+
 from sqlalchemy import Column
 from sqlalchemy import Identity
 from sqlalchemy import Integer
@@ -5,6 +7,7 @@ from sqlalchemy import MetaData
 from sqlalchemy import Sequence
 from sqlalchemy import Table
 from sqlalchemy import testing
+from sqlalchemy.engine import URL
 from sqlalchemy.exc import ArgumentError
 from sqlalchemy.schema import CreateTable
 from sqlalchemy.testing import assert_raises_message
@@ -63,9 +66,9 @@ class _IdentityDDLFixture(testing.AssertsCompiledSQL):
     )
     def test_create_ddl(self, identity_args, text):
 
-        if getattr(self, "__dialect__", None) != "default" and testing.against(
-            "oracle"
-        ):
+        if getattr(
+            self, "__dialect__", None
+        ) != "default_enhanced" and testing.against("oracle"):
             text = text.replace("NO MINVALUE", "NOMINVALUE")
             text = text.replace("NO MAXVALUE", "NOMAXVALUE")
             text = text.replace("NO CYCLE", "NOCYCLE")
@@ -138,9 +141,9 @@ class _IdentityDDLFixture(testing.AssertsCompiledSQL):
         is_(t.c.c.nullable, False)
 
         nullable = ""
-        if getattr(self, "__dialect__", None) != "default" and testing.against(
-            "postgresql"
-        ):
+        if getattr(
+            self, "__dialect__", None
+        ) != "default_enhanced" and testing.against("postgresql"):
             nullable = " NULL"
 
         self.assert_compile(
@@ -183,22 +186,67 @@ class IdentityDDL(_IdentityDDLFixture, fixtures.TestBase):
 
 class DefaultDialectIdentityDDL(_IdentityDDLFixture, fixtures.TestBase):
     # this uses the default dialect
-    __dialect__ = "default"
+    __dialect__ = "default_enhanced"
 
 
 class NotSupportingIdentityDDL(testing.AssertsCompiledSQL, fixtures.TestBase):
-    # a dialect that doesn't render IDENTITY
-    __dialect__ = "sqlite"
+    def get_dialect(self, dialect):
+        dd = URL.create(dialect).get_dialect()()
+        if dialect in {"oracle", "postgresql"}:
+            dd.supports_identity_columns = False
+        return dd
+
+    @testing.combinations("sqlite", "mysql", "mariadb", "postgresql", "oracle")
+    def test_identity_is_ignored(self, dialect):
 
-    @testing.skip_if(testing.requires.identity_columns)
-    def test_identity_is_ignored(self):
         t = Table(
             "foo_table",
             MetaData(),
             Column("foo", Integer(), Identity("always", start=3)),
         )
+        t_exp = Table(
+            "foo_table",
+            MetaData(),
+            Column("foo", Integer(), nullable=False),
+        )
+        dialect = self.get_dialect(dialect)
+        exp = CreateTable(t_exp).compile(dialect=dialect).string
+        self.assert_compile(
+            CreateTable(t), re.sub(r"[\n\t]", "", exp), dialect=dialect
+        )
+
+    @testing.combinations(
+        "sqlite",
+        "mysql",
+        "mariadb",
+        "postgresql",
+        "oracle",
+        argnames="dialect",
+    )
+    @testing.combinations(True, "auto", argnames="autoincrement")
+    def test_identity_is_ignored_in_pk(self, dialect, autoincrement):
+        t = Table(
+            "foo_table",
+            MetaData(),
+            Column(
+                "foo",
+                Integer(),
+                Identity("always", start=3),
+                primary_key=True,
+                autoincrement=autoincrement,
+            ),
+        )
+        t_exp = Table(
+            "foo_table",
+            MetaData(),
+            Column(
+                "foo", Integer(), primary_key=True, autoincrement=autoincrement
+            ),
+        )
+        dialect = self.get_dialect(dialect)
+        exp = CreateTable(t_exp).compile(dialect=dialect).string
         self.assert_compile(
-            CreateTable(t), "CREATE TABLE foo_table (foo INTEGER NOT NULL)"
+            CreateTable(t), re.sub(r"[\n\t]", "", exp), dialect=dialect
         )