From: FeeeeK <26704473+FeeeeK@users.noreply.github.com> Date: Sat, 14 Dec 2024 08:03:24 +0000 (-0500) Subject: Add missing `SmallInteger` column spec for `asyncpg` X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c5abd84a2c3c7a1f4e733dbee387aae939464f3e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add missing `SmallInteger` column spec for `asyncpg` Adds missing column spec for `SmallInteger` in `asyncpg` driver Fixes: #12170 Closes: #12171 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12171 Pull-request-sha: 82886d8521cb4e78822d685a864a9af438f6ea6b Change-Id: I2cb15f066de756d4e3f21bcac6af2cf03bd25a1c --- diff --git a/doc/build/changelog/unreleased_20/12170.rst b/doc/build/changelog/unreleased_20/12170.rst new file mode 100644 index 0000000000..452181efa3 --- /dev/null +++ b/doc/build/changelog/unreleased_20/12170.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, postgresql + :tickets: 12170 + + Fixed issue where creating a table with a primary column of + :class:`_sql.SmallInteger` and using the asyncpg driver would result in + the type being compiled to ``SERIAL`` rather than ``SMALLSERIAL``. diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 4e89d5c94a..a4909b74ea 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -275,6 +275,10 @@ class AsyncpgInteger(sqltypes.Integer): render_bind_cast = True +class AsyncpgSmallInteger(sqltypes.SmallInteger): + render_bind_cast = True + + class AsyncpgBigInteger(sqltypes.BigInteger): render_bind_cast = True @@ -1062,6 +1066,7 @@ class PGDialect_asyncpg(PGDialect): INTERVAL: AsyncPgInterval, sqltypes.Boolean: AsyncpgBoolean, sqltypes.Integer: AsyncpgInteger, + sqltypes.SmallInteger: AsyncpgSmallInteger, sqltypes.BigInteger: AsyncpgBigInteger, sqltypes.Numeric: AsyncpgNumeric, sqltypes.Float: AsyncpgFloat, diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 3f55c085fb..892e2abc9b 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -1573,61 +1573,62 @@ $$ LANGUAGE plpgsql; stmt = text("select cast('hi' as char) as hi").columns(hi=Numeric) assert_raises(exc.InvalidRequestError, connection.execute, stmt) - @testing.only_on("postgresql+psycopg2") - def test_serial_integer(self): - class BITD(TypeDecorator): - impl = Integer - - cache_ok = True - - def load_dialect_impl(self, dialect): - if dialect.name == "postgresql": - return BigInteger() - else: - return Integer() - - for version, type_, expected in [ - (None, Integer, "SERIAL"), - (None, BigInteger, "BIGSERIAL"), - ((9, 1), SmallInteger, "SMALLINT"), - ((9, 2), SmallInteger, "SMALLSERIAL"), - (None, postgresql.INTEGER, "SERIAL"), - (None, postgresql.BIGINT, "BIGSERIAL"), - ( - None, - Integer().with_variant(BigInteger(), "postgresql"), - "BIGSERIAL", - ), - ( - None, - Integer().with_variant(postgresql.BIGINT, "postgresql"), - "BIGSERIAL", - ), - ( - (9, 2), - Integer().with_variant(SmallInteger, "postgresql"), - "SMALLSERIAL", - ), - (None, BITD(), "BIGSERIAL"), - ]: - m = MetaData() + @testing.combinations( + (None, Integer, "SERIAL"), + (None, BigInteger, "BIGSERIAL"), + ((9, 1), SmallInteger, "SMALLINT"), + ((9, 2), SmallInteger, "SMALLSERIAL"), + (None, SmallInteger, "SMALLSERIAL"), + (None, postgresql.INTEGER, "SERIAL"), + (None, postgresql.BIGINT, "BIGSERIAL"), + ( + None, + Integer().with_variant(BigInteger(), "postgresql"), + "BIGSERIAL", + ), + ( + None, + Integer().with_variant(postgresql.BIGINT, "postgresql"), + "BIGSERIAL", + ), + ( + (9, 2), + Integer().with_variant(SmallInteger, "postgresql"), + "SMALLSERIAL", + ), + (None, "BITD()", "BIGSERIAL"), + argnames="version, type_, expected", + ) + def test_serial_integer(self, version, type_, expected, testing_engine): + if type_ == "BITD()": - t = Table("t", m, Column("c", type_, primary_key=True)) + class BITD(TypeDecorator): + impl = Integer - if version: - dialect = testing.db.dialect.__class__() - dialect._get_server_version_info = mock.Mock( - return_value=version - ) - dialect.initialize(testing.db.connect()) - else: - dialect = testing.db.dialect + cache_ok = True - ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t)) - eq_( - ddl_compiler.get_column_specification(t.c.c), - "c %s NOT NULL" % expected, - ) + def load_dialect_impl(self, dialect): + if dialect.name == "postgresql": + return BigInteger() + else: + return Integer() + + type_ = BITD() + t = Table("t", MetaData(), Column("c", type_, primary_key=True)) + + if version: + engine = testing_engine() + dialect = engine.dialect + dialect._get_server_version_info = mock.Mock(return_value=version) + engine.connect().close() # initialize the dialect + else: + dialect = testing.db.dialect + + ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t)) + eq_( + ddl_compiler.get_column_specification(t.c.c), + "c %s NOT NULL" % expected, + ) @testing.requires.psycopg2_compatibility def test_initial_transaction_state_psycopg2(self):