From: Mike Bayer Date: Wed, 12 Jul 2023 13:32:10 +0000 (-0400) Subject: ensure CITEXT is not cast as VARCHAR X-Git-Tag: rel_2_0_19~7^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=d7ee73ff81ed69df43756240670bd98f3b1c3302;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git ensure CITEXT is not cast as VARCHAR Fixed issue where comparisons to the :class:`_postgresql.CITEXT` datatype would cast the right side to ``VARCHAR``, leading to the right side not being interpreted as a ``CITEXT`` datatype, for the asyncpg, psycopg3 and pg80000 dialects. This led to the :class:`_postgresql.CITEXT` type being essentially unusable for practical use; this is now fixed and the test suite has been corrected to properly assert that expressions are rendered correctly. Fixes: #10096 Change-Id: I49129e50261cf09942c0c339d581ce17a26d8181 --- diff --git a/doc/build/changelog/unreleased_20/10096.rst b/doc/build/changelog/unreleased_20/10096.rst new file mode 100644 index 0000000000..2d8a8a5413 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10096.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, postgresql + :tickets: 10096 + + Fixed issue where comparisons to the :class:`_postgresql.CITEXT` datatype + would cast the right side to ``VARCHAR``, leading to the right side not + being interpreted as a ``CITEXT`` datatype, for the asyncpg, psycopg3 and + pg80000 dialects. This led to the :class:`_postgresql.CITEXT` type being + essentially unusable for practical use; this is now fixed and the test + suite has been corrected to properly assert that expressions are rendered + correctly. diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index d4350cc289..1f6c0dcbe4 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -206,6 +206,7 @@ from .base import PGIdentifierPreparer from .base import REGCLASS from .base import REGCONFIG from .types import BYTEA +from .types import CITEXT from ... import exc from ... import pool from ... import util @@ -1001,6 +1002,7 @@ class PGDialect_asyncpg(PGDialect): { sqltypes.String: AsyncpgString, sqltypes.ARRAY: AsyncpgARRAY, + CITEXT: CITEXT, REGCONFIG: AsyncpgREGCONFIG, sqltypes.Time: AsyncpgTime, sqltypes.Date: AsyncpgDate, diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index d221574888..11f7d171da 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1701,7 +1701,7 @@ class PGCompiler(compiler.SQLCompiler): return f"{element.name}{self.function_argspec(element, **kw)}" def render_bind_cast(self, type_, dbapi_type, sqltext): - if dbapi_type._type_affinity is sqltypes.String: + if dbapi_type._type_affinity is sqltypes.String and dbapi_type.length: # use VARCHAR with no length for VARCHAR cast. # see #9511 dbapi_type = sqltypes.STRINGTYPE diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index e00628fbf2..71ee4ebd63 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -110,6 +110,7 @@ from .json import JSONB from .json import JSONPathType from .pg_catalog import _SpaceVector from .pg_catalog import OIDVECTOR +from .types import CITEXT from ... import exc from ... import util from ...engine import processors @@ -432,6 +433,7 @@ class PGDialect_pg8000(PGDialect): sqltypes.Boolean: _PGBoolean, sqltypes.NullType: _PGNullType, JSONB: _PGJSONB, + CITEXT: CITEXT, sqltypes.JSON.JSONPathType: _PGJSONPathType, sqltypes.JSON.JSONIndexType: _PGJSONIndexType, sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType, diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index f4a0fc0aa7..dcd69ce663 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -68,6 +68,7 @@ from .base import REGCONFIG from .json import JSON from .json import JSONB from .json import JSONPathType +from .types import CITEXT from ... import pool from ... import util from ...engine import AdaptedConnection @@ -271,6 +272,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): sqltypes.String: _PGString, REGCONFIG: _PGREGCONFIG, JSON: _PGJSON, + CITEXT: CITEXT, sqltypes.JSON: _PGJSON, JSONB: _PGJSONB, sqltypes.JSON.JSONPathType: _PGJSONPathType, diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index 0db2721c87..2f49ff12a4 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -17,6 +17,10 @@ from ...sql import sqltypes from ...sql import type_api from ...util.typing import Literal +if TYPE_CHECKING: + from ...sql.operators import OperatorType + from ...sql.type_api import TypeEngine + _DECIMAL_TYPES = (1231, 1700) _FLOAT_TYPES = (700, 701, 1021, 1022) _INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) @@ -291,3 +295,8 @@ class CITEXT(sqltypes.TEXT): """ __visit_name__ = "CITEXT" + + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> TypeEngine[Any]: + return self diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 269d0e8082..422e735d3d 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -44,6 +44,7 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import aggregate_order_by from sqlalchemy.dialects.postgresql import array from sqlalchemy.dialects.postgresql import array_agg +from sqlalchemy.dialects.postgresql import asyncpg from sqlalchemy.dialects.postgresql import base from sqlalchemy.dialects.postgresql import CITEXT from sqlalchemy.dialects.postgresql import DATEMULTIRANGE @@ -62,6 +63,7 @@ from sqlalchemy.dialects.postgresql import NamedType from sqlalchemy.dialects.postgresql import NUMMULTIRANGE from sqlalchemy.dialects.postgresql import NUMRANGE from sqlalchemy.dialects.postgresql import pg8000 +from sqlalchemy.dialects.postgresql import psycopg from sqlalchemy.dialects.postgresql import psycopg2 from sqlalchemy.dialects.postgresql import psycopg2cffi from sqlalchemy.dialects.postgresql import Range @@ -5943,9 +5945,10 @@ class JSONBCastSuiteTest(suite.JSONLegacyStringCastIndexTest): datatype = JSONB -class CITextTest(fixtures.TablesTest): +class CITextTest(testing.AssertsCompiledSQL, fixtures.TablesTest): __requires__ = ("citext",) __only_on__ = "postgresql" + __backend__ = True @classmethod def define_tables(cls, metadata): @@ -5956,18 +5959,92 @@ class CITextTest(fixtures.TablesTest): Column("caseignore_text", CITEXT), ) - def test_citext(self, connection): + @testing.variation( + "inserts", + ["multiple", "single", "insertmanyvalues", "imv_deterministic"], + ) + def test_citext_round_trip(self, connection, inserts): ci_test_table = self.tables.ci_test_table - connection.execute( - ci_test_table.insert(), + + data = [ {"caseignore_text": "Hello World"}, - ) + {"caseignore_text": "greetings all"}, + ] + + if inserts.single: + for d in data: + connection.execute( + ci_test_table.insert(), + d, + ) + elif inserts.multiple: + connection.execute(ci_test_table.insert(), data) + elif inserts.insertmanyvalues: + result = connection.execute( + ci_test_table.insert().returning(ci_test_table.c.id), data + ) + result.all() + elif inserts.imv_deterministic: + result = connection.execute( + ci_test_table.insert().returning( + ci_test_table.c.id, sort_by_parameter_order=True + ), + data, + ) + result.all() + else: + inserts.fail() + + ret = connection.execute( + select(func.count(ci_test_table.c.id)).where( + ci_test_table.c.caseignore_text == "hello world" + ) + ).scalar() + + eq_(ret, 1) ret = connection.execute( - select(ci_test_table.c.caseignore_text == "hello world") + select(func.count(ci_test_table.c.id)).where( + ci_test_table.c.caseignore_text == "Greetings All" + ) ).scalar() - assert ret is not None + eq_(ret, 1) + + +class CITextCastTest(testing.AssertsCompiledSQL, fixtures.TestBase): + @testing.combinations( + (psycopg.dialect(),), + (psycopg2.dialect(),), + (asyncpg.dialect(),), + (pg8000.dialect(),), + ) + def test_cast(self, dialect): + ci_test_table = Table( + "ci_test_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("caseignore_text", CITEXT), + ) + + stmt = select(ci_test_table).where( + ci_test_table.c.caseignore_text == "xyz" + ) + + param = { + "format": "%s", + "numeric_dollar": "$1", + "pyformat": "%(caseignore_text_1)s", + }[dialect.paramstyle] + expected = ( + "SELECT ci_test_table.id, ci_test_table.caseignore_text " + "FROM ci_test_table WHERE " + # currently CITEXT has render_bind_cast turned off. + # if there's a need to turn it on, change as follows: + # f"ci_test_table.caseignore_text = {param}::CITEXT" + f"ci_test_table.caseignore_text = {param}" + ) + self.assert_compile(stmt, expected, dialect=dialect) class InetRoundTripTests(fixtures.TestBase):