]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fixed precedence of operators so that parenthesis are correctly applied
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 25 Jun 2007 17:07:25 +0000 (17:07 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 25 Jun 2007 17:07:25 +0000 (17:07 +0000)
[ticket:620]
- calling <column>.in_() (i.e. with no arguments) will return
"CASE WHEN (<column> IS NULL) THEN NULL ELSE 0 END = 1)", so that
NULL or False is returned in all cases, rather than throwing an error
[ticket:545]

CHANGES
lib/sqlalchemy/sql.py
test/sql/query.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index dbeffbe5dae7a8e7276e2278c58d1631558132db..1b5af7a51cab04835dd67e5a5ac40dca8892ec67 100644 (file)
--- a/CHANGES
+++ b/CHANGES
       to polymorphic mappers that are using a straight "outerjoin"
       clause
 - sql
+    - fixed precedence of operators so that parenthesis are correctly applied
+      [ticket:620]
+    - calling <column>.in_() (i.e. with no arguments) will return 
+      "CASE WHEN (<column> IS NULL) THEN NULL ELSE 0 END = 1)", so that 
+      NULL or False is returned in all cases, rather than throwing an error
+      [ticket:545]
     - fixed "where"/"from" criterion of select() to accept a unicode string
       in addition to regular string - both convert to text()
     - added standalone distinct() function in addition to column.distinct()
index 5ceb9bdea3beaa0fdae8e403aac699107cc7a383..9bea33946e76d71465f15a2d6c3d714c3d0b832a 100644 (file)
@@ -40,22 +40,37 @@ __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
            'subquery', 'table', 'text', 'union', 'union_all', 'update',]
 
 # precedence ordering for common operators.  if an operator is not present in this list,
-# its precedence is assumed to be '0' which will cause it to be parenthesized when grouped against other operators
+# it will be parenthesized when grouped against other operators
 PRECEDENCE = {
     'FROM':15,
-    'AS':15,
-    'NOT':10,
+    '*':7,
+    '/':7,
+       '%':7,
+    '+':6,
+    '-':6,
+    'ILIKE':5,
+    'NOT ILIKE':5,
+    'LIKE':5,
+    'NOT LIKE':5,
+    'IN':5,
+    'NOT IN':5,
+    'IS':5,
+    'IS NOT':5,
+    '=':5,
+    '!=':5,
+    '>':5,
+    '<':5,
+    '>=':5,
+    '<=':5,
+    'NOT':4,
     'AND':3,
-    'OR':3,
-    '=':7,
-    '!=':7,
-    '>':7,
-    '<':7,
-    '+':5,
-    '-':5,
-    '*':5,
-    '/':5,
-    ',':0
+    'OR':2,
+    ',':-1,
+    'AS':-1,
+    'EXISTS':0,
+    'BETWEEN':0,
+    '_smallest': -1000,
+    '_largest': 1000
 }
 
 def desc(column):
@@ -1286,7 +1301,7 @@ class _CompareMixin(object):
     def in_(self, *other):
         """produce an ``IN`` clause."""
         if len(other) == 0:
-            return self.__eq__(None)
+            return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1')))
         elif len(other) == 1:
             o = other[0]
             if _is_literal(o) or isinstance( o, _CompareMixin):
@@ -1965,7 +1980,7 @@ class ClauseList(ClauseElement):
         return f
 
     def self_group(self, against=None):
-        if PRECEDENCE.get(self.operator, 0) <= PRECEDENCE.get(against, 0):
+        if self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
             return _Grouping(self)
         else:
             return self
@@ -2122,6 +2137,12 @@ class _UnaryExpression(ColumnElement):
             return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type)
         else:
             return super(_UnaryExpression, self)._negate()
+    
+    def self_group(self, against):
+        if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']):
+            return _Grouping(self)
+        else:
+            return self
 
 
 class _BinaryExpression(ColumnElement):
@@ -2155,7 +2176,8 @@ class _BinaryExpression(ColumnElement):
         )
         
     def self_group(self, against=None):
-        if PRECEDENCE.get(self.operator, 0) <= PRECEDENCE.get(against, 0):
+        # use small/large defaults for comparison so that unknown operators are always parenthesized
+        if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest'])):
             return _Grouping(self)
         else:
             return self
index 1c63132b59900cfb98a4370f80023b17a2b271d0..593b392e83676dcabc5ee4c2fa071a08de0926cd 100644 (file)
@@ -480,7 +480,7 @@ class QueryTest(PersistTest):
             tr.commit()
             con.execute("""drop trigger paj""")
             meta.drop_all()
-
+    
     @testbase.supported('mssql')
     def test_insertid_schema(self):
         meta = BoundMetaData(testbase.db)
@@ -493,6 +493,55 @@ class QueryTest(PersistTest):
         finally:
             tbl.drop()
             con.execute('drop schema paj')
+    
+    def test_in_filtering(self):
+        """test the 'shortname' field on BindParamClause."""
+        self.users.insert().execute(user_id = 7, user_name = 'jack')
+        self.users.insert().execute(user_id = 8, user_name = 'fred')
+        self.users.insert().execute(user_id = 9, user_name = None)
+        
+        s = self.users.select(self.users.c.user_name.in_())
+        r = s.execute().fetchall()
+        # No username is in empty set
+        assert len(r) == 0
+        
+        s = self.users.select(not_(self.users.c.user_name.in_()))
+        r = s.execute().fetchall()
+        # All usernames with a value are outside an empty set
+        assert len(r) == 2
+        
+        s = self.users.select(self.users.c.user_name.in_('jack','fred'))
+        r = s.execute().fetchall()
+        assert len(r) == 2
+        
+        s = self.users.select(not_(self.users.c.user_name.in_('jack','fred')))
+        r = s.execute().fetchall()
+        # Null values are not outside any set
+        assert len(r) == 0
+        
+        u = bindparam('search_key')
+        
+        s = self.users.select(u.in_())
+        r = s.execute(search_key='john').fetchall()
+        assert len(r) == 0
+        r = s.execute(search_key=None).fetchall()
+        assert len(r) == 0
+        
+        s = self.users.select(not_(u.in_()))
+        r = s.execute(search_key='john').fetchall()
+        assert len(r) == 3
+        r = s.execute(search_key=None).fetchall()
+        assert len(r) == 0
+        
+        s = self.users.select(self.users.c.user_name.in_() == True)
+        r = s.execute().fetchall()
+        assert len(r) == 0
+        s = self.users.select(self.users.c.user_name.in_() == False)
+        r = s.execute().fetchall()
+        assert len(r) == 2
+        s = self.users.select(self.users.c.user_name.in_() == None)
+        r = s.execute().fetchall()
+        assert len(r) == 1
         
 
 class CompoundTest(PersistTest):
index 01fbd5cc851ac60a1e4b4317ea32da644baa8c90..7ae830e6aeda312bdc5697bdcefa80b131556756 100644 (file)
@@ -263,11 +263,11 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
     def testoperators(self):
         self.runtest(
             table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), 
-            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT (mytable.name = :mytable_name)"
+            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name = :mytable_name"
         )
         
         self.runtest(
-            literal("a") + literal("b") * literal("c"), ":literal + (:literal_1 * :literal_2)"
+            literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2"
         )
 
         # exercise arithmetic operators
@@ -527,12 +527,12 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
          self.runtest(
              select([value_tbl.c.id], (value_tbl.c.val2 -
      value_tbl.c.val1)/value_tbl.c.val1 > 2.0),
-             "SELECT values.id FROM values WHERE ((values.val2 - values.val1) / values.val1) > :literal"
+             "SELECT values.id FROM values WHERE (values.val2 - values.val1) / values.val1 > :literal"
          )
 
          self.runtest(
              select([value_tbl.c.id], value_tbl.c.val1 / (value_tbl.c.val2 - value_tbl.c.val1) /value_tbl.c.val1 > 2.0),
-             "SELECT values.id FROM values WHERE ((values.val1 / (values.val2 - values.val1)) / values.val1) > :literal"
+             "SELECT values.id FROM values WHERE values.val1 / (values.val2 - values.val1) / values.val1 > :literal"
          )
          
     def testfunction(self):
@@ -809,7 +809,7 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal)")
 
         self.runtest(select([table1], table1.c.myid.in_(literal('a') + 'a')),
-        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (:literal + :literal_1)")
+        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :literal + :literal_1")
 
         self.runtest(select([table1], table1.c.myid.in_(literal('a') +'a', 'b')),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :mytable_myid)")
@@ -868,6 +868,10 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
             "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid WHERE myothertable.otherid IN (SELECT myothertable.otherid FROM myothertable ORDER BY myothertable.othername  LIMIT 10) ORDER BY mytable.myid"
         )
         
+        # test empty in clause
+        self.runtest(select([table1], table1.c.myid.in_()),
+        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)")
+        
     
     def testlateargs(self):
         """tests that a SELECT clause will have extra "WHERE" clauses added to it at compile time if extra arguments
@@ -916,6 +920,26 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
         self.runtest(table.select(table.c.date.between(datetime.date(2006,6,1), datetime.date(2006,6,5))), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :dt_date AND :dt_date_1", checkparams={'dt_date':datetime.date(2006,6,1), 'dt_date_1':datetime.date(2006,6,5)})
 
         self.runtest(table.select(sql.between(table.c.date, datetime.date(2006,6,1), datetime.date(2006,6,5))), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :literal AND :literal_1", checkparams={'literal':datetime.date(2006,6,1), 'literal_1':datetime.date(2006,6,5)})
+    
+    def test_operator_precedence(self):
+        table = Table('op', metadata,
+            Column('field', Integer))
+        self.runtest(table.select((table.c.field == 5) == None),
+            "SELECT op.field FROM op WHERE (op.field = :op_field) IS NULL")
+        self.runtest(table.select((table.c.field + 5) == table.c.field),
+            "SELECT op.field FROM op WHERE op.field + :op_field = op.field")
+        self.runtest(table.select((table.c.field + 5) * 6),
+            "SELECT op.field FROM op WHERE (op.field + :op_field) * :literal")
+        self.runtest(table.select((table.c.field * 5) + 6),
+            "SELECT op.field FROM op WHERE op.field * :op_field + :literal")
+        self.runtest(table.select(5 + table.c.field.in_(5,6)),
+            "SELECT op.field FROM op WHERE :literal + (op.field IN (:op_field, :op_field_1))")
+        self.runtest(table.select((5 + table.c.field).in_(5,6)),
+            "SELECT op.field FROM op WHERE :op_field + op.field IN (:literal, :literal_1)")
+        self.runtest(table.select(not_(table.c.field == 5)),
+            "SELECT op.field FROM op WHERE NOT op.field = :op_field")
+        self.runtest(table.select(not_(table.c.field) == 5),
+            "SELECT op.field FROM op WHERE (NOT op.field) = :literal")
 
 class CRUDTest(SQLTest):
     def testinsert(self):
@@ -964,7 +988,7 @@ class CRUDTest(SQLTest):
             values = {
             table1.c.name : table1.c.name + "lala",
             table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho'))
-            }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=mytable.name + :mytable_name WHERE mytable.myid = hoho(:hoho) AND mytable.name = ((:literal + mytable.name) + :literal_1)")
+            }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=mytable.name + :mytable_name WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal + mytable.name + :literal_1")
         
     def testcorrelatedupdate(self):
         # test against a straight text subquery