]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure PostgreSQL network address types are not cast as VARCHAR
authorDenis Laxalde <denis@laxalde.org>
Mon, 10 Mar 2025 13:56:10 +0000 (14:56 +0100)
committerDenis Laxalde <denis@laxalde.org>
Mon, 10 Mar 2025 15:20:18 +0000 (16:20 +0100)
This is a similar change as for CITEXT in commit
d7ee73ff81ed69df43756240670bd98f3b1c3302.

Fix https://github.com/sqlalchemy/sqlalchemy/issues/12060

lib/sqlalchemy/dialects/postgresql/types.py
test/dialect/postgresql/test_types.py

index 6fe4f576ebd84e7f4aa768e4b07c32dfc2e9c2d1..bcf8f311aef22d081b96331673fdb63425a8995c 100644 (file)
@@ -55,6 +55,11 @@ class BYTEA(sqltypes.LargeBinary):
 class INET(sqltypes.TypeEngine[str]):
     __visit_name__ = "INET"
 
+    def coerce_compared_value(
+        self, op: Optional[OperatorType], value: Any
+    ) -> TypeEngine[Any]:
+        return self
+
 
 PGInet = INET
 
@@ -62,6 +67,11 @@ PGInet = INET
 class CIDR(sqltypes.TypeEngine[str]):
     __visit_name__ = "CIDR"
 
+    def coerce_compared_value(
+        self, op: Optional[OperatorType], value: Any
+    ) -> TypeEngine[Any]:
+        return self
+
 
 PGCidr = CIDR
 
@@ -69,6 +79,11 @@ PGCidr = CIDR
 class MACADDR(sqltypes.TypeEngine[str]):
     __visit_name__ = "MACADDR"
 
+    def coerce_compared_value(
+        self, op: Optional[OperatorType], value: Any
+    ) -> TypeEngine[Any]:
+        return self
+
 
 PGMacAddr = MACADDR
 
@@ -76,6 +91,11 @@ PGMacAddr = MACADDR
 class MACADDR8(sqltypes.TypeEngine[str]):
     __visit_name__ = "MACADDR8"
 
+    def coerce_compared_value(
+        self, op: Optional[OperatorType], value: Any
+    ) -> TypeEngine[Any]:
+        return self
+
 
 PGMacAddr8 = MACADDR8
 
index 5f39aa608c8ac5453a19731051fb2b7048d2664a..8437ee1720985217a4c3d55aa43138a6ab802149 100644 (file)
@@ -3448,7 +3448,9 @@ class SpecialTypesCompileTest(fixtures.TestBase, AssertsCompiledSQL):
         self.assert_compile(type_, expected)
 
 
-class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
+class SpecialTypesTest(
+    AssertsCompiledSQL, fixtures.TablesTest, ComparesTables
+):
     """test DDL and reflection of PG-specific types"""
 
     __only_on__ = ("postgresql >= 8.3.0",)
@@ -3501,6 +3503,64 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
         assert t.c.precision_interval.type.precision == 3
         assert t.c.bitstring.type.length == 4
 
+    @testing.combinations(
+        (psycopg.dialect(),),
+        (psycopg2.dialect(),),
+        (asyncpg.dialect(),),
+        (pg8000.dialect(),),
+        argnames="dialect",
+        id_="n",
+    )
+    def test_network_address_cast(self, special_types_table, dialect):
+        stmt = select(special_types_table.c.id).where(
+            special_types_table.c.addr == "127.0.0.1",
+            special_types_table.c.addr2 == "08:00:2b:01:02:03",
+            special_types_table.c.addr3 == "192.168.100.128/25",
+            special_types_table.c.addr4 == "08:00:2b:01:02:03:04:05",
+        )
+        param, param2, param3, param4 = {
+            "format": ("%s", "%s", "%s", "%s"),
+            "numeric_dollar": ("$1", "$2", "$3", "$4"),
+            "pyformat": (
+                "%(addr_1)s",
+                "%(addr2_1)s",
+                "%(addr3_1)s",
+                "%(addr4_1)s",
+            ),
+        }[dialect.paramstyle]
+        expected = (
+            "SELECT sometable.id FROM sometable "
+            f"WHERE sometable.addr = {param} "
+            f"AND sometable.addr2 = {param2} "
+            f"AND sometable.addr3 = {param3} "
+            f"AND sometable.addr4 = {param4}"
+        )
+        self.assert_compile(stmt, expected, dialect=dialect)
+
+    @testing.combinations(
+        (postgresql.INET, "127.0.0.1"),
+        (postgresql.CIDR, "192.168.100.128/25"),
+        (postgresql.MACADDR, "08:00:2b:01:02:03"),
+        (postgresql.MACADDR8, "08:00:2b:01:02:03:04:05"),
+        argnames="column_type, value",
+        id_="na",
+    )
+    def test_network_address_round_trip(
+        self, connection, metadata, column_type, value
+    ):
+        t = Table(
+            "addresses",
+            metadata,
+            Column("name", String),
+            Column("value", column_type),
+        )
+        t.create(connection)
+        connection.execute(t.insert(), {"name": "test", "value": value})
+        eq_(
+            connection.scalar(select(t.c.name).where(t.c.value == value)),
+            "test",
+        )
+
     def test_tsvector_round_trip(self, connection, metadata):
         t = Table("t1", metadata, Column("data", postgresql.TSVECTOR))
         t.create(connection)