]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed operator precedence rules for multiple
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 25 Nov 2010 17:20:13 +0000 (12:20 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 25 Nov 2010 17:20:13 +0000 (12:20 -0500)
chains of a single non-associative operator.
I.e. "x - (y - z)" will compile as "x - (y - z)"
and not "x - y - z".  Also works with labels,
i.e. "x - (y - z).label('foo')"
[ticket:1984]
- Single element tuple expressions inside an IN clause
parenthesize correctly, also from [ticket:1984],
added tests for PG
- re-fix again importlater, [ticket:1983]

CHANGES
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/util.py
test/base/test_utils.py
test/dialect/test_postgresql.py
test/sql/test_compiler.py

diff --git a/CHANGES b/CHANGES
index 5bf836ac02534f7a9443823a5b43a12145ba3cae..8e3d0666629a0286096b7843909c9a8689e1f5d0 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -45,6 +45,13 @@ CHANGES
     [ticket:1976]
     
 - sql
+  - Fixed operator precedence rules for multiple
+    chains of a single non-associative operator.
+    I.e. "x - (y - z)" will compile as "x - (y - z)"
+    and not "x - y - z".  Also works with labels,
+    i.e. "x - (y - z).label('foo')"
+    [ticket:1984]
+    
   - The 'info' attribute of Column is copied during 
     Column.copy(), i.e. as occurs when using columns
     in declarative mixins.  [ticket:1967]
@@ -60,6 +67,9 @@ CHANGES
     [ticket:1871]
 
 - postgresql
+  - Single element tuple expressions inside an IN clause
+    parenthesize correctly, also from [ticket:1984]
+    
   - Ensured every numeric, float, int code, scalar + array,
     are recognized by psycopg2 and pg8000's "numeric" 
     base type. [ticket:1955]
index c3dc339a50dd9adaac92e502228e2b0c9767da94..6e22e7ca5753b402faa482b0955a109ee45e5486 100644 (file)
@@ -1289,6 +1289,29 @@ class ClauseElement(Visitable):
         return []
 
     def self_group(self, against=None):
+        """Apply a 'grouping' to this :class:`.ClauseElement`.
+        
+        This method is overridden by subclasses to return a 
+        "grouping" construct, i.e. parenthesis.   In particular
+        it's used by "binary" expressions to provide a grouping
+        around themselves when placed into a larger expression, 
+        as well as by :func:`.select` constructs when placed into
+        the FROM clause of another :func:`.select`.  (Note that 
+        subqueries should be normally created using the 
+        :func:`.Select.alias` method, as many platforms require
+        nested SELECT statements to be named).
+        
+        As expressions are composed together, the application of
+        :meth:`self_group` is automatic - end-user code should never 
+        need to use this method directly.  Note that SQLAlchemy's
+        clause constructs take operator precedence into account - 
+        so parenthesis might not be needed, for example, in 
+        an expression like ``x OR (y AND z)`` - AND takes precedence
+        over OR.
+        
+        The base :meth:`self_group` method of :class:`.ClauseElement`
+        just returns self.
+        """
         return self
 
     # TODO: remove .bind as a method from the root ClauseElement.
@@ -2657,8 +2680,7 @@ class ClauseList(ClauseElement):
         return list(itertools.chain(*[c._from_objects for c in self.clauses]))
 
     def self_group(self, against=None):
-        if self.group and self.operator is not against and \
-                operators.is_precedent(self.operator, against):
+        if self.group and operators.is_precedent(self.operator, against):
             return _Grouping(self)
         else:
             return self
@@ -2984,10 +3006,7 @@ class _BinaryExpression(ColumnElement):
         )
 
     def self_group(self, against=None):
-        # use small/large defaults for comparison so that unknown
-        # operators are always parenthesized
-        if self.operator is not against and \
-                operators.is_precedent(self.operator, against):
+        if operators.is_precedent(self.operator, against):
             return _Grouping(self)
         else:
             return self
@@ -3343,7 +3362,16 @@ class _Label(ColumnElement):
     @util.memoized_property
     def element(self):
         return self._element.self_group(against=operators.as_)
-
+    
+    def self_group(self, against=None):
+        sub_element = self._element.self_group(against=against)
+        if sub_element is not self._element:
+            return _Label(self.name, 
+                        sub_element, 
+                        type_=self._type)
+        else:
+            return self._element
+        
     @property
     def primary_key(self):
         return self.element.primary_key
index 6f70b1778d2cfeb7dfbdee363ac346a759c94f01..68e30e646d88e9a41c3ca0b471b24f0f41d47cd1 100644 (file)
@@ -83,10 +83,14 @@ def desc_op(a):
 def asc_op(a):
     return a.asc()
 
+
 _commutative = set([eq, ne, add, mul])
 def is_commutative(op):
     return op in _commutative
 
+_associative = _commutative.union([concat_op, and_, or_])
+    
+
 _smallest = symbol('_smallest')
 _largest = symbol('_largest')
 
@@ -131,5 +135,8 @@ _PRECEDENCE = {
 }
 
 def is_precedent(operator, against):
-    return (_PRECEDENCE.get(operator, _PRECEDENCE[_smallest]) <=
+    if operator is against and operator in _associative:
+        return False
+    else:
+        return (_PRECEDENCE.get(operator, _PRECEDENCE[_smallest]) <=
             _PRECEDENCE.get(against, _PRECEDENCE[_largest]))
index dafba8250242441dd8422cdae2bd812d64c93d97..59704e41bddea8d6655e186ce1b46a5ac6251991 100644 (file)
@@ -1579,12 +1579,8 @@ class importlater(object):
     @memoized_property
     def _il_module(self):
         if self._il_addtl:
-            m = __import__(self._il_path + "." + self._il_addtl)
-        else:
-            m = __import__(self._il_path)
-        for token in self._il_path.split(".")[1:]:
-            m = getattr(m, token)
-        if self._il_addtl:
+            m = __import__(self._il_path, globals(), locals(), 
+                                [self._il_addtl])
             try:
                 return getattr(m, self._il_addtl)
             except AttributeError:
@@ -1593,6 +1589,9 @@ class importlater(object):
                         (self._il_path, self._il_addtl)
                     )
         else:
+            m = __import__(self._il_path)
+            for token in self._il_path.split(".")[1:]:
+                m = getattr(m, token)
             return m
         
     def __getattr__(self, key):
index d083a8458e9a78717449f8c611e54827262378b7..e7ecbec5126f83517c41c8ff0fc9919437dd0864 100644 (file)
@@ -149,8 +149,6 @@ class ColumnCollectionTest(TestBase):
         assert (cc1==cc2).compare(c1 == c2)
         assert not (cc1==cc3).compare(c2 == c3)
 
-
-
 class LRUTest(TestBase):
 
     def test_lru(self):                
index 150dacf180a99f0e1e5a2ee02e4b0a57dabe1294..92c0894806a303dca1c718febf37b9b230b91d4b 100644 (file)
@@ -2046,3 +2046,34 @@ class MatchTest(TestBase, AssertsCompiledSQL):
                 matchtable.c.title.match('nutshells'
                 )))).order_by(matchtable.c.id).execute().fetchall()
         eq_([1, 3, 5], [r.id for r in results])
+
+
+class TupleTest(TestBase):
+    __only_on__ = 'postgresql'
+    
+    def test_tuple_containment(self):
+        
+        for test, exp in [
+            ([('a', 'b')], True),
+            ([('a', 'c')], False),
+            ([('f', 'q'), ('a', 'b')], True),
+            ([('f', 'q'), ('a', 'c')], False)
+        ]:
+            eq_(
+                testing.db.execute(
+                    select([
+                            tuple_(
+                                literal_column("'a'"), 
+                                literal_column("'b'")
+                            ).\
+                                in_([
+                                    tuple_(*[
+                                            literal_column("'%s'" % letter) 
+                                            for letter in elem
+                                        ]) for elem in test
+                                ])
+                            ])
+                ).scalar(),
+                exp
+            )
+
index 338a5491eed51b200bc51398fe3c9c89ab220190..93c0d6587a99c4b6711a2f648ae219a27fe47e4f 100644 (file)
@@ -678,7 +678,7 @@ class SelectTest(TestBase, AssertsCompiledSQL):
             select([func.count(distinct(table1.c.myid))]), 
             "SELECT count(DISTINCT mytable.myid) AS count_1 FROM mytable"
         )
-
+    
     def test_operators(self):
         for (py_op, sql_op) in ((operator.add, '+'), (operator.mul, '*'),
                                 (operator.sub, '-'), 
@@ -1293,7 +1293,7 @@ class SelectTest(TestBase, AssertsCompiledSQL):
 
          self.assert_compile(
              select([value_tbl.c.id], value_tbl.c.val1 / (value_tbl.c.val2 - value_tbl.c.val1) /value_tbl.c.val1 > 2.0),
-             "SELECT values.id FROM values WHERE values.val1 / (values.val2 - values.val1) / values.val1 > :param_1"
+             "SELECT values.id FROM values WHERE (values.val1 / (values.val2 - values.val1)) / values.val1 > :param_1"
          )
 
     def test_collate(self):
@@ -1925,7 +1925,7 @@ class SelectTest(TestBase, AssertsCompiledSQL):
             tuple_(table1.c.myid, table1.c.name).in_(
                         [tuple_(table2.c.otherid, table2.c.othername)]
                     ),
-            "(mytable.myid, mytable.name) IN (myothertable.otherid, myothertable.othername)"
+            "(mytable.myid, mytable.name) IN ((myothertable.otherid, myothertable.othername))"
         )
         
         self.assert_compile(
@@ -2044,6 +2044,42 @@ class SelectTest(TestBase, AssertsCompiledSQL):
         self.assert_compile(table.select(between((table.c.field == table.c.field), False, True)),
             "SELECT op.field FROM op WHERE (op.field = op.field) BETWEEN :param_1 AND :param_2")
     
+    def test_associativity(self):
+        f = column('f')
+        self.assert_compile( f - f, "f - f" )
+        self.assert_compile( f - f - f, "(f - f) - f" )
+        
+        self.assert_compile( (f - f) - f, "(f - f) - f" )
+        self.assert_compile( (f - f).label('foo') - f, "(f - f) - f" )
+        
+        self.assert_compile( f - (f - f), "f - (f - f)" )
+        self.assert_compile( f - (f - f).label('foo'), "f - (f - f)" )
+
+        # because - less precedent than /
+        self.assert_compile( f / (f - f), "f / (f - f)" )
+        self.assert_compile( f / (f - f).label('foo'), "f / (f - f)" )
+
+        self.assert_compile( f / f - f, "f / f - f" )
+        self.assert_compile( (f / f) - f, "f / f - f" )
+        self.assert_compile( (f / f).label('foo') - f, "f / f - f" )
+        
+        # because / more precedent than -
+        self.assert_compile( f - (f / f), "f - f / f" )
+        self.assert_compile( f - (f / f).label('foo'), "f - f / f" )
+        self.assert_compile( f - f / f, "f - f / f" )
+        self.assert_compile( (f - f) / f, "(f - f) / f" )
+        
+        self.assert_compile( ((f - f) / f) - f, "(f - f) / f - f")
+        self.assert_compile( (f - f) / (f - f), "(f - f) / (f - f)")
+        
+        # higher precedence
+        self.assert_compile( (f / f) - (f / f), "f / f - f / f")
+
+        self.assert_compile( (f / f) - (f - f), "f / f - (f - f)")
+        self.assert_compile( (f / f) / (f - f), "(f / f) / (f - f)")
+        self.assert_compile( f / (f / (f - f)), "f / (f / (f - f))")
+        
+    
     def test_delayed_col_naming(self):
         my_str = Column(String)