]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- [feature] Reworked the startswith(), endswith(),
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Aug 2012 23:40:12 +0000 (19:40 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Aug 2012 23:40:12 +0000 (19:40 -0400)
    contains() operators to do a better job with
    negation (NOT LIKE), and also to assemble them
    at compilation time so that their rendered SQL
    can be altered, such as in the case for Firebird
    STARTING WITH [ticket:2470]
  - [feature] firebird - The "startswith()" operator renders
    as "STARTING WITH", "~startswith()" renders
    as "NOT STARTING WITH", using FB's more efficient
    operator.  [ticket:2470]

CHANGES
lib/sqlalchemy/dialects/firebird/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/operators.py
test/sql/test_compiler.py
test/sql/test_operators.py

diff --git a/CHANGES b/CHANGES
index b8eff05552b58d733f5b5ead5ef9a888c27f5a62..f34edf19468882ac22ed297e8dbfc66735f86852 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -521,6 +521,13 @@ underneath "0.7.xx".
     name.  The deprecated fold_equivalents() feature is
     removed [ticket:1729].
 
+  - [feature] Reworked the startswith(), endswith(),
+    contains() operators to do a better job with
+    negation (NOT LIKE), and also to assemble them
+    at compilation time so that their rendered SQL
+    can be altered, such as in the case for Firebird
+    STARTING WITH [ticket:2470]
+
   - [bug] Fixes to the interpretation of the
     Column "default" parameter as a callable
     to not pass ExecutionContext into a keyword
@@ -600,6 +607,11 @@ underneath "0.7.xx".
     The phrase is established using with_hint().
     Courtesy Ryan Kelly [ticket:2506]
 
+- firebird
+  - [feature] The "startswith()" operator renders
+    as "STARTING WITH", "~startswith()" renders
+    as "NOT STARTING WITH", using FB's more efficient
+    operator.  [ticket:2470]
 
 - mysql
   - [bug] Dialect no longer emits expensive server
index f7877a901f3d5fa1fdacb31b0a3237562406deeb..b4b856804e8fa4a775ecd9f224e7c32ca3b7b5a8 100644 (file)
@@ -200,6 +200,22 @@ class FBTypeCompiler(compiler.GenericTypeCompiler):
 class FBCompiler(sql.compiler.SQLCompiler):
     """Firebird specific idiosyncrasies"""
 
+    #def visit_contains_op_binary(self, binary, operator, **kw):
+        # cant use CONTAINING b.c. it's case insensitive.
+
+    #def visit_notcontains_op_binary(self, binary, operator, **kw):
+        # cant use NOT CONTAINING b.c. it's case insensitive.
+
+    def visit_startswith_op_binary(self, binary, operator, **kw):
+        return '%s STARTING WITH %s' % (
+                            binary.left._compiler_dispatch(self, **kw),
+                            binary.right._compiler_dispatch(self, **kw))
+
+    def visit_notstartswith_op_binary(self, binary, operator, **kw):
+        return '%s NOT STARTING WITH %s' % (
+                            binary.left._compiler_dispatch(self, **kw),
+                            binary.right._compiler_dispatch(self, **kw))
+
     def visit_mod_binary(self, binary, operator, **kw):
         return "mod(%s, %s)" % (
                                 self.process(binary.left, **kw),
@@ -265,9 +281,9 @@ class FBCompiler(sql.compiler.SQLCompiler):
 
         result = ""
         if select._limit:
-            result += "FIRST %s "  % self.process(sql.literal(select._limit))
+            result += "FIRST %s " % self.process(sql.literal(select._limit))
         if select._offset:
-            result +="SKIP %s "  %  self.process(sql.literal(select._offset))
+            result += "SKIP %s " % self.process(sql.literal(select._offset))
         if select._distinct:
             result += "DISTINCT "
         return result
index 8e4f0288f8dd70efb7b5e64d28874e6a791ad289..297cd9adb0154c54e069bd9bc8602d177772ea07 100644 (file)
@@ -24,7 +24,7 @@ To generate user-defined SQL strings, see
 
 import re
 import sys
-from .. import schema, engine, util, exc
+from .. import schema, engine, util, exc, types
 from . import (
     operators, functions, util as sql_util, visitors, expression as sql
 )
@@ -670,6 +670,50 @@ class SQLCompiler(engine.Compiled):
     def _generate_generic_unary_modifier(self, unary, opstring, **kw):
         return unary.element._compiler_dispatch(self, **kw) + opstring
 
+    @util.memoized_property
+    def _like_percent_literal(self):
+        return sql.literal_column("'%'", type_=types.String())
+
+    def visit_contains_op_binary(self, binary, operator, **kw):
+        binary = binary._clone()
+        percent = self._like_percent_literal
+        binary.right = percent.__add__(binary.right).__add__(percent)
+        return self.visit_like_op_binary(binary, operator, **kw)
+
+    def visit_notcontains_op_binary(self, binary, operator, **kw):
+        binary = binary._clone()
+        percent = self._like_percent_literal
+        binary.right = percent.__add__(binary.right).__add__(percent)
+        return self.visit_notlike_op_binary(binary, operator, **kw)
+
+    def visit_startswith_op_binary(self, binary, operator, **kw):
+        binary = binary._clone()
+        percent = self._like_percent_literal
+        binary.right = percent.__radd__(
+                    binary.right
+                )
+        return self.visit_like_op_binary(binary, operator, **kw)
+
+    def visit_notstartswith_op_binary(self, binary, operator, **kw):
+        binary = binary._clone()
+        percent = self._like_percent_literal
+        binary.right = percent.__radd__(
+                    binary.right
+                )
+        return self.visit_notlike_op_binary(binary, operator, **kw)
+
+    def visit_endswith_op_binary(self, binary, operator, **kw):
+        binary = binary._clone()
+        percent = self._like_percent_literal
+        binary.right = percent.__add__(binary.right)
+        return self.visit_like_op_binary(binary, operator, **kw)
+
+    def visit_notendswith_op_binary(self, binary, operator, **kw):
+        binary = binary._clone()
+        percent = self._like_percent_literal
+        binary.right = percent.__add__(binary.right)
+        return self.visit_notlike_op_binary(binary, operator, **kw)
+
     def visit_like_op_binary(self, binary, operator, **kw):
         escape = binary.modifiers.get("escape", None)
         return '%s LIKE %s' % (
index 0e8a46b600f2edc80de8dca3c5e5d5d5f80b047c..2583e6510827ddf0afd28ebe0c9fb879562ca3de 100644 (file)
@@ -2049,37 +2049,6 @@ class _DefaultColumnComparator(operators.ColumnOperators):
         """See :meth:`.ColumnOperators.__neg__`."""
         return UnaryExpression(expr, operator=operators.neg)
 
-    def _startswith_impl(self, expr, op, other, escape=None, **kw):
-        """See :meth:`.ColumnOperators.startswith`."""
-        # use __radd__ to force string concat behavior
-        return self._boolean_compare(
-            expr,
-            operators.like_op,
-            literal_column("'%'", type_=sqltypes.String).__radd__(
-                                self._check_literal(expr,
-                                        operators.like_op, other)
-                            ),
-            escape=escape)
-
-    def _endswith_impl(self, expr, op, other, escape=None, **kw):
-        """See :meth:`.ColumnOperators.endswith`."""
-        return self._boolean_compare(
-            expr,
-            operators.like_op,
-            literal_column("'%'", type_=sqltypes.String) +
-                self._check_literal(expr, operators.like_op, other),
-            escape=escape)
-
-    def _contains_impl(self, expr, op, other, escape=None, **kw):
-        """See :meth:`.ColumnOperators.contains`."""
-        return self._boolean_compare(
-            expr,
-            operators.like_op,
-            literal_column("'%'", type_=sqltypes.String) +
-                self._check_literal(expr, operators.like_op, other) +
-                literal_column("'%'", type_=sqltypes.String),
-            escape=escape)
-
     def _match_impl(self, expr, op, other, **kw):
         """See :meth:`.ColumnOperators.match`."""
         return self._boolean_compare(expr, operators.match_op,
@@ -2124,6 +2093,9 @@ class _DefaultColumnComparator(operators.ColumnOperators):
         "eq": (_boolean_compare, operators.ne),
         "like_op": (_boolean_compare, operators.notlike_op),
         "ilike_op": (_boolean_compare, operators.notilike_op),
+        "contains_op": (_boolean_compare, operators.notcontains_op),
+        "startswith_op": (_boolean_compare, operators.notstartswith_op),
+        "endswith_op": (_boolean_compare, operators.notendswith_op),
         "desc_op": (_scalar, desc),
         "asc_op": (_scalar, asc),
         "nullsfirst_op": (_scalar, nullsfirst),
@@ -2133,9 +2105,6 @@ class _DefaultColumnComparator(operators.ColumnOperators):
         "match_op": (_match_impl,),
         "distinct_op": (_distinct_impl,),
         "between_op": (_between_impl, ),
-        "contains_op": (_contains_impl, ),
-        "startswith_op": (_startswith_impl,),
-        "endswith_op": (_endswith_impl,),
         "neg": (_neg_impl,),
         "getitem": (_unsupported_impl,),
     }
index f1607c884d0f2a3c680b8e45257b2eb0885e3580..ba33d016aa3a22945446b8af6e01cc3538c21c23 100644 (file)
@@ -558,12 +558,21 @@ def distinct_op(a):
 def startswith_op(a, b, escape=None):
     return a.startswith(b, escape=escape)
 
+def notstartswith_op(a, b, escape=None):
+    return ~a.startswith(b, escape=escape)
+
 def endswith_op(a, b, escape=None):
     return a.endswith(b, escape=escape)
 
+def notendswith_op(a, b, escape=None):
+    return ~a.endswith(b, escape=escape)
+
 def contains_op(a, b, escape=None):
     return a.contains(b, escape=escape)
 
+def notcontains_op(a, b, escape=None):
+    return ~a.contains(b, escape=escape)
+
 def match_op(a, b):
     return a.match(b)
 
index 40d29f2220799cf018ced3cc02f17341bbead03d..356f2e8b1f056646b5c519aa274dc7c42730ef2d 100644 (file)
@@ -1029,61 +1029,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL):
         ]:
             self.assert_compile(expr, check, dialect=dialect)
 
-    def test_composed_string_comparators(self):
-        self.assert_compile(
-            table1.c.name.contains('jo'),
-            "mytable.name LIKE '%%' || :name_1 || '%%'" ,
-            checkparams = {'name_1': u'jo'},
-        )
-        self.assert_compile(
-            table1.c.name.contains('jo'),
-            "mytable.name LIKE concat(concat('%%', %s), '%%')" ,
-            checkparams = {'name_1': u'jo'},
-            dialect=mysql.dialect()
-        )
-        self.assert_compile(
-            table1.c.name.contains('jo', escape='\\'),
-            "mytable.name LIKE '%%' || :name_1 || '%%' ESCAPE '\\'" ,
-            checkparams = {'name_1': u'jo'},
-        )
-        self.assert_compile(
-            table1.c.name.startswith('jo', escape='\\'),
-            "mytable.name LIKE :name_1 || '%%' ESCAPE '\\'" )
-        self.assert_compile(
-            table1.c.name.endswith('jo', escape='\\'),
-            "mytable.name LIKE '%%' || :name_1 ESCAPE '\\'" )
-        self.assert_compile(
-            table1.c.name.endswith('hn'),
-            "mytable.name LIKE '%%' || :name_1",
-            checkparams = {'name_1': u'hn'}, )
-        self.assert_compile(
-            table1.c.name.endswith('hn'),
-            "mytable.name LIKE concat('%%', %s)",
-            checkparams = {'name_1': u'hn'}, dialect=mysql.dialect()
-        )
-        self.assert_compile(
-            table1.c.name.startswith(u"hi \xf6 \xf5"),
-            "mytable.name LIKE :name_1 || '%%'",
-            checkparams = {'name_1': u'hi \xf6 \xf5'},
-        )
-        self.assert_compile(
-                column('name').endswith(text("'foo'")),
-                "name LIKE '%%' || 'foo'"  )
-        self.assert_compile(
-                column('name').endswith(literal_column("'foo'")),
-                "name LIKE '%%' || 'foo'"  )
-        self.assert_compile(
-                column('name').startswith(text("'foo'")),
-                "name LIKE 'foo' || '%%'"  )
-        self.assert_compile(
-                column('name').startswith(text("'foo'")),
-                 "name LIKE concat('foo', '%%')", dialect=mysql.dialect())
-        self.assert_compile(
-                column('name').startswith(literal_column("'foo'")),
-                "name LIKE 'foo' || '%%'"  )
-        self.assert_compile(
-                column('name').startswith(literal_column("'foo'")),
-                "name LIKE concat('foo', '%%')", dialect=mysql.dialect())
 
     def test_multiple_col_binds(self):
         self.assert_compile(
index 26a36fd3456af2fcac4b3de6b95c5b216ec225ce..69a22172f02e312d068c6ba498e60c46fd9ea174 100644 (file)
@@ -2,12 +2,15 @@ from test.lib import fixtures, testing
 from test.lib.testing import assert_raises_message
 from sqlalchemy.sql import column, desc, asc, literal, collate
 from sqlalchemy.sql.expression import BinaryExpression, \
-                ClauseList, Grouping, _DefaultColumnComparator,\
+                ClauseList, Grouping, \
                 UnaryExpression
 from sqlalchemy.sql import operators
 from sqlalchemy import exc
 from sqlalchemy.schema import Column, Table, MetaData
 from sqlalchemy.types import Integer, TypeEngine, TypeDecorator
+from sqlalchemy.dialects import mysql, firebird
+
+from sqlalchemy import text, literal_column
 
 class DefaultColumnComparatorTest(fixtures.TestBase):
 
@@ -320,3 +323,244 @@ class OperatorAssociativityTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         self.assert_compile(f / (f / (f - f)), "f / (f / (f - f))")
 
 
+class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL):
+    __dialect__ = 'default'
+
+    def test_contains(self):
+        self.assert_compile(
+            column('x').contains('y'),
+            "x LIKE '%%' || :x_1 || '%%'",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_contains_escape(self):
+        self.assert_compile(
+            column('x').contains('y', escape='\\'),
+            "x LIKE '%%' || :x_1 || '%%' ESCAPE '\\'",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_contains_literal(self):
+        self.assert_compile(
+            column('x').contains(literal_column('y')),
+            "x LIKE '%%' || y || '%%'",
+            checkparams={}
+        )
+
+    def test_contains_text(self):
+        self.assert_compile(
+            column('x').contains(text('y')),
+            "x LIKE '%%' || y || '%%'",
+            checkparams={}
+        )
+
+    def test_not_contains(self):
+        self.assert_compile(
+            ~column('x').contains('y'),
+            "x NOT LIKE '%%' || :x_1 || '%%'",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_not_contains_escape(self):
+        self.assert_compile(
+            ~column('x').contains('y', escape='\\'),
+            "x NOT LIKE '%%' || :x_1 || '%%' ESCAPE '\\'",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_contains_concat(self):
+        self.assert_compile(
+            column('x').contains('y'),
+            "x LIKE concat(concat('%%', %s), '%%')",
+            checkparams={'x_1': 'y'},
+            dialect=mysql.dialect()
+        )
+
+    def test_not_contains_concat(self):
+        self.assert_compile(
+            ~column('x').contains('y'),
+            "x NOT LIKE concat(concat('%%', %s), '%%')",
+            checkparams={'x_1': 'y'},
+            dialect=mysql.dialect()
+        )
+
+    def test_contains_literal_concat(self):
+        self.assert_compile(
+            column('x').contains(literal_column('y')),
+            "x LIKE concat(concat('%%', y), '%%')",
+            checkparams={},
+            dialect=mysql.dialect()
+        )
+
+    def test_contains_text_concat(self):
+        self.assert_compile(
+            column('x').contains(text('y')),
+            "x LIKE concat(concat('%%', y), '%%')",
+            checkparams={},
+            dialect=mysql.dialect()
+        )
+
+    def test_startswith(self):
+        self.assert_compile(
+            column('x').startswith('y'),
+            "x LIKE :x_1 || '%%'",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_startswith_escape(self):
+        self.assert_compile(
+            column('x').startswith('y', escape='\\'),
+            "x LIKE :x_1 || '%%' ESCAPE '\\'",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_not_startswith(self):
+        self.assert_compile(
+            ~column('x').startswith('y'),
+            "x NOT LIKE :x_1 || '%%'",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_not_startswith_escape(self):
+        self.assert_compile(
+            ~column('x').startswith('y', escape='\\'),
+            "x NOT LIKE :x_1 || '%%' ESCAPE '\\'",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_startswith_literal(self):
+        self.assert_compile(
+            column('x').startswith(literal_column('y')),
+            "x LIKE y || '%%'",
+            checkparams={}
+        )
+
+    def test_startswith_text(self):
+        self.assert_compile(
+            column('x').startswith(text('y')),
+            "x LIKE y || '%%'",
+            checkparams={}
+        )
+
+    def test_startswith_concat(self):
+        self.assert_compile(
+            column('x').startswith('y'),
+            "x LIKE concat(%s, '%%')",
+            checkparams={'x_1': 'y'},
+            dialect=mysql.dialect()
+        )
+
+    def test_not_startswith_concat(self):
+        self.assert_compile(
+            ~column('x').startswith('y'),
+            "x NOT LIKE concat(%s, '%%')",
+            checkparams={'x_1': 'y'},
+            dialect=mysql.dialect()
+        )
+
+    def test_startswith_firebird(self):
+        self.assert_compile(
+            column('x').startswith('y'),
+            "x STARTING WITH :x_1",
+            checkparams={'x_1': 'y'},
+            dialect=firebird.dialect()
+        )
+
+    def test_not_startswith_firebird(self):
+        self.assert_compile(
+            ~column('x').startswith('y'),
+            "x NOT STARTING WITH :x_1",
+            checkparams={'x_1': 'y'},
+            dialect=firebird.dialect()
+        )
+
+    def test_startswith_literal_mysql(self):
+        self.assert_compile(
+            column('x').startswith(literal_column('y')),
+            "x LIKE concat(y, '%%')",
+            checkparams={},
+            dialect=mysql.dialect()
+        )
+
+    def test_startswith_text_mysql(self):
+        self.assert_compile(
+            column('x').startswith(text('y')),
+            "x LIKE concat(y, '%%')",
+            checkparams={},
+            dialect=mysql.dialect()
+        )
+
+    def test_endswith(self):
+        self.assert_compile(
+            column('x').endswith('y'),
+            "x LIKE '%%' || :x_1",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_endswith_escape(self):
+        self.assert_compile(
+            column('x').endswith('y', escape='\\'),
+            "x LIKE '%%' || :x_1 ESCAPE '\\'",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_not_endswith(self):
+        self.assert_compile(
+            ~column('x').endswith('y'),
+            "x NOT LIKE '%%' || :x_1",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_not_endswith_escape(self):
+        self.assert_compile(
+            ~column('x').endswith('y', escape='\\'),
+            "x NOT LIKE '%%' || :x_1 ESCAPE '\\'",
+            checkparams={'x_1': 'y'}
+        )
+
+    def test_endswith_literal(self):
+        self.assert_compile(
+            column('x').endswith(literal_column('y')),
+            "x LIKE '%%' || y",
+            checkparams={}
+        )
+
+    def test_endswith_text(self):
+        self.assert_compile(
+            column('x').endswith(text('y')),
+            "x LIKE '%%' || y",
+            checkparams={}
+        )
+
+    def test_endswith_mysql(self):
+        self.assert_compile(
+            column('x').endswith('y'),
+            "x LIKE concat('%%', %s)",
+            checkparams={'x_1': 'y'},
+            dialect=mysql.dialect()
+        )
+
+    def test_not_endswith_mysql(self):
+        self.assert_compile(
+            ~column('x').endswith('y'),
+            "x NOT LIKE concat('%%', %s)",
+            checkparams={'x_1': 'y'},
+            dialect=mysql.dialect()
+        )
+
+    def test_endswith_literal_mysql(self):
+        self.assert_compile(
+            column('x').endswith(literal_column('y')),
+            "x LIKE concat('%%', y)",
+            checkparams={},
+            dialect=mysql.dialect()
+        )
+
+    def test_endswith_text_mysql(self):
+        self.assert_compile(
+            column('x').endswith(text('y')),
+            "x LIKE concat('%%', y)",
+            checkparams={},
+            dialect=mysql.dialect()
+        )
+