From 095d8864fd5d5a89f3e85c0da40e192332f12815 Mon Sep 17 00:00:00 2001 From: Julian David Rath Date: Fri, 3 Mar 2023 18:06:21 +0100 Subject: [PATCH] Implementation of CITEXT , unittest and documentation --- README.unittests.rst | 10 +++++-- doc/build/dialects/postgresql.rst | 2 ++ .../dialects/postgresql/__init__.py | 2 ++ lib/sqlalchemy/dialects/postgresql/base.py | 15 ++++------ lib/sqlalchemy/dialects/postgresql/types.py | 7 +++++ test/dialect/postgresql/test_types.py | 28 +++++++++++++++++++ test/requirements.py | 4 +++ 7 files changed, 56 insertions(+), 12 deletions(-) diff --git a/README.unittests.rst b/README.unittests.rst index 9d63d238fb..4fcd7ed991 100644 --- a/README.unittests.rst +++ b/README.unittests.rst @@ -10,7 +10,6 @@ a single Python interpreter:: tox - Advanced Tox Options ==================== @@ -50,7 +49,7 @@ database options and test selection. A generic pytest run looks like:: - pytest -n4 + pytest - n4 Above, the full test suite will run against SQLite, using four processes. If the "-n" flag is not used, the pytest-xdist is skipped and the tests will @@ -199,6 +198,13 @@ Additional steps specific to individual databases are as follows:: test=# create extension hstore; CREATE EXTENSION + To include tests for CITEXT, create the CITEXT extension:: + + postgres=# \c test; + You are now connected to database "test" as user "postgresql". + test=# create extension citext; + CREATE EXTENSION + Full-text search configuration should be set to English, else several tests of ``.match()`` will fail. This can be set (if it isn't so already) with: diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index ce6022c559..fce0e4610e 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -312,6 +312,7 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect:: BYTEA, CHAR, CIDR, + CITEXT, DATE, DOUBLE_PRECISION, ENUM, @@ -372,6 +373,7 @@ construction arguments, are as follows: .. autoclass:: CIDR +.. autoclass:: CITEXT .. autoclass:: DOMAIN :members: __init__, create, drop diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index b68bd0502b..c3ed7c1fc0 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -72,6 +72,7 @@ from .ranges import TSTZRANGE from .types import BIT from .types import BYTEA from .types import CIDR +from .types import CITEXT from .types import INET from .types import INTERVAL from .types import MACADDR @@ -105,6 +106,7 @@ __all__ = ( "REAL", "INET", "CIDR", + "CITEXT", "UUID", "BIT", "MACADDR", diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 3ba1038026..a50eb253ea 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1450,6 +1450,7 @@ from .types import _INT_TYPES # noqa: F401 from .types import BIT as BIT from .types import BYTEA as BYTEA from .types import CIDR as CIDR +from .types import CITEXT as CITEXT from .types import INET as INET from .types import INTERVAL as INTERVAL from .types import MACADDR as MACADDR @@ -1651,6 +1652,7 @@ ischema_names = { "real": REAL, "inet": INET, "cidr": CIDR, + "citext": CITEXT, "uuid": UUID, "bit": BIT, "bit varying": BIT, @@ -1920,7 +1922,6 @@ class PGCompiler(compiler.SQLCompiler): return "" def for_update_clause(self, select, **kw): - if select._for_update_arg.read: if select._for_update_arg.key_share: tmp = " FOR KEY SHARE" @@ -1932,7 +1933,6 @@ class PGCompiler(compiler.SQLCompiler): tmp = " FOR UPDATE" if select._for_update_arg.of: - tables = util.OrderedSet() for c in select._for_update_arg.of: tables.update(sql_util.surface_selectables_only(c)) @@ -1959,7 +1959,6 @@ class PGCompiler(compiler.SQLCompiler): return "SUBSTRING(%s FROM %s)" % (s, start) def _on_conflict_target(self, clause, **kw): - if clause.constraint_target is not None: # target may be a name of an Index, UniqueConstraint or # ExcludeConstraint. While there is a separate @@ -1993,7 +1992,6 @@ class PGCompiler(compiler.SQLCompiler): return target_text def visit_on_conflict_do_nothing(self, on_conflict, **kw): - target_text = self._on_conflict_target(on_conflict, **kw) if target_text: @@ -2002,7 +2000,6 @@ class PGCompiler(compiler.SQLCompiler): return "ON CONFLICT DO NOTHING" def visit_on_conflict_do_update(self, on_conflict, **kw): - clause = on_conflict target_text = self._on_conflict_target(on_conflict, **kw) @@ -2110,7 +2107,6 @@ class PGCompiler(compiler.SQLCompiler): class PGDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) impl_type = column.type.dialect_impl(self.dialect) if isinstance(impl_type, sqltypes.TypeDecorator): @@ -2472,6 +2468,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_CIDR(self, type_, **kw): return "CIDR" + def visit_CITEXT(self, type_, **kw): + return "CITEXT" + def visit_MACADDR(self, type_, **kw): return "MACADDR" @@ -2621,7 +2620,6 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): return "BYTEA" def visit_ARRAY(self, type_, **kw): - inner = self.process(type_.item_type, **kw) return re.sub( r"((?: COLLATE.*)?)$", @@ -2644,7 +2642,6 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): class PGIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = RESERVED_WORDS def _unquote_identifier(self, value): @@ -2843,7 +2840,6 @@ class PGExecutionContext(default.DefaultExecutionContext): def get_insert_default(self, column): if column.primary_key and column is column.table._autoincrement_column: if column.server_default and column.server_default.has_argument: - # pre-execute passive defaults on primary key columns return self._execute_scalar( "select %s" % column.server_default.arg, column.type @@ -4222,7 +4218,6 @@ class PGDialect(default.DefaultDialect): def get_multi_indexes( self, connection, schema, filter_names, scope, kind, **kw ): - table_oids = self._get_table_oids( connection, schema, filter_names, scope, kind, **kw ) diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index a03fcaa392..95b4368a34 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -255,3 +255,10 @@ class TSVECTOR(sqltypes.TypeEngine[str]): """ __visit_name__ = "TSVECTOR" + + +class CITEXT(sqltypes.TypeEngine[str]): + + """The :class:`_postgresql.CITEXT` type implements the PostgreSQL""" + + __visit_name__ = "CITEXT" diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 2b15c7d735..61d2a3107f 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -40,6 +40,7 @@ from sqlalchemy.dialects.postgresql import aggregate_order_by from sqlalchemy.dialects.postgresql import array from sqlalchemy.dialects.postgresql import array_agg from sqlalchemy.dialects.postgresql import base +from sqlalchemy.dialects.postgresql import CITEXT from sqlalchemy.dialects.postgresql import DATEMULTIRANGE from sqlalchemy.dialects.postgresql import DATERANGE from sqlalchemy.dialects.postgresql import DOMAIN @@ -5748,3 +5749,30 @@ class JSONBCastSuiteTest(suite.JSONLegacyStringCastIndexTest): __requires__ = ("postgresql_jsonb",) datatype = JSONB + + +class CITextTest(fixtures.TablesTest): + __requires__ = ("citext",) + __only_on__ = "postgresql" + + @classmethod + def define_tables(cls, metadata): + Table( + "ci_test_table", + metadata, + Column("id", Integer, primary_key=True), + Column("caseignore_text", CITEXT), + ) + + def test_citext(self, connection): + ci_test_table = self.tables.ci_test_table + connection.execute( + ci_test_table.insert(), + {"caseignore_text": "Hello World"}, + ) + + ret = connection.execute( + select(ci_test_table.c.caseignore_text == "hello world") + ).scalar() + + assert ret is not None diff --git a/test/requirements.py b/test/requirements.py index 923d98b462..2b0944fb90 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -1416,6 +1416,10 @@ class DefaultRequirements(SuiteRequirements): def hstore(self): return self._has_pg_extension("hstore") + @property + def citext(self): + return self._has_pg_extension("citext") + @property def btree_gist(self): return self._has_pg_extension("btree_gist") -- 2.47.3