From: Mike Bayer Date: Mon, 25 Jun 2007 17:07:25 +0000 (+0000) Subject: - fixed precedence of operators so that parenthesis are correctly applied X-Git-Tag: rel_0_3_9~66 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bc58df9c1f1f443b67a3312463df2c9425531503;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - fixed precedence of operators so that parenthesis are correctly applied [ticket:620] - calling .in_() (i.e. with no arguments) will return "CASE WHEN ( IS NULL) THEN NULL ELSE 0 END = 1)", so that NULL or False is returned in all cases, rather than throwing an error [ticket:545] --- diff --git a/CHANGES b/CHANGES index dbeffbe5da..1b5af7a51c 100644 --- a/CHANGES +++ b/CHANGES @@ -25,6 +25,12 @@ to polymorphic mappers that are using a straight "outerjoin" clause - sql + - fixed precedence of operators so that parenthesis are correctly applied + [ticket:620] + - calling .in_() (i.e. with no arguments) will return + "CASE WHEN ( IS NULL) THEN NULL ELSE 0 END = 1)", so that + NULL or False is returned in all cases, rather than throwing an error + [ticket:545] - fixed "where"/"from" criterion of select() to accept a unicode string in addition to regular string - both convert to text() - added standalone distinct() function in addition to column.distinct() diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 5ceb9bdea3..9bea33946e 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -40,22 +40,37 @@ __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters', 'subquery', 'table', 'text', 'union', 'union_all', 'update',] # precedence ordering for common operators. if an operator is not present in this list, -# its precedence is assumed to be '0' which will cause it to be parenthesized when grouped against other operators +# it will be parenthesized when grouped against other operators PRECEDENCE = { 'FROM':15, - 'AS':15, - 'NOT':10, + '*':7, + '/':7, + '%':7, + '+':6, + '-':6, + 'ILIKE':5, + 'NOT ILIKE':5, + 'LIKE':5, + 'NOT LIKE':5, + 'IN':5, + 'NOT IN':5, + 'IS':5, + 'IS NOT':5, + '=':5, + '!=':5, + '>':5, + '<':5, + '>=':5, + '<=':5, + 'NOT':4, 'AND':3, - 'OR':3, - '=':7, - '!=':7, - '>':7, - '<':7, - '+':5, - '-':5, - '*':5, - '/':5, - ',':0 + 'OR':2, + ',':-1, + 'AS':-1, + 'EXISTS':0, + 'BETWEEN':0, + '_smallest': -1000, + '_largest': 1000 } def desc(column): @@ -1286,7 +1301,7 @@ class _CompareMixin(object): def in_(self, *other): """produce an ``IN`` clause.""" if len(other) == 0: - return self.__eq__(None) + return _Grouping(case([(self.__eq__(None), text('NULL'))], else_=text('0')).__eq__(text('1'))) elif len(other) == 1: o = other[0] if _is_literal(o) or isinstance( o, _CompareMixin): @@ -1965,7 +1980,7 @@ class ClauseList(ClauseElement): return f def self_group(self, against=None): - if PRECEDENCE.get(self.operator, 0) <= PRECEDENCE.get(against, 0): + if self.operator != against and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']): return _Grouping(self) else: return self @@ -2122,6 +2137,12 @@ class _UnaryExpression(ColumnElement): return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type) else: return super(_UnaryExpression, self)._negate() + + def self_group(self, against): + if self.operator and PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest']): + return _Grouping(self) + else: + return self class _BinaryExpression(ColumnElement): @@ -2155,7 +2176,8 @@ class _BinaryExpression(ColumnElement): ) def self_group(self, against=None): - if PRECEDENCE.get(self.operator, 0) <= PRECEDENCE.get(against, 0): + # use small/large defaults for comparison so that unknown operators are always parenthesized + if self.operator != against and (PRECEDENCE.get(self.operator, PRECEDENCE['_smallest']) <= PRECEDENCE.get(against, PRECEDENCE['_largest'])): return _Grouping(self) else: return self diff --git a/test/sql/query.py b/test/sql/query.py index 1c63132b59..593b392e83 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -480,7 +480,7 @@ class QueryTest(PersistTest): tr.commit() con.execute("""drop trigger paj""") meta.drop_all() - + @testbase.supported('mssql') def test_insertid_schema(self): meta = BoundMetaData(testbase.db) @@ -493,6 +493,55 @@ class QueryTest(PersistTest): finally: tbl.drop() con.execute('drop schema paj') + + def test_in_filtering(self): + """test the 'shortname' field on BindParamClause.""" + self.users.insert().execute(user_id = 7, user_name = 'jack') + self.users.insert().execute(user_id = 8, user_name = 'fred') + self.users.insert().execute(user_id = 9, user_name = None) + + s = self.users.select(self.users.c.user_name.in_()) + r = s.execute().fetchall() + # No username is in empty set + assert len(r) == 0 + + s = self.users.select(not_(self.users.c.user_name.in_())) + r = s.execute().fetchall() + # All usernames with a value are outside an empty set + assert len(r) == 2 + + s = self.users.select(self.users.c.user_name.in_('jack','fred')) + r = s.execute().fetchall() + assert len(r) == 2 + + s = self.users.select(not_(self.users.c.user_name.in_('jack','fred'))) + r = s.execute().fetchall() + # Null values are not outside any set + assert len(r) == 0 + + u = bindparam('search_key') + + s = self.users.select(u.in_()) + r = s.execute(search_key='john').fetchall() + assert len(r) == 0 + r = s.execute(search_key=None).fetchall() + assert len(r) == 0 + + s = self.users.select(not_(u.in_())) + r = s.execute(search_key='john').fetchall() + assert len(r) == 3 + r = s.execute(search_key=None).fetchall() + assert len(r) == 0 + + s = self.users.select(self.users.c.user_name.in_() == True) + r = s.execute().fetchall() + assert len(r) == 0 + s = self.users.select(self.users.c.user_name.in_() == False) + r = s.execute().fetchall() + assert len(r) == 2 + s = self.users.select(self.users.c.user_name.in_() == None) + r = s.execute().fetchall() + assert len(r) == 1 class CompoundTest(PersistTest): diff --git a/test/sql/select.py b/test/sql/select.py index 01fbd5cc85..7ae830e6ae 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -263,11 +263,11 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A def testoperators(self): self.runtest( table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT (mytable.name = :mytable_name)" + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT mytable.name = :mytable_name" ) self.runtest( - literal("a") + literal("b") * literal("c"), ":literal + (:literal_1 * :literal_2)" + literal("a") + literal("b") * literal("c"), ":literal + :literal_1 * :literal_2" ) # exercise arithmetic operators @@ -527,12 +527,12 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today self.runtest( select([value_tbl.c.id], (value_tbl.c.val2 - value_tbl.c.val1)/value_tbl.c.val1 > 2.0), - "SELECT values.id FROM values WHERE ((values.val2 - values.val1) / values.val1) > :literal" + "SELECT values.id FROM values WHERE (values.val2 - values.val1) / values.val1 > :literal" ) self.runtest( select([value_tbl.c.id], value_tbl.c.val1 / (value_tbl.c.val2 - value_tbl.c.val1) /value_tbl.c.val1 > 2.0), - "SELECT values.id FROM values WHERE ((values.val1 / (values.val2 - values.val1)) / values.val1) > :literal" + "SELECT values.id FROM values WHERE values.val1 / (values.val2 - values.val1) / values.val1 > :literal" ) def testfunction(self): @@ -809,7 +809,7 @@ myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :literal)") self.runtest(select([table1], table1.c.myid.in_(literal('a') + 'a')), - "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (:literal + :literal_1)") + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :literal + :literal_1") self.runtest(select([table1], table1.c.myid.in_(literal('a') +'a', 'b')), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:literal + :literal_1, :mytable_myid)") @@ -868,6 +868,10 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid WHERE myothertable.otherid IN (SELECT myothertable.otherid FROM myothertable ORDER BY myothertable.othername LIMIT 10) ORDER BY mytable.myid" ) + # test empty in clause + self.runtest(select([table1], table1.c.myid.in_()), + "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)") + def testlateargs(self): """tests that a SELECT clause will have extra "WHERE" clauses added to it at compile time if extra arguments @@ -916,6 +920,26 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE self.runtest(table.select(table.c.date.between(datetime.date(2006,6,1), datetime.date(2006,6,5))), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :dt_date AND :dt_date_1", checkparams={'dt_date':datetime.date(2006,6,1), 'dt_date_1':datetime.date(2006,6,5)}) self.runtest(table.select(sql.between(table.c.date, datetime.date(2006,6,1), datetime.date(2006,6,5))), "SELECT dt.date FROM dt WHERE dt.date BETWEEN :literal AND :literal_1", checkparams={'literal':datetime.date(2006,6,1), 'literal_1':datetime.date(2006,6,5)}) + + def test_operator_precedence(self): + table = Table('op', metadata, + Column('field', Integer)) + self.runtest(table.select((table.c.field == 5) == None), + "SELECT op.field FROM op WHERE (op.field = :op_field) IS NULL") + self.runtest(table.select((table.c.field + 5) == table.c.field), + "SELECT op.field FROM op WHERE op.field + :op_field = op.field") + self.runtest(table.select((table.c.field + 5) * 6), + "SELECT op.field FROM op WHERE (op.field + :op_field) * :literal") + self.runtest(table.select((table.c.field * 5) + 6), + "SELECT op.field FROM op WHERE op.field * :op_field + :literal") + self.runtest(table.select(5 + table.c.field.in_(5,6)), + "SELECT op.field FROM op WHERE :literal + (op.field IN (:op_field, :op_field_1))") + self.runtest(table.select((5 + table.c.field).in_(5,6)), + "SELECT op.field FROM op WHERE :op_field + op.field IN (:literal, :literal_1)") + self.runtest(table.select(not_(table.c.field == 5)), + "SELECT op.field FROM op WHERE NOT op.field = :op_field") + self.runtest(table.select(not_(table.c.field) == 5), + "SELECT op.field FROM op WHERE (NOT op.field) = :literal") class CRUDTest(SQLTest): def testinsert(self): @@ -964,7 +988,7 @@ class CRUDTest(SQLTest): values = { table1.c.name : table1.c.name + "lala", table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho')) - }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=mytable.name + :mytable_name WHERE mytable.myid = hoho(:hoho) AND mytable.name = ((:literal + mytable.name) + :literal_1)") + }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=mytable.name + :mytable_name WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal + mytable.name + :literal_1") def testcorrelatedupdate(self): # test against a straight text subquery