]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- parenthesis are applied to clauses via a new _Grouping construct.
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 May 2007 22:25:36 +0000 (22:25 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 May 2007 22:25:36 +0000 (22:25 +0000)
uses operator precedence to more intelligently apply parenthesis
to clauses, provides cleaner nesting of clauses (doesnt mutate
clauses placed in other clauses, i.e. no 'parens' flag)
- added 'modifier' keyword, works like func.<foo> except does not
add parenthesis.  e.g. select([modifier.DISTINCT(...)]) etc.

CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/sql_util.py
test/sql/query.py
test/sql/select.py

diff --git a/CHANGES b/CHANGES
index a7d1477c8613b98aa550194e53a4a3feb50a9895..6e44a2e09cd07261b5b06e2f7eabf01ada4c7364 100644 (file)
--- a/CHANGES
+++ b/CHANGES
       behave more properly with regards to FROM clause #574 
     - fix to long name generation when using oid_column as an order by
       (oids used heavily in mapper queries)
+    - parenthesis are applied to clauses via a new _Grouping construct.
+      uses operator precedence to more intelligently apply parenthesis 
+      to clauses, provides cleaner nesting of clauses (doesnt mutate
+      clauses placed in other clauses, i.e. no 'parens' flag)
+    - added 'modifier' keyword, works like func.<foo> except does not
+      add parenthesis.  e.g. select([modifier.DISTINCT(...)]) etc.
 - orm
     - "delete-orphan" no longer implies "delete". ongoing effort to 
       separate the behavior of these two operations.
index ab043f3ec963a88dbac1456e8a53ffa8ae1d8dca..28dd0866ca4eff1c72682f8f6999828df11d2d70 100644 (file)
@@ -245,7 +245,10 @@ class ANSICompiler(sql.Compiled):
         """
 
         return ""
-
+    
+    def visit_grouping(self, grouping):
+        self.strings[grouping] = "(" + self.strings[grouping.elem] + ")"
+        
     def visit_label(self, label):
         labelname = self._truncated_identifier("colident", label.name)
         
@@ -298,10 +301,7 @@ class ANSICompiler(sql.Compiled):
         self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec()
 
     def visit_textclause(self, textclause):
-        if textclause.parens and len(textclause.text):
-            self.strings[textclause] = "(" + textclause.text + ")"
-        else:
-            self.strings[textclause] = textclause.text
+        self.strings[textclause] = textclause.text
         self.froms[textclause] = textclause.text
         if textclause.typemap is not None:
             self.typemap.update(textclause.typemap)
@@ -309,32 +309,21 @@ class ANSICompiler(sql.Compiled):
     def visit_null(self, null):
         self.strings[null] = 'NULL'
 
-    def visit_compound(self, compound):
-        if compound.operator is None:
-            sep = " "
-        else:
-            sep = " " + compound.operator + " "
-
-        s = string.join([self.get_str(c) for c in compound.clauses], sep)
-        if compound.parens:
-            self.strings[compound] = "(" + s + ")"
-        else:
-            self.strings[compound] = s
-
     def visit_clauselist(self, list):
-        if list.parens:
-            self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ') + ")"
+        sep = list.operator
+        if sep == ',':
+            sep = ', '
+        elif sep is None or sep == " ":
+            sep = " "
         else:
-            self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ')
+            sep = " " + sep + " "
+        self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], sep)
 
     def apply_function_parens(self, func):
         return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
 
-    def visit_calculatedclause(self, list):
-        if list.parens:
-            self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ' ') + ")"
-        else:
-            self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ' ')
+    def visit_calculatedclause(self, clause):
+        self.strings[clause] = self.get_str(clause.clause_expr)
 
     def visit_cast(self, cast):
         if len(self.select_stack):
@@ -349,7 +338,7 @@ class ANSICompiler(sql.Compiled):
             self.strings[func] = ".".join(func.packagenames + [func.name])
             self.froms[func] = self.strings[func]
         else:
-            self.strings[func] = ".".join(func.packagenames + [func.name]) + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")"
+            self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.get_str(func.clause_expr)
             self.froms[func] = self.strings[func]
 
     def visit_compound_select(self, cs):
@@ -359,19 +348,22 @@ class ANSICompiler(sql.Compiled):
             text += " GROUP BY " + group_by
         text += self.order_by_clause(cs)            
         text += self.visit_select_postclauses(cs)
-        if cs.parens:
-            self.strings[cs] = "(" + text + ")"
-        else:
-            self.strings[cs] = text
+        self.strings[cs] = text
         self.froms[cs] = "(" + text + ")"
 
+    def visit_unary(self, unary):
+        s = self.get_str(unary.element)
+        if unary.operator:
+            s = unary.operator + " " + s
+        if unary.modifier:
+            s = s + " " + unary.modifier
+        self.strings[unary] = s
+        
     def visit_binary(self, binary):
         result = self.get_str(binary.left)
         if binary.operator is not None:
             result += " " + self.binary_operator_string(binary)
         result += " " + self.get_str(binary.right)
-        if binary.parens:
-            result = "(" + result + ")"
         self.strings[binary] = result
 
     def binary_operator_string(self, binary):
@@ -438,10 +430,6 @@ class ANSICompiler(sql.Compiled):
 
         self.select_stack.append(select)
         for c in select._raw_columns:
-            if isinstance(c, sql.Select) and c.is_scalar:
-                self.traverse(c)
-                inner_columns[self.get_str(c)] = c
-                continue
             if hasattr(c, '_selectable'):
                 s = c._selectable()
             else:
@@ -484,6 +472,7 @@ class ANSICompiler(sql.Compiled):
         for f in select.froms:
 
             if self.parameters is not None:
+                # TODO: whack this feature in 0.4
                 # look at our own parameters, see if they
                 # are all present in the form of BindParamClauses.  if
                 # not, then append to the above whereclause column conditions
@@ -494,16 +483,20 @@ class ANSICompiler(sql.Compiled):
                     else:
                         continue
                     clause = c==value
-                    self.traverse(clause)
-                    whereclause = sql.and_(clause, whereclause)
-                    self.visit_compound(whereclause)
+                    if whereclause is not None:
+                        whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause]))
+                    else:
+                        whereclause = clause
+                        self.traverse(whereclause)
 
             # special thingy used by oracle to redefine a join
             w = self.get_whereclause(f)
             if w is not None:
                 # TODO: move this more into the oracle module
-                whereclause = sql.and_(w, whereclause)
-                self.visit_compound(whereclause)
+                if whereclause is not None:
+                    whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w]))
+                else:
+                    whereclause = w
 
             t = self.get_from_text(f)
             if t is not None:
@@ -533,10 +526,7 @@ class ANSICompiler(sql.Compiled):
         text += self.visit_select_postclauses(select)
         text += self.for_update_clause(select)
 
-        if getattr(select, 'parens', False):
-            self.strings[select] = "(" + text + ")"
-        else:
-            self.strings[select] = text
+        self.strings[select] = text
         self.froms[select] = "(" + text + ")"
 
     def visit_select_precolumns(self, select):
index 0c1e2ff62a8df69b9d398860dd73cbfd2fe12938..7e79a889f8eaf9d1e19808b233b9184c713cef05 100644 (file)
@@ -441,7 +441,12 @@ class OracleCompiler(ansisql.ANSICompiler):
             return ansisql.ANSICompiler.visit_join(self, join)
 
         self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right)
-        self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
+        where = self.wheres.get(join.left, None)
+        if where is not None:
+            self.wheres[join] = sql.and_(where, join.onclause)
+        else:
+            self.wheres[join] = join.onclause
+#        self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
         self.strings[join] = self.froms[join]
 
         if join.isouter:
@@ -454,7 +459,7 @@ class OracleCompiler(ansisql.ANSICompiler):
 
             self._outertable = None
 
-        self.visit_compound(self.wheres[join])
+        self.wheres[join].accept_visitor(self)
 
     def visit_insert_sequence(self, column, sequence, parameters):
         """This is the `sequence` equivalent to ``ANSICompiler``'s
index 1dfe8293474eca429f8c2fb01d21705628e19eac..ca18d1e26b691e418665986730ee4bdf4f543d88 100644 (file)
@@ -714,7 +714,7 @@ class Engine(Connectable):
             connection.close()
 
     def _func(self):
-        return sql._FunctionGenerator(self)
+        return sql._FunctionGenerator(engine=self)
 
     func = property(_func)
 
index 00ca7cde7862a2876f08f8b1a1366973cd80a399..ddf7d6251c8bb1af496178a65f1c1d6b4fae6858 100644 (file)
@@ -479,6 +479,8 @@ class Session(object):
                     merged = self.identity_map[key]
                 else:
                     merged = self.get(mapper.class_, key[1])
+                    if merged is None:
+                        raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object))
             for prop in mapper.props.values():
                 prop.merge(self, object, merged, _recursive)
             if key is None:
index 5adef46f2776ce53f995d6b757b3350420e780f9..e27181a9a7547015b8381434c1855b25e302af0b 100644 (file)
@@ -34,11 +34,30 @@ __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters',
            'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join',
            'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc',
            'between_', 'bindparam', 'case', 'cast', 'column', 'delete',
-           'desc', 'except_', 'except_all', 'exists', 'extract', 'func',
+           'desc', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
            'insert', 'intersect', 'intersect_all', 'join', 'literal',
            'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select',
            'subquery', 'table', 'text', 'union', 'union_all', 'update',]
 
+# precedence ordering for common operators.  if an operator is not present in this list,
+# its precedence is assumed to be '0' which will cause it to be parenthesized when grouped against other operators
+PRECEDENCE = {
+    'FROM':15,
+    'AS':15,
+    'NOT':10,
+    'AND':3,
+    'OR':3,
+    '=':7,
+    '!=':7,
+    '>':7,
+    '<':7,
+    '+':5,
+    '-':5,
+    '*':5,
+    '/':5,
+    ',':0
+}
+
 def desc(column):
     """Return a descending ``ORDER BY`` clause element.
 
@@ -46,7 +65,7 @@ def desc(column):
 
       order_by = [desc(table1.mycol)]
     """
-    return _CompoundClause(None, column, "DESC")
+    return _UnaryExpression(column, modifier="DESC")
 
 def asc(column):
     """Return an ascending ``ORDER BY`` clause element.
@@ -55,7 +74,7 @@ def asc(column):
 
       order_by = [asc(table1.mycol)]
     """
-    return _CompoundClause(None, column, "ASC")
+    return _UnaryExpression(column, modifier="ASC")
 
 def outerjoin(left, right, onclause=None, **kwargs):
     """Return an ``OUTER JOIN`` clause element.
@@ -332,8 +351,9 @@ def and_(*clauses):
     The ``&`` operator is also overloaded on all [sqlalchemy.sql#_CompareMixin]
     subclasses to produce the same result.
     """
-
-    return _compound_clause('AND', *clauses)
+    if len(clauses) == 1:
+        return clauses[0]
+    return ClauseList(operator='AND', *clauses)
 
 def or_(*clauses):
     """Join a list of clauses together using the ``OR`` operator.
@@ -342,7 +362,9 @@ def or_(*clauses):
     subclasses to produce the same result.
     """
 
-    return _compound_clause('OR', *clauses)
+    if len(clauses) == 1:
+        return clauses[0]
+    return ClauseList(operator='OR', *clauses)
 
 def not_(clause):
     """Return a negation of the given clause, i.e. ``NOT(clause)``.
@@ -362,7 +384,7 @@ def between(ctest, cleft, cright):
     provides similar functionality.
     """
 
-    return _BooleanExpression(ctest, and_(_check_literal(cleft, ctest.type), _check_literal(cright, ctest.type)), 'BETWEEN')
+    return _BinaryExpression(ctest, and_(_literals_as_binds(cleft, type=ctest.type), _literals_as_binds(cright, type=ctest.type)), 'BETWEEN')
 
 def between_(*args, **kwargs):
     """synonym for [sqlalchemy.sql#between()] (deprecated)."""
@@ -383,16 +405,14 @@ def case(whens, value=None, else_=None):
 
     """
 
-    whenlist = [_CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens]
+    whenlist = [ClauseList('WHEN', c, 'THEN', r, operator=None) for (c,r) in whens]
     if not else_ is None:
-        whenlist.append(_CompoundClause(None, 'ELSE', else_))
+        whenlist.append(ClauseList('ELSE', else_, operator=None))
     if len(whenlist):
         type = list(whenlist[-1])[-1].type
     else:
         type = None
-    cc = _CalculatedClause(None, 'CASE', value, type=type, *whenlist + ['END'])
-    for c in cc.clauses:
-        c.parens = False
+    cc = _CalculatedClause(None, 'CASE', value, type=type, operator=None, group_contents=False, *whenlist + ['END'])
     return cc
 
 def cast(clause, totype, **kwargs):
@@ -414,7 +434,7 @@ def cast(clause, totype, **kwargs):
 def extract(field, expr):
     """Return the clause ``extract(field FROM expr)``."""
 
-    expr = _BinaryClause(text(field), expr, "FROM")
+    expr = _BinaryExpression(text(field), expr, "FROM")
     return func.extract(expr)
 
 def exists(*args, **kwargs):
@@ -543,11 +563,6 @@ def alias(selectable, alias=None):
         
     return Alias(selectable, alias=alias)
 
-def _check_literal(value, type):
-    if _is_literal(value):
-        return literal(value, type)
-    else:
-        return value
 
 def literal(value, type=None):
     """Return a literal clause, bound to a bind parameter.
@@ -714,20 +729,27 @@ def null():
 
     return _Null()
 
-class _FunctionGateway(object):
-    """Return a callable based on an attribute name, which then
-    returns a ``_Function`` object with that name.
-    """
+class _FunctionGenerator(object):
+    """Generate ``_Function`` objects based on getattr calls."""
+
+    def __init__(self, **opts):
+        self.__names = []
+        self.opts = opts
 
     def __getattr__(self, name):
         if name[-1] == '_':
             name = name[0:-1]
-        return getattr(_FunctionGenerator(), name)
+        f = _FunctionGenerator(**self.opts)
+        f.__names = list(self.__names) + [name]
+        return f
 
-func = _FunctionGateway()
+    def __call__(self, *c, **kwargs):
+        o = self.opts.copy()
+        o.update(kwargs)
+        return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o)
 
-def _compound_clause(keyword, *clauses):
-    return _CompoundClause(keyword, *clauses)
+func = _FunctionGenerator()
+modifier = _FunctionGenerator(group=False)
 
 def _compound_select(keyword, *selects, **kwargs):
     return CompoundSelect(keyword, *selects, **kwargs)
@@ -735,6 +757,21 @@ def _compound_select(keyword, *selects, **kwargs):
 def _is_literal(element):
     return not isinstance(element, ClauseElement)
 
+def _literals_as_text(element):
+    if _is_literal(element):
+        return _TextClause(unicode(element))
+    else:
+        return element
+
+def _literals_as_binds(element, name='literal', type=None):
+    if _is_literal(element):
+        if element is None:
+            return null()
+        else:
+            return _BindParamClause(name, element, shortname=name, type=type, unique=True)
+    else:
+        return element
+        
 def is_column(col):
     return isinstance(col, ColumnElement)
 
@@ -825,13 +862,23 @@ class ClauseVisitor(object):
     (column_collections=False) or to return Schema-level items
     (schema_visitor=True)."""
     __traverse_options__ = {}
-    def traverse(self, obj):
-        for n in obj.get_children(**self.__traverse_options__):
-            self.traverse(n)
-        v = self
-        while v is not None:
-            obj.accept_visitor(v)
-            v = getattr(v, '_next', None)
+    def traverse(self, obj, stop_on=None, echo=False):
+        stack = [obj]
+        traversal = []
+        while len(stack) > 0:
+            t = stack.pop()
+            if stop_on is None or t not in stop_on:
+                traversal.insert(0, t)
+                for c in t.get_children(**self.__traverse_options__):
+                    stack.append(c)
+        for target in traversal:
+            v = self
+            if echo:
+                print "VISITING", repr(target), "STOP ON", stop_on
+            while v is not None:
+                target.accept_visitor(v)
+                v = getattr(v, '_next', None)
+        return obj
         
     def chain(self, visitor):
         """'chain' an additional ClauseVisitor onto this ClauseVisitor.
@@ -859,6 +906,8 @@ class ClauseVisitor(object):
         pass
     def visit_binary(self, binary):
         pass
+    def visit_unary(self, unary):
+        pass
     def visit_alias(self, alias):
         pass
     def visit_select(self, select):
@@ -871,6 +920,8 @@ class ClauseVisitor(object):
         pass
     def visit_calculatedclause(self, calcclause):
         pass
+    def visit_grouping(self, gr):
+        pass
     def visit_function(self, func):
         pass
     def visit_cast(self, cast):
@@ -1060,7 +1111,10 @@ class ClauseElement(object):
         child items from a different context (such as schema-level
         collections instead of clause-level)."""
         return []
-        
+    
+    def self_group(self, against=None):
+        return self
+
     def supports_execution(self):
         """Return True if this clause element represents a complete
         executable statement.
@@ -1175,8 +1229,7 @@ class ClauseElement(object):
         return self._negate()
 
     def _negate(self):
-        self.parens=True
-        return _BooleanExpression(_TextClause("NOT"), self, None)
+        return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None)
 
 class _CompareMixin(object):
     """Defines comparison operations for ``ClauseElement`` instances.
@@ -1237,7 +1290,7 @@ class _CompareMixin(object):
             else:
                 o = self._bind_param(o)
             args.append(o)
-        return self._compare( 'IN', ClauseList( parens=True, *args), negate='NOT IN')
+        return self._compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN')
 
     def startswith(self, other):
         """produce the clause ``LIKE '<other>%'``"""
@@ -1267,11 +1320,11 @@ class _CompareMixin(object):
 
     def distinct(self):
         """produce a DISTINCT clause, i.e. ``DISTINCT <columnname>``"""
-        return _CompoundClause(None,"DISTINCT", self)
+        return _UnaryExpression(self, operator="DISTINCT")
 
     def between(self, cleft, cright):
         """produce a BETWEEN clause, i.e. ``<column> BETWEEN <cleft> AND <cright>``"""
-        return _BooleanExpression(self, and_(self._check_literal(cleft), self._check_literal(cright)), 'BETWEEN')
+        return _BinaryExpression(self, and_(self._check_literal(cleft), self._check_literal(cright)), 'BETWEEN')
 
     def op(self, operator):
         """produce a generic operator function.
@@ -1324,15 +1377,15 @@ class _CompareMixin(object):
     def _compare(self, operator, obj, negate=None):
         if obj is None or isinstance(obj, _Null):
             if operator == '=':
-                return _BooleanExpression(self._compare_self(), null(), 'IS', negate='IS NOT')
+                return _BinaryExpression(self._compare_self(), null(), 'IS', negate='IS NOT')
             elif operator == '!=':
-                return _BooleanExpression(self._compare_self(), null(), 'IS NOT', negate='IS')
+                return _BinaryExpression(self._compare_self(), null(), 'IS NOT', negate='IS')
             else:
                 raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
         else:
             obj = self._check_literal(obj)
 
-        return _BooleanExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate)
+        return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate)
 
     def _operate(self, operator, obj):
         if _is_literal(obj):
@@ -1348,7 +1401,7 @@ class _CompareMixin(object):
 
     def _compare_type(self, obj):
         """Allow subclasses to override the type used in constructing
-        ``_BinaryClause`` objects.
+        ``_BinaryExpression`` objects.
 
         Default return value is the type of the given object.
         """
@@ -1384,6 +1437,7 @@ class Selectable(ClauseElement):
 
         return True
 
+        
 class ColumnElement(Selectable, _CompareMixin):
     """Represent an element that is useable within the 
     "column clause" portion of a ``SELECT`` statement. 
@@ -1789,7 +1843,6 @@ class _TextClause(ClauseElement):
     """
 
     def __init__(self, text = "", engine=None, bindparams=None, typemap=None):
-        self.parens = False
         self._engine = engine
         self.bindparams = {}
         self.typemap = typemap
@@ -1845,29 +1898,42 @@ class _Null(ColumnElement):
         return []
 
 class ClauseList(ClauseElement):
-    """Describe a list of clauses.
+    """Describe a list of clauses, separated by an operator.
 
     By default, is comma-separated, such as a column listing.
     """
 
     def __init__(self, *clauses, **kwargs):
         self.clauses = []
+        self.operator = kwargs.pop('operator', ',')
+        self.group = kwargs.pop('group', True)
+        self.group_contents = kwargs.pop('group_contents', True)
         for c in clauses:
             if c is None: continue
             self.append(c)
-        self.parens = kwargs.get('parens', False)
 
     def __iter__(self):
         return iter(self.clauses)
-
+    def __len__(self):
+        return len(self.clauses)
+        
     def copy_container(self):
         clauses = [clause.copy_container() for clause in self.clauses]
-        return ClauseList(parens=self.parens, *clauses)
+        return ClauseList(operator=self.operator, *clauses)
+
+    def self_group(self, against=None):
+        if self.group:
+            return _Grouping(self)
+        else:
+            return self
 
     def append(self, clause):
-        if _is_literal(clause):
-            clause = _TextClause(unicode(clause))
-        self.clauses.append(clause)
+        # TODO: not sure if i like the 'group_contents' flag.  need to define the difference between
+        # a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists.  flatten() method ?
+        if self.group_contents:
+            self.clauses.append(_literals_as_text(clause).self_group(against=self.operator))
+        else:
+            self.clauses.append(_literals_as_text(clause))
 
     def get_children(self, **kwargs):
         return self.clauses
@@ -1881,6 +1947,12 @@ class ClauseList(ClauseElement):
             f += c._get_from_objects()
         return f
 
+    def self_group(self, against=None):
+        if PRECEDENCE.get(self.operator, 0) <= PRECEDENCE.get(against, 0):
+            return _Grouping(self)
+        else:
+            return self
+
     def compare(self, other):
         """Compare this ``ClauseList`` to the given ``ClauseList``,
         including a comparison of all the clause items.
@@ -1891,59 +1963,11 @@ class ClauseList(ClauseElement):
                 if not self.clauses[i].compare(other.clauses[i]):
                     return False
             else:
-                return True
-        else:
-            return False
-
-class _CompoundClause(ClauseList):
-    """Represent a list of clauses joined by an operator, such as ``AND`` or ``OR``.
-
-    Extends ``ClauseList`` to add the operator as well as a
-    `from_objects` accessor to help determine ``FROM`` objects in a
-    ``SELECT`` statement.
-    """
-
-    def __init__(self, operator, *clauses, **kwargs):
-        ClauseList.__init__(self, *clauses, **kwargs)
-        self.operator = operator
-
-    def copy_container(self):
-        clauses = [clause.copy_container() for clause in self.clauses]
-        return _CompoundClause(self.operator, *clauses)
-
-    def append(self, clause):
-        if isinstance(clause, _CompoundClause):
-            clause.parens = True
-        ClauseList.append(self, clause)
-
-    def get_children(self, **kwargs):
-        return self.clauses
-    def accept_visitor(self, visitor):
-        visitor.visit_compound(self)
-
-    def _get_from_objects(self):
-        f = []
-        for c in self.clauses:
-            f += c._get_from_objects()
-        return f
-
-    def compare(self, other):
-        """Compare this ``_CompoundClause`` to the given item.
-
-        In addition to the regular comparison, has the special case
-        that it returns True if this ``_CompoundClause`` has only one
-        item, and that item matches the given item.
-        """
-
-        if not isinstance(other, _CompoundClause):
-            if len(self.clauses) == 1:
-                return self.clauses[0].compare(other)
-        if ClauseList.compare(self, other):
-            return self.operator == other.operator
+                return self.operator == other.operator
         else:
             return False
 
-class _CalculatedClause(ClauseList, ColumnElement):
+class _CalculatedClause(ColumnElement):
     """Describe a calculated SQL expression that has a type, like ``CASE``.
 
     Extends ``ColumnElement`` to provide column-level comparison
@@ -1954,8 +1978,13 @@ class _CalculatedClause(ClauseList, ColumnElement):
         self.name = name
         self.type = sqltypes.to_instance(kwargs.get('type', None))
         self._engine = kwargs.get('engine', None)
-        ClauseList.__init__(self, *clauses)
-
+        self.group = kwargs.pop('group', True)
+        self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses)
+        if self.group:
+            self.clause_expr = self.clauses.self_group()
+        else:
+            self.clause_expr = self.clauses
+            
     key = property(lambda self:self.name or "_calc_")
 
     def copy_container(self):
@@ -1963,9 +1992,12 @@ class _CalculatedClause(ClauseList, ColumnElement):
         return _CalculatedClause(type=self.type, engine=self._engine, *clauses)
 
     def get_children(self, **kwargs):
-        return self.clauses
+        return self.clause_expr,
+        
     def accept_visitor(self, visitor):
         visitor.visit_calculatedclause(self)
+    def _get_from_objects(self):
+        return self.clauses._get_from_objects()
 
     def _bind_param(self, obj):
         return _BindParamClause(self.name, obj, type=self.type, unique=True)
@@ -1990,28 +2022,24 @@ class _Function(_CalculatedClause, FromClause):
     """
 
     def __init__(self, name, *clauses, **kwargs):
-        self.name = name
         self.type = sqltypes.to_instance(kwargs.get('type', None))
         self.packagenames = kwargs.get('packagenames', None) or []
+        kwargs['operator'] = ','
         self._engine = kwargs.get('engine', None)
-        ClauseList.__init__(self, parens=True, *[c is None and _Null() or c for c in clauses])
+        _CalculatedClause.__init__(self, name, **kwargs)
+        for c in clauses:
+            self.append(c)
 
     key = property(lambda self:self.name)
 
+
     def append(self, clause):
-        if _is_literal(clause):
-            if clause is None:
-                clause = null()
-            else:
-                clause = _BindParamClause(self.name, clause, shortname=self.name, type=None, unique=True)
-        self.clauses.append(clause)
+        self.clauses.append(_literals_as_binds(clause, self.name))
 
     def copy_container(self):
         clauses = [clause.copy_container() for clause in self.clauses]
         return _Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses)
-
-    def get_children(self, **kwargs):
-        return self.clauses
+        
     def accept_visitor(self, visitor):
         visitor.visit_function(self)
 
@@ -2040,39 +2068,52 @@ class _Cast(ColumnElement):
         else:
             return self
 
-class _FunctionGenerator(object):
-    """Generate ``_Function`` objects based on getattr calls."""
 
-    def __init__(self, engine=None):
-        self.__engine = engine
-        self.__names = []
+class _UnaryExpression(ColumnElement):
+    def __init__(self, element, operator=None, modifier=None, type=None, negate=None):
+        self.operator = operator
+        self.modifier = modifier
+        
+        self.element = _literals_as_text(element).self_group(against=self.operator or self.modifier)
+        self.type = sqltypes.to_instance(type)
+        self.negate = negate
+        
+    def copy_container(self):
+        return self.__class__(self.element.copy_container(), operator=self.operator, modifier=self.modifier, type=self.type, negate=self.negate)
 
-    def __getattr__(self, name):
-        self.__names.append(name)
-        return self
+    def _get_from_objects(self):
+        return self.element._get_from_objects()
 
-    def __call__(self, *c, **kwargs):
-        kwargs.setdefault('engine', self.__engine)
-        return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **kwargs)
+    def get_children(self, **kwargs):
+        return self.element,
 
-class _BinaryClause(ClauseElement):
-    """Represent two clauses with an operator in between.
-    
-    This class serves as the base class for ``_BinaryExpression``
-    and ``_BooleanExpression``, both of which add additional 
-    semantics to the base ``_BinaryClause`` construct.
-    """
+    def accept_visitor(self, visitor):
+        visitor.visit_unary(self)
+
+    def compare(self, other):
+        """Compare this ``_UnaryClause`` against the given ``ClauseElement``."""
 
-    def __init__(self, left, right, operator, type=None):
-        self.left = left
-        self.right = right
+        return (
+            isinstance(other, _UnaryClause) and self.operator == other.operator and
+            self.modifier == other.modifier and 
+            self.element.compare(other.element)
+        )
+    def _negate(self):
+        if self.negate is not None:
+            return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type)
+        else:
+            return super(_UnaryExpression, self)._negate()
+
+
+class _BinaryExpression(ColumnElement):
+    """Represent an expression that is ``LEFT <operator> RIGHT``."""
+    
+    def __init__(self, left, right, operator, type=None, negate=None):
+        self.left = _literals_as_text(left).self_group(against=operator)
+        self.right = _literals_as_text(right).self_group(against=operator)
         self.operator = operator
         self.type = sqltypes.to_instance(type)
-        self.parens = False
-        if isinstance(self.left, _BinaryClause) or hasattr(self.left, '_selectable'):
-            self.left.parens = True
-        if isinstance(self.right, _BinaryClause) or hasattr(self.right, '_selectable'):
-            self.right.parens = True
+        self.negate = negate
 
     def copy_container(self):
         return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator)
@@ -2086,58 +2127,31 @@ class _BinaryClause(ClauseElement):
     def accept_visitor(self, visitor):
         visitor.visit_binary(self)
 
-    def swap(self):
-        c = self.left
-        self.left = self.right
-        self.right = c
-
     def compare(self, other):
-        """Compare this ``_BinaryClause`` against the given ``_BinaryClause``."""
+        """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``."""
 
         return (
-            isinstance(other, _BinaryClause) and self.operator == other.operator and
+            isinstance(other, _BinaryExpression) and self.operator == other.operator and
             self.left.compare(other.left) and self.right.compare(other.right)
         )
-
-class _BinaryExpression(_BinaryClause, ColumnElement):
-    """Represent a binary expression, which can be in a ``WHERE``
-    criterion or in the column list of a ``SELECT``.
-
-    This class differs from ``_BinaryClause`` in that it mixes
-    in ``ColumnElement``.  The effect is that elements of this 
-    type become ``Selectable`` units which can be placed in the 
-    column list of a ``select()`` construct.
-    
-    """
-
-    pass
-
-class _BooleanExpression(_BinaryExpression):
-    """Represent a boolean expression.
-    
-    ``_BooleanExpression`` is constructed as the result of compare operations
-    involving ``CompareMixin`` subclasses, such as when comparing a ``ColumnElement``
-    to a scalar value via the ``==`` operator, ``CompareMixin``'s ``__eq__()`` method
-    produces a ``_BooleanExpression`` consisting of the ``ColumnElement`` and a
-    ``_BindParamClause``.
+        
+    def self_group(self, against=None):
+        if PRECEDENCE.get(self.operator, 0) <= PRECEDENCE.get(against, 0):
+            return _Grouping(self)
+        else:
+            return self
     
-    """
-
-    def __init__(self, *args, **kwargs):
-        self.negate = kwargs.pop('negate', None)
-        super(_BooleanExpression, self).__init__(*args, **kwargs)
-
     def _negate(self):
         if self.negate is not None:
-            return _BooleanExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type)
+            return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type)
         else:
-            return super(_BooleanExpression, self)._negate()
+            return super(_BinaryExpression, self)._negate()
 
-class _Exists(_BooleanExpression):
+class _Exists(_UnaryExpression):
     def __init__(self, *args, **kwargs):
         kwargs['correlate'] = True
-        s = select(*args, **kwargs)
-        _BooleanExpression.__init__(self, _TextClause("EXISTS"), s, None)
+        s = select(*args, **kwargs).self_group()
+        _UnaryExpression.__init__(self, s, operator="EXISTS")
 
     def _hide_froms(self):
         return self._get_from_objects()
@@ -2358,6 +2372,27 @@ class Alias(FromClause):
 
     engine = property(lambda s: s.selectable.engine)
 
+class _Grouping(ColumnElement):
+    def __init__(self, elem):
+        self.elem = elem
+        self.type = getattr(elem, 'type', None)
+
+    key = property(lambda s: s.elem.key)
+    _label = property(lambda s: s.elem._label)
+    orig_set = property(lambda s:s.elem.orig_set)
+    
+    def copy_container(self):
+        return _Grouping(self.elem.copy_container())
+        
+    def accept_visitor(self, visitor):
+        visitor.visit_grouping(self)
+    def get_children(self, **kwargs):
+        return self.elem,
+    def _hide_froms(self):
+        return self.elem._hide_froms()
+    def _get_from_objects(self):
+        return self.elem._get_from_objects()
+        
 class _Label(ColumnElement):
     """represent a label, as typically applied to any column-level element
     using the ``AS`` sql keyword.
@@ -2372,10 +2407,9 @@ class _Label(ColumnElement):
         self.name = name
         while isinstance(obj, _Label):
             obj = obj.obj
-        self.obj = obj
+        self.obj = obj.self_group(against='AS')
         self.case_sensitive = getattr(obj, "case_sensitive", True)
         self.type = sqltypes.to_instance(type or getattr(obj, 'type', None))
-        obj.parens=True
 
     key = property(lambda s: s.name)
     _label = property(lambda s: s.name)
@@ -2633,7 +2667,6 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         _SelectBaseMixin.__init__(self)
         self.keyword = keyword
         self.use_labels = kwargs.pop('use_labels', False)
-        self.parens = kwargs.pop('parens', False)
         self.should_correlate = kwargs.pop('correlate', False)
         self.for_update = kwargs.pop('for_update', False)
         self.nowait = kwargs.pop('nowait', False)
@@ -2794,8 +2827,6 @@ class Select(_SelectBaseMixin, FromClause):
 
         def visit_compound_select(self, cs):
             self.visit_select(cs)
-            for s in cs.selects:
-                s.parens = False
 
         def visit_column(self, c):
             pass
@@ -2808,7 +2839,6 @@ class Select(_SelectBaseMixin, FromClause):
                 return
             select.is_where = self.is_where
             select.is_subquery = True
-            select.parens = True
             if not select.should_correlate:
                 return
             [select.correlate(x) for x in self.select._Select__froms]
@@ -2828,6 +2858,9 @@ class Select(_SelectBaseMixin, FromClause):
         if _is_literal(column):
             column = literal_column(str(column))
 
+        if isinstance(column, Select) and column.is_scalar:
+            column = column.self_group(against=',')
+
         self._raw_columns.append(column)
 
         if self.is_scalar and not hasattr(self, 'type'):
@@ -2873,6 +2906,9 @@ class Select(_SelectBaseMixin, FromClause):
         for f in elem._hide_froms():
             self.__hide_froms.add(f)
 
+    def self_group(self, against=None):
+        return _Grouping(self)
+    
     def append_whereclause(self, whereclause):
         self._append_condition('whereclause', whereclause)
 
index dcd19f891569196e76124a0a4af3a505354657dc..1f5ac168118d59c04f803968ae4011d2958b4a8f 100644 (file)
@@ -133,16 +133,23 @@ class AbstractClauseProcessor(sql.NoColumnVisitor):
                 list_[i] = elem
             else:
                 self.traverse(list_[i])
-
-    def visit_compound(self, compound):
-        self.visit_clauselist(compound)
-
+    
+    def visit_grouping(self, grouping):
+        elem = self.convert_element(grouping.elem)
+        if elem is not None:
+            grouping.elem = elem
+            
     def visit_clauselist(self, clist):
         for i in range(0, len(clist.clauses)):
             n = self.convert_element(clist.clauses[i])
             if n is not None:
                 clist.clauses[i] = n
-
+    
+    def visit_unary(self, unary):
+        elem = self.convert_element(unary.element)
+        if elem is not None:
+            unary.element = elem
+            
     def visit_binary(self, binary):
         elem = self.convert_element(binary.left)
         if elem is not None:
index eb2012887ad6253fc11e951c52a5adf197124fd2..d788544c096646a6404ef0f1cdcbd39a978633f8 100644 (file)
@@ -470,7 +470,7 @@ class CompoundTest(PersistTest):
             select([t1.c.col3, t1.c.col4]),
             select([t2.c.col3, t2.c.col4]),
             select([t3.c.col3, t3.c.col4]),
-        parens=True), select([t2.c.col3, t2.c.col4]))
+        ), select([t2.c.col3, t2.c.col4]))
         assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')]
 
     @testbase.unsupported('mysql', 'oracle')
index 34a37c97846dd9976173391bbc5c08335f33521c..6d0d3b04cc11a32ec42ac668bd57d924fdad80ca 100644 (file)
@@ -139,7 +139,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
 
         self.runtest(select([table1, exists([1], from_obj=[table2])]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) FROM mytable", params={})
 
-        self.runtest(select([table1, exists([1], from_obj=[table2]).label('foo')]), "SELECT mytable.myid, mytable.name, mytable.description, (EXISTS (SELECT 1 FROM myothertable)) AS foo FROM mytable", params={})
+        self.runtest(select([table1, exists([1], from_obj=[table2]).label('foo')]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) AS foo FROM mytable", params={})
         
     def testwheresubquery(self):
         # TODO: this tests that you dont get a "SELECT column" without a FROM but its not working yet.
@@ -426,6 +426,9 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
             "SELECT column1 AS foobar, column2 AS hoho, mytable.myid AS mytable_myid FROM mytable"
         )
         
+        print "---------------------------------------------"
+        s1 = select(["column1 AS foobar", "column2 AS hoho", table1.c.myid], from_obj=[table1])
+        print "---------------------------------------------"
         # test that "auto-labeling of subquery columns" doesnt interfere with literal columns,
         # exported columns dont get quoted
         self.runtest(
@@ -633,7 +636,7 @@ FROM myothertable ORDER BY myid \
         
         query = select(
                 [table1, table2],
-                and_(
+                or_(
                     table1.c.name == 'fred',
                     table1.c.myid == 10,
                     table2.c.othername != 'jack',
@@ -641,21 +644,22 @@ FROM myothertable ORDER BY myid \
                 ),
                 from_obj = [ outerjoin(table1, table2, table1.c.myid == table2.c.otherid) ]
                 )
-                
-        self.runtest(query, 
-            "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \
-FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid \
-WHERE mytable.name = %(mytable_name)s AND mytable.myid = %(mytable_myid)s AND \
-myothertable.othername != %(myothertable_othername)s AND \
-EXISTS (select yay from foo where boo = lar)",
-            dialect=postgres.dialect()
-            )
+        if False:
+            self.runtest(query, 
+                "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \
+    FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid \
+    WHERE mytable.name = %(mytable_name)s OR mytable.myid = %(mytable_myid)s OR \
+    myothertable.othername != %(myothertable_othername)s OR \
+    EXISTS (select yay from foo where boo = lar)",
+                dialect=postgres.dialect()
+                )
 
+        print "-------------------------------------------------"
         self.runtest(query, 
             "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \
 FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid(+) AND \
-mytable.name = :mytable_name AND mytable.myid = :mytable_myid AND \
-myothertable.othername != :myothertable_othername AND EXISTS (select yay from foo where boo = lar)",
+(mytable.name = :mytable_name OR mytable.myid = :mytable_myid OR \
+myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo where boo = lar))",
             dialect=oracle.OracleDialect(use_ansi = False))
 
         query = table1.outerjoin(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid)