]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure PostgreSQL network address types are not cast as VARCHAR
authorDenis Laxalde <denis@laxalde.org>
Tue, 11 Mar 2025 13:27:13 +0000 (09:27 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 Mar 2025 14:05:39 +0000 (10:05 -0400)
Fixed issue in PostgreSQL network types :class:`_postgresql.INET`,
:class:`_postgresql.CIDR`, :class:`_postgresql.MACADDR`,
:class:`_postgresql.MACADDR8` where sending string values to compare to
these types would render an explicit CAST to VARCHAR, causing some SQL /
driver combinations to fail.  Pull request courtesy Denis Laxalde.

Fixes: #12060
Closes: #12412
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12412
Pull-request-sha: 029fda7f2d182af71ebc48aef191aa9114927f28

Change-Id: Id4a502ebc119775567cacddbabef2ce9715c1a9f

doc/build/changelog/unreleased_20/12060.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/types.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/12060.rst b/doc/build/changelog/unreleased_20/12060.rst
new file mode 100644 (file)
index 0000000..c215d37
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 12060
+
+    Fixed issue in PostgreSQL network types :class:`_postgresql.INET`,
+    :class:`_postgresql.CIDR`, :class:`_postgresql.MACADDR`,
+    :class:`_postgresql.MACADDR8` where sending string values to compare to
+    these types would render an explicit CAST to VARCHAR, causing some SQL /
+    driver combinations to fail.  Pull request courtesy Denis Laxalde.
index 6fe4f576ebd84e7f4aa768e4b07c32dfc2e9c2d1..1aed2bf4724077288c0df99abe464ed429c839f3 100644 (file)
@@ -52,28 +52,38 @@ class BYTEA(sqltypes.LargeBinary):
     __visit_name__ = "BYTEA"
 
 
-class INET(sqltypes.TypeEngine[str]):
+class _NetworkAddressTypeMixin:
+
+    def coerce_compared_value(
+        self, op: Optional[OperatorType], value: Any
+    ) -> TypeEngine[Any]:
+        if TYPE_CHECKING:
+            assert isinstance(self, TypeEngine)
+        return self
+
+
+class INET(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
     __visit_name__ = "INET"
 
 
 PGInet = INET
 
 
-class CIDR(sqltypes.TypeEngine[str]):
+class CIDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
     __visit_name__ = "CIDR"
 
 
 PGCidr = CIDR
 
 
-class MACADDR(sqltypes.TypeEngine[str]):
+class MACADDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
     __visit_name__ = "MACADDR"
 
 
 PGMacAddr = MACADDR
 
 
-class MACADDR8(sqltypes.TypeEngine[str]):
+class MACADDR8(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
     __visit_name__ = "MACADDR8"
 
 
index 5f39aa608c8ac5453a19731051fb2b7048d2664a..795a897699b331b4c994c594cffdebce610d6cc5 100644 (file)
@@ -3447,6 +3447,49 @@ class SpecialTypesCompileTest(fixtures.TestBase, AssertsCompiledSQL):
     def test_bit_compile(self, type_, expected):
         self.assert_compile(type_, expected)
 
+    @testing.combinations(
+        (psycopg.dialect(),),
+        (psycopg2.dialect(),),
+        (asyncpg.dialect(),),
+        (pg8000.dialect(),),
+        argnames="dialect",
+        id_="n",
+    )
+    def test_network_address_cast(self, metadata, dialect):
+        t = Table(
+            "addresses",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("addr", postgresql.INET),
+            Column("addr2", postgresql.MACADDR),
+            Column("addr3", postgresql.CIDR),
+            Column("addr4", postgresql.MACADDR8),
+        )
+        stmt = select(t.c.id).where(
+            t.c.addr == "127.0.0.1",
+            t.c.addr2 == "08:00:2b:01:02:03",
+            t.c.addr3 == "192.168.100.128/25",
+            t.c.addr4 == "08:00:2b:01:02:03:04:05",
+        )
+        param, param2, param3, param4 = {
+            "format": ("%s", "%s", "%s", "%s"),
+            "numeric_dollar": ("$1", "$2", "$3", "$4"),
+            "pyformat": (
+                "%(addr_1)s",
+                "%(addr2_1)s",
+                "%(addr3_1)s",
+                "%(addr4_1)s",
+            ),
+        }[dialect.paramstyle]
+        expected = (
+            "SELECT addresses.id FROM addresses "
+            f"WHERE addresses.addr = {param} "
+            f"AND addresses.addr2 = {param2} "
+            f"AND addresses.addr3 = {param3} "
+            f"AND addresses.addr4 = {param4}"
+        )
+        self.assert_compile(stmt, expected, dialect=dialect)
+
 
 class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
     """test DDL and reflection of PG-specific types"""
@@ -3501,6 +3544,30 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables):
         assert t.c.precision_interval.type.precision == 3
         assert t.c.bitstring.type.length == 4
 
+    @testing.combinations(
+        (postgresql.INET, "127.0.0.1"),
+        (postgresql.CIDR, "192.168.100.128/25"),
+        (postgresql.MACADDR, "08:00:2b:01:02:03"),
+        (postgresql.MACADDR8, "08:00:2b:01:02:03:04:05"),
+        argnames="column_type, value",
+        id_="na",
+    )
+    def test_network_address_round_trip(
+        self, connection, metadata, column_type, value
+    ):
+        t = Table(
+            "addresses",
+            metadata,
+            Column("name", String),
+            Column("value", column_type),
+        )
+        t.create(connection)
+        connection.execute(t.insert(), {"name": "test", "value": value})
+        eq_(
+            connection.scalar(select(t.c.name).where(t.c.value == value)),
+            "test",
+        )
+
     def test_tsvector_round_trip(self, connection, metadata):
         t = Table("t1", metadata, Column("data", postgresql.TSVECTOR))
         t.create(connection)