]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add missing `SmallInteger` column spec for `asyncpg`
authorFeeeeK <26704473+FeeeeK@users.noreply.github.com>
Sat, 14 Dec 2024 08:03:24 +0000 (03:03 -0500)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 17 Dec 2024 19:27:22 +0000 (20:27 +0100)
Adds missing column spec for `SmallInteger` in `asyncpg` driver

Fixes: #12170
Closes: #12171
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12171
Pull-request-sha: 82886d8521cb4e78822d685a864a9af438f6ea6b

Change-Id: I2cb15f066de756d4e3f21bcac6af2cf03bd25a1c

doc/build/changelog/unreleased_20/12170.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/asyncpg.py
test/dialect/postgresql/test_dialect.py

diff --git a/doc/build/changelog/unreleased_20/12170.rst b/doc/build/changelog/unreleased_20/12170.rst
new file mode 100644 (file)
index 0000000..452181e
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 12170
+
+    Fixed issue where creating a table with a primary column of
+    :class:`_sql.SmallInteger` and using the asyncpg driver would result in
+    the type being compiled to ``SERIAL`` rather than ``SMALLSERIAL``.
index 4e89d5c94a99f5e7c9d6cca937701e755558f7b8..a4909b74ea50d208a669fd5bd7b32b683328d25e 100644 (file)
@@ -275,6 +275,10 @@ class AsyncpgInteger(sqltypes.Integer):
     render_bind_cast = True
 
 
+class AsyncpgSmallInteger(sqltypes.SmallInteger):
+    render_bind_cast = True
+
+
 class AsyncpgBigInteger(sqltypes.BigInteger):
     render_bind_cast = True
 
@@ -1062,6 +1066,7 @@ class PGDialect_asyncpg(PGDialect):
             INTERVAL: AsyncPgInterval,
             sqltypes.Boolean: AsyncpgBoolean,
             sqltypes.Integer: AsyncpgInteger,
+            sqltypes.SmallInteger: AsyncpgSmallInteger,
             sqltypes.BigInteger: AsyncpgBigInteger,
             sqltypes.Numeric: AsyncpgNumeric,
             sqltypes.Float: AsyncpgFloat,
index 3f55c085fb4cfaa96a68a088d75c7e183bd38382..892e2abc9be02a450c954ae5bca7700389b73996 100644 (file)
@@ -1573,61 +1573,62 @@ $$ LANGUAGE plpgsql;
         stmt = text("select cast('hi' as char) as hi").columns(hi=Numeric)
         assert_raises(exc.InvalidRequestError, connection.execute, stmt)
 
-    @testing.only_on("postgresql+psycopg2")
-    def test_serial_integer(self):
-        class BITD(TypeDecorator):
-            impl = Integer
-
-            cache_ok = True
-
-            def load_dialect_impl(self, dialect):
-                if dialect.name == "postgresql":
-                    return BigInteger()
-                else:
-                    return Integer()
-
-        for version, type_, expected in [
-            (None, Integer, "SERIAL"),
-            (None, BigInteger, "BIGSERIAL"),
-            ((9, 1), SmallInteger, "SMALLINT"),
-            ((9, 2), SmallInteger, "SMALLSERIAL"),
-            (None, postgresql.INTEGER, "SERIAL"),
-            (None, postgresql.BIGINT, "BIGSERIAL"),
-            (
-                None,
-                Integer().with_variant(BigInteger(), "postgresql"),
-                "BIGSERIAL",
-            ),
-            (
-                None,
-                Integer().with_variant(postgresql.BIGINT, "postgresql"),
-                "BIGSERIAL",
-            ),
-            (
-                (9, 2),
-                Integer().with_variant(SmallInteger, "postgresql"),
-                "SMALLSERIAL",
-            ),
-            (None, BITD(), "BIGSERIAL"),
-        ]:
-            m = MetaData()
+    @testing.combinations(
+        (None, Integer, "SERIAL"),
+        (None, BigInteger, "BIGSERIAL"),
+        ((9, 1), SmallInteger, "SMALLINT"),
+        ((9, 2), SmallInteger, "SMALLSERIAL"),
+        (None, SmallInteger, "SMALLSERIAL"),
+        (None, postgresql.INTEGER, "SERIAL"),
+        (None, postgresql.BIGINT, "BIGSERIAL"),
+        (
+            None,
+            Integer().with_variant(BigInteger(), "postgresql"),
+            "BIGSERIAL",
+        ),
+        (
+            None,
+            Integer().with_variant(postgresql.BIGINT, "postgresql"),
+            "BIGSERIAL",
+        ),
+        (
+            (9, 2),
+            Integer().with_variant(SmallInteger, "postgresql"),
+            "SMALLSERIAL",
+        ),
+        (None, "BITD()", "BIGSERIAL"),
+        argnames="version, type_, expected",
+    )
+    def test_serial_integer(self, version, type_, expected, testing_engine):
+        if type_ == "BITD()":
 
-            t = Table("t", m, Column("c", type_, primary_key=True))
+            class BITD(TypeDecorator):
+                impl = Integer
 
-            if version:
-                dialect = testing.db.dialect.__class__()
-                dialect._get_server_version_info = mock.Mock(
-                    return_value=version
-                )
-                dialect.initialize(testing.db.connect())
-            else:
-                dialect = testing.db.dialect
+                cache_ok = True
 
-            ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t))
-            eq_(
-                ddl_compiler.get_column_specification(t.c.c),
-                "c %s NOT NULL" % expected,
-            )
+                def load_dialect_impl(self, dialect):
+                    if dialect.name == "postgresql":
+                        return BigInteger()
+                    else:
+                        return Integer()
+
+            type_ = BITD()
+        t = Table("t", MetaData(), Column("c", type_, primary_key=True))
+
+        if version:
+            engine = testing_engine()
+            dialect = engine.dialect
+            dialect._get_server_version_info = mock.Mock(return_value=version)
+            engine.connect().close()  # initialize the dialect
+        else:
+            dialect = testing.db.dialect
+
+        ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t))
+        eq_(
+            ddl_compiler.get_column_specification(t.c.c),
+            "c %s NOT NULL" % expected,
+        )
 
     @testing.requires.psycopg2_compatibility
     def test_initial_transaction_state_psycopg2(self):