]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement MySQL-specific MATCH
authorAnton Kovalevich <kai3341@gmail.com>
Fri, 18 Jun 2021 14:33:48 +0000 (10:33 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 21 Jun 2021 20:39:52 +0000 (16:39 -0400)
Added new construct :class:`_mysql.match`, which provides for the full
range of MySQL's MATCH operator including multiple column support and
modifiers. Pull request courtesy Anton Kovalevich.

Fixes: #6132
Closes: #6133
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/6133
Pull-request-sha: dc6842f13688849a848e2ecbb81600e6edf8b3a9

Change-Id: I66bbfd7947aa2e43a031772e9b5ae238d94e5223

doc/build/changelog/unreleased_14/6132.rst [new file with mode: 0644]
doc/build/dialects/mysql.rst
lib/sqlalchemy/dialects/mysql/__init__.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/mysql/expression.py [new file with mode: 0644]
lib/sqlalchemy/sql/operators.py
test/dialect/mysql/test_compiler.py

diff --git a/doc/build/changelog/unreleased_14/6132.rst b/doc/build/changelog/unreleased_14/6132.rst
new file mode 100644 (file)
index 0000000..59964e6
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: usecase, mysql
+    :tickets: 6132
+
+    Added new construct :class:`_mysql.match`, which provides for the full
+    range of MySQL's MATCH operator including multiple column support and
+    modifiers. Pull request courtesy Anton Kovalevich.
+
+    .. seealso::
+
+        :class:`_mysql.match`
index c0bfa7bc6220ef55582e7dbe91d873782e70c42f..573c2598c0eda2fc8ab23763660c490ffcce55a0 100644 (file)
@@ -5,6 +5,14 @@ MySQL and MariaDB
 
 .. automodule:: sqlalchemy.dialects.mysql.base
 
+MySQL SQL Constructs
+--------------------
+
+.. currentmodule:: sqlalchemy.dialects.mysql
+
+.. autoclass:: match
+    :members:
+
 MySQL Data Types
 ----------------
 
index 20dd68d8f0bfebd38cc11dc8fe9a10c818f91f08..4db05984c25257da90088427cc526f8fd2e758b7 100644 (file)
@@ -49,6 +49,7 @@ from .base import VARCHAR
 from .base import YEAR
 from .dml import Insert
 from .dml import insert
+from .expression import match
 from ...util import compat
 
 if compat.py3k:
@@ -99,4 +100,5 @@ __all__ = (
     "dialect",
     "insert",
     "Insert",
+    "match",
 )
index 92023b3b2db2a217f3716c982186306437c26f23..5ebc83a75bd3fc3face257c22cddc3a60c8b7264 100644 (file)
@@ -443,6 +443,15 @@ available.
 
     select(...).with_hint(some_table, "USE INDEX xyz")
 
+* MATCH operator support::
+
+    from sqlalchemy.dialects.mysql import match
+    select(...).where(match(col1, col2, against="some expr").in_boolean_mode())
+
+    .. seealso::
+
+        :class:`_mysql.match`
+
 .. _mysql_insert_on_duplicate_key_update:
 
 INSERT...ON DUPLICATE KEY UPDATE (Upsert)
@@ -928,6 +937,7 @@ output::
 
 from array import array as _array
 from collections import defaultdict
+from itertools import compress
 import re
 
 from sqlalchemy import literal_column
@@ -1583,11 +1593,67 @@ class MySQLCompiler(compiler.SQLCompiler):
             self.process(binary.right, **kw),
         )
 
-    def visit_match_op_binary(self, binary, operator, **kw):
-        return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % (
-            self.process(binary.left, **kw),
-            self.process(binary.right, **kw),
+    _match_valid_flag_combinations = frozenset(
+        (
+            # (boolean_mode, natural_language, query_expansion)
+            (False, False, False),
+            (True, False, False),
+            (False, True, False),
+            (False, False, True),
+            (False, True, True),
         )
+    )
+
+    _match_flag_expressions = (
+        "IN BOOLEAN MODE",
+        "IN NATURAL LANGUAGE MODE",
+        "WITH QUERY EXPANSION",
+    )
+
+    def visit_mysql_match(self, element, **kw):
+        return self.visit_match_op_binary(element, element.operator, **kw)
+
+    def visit_match_op_binary(self, binary, operator, **kw):
+        """
+        Note that `mysql_boolean_mode` is enabled by default because of
+        backward compatibility
+        """
+
+        modifiers = binary.modifiers
+
+        boolean_mode = modifiers.get("mysql_boolean_mode", True)
+        natural_language = modifiers.get("mysql_natural_language", False)
+        query_expansion = modifiers.get("mysql_query_expansion", False)
+
+        flag_combination = (boolean_mode, natural_language, query_expansion)
+
+        if flag_combination not in self._match_valid_flag_combinations:
+            flags = (
+                "in_boolean_mode=%s" % boolean_mode,
+                "in_natural_language_mode=%s" % natural_language,
+                "with_query_expansion=%s" % query_expansion,
+            )
+
+            flags = ", ".join(flags)
+
+            raise exc.CompileError("Invalid MySQL match flags: %s" % flags)
+
+        match_clause = binary.left
+        match_clause = self.process(match_clause, **kw)
+        against_clause = self.process(binary.right, **kw)
+
+        if any(flag_combination):
+            flag_expressions = compress(
+                self._match_flag_expressions,
+                flag_combination,
+            )
+
+            against_clause = [against_clause]
+            against_clause.extend(flag_expressions)
+
+            against_clause = " ".join(against_clause)
+
+        return "MATCH (%s) AGAINST (%s)" % (match_clause, against_clause)
 
     def get_from_hint_text(self, table, text):
         return text
diff --git a/lib/sqlalchemy/dialects/mysql/expression.py b/lib/sqlalchemy/dialects/mysql/expression.py
new file mode 100644 (file)
index 0000000..d6ef80e
--- /dev/null
@@ -0,0 +1,130 @@
+from ... import exc
+from ... import util
+from ...sql import coercions
+from ...sql import elements
+from ...sql import operators
+from ...sql import roles
+from ...sql.base import _generative
+from ...sql.base import Generative
+
+
+class match(Generative, elements.BinaryExpression):
+    """Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
+
+    E.g.::
+
+        from sqlalchemy import desc
+        from sqlalchemy.dialects.mysql import match
+
+        match_expr = match(
+            users_table.c.firstname,
+            users_table.c.lastname,
+            against="Firstname Lastname",
+        )
+
+        stmt = (
+            select(users_table)
+            .where(match_expr.in_boolean_mode())
+            .order_by(desc(match_expr))
+        )
+
+    Would produce SQL resembling::
+
+        SELECT id, firstname, lastname
+        FROM user
+        WHERE MATCH(firstname, lastname) AGAINST (:param_1 IN BOOLEAN MODE)
+        ORDER BY MATCH(firstname, lastname) AGAINST (:param_2) DESC
+
+    The :func:`_mysql.match` function is a standalone version of the
+    :meth:`_sql.ColumnElement.match` method available on all
+    SQL expressions, as when :meth:`_expression.ColumnElement.match` is
+    used, but allows to pass multiple columns
+
+    :param cols: column expressions to match against
+
+    :param against: expression to be compared towards
+
+    :param in_boolean_mode: boolean, set "boolean mode" to true
+
+    :param in_natural_language_mode: boolean , set "natural language" to true
+
+    :param with_query_expansion: boolean, set "query expansion" to true
+
+    .. versionadded:: 1.4.19
+
+    .. seealso::
+
+        :meth:`_expression.ColumnElement.match`
+
+    """
+
+    __visit_name__ = "mysql_match"
+
+    inherit_cache = True
+
+    def __init__(self, *cols, **kw):
+        if not cols:
+            raise exc.ArgumentError("columns are required")
+
+        against = kw.pop("against", None)
+
+        if not against:
+            raise exc.ArgumentError("against is required")
+        against = coercions.expect(
+            roles.ExpressionElementRole,
+            against,
+        )
+
+        left = elements.BooleanClauseList._construct_raw(
+            operators.comma_op,
+            clauses=cols,
+        )
+        left.group = False
+
+        flags = util.immutabledict(
+            {
+                "mysql_boolean_mode": kw.pop("in_boolean_mode", False),
+                "mysql_natural_language": kw.pop(
+                    "in_natural_language_mode", False
+                ),
+                "mysql_query_expansion": kw.pop("with_query_expansion", False),
+            }
+        )
+
+        if kw:
+            raise exc.ArgumentError("unknown arguments: %s" % (", ".join(kw)))
+
+        super(match, self).__init__(
+            left, against, operators.match_op, modifiers=flags
+        )
+
+    @_generative
+    def in_boolean_mode(self):
+        """Apply the "IN BOOLEAN MODE" modifier to the MATCH expression.
+
+        :return: a new :class:`_mysql.match` instance with modifications
+         applied.
+        """
+
+        self.modifiers = self.modifiers.union({"mysql_boolean_mode": True})
+
+    @_generative
+    def in_natural_language_mode(self):
+        """Apply the "IN NATURAL LANGUAGE MODE" modifier to the MATCH
+        expression.
+
+        :return: a new :class:`_mysql.match` instance with modifications
+         applied.
+        """
+
+        self.modifiers = self.modifiers.union({"mysql_natural_language": True})
+
+    @_generative
+    def with_query_expansion(self):
+        """Apply the "WITH QUERY EXPANSION" modifier to the MATCH expression.
+
+        :return: a new :class:`_mysql.match` instance with modifications
+         applied.
+        """
+
+        self.modifiers = self.modifiers.union({"mysql_query_expansion": True})
index 60f03195cda8d9cedcf514a22297fe043088ac53..408a505aafbffa1bcd0f401d6634e97a2ae8d26a 100644 (file)
@@ -954,6 +954,12 @@ class ColumnOperators(Operators):
 
         * PostgreSQL - renders ``x @@ to_tsquery(y)``
         * MySQL - renders ``MATCH (x) AGAINST (y IN BOOLEAN MODE)``
+
+          .. seealso::
+
+                :class:`_mysql.match` - MySQL specific construct with
+                additional features.
+
         * Oracle - renders ``CONTAINS(x, y)``
         * other backends may provide special implementations.
         * Backends without any special implementation will emit
index 8d311fb6c8f668692a4a8d9e02d294c0054b6bc9..b0a6ee333cfbae4e0d5078c8730016e67b25efa9 100644 (file)
@@ -50,6 +50,7 @@ from sqlalchemy import UnicodeText
 from sqlalchemy import VARCHAR
 from sqlalchemy.dialects.mysql import base as mysql
 from sqlalchemy.dialects.mysql import insert
+from sqlalchemy.dialects.mysql import match
 from sqlalchemy.sql import column
 from sqlalchemy.sql import table
 from sqlalchemy.sql.expression import literal_column
@@ -415,21 +416,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=mysql.dialect(),
         )
 
-    def test_match(self):
-        matchtable = table("matchtable", column("title", String))
-        self.assert_compile(
-            matchtable.c.title.match("somstr"),
-            "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)",
-        )
-
-    def test_match_compile_kw(self):
-        expr = literal("x").match(literal("y"))
-        self.assert_compile(
-            expr,
-            "MATCH ('x') AGAINST ('y' IN BOOLEAN MODE)",
-            literal_binds=True,
-        )
-
     def test_concat_compile_kw(self):
         expr = literal("x", type_=String) + literal("y", type_=String)
         self.assert_compile(expr, "concat('x', 'y')", literal_binds=True)
@@ -1207,3 +1193,151 @@ class RegexpTestMariaDb(fixtures.TestBase, RegexpCommon):
             "REGEXP_REPLACE(mytable.myid, CONCAT('(?', %s, ')', %s), %s)",
             checkpositional=("ig", "pattern", "replacement"),
         )
+
+
+class MatchExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
+
+    __dialect__ = mysql.dialect()
+
+    match_table = table(
+        "user",
+        column("firstname", String),
+        column("lastname", String),
+    )
+
+    @testing.combinations(
+        (
+            lambda title: title.match("somstr", mysql_boolean_mode=False),
+            "MATCH (matchtable.title) AGAINST (%s)",
+        ),
+        (
+            lambda title: title.match(
+                "somstr",
+                mysql_boolean_mode=False,
+                mysql_natural_language=True,
+            ),
+            "MATCH (matchtable.title) AGAINST (%s IN NATURAL LANGUAGE MODE)",
+        ),
+        (
+            lambda title: title.match(
+                "somstr",
+                mysql_boolean_mode=False,
+                mysql_query_expansion=True,
+            ),
+            "MATCH (matchtable.title) AGAINST (%s WITH QUERY EXPANSION)",
+        ),
+        (
+            lambda title: title.match(
+                "somstr",
+                mysql_boolean_mode=False,
+                mysql_natural_language=True,
+                mysql_query_expansion=True,
+            ),
+            "MATCH (matchtable.title) AGAINST "
+            "(%s IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION)",
+        ),
+    )
+    def test_match_expression_single_col(self, case, expected):
+        matchtable = table("matchtable", column("title", String))
+        title = matchtable.c.title
+
+        expr = case(title)
+        self.assert_compile(expr, expected)
+
+    @testing.combinations(
+        (
+            lambda expr: expr,
+            "MATCH (user.firstname, user.lastname) AGAINST (%s)",
+        ),
+        (
+            lambda expr: expr.in_boolean_mode(),
+            "MATCH (user.firstname, user.lastname) AGAINST "
+            "(%s IN BOOLEAN MODE)",
+        ),
+        (
+            lambda expr: expr.in_natural_language_mode(),
+            "MATCH (user.firstname, user.lastname) AGAINST "
+            "(%s IN NATURAL LANGUAGE MODE)",
+        ),
+        (
+            lambda expr: expr.with_query_expansion(),
+            "MATCH (user.firstname, user.lastname) AGAINST "
+            "(%s WITH QUERY EXPANSION)",
+        ),
+        (
+            lambda expr: (
+                expr.in_natural_language_mode().with_query_expansion()
+            ),
+            "MATCH (user.firstname, user.lastname) AGAINST "
+            "(%s IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION)",
+        ),
+    )
+    def test_match_expression_multiple_cols(self, case, expected):
+        firstname = self.match_table.c.firstname
+        lastname = self.match_table.c.lastname
+
+        expr = match(firstname, lastname, against="Firstname Lastname")
+
+        expr = case(expr)
+        self.assert_compile(expr, expected)
+
+    def test_cols_required(self):
+        assert_raises_message(
+            exc.ArgumentError,
+            "columns are required",
+            match,
+            against="Firstname Lastname",
+        )
+
+    @testing.combinations(
+        (True, False, True), (True, True, False), (True, True, True)
+    )
+    def test_invalid_combinations(
+        self, boolean_mode, natural_language, query_expansion
+    ):
+        firstname = self.match_table.c.firstname
+        lastname = self.match_table.c.lastname
+
+        assert_raises_message(
+            exc.ArgumentError,
+            "columns are required",
+            match,
+            against="Firstname Lastname",
+        )
+
+        expr = match(
+            firstname,
+            lastname,
+            against="Firstname Lastname",
+            in_boolean_mode=boolean_mode,
+            in_natural_language_mode=natural_language,
+            with_query_expansion=query_expansion,
+        )
+        msg = (
+            "Invalid MySQL match flags: "
+            "in_boolean_mode=%s, "
+            "in_natural_language_mode=%s, "
+            "with_query_expansion=%s"
+        ) % (boolean_mode, natural_language, query_expansion)
+
+        assert_raises_message(
+            exc.CompileError,
+            msg,
+            expr.compile,
+            dialect=self.__dialect__,
+        )
+
+    def test_match_operator(self):
+        matchtable = table("matchtable", column("title", String))
+        self.assert_compile(
+            matchtable.c.title.match("somstr"),
+            "MATCH (matchtable.title) AGAINST (%s IN BOOLEAN MODE)",
+        )
+
+    def test_literal_binds(self):
+        expr = literal("x").match(literal("y"))
+        self.assert_compile(
+            expr,
+            "MATCH ('x') AGAINST ('y' IN BOOLEAN MODE)",
+            literal_binds=True,
+        )