]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the startswith(), endswith(), and contains() operators
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 31 Jan 2008 03:57:20 +0000 (03:57 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 31 Jan 2008 03:57:20 +0000 (03:57 +0000)
now concatenate the wildcard operator with the given
operand in SQL, i.e. "'%' || <bindparam>" in all cases,
accept text('something') operands properly [ticket:962]

- cast() accepts text('something') and other non-literal
operands properly [ticket:962]

CHANGES
lib/sqlalchemy/sql/expression.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index 01f04acd75e61d57e5382a3634be5873850d9762..97f480eb9f37f1b750e4a925c6b7d23d44f41d77 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -9,6 +9,14 @@ CHANGES
       to ILIKE on postgres, lower(x) LIKE lower(y) on all
       others. [ticket:727]
 
+    - the startswith(), endswith(), and contains() operators
+      now concatenate the wildcard operator with the given
+      operand in SQL, i.e. "'%' || <bindparam>" in all cases,
+      accept text('something') operands properly [ticket:962]
+      
+    - cast() accepts text('something') and other non-literal
+      operands properly [ticket:962]
+      
     - The '.c.' attribute on a selectable now gets an entry
       for every column expression in its columns clause.
       Previously, "unnamed" columns like functions and CASE
index aff8654f256e987b81907e968c24808dcf0450b8..6c0c4659ec9a991419340ada0127931bf0ca3a9f 100644 (file)
@@ -640,21 +640,26 @@ def column(text, type_=None):
     return _ColumnClause(text, type_=type_)
 
 def literal_column(text, type_=None):
-    """Return a textual column clause, as would be in the columns clause of a ``SELECT`` statement.
+    """Return a textual column expression, as would be in the columns 
+    clause of a ``SELECT`` statement.
 
-    The object returned is an instance of [sqlalchemy.sql.expression#_ColumnClause],
-    which represents the "syntactical" portion of the schema-level
-    [sqlalchemy.schema#Column] object.
+    The object returned supports further expressions in the same way
+    as any other column object, including comparison, math and string
+    operations.  The type_ parameter is important to determine proper
+    expression behavior (such as, '+' means string concatenation or
+    numerical addition based on the type).
 
     text
-      the name of the column.  Quoting rules will not be applied to
-      the column.  For textual column constructs that should be quoted
-      like any other column construct, use the
-      [sqlalchemy.sql.expression#column()] function.
+      the text of the expression; can be any SQL expression.  Quoting rules 
+      will not be applied.  To specify a column-name expression which should
+      be subject to quoting rules, use the [sqlalchemy.sql.expression#column()] 
+      function.
 
-    type
+    type_
       an optional [sqlalchemy.types#TypeEngine] object which will
-      provide result-set translation for this column.
+      provide result-set translation and additional expression 
+      semantics for this column.  If left as None the type will be
+      NullType.
     """
 
     return _ColumnClause(text, type_=type_, is_literal=True)
@@ -1173,7 +1178,7 @@ class ColumnOperators(Operators):
 class _CompareMixin(ColumnOperators):
     """Defines comparison and math operations for ``ClauseElement`` instances."""
 
-    def __compare(self, op, obj, negate=None):
+    def __compare(self, op, obj, negate=None, reverse=False):
         if obj is None or isinstance(obj, _Null):
             if op == operators.eq:
                 return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot)
@@ -1183,14 +1188,21 @@ class _CompareMixin(ColumnOperators):
                 raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
         else:
             obj = self._check_literal(obj)
-        return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate)
+            
+        if reverse:
+            return _BinaryExpression(obj, self.expression_element(), op, type_=sqltypes.Boolean, negate=negate)
+        else:
+            return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate)
 
-    def __operate(self, op, obj):
+    def __operate(self, op, obj, reverse=False):
         obj = self._check_literal(obj)
 
         type_ = self._compare_type(obj)
 
-        return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_)
+        if reverse:
+            return _BinaryExpression(obj, self.expression_element(), type_.adapt_operator(op), type_=type_)
+        else:
+            return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_)
 
     # a mapping of operators with the method they use, along with their negated
     # operator for comparison operators
@@ -1216,7 +1228,8 @@ class _CompareMixin(ColumnOperators):
         return o[0](self, op, other[0], *o[1:])
 
     def reverse_operate(self, op, other):
-        return self._bind_param(other).operate(op, self)
+        o = _CompareMixin.operators[op]
+        return o[0](self, op, other, reverse=True, *o[1:])
 
     def in_(self, *other):
         return self._in_impl(operators.in_op, operators.notin_op, *other)
@@ -1251,29 +1264,18 @@ class _CompareMixin(ColumnOperators):
     def startswith(self, other):
         """Produce the clause ``LIKE '<other>%'``"""
 
-        perc = isinstance(other, basestring) and '%' or literal('%', type_=sqltypes.String)
-        return self.__compare(operators.like_op, other + perc)
+        # use __radd__ to force string concat behavior
+        return self.__compare(operators.like_op, literal_column("'%'", type_=sqltypes.String).__radd__(self._check_literal(other)))
 
     def endswith(self, other):
         """Produce the clause ``LIKE '%<other>'``"""
 
-        if isinstance(other, basestring):
-            po = '%' + other
-        else:
-            po = literal('%', type_=sqltypes.String) + other
-            po.type = sqltypes.to_instance(sqltypes.String)     #force!
-        return self.__compare(operators.like_op, po)
+        return self.__compare(operators.like_op, literal_column("'%'", type_=sqltypes.String) + self._check_literal(other))
 
     def contains(self, other):
         """Produce the clause ``LIKE '%<other>%'``"""
 
-        if isinstance(other, basestring):
-            po = '%' + other + '%'
-        else:
-            perc = literal('%', type_=sqltypes.String)
-            po = perc + other + perc
-            po.type = sqltypes.to_instance(sqltypes.String)     #force!
-        return self.__compare(operators.like_op, po)
+        return self.__compare(operators.like_op, literal_column("'%'", type_=sqltypes.String) + self._check_literal(other) + literal_column("'%'", type_=sqltypes.String))
 
     def label(self, name):
         """Produce a column label, i.e. ``<columnname> AS <name>``.
@@ -2030,10 +2032,8 @@ class _Cast(ColumnElement):
 
     def __init__(self, clause, totype, **kwargs):
         ColumnElement.__init__(self)
-        if not hasattr(clause, 'label'):
-            clause = literal(clause)
         self.type = sqltypes.to_instance(totype)
-        self.clause = clause
+        self.clause = _literal_as_binds(clause, None)
         self.typeclause = _TypeClause(self.type)
 
     def _copy_internals(self, clone=_clone):
index c34cec7c516c18ebed4062ce6710edf8d3a326f5..522f9a2ffd5d80ed4196e70a6e083084d65cadba 100644 (file)
@@ -451,29 +451,33 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         clause = (table1.c.myid == 12) & table1.c.myid.between(15, 20) & table1.c.myid.like('hoho')
         assert str(clause) == str(util.pickle.loads(util.pickle.dumps(clause)))
 
-
-
-    def testextracomparisonoperators(self):
+    def test_composed_string_comparators(self):
         self.assert_compile(
-            table1.select(table1.c.name.contains('jo')),
-            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name LIKE :mytable_name_1",
-            checkparams = {'mytable_name_1': u'%jo%'},
+            table1.c.name.contains('jo'), "mytable.name LIKE '%' || :mytable_name_1 || '%'" , checkparams = {'mytable_name_1': u'jo'},
         )
         self.assert_compile(
-            table1.select(table1.c.name.endswith('hn')),
-            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name LIKE :mytable_name_1",
-            checkparams = {'mytable_name_1': u'%hn'},
+            table1.c.name.contains('jo'), "mytable.name LIKE concat(concat('%', %s), '%')" , checkparams = {'mytable_name_1': u'jo'},
+            dialect=mysql.dialect()
+        )
+        self.assert_compile(
+            table1.c.name.endswith('hn'), "mytable.name LIKE '%' || :mytable_name_1", checkparams = {'mytable_name_1': u'hn'},
+        )
+        self.assert_compile(
+            table1.c.name.endswith('hn'), "mytable.name LIKE concat('%', %s)",
+            checkparams = {'mytable_name_1': u'hn'}, dialect=mysql.dialect()
         )
-
-    def testunicodestartswith(self):
-        string = u"hi \xf6 \xf5"
         self.assert_compile(
-            table1.select(table1.c.name.startswith(string)),
-            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name LIKE :mytable_name_1",
-            checkparams = {'mytable_name_1': u'hi \xf6 \xf5%'},
+            table1.c.name.startswith(u"hi \xf6 \xf5"), "mytable.name LIKE :mytable_name_1 || '%'",
+            checkparams = {'mytable_name_1': u'hi \xf6 \xf5'},
         )
+        self.assert_compile(column('name').endswith(text("'foo'")), "name LIKE '%' || 'foo'"  )
+        self.assert_compile(column('name').endswith(literal_column("'foo'")), "name LIKE '%' || 'foo'"  )
+        self.assert_compile(column('name').startswith(text("'foo'")), "name LIKE 'foo' || '%'"  )
+        self.assert_compile(column('name').startswith(text("'foo'")), "name LIKE concat('foo', '%')", dialect=mysql.dialect())
+        self.assert_compile(column('name').startswith(literal_column("'foo'")), "name LIKE 'foo' || '%'"  )
+        self.assert_compile(column('name').startswith(literal_column("'foo'")), "name LIKE concat('foo', '%')", dialect=mysql.dialect())
 
-    def testmultiparam(self):
+    def test_multiple_col_binds(self):
         self.assert_compile(
             select(["*"], or_(table1.c.myid == 12, table1.c.myid=='asdf', table1.c.myid == 'foo')),
             "SELECT * FROM mytable WHERE mytable.myid = :mytable_myid_1 OR mytable.myid = :mytable_myid_2 OR mytable.myid = :mytable_myid_3"
@@ -1067,14 +1071,14 @@ EXISTS (select yay from foo where boo = lar)",
             assert str(err) == "Bind parameter 'mytable_myid_1' conflicts with unique bind parameter of the same name"
 
 
-    def testbindascol(self):
+    def test_bind_as_col(self):
         t = table('foo', column('id'))
 
         s = select([t, literal('lala').label('hoho')])
         self.assert_compile(s, "SELECT foo.id, :param_1 AS hoho FROM foo")
         assert [str(c) for c in s.c] == ["id", "hoho"]
 
-    def testin(self):
+    def test_in(self):
         self.assert_compile(select([table1], table1.c.myid.in_(['a'])),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid_1)")
 
@@ -1179,7 +1183,7 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
         self.assert_compile(select([table1], table1.c.myid.in_()),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)")
 
-    def testcast(self):
+    def test_cast(self):
         tbl = table('casttest',
                     column('id', Integer),
                     column('v1', Float),
@@ -1215,7 +1219,11 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
         # then the MySQL engine
         check_results(mysql.dialect(), ['DECIMAL(10, 2)', 'DECIMAL(12, 9)', 'DATE', 'CHAR', 'CHAR(20)'], '%s')
 
-    def testdatebetween(self):
+        self.assert_compile(cast(text('NULL'), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect())
+        self.assert_compile(cast(null(), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect())
+        self.assert_compile(cast(literal_column('NULL'), Integer), "CAST(NULL AS INTEGER)", dialect=sqlite.dialect())
+        
+    def test_date_between(self):
         import datetime
         table = Table('dt', metadata,
             Column('date', Date))