]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implementation of CITEXT , unittest and documentation
authorJulian David Rath <julian.rath@semadox.com>
Fri, 3 Mar 2023 17:06:21 +0000 (18:06 +0100)
committerJulian David Rath <julian.rath@semadox.com>
Fri, 3 Mar 2023 17:06:21 +0000 (18:06 +0100)
README.unittests.rst
doc/build/dialects/postgresql.rst
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/types.py
test/dialect/postgresql/test_types.py
test/requirements.py

index 9d63d238fb4372ed9b89caa07df68ba8bed6f441..4fcd7ed991efdd86f069cf54e9d4eabd0f23a74a 100644 (file)
@@ -10,7 +10,6 @@ a single Python interpreter::
 
     tox
 
-
 Advanced Tox Options
 ====================
 
@@ -50,7 +49,7 @@ database options and test selection.
 
 A generic pytest run looks like::
 
-    pytest -n4
+    pytest - n4
 
 Above, the full test suite will run against SQLite, using four processes.
 If the "-n" flag is not used, the pytest-xdist is skipped and the tests will
@@ -199,6 +198,13 @@ Additional steps specific to individual databases are as follows::
         test=# create extension hstore;
         CREATE EXTENSION
 
+    To include tests for CITEXT, create the CITEXT extension::
+
+        postgres=# \c test;
+        You are now connected to database "test" as user "postgresql".
+        test=# create extension citext;
+        CREATE EXTENSION
+
     Full-text search configuration should be set to English, else
     several tests of ``.match()`` will fail. This can be set (if it isn't so
     already) with:
index ce6022c55921564fbd7b76374dba3d9898a23816..fce0e4610e80b225c06c742b79cba250407e17e7 100644 (file)
@@ -312,6 +312,7 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect::
         BYTEA,
         CHAR,
         CIDR,
+        CITEXT,
         DATE,
         DOUBLE_PRECISION,
         ENUM,
@@ -372,6 +373,7 @@ construction arguments, are as follows:
 
 .. autoclass:: CIDR
 
+.. autoclass:: CITEXT
 
 .. autoclass:: DOMAIN
     :members: __init__, create, drop
index b68bd0502bd15820187b349a7c237ba2b5f9c1ff..c3ed7c1fc00e668365691548c9acfac8c0ff71e8 100644 (file)
@@ -72,6 +72,7 @@ from .ranges import TSTZRANGE
 from .types import BIT
 from .types import BYTEA
 from .types import CIDR
+from .types import CITEXT
 from .types import INET
 from .types import INTERVAL
 from .types import MACADDR
@@ -105,6 +106,7 @@ __all__ = (
     "REAL",
     "INET",
     "CIDR",
+    "CITEXT",
     "UUID",
     "BIT",
     "MACADDR",
index 3ba10380264a20cda62e6967387945c5f1734573..a50eb253ead7c848118c800e04511aae67b4e3d0 100644 (file)
@@ -1450,6 +1450,7 @@ from .types import _INT_TYPES  # noqa: F401
 from .types import BIT as BIT
 from .types import BYTEA as BYTEA
 from .types import CIDR as CIDR
+from .types import CITEXT as CITEXT
 from .types import INET as INET
 from .types import INTERVAL as INTERVAL
 from .types import MACADDR as MACADDR
@@ -1651,6 +1652,7 @@ ischema_names = {
     "real": REAL,
     "inet": INET,
     "cidr": CIDR,
+    "citext": CITEXT,
     "uuid": UUID,
     "bit": BIT,
     "bit varying": BIT,
@@ -1920,7 +1922,6 @@ class PGCompiler(compiler.SQLCompiler):
             return ""
 
     def for_update_clause(self, select, **kw):
-
         if select._for_update_arg.read:
             if select._for_update_arg.key_share:
                 tmp = " FOR KEY SHARE"
@@ -1932,7 +1933,6 @@ class PGCompiler(compiler.SQLCompiler):
             tmp = " FOR UPDATE"
 
         if select._for_update_arg.of:
-
             tables = util.OrderedSet()
             for c in select._for_update_arg.of:
                 tables.update(sql_util.surface_selectables_only(c))
@@ -1959,7 +1959,6 @@ class PGCompiler(compiler.SQLCompiler):
             return "SUBSTRING(%s FROM %s)" % (s, start)
 
     def _on_conflict_target(self, clause, **kw):
-
         if clause.constraint_target is not None:
             # target may be a name of an Index, UniqueConstraint or
             # ExcludeConstraint.  While there is a separate
@@ -1993,7 +1992,6 @@ class PGCompiler(compiler.SQLCompiler):
         return target_text
 
     def visit_on_conflict_do_nothing(self, on_conflict, **kw):
-
         target_text = self._on_conflict_target(on_conflict, **kw)
 
         if target_text:
@@ -2002,7 +2000,6 @@ class PGCompiler(compiler.SQLCompiler):
             return "ON CONFLICT DO NOTHING"
 
     def visit_on_conflict_do_update(self, on_conflict, **kw):
-
         clause = on_conflict
 
         target_text = self._on_conflict_target(on_conflict, **kw)
@@ -2110,7 +2107,6 @@ class PGCompiler(compiler.SQLCompiler):
 
 class PGDDLCompiler(compiler.DDLCompiler):
     def get_column_specification(self, column, **kwargs):
-
         colspec = self.preparer.format_column(column)
         impl_type = column.type.dialect_impl(self.dialect)
         if isinstance(impl_type, sqltypes.TypeDecorator):
@@ -2472,6 +2468,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
     def visit_CIDR(self, type_, **kw):
         return "CIDR"
 
+    def visit_CITEXT(self, type_, **kw):
+        return "CITEXT"
+
     def visit_MACADDR(self, type_, **kw):
         return "MACADDR"
 
@@ -2621,7 +2620,6 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
         return "BYTEA"
 
     def visit_ARRAY(self, type_, **kw):
-
         inner = self.process(type_.item_type, **kw)
         return re.sub(
             r"((?: COLLATE.*)?)$",
@@ -2644,7 +2642,6 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
 
 
 class PGIdentifierPreparer(compiler.IdentifierPreparer):
-
     reserved_words = RESERVED_WORDS
 
     def _unquote_identifier(self, value):
@@ -2843,7 +2840,6 @@ class PGExecutionContext(default.DefaultExecutionContext):
     def get_insert_default(self, column):
         if column.primary_key and column is column.table._autoincrement_column:
             if column.server_default and column.server_default.has_argument:
-
                 # pre-execute passive defaults on primary key columns
                 return self._execute_scalar(
                     "select %s" % column.server_default.arg, column.type
@@ -4222,7 +4218,6 @@ class PGDialect(default.DefaultDialect):
     def get_multi_indexes(
         self, connection, schema, filter_names, scope, kind, **kw
     ):
-
         table_oids = self._get_table_oids(
             connection, schema, filter_names, scope, kind, **kw
         )
index a03fcaa392ef82acb2b8aa9f6eae5ee467ff5e3e..95b4368a34ee1367c3b7836025ffd6dbc888c9dc 100644 (file)
@@ -255,3 +255,10 @@ class TSVECTOR(sqltypes.TypeEngine[str]):
     """
 
     __visit_name__ = "TSVECTOR"
+
+
+class CITEXT(sqltypes.TypeEngine[str]):
+
+    """The :class:`_postgresql.CITEXT` type implements the PostgreSQL"""
+
+    __visit_name__ = "CITEXT"
index 2b15c7d735af2c9bffc3a5bc7c3cad701a800d64..61d2a3107fb6c88e64d50462e1c24f2a0aacdac6 100644 (file)
@@ -40,6 +40,7 @@ from sqlalchemy.dialects.postgresql import aggregate_order_by
 from sqlalchemy.dialects.postgresql import array
 from sqlalchemy.dialects.postgresql import array_agg
 from sqlalchemy.dialects.postgresql import base
+from sqlalchemy.dialects.postgresql import CITEXT
 from sqlalchemy.dialects.postgresql import DATEMULTIRANGE
 from sqlalchemy.dialects.postgresql import DATERANGE
 from sqlalchemy.dialects.postgresql import DOMAIN
@@ -5748,3 +5749,30 @@ class JSONBCastSuiteTest(suite.JSONLegacyStringCastIndexTest):
     __requires__ = ("postgresql_jsonb",)
 
     datatype = JSONB
+
+
+class CITextTest(fixtures.TablesTest):
+    __requires__ = ("citext",)
+    __only_on__ = "postgresql"
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "ci_test_table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("caseignore_text", CITEXT),
+        )
+
+    def test_citext(self, connection):
+        ci_test_table = self.tables.ci_test_table
+        connection.execute(
+            ci_test_table.insert(),
+            {"caseignore_text": "Hello World"},
+        )
+
+        ret = connection.execute(
+            select(ci_test_table.c.caseignore_text == "hello world")
+        ).scalar()
+
+        assert ret is not None
index 923d98b4626a2e70a97515192c077add9091a226..2b0944fb907920c49bf1f5ea0e94d5b135c04304 100644 (file)
@@ -1416,6 +1416,10 @@ class DefaultRequirements(SuiteRequirements):
     def hstore(self):
         return self._has_pg_extension("hstore")
 
+    @property
+    def citext(self):
+        return self._has_pg_extension("citext")
+
     @property
     def btree_gist(self):
         return self._has_pg_extension("btree_gist")