from sqlalchemy.dialects.postgresql import aggregate_order_by
from sqlalchemy.dialects.postgresql import array
from sqlalchemy.dialects.postgresql import array_agg
+from sqlalchemy.dialects.postgresql import asyncpg
from sqlalchemy.dialects.postgresql import base
from sqlalchemy.dialects.postgresql import CITEXT
from sqlalchemy.dialects.postgresql import DATEMULTIRANGE
from sqlalchemy.dialects.postgresql import NUMMULTIRANGE
from sqlalchemy.dialects.postgresql import NUMRANGE
from sqlalchemy.dialects.postgresql import pg8000
+from sqlalchemy.dialects.postgresql import psycopg
from sqlalchemy.dialects.postgresql import psycopg2
from sqlalchemy.dialects.postgresql import psycopg2cffi
from sqlalchemy.dialects.postgresql import Range
datatype = JSONB
-class CITextTest(fixtures.TablesTest):
+class CITextTest(testing.AssertsCompiledSQL, fixtures.TablesTest):
__requires__ = ("citext",)
__only_on__ = "postgresql"
+ __backend__ = True
@classmethod
def define_tables(cls, metadata):
Column("caseignore_text", CITEXT),
)
- def test_citext(self, connection):
+ @testing.variation(
+ "inserts",
+ ["multiple", "single", "insertmanyvalues", "imv_deterministic"],
+ )
+ def test_citext_round_trip(self, connection, inserts):
ci_test_table = self.tables.ci_test_table
- connection.execute(
- ci_test_table.insert(),
+
+ data = [
{"caseignore_text": "Hello World"},
- )
+ {"caseignore_text": "greetings all"},
+ ]
+
+ if inserts.single:
+ for d in data:
+ connection.execute(
+ ci_test_table.insert(),
+ d,
+ )
+ elif inserts.multiple:
+ connection.execute(ci_test_table.insert(), data)
+ elif inserts.insertmanyvalues:
+ result = connection.execute(
+ ci_test_table.insert().returning(ci_test_table.c.id), data
+ )
+ result.all()
+ elif inserts.imv_deterministic:
+ result = connection.execute(
+ ci_test_table.insert().returning(
+ ci_test_table.c.id, sort_by_parameter_order=True
+ ),
+ data,
+ )
+ result.all()
+ else:
+ inserts.fail()
+
+ ret = connection.execute(
+ select(func.count(ci_test_table.c.id)).where(
+ ci_test_table.c.caseignore_text == "hello world"
+ )
+ ).scalar()
+
+ eq_(ret, 1)
ret = connection.execute(
- select(ci_test_table.c.caseignore_text == "hello world")
+ select(func.count(ci_test_table.c.id)).where(
+ ci_test_table.c.caseignore_text == "Greetings All"
+ )
).scalar()
- assert ret is not None
+ eq_(ret, 1)
+
+
+class CITextCastTest(testing.AssertsCompiledSQL, fixtures.TestBase):
+ @testing.combinations(
+ (psycopg.dialect(),),
+ (psycopg2.dialect(),),
+ (asyncpg.dialect(),),
+ (pg8000.dialect(),),
+ )
+ def test_cast(self, dialect):
+ ci_test_table = Table(
+ "ci_test_table",
+ MetaData(),
+ Column("id", Integer, primary_key=True),
+ Column("caseignore_text", CITEXT),
+ )
+
+ stmt = select(ci_test_table).where(
+ ci_test_table.c.caseignore_text == "xyz"
+ )
+
+ param = {
+ "format": "%s",
+ "numeric_dollar": "$1",
+ "pyformat": "%(caseignore_text_1)s",
+ }[dialect.paramstyle]
+ expected = (
+ "SELECT ci_test_table.id, ci_test_table.caseignore_text "
+ "FROM ci_test_table WHERE "
+ # currently CITEXT has render_bind_cast turned off.
+ # if there's a need to turn it on, change as follows:
+ # f"ci_test_table.caseignore_text = {param}::CITEXT"
+ f"ci_test_table.caseignore_text = {param}"
+ )
+ self.assert_compile(stmt, expected, dialect=dialect)
class InetRoundTripTests(fixtures.TestBase):