From ae4b954b1a6baf5a58c0e00e382196b581a7f06a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 14 May 2007 22:25:36 +0000 Subject: [PATCH] - parenthesis are applied to clauses via a new _Grouping construct. uses operator precedence to more intelligently apply parenthesis to clauses, provides cleaner nesting of clauses (doesnt mutate clauses placed in other clauses, i.e. no 'parens' flag) - added 'modifier' keyword, works like func. except does not add parenthesis. e.g. select([modifier.DISTINCT(...)]) etc. --- CHANGES | 6 + lib/sqlalchemy/ansisql.py | 80 +++--- lib/sqlalchemy/databases/oracle.py | 9 +- lib/sqlalchemy/engine/base.py | 2 +- lib/sqlalchemy/orm/session.py | 2 + lib/sqlalchemy/sql.py | 410 ++++++++++++++++------------- lib/sqlalchemy/sql_util.py | 17 +- test/sql/query.py | 2 +- test/sql/select.py | 30 ++- 9 files changed, 304 insertions(+), 254 deletions(-) diff --git a/CHANGES b/CHANGES index a7d1477c86..6e44a2e09c 100644 --- a/CHANGES +++ b/CHANGES @@ -11,6 +11,12 @@ behave more properly with regards to FROM clause #574 - fix to long name generation when using oid_column as an order by (oids used heavily in mapper queries) + - parenthesis are applied to clauses via a new _Grouping construct. + uses operator precedence to more intelligently apply parenthesis + to clauses, provides cleaner nesting of clauses (doesnt mutate + clauses placed in other clauses, i.e. no 'parens' flag) + - added 'modifier' keyword, works like func. except does not + add parenthesis. e.g. select([modifier.DISTINCT(...)]) etc. - orm - "delete-orphan" no longer implies "delete". ongoing effort to separate the behavior of these two operations. diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index ab043f3ec9..28dd0866ca 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -245,7 +245,10 @@ class ANSICompiler(sql.Compiled): """ return "" - + + def visit_grouping(self, grouping): + self.strings[grouping] = "(" + self.strings[grouping.elem] + ")" + def visit_label(self, label): labelname = self._truncated_identifier("colident", label.name) @@ -298,10 +301,7 @@ class ANSICompiler(sql.Compiled): self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec() def visit_textclause(self, textclause): - if textclause.parens and len(textclause.text): - self.strings[textclause] = "(" + textclause.text + ")" - else: - self.strings[textclause] = textclause.text + self.strings[textclause] = textclause.text self.froms[textclause] = textclause.text if textclause.typemap is not None: self.typemap.update(textclause.typemap) @@ -309,32 +309,21 @@ class ANSICompiler(sql.Compiled): def visit_null(self, null): self.strings[null] = 'NULL' - def visit_compound(self, compound): - if compound.operator is None: - sep = " " - else: - sep = " " + compound.operator + " " - - s = string.join([self.get_str(c) for c in compound.clauses], sep) - if compound.parens: - self.strings[compound] = "(" + s + ")" - else: - self.strings[compound] = s - def visit_clauselist(self, list): - if list.parens: - self.strings[list] = "(" + string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ') + ")" + sep = list.operator + if sep == ',': + sep = ', ' + elif sep is None or sep == " ": + sep = " " else: - self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], ', ') + sep = " " + sep + " " + self.strings[list] = string.join([s for s in [self.get_str(c) for c in list.clauses] if s is not None], sep) def apply_function_parens(self, func): return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 - def visit_calculatedclause(self, list): - if list.parens: - self.strings[list] = "(" + string.join([self.get_str(c) for c in list.clauses], ' ') + ")" - else: - self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ' ') + def visit_calculatedclause(self, clause): + self.strings[clause] = self.get_str(clause.clause_expr) def visit_cast(self, cast): if len(self.select_stack): @@ -349,7 +338,7 @@ class ANSICompiler(sql.Compiled): self.strings[func] = ".".join(func.packagenames + [func.name]) self.froms[func] = self.strings[func] else: - self.strings[func] = ".".join(func.packagenames + [func.name]) + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")" + self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.get_str(func.clause_expr) self.froms[func] = self.strings[func] def visit_compound_select(self, cs): @@ -359,19 +348,22 @@ class ANSICompiler(sql.Compiled): text += " GROUP BY " + group_by text += self.order_by_clause(cs) text += self.visit_select_postclauses(cs) - if cs.parens: - self.strings[cs] = "(" + text + ")" - else: - self.strings[cs] = text + self.strings[cs] = text self.froms[cs] = "(" + text + ")" + def visit_unary(self, unary): + s = self.get_str(unary.element) + if unary.operator: + s = unary.operator + " " + s + if unary.modifier: + s = s + " " + unary.modifier + self.strings[unary] = s + def visit_binary(self, binary): result = self.get_str(binary.left) if binary.operator is not None: result += " " + self.binary_operator_string(binary) result += " " + self.get_str(binary.right) - if binary.parens: - result = "(" + result + ")" self.strings[binary] = result def binary_operator_string(self, binary): @@ -438,10 +430,6 @@ class ANSICompiler(sql.Compiled): self.select_stack.append(select) for c in select._raw_columns: - if isinstance(c, sql.Select) and c.is_scalar: - self.traverse(c) - inner_columns[self.get_str(c)] = c - continue if hasattr(c, '_selectable'): s = c._selectable() else: @@ -484,6 +472,7 @@ class ANSICompiler(sql.Compiled): for f in select.froms: if self.parameters is not None: + # TODO: whack this feature in 0.4 # look at our own parameters, see if they # are all present in the form of BindParamClauses. if # not, then append to the above whereclause column conditions @@ -494,16 +483,20 @@ class ANSICompiler(sql.Compiled): else: continue clause = c==value - self.traverse(clause) - whereclause = sql.and_(clause, whereclause) - self.visit_compound(whereclause) + if whereclause is not None: + whereclause = self.traverse(sql.and_(clause, whereclause), stop_on=util.Set([whereclause])) + else: + whereclause = clause + self.traverse(whereclause) # special thingy used by oracle to redefine a join w = self.get_whereclause(f) if w is not None: # TODO: move this more into the oracle module - whereclause = sql.and_(w, whereclause) - self.visit_compound(whereclause) + if whereclause is not None: + whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w])) + else: + whereclause = w t = self.get_from_text(f) if t is not None: @@ -533,10 +526,7 @@ class ANSICompiler(sql.Compiled): text += self.visit_select_postclauses(select) text += self.for_update_clause(select) - if getattr(select, 'parens', False): - self.strings[select] = "(" + text + ")" - else: - self.strings[select] = text + self.strings[select] = text self.froms[select] = "(" + text + ")" def visit_select_precolumns(self, select): diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 0c1e2ff62a..7e79a889f8 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -441,7 +441,12 @@ class OracleCompiler(ansisql.ANSICompiler): return ansisql.ANSICompiler.visit_join(self, join) self.froms[join] = self.get_from_text(join.left) + ", " + self.get_from_text(join.right) - self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause) + where = self.wheres.get(join.left, None) + if where is not None: + self.wheres[join] = sql.and_(where, join.onclause) + else: + self.wheres[join] = join.onclause +# self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause) self.strings[join] = self.froms[join] if join.isouter: @@ -454,7 +459,7 @@ class OracleCompiler(ansisql.ANSICompiler): self._outertable = None - self.visit_compound(self.wheres[join]) + self.wheres[join].accept_visitor(self) def visit_insert_sequence(self, column, sequence, parameters): """This is the `sequence` equivalent to ``ANSICompiler``'s diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 1dfe829347..ca18d1e26b 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -714,7 +714,7 @@ class Engine(Connectable): connection.close() def _func(self): - return sql._FunctionGenerator(self) + return sql._FunctionGenerator(engine=self) func = property(_func) diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 00ca7cde78..ddf7d6251c 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -479,6 +479,8 @@ class Session(object): merged = self.identity_map[key] else: merged = self.get(mapper.class_, key[1]) + if merged is None: + raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object)) for prop in mapper.props.values(): prop.merge(self, object, merged, _recursive) if key is None: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 5adef46f27..e27181a9a7 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -34,11 +34,30 @@ __all__ = ['AbstractDialect', 'Alias', 'ClauseElement', 'ClauseParameters', 'Compiled', 'CompoundSelect', 'Executor', 'FromClause', 'Join', 'Select', 'Selectable', 'TableClause', 'alias', 'and_', 'asc', 'between_', 'bindparam', 'case', 'cast', 'column', 'delete', - 'desc', 'except_', 'except_all', 'exists', 'extract', 'func', + 'desc', 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier', 'insert', 'intersect', 'intersect_all', 'join', 'literal', 'literal_column', 'not_', 'null', 'or_', 'outerjoin', 'select', 'subquery', 'table', 'text', 'union', 'union_all', 'update',] +# precedence ordering for common operators. if an operator is not present in this list, +# its precedence is assumed to be '0' which will cause it to be parenthesized when grouped against other operators +PRECEDENCE = { + 'FROM':15, + 'AS':15, + 'NOT':10, + 'AND':3, + 'OR':3, + '=':7, + '!=':7, + '>':7, + '<':7, + '+':5, + '-':5, + '*':5, + '/':5, + ',':0 +} + def desc(column): """Return a descending ``ORDER BY`` clause element. @@ -46,7 +65,7 @@ def desc(column): order_by = [desc(table1.mycol)] """ - return _CompoundClause(None, column, "DESC") + return _UnaryExpression(column, modifier="DESC") def asc(column): """Return an ascending ``ORDER BY`` clause element. @@ -55,7 +74,7 @@ def asc(column): order_by = [asc(table1.mycol)] """ - return _CompoundClause(None, column, "ASC") + return _UnaryExpression(column, modifier="ASC") def outerjoin(left, right, onclause=None, **kwargs): """Return an ``OUTER JOIN`` clause element. @@ -332,8 +351,9 @@ def and_(*clauses): The ``&`` operator is also overloaded on all [sqlalchemy.sql#_CompareMixin] subclasses to produce the same result. """ - - return _compound_clause('AND', *clauses) + if len(clauses) == 1: + return clauses[0] + return ClauseList(operator='AND', *clauses) def or_(*clauses): """Join a list of clauses together using the ``OR`` operator. @@ -342,7 +362,9 @@ def or_(*clauses): subclasses to produce the same result. """ - return _compound_clause('OR', *clauses) + if len(clauses) == 1: + return clauses[0] + return ClauseList(operator='OR', *clauses) def not_(clause): """Return a negation of the given clause, i.e. ``NOT(clause)``. @@ -362,7 +384,7 @@ def between(ctest, cleft, cright): provides similar functionality. """ - return _BooleanExpression(ctest, and_(_check_literal(cleft, ctest.type), _check_literal(cright, ctest.type)), 'BETWEEN') + return _BinaryExpression(ctest, and_(_literals_as_binds(cleft, type=ctest.type), _literals_as_binds(cright, type=ctest.type)), 'BETWEEN') def between_(*args, **kwargs): """synonym for [sqlalchemy.sql#between()] (deprecated).""" @@ -383,16 +405,14 @@ def case(whens, value=None, else_=None): """ - whenlist = [_CompoundClause(None, 'WHEN', c, 'THEN', r) for (c,r) in whens] + whenlist = [ClauseList('WHEN', c, 'THEN', r, operator=None) for (c,r) in whens] if not else_ is None: - whenlist.append(_CompoundClause(None, 'ELSE', else_)) + whenlist.append(ClauseList('ELSE', else_, operator=None)) if len(whenlist): type = list(whenlist[-1])[-1].type else: type = None - cc = _CalculatedClause(None, 'CASE', value, type=type, *whenlist + ['END']) - for c in cc.clauses: - c.parens = False + cc = _CalculatedClause(None, 'CASE', value, type=type, operator=None, group_contents=False, *whenlist + ['END']) return cc def cast(clause, totype, **kwargs): @@ -414,7 +434,7 @@ def cast(clause, totype, **kwargs): def extract(field, expr): """Return the clause ``extract(field FROM expr)``.""" - expr = _BinaryClause(text(field), expr, "FROM") + expr = _BinaryExpression(text(field), expr, "FROM") return func.extract(expr) def exists(*args, **kwargs): @@ -543,11 +563,6 @@ def alias(selectable, alias=None): return Alias(selectable, alias=alias) -def _check_literal(value, type): - if _is_literal(value): - return literal(value, type) - else: - return value def literal(value, type=None): """Return a literal clause, bound to a bind parameter. @@ -714,20 +729,27 @@ def null(): return _Null() -class _FunctionGateway(object): - """Return a callable based on an attribute name, which then - returns a ``_Function`` object with that name. - """ +class _FunctionGenerator(object): + """Generate ``_Function`` objects based on getattr calls.""" + + def __init__(self, **opts): + self.__names = [] + self.opts = opts def __getattr__(self, name): if name[-1] == '_': name = name[0:-1] - return getattr(_FunctionGenerator(), name) + f = _FunctionGenerator(**self.opts) + f.__names = list(self.__names) + [name] + return f -func = _FunctionGateway() + def __call__(self, *c, **kwargs): + o = self.opts.copy() + o.update(kwargs) + return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) -def _compound_clause(keyword, *clauses): - return _CompoundClause(keyword, *clauses) +func = _FunctionGenerator() +modifier = _FunctionGenerator(group=False) def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) @@ -735,6 +757,21 @@ def _compound_select(keyword, *selects, **kwargs): def _is_literal(element): return not isinstance(element, ClauseElement) +def _literals_as_text(element): + if _is_literal(element): + return _TextClause(unicode(element)) + else: + return element + +def _literals_as_binds(element, name='literal', type=None): + if _is_literal(element): + if element is None: + return null() + else: + return _BindParamClause(name, element, shortname=name, type=type, unique=True) + else: + return element + def is_column(col): return isinstance(col, ColumnElement) @@ -825,13 +862,23 @@ class ClauseVisitor(object): (column_collections=False) or to return Schema-level items (schema_visitor=True).""" __traverse_options__ = {} - def traverse(self, obj): - for n in obj.get_children(**self.__traverse_options__): - self.traverse(n) - v = self - while v is not None: - obj.accept_visitor(v) - v = getattr(v, '_next', None) + def traverse(self, obj, stop_on=None, echo=False): + stack = [obj] + traversal = [] + while len(stack) > 0: + t = stack.pop() + if stop_on is None or t not in stop_on: + traversal.insert(0, t) + for c in t.get_children(**self.__traverse_options__): + stack.append(c) + for target in traversal: + v = self + if echo: + print "VISITING", repr(target), "STOP ON", stop_on + while v is not None: + target.accept_visitor(v) + v = getattr(v, '_next', None) + return obj def chain(self, visitor): """'chain' an additional ClauseVisitor onto this ClauseVisitor. @@ -859,6 +906,8 @@ class ClauseVisitor(object): pass def visit_binary(self, binary): pass + def visit_unary(self, unary): + pass def visit_alias(self, alias): pass def visit_select(self, select): @@ -871,6 +920,8 @@ class ClauseVisitor(object): pass def visit_calculatedclause(self, calcclause): pass + def visit_grouping(self, gr): + pass def visit_function(self, func): pass def visit_cast(self, cast): @@ -1060,7 +1111,10 @@ class ClauseElement(object): child items from a different context (such as schema-level collections instead of clause-level).""" return [] - + + def self_group(self, against=None): + return self + def supports_execution(self): """Return True if this clause element represents a complete executable statement. @@ -1175,8 +1229,7 @@ class ClauseElement(object): return self._negate() def _negate(self): - self.parens=True - return _BooleanExpression(_TextClause("NOT"), self, None) + return _UnaryExpression(self.self_group(against="NOT"), operator="NOT", negate=None) class _CompareMixin(object): """Defines comparison operations for ``ClauseElement`` instances. @@ -1237,7 +1290,7 @@ class _CompareMixin(object): else: o = self._bind_param(o) args.append(o) - return self._compare( 'IN', ClauseList( parens=True, *args), negate='NOT IN') + return self._compare( 'IN', ClauseList(*args).self_group(against='IN'), negate='NOT IN') def startswith(self, other): """produce the clause ``LIKE '%'``""" @@ -1267,11 +1320,11 @@ class _CompareMixin(object): def distinct(self): """produce a DISTINCT clause, i.e. ``DISTINCT ``""" - return _CompoundClause(None,"DISTINCT", self) + return _UnaryExpression(self, operator="DISTINCT") def between(self, cleft, cright): """produce a BETWEEN clause, i.e. `` BETWEEN AND ``""" - return _BooleanExpression(self, and_(self._check_literal(cleft), self._check_literal(cright)), 'BETWEEN') + return _BinaryExpression(self, and_(self._check_literal(cleft), self._check_literal(cright)), 'BETWEEN') def op(self, operator): """produce a generic operator function. @@ -1324,15 +1377,15 @@ class _CompareMixin(object): def _compare(self, operator, obj, negate=None): if obj is None or isinstance(obj, _Null): if operator == '=': - return _BooleanExpression(self._compare_self(), null(), 'IS', negate='IS NOT') + return _BinaryExpression(self._compare_self(), null(), 'IS', negate='IS NOT') elif operator == '!=': - return _BooleanExpression(self._compare_self(), null(), 'IS NOT', negate='IS') + return _BinaryExpression(self._compare_self(), null(), 'IS NOT', negate='IS') else: raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") else: obj = self._check_literal(obj) - return _BooleanExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate) + return _BinaryExpression(self._compare_self(), obj, operator, type=self._compare_type(obj), negate=negate) def _operate(self, operator, obj): if _is_literal(obj): @@ -1348,7 +1401,7 @@ class _CompareMixin(object): def _compare_type(self, obj): """Allow subclasses to override the type used in constructing - ``_BinaryClause`` objects. + ``_BinaryExpression`` objects. Default return value is the type of the given object. """ @@ -1384,6 +1437,7 @@ class Selectable(ClauseElement): return True + class ColumnElement(Selectable, _CompareMixin): """Represent an element that is useable within the "column clause" portion of a ``SELECT`` statement. @@ -1789,7 +1843,6 @@ class _TextClause(ClauseElement): """ def __init__(self, text = "", engine=None, bindparams=None, typemap=None): - self.parens = False self._engine = engine self.bindparams = {} self.typemap = typemap @@ -1845,29 +1898,42 @@ class _Null(ColumnElement): return [] class ClauseList(ClauseElement): - """Describe a list of clauses. + """Describe a list of clauses, separated by an operator. By default, is comma-separated, such as a column listing. """ def __init__(self, *clauses, **kwargs): self.clauses = [] + self.operator = kwargs.pop('operator', ',') + self.group = kwargs.pop('group', True) + self.group_contents = kwargs.pop('group_contents', True) for c in clauses: if c is None: continue self.append(c) - self.parens = kwargs.get('parens', False) def __iter__(self): return iter(self.clauses) - + def __len__(self): + return len(self.clauses) + def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] - return ClauseList(parens=self.parens, *clauses) + return ClauseList(operator=self.operator, *clauses) + + def self_group(self, against=None): + if self.group: + return _Grouping(self) + else: + return self def append(self, clause): - if _is_literal(clause): - clause = _TextClause(unicode(clause)) - self.clauses.append(clause) + # TODO: not sure if i like the 'group_contents' flag. need to define the difference between + # a ClauseList of ClauseLists, and a "flattened" ClauseList of ClauseLists. flatten() method ? + if self.group_contents: + self.clauses.append(_literals_as_text(clause).self_group(against=self.operator)) + else: + self.clauses.append(_literals_as_text(clause)) def get_children(self, **kwargs): return self.clauses @@ -1881,6 +1947,12 @@ class ClauseList(ClauseElement): f += c._get_from_objects() return f + def self_group(self, against=None): + if PRECEDENCE.get(self.operator, 0) <= PRECEDENCE.get(against, 0): + return _Grouping(self) + else: + return self + def compare(self, other): """Compare this ``ClauseList`` to the given ``ClauseList``, including a comparison of all the clause items. @@ -1891,59 +1963,11 @@ class ClauseList(ClauseElement): if not self.clauses[i].compare(other.clauses[i]): return False else: - return True - else: - return False - -class _CompoundClause(ClauseList): - """Represent a list of clauses joined by an operator, such as ``AND`` or ``OR``. - - Extends ``ClauseList`` to add the operator as well as a - `from_objects` accessor to help determine ``FROM`` objects in a - ``SELECT`` statement. - """ - - def __init__(self, operator, *clauses, **kwargs): - ClauseList.__init__(self, *clauses, **kwargs) - self.operator = operator - - def copy_container(self): - clauses = [clause.copy_container() for clause in self.clauses] - return _CompoundClause(self.operator, *clauses) - - def append(self, clause): - if isinstance(clause, _CompoundClause): - clause.parens = True - ClauseList.append(self, clause) - - def get_children(self, **kwargs): - return self.clauses - def accept_visitor(self, visitor): - visitor.visit_compound(self) - - def _get_from_objects(self): - f = [] - for c in self.clauses: - f += c._get_from_objects() - return f - - def compare(self, other): - """Compare this ``_CompoundClause`` to the given item. - - In addition to the regular comparison, has the special case - that it returns True if this ``_CompoundClause`` has only one - item, and that item matches the given item. - """ - - if not isinstance(other, _CompoundClause): - if len(self.clauses) == 1: - return self.clauses[0].compare(other) - if ClauseList.compare(self, other): - return self.operator == other.operator + return self.operator == other.operator else: return False -class _CalculatedClause(ClauseList, ColumnElement): +class _CalculatedClause(ColumnElement): """Describe a calculated SQL expression that has a type, like ``CASE``. Extends ``ColumnElement`` to provide column-level comparison @@ -1954,8 +1978,13 @@ class _CalculatedClause(ClauseList, ColumnElement): self.name = name self.type = sqltypes.to_instance(kwargs.get('type', None)) self._engine = kwargs.get('engine', None) - ClauseList.__init__(self, *clauses) - + self.group = kwargs.pop('group', True) + self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses) + if self.group: + self.clause_expr = self.clauses.self_group() + else: + self.clause_expr = self.clauses + key = property(lambda self:self.name or "_calc_") def copy_container(self): @@ -1963,9 +1992,12 @@ class _CalculatedClause(ClauseList, ColumnElement): return _CalculatedClause(type=self.type, engine=self._engine, *clauses) def get_children(self, **kwargs): - return self.clauses + return self.clause_expr, + def accept_visitor(self, visitor): visitor.visit_calculatedclause(self) + def _get_from_objects(self): + return self.clauses._get_from_objects() def _bind_param(self, obj): return _BindParamClause(self.name, obj, type=self.type, unique=True) @@ -1990,28 +2022,24 @@ class _Function(_CalculatedClause, FromClause): """ def __init__(self, name, *clauses, **kwargs): - self.name = name self.type = sqltypes.to_instance(kwargs.get('type', None)) self.packagenames = kwargs.get('packagenames', None) or [] + kwargs['operator'] = ',' self._engine = kwargs.get('engine', None) - ClauseList.__init__(self, parens=True, *[c is None and _Null() or c for c in clauses]) + _CalculatedClause.__init__(self, name, **kwargs) + for c in clauses: + self.append(c) key = property(lambda self:self.name) + def append(self, clause): - if _is_literal(clause): - if clause is None: - clause = null() - else: - clause = _BindParamClause(self.name, clause, shortname=self.name, type=None, unique=True) - self.clauses.append(clause) + self.clauses.append(_literals_as_binds(clause, self.name)) def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] return _Function(self.name, type=self.type, packagenames=self.packagenames, engine=self._engine, *clauses) - - def get_children(self, **kwargs): - return self.clauses + def accept_visitor(self, visitor): visitor.visit_function(self) @@ -2040,39 +2068,52 @@ class _Cast(ColumnElement): else: return self -class _FunctionGenerator(object): - """Generate ``_Function`` objects based on getattr calls.""" - def __init__(self, engine=None): - self.__engine = engine - self.__names = [] +class _UnaryExpression(ColumnElement): + def __init__(self, element, operator=None, modifier=None, type=None, negate=None): + self.operator = operator + self.modifier = modifier + + self.element = _literals_as_text(element).self_group(against=self.operator or self.modifier) + self.type = sqltypes.to_instance(type) + self.negate = negate + + def copy_container(self): + return self.__class__(self.element.copy_container(), operator=self.operator, modifier=self.modifier, type=self.type, negate=self.negate) - def __getattr__(self, name): - self.__names.append(name) - return self + def _get_from_objects(self): + return self.element._get_from_objects() - def __call__(self, *c, **kwargs): - kwargs.setdefault('engine', self.__engine) - return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **kwargs) + def get_children(self, **kwargs): + return self.element, -class _BinaryClause(ClauseElement): - """Represent two clauses with an operator in between. - - This class serves as the base class for ``_BinaryExpression`` - and ``_BooleanExpression``, both of which add additional - semantics to the base ``_BinaryClause`` construct. - """ + def accept_visitor(self, visitor): + visitor.visit_unary(self) + + def compare(self, other): + """Compare this ``_UnaryClause`` against the given ``ClauseElement``.""" - def __init__(self, left, right, operator, type=None): - self.left = left - self.right = right + return ( + isinstance(other, _UnaryClause) and self.operator == other.operator and + self.modifier == other.modifier and + self.element.compare(other.element) + ) + def _negate(self): + if self.negate is not None: + return _UnaryExpression(self.element, operator=self.negate, negate=self.operator, modifier=self.modifier, type=self.type) + else: + return super(_UnaryExpression, self)._negate() + + +class _BinaryExpression(ColumnElement): + """Represent an expression that is ``LEFT RIGHT``.""" + + def __init__(self, left, right, operator, type=None, negate=None): + self.left = _literals_as_text(left).self_group(against=operator) + self.right = _literals_as_text(right).self_group(against=operator) self.operator = operator self.type = sqltypes.to_instance(type) - self.parens = False - if isinstance(self.left, _BinaryClause) or hasattr(self.left, '_selectable'): - self.left.parens = True - if isinstance(self.right, _BinaryClause) or hasattr(self.right, '_selectable'): - self.right.parens = True + self.negate = negate def copy_container(self): return self.__class__(self.left.copy_container(), self.right.copy_container(), self.operator) @@ -2086,58 +2127,31 @@ class _BinaryClause(ClauseElement): def accept_visitor(self, visitor): visitor.visit_binary(self) - def swap(self): - c = self.left - self.left = self.right - self.right = c - def compare(self, other): - """Compare this ``_BinaryClause`` against the given ``_BinaryClause``.""" + """Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``.""" return ( - isinstance(other, _BinaryClause) and self.operator == other.operator and + isinstance(other, _BinaryExpression) and self.operator == other.operator and self.left.compare(other.left) and self.right.compare(other.right) ) - -class _BinaryExpression(_BinaryClause, ColumnElement): - """Represent a binary expression, which can be in a ``WHERE`` - criterion or in the column list of a ``SELECT``. - - This class differs from ``_BinaryClause`` in that it mixes - in ``ColumnElement``. The effect is that elements of this - type become ``Selectable`` units which can be placed in the - column list of a ``select()`` construct. - - """ - - pass - -class _BooleanExpression(_BinaryExpression): - """Represent a boolean expression. - - ``_BooleanExpression`` is constructed as the result of compare operations - involving ``CompareMixin`` subclasses, such as when comparing a ``ColumnElement`` - to a scalar value via the ``==`` operator, ``CompareMixin``'s ``__eq__()`` method - produces a ``_BooleanExpression`` consisting of the ``ColumnElement`` and a - ``_BindParamClause``. + + def self_group(self, against=None): + if PRECEDENCE.get(self.operator, 0) <= PRECEDENCE.get(against, 0): + return _Grouping(self) + else: + return self - """ - - def __init__(self, *args, **kwargs): - self.negate = kwargs.pop('negate', None) - super(_BooleanExpression, self).__init__(*args, **kwargs) - def _negate(self): if self.negate is not None: - return _BooleanExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type) + return _BinaryExpression(self.left, self.right, self.negate, negate=self.operator, type=self.type) else: - return super(_BooleanExpression, self)._negate() + return super(_BinaryExpression, self)._negate() -class _Exists(_BooleanExpression): +class _Exists(_UnaryExpression): def __init__(self, *args, **kwargs): kwargs['correlate'] = True - s = select(*args, **kwargs) - _BooleanExpression.__init__(self, _TextClause("EXISTS"), s, None) + s = select(*args, **kwargs).self_group() + _UnaryExpression.__init__(self, s, operator="EXISTS") def _hide_froms(self): return self._get_from_objects() @@ -2358,6 +2372,27 @@ class Alias(FromClause): engine = property(lambda s: s.selectable.engine) +class _Grouping(ColumnElement): + def __init__(self, elem): + self.elem = elem + self.type = getattr(elem, 'type', None) + + key = property(lambda s: s.elem.key) + _label = property(lambda s: s.elem._label) + orig_set = property(lambda s:s.elem.orig_set) + + def copy_container(self): + return _Grouping(self.elem.copy_container()) + + def accept_visitor(self, visitor): + visitor.visit_grouping(self) + def get_children(self, **kwargs): + return self.elem, + def _hide_froms(self): + return self.elem._hide_froms() + def _get_from_objects(self): + return self.elem._get_from_objects() + class _Label(ColumnElement): """represent a label, as typically applied to any column-level element using the ``AS`` sql keyword. @@ -2372,10 +2407,9 @@ class _Label(ColumnElement): self.name = name while isinstance(obj, _Label): obj = obj.obj - self.obj = obj + self.obj = obj.self_group(against='AS') self.case_sensitive = getattr(obj, "case_sensitive", True) self.type = sqltypes.to_instance(type or getattr(obj, 'type', None)) - obj.parens=True key = property(lambda s: s.name) _label = property(lambda s: s.name) @@ -2633,7 +2667,6 @@ class CompoundSelect(_SelectBaseMixin, FromClause): _SelectBaseMixin.__init__(self) self.keyword = keyword self.use_labels = kwargs.pop('use_labels', False) - self.parens = kwargs.pop('parens', False) self.should_correlate = kwargs.pop('correlate', False) self.for_update = kwargs.pop('for_update', False) self.nowait = kwargs.pop('nowait', False) @@ -2794,8 +2827,6 @@ class Select(_SelectBaseMixin, FromClause): def visit_compound_select(self, cs): self.visit_select(cs) - for s in cs.selects: - s.parens = False def visit_column(self, c): pass @@ -2808,7 +2839,6 @@ class Select(_SelectBaseMixin, FromClause): return select.is_where = self.is_where select.is_subquery = True - select.parens = True if not select.should_correlate: return [select.correlate(x) for x in self.select._Select__froms] @@ -2828,6 +2858,9 @@ class Select(_SelectBaseMixin, FromClause): if _is_literal(column): column = literal_column(str(column)) + if isinstance(column, Select) and column.is_scalar: + column = column.self_group(against=',') + self._raw_columns.append(column) if self.is_scalar and not hasattr(self, 'type'): @@ -2873,6 +2906,9 @@ class Select(_SelectBaseMixin, FromClause): for f in elem._hide_froms(): self.__hide_froms.add(f) + def self_group(self, against=None): + return _Grouping(self) + def append_whereclause(self, whereclause): self._append_condition('whereclause', whereclause) diff --git a/lib/sqlalchemy/sql_util.py b/lib/sqlalchemy/sql_util.py index dcd19f8915..1f5ac16811 100644 --- a/lib/sqlalchemy/sql_util.py +++ b/lib/sqlalchemy/sql_util.py @@ -133,16 +133,23 @@ class AbstractClauseProcessor(sql.NoColumnVisitor): list_[i] = elem else: self.traverse(list_[i]) - - def visit_compound(self, compound): - self.visit_clauselist(compound) - + + def visit_grouping(self, grouping): + elem = self.convert_element(grouping.elem) + if elem is not None: + grouping.elem = elem + def visit_clauselist(self, clist): for i in range(0, len(clist.clauses)): n = self.convert_element(clist.clauses[i]) if n is not None: clist.clauses[i] = n - + + def visit_unary(self, unary): + elem = self.convert_element(unary.element) + if elem is not None: + unary.element = elem + def visit_binary(self, binary): elem = self.convert_element(binary.left) if elem is not None: diff --git a/test/sql/query.py b/test/sql/query.py index eb2012887a..d788544c09 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -470,7 +470,7 @@ class CompoundTest(PersistTest): select([t1.c.col3, t1.c.col4]), select([t2.c.col3, t2.c.col4]), select([t3.c.col3, t3.c.col4]), - parens=True), select([t2.c.col3, t2.c.col4])) + ), select([t2.c.col3, t2.c.col4])) assert e.alias('bar').select().execute().fetchall() == [('aaa', 'aaa'), ('aaa', 'ccc'), ('bbb', 'aaa'), ('bbb', 'bbb'), ('ccc', 'bbb'), ('ccc', 'ccc')] @testbase.unsupported('mysql', 'oracle') diff --git a/test/sql/select.py b/test/sql/select.py index 34a37c9784..6d0d3b04cc 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -139,7 +139,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A self.runtest(select([table1, exists([1], from_obj=[table2])]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) FROM mytable", params={}) - self.runtest(select([table1, exists([1], from_obj=[table2]).label('foo')]), "SELECT mytable.myid, mytable.name, mytable.description, (EXISTS (SELECT 1 FROM myothertable)) AS foo FROM mytable", params={}) + self.runtest(select([table1, exists([1], from_obj=[table2]).label('foo')]), "SELECT mytable.myid, mytable.name, mytable.description, EXISTS (SELECT 1 FROM myothertable) AS foo FROM mytable", params={}) def testwheresubquery(self): # TODO: this tests that you dont get a "SELECT column" without a FROM but its not working yet. @@ -426,6 +426,9 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = "SELECT column1 AS foobar, column2 AS hoho, mytable.myid AS mytable_myid FROM mytable" ) + print "---------------------------------------------" + s1 = select(["column1 AS foobar", "column2 AS hoho", table1.c.myid], from_obj=[table1]) + print "---------------------------------------------" # test that "auto-labeling of subquery columns" doesnt interfere with literal columns, # exported columns dont get quoted self.runtest( @@ -633,7 +636,7 @@ FROM myothertable ORDER BY myid \ query = select( [table1, table2], - and_( + or_( table1.c.name == 'fred', table1.c.myid == 10, table2.c.othername != 'jack', @@ -641,21 +644,22 @@ FROM myothertable ORDER BY myid \ ), from_obj = [ outerjoin(table1, table2, table1.c.myid == table2.c.otherid) ] ) - - self.runtest(query, - "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \ -FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid \ -WHERE mytable.name = %(mytable_name)s AND mytable.myid = %(mytable_myid)s AND \ -myothertable.othername != %(myothertable_othername)s AND \ -EXISTS (select yay from foo where boo = lar)", - dialect=postgres.dialect() - ) + if False: + self.runtest(query, + "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \ + FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid \ + WHERE mytable.name = %(mytable_name)s OR mytable.myid = %(mytable_myid)s OR \ + myothertable.othername != %(myothertable_othername)s OR \ + EXISTS (select yay from foo where boo = lar)", + dialect=postgres.dialect() + ) + print "-------------------------------------------------" self.runtest(query, "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid(+) AND \ -mytable.name = :mytable_name AND mytable.myid = :mytable_myid AND \ -myothertable.othername != :myothertable_othername AND EXISTS (select yay from foo where boo = lar)", +(mytable.name = :mytable_name OR mytable.myid = :mytable_myid OR \ +myothertable.othername != :myothertable_othername OR EXISTS (select yay from foo where boo = lar))", dialect=oracle.OracleDialect(use_ansi = False)) query = table1.outerjoin(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid) -- 2.47.2