From 1702222e2b1f91b0e0851e11ccb20ed034621b3c Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 3 Dec 2005 08:41:18 +0000 Subject: [PATCH] refactorings to sql generation, unions, engine location --- lib/sqlalchemy/ansisql.py | 11 +- lib/sqlalchemy/databases/oracle.py | 4 +- lib/sqlalchemy/sql.py | 163 ++++++++++++++++------------- 3 files changed, 103 insertions(+), 75 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 99402aebb5..85b687ad15 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -170,7 +170,13 @@ class ANSICompiler(sql.Compiled): def visit_function(self, func): self.strings[func] = func.name + "(" + string.join([self.get_str(c) for c in func.clauses], ', ') + ")" - + + def visit_compound_select(self, cs): + text = string.join([self.get_str(c) for c in cs.selects], " " + cs.keyword + " ") + for tup in cs.clauses: + text += " " + tup[0] + " " + self.get_str(tup[1]) + self.strings[cs] = text + def visit_binary(self, binary): result = self.get_str(binary.left) if binary.operator is not None: @@ -245,7 +251,7 @@ class ANSICompiler(sql.Compiled): if t: text += " \nWHERE " + t - for tup in select._clauses: + for tup in select.clauses: text += " " + tup[0] + " " + self.get_str(tup[1]) if select.having is not None: @@ -275,6 +281,7 @@ class ANSICompiler(sql.Compiled): else: self.froms[join] = (self.get_from_text(join.left) + " JOIN " + righttext + " ON " + self.get_str(join.onclause)) + self.strings[join] = self.froms[join] def visit_insert(self, insert_stmt): self.isinsert = True diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 9763c69784..91547d12c8 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -151,10 +151,10 @@ class OracleCompiler(ansisql.ANSICompiler): """oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if the use_ansi flag is False.""" - def __init__(self, engine, statement, bindparams, use_ansi = True): + def __init__(self, engine, statement, bindparams, use_ansi = True, **kwargs): self._outertable = None self._use_ansi = use_ansi - ansisql.ANSICompiler.__init__(self, engine, statement, bindparams) + ansisql.ANSICompiler.__init__(self, engine, statement, bindparams, **kwargs) def visit_join(self, join): if self._use_ansi: diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 486404f33a..6a17b44552 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -168,22 +168,12 @@ func = FunctionGateway() def _compound_clause(keyword, *clauses): return CompoundClause(keyword, *clauses) -def _compound_select(keyword, *selects, **params): - if len(selects) == 0: - return None - s = selects[0] - for n in selects[1:]: - s.append_clause(keyword, n) - - if params.get('order_by', None) is not None: - s.order_by(*params['order_by']) - - return s +def _compound_select(keyword, *selects, **kwargs): + return CompoundSelect(keyword, *selects, **kwargs) def _is_literal(element): return not isinstance(element, ClauseElement) and not isinstance(element, schema.SchemaItem) - class ClauseVisitor(schema.SchemaVisitor): """builds upon SchemaVisitor to define the visiting of SQL statement elements in addition to Schema elements.""" @@ -192,6 +182,7 @@ class ClauseVisitor(schema.SchemaVisitor): def visit_bindparam(self, bindparam):pass def visit_textclause(self, textclause):pass def visit_compound(self, compound):pass + def visit_compound_select(self, compound):pass def visit_binary(self, binary):pass def visit_alias(self, alias):pass def visit_select(self, select):pass @@ -258,8 +249,13 @@ class ClauseElement(object): change if the underlying structure of the ClauseElement changes.""" raise NotImplementedError(repr(self)) def _get_from_objects(self): + """returns objects represented in this ClauseElement that should be added to the + FROM list of a query.""" raise NotImplementedError(repr(self)) def _process_from_dict(self, data, asfrom): + """given a dictionary attached to a Select object, places the appropriate + FROM objects in the dictionary corresponding to this ClauseElement, + and possibly removes or modifies others.""" for f in self._get_from_objects(): data.setdefault(f.id, f) if asfrom: @@ -274,6 +270,20 @@ class ClauseElement(object): new structure can then be restructured without affecting the original.""" return self + def _find_engine(self): + try: + if self._engine is not None: + return self._engine + except AttributeError: + pass + for f in self._get_from_objects(): + engine = f.engine + if engine is not None: + return engine + else: + return None + + engine = property(lambda s: s._find_engine()) def compile(self, engine = None, bindparams = None, typemap=None): """compiles this SQL expression using its underlying SQLEngine to produce @@ -281,18 +291,19 @@ class ClauseElement(object): bindparams is a dictionary representing the default bind parameters to be used with the statement. """ if engine is None: - for f in self._get_from_objects(): - engine = f.engine - if engine is not None: break - else: - import sqlalchemy.ansisql as ansisql - engine = ansisql.engine() - #raise "no engine supplied, and no engine could be located within the clauses!" + engine = self.engine + + if engine is None: + raise "no SQLEngine could be located within this ClauseElement." return engine.compile(self, bindparams = bindparams, typemap=typemap) def __str__(self): - return str(self.compile()) + e = self.engine + if e is None: + import sqlalchemy.ansisql as ansisql + e = ansisql.engine() + return str(self.compile(e)) def execute(self, *multiparams, **params): """compiles and executes this SQL expression using its underlying SQLEngine. the @@ -410,8 +421,6 @@ class FromClause(ClauseElement): # this could also be [self], at the moment it doesnt matter to the Select object return [] - engine = property(lambda s: None) - def hash_key(self): return "FromClause(%s, %s)" % (repr(self.id), repr(self.from_name)) @@ -444,7 +453,7 @@ class TextClause(ClauseElement): def __init__(self, text = "", engine=None, isliteral=False): self.text = text self.parens = False - self.engine = engine + self._engine = engine self.id = id(self) if isliteral: if isinstance(text, int) or isinstance(text, long): @@ -489,7 +498,10 @@ class ClauseList(ClauseElement): c.accept_visitor(visitor) visitor.visit_clauselist(self) def _get_from_objects(self): - return [] + f = [] + for c in self.clauses: + f += c._get_from_objects() + return f class CompoundClause(ClauseList): """represents a list of clauses joined by an operator, such as AND or OR. @@ -600,10 +612,6 @@ class Selectable(FromClause): def alias(self, name): return Alias(self, name) - def union(self, other, **kwargs): - return union(self, other, **kwargs) - def union_all(self, other, **kwargs): - return union_all(self, other, **kwargs) def group_parenthesized(self): """indicates if this Selectable requires parenthesis when grouped into a compound statement""" @@ -644,7 +652,7 @@ class Join(Selectable): return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter)) def select(self, whereclauses = None, **params): - return select([self.left, self.right], and_(self.onclause, whereclauses), **params) + return select([self.left, self.right], whereclauses, from_obj=[self], **params) def accept_visitor(self, visitor): self.left.accept_visitor(visitor) @@ -706,8 +714,6 @@ class Alias(Selectable): engine = property(lambda s: s.selectable.engine) - - class ColumnImpl(Selectable, CompareMixin): """Selectable implementation that gets attached to a schema.Column object.""" @@ -821,9 +827,52 @@ class TableImpl(Selectable): def drop(self, **params): self.table.engine.drop(self.table) + +class TailClauseMixin(object): + def order_by(self, *clauses): + self._append_clause('order_by_clause', "ORDER BY", *clauses) + def group_by(self, *clauses): + self._append_clause('group_by_clause', "GROUP BY", *clauses) + def _append_clause(self, attribute, prefix, *clauses): + if not hasattr(self, attribute): + l = ClauseList(*clauses) + setattr(self, attribute, l) + self.append_clause(prefix, l) + else: + getattr(self, attribute).clauses += clauses + def append_clause(self, keyword, clause): + if type(clause) == str: + clause = TextClause(clause) + self.clauses.append((keyword, clause)) + +class CompoundSelect(Selectable, TailClauseMixin): + def __init__(self, keyword, *selects, **kwargs): + self.keyword = keyword + self.selects = selects + self.clauses = [] + order_by = kwargs.get('order_by', None) + if order_by: + self.order_by(*order_by) + group_by = kwargs.get('group_by', None) + if group_by: + self.group_by(*group_by) + + columns = property(lambda s:s.selects[0].columns) + def accept_visitor(self, visitor): + for tup in self.clauses: + tup[1].accept_visitor(visitor) + for s in self.selects: + s.accept_visitor(visitor) + visitor.visit_compound_select(self) + def _find_engine(self): + for s in self.selects: + e = s._find_engine() + if e: + return e + else: + return None - -class Select(Selectable): +class Select(Selectable, TailClauseMixin): """finally, represents a SELECT statement, with appendable clauses, as well as the ability to execute itself and return a result set.""" def __init__(self, columns, whereclause = None, from_obj = [], order_by = None, group_by=None, having=None, use_labels = False, distinct=False, engine = None): @@ -842,11 +891,11 @@ class Select(Selectable): # indicates if this select statement is a subquery as a criterion # inside of a WHERE clause self.is_where = False + self.clauses = [] self.distinct = distinct self._text = None self._raw_columns = [] - self._clauses = [] self._correlated = None self._correlator = Select.CorrelatedVisitor(self, False) self._wherecorrelator = Select.CorrelatedVisitor(self, True) @@ -929,18 +978,6 @@ class Select(Selectable): fromclause.accept_visitor(self._correlator) fromclause._process_from_dict(self._froms, True) - - def append_clause(self, keyword, clause): - if type(clause) == str: - clause = TextClause(clause) - self._clauses.append((keyword, clause)) - - def compile(self, engine = None, bindparams = None): - if engine is None: - engine = self.engine - if engine is None: - raise "no engine supplied, and no engine could be located within the clauses!" - return engine.compile(self, bindparams) def _get_froms(self): return [f for f in self._froms.values() if self._correlated is None or not self._correlated.has_key(f.id)] @@ -953,25 +990,17 @@ class Select(Selectable): self.whereclause.accept_visitor(visitor) if self.having is not None: self.having.accept_visitor(visitor) - for tup in self._clauses: + for tup in self.clauses: tup[1].accept_visitor(visitor) visitor.visit_select(self) - def order_by(self, *clauses): - self._append_clause('order_by_clause', "ORDER BY", *clauses) - def group_by(self, *clauses): - self._append_clause('group_by_clause', "GROUP BY", *clauses) - def _append_clause(self, attribute, prefix, *clauses): - if not hasattr(self, attribute): - l = ClauseList(*clauses) - setattr(self, attribute, l) - self.append_clause(prefix, l) - else: - getattr(self, attribute).clauses += clauses - def select(self, whereclauses = None, **params): return select([self], whereclauses, **params) + def union(self, other, **kwargs): + return union(self, other, **kwargs) + def union_all(self, other, **kwargs): + return union_all(self, other, **kwargs) def _find_engine(self): """tries to return a SQLEngine, either explicitly set in this object, or searched @@ -988,7 +1017,7 @@ class Select(Selectable): return None - engine = property(lambda s: s._find_engine()) +# engine = property(lambda s: s._find_engine()) def _get_from_objects(self): if self.is_where: @@ -1068,21 +1097,13 @@ class UpdateBase(ClauseElement): values.append((c, value)) return values - def compile(self, engine = None, bindparams = None): - if engine is None: - engine = self.engine - - if engine is None: - raise "no engine supplied, and no engine could be located within the clauses!" - - return engine.compile(self, bindparams) class Insert(UpdateBase): def __init__(self, table, values=None, **params): self.table = table self.select = None self.parameters = self._process_colparams(values) - self.engine = self.table.engine + self._engine = self.table.engine def accept_visitor(self, visitor): if self.select is not None: @@ -1095,7 +1116,7 @@ class Update(UpdateBase): self.table = table self.whereclause = whereclause self.parameters = self._process_colparams(values) - self.engine = self.table.engine + self._engine = self.table.engine def accept_visitor(self, visitor): if self.whereclause is not None: @@ -1106,7 +1127,7 @@ class Delete(UpdateBase): def __init__(self, table, whereclause, **params): self.table = table self.whereclause = whereclause - self.engine = self.table.engine + self._engine = self.table.engine def accept_visitor(self, visitor): if self.whereclause is not None: -- 2.47.2