From 48cc2c0535cd7821ff186c03b6eb08399c0330df Mon Sep 17 00:00:00 2001 From: Anton Kovalevich Date: Wed, 31 Mar 2021 15:22:39 +0300 Subject: [PATCH] Change API --- lib/sqlalchemy/dialects/mysql/base.py | 48 ++++++-- lib/sqlalchemy/dialects/mysql/expression.py | 103 +++++++++++++----- .../dialects/mysql/expression_enum.py | 11 -- 3 files changed, 110 insertions(+), 52 deletions(-) delete mode 100644 lib/sqlalchemy/dialects/mysql/expression_enum.py diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 8eb2250ac7..264d238f20 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -928,13 +928,13 @@ output:: from array import array as _array from collections import defaultdict +from itertools import compress import re from sqlalchemy import literal_column from sqlalchemy import text from sqlalchemy.sql import visitors from . import reflection as _reflection -from .expression_enum import MatchExpressionModifier from .enumerated import ENUM from .enumerated import SET from .json import JSON @@ -1588,24 +1588,50 @@ class MySQLCompiler(compiler.SQLCompiler): self.process(binary.right, **kw), ) + match_valid_flag_combinations = frozenset(( + # (boolean_mode, natural_language, query_expansion) + (False, False, False), + (True, False, False), + (False, True, False), + (False, False, True), + (False, True, True), + )) + + match_flag_expressions = ( + 'IN BOOLEAN MODE', + 'IN NATURAL LANGUAGE MODE', + 'WITH QUERY EXPANSION', + ) + def visit_match_op_binary(self, binary, operator, **kw): - modifier = kw.pop('modifier', MatchExpressionModifier.in_boolean_mode) + """ + Note that `mysql_boolean_mode` is enabled by default because of + backward compatibility + """ - match_clause = self.process(binary.left, **kw) - against_clause = self.process(binary.right, **kw) + boolean_mode = kw.pop('mysql_boolean_mode', True) + natural_language = kw.pop('mysql_natural_language', False) + query_expansion = kw.pop('mysql_query_expansion', False) - if modifier: - if not isinstance(modifier, MatchExpressionModifier): - raise exc.CompileError( + flag_combination = (boolean_mode, natural_language, query_expansion) + + if flag_combination not in self.match_valid_flag_combinations: + raise exc.CompileError( "The `modifier` keyword argument must be a member of " "`sqlalchemy.mysql.expression_enum." "MatchExpressionModifier` enum or `None`" ) - against_clause = ' '.join(( - against_clause, - modifier.value, - )) + match_clause = self.process(binary.left, **kw) + against_clause = self.process(binary.right, **kw) + + if any(flag_combination): + flag_expressions = compress( + self.match_flag_expressions, + flag_combination, + ) + against_clause = (against_clause, *flag_expressions) + against_clause = ' '.join(against_clause) return "MATCH (%s) AGAINST (%s)" % (match_clause, against_clause) diff --git a/lib/sqlalchemy/dialects/mysql/expression.py b/lib/sqlalchemy/dialects/mysql/expression.py index d3e1c14c7b..1568258bf8 100644 --- a/lib/sqlalchemy/dialects/mysql/expression.py +++ b/lib/sqlalchemy/dialects/mysql/expression.py @@ -1,57 +1,69 @@ -from ...sql.elements import ClauseElementBatch +from functools import wraps +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.elements import ( + ColumnElement, + ClauseElementBatch, +) -def match(*clauselist, against, modifier=None, **kwargs): + +def property_enables_flag(flag_name): + def wrapper(target): + @property + @wraps(target) + def inner(self): + new_flags = self.flags.copy() + new_flags[flag_name] = True + + return match( + self.clause, + against=self.against, + flags=new_flags, + ) + + return inner + return wrapper + + +class match(ColumnElement): """Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause. E.g.:: + from sqlalchemy import desc from sqlalchemy.mysql.dialects.mysql.expression import match - from sqlalchemy.mysql.dialects.mysql.expression_enum \ - import MatchExpressionModifier - - match_columns_where = match( - users_table.c.firstname, - users_table.c.lastname, - against="John Connor", - modifier=MatchExpressionModifier.in_boolean_mode, - ) - - match_columns_order = match( + match_expr = match( users_table.c.firstname, users_table.c.lastname, against="John Connor", ) stmt = select(users_table)\ - .where(match_columns_where)\ - .order_by(match_columns_order) + .where(match_expr.in_boolean_mode)\ + .order_by(desc(match_expr)) Would produce SQL resembling:: - SELECT id, firstname, lastname FROM user - WHERE MATCH(firstname, lastname) - AGAINST (:param_1 IN BOOLEAN MODE) - ORDER BY MATCH(firstname, lastname) AGAINST (:param_2) + SELECT id, firstname, lastname + FROM user + WHERE MATCH(firstname, lastname) AGAINST (:param_1 IN BOOLEAN MODE) + ORDER BY MATCH(firstname, lastname) AGAINST (:param_2) DESC The :func:`.match` function is a standalone version of the :meth:`_expression.ColumnElement.match` method available on all SQL expressions, as when :meth:`_expression.ColumnElement.match` is used, but allows to pass multiple columns - All positional arguments passed to :func:`.match`, should - be :class:`_expression.ColumnElement` subclass. - - :param clauselist: a column iterator, typically a + All positional arguments passed to :func:`.match`, typically should be a :class:`_expression.ColumnElement` instances or alternatively a Python - scalar expression to be coerced into a column expression, - serving as the ``MATCH`` side of expression. + scalar expression to be coerced into a column expression, serving as + the ``MATCH`` side of expression. - :param modifier: ``None`` or member of - :class:`.expression_enum.MatchExpressionModifier`. + :param against: typically scalar expression to be coerced into a ``str``, + but may be a :class:`_expression.ColumnElement` instance - : + :param flags: optional ``dict`` .. versionadded:: 1.4.4 @@ -61,5 +73,36 @@ def match(*clauselist, against, modifier=None, **kwargs): """ - clause_batch = ClauseElementBatch(*clauselist, group=False) - return clause_batch.match(against, modifier=modifier, **kwargs) + default_flags = { + 'mysql_boolean_mode': False, + 'mysql_natural_language': False, + 'mysql_query_expansion': False, + } + + def __init__(self, *clauselist, against, flags=None): + if len(clauselist) == 1: + self.clause = clauselist[0] + else: + self.clause = ClauseElementBatch(*clauselist, group=False) + + self.against = against + self.flags = flags or self.default_flags.copy() + + @property_enables_flag('mysql_boolean_mode') + def in_boolean_mode(self): ... + + @property_enables_flag('mysql_natural_language') + def in_natural_language_mode(self): ... + + @property_enables_flag('mysql_query_expansion') + def with_query_expansion(self): ... + + +@compiles(match, "mysql") +def visit_match(element: match, compiler, **kw): + target = element.clause.match( + element.against, + **element.flags + ) + + return compiler.process(target, **kw) diff --git a/lib/sqlalchemy/dialects/mysql/expression_enum.py b/lib/sqlalchemy/dialects/mysql/expression_enum.py deleted file mode 100644 index 1548dba07a..0000000000 --- a/lib/sqlalchemy/dialects/mysql/expression_enum.py +++ /dev/null @@ -1,11 +0,0 @@ -import enum - - -class MatchExpressionModifier(enum.Enum): - in_natural_language_mode = 'IN NATURAL LANGUAGE MODE' - - in_natural_language_mode_with_query_expansion = \ - 'IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION' - - in_boolean_mode = 'IN BOOLEAN MODE' - with_query_expansion = 'WITH QUERY EXPANSION' -- 2.47.3