]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add explicit REGCONFIG, pg full text functions
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 14 Dec 2022 01:07:14 +0000 (20:07 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 15 Dec 2022 15:36:48 +0000 (10:36 -0500)
Added support for explicit use of PG full text functions with asyncpg and
psycopg (SQLAlchemy 2.0 only), with regards to the ``REGCONFIG`` type cast
for the first argument, which previously would be incorrectly cast to a
VARCHAR, causing failures on these dialects that rely upon explicit type
casts. This includes support for :class:`_postgresql.to_tsvector`,
:class:`_postgresql.to_tsquery`, :class:`_postgresql.plainto_tsquery`,
:class:`_postgresql.phraseto_tsquery`,
:class:`_postgresql.websearch_to_tsquery`,
:class:`_postgresql.ts_headline`, each of which will determine based on
number of arguments passed if the first string argument should be
interpreted as a PostgreSQL "REGCONFIG" value; if so, the argument is typed
using a newly added type object :class:`_postgresql.REGCONFIG` which is
then explicitly cast in the SQL expression.

Fixes: #8977
Change-Id: Ib36698a984fd4194bd6e0eb663105f790f3db7d3

doc/build/changelog/unreleased_20/8977.rst [new file with mode: 0644]
doc/build/dialects/postgresql.rst
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/asyncpg.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/ext.py
lib/sqlalchemy/dialects/postgresql/psycopg.py
lib/sqlalchemy/dialects/postgresql/types.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_query.py

diff --git a/doc/build/changelog/unreleased_20/8977.rst b/doc/build/changelog/unreleased_20/8977.rst
new file mode 100644 (file)
index 0000000..904e08b
--- /dev/null
@@ -0,0 +1,19 @@
+.. change::
+    :tags: postgresql, bug
+    :tickets: 8977
+    :versions: 2.0.0b5
+
+    Added support for explicit use of PG full text functions with asyncpg and
+    psycopg (SQLAlchemy 2.0 only), with regards to the ``REGCONFIG`` type cast
+    for the first argument, which previously would be incorrectly cast to a
+    VARCHAR, causing failures on these dialects that rely upon explicit type
+    casts. This includes support for :class:`_postgresql.to_tsvector`,
+    :class:`_postgresql.to_tsquery`, :class:`_postgresql.plainto_tsquery`,
+    :class:`_postgresql.phraseto_tsquery`,
+    :class:`_postgresql.websearch_to_tsquery`,
+    :class:`_postgresql.ts_headline`, each of which will determine based on
+    number of arguments passed if the first string argument should be
+    interpreted as a PostgreSQL "REGCONFIG" value; if so, the argument is typed
+    using a newly added type object :class:`_postgresql.REGCONFIG` which is
+    then explicitly cast in the SQL expression.
+
index 2f541b5abd2d6af1574ecf8358dfa6ff6c8c839c..8f4dc1e72dd7a6bf9cf89389117d135aeea38bbe 100644 (file)
@@ -339,6 +339,9 @@ they originate from :mod:`sqlalchemy.types` or from the local dialect::
         DATERANGE,
         TSRANGE,
         TSTZRANGE,
+        REGCONFIG,
+        REGCLASS,
+        TSQUERY,
         TSVECTOR,
     )
 
@@ -356,18 +359,10 @@ construction arguments, are as follows:
 
 .. autoclass:: sqlalchemy.dialects.postgresql.AbstractMultiRange
 
-.. autoclass:: aggregate_order_by
-
-.. autoclass:: array
 
 .. autoclass:: ARRAY
     :members: __init__, Comparator
 
-.. autofunction:: array_agg
-
-.. autofunction:: Any
-
-.. autofunction:: All
 
 .. autoclass:: BIT
 
@@ -393,10 +388,6 @@ construction arguments, are as follows:
     :members:
 
 
-.. autoclass:: hstore
-    :members:
-
-
 .. autoclass:: INET
 
 .. autoclass:: INTERVAL
@@ -420,6 +411,9 @@ construction arguments, are as follows:
     :members: __init__
     :noindex:
 
+
+.. autoclass:: REGCONFIG
+
 .. autoclass:: REGCLASS
 
 .. autoclass:: TIMESTAMP
@@ -428,6 +422,8 @@ construction arguments, are as follows:
 .. autoclass:: TIME
     :members: __init__
 
+.. autoclass:: TSQUERY
+
 .. autoclass:: TSVECTOR
 
 .. autoclass:: UUID
@@ -471,6 +467,33 @@ construction arguments, are as follows:
 .. autoclass:: TSTZMULTIRANGE
 
 
+PostgreSQL SQL Elements and Functions
+--------------------------------------
+
+.. autoclass:: aggregate_order_by
+
+.. autoclass:: array
+
+.. autofunction:: array_agg
+
+.. autofunction:: Any
+
+.. autofunction:: All
+
+.. autoclass:: hstore
+    :members:
+
+.. autoclass:: to_tsvector
+
+.. autoclass:: to_tsquery
+
+.. autoclass:: plainto_tsquery
+
+.. autoclass:: phraseto_tsquery
+
+.. autoclass:: websearch_to_tsquery
+
+.. autoclass:: ts_headline
 
 PostgreSQL Constraint Types
 ---------------------------
index 7890541ffde05c3afdae4ad7d6bebea54ae0dbe9..d2e213bbc716dcb8e46a4e4d56791f0feeaa1f2e 100644 (file)
@@ -37,6 +37,12 @@ from .dml import insert
 from .ext import aggregate_order_by
 from .ext import array_agg
 from .ext import ExcludeConstraint
+from .ext import phraseto_tsquery
+from .ext import plainto_tsquery
+from .ext import to_tsquery
+from .ext import to_tsvector
+from .ext import ts_headline
+from .ext import websearch_to_tsquery
 from .hstore import HSTORE
 from .hstore import hstore
 from .json import JSON
@@ -72,8 +78,10 @@ from .types import MACADDR
 from .types import MONEY
 from .types import OID
 from .types import REGCLASS
+from .types import REGCONFIG
 from .types import TIME
 from .types import TIMESTAMP
+from .types import TSQUERY
 from .types import TSVECTOR
 
 # Alias psycopg also as psycopg_async
@@ -102,6 +110,9 @@ __all__ = (
     "MONEY",
     "OID",
     "REGCLASS",
+    "REGCONFIG",
+    "TSQUERY",
+    "TSVECTOR",
     "DOUBLE_PRECISION",
     "TIMESTAMP",
     "TIME",
index b8f614eba5edfb692ad4218593c2bb88c731035d..3c1eaf918dfca6111ed38fab4a1560698564e939 100644 (file)
@@ -142,6 +142,7 @@ from .base import PGDialect
 from .base import PGExecutionContext
 from .base import PGIdentifierPreparer
 from .base import REGCLASS
+from .base import REGCONFIG
 from ... import exc
 from ... import pool
 from ... import util
@@ -160,6 +161,10 @@ class AsyncpgString(sqltypes.String):
     render_bind_cast = True
 
 
+class AsyncpgREGCONFIG(REGCONFIG):
+    render_bind_cast = True
+
+
 class AsyncpgTime(sqltypes.Time):
     render_bind_cast = True
 
@@ -899,6 +904,7 @@ class PGDialect_asyncpg(PGDialect):
         PGDialect.colspecs,
         {
             sqltypes.String: AsyncpgString,
+            REGCONFIG: AsyncpgREGCONFIG,
             sqltypes.Time: AsyncpgTime,
             sqltypes.Date: AsyncpgDate,
             sqltypes.DateTime: AsyncpgDateTime,
index f9108094f21ec4bdc090b32d576694cd04f5c12b..8287e828a74c5608b8da1dd587e9a172b02ff905 100644 (file)
@@ -827,6 +827,8 @@ For example, the query::
 
 would generate:
 
+.. sourcecode:: sql
+
     SELECT to_tsquery('cat') @> to_tsquery('cat & rat')
 
 
@@ -840,6 +842,20 @@ produces a statement equivalent to::
 
     SELECT CAST('some text' AS TSVECTOR) AS anon_1
 
+The ``func`` namespace is augmented by the PostgreSQL dialect to set up
+correct argument and return types for most full text search functions.
+These functions are used automatically by the :attr:`_sql.func` namespace
+assuming the ``sqlalchemy.dialects.postgresql`` package has been imported,
+or :func:`_sa.create_engine` has been invoked using a ``postgresql``
+dialect.  These functions are documented at:
+
+* :class:`_postgresql.to_tsvector`
+* :class:`_postgresql.to_tsquery`
+* :class:`_postgresql.plainto_tsquery`
+* :class:`_postgresql.phraseto_tsquery`
+* :class:`_postgresql.websearch_to_tsquery`
+* :class:`_postgresql.ts_headline`
+
 Specifying the "regconfig" with ``match()`` or custom operators
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
@@ -1402,6 +1418,7 @@ from . import hstore as _hstore
 from . import json as _json
 from . import pg_catalog
 from . import ranges as _ranges
+from .ext import _regconfig_fn
 from .ext import aggregate_order_by
 from .named_types import CreateDomainType as CreateDomainType  # noqa: F401
 from .named_types import CreateEnumType as CreateEnumType  # noqa: F401
@@ -1428,6 +1445,7 @@ from .types import PGInterval as PGInterval  # noqa: F401
 from .types import PGMacAddr as PGMacAddr  # noqa: F401
 from .types import PGUuid as PGUuid
 from .types import REGCLASS as REGCLASS
+from .types import REGCONFIG as REGCONFIG  # noqa: F401
 from .types import TIME as TIME
 from .types import TIMESTAMP as TIMESTAMP
 from .types import TSVECTOR as TSVECTOR
@@ -1636,6 +1654,45 @@ ischema_names = {
 
 
 class PGCompiler(compiler.SQLCompiler):
+    def visit_to_tsvector_func(self, element, **kw):
+        return self._assert_pg_ts_ext(element, **kw)
+
+    def visit_to_tsquery_func(self, element, **kw):
+        return self._assert_pg_ts_ext(element, **kw)
+
+    def visit_plainto_tsquery_func(self, element, **kw):
+        return self._assert_pg_ts_ext(element, **kw)
+
+    def visit_phraseto_tsquery_func(self, element, **kw):
+        return self._assert_pg_ts_ext(element, **kw)
+
+    def visit_websearch_to_tsquery_func(self, element, **kw):
+        return self._assert_pg_ts_ext(element, **kw)
+
+    def visit_ts_headline_func(self, element, **kw):
+        return self._assert_pg_ts_ext(element, **kw)
+
+    def _assert_pg_ts_ext(self, element, **kw):
+        if not isinstance(element, _regconfig_fn):
+            # other options here include trying to rewrite the function
+            # with the correct types.  however, that means we have to
+            # "un-SQL-ize" the first argument, which can't work in a
+            # generalized way. Also, parent compiler class has already added
+            # the incorrect return type to the result map.   So let's just
+            # make sure the function we want is used up front.
+
+            raise exc.CompileError(
+                f'Can\'t compile "{element.name}()" full text search '
+                f"function construct that does not originate from the "
+                f'"sqlalchemy.dialects.postgresql" package.  '
+                f'Please ensure "import sqlalchemy.dialects.postgresql" is '
+                f"called before constructing "
+                f'"sqlalchemy.func.{element.name}()" to ensure registration '
+                f"of the correct argument and return types."
+            )
+
+        return f"{element.name}{self.function_argspec(element, **kw)}"
+
     def render_bind_cast(self, type_, dbapi_type, sqltext):
         return f"""{sqltext}::{
                 self.dialect.type_compiler_instance.process(
@@ -2381,6 +2438,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
     def visit_TSVECTOR(self, type_, **kw):
         return "TSVECTOR"
 
+    def visit_TSQUERY(self, type_, **kw):
+        return "TSQUERY"
+
     def visit_INET(self, type_, **kw):
         return "INET"
 
@@ -2396,6 +2456,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
     def visit_OID(self, type_, **kw):
         return "OID"
 
+    def visit_REGCONFIG(self, type_, **kw):
+        return "REGCONFIG"
+
     def visit_REGCLASS(self, type_, **kw):
         return "REGCLASS"
 
index b0d8ef345703cc3331661b7ce170eb8b1cb95c2c..31fbf203be16e3d7edef0453a6a6de25f440fdb0 100644 (file)
@@ -8,8 +8,11 @@
 from __future__ import annotations
 
 from itertools import zip_longest
+from typing import Any
 from typing import TYPE_CHECKING
+from typing import TypeVar
 
+from . import types
 from .array import ARRAY
 from ...sql import coercions
 from ...sql import elements
@@ -18,8 +21,11 @@ from ...sql import functions
 from ...sql import roles
 from ...sql import schema
 from ...sql.schema import ColumnCollectionConstraint
+from ...sql.sqltypes import TEXT
 from ...sql.visitors import InternalTraversal
 
+_T = TypeVar("_T", bound=Any)
+
 if TYPE_CHECKING:
     from ...sql.visitors import _TraverseInternalsType
 
@@ -287,3 +293,205 @@ def array_agg(*arg, **kw):
     """
     kw["_default_array_type"] = ARRAY
     return functions.func.array_agg(*arg, **kw)
+
+
+class _regconfig_fn(functions.GenericFunction[_T]):
+    inherit_cache = True
+
+    def __init__(self, *args, **kwargs):
+        args = list(args)
+        if len(args) > 1:
+
+            initial_arg = coercions.expect(
+                roles.ExpressionElementRole,
+                args.pop(0),
+                name=getattr(self, "name", None),
+                apply_propagate_attrs=self,
+                type_=types.REGCONFIG,
+            )
+            initial_arg = [initial_arg]
+        else:
+            initial_arg = []
+
+        addtl_args = [
+            coercions.expect(
+                roles.ExpressionElementRole,
+                c,
+                name=getattr(self, "name", None),
+                apply_propagate_attrs=self,
+            )
+            for c in args
+        ]
+        super().__init__(*(initial_arg + addtl_args), **kwargs)
+
+
+class to_tsvector(_regconfig_fn):
+    """The PostgreSQL ``to_tsvector`` SQL function.
+
+    This function applies automatic casting of the REGCONFIG argument
+    to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+    and applies a return type of :class:`_postgresql.TSVECTOR`.
+
+    Assuming the PostgreSQL dialect has been imported, either by invoking
+    ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+    engine using ``create_engine("postgresql...")``,
+    :class:`_postgresql.to_tsvector` will be used automatically when invoking
+    ``sqlalchemy.func.to_tsvector()``, ensuring the correct argument and return
+    type handlers are used at compile and execution time.
+
+    .. versionadded:: 2.0.0b5
+
+    """
+
+    inherit_cache = True
+    type = types.TSVECTOR
+
+
+class to_tsquery(_regconfig_fn):
+    """The PostgreSQL ``to_tsquery`` SQL function.
+
+    This function applies automatic casting of the REGCONFIG argument
+    to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+    and applies a return type of :class:`_postgresql.TSQUERY`.
+
+    Assuming the PostgreSQL dialect has been imported, either by invoking
+    ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+    engine using ``create_engine("postgresql...")``,
+    :class:`_postgresql.to_tsquery` will be used automatically when invoking
+    ``sqlalchemy.func.to_tsquery()``, ensuring the correct argument and return
+    type handlers are used at compile and execution time.
+
+    .. versionadded:: 2.0.0b5
+
+    """
+
+    inherit_cache = True
+    type = types.TSQUERY
+
+
+class plainto_tsquery(_regconfig_fn):
+    """The PostgreSQL ``plainto_tsquery`` SQL function.
+
+    This function applies automatic casting of the REGCONFIG argument
+    to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+    and applies a return type of :class:`_postgresql.TSQUERY`.
+
+    Assuming the PostgreSQL dialect has been imported, either by invoking
+    ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+    engine using ``create_engine("postgresql...")``,
+    :class:`_postgresql.plainto_tsquery` will be used automatically when
+    invoking ``sqlalchemy.func.plainto_tsquery()``, ensuring the correct
+    argument and return type handlers are used at compile and execution time.
+
+    .. versionadded:: 2.0.0b5
+
+    """
+
+    inherit_cache = True
+    type = types.TSQUERY
+
+
+class phraseto_tsquery(_regconfig_fn):
+    """The PostgreSQL ``phraseto_tsquery`` SQL function.
+
+    This function applies automatic casting of the REGCONFIG argument
+    to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+    and applies a return type of :class:`_postgresql.TSQUERY`.
+
+    Assuming the PostgreSQL dialect has been imported, either by invoking
+    ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+    engine using ``create_engine("postgresql...")``,
+    :class:`_postgresql.phraseto_tsquery` will be used automatically when
+    invoking ``sqlalchemy.func.phraseto_tsquery()``, ensuring the correct
+    argument and return type handlers are used at compile and execution time.
+
+    .. versionadded:: 2.0.0b5
+
+    """
+
+    inherit_cache = True
+    type = types.TSQUERY
+
+
+class websearch_to_tsquery(_regconfig_fn):
+    """The PostgreSQL ``websearch_to_tsquery`` SQL function.
+
+    This function applies automatic casting of the REGCONFIG argument
+    to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+    and applies a return type of :class:`_postgresql.TSQUERY`.
+
+    Assuming the PostgreSQL dialect has been imported, either by invoking
+    ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+    engine using ``create_engine("postgresql...")``,
+    :class:`_postgresql.websearch_to_tsquery` will be used automatically when
+    invoking ``sqlalchemy.func.websearch_to_tsquery()``, ensuring the correct
+    argument and return type handlers are used at compile and execution time.
+
+    .. versionadded:: 2.0.0b5
+
+    """
+
+    inherit_cache = True
+    type = types.TSQUERY
+
+
+class ts_headline(_regconfig_fn):
+    """The PostgreSQL ``ts_headline`` SQL function.
+
+    This function applies automatic casting of the REGCONFIG argument
+    to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+    and applies a return type of :class:`_types.TEXT`.
+
+    Assuming the PostgreSQL dialect has been imported, either by invoking
+    ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+    engine using ``create_engine("postgresql...")``,
+    :class:`_postgresql.ts_headline` will be used automatically when invoking
+    ``sqlalchemy.func.ts_headline()``, ensuring the correct argument and return
+    type handlers are used at compile and execution time.
+
+    .. versionadded:: 2.0.0b5
+
+    """
+
+    inherit_cache = True
+    type = TEXT
+
+    def __init__(self, *args, **kwargs):
+        args = list(args)
+
+        # parse types according to
+        # https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-HEADLINE
+        if len(args) < 2:
+            # invalid args; don't do anything
+            has_regconfig = False
+        elif (
+            isinstance(args[1], elements.ColumnElement)
+            and args[1].type._type_affinity is types.TSQUERY
+        ):
+            # tsquery is second argument, no regconfig argument
+            has_regconfig = False
+        else:
+            has_regconfig = True
+
+        if has_regconfig:
+            initial_arg = coercions.expect(
+                roles.ExpressionElementRole,
+                args.pop(0),
+                apply_propagate_attrs=self,
+                name=getattr(self, "name", None),
+                type_=types.REGCONFIG,
+            )
+            initial_arg = [initial_arg]
+        else:
+            initial_arg = []
+
+        addtl_args = [
+            coercions.expect(
+                roles.ExpressionElementRole,
+                c,
+                name=getattr(self, "name", None),
+                apply_propagate_attrs=self,
+            )
+            for c in args
+        ]
+        super().__init__(*(initial_arg + addtl_args), **kwargs)
index 400c3186ec7ee50e5248951e334f2e04389e3e10..67d1370f530bfd786215f1d249b6f34fa9de5a99 100644 (file)
@@ -70,6 +70,7 @@ from ._psycopg_common import _PGExecutionContext_common_psycopg
 from .base import INTERVAL
 from .base import PGCompiler
 from .base import PGIdentifierPreparer
+from .base import REGCONFIG
 from .json import JSON
 from .json import JSONB
 from .json import JSONPathType
@@ -90,6 +91,10 @@ class _PGString(sqltypes.String):
     render_bind_cast = True
 
 
+class _PGREGCONFIG(REGCONFIG):
+    render_bind_cast = True
+
+
 class _PGJSON(JSON):
     render_bind_cast = True
 
@@ -270,6 +275,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
         _PGDialect_common_psycopg.colspecs,
         {
             sqltypes.String: _PGString,
+            REGCONFIG: _PGREGCONFIG,
             JSON: _PGJSON,
             sqltypes.JSON: _PGJSON,
             JSONB: _PGJSONB,
index 72703ff814a59dd69a08755ed4cf8b558d1fda2c..49fc70ba3944996cea68e656b9c53cef6020e2ce 100644 (file)
@@ -6,7 +6,6 @@
 # mypy: ignore-errors
 
 import datetime as dt
-from typing import Any
 
 from ...sql import sqltypes
 
@@ -102,6 +101,28 @@ class OID(sqltypes.TypeEngine[int]):
     __visit_name__ = "OID"
 
 
+class REGCONFIG(sqltypes.TypeEngine[str]):
+
+    """Provide the PostgreSQL REGCONFIG type.
+
+    .. versionadded:: 2.0.0b5
+
+    """
+
+    __visit_name__ = "REGCONFIG"
+
+
+class TSQUERY(sqltypes.TypeEngine[str]):
+
+    """Provide the PostgreSQL TSQUERY type.
+
+    .. versionadded:: 2.0.0b5
+
+    """
+
+    __visit_name__ = "TSQUERY"
+
+
 class REGCLASS(sqltypes.TypeEngine[str]):
 
     """Provide the PostgreSQL REGCLASS type.
@@ -207,7 +228,7 @@ class BIT(sqltypes.TypeEngine[int]):
 PGBit = BIT
 
 
-class TSVECTOR(sqltypes.TypeEngine[Any]):
+class TSVECTOR(sqltypes.TypeEngine[str]):
 
     """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
     text search type TSVECTOR.
index cf5f1c82674f6f3ab9705054be714a05278c4ed9..57b147c90d9f309234c46b7500b4075b7f513a53 100644 (file)
@@ -1,6 +1,7 @@
 from sqlalchemy import and_
 from sqlalchemy import BigInteger
 from sqlalchemy import bindparam
+from sqlalchemy import case
 from sqlalchemy import cast
 from sqlalchemy import CheckConstraint
 from sqlalchemy import Column
@@ -28,6 +29,7 @@ from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import Text
 from sqlalchemy import text
+from sqlalchemy import true
 from sqlalchemy import tuple_
 from sqlalchemy import types as sqltypes
 from sqlalchemy import UniqueConstraint
@@ -43,6 +45,8 @@ from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.dialects.postgresql import JSONPATH
 from sqlalchemy.dialects.postgresql import Range
+from sqlalchemy.dialects.postgresql import REGCONFIG
+from sqlalchemy.dialects.postgresql import TSQUERY
 from sqlalchemy.dialects.postgresql import TSRANGE
 from sqlalchemy.dialects.postgresql.base import PGDialect
 from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
@@ -54,6 +58,8 @@ from sqlalchemy.sql import literal_column
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql.functions import GenericFunction
+from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import assert_raises
 from sqlalchemy.testing.assertions import assert_raises_message
@@ -3183,6 +3189,12 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
             column("title", String(128)),
             column("body", String(128)),
         )
+        self.matchtable = Table(
+            "matchtable",
+            MetaData(),
+            Column("id", Integer, primary_key=True),
+            Column("title", String(200)),
+        )
 
     def _raise_query(self, q):
         """
@@ -3287,6 +3299,173 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
             """plainto_tsquery('english', %(to_tsvector_2)s)""",
         )
 
+    @testing.combinations(
+        ("to_tsvector",),
+        ("to_tsquery",),
+        ("plainto_tsquery",),
+        ("phraseto_tsquery",),
+        ("websearch_to_tsquery",),
+        ("ts_headline",),
+        argnames="to_ts_name",
+    )
+    def test_dont_compile_non_imported(self, to_ts_name):
+        new_func = type(
+            to_ts_name,
+            (GenericFunction,),
+            {
+                "_register": False,
+                "inherit_cache": True,
+            },
+        )
+
+        with expect_raises_message(
+            exc.CompileError,
+            rf"Can't compile \"{to_ts_name}\(\)\" full text search "
+            f"function construct that does not originate from the "
+            f'"sqlalchemy.dialects.postgresql" package.  '
+            f'Please ensure "import sqlalchemy.dialects.postgresql" is '
+            f"called before constructing "
+            rf"\"sqlalchemy.func.{to_ts_name}\(\)\" to ensure "
+            f"registration of the correct "
+            f"argument and return types.",
+        ):
+            select(new_func("x", "y")).compile(dialect=postgresql.dialect())
+
+    @testing.combinations(
+        (func.to_tsvector,),
+        (func.to_tsquery,),
+        (func.plainto_tsquery,),
+        (func.phraseto_tsquery,),
+        (func.websearch_to_tsquery,),
+        argnames="to_ts_func",
+    )
+    @testing.variation("use_regconfig", [True, False, "literal"])
+    def test_to_regconfig_fns(self, to_ts_func, use_regconfig):
+        """test #8977"""
+        matchtable = self.matchtable
+
+        fn_name = to_ts_func().name
+
+        if use_regconfig.literal:
+            regconfig = literal("english", REGCONFIG)
+        elif use_regconfig:
+            regconfig = "english"
+        else:
+            regconfig = None
+
+        if regconfig is None:
+            if fn_name == "to_tsvector":
+                fn = to_ts_func(matchtable.c.title).match("python")
+                expected = (
+                    "to_tsvector(matchtable.title) @@ "
+                    "plainto_tsquery($1::VARCHAR)"
+                )
+            else:
+                fn = func.to_tsvector(matchtable.c.title).op("@@")(
+                    to_ts_func("python")
+                )
+                expected = (
+                    f"to_tsvector(matchtable.title) @@ {fn_name}($1::VARCHAR)"
+                )
+        else:
+            if fn_name == "to_tsvector":
+                fn = to_ts_func(regconfig, matchtable.c.title).match("python")
+                expected = (
+                    "to_tsvector($1::REGCONFIG, matchtable.title) @@ "
+                    "plainto_tsquery($2::VARCHAR)"
+                )
+            else:
+                fn = func.to_tsvector(matchtable.c.title).op("@@")(
+                    to_ts_func(regconfig, "python")
+                )
+                expected = (
+                    f"to_tsvector(matchtable.title) @@ "
+                    f"{fn_name}($1::REGCONFIG, $2::VARCHAR)"
+                )
+
+        stmt = matchtable.select().where(fn)
+
+        self.assert_compile(
+            stmt,
+            "SELECT matchtable.id, matchtable.title "
+            f"FROM matchtable WHERE {expected}",
+            dialect="postgresql+asyncpg",
+        )
+
+    @testing.variation("use_regconfig", [True, False, "literal"])
+    @testing.variation("include_options", [True, False])
+    @testing.variation("tsquery_in_expr", [True, False])
+    def test_ts_headline(
+        self, connection, use_regconfig, include_options, tsquery_in_expr
+    ):
+        """test #8977"""
+        if use_regconfig.literal:
+            regconfig = literal("english", REGCONFIG)
+        elif use_regconfig:
+            regconfig = "english"
+        else:
+            regconfig = None
+
+        text = (
+            "The most common type of search is to find all documents "
+            "containing given query terms and return them in order of "
+            "their similarity to the query."
+        )
+        tsquery = func.to_tsquery("english", "query & similarity")
+
+        if regconfig is None:
+            tsquery_str = "to_tsquery($2::REGCONFIG, $3::VARCHAR)"
+        else:
+            tsquery_str = "to_tsquery($3::REGCONFIG, $4::VARCHAR)"
+
+        if tsquery_in_expr:
+            tsquery = case((true(), tsquery), else_=null())
+            tsquery_str = f"CASE WHEN true THEN {tsquery_str} ELSE NULL END"
+
+        is_(tsquery.type._type_affinity, TSQUERY)
+
+        args = [text, tsquery]
+        if regconfig is not None:
+            args.insert(0, regconfig)
+        if include_options:
+            args.append(
+                "MaxFragments=10, MaxWords=7, "
+                "MinWords=3, StartSel=<<, StopSel=>>"
+            )
+
+        fn = func.ts_headline(*args)
+        stmt = select(fn)
+
+        if regconfig is None and not include_options:
+            self.assert_compile(
+                stmt,
+                f"SELECT ts_headline($1::VARCHAR, "
+                f"{tsquery_str}) AS ts_headline_1",
+                dialect="postgresql+asyncpg",
+            )
+        elif regconfig is None and include_options:
+            self.assert_compile(
+                stmt,
+                f"SELECT ts_headline($1::VARCHAR, "
+                f"{tsquery_str}, $4::VARCHAR) AS ts_headline_1",
+                dialect="postgresql+asyncpg",
+            )
+        elif regconfig is not None and not include_options:
+            self.assert_compile(
+                stmt,
+                f"SELECT ts_headline($1::REGCONFIG, $2::VARCHAR, "
+                f"{tsquery_str}) AS ts_headline_1",
+                dialect="postgresql+asyncpg",
+            )
+        else:
+            self.assert_compile(
+                stmt,
+                f"SELECT ts_headline($1::REGCONFIG, $2::VARCHAR, "
+                f"{tsquery_str}, $5::VARCHAR) "
+                "AS ts_headline_1",
+                dialect="postgresql+asyncpg",
+            )
+
 
 class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "postgresql"
index 2b32d6db7fbcdcbe9d51356bb8d052def2196963..7ef7033b2929dd2228faa0554f9c755d488d2353 100644 (file)
@@ -28,6 +28,7 @@ from sqlalchemy import true
 from sqlalchemy import tuple_
 from sqlalchemy.dialects import postgresql
 from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.postgresql import REGCONFIG
 from sqlalchemy.sql.expression import type_coerce
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import AssertsCompiledSQL
@@ -1005,6 +1006,110 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
                 "matchtable.title @@ plainto_tsquery($1)",
             )
 
+    @testing.combinations(
+        (func.to_tsvector,),
+        (func.to_tsquery,),
+        (func.plainto_tsquery,),
+        (func.phraseto_tsquery,),
+        (func.websearch_to_tsquery,),
+        argnames="to_ts_func",
+    )
+    @testing.variation("use_regconfig", [True, False, "literal"])
+    def test_to_regconfig_fns(self, connection, to_ts_func, use_regconfig):
+        """test #8977"""
+
+        matchtable = self.tables.matchtable
+
+        fn_name = to_ts_func().name
+
+        if use_regconfig.literal:
+            regconfig = literal("english", REGCONFIG)
+        elif use_regconfig:
+            regconfig = "english"
+        else:
+            regconfig = None
+
+        if regconfig is None:
+            if fn_name == "to_tsvector":
+                fn = to_ts_func(matchtable.c.title).match("python")
+            else:
+                fn = func.to_tsvector(matchtable.c.title).op("@@")(
+                    to_ts_func("python")
+                )
+        else:
+            if fn_name == "to_tsvector":
+                fn = to_ts_func(regconfig, matchtable.c.title).match("python")
+            else:
+                fn = func.to_tsvector(matchtable.c.title).op("@@")(
+                    to_ts_func(regconfig, "python")
+                )
+
+        stmt = matchtable.select().where(fn).order_by(matchtable.c.id)
+        results = connection.execute(stmt).fetchall()
+        eq_([2, 5], [r.id for r in results])
+
+    @testing.variation("use_regconfig", [True, False, "literal"])
+    @testing.variation("include_options", [True, False])
+    def test_ts_headline(self, connection, use_regconfig, include_options):
+        """test #8977"""
+        if use_regconfig.literal:
+            regconfig = literal("english", REGCONFIG)
+        elif use_regconfig:
+            regconfig = "english"
+        else:
+            regconfig = None
+
+        text = (
+            "The most common type of search is to find all documents "
+            "containing given query terms and return them in order of "
+            "their similarity to the query."
+        )
+        tsquery = func.to_tsquery("english", "query & similarity")
+
+        if regconfig is None:
+            if include_options:
+                fn = func.ts_headline(
+                    text,
+                    tsquery,
+                    "MaxFragments=10, MaxWords=7, MinWords=3, "
+                    "StartSel=<<, StopSel=>>",
+                )
+            else:
+                fn = func.ts_headline(
+                    text,
+                    tsquery,
+                )
+        else:
+            if include_options:
+                fn = func.ts_headline(
+                    regconfig,
+                    text,
+                    tsquery,
+                    "MaxFragments=10, MaxWords=7, MinWords=3, "
+                    "StartSel=<<, StopSel=>>",
+                )
+            else:
+                fn = func.ts_headline(
+                    regconfig,
+                    text,
+                    tsquery,
+                )
+
+        stmt = select(fn)
+
+        if include_options:
+            eq_(
+                connection.scalar(stmt),
+                "documents containing given <<query>> terms and return ... "
+                "their <<similarity>> to the <<query>>",
+            )
+        else:
+            eq_(
+                connection.scalar(stmt),
+                "containing given <b>query</b> terms and return them in "
+                "order of their <b>similarity</b> to the <b>query</b>.",
+            )
+
     def test_simple_match(self, connection):
         matchtable = self.tables.matchtable
         results = connection.execute(