]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Allow NUMERIC()/DECIMAL() IDENTITY columns
authorGord Thompson <gord@gordthompson.com>
Tue, 14 Jun 2022 16:09:04 +0000 (10:09 -0600)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 18 Jun 2022 17:57:54 +0000 (13:57 -0400)
Fixed issue where :class:`.Table` objects that made use of IDENTITY columns
with a :class:`.Numeric` datatype would produce errors when attempting to
reconcile the "autoincrement" column, preventing construction of the
:class:`.Column` from using the :paramref:`.Column.autoincrement` parameter
as well as emitting errors when attempting to invoke an :class:`.Insert`
construct.

Fixes: #8111
Change-Id: Iaacc4eebfbafb42fa18f9a1a4f43cb2b6b91d28a

doc/build/changelog/unreleased_14/8111.rst [new file with mode: 0644]
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
test/dialect/mssql/test_query.py

diff --git a/doc/build/changelog/unreleased_14/8111.rst b/doc/build/changelog/unreleased_14/8111.rst
new file mode 100644 (file)
index 0000000..ac43297
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, schema, mssql
+    :tickets: 8111
+
+    Fixed issue where :class:`.Table` objects that made use of IDENTITY columns
+    with a :class:`.Numeric` datatype would produce errors when attempting to
+    reconcile the "autoincrement" column, preventing construction of the
+    :class:`.Column` from using the :paramref:`.Column.autoincrement` parameter
+    as well as emitting errors when attempting to invoke an :class:`.Insert`
+    construct.
+
index c37b6000391a15de24fba3fe48736a96d52ca003..2414d9235712e58bcf7ddc3b837dc3cda326a751 100644 (file)
@@ -4540,7 +4540,11 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
     def _autoincrement_column(self) -> Optional[Column[Any]]:
         def _validate_autoinc(col: Column[Any], autoinc_true: bool) -> bool:
             if col.type._type_affinity is None or not issubclass(
-                col.type._type_affinity, type_api.INTEGERTYPE._type_affinity
+                col.type._type_affinity,
+                (
+                    type_api.INTEGERTYPE._type_affinity,
+                    type_api.NUMERICTYPE._type_affinity,
+                ),
             ):
                 if autoinc_true:
                     raise exc.ArgumentError(
index 32f0813f5d9e95e50e5b28cac7e4235d97687f22..2bf83b678705740e9b97b21b6f0a6de98bbfd9ef 100644 (file)
@@ -459,6 +459,12 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]):
 
     __visit_name__ = "numeric"
 
+    if TYPE_CHECKING:
+
+        @util.ro_memoized_property
+        def _type_affinity(self) -> Type[Numeric[_N]]:
+            ...
+
     _default_decimal_return_scale = 10
 
     def __init__(
@@ -3553,6 +3559,7 @@ NULLTYPE = NullType()
 BOOLEANTYPE = Boolean()
 STRINGTYPE = String()
 INTEGERTYPE = Integer()
+NUMERICTYPE: Numeric[decimal.Decimal] = Numeric()
 MATCHTYPE = MatchType()
 TABLEVALUE = TableValueType()
 DATETIME_TIMEZONE = DateTime(timezone=True)
@@ -3610,6 +3617,7 @@ type_api.BOOLEANTYPE = BOOLEANTYPE
 type_api.STRINGTYPE = STRINGTYPE
 type_api.INTEGERTYPE = INTEGERTYPE
 type_api.NULLTYPE = NULLTYPE
+type_api.NUMERICTYPE = NUMERICTYPE
 type_api.MATCHTYPE = MATCHTYPE
 type_api.INDEXABLE = INDEXABLE = Indexable
 type_api.TABLEVALUE = TABLEVALUE
index 00bae17bc57fba44b542f72a0e7443d01392cf7d..d8f1e92c493423cb7aca8fb90bcb8356a69e1264 100644 (file)
@@ -50,6 +50,7 @@ if typing.TYPE_CHECKING:
     from .sqltypes import INTEGERTYPE as INTEGERTYPE  # noqa: F401
     from .sqltypes import MATCHTYPE as MATCHTYPE  # noqa: F401
     from .sqltypes import NULLTYPE as NULLTYPE
+    from .sqltypes import NUMERICTYPE as NUMERICTYPE  # noqa: F401
     from .sqltypes import STRINGTYPE as STRINGTYPE  # noqa: F401
     from .sqltypes import TABLEVALUE as TABLEVALUE  # noqa: F401
     from ..engine.interfaces import Dialect
index 3576a9fc2a6b1c5eb48bff3a986f0b92fab8ae50..29bf4c812e8fb4fae4002fe4d2d6d6ba5ef14f0a 100644 (file)
@@ -1,4 +1,6 @@
 # -*- encoding: utf-8
+import decimal
+
 from sqlalchemy import and_
 from sqlalchemy import Column
 from sqlalchemy import DDL
@@ -9,6 +11,7 @@ from sqlalchemy import func
 from sqlalchemy import Identity
 from sqlalchemy import Integer
 from sqlalchemy import literal
+from sqlalchemy import Numeric
 from sqlalchemy import or_
 from sqlalchemy import PrimaryKeyConstraint
 from sqlalchemy import select
@@ -39,6 +42,13 @@ class IdentityInsertTest(fixtures.TablesTest, AssertsCompiledSQL):
             Column("description", String(50)),
             PrimaryKeyConstraint("id", name="PK_cattable"),
         )
+        Table(
+            "numeric_identity",
+            metadata,
+            Column("id", Numeric(18, 0), autoincrement=True),
+            Column("description", String(50)),
+            PrimaryKeyConstraint("id", name="PK_numeric_identity"),
+        )
 
     def test_compiled(self):
         cattable = self.tables.cattable
@@ -61,6 +71,13 @@ class IdentityInsertTest(fixtures.TablesTest, AssertsCompiledSQL):
         lastcat = conn.execute(cattable.select().order_by(desc(cattable.c.id)))
         eq_((10, "PHP"), lastcat.first())
 
+        numeric_identity = self.tables.numeric_identity
+        # for some reason, T-SQL does not like .values(), but this works
+        result = conn.execute(
+            numeric_identity.insert(), dict(description="T-SQL")
+        )
+        eq_(result.inserted_primary_key, (decimal.Decimal("1"),))
+
     def test_executemany(self, connection):
         conn = connection
         cattable = self.tables.cattable