From: Mike Bayer Date: Wed, 14 Dec 2022 01:07:14 +0000 (-0500) Subject: add explicit REGCONFIG, pg full text functions X-Git-Tag: rel_2_0_0rc1~27^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7b84c850606c7b093b4260c08ff4636ff1bdbfef;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add explicit REGCONFIG, pg full text functions Added support for explicit use of PG full text functions with asyncpg and psycopg (SQLAlchemy 2.0 only), with regards to the ``REGCONFIG`` type cast for the first argument, which previously would be incorrectly cast to a VARCHAR, causing failures on these dialects that rely upon explicit type casts. This includes support for :class:`_postgresql.to_tsvector`, :class:`_postgresql.to_tsquery`, :class:`_postgresql.plainto_tsquery`, :class:`_postgresql.phraseto_tsquery`, :class:`_postgresql.websearch_to_tsquery`, :class:`_postgresql.ts_headline`, each of which will determine based on number of arguments passed if the first string argument should be interpreted as a PostgreSQL "REGCONFIG" value; if so, the argument is typed using a newly added type object :class:`_postgresql.REGCONFIG` which is then explicitly cast in the SQL expression. Fixes: #8977 Change-Id: Ib36698a984fd4194bd6e0eb663105f790f3db7d3 --- diff --git a/doc/build/changelog/unreleased_20/8977.rst b/doc/build/changelog/unreleased_20/8977.rst new file mode 100644 index 0000000000..904e08bf39 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8977.rst @@ -0,0 +1,19 @@ +.. change:: + :tags: postgresql, bug + :tickets: 8977 + :versions: 2.0.0b5 + + Added support for explicit use of PG full text functions with asyncpg and + psycopg (SQLAlchemy 2.0 only), with regards to the ``REGCONFIG`` type cast + for the first argument, which previously would be incorrectly cast to a + VARCHAR, causing failures on these dialects that rely upon explicit type + casts. This includes support for :class:`_postgresql.to_tsvector`, + :class:`_postgresql.to_tsquery`, :class:`_postgresql.plainto_tsquery`, + :class:`_postgresql.phraseto_tsquery`, + :class:`_postgresql.websearch_to_tsquery`, + :class:`_postgresql.ts_headline`, each of which will determine based on + number of arguments passed if the first string argument should be + interpreted as a PostgreSQL "REGCONFIG" value; if so, the argument is typed + using a newly added type object :class:`_postgresql.REGCONFIG` which is + then explicitly cast in the SQL expression. + diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index 2f541b5abd..8f4dc1e72d 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -339,6 +339,9 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect:: DATERANGE, TSRANGE, TSTZRANGE, + REGCONFIG, + REGCLASS, + TSQUERY, TSVECTOR, ) @@ -356,18 +359,10 @@ construction arguments, are as follows: .. autoclass:: sqlalchemy.dialects.postgresql.AbstractMultiRange -.. autoclass:: aggregate_order_by - -.. autoclass:: array .. autoclass:: ARRAY :members: __init__, Comparator -.. autofunction:: array_agg - -.. autofunction:: Any - -.. autofunction:: All .. autoclass:: BIT @@ -393,10 +388,6 @@ construction arguments, are as follows: :members: -.. autoclass:: hstore - :members: - - .. autoclass:: INET .. autoclass:: INTERVAL @@ -420,6 +411,9 @@ construction arguments, are as follows: :members: __init__ :noindex: + +.. autoclass:: REGCONFIG + .. autoclass:: REGCLASS .. autoclass:: TIMESTAMP @@ -428,6 +422,8 @@ construction arguments, are as follows: .. autoclass:: TIME :members: __init__ +.. autoclass:: TSQUERY + .. autoclass:: TSVECTOR .. autoclass:: UUID @@ -471,6 +467,33 @@ construction arguments, are as follows: .. autoclass:: TSTZMULTIRANGE +PostgreSQL SQL Elements and Functions +-------------------------------------- + +.. autoclass:: aggregate_order_by + +.. autoclass:: array + +.. autofunction:: array_agg + +.. autofunction:: Any + +.. autofunction:: All + +.. autoclass:: hstore + :members: + +.. autoclass:: to_tsvector + +.. autoclass:: to_tsquery + +.. autoclass:: plainto_tsquery + +.. autoclass:: phraseto_tsquery + +.. autoclass:: websearch_to_tsquery + +.. autoclass:: ts_headline PostgreSQL Constraint Types --------------------------- diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 7890541ffd..d2e213bbc7 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -37,6 +37,12 @@ from .dml import insert from .ext import aggregate_order_by from .ext import array_agg from .ext import ExcludeConstraint +from .ext import phraseto_tsquery +from .ext import plainto_tsquery +from .ext import to_tsquery +from .ext import to_tsvector +from .ext import ts_headline +from .ext import websearch_to_tsquery from .hstore import HSTORE from .hstore import hstore from .json import JSON @@ -72,8 +78,10 @@ from .types import MACADDR from .types import MONEY from .types import OID from .types import REGCLASS +from .types import REGCONFIG from .types import TIME from .types import TIMESTAMP +from .types import TSQUERY from .types import TSVECTOR # Alias psycopg also as psycopg_async @@ -102,6 +110,9 @@ __all__ = ( "MONEY", "OID", "REGCLASS", + "REGCONFIG", + "TSQUERY", + "TSVECTOR", "DOUBLE_PRECISION", "TIMESTAMP", "TIME", diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index b8f614eba5..3c1eaf918d 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -142,6 +142,7 @@ from .base import PGDialect from .base import PGExecutionContext from .base import PGIdentifierPreparer from .base import REGCLASS +from .base import REGCONFIG from ... import exc from ... import pool from ... import util @@ -160,6 +161,10 @@ class AsyncpgString(sqltypes.String): render_bind_cast = True +class AsyncpgREGCONFIG(REGCONFIG): + render_bind_cast = True + + class AsyncpgTime(sqltypes.Time): render_bind_cast = True @@ -899,6 +904,7 @@ class PGDialect_asyncpg(PGDialect): PGDialect.colspecs, { sqltypes.String: AsyncpgString, + REGCONFIG: AsyncpgREGCONFIG, sqltypes.Time: AsyncpgTime, sqltypes.Date: AsyncpgDate, sqltypes.DateTime: AsyncpgDateTime, diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index f9108094f2..8287e828a7 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -827,6 +827,8 @@ For example, the query:: would generate: +.. sourcecode:: sql + SELECT to_tsquery('cat') @> to_tsquery('cat & rat') @@ -840,6 +842,20 @@ produces a statement equivalent to:: SELECT CAST('some text' AS TSVECTOR) AS anon_1 +The ``func`` namespace is augmented by the PostgreSQL dialect to set up +correct argument and return types for most full text search functions. +These functions are used automatically by the :attr:`_sql.func` namespace +assuming the ``sqlalchemy.dialects.postgresql`` package has been imported, +or :func:`_sa.create_engine` has been invoked using a ``postgresql`` +dialect. These functions are documented at: + +* :class:`_postgresql.to_tsvector` +* :class:`_postgresql.to_tsquery` +* :class:`_postgresql.plainto_tsquery` +* :class:`_postgresql.phraseto_tsquery` +* :class:`_postgresql.websearch_to_tsquery` +* :class:`_postgresql.ts_headline` + Specifying the "regconfig" with ``match()`` or custom operators ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -1402,6 +1418,7 @@ from . import hstore as _hstore from . import json as _json from . import pg_catalog from . import ranges as _ranges +from .ext import _regconfig_fn from .ext import aggregate_order_by from .named_types import CreateDomainType as CreateDomainType # noqa: F401 from .named_types import CreateEnumType as CreateEnumType # noqa: F401 @@ -1428,6 +1445,7 @@ from .types import PGInterval as PGInterval # noqa: F401 from .types import PGMacAddr as PGMacAddr # noqa: F401 from .types import PGUuid as PGUuid from .types import REGCLASS as REGCLASS +from .types import REGCONFIG as REGCONFIG # noqa: F401 from .types import TIME as TIME from .types import TIMESTAMP as TIMESTAMP from .types import TSVECTOR as TSVECTOR @@ -1636,6 +1654,45 @@ ischema_names = { class PGCompiler(compiler.SQLCompiler): + def visit_to_tsvector_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_to_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_plainto_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_phraseto_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_websearch_to_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_ts_headline_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def _assert_pg_ts_ext(self, element, **kw): + if not isinstance(element, _regconfig_fn): + # other options here include trying to rewrite the function + # with the correct types. however, that means we have to + # "un-SQL-ize" the first argument, which can't work in a + # generalized way. Also, parent compiler class has already added + # the incorrect return type to the result map. So let's just + # make sure the function we want is used up front. + + raise exc.CompileError( + f'Can\'t compile "{element.name}()" full text search ' + f"function construct that does not originate from the " + f'"sqlalchemy.dialects.postgresql" package. ' + f'Please ensure "import sqlalchemy.dialects.postgresql" is ' + f"called before constructing " + f'"sqlalchemy.func.{element.name}()" to ensure registration ' + f"of the correct argument and return types." + ) + + return f"{element.name}{self.function_argspec(element, **kw)}" + def render_bind_cast(self, type_, dbapi_type, sqltext): return f"""{sqltext}::{ self.dialect.type_compiler_instance.process( @@ -2381,6 +2438,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_TSVECTOR(self, type_, **kw): return "TSVECTOR" + def visit_TSQUERY(self, type_, **kw): + return "TSQUERY" + def visit_INET(self, type_, **kw): return "INET" @@ -2396,6 +2456,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_OID(self, type_, **kw): return "OID" + def visit_REGCONFIG(self, type_, **kw): + return "REGCONFIG" + def visit_REGCLASS(self, type_, **kw): return "REGCLASS" diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index b0d8ef3457..31fbf203be 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -8,8 +8,11 @@ from __future__ import annotations from itertools import zip_longest +from typing import Any from typing import TYPE_CHECKING +from typing import TypeVar +from . import types from .array import ARRAY from ...sql import coercions from ...sql import elements @@ -18,8 +21,11 @@ from ...sql import functions from ...sql import roles from ...sql import schema from ...sql.schema import ColumnCollectionConstraint +from ...sql.sqltypes import TEXT from ...sql.visitors import InternalTraversal +_T = TypeVar("_T", bound=Any) + if TYPE_CHECKING: from ...sql.visitors import _TraverseInternalsType @@ -287,3 +293,205 @@ def array_agg(*arg, **kw): """ kw["_default_array_type"] = ARRAY return functions.func.array_agg(*arg, **kw) + + +class _regconfig_fn(functions.GenericFunction[_T]): + inherit_cache = True + + def __init__(self, *args, **kwargs): + args = list(args) + if len(args) > 1: + + initial_arg = coercions.expect( + roles.ExpressionElementRole, + args.pop(0), + name=getattr(self, "name", None), + apply_propagate_attrs=self, + type_=types.REGCONFIG, + ) + initial_arg = [initial_arg] + else: + initial_arg = [] + + addtl_args = [ + coercions.expect( + roles.ExpressionElementRole, + c, + name=getattr(self, "name", None), + apply_propagate_attrs=self, + ) + for c in args + ] + super().__init__(*(initial_arg + addtl_args), **kwargs) + + +class to_tsvector(_regconfig_fn): + """The PostgreSQL ``to_tsvector`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSVECTOR`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.to_tsvector` will be used automatically when invoking + ``sqlalchemy.func.to_tsvector()``, ensuring the correct argument and return + type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0b5 + + """ + + inherit_cache = True + type = types.TSVECTOR + + +class to_tsquery(_regconfig_fn): + """The PostgreSQL ``to_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.to_tsquery` will be used automatically when invoking + ``sqlalchemy.func.to_tsquery()``, ensuring the correct argument and return + type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0b5 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class plainto_tsquery(_regconfig_fn): + """The PostgreSQL ``plainto_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.plainto_tsquery` will be used automatically when + invoking ``sqlalchemy.func.plainto_tsquery()``, ensuring the correct + argument and return type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0b5 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class phraseto_tsquery(_regconfig_fn): + """The PostgreSQL ``phraseto_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.phraseto_tsquery` will be used automatically when + invoking ``sqlalchemy.func.phraseto_tsquery()``, ensuring the correct + argument and return type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0b5 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class websearch_to_tsquery(_regconfig_fn): + """The PostgreSQL ``websearch_to_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.websearch_to_tsquery` will be used automatically when + invoking ``sqlalchemy.func.websearch_to_tsquery()``, ensuring the correct + argument and return type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0b5 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class ts_headline(_regconfig_fn): + """The PostgreSQL ``ts_headline`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_types.TEXT`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.ts_headline` will be used automatically when invoking + ``sqlalchemy.func.ts_headline()``, ensuring the correct argument and return + type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0b5 + + """ + + inherit_cache = True + type = TEXT + + def __init__(self, *args, **kwargs): + args = list(args) + + # parse types according to + # https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-HEADLINE + if len(args) < 2: + # invalid args; don't do anything + has_regconfig = False + elif ( + isinstance(args[1], elements.ColumnElement) + and args[1].type._type_affinity is types.TSQUERY + ): + # tsquery is second argument, no regconfig argument + has_regconfig = False + else: + has_regconfig = True + + if has_regconfig: + initial_arg = coercions.expect( + roles.ExpressionElementRole, + args.pop(0), + apply_propagate_attrs=self, + name=getattr(self, "name", None), + type_=types.REGCONFIG, + ) + initial_arg = [initial_arg] + else: + initial_arg = [] + + addtl_args = [ + coercions.expect( + roles.ExpressionElementRole, + c, + name=getattr(self, "name", None), + apply_propagate_attrs=self, + ) + for c in args + ] + super().__init__(*(initial_arg + addtl_args), **kwargs) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py index 400c3186ec..67d1370f53 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py @@ -70,6 +70,7 @@ from ._psycopg_common import _PGExecutionContext_common_psycopg from .base import INTERVAL from .base import PGCompiler from .base import PGIdentifierPreparer +from .base import REGCONFIG from .json import JSON from .json import JSONB from .json import JSONPathType @@ -90,6 +91,10 @@ class _PGString(sqltypes.String): render_bind_cast = True +class _PGREGCONFIG(REGCONFIG): + render_bind_cast = True + + class _PGJSON(JSON): render_bind_cast = True @@ -270,6 +275,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): _PGDialect_common_psycopg.colspecs, { sqltypes.String: _PGString, + REGCONFIG: _PGREGCONFIG, JSON: _PGJSON, sqltypes.JSON: _PGJSON, JSONB: _PGJSONB, diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index 72703ff814..49fc70ba39 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -6,7 +6,6 @@ # mypy: ignore-errors import datetime as dt -from typing import Any from ...sql import sqltypes @@ -102,6 +101,28 @@ class OID(sqltypes.TypeEngine[int]): __visit_name__ = "OID" +class REGCONFIG(sqltypes.TypeEngine[str]): + + """Provide the PostgreSQL REGCONFIG type. + + .. versionadded:: 2.0.0b5 + + """ + + __visit_name__ = "REGCONFIG" + + +class TSQUERY(sqltypes.TypeEngine[str]): + + """Provide the PostgreSQL TSQUERY type. + + .. versionadded:: 2.0.0b5 + + """ + + __visit_name__ = "TSQUERY" + + class REGCLASS(sqltypes.TypeEngine[str]): """Provide the PostgreSQL REGCLASS type. @@ -207,7 +228,7 @@ class BIT(sqltypes.TypeEngine[int]): PGBit = BIT -class TSVECTOR(sqltypes.TypeEngine[Any]): +class TSVECTOR(sqltypes.TypeEngine[str]): """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL text search type TSVECTOR. diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index cf5f1c8267..57b147c90d 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -1,6 +1,7 @@ from sqlalchemy import and_ from sqlalchemy import BigInteger from sqlalchemy import bindparam +from sqlalchemy import case from sqlalchemy import cast from sqlalchemy import CheckConstraint from sqlalchemy import Column @@ -28,6 +29,7 @@ from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import Text from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import tuple_ from sqlalchemy import types as sqltypes from sqlalchemy import UniqueConstraint @@ -43,6 +45,8 @@ from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONPATH from sqlalchemy.dialects.postgresql import Range +from sqlalchemy.dialects.postgresql import REGCONFIG +from sqlalchemy.dialects.postgresql import TSQUERY from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.dialects.postgresql.base import PGDialect from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 @@ -54,6 +58,8 @@ from sqlalchemy.sql import literal_column from sqlalchemy.sql import operators from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util +from sqlalchemy.sql.functions import GenericFunction +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message @@ -3183,6 +3189,12 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL): column("title", String(128)), column("body", String(128)), ) + self.matchtable = Table( + "matchtable", + MetaData(), + Column("id", Integer, primary_key=True), + Column("title", String(200)), + ) def _raise_query(self, q): """ @@ -3287,6 +3299,173 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL): """plainto_tsquery('english', %(to_tsvector_2)s)""", ) + @testing.combinations( + ("to_tsvector",), + ("to_tsquery",), + ("plainto_tsquery",), + ("phraseto_tsquery",), + ("websearch_to_tsquery",), + ("ts_headline",), + argnames="to_ts_name", + ) + def test_dont_compile_non_imported(self, to_ts_name): + new_func = type( + to_ts_name, + (GenericFunction,), + { + "_register": False, + "inherit_cache": True, + }, + ) + + with expect_raises_message( + exc.CompileError, + rf"Can't compile \"{to_ts_name}\(\)\" full text search " + f"function construct that does not originate from the " + f'"sqlalchemy.dialects.postgresql" package. ' + f'Please ensure "import sqlalchemy.dialects.postgresql" is ' + f"called before constructing " + rf"\"sqlalchemy.func.{to_ts_name}\(\)\" to ensure " + f"registration of the correct " + f"argument and return types.", + ): + select(new_func("x", "y")).compile(dialect=postgresql.dialect()) + + @testing.combinations( + (func.to_tsvector,), + (func.to_tsquery,), + (func.plainto_tsquery,), + (func.phraseto_tsquery,), + (func.websearch_to_tsquery,), + argnames="to_ts_func", + ) + @testing.variation("use_regconfig", [True, False, "literal"]) + def test_to_regconfig_fns(self, to_ts_func, use_regconfig): + """test #8977""" + matchtable = self.matchtable + + fn_name = to_ts_func().name + + if use_regconfig.literal: + regconfig = literal("english", REGCONFIG) + elif use_regconfig: + regconfig = "english" + else: + regconfig = None + + if regconfig is None: + if fn_name == "to_tsvector": + fn = to_ts_func(matchtable.c.title).match("python") + expected = ( + "to_tsvector(matchtable.title) @@ " + "plainto_tsquery($1::VARCHAR)" + ) + else: + fn = func.to_tsvector(matchtable.c.title).op("@@")( + to_ts_func("python") + ) + expected = ( + f"to_tsvector(matchtable.title) @@ {fn_name}($1::VARCHAR)" + ) + else: + if fn_name == "to_tsvector": + fn = to_ts_func(regconfig, matchtable.c.title).match("python") + expected = ( + "to_tsvector($1::REGCONFIG, matchtable.title) @@ " + "plainto_tsquery($2::VARCHAR)" + ) + else: + fn = func.to_tsvector(matchtable.c.title).op("@@")( + to_ts_func(regconfig, "python") + ) + expected = ( + f"to_tsvector(matchtable.title) @@ " + f"{fn_name}($1::REGCONFIG, $2::VARCHAR)" + ) + + stmt = matchtable.select().where(fn) + + self.assert_compile( + stmt, + "SELECT matchtable.id, matchtable.title " + f"FROM matchtable WHERE {expected}", + dialect="postgresql+asyncpg", + ) + + @testing.variation("use_regconfig", [True, False, "literal"]) + @testing.variation("include_options", [True, False]) + @testing.variation("tsquery_in_expr", [True, False]) + def test_ts_headline( + self, connection, use_regconfig, include_options, tsquery_in_expr + ): + """test #8977""" + if use_regconfig.literal: + regconfig = literal("english", REGCONFIG) + elif use_regconfig: + regconfig = "english" + else: + regconfig = None + + text = ( + "The most common type of search is to find all documents " + "containing given query terms and return them in order of " + "their similarity to the query." + ) + tsquery = func.to_tsquery("english", "query & similarity") + + if regconfig is None: + tsquery_str = "to_tsquery($2::REGCONFIG, $3::VARCHAR)" + else: + tsquery_str = "to_tsquery($3::REGCONFIG, $4::VARCHAR)" + + if tsquery_in_expr: + tsquery = case((true(), tsquery), else_=null()) + tsquery_str = f"CASE WHEN true THEN {tsquery_str} ELSE NULL END" + + is_(tsquery.type._type_affinity, TSQUERY) + + args = [text, tsquery] + if regconfig is not None: + args.insert(0, regconfig) + if include_options: + args.append( + "MaxFragments=10, MaxWords=7, " + "MinWords=3, StartSel=<<, StopSel=>>" + ) + + fn = func.ts_headline(*args) + stmt = select(fn) + + if regconfig is None and not include_options: + self.assert_compile( + stmt, + f"SELECT ts_headline($1::VARCHAR, " + f"{tsquery_str}) AS ts_headline_1", + dialect="postgresql+asyncpg", + ) + elif regconfig is None and include_options: + self.assert_compile( + stmt, + f"SELECT ts_headline($1::VARCHAR, " + f"{tsquery_str}, $4::VARCHAR) AS ts_headline_1", + dialect="postgresql+asyncpg", + ) + elif regconfig is not None and not include_options: + self.assert_compile( + stmt, + f"SELECT ts_headline($1::REGCONFIG, $2::VARCHAR, " + f"{tsquery_str}) AS ts_headline_1", + dialect="postgresql+asyncpg", + ) + else: + self.assert_compile( + stmt, + f"SELECT ts_headline($1::REGCONFIG, $2::VARCHAR, " + f"{tsquery_str}, $5::VARCHAR) " + "AS ts_headline_1", + dialect="postgresql+asyncpg", + ) + class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "postgresql" diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index 2b32d6db7f..7ef7033b29 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -28,6 +28,7 @@ from sqlalchemy import true from sqlalchemy import tuple_ from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import REGCONFIG from sqlalchemy.sql.expression import type_coerce from sqlalchemy.testing import assert_raises from sqlalchemy.testing import AssertsCompiledSQL @@ -1005,6 +1006,110 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL): "matchtable.title @@ plainto_tsquery($1)", ) + @testing.combinations( + (func.to_tsvector,), + (func.to_tsquery,), + (func.plainto_tsquery,), + (func.phraseto_tsquery,), + (func.websearch_to_tsquery,), + argnames="to_ts_func", + ) + @testing.variation("use_regconfig", [True, False, "literal"]) + def test_to_regconfig_fns(self, connection, to_ts_func, use_regconfig): + """test #8977""" + + matchtable = self.tables.matchtable + + fn_name = to_ts_func().name + + if use_regconfig.literal: + regconfig = literal("english", REGCONFIG) + elif use_regconfig: + regconfig = "english" + else: + regconfig = None + + if regconfig is None: + if fn_name == "to_tsvector": + fn = to_ts_func(matchtable.c.title).match("python") + else: + fn = func.to_tsvector(matchtable.c.title).op("@@")( + to_ts_func("python") + ) + else: + if fn_name == "to_tsvector": + fn = to_ts_func(regconfig, matchtable.c.title).match("python") + else: + fn = func.to_tsvector(matchtable.c.title).op("@@")( + to_ts_func(regconfig, "python") + ) + + stmt = matchtable.select().where(fn).order_by(matchtable.c.id) + results = connection.execute(stmt).fetchall() + eq_([2, 5], [r.id for r in results]) + + @testing.variation("use_regconfig", [True, False, "literal"]) + @testing.variation("include_options", [True, False]) + def test_ts_headline(self, connection, use_regconfig, include_options): + """test #8977""" + if use_regconfig.literal: + regconfig = literal("english", REGCONFIG) + elif use_regconfig: + regconfig = "english" + else: + regconfig = None + + text = ( + "The most common type of search is to find all documents " + "containing given query terms and return them in order of " + "their similarity to the query." + ) + tsquery = func.to_tsquery("english", "query & similarity") + + if regconfig is None: + if include_options: + fn = func.ts_headline( + text, + tsquery, + "MaxFragments=10, MaxWords=7, MinWords=3, " + "StartSel=<<, StopSel=>>", + ) + else: + fn = func.ts_headline( + text, + tsquery, + ) + else: + if include_options: + fn = func.ts_headline( + regconfig, + text, + tsquery, + "MaxFragments=10, MaxWords=7, MinWords=3, " + "StartSel=<<, StopSel=>>", + ) + else: + fn = func.ts_headline( + regconfig, + text, + tsquery, + ) + + stmt = select(fn) + + if include_options: + eq_( + connection.scalar(stmt), + "documents containing given <> terms and return ... " + "their <> to the <>", + ) + else: + eq_( + connection.scalar(stmt), + "containing given query terms and return them in " + "order of their similarity to the query.", + ) + def test_simple_match(self, connection): matchtable = self.tables.matchtable results = connection.execute(