]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
ensure CITEXT is not cast as VARCHAR
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 Jul 2023 13:32:10 +0000 (09:32 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 12 Jul 2023 15:14:09 +0000 (11:14 -0400)
Fixed issue where comparisons to the :class:`_postgresql.CITEXT` datatype
would cast the right side to ``VARCHAR``, leading to the right side not
being interpreted as a ``CITEXT`` datatype, for the asyncpg, psycopg3 and
pg80000 dialects.   This led to the :class:`_postgresql.CITEXT` type being
essentially unusable for practical use; this is now fixed and the test
suite has been corrected to properly assert that expressions are rendered
correctly.

Fixes: #10096
Change-Id: I49129e50261cf09942c0c339d581ce17a26d8181

doc/build/changelog/unreleased_20/10096.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/pg8000.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/postgresql/types.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/10096.rst b/doc/build/changelog/unreleased_20/10096.rst
new file mode 100644 (file)
index 0000000..2d8a8a5
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 10096
+
+    Fixed issue where comparisons to the :class:`_postgresql.CITEXT` datatype
+    would cast the right side to ``VARCHAR``, leading to the right side not
+    being interpreted as a ``CITEXT`` datatype, for the asyncpg, psycopg3 and
+    pg80000 dialects.   This led to the :class:`_postgresql.CITEXT` type being
+    essentially unusable for practical use; this is now fixed and the test
+    suite has been corrected to properly assert that expressions are rendered
+    correctly.
index d4350cc2892dc478f19bba368bb9c3974ada7c4f..1f6c0dcbe4091966301f4c6a41f7205fd1e4c043 100644 (file)
@@ -206,6 +206,7 @@ from .base import PGIdentifierPreparer
 from .base import REGCLASS
 from .base import REGCONFIG
 from .types import BYTEA
+from .types import CITEXT
 from ... import exc
 from ... import pool
 from ... import util
@@ -1001,6 +1002,7 @@ class PGDialect_asyncpg(PGDialect):
         {
             sqltypes.String: AsyncpgString,
             sqltypes.ARRAY: AsyncpgARRAY,
+            CITEXT: CITEXT,
             REGCONFIG: AsyncpgREGCONFIG,
             sqltypes.Time: AsyncpgTime,
             sqltypes.Date: AsyncpgDate,
index d221574888ad528a1986f11181fbd08bb851f2ab..11f7d171dac06727f989b6e61f5ede68b43362bc 100644 (file)
@@ -1701,7 +1701,7 @@ class PGCompiler(compiler.SQLCompiler):
         return f"{element.name}{self.function_argspec(element, **kw)}"
 
     def render_bind_cast(self, type_, dbapi_type, sqltext):
-        if dbapi_type._type_affinity is sqltypes.String:
+        if dbapi_type._type_affinity is sqltypes.String and dbapi_type.length:
             # use VARCHAR with no length for VARCHAR cast.
             # see #9511
             dbapi_type = sqltypes.STRINGTYPE
index e00628fbf251b1d54f65123b7c547780f86bf57d..71ee4ebd63e4e9604a9a7d5f306b18dca28cf496 100644 (file)
@@ -110,6 +110,7 @@ from .json import JSONB
 from .json import JSONPathType
 from .pg_catalog import _SpaceVector
 from .pg_catalog import OIDVECTOR
+from .types import CITEXT
 from ... import exc
 from ... import util
 from ...engine import processors
@@ -432,6 +433,7 @@ class PGDialect_pg8000(PGDialect):
             sqltypes.Boolean: _PGBoolean,
             sqltypes.NullType: _PGNullType,
             JSONB: _PGJSONB,
+            CITEXT: CITEXT,
             sqltypes.JSON.JSONPathType: _PGJSONPathType,
             sqltypes.JSON.JSONIndexType: _PGJSONIndexType,
             sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
index f4a0fc0aa7ec22bae62213684d4fd7ff82a8bee4..dcd69ce6631fc8a1d58ddcd4f699015adf0ae12d 100644 (file)
@@ -68,6 +68,7 @@ from .base import REGCONFIG
 from .json import JSON
 from .json import JSONB
 from .json import JSONPathType
+from .types import CITEXT
 from ... import pool
 from ... import util
 from ...engine import AdaptedConnection
@@ -271,6 +272,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
             sqltypes.String: _PGString,
             REGCONFIG: _PGREGCONFIG,
             JSON: _PGJSON,
+            CITEXT: CITEXT,
             sqltypes.JSON: _PGJSON,
             JSONB: _PGJSONB,
             sqltypes.JSON.JSONPathType: _PGJSONPathType,
index 0db2721c8750d7b3d1ba575426548e213f02e015..2f49ff12a459723e1738014e1e699e8e3d844c7c 100644 (file)
@@ -17,6 +17,10 @@ from ...sql import sqltypes
 from ...sql import type_api
 from ...util.typing import Literal
 
+if TYPE_CHECKING:
+    from ...sql.operators import OperatorType
+    from ...sql.type_api import TypeEngine
+
 _DECIMAL_TYPES = (1231, 1700)
 _FLOAT_TYPES = (700, 701, 1021, 1022)
 _INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
@@ -291,3 +295,8 @@ class CITEXT(sqltypes.TEXT):
     """
 
     __visit_name__ = "CITEXT"
+
+    def coerce_compared_value(
+        self, op: Optional[OperatorType], value: Any
+    ) -> TypeEngine[Any]:
+        return self
index 269d0e8082b9b14494125c8c07fe003fceabdadc..422e735d3d4150c571c3c9b37ce8e7b14556a65e 100644 (file)
@@ -44,6 +44,7 @@ from sqlalchemy.dialects import postgresql
 from sqlalchemy.dialects.postgresql import aggregate_order_by
 from sqlalchemy.dialects.postgresql import array
 from sqlalchemy.dialects.postgresql import array_agg
+from sqlalchemy.dialects.postgresql import asyncpg
 from sqlalchemy.dialects.postgresql import base
 from sqlalchemy.dialects.postgresql import CITEXT
 from sqlalchemy.dialects.postgresql import DATEMULTIRANGE
@@ -62,6 +63,7 @@ from sqlalchemy.dialects.postgresql import NamedType
 from sqlalchemy.dialects.postgresql import NUMMULTIRANGE
 from sqlalchemy.dialects.postgresql import NUMRANGE
 from sqlalchemy.dialects.postgresql import pg8000
+from sqlalchemy.dialects.postgresql import psycopg
 from sqlalchemy.dialects.postgresql import psycopg2
 from sqlalchemy.dialects.postgresql import psycopg2cffi
 from sqlalchemy.dialects.postgresql import Range
@@ -5943,9 +5945,10 @@ class JSONBCastSuiteTest(suite.JSONLegacyStringCastIndexTest):
     datatype = JSONB
 
 
-class CITextTest(fixtures.TablesTest):
+class CITextTest(testing.AssertsCompiledSQL, fixtures.TablesTest):
     __requires__ = ("citext",)
     __only_on__ = "postgresql"
+    __backend__ = True
 
     @classmethod
     def define_tables(cls, metadata):
@@ -5956,18 +5959,92 @@ class CITextTest(fixtures.TablesTest):
             Column("caseignore_text", CITEXT),
         )
 
-    def test_citext(self, connection):
+    @testing.variation(
+        "inserts",
+        ["multiple", "single", "insertmanyvalues", "imv_deterministic"],
+    )
+    def test_citext_round_trip(self, connection, inserts):
         ci_test_table = self.tables.ci_test_table
-        connection.execute(
-            ci_test_table.insert(),
+
+        data = [
             {"caseignore_text": "Hello World"},
-        )
+            {"caseignore_text": "greetings all"},
+        ]
+
+        if inserts.single:
+            for d in data:
+                connection.execute(
+                    ci_test_table.insert(),
+                    d,
+                )
+        elif inserts.multiple:
+            connection.execute(ci_test_table.insert(), data)
+        elif inserts.insertmanyvalues:
+            result = connection.execute(
+                ci_test_table.insert().returning(ci_test_table.c.id), data
+            )
+            result.all()
+        elif inserts.imv_deterministic:
+            result = connection.execute(
+                ci_test_table.insert().returning(
+                    ci_test_table.c.id, sort_by_parameter_order=True
+                ),
+                data,
+            )
+            result.all()
+        else:
+            inserts.fail()
+
+        ret = connection.execute(
+            select(func.count(ci_test_table.c.id)).where(
+                ci_test_table.c.caseignore_text == "hello world"
+            )
+        ).scalar()
+
+        eq_(ret, 1)
 
         ret = connection.execute(
-            select(ci_test_table.c.caseignore_text == "hello world")
+            select(func.count(ci_test_table.c.id)).where(
+                ci_test_table.c.caseignore_text == "Greetings All"
+            )
         ).scalar()
 
-        assert ret is not None
+        eq_(ret, 1)
+
+
+class CITextCastTest(testing.AssertsCompiledSQL, fixtures.TestBase):
+    @testing.combinations(
+        (psycopg.dialect(),),
+        (psycopg2.dialect(),),
+        (asyncpg.dialect(),),
+        (pg8000.dialect(),),
+    )
+    def test_cast(self, dialect):
+        ci_test_table = Table(
+            "ci_test_table",
+            MetaData(),
+            Column("id", Integer, primary_key=True),
+            Column("caseignore_text", CITEXT),
+        )
+
+        stmt = select(ci_test_table).where(
+            ci_test_table.c.caseignore_text == "xyz"
+        )
+
+        param = {
+            "format": "%s",
+            "numeric_dollar": "$1",
+            "pyformat": "%(caseignore_text_1)s",
+        }[dialect.paramstyle]
+        expected = (
+            "SELECT ci_test_table.id, ci_test_table.caseignore_text "
+            "FROM ci_test_table WHERE "
+            # currently CITEXT has render_bind_cast turned off.
+            # if there's a need to turn it on, change as follows:
+            # f"ci_test_table.caseignore_text = {param}::CITEXT"
+            f"ci_test_table.caseignore_text = {param}"
+        )
+        self.assert_compile(stmt, expected, dialect=dialect)
 
 
 class InetRoundTripTests(fixtures.TestBase):