From: Denis Laxalde Date: Tue, 11 Mar 2025 13:27:13 +0000 (-0400) Subject: Ensure PostgreSQL network address types are not cast as VARCHAR X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=f91e61e5c80004db6db47f4e13f37553ff22675a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Ensure PostgreSQL network address types are not cast as VARCHAR Fixed issue in PostgreSQL network types :class:`_postgresql.INET`, :class:`_postgresql.CIDR`, :class:`_postgresql.MACADDR`, :class:`_postgresql.MACADDR8` where sending string values to compare to these types would render an explicit CAST to VARCHAR, causing some SQL / driver combinations to fail. Pull request courtesy Denis Laxalde. Fixes: #12060 Closes: #12412 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12412 Pull-request-sha: 029fda7f2d182af71ebc48aef191aa9114927f28 Change-Id: Id4a502ebc119775567cacddbabef2ce9715c1a9f --- diff --git a/doc/build/changelog/unreleased_20/12060.rst b/doc/build/changelog/unreleased_20/12060.rst new file mode 100644 index 0000000000..c215d3799f --- /dev/null +++ b/doc/build/changelog/unreleased_20/12060.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, postgresql + :tickets: 12060 + + Fixed issue in PostgreSQL network types :class:`_postgresql.INET`, + :class:`_postgresql.CIDR`, :class:`_postgresql.MACADDR`, + :class:`_postgresql.MACADDR8` where sending string values to compare to + these types would render an explicit CAST to VARCHAR, causing some SQL / + driver combinations to fail. Pull request courtesy Denis Laxalde. diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index 6fe4f576eb..1aed2bf472 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -52,28 +52,38 @@ class BYTEA(sqltypes.LargeBinary): __visit_name__ = "BYTEA" -class INET(sqltypes.TypeEngine[str]): +class _NetworkAddressTypeMixin: + + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> TypeEngine[Any]: + if TYPE_CHECKING: + assert isinstance(self, TypeEngine) + return self + + +class INET(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): __visit_name__ = "INET" PGInet = INET -class CIDR(sqltypes.TypeEngine[str]): +class CIDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): __visit_name__ = "CIDR" PGCidr = CIDR -class MACADDR(sqltypes.TypeEngine[str]): +class MACADDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): __visit_name__ = "MACADDR" PGMacAddr = MACADDR -class MACADDR8(sqltypes.TypeEngine[str]): +class MACADDR8(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): __visit_name__ = "MACADDR8" diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 5f39aa608c..795a897699 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -3447,6 +3447,49 @@ class SpecialTypesCompileTest(fixtures.TestBase, AssertsCompiledSQL): def test_bit_compile(self, type_, expected): self.assert_compile(type_, expected) + @testing.combinations( + (psycopg.dialect(),), + (psycopg2.dialect(),), + (asyncpg.dialect(),), + (pg8000.dialect(),), + argnames="dialect", + id_="n", + ) + def test_network_address_cast(self, metadata, dialect): + t = Table( + "addresses", + metadata, + Column("id", Integer, primary_key=True), + Column("addr", postgresql.INET), + Column("addr2", postgresql.MACADDR), + Column("addr3", postgresql.CIDR), + Column("addr4", postgresql.MACADDR8), + ) + stmt = select(t.c.id).where( + t.c.addr == "127.0.0.1", + t.c.addr2 == "08:00:2b:01:02:03", + t.c.addr3 == "192.168.100.128/25", + t.c.addr4 == "08:00:2b:01:02:03:04:05", + ) + param, param2, param3, param4 = { + "format": ("%s", "%s", "%s", "%s"), + "numeric_dollar": ("$1", "$2", "$3", "$4"), + "pyformat": ( + "%(addr_1)s", + "%(addr2_1)s", + "%(addr3_1)s", + "%(addr4_1)s", + ), + }[dialect.paramstyle] + expected = ( + "SELECT addresses.id FROM addresses " + f"WHERE addresses.addr = {param} " + f"AND addresses.addr2 = {param2} " + f"AND addresses.addr3 = {param3} " + f"AND addresses.addr4 = {param4}" + ) + self.assert_compile(stmt, expected, dialect=dialect) + class SpecialTypesTest(fixtures.TablesTest, ComparesTables): """test DDL and reflection of PG-specific types""" @@ -3501,6 +3544,30 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): assert t.c.precision_interval.type.precision == 3 assert t.c.bitstring.type.length == 4 + @testing.combinations( + (postgresql.INET, "127.0.0.1"), + (postgresql.CIDR, "192.168.100.128/25"), + (postgresql.MACADDR, "08:00:2b:01:02:03"), + (postgresql.MACADDR8, "08:00:2b:01:02:03:04:05"), + argnames="column_type, value", + id_="na", + ) + def test_network_address_round_trip( + self, connection, metadata, column_type, value + ): + t = Table( + "addresses", + metadata, + Column("name", String), + Column("value", column_type), + ) + t.create(connection) + connection.execute(t.insert(), {"name": "test", "value": value}) + eq_( + connection.scalar(select(t.c.name).where(t.c.value == value)), + "test", + ) + def test_tsvector_round_trip(self, connection, metadata): t = Table("t1", metadata, Column("data", postgresql.TSVECTOR)) t.create(connection)