From 972d1dfcd2051fe4793849540a381339d8dba809 Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Mon, 10 Mar 2025 14:56:10 +0100 Subject: [PATCH] Ensure PostgreSQL network address types are not cast as VARCHAR This is a similar change as for CITEXT in commit d7ee73ff81ed69df43756240670bd98f3b1c3302. Fix https://github.com/sqlalchemy/sqlalchemy/issues/12060 --- lib/sqlalchemy/dialects/postgresql/types.py | 20 +++++++ test/dialect/postgresql/test_types.py | 62 ++++++++++++++++++++- 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index 6fe4f576eb..bcf8f311ae 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -55,6 +55,11 @@ class BYTEA(sqltypes.LargeBinary): class INET(sqltypes.TypeEngine[str]): __visit_name__ = "INET" + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> TypeEngine[Any]: + return self + PGInet = INET @@ -62,6 +67,11 @@ PGInet = INET class CIDR(sqltypes.TypeEngine[str]): __visit_name__ = "CIDR" + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> TypeEngine[Any]: + return self + PGCidr = CIDR @@ -69,6 +79,11 @@ PGCidr = CIDR class MACADDR(sqltypes.TypeEngine[str]): __visit_name__ = "MACADDR" + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> TypeEngine[Any]: + return self + PGMacAddr = MACADDR @@ -76,6 +91,11 @@ PGMacAddr = MACADDR class MACADDR8(sqltypes.TypeEngine[str]): __visit_name__ = "MACADDR8" + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> TypeEngine[Any]: + return self + PGMacAddr8 = MACADDR8 diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 5f39aa608c..8437ee1720 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -3448,7 +3448,9 @@ class SpecialTypesCompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile(type_, expected) -class SpecialTypesTest(fixtures.TablesTest, ComparesTables): +class SpecialTypesTest( + AssertsCompiledSQL, fixtures.TablesTest, ComparesTables +): """test DDL and reflection of PG-specific types""" __only_on__ = ("postgresql >= 8.3.0",) @@ -3501,6 +3503,64 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): assert t.c.precision_interval.type.precision == 3 assert t.c.bitstring.type.length == 4 + @testing.combinations( + (psycopg.dialect(),), + (psycopg2.dialect(),), + (asyncpg.dialect(),), + (pg8000.dialect(),), + argnames="dialect", + id_="n", + ) + def test_network_address_cast(self, special_types_table, dialect): + stmt = select(special_types_table.c.id).where( + special_types_table.c.addr == "127.0.0.1", + special_types_table.c.addr2 == "08:00:2b:01:02:03", + special_types_table.c.addr3 == "192.168.100.128/25", + special_types_table.c.addr4 == "08:00:2b:01:02:03:04:05", + ) + param, param2, param3, param4 = { + "format": ("%s", "%s", "%s", "%s"), + "numeric_dollar": ("$1", "$2", "$3", "$4"), + "pyformat": ( + "%(addr_1)s", + "%(addr2_1)s", + "%(addr3_1)s", + "%(addr4_1)s", + ), + }[dialect.paramstyle] + expected = ( + "SELECT sometable.id FROM sometable " + f"WHERE sometable.addr = {param} " + f"AND sometable.addr2 = {param2} " + f"AND sometable.addr3 = {param3} " + f"AND sometable.addr4 = {param4}" + ) + self.assert_compile(stmt, expected, dialect=dialect) + + @testing.combinations( + (postgresql.INET, "127.0.0.1"), + (postgresql.CIDR, "192.168.100.128/25"), + (postgresql.MACADDR, "08:00:2b:01:02:03"), + (postgresql.MACADDR8, "08:00:2b:01:02:03:04:05"), + argnames="column_type, value", + id_="na", + ) + def test_network_address_round_trip( + self, connection, metadata, column_type, value + ): + t = Table( + "addresses", + metadata, + Column("name", String), + Column("value", column_type), + ) + t.create(connection) + connection.execute(t.insert(), {"name": "test", "value": value}) + eq_( + connection.scalar(select(t.c.name).where(t.c.value == value)), + "test", + ) + def test_tsvector_round_trip(self, connection, metadata): t = Table("t1", metadata, Column("data", postgresql.TSVECTOR)) t.create(connection) -- 2.47.3