From 66a86fe2e3be52f27a142aafcea2798911f7cc42 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sat, 6 Aug 2005 22:06:15 +0000 Subject: [PATCH] --- TODO | 20 ++++++++++++++++ lib/sqlalchemy/ansisql.py | 12 ++++++---- lib/sqlalchemy/databases/sqlite.py | 3 +-- lib/sqlalchemy/mapper.py | 2 +- lib/sqlalchemy/sql.py | 37 +++++++++++++++++++++++++----- 5 files changed, 61 insertions(+), 13 deletions(-) create mode 100644 TODO diff --git a/TODO b/TODO new file mode 100644 index 0000000000..4413543259 --- /dev/null +++ b/TODO @@ -0,0 +1,20 @@ +TODO: + +correlated subquery support, plus clauses like EXISTS, IN, etc: + + select foo from lala where g = (select x from y where lala.xx = y.bar) + select foo from lala where exists (select x from y where lala.xx = y.bar) + +table reflection, i.e. create tables with autoload = True + +sequences/autoincrement support + +Oracle module + +Postgres module + +MySQL module + +INSERT from a SELECT + + diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 93e7e737d6..a5e5a5b19f 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -116,11 +116,14 @@ class ANSICompiler(sql.Compiled): self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ') def visit_binary(self, binary): - + if isinstance(binary.right, sql.Select): + s = self.get_str(binary.left) + " " + str(binary.operator) + " (" + self.get_str(binary.right) + ")" + else: + s = self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) if binary.parens: - self.strings[binary] = "(" + self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) + ")" + self.strings[binary] = "(" + s + ")" else: - self.strings[binary] = self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) + self.strings[binary] = s def visit_bindparam(self, bindparam): self.binds[bindparam.shortname] = bindparam @@ -136,6 +139,7 @@ class ANSICompiler(sql.Compiled): def visit_alias(self, alias): self.froms[alias] = self.get_from_text(alias.selectable) + " " + alias.name + self.strings[alias] = self.get_str(alias.selectable) def visit_select(self, select): inner_columns = [] @@ -183,6 +187,7 @@ class ANSICompiler(sql.Compiled): def visit_table(self, table): self.froms[table] = table.name + self.strings[table] = "" def visit_join(self, join): if join.isouter: @@ -194,7 +199,6 @@ class ANSICompiler(sql.Compiled): def visit_insert(self, insert_stmt): colparams = insert_stmt.get_colparams(self._bindparams) - for c in colparams: b = c[1] self.binds[b.key] = b diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index fffda0916d..42b8c4c431 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -68,8 +68,7 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine): raise NotImplementedError() class SQLiteCompiler(ansisql.ANSICompiler): - def visit_insert(self, insert): - ansisql.ANSICompiler.visit_insert(self, insert) + pass class SQLiteColumnImpl(sql.ColumnSelectable): def _get_specification(self): diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 561a0a4a7e..37cdf49317 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -70,7 +70,7 @@ def eagerload(name): def lazyload(name): return EagerLazySwitcher(name, toeager = False) -copy_containerclass Mapper(object): +class Mapper(object): def __init__(self, class_, selectable, table = None, properties = None, identitymap = None, use_smart_properties = True, isroot = True, echo = None): self.class_ = class_ self.selectable = selectable diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 643f9e3d33..a8b75b875d 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -101,6 +101,14 @@ def or_(*clauses): clause = _compound_clause('OR', *clauses) return clause +def exists(*args, **params): + s = select(*args, **params) + return BinaryClause(TextClause("EXISTS"), s, '') + +def in_(*args, **params): + s = select(*args, **params) + return BinaryClause(TextClause("IN"), s, '') + def union(*selects, **params): return _compound_select('UNION', *selects, **params) @@ -346,7 +354,9 @@ class ClauseList(ClauseElement): for c in self.clauses: c.accept_visitor(visitor) visitor.visit_clauselist(self) - + + def _get_from_objects(self): + return [] class BinaryClause(ClauseElement): """represents two clauses with an operator in between""" @@ -354,6 +364,8 @@ class BinaryClause(ClauseElement): def __init__(self, left, right, operator): self.left = left self.right = right + if isinstance(right, Select): + right._set_from_objects([]) self.operator = operator self.parens = False @@ -429,9 +441,11 @@ class Join(Selectable): return result class Alias(Selectable): - def __init__(self, selectable, alias): + def __init__(self, selectable, alias = None): self.selectable = selectable self.columns = util.OrderedProperties() + if alias is None: + alias = id(self) self.name = alias self.id = self.name self.count = 0 @@ -479,12 +493,12 @@ class ColumnSelectable(Selectable): return [self.column.table] def _compare(self, operator, obj): - if not isinstance(obj, BindParamClause) and not isinstance(obj, schema.Column): + if not isinstance(obj, ClauseElement) and not isinstance(obj, schema.Column): if self.column.table.name is None: obj = BindParamClause(self.name, obj, shortname = self.name) else: obj = BindParamClause(self.column.table.name + "_" + self.name, obj, shortname = self.name) - + return BinaryClause(self.column, obj, operator) def __lt__(self, other): @@ -605,6 +619,14 @@ class Select(Selectable): for f in self.whereclause._get_from_objects(): self.froms.setdefault(f.id, f) + class CorrelatedVisitor(ClauseVisitor): + def visit_select(s, select): + for f in self.froms.keys(): + select.clear_from(f) + self.whereclause.accept_visitor(CorrelatedVisitor()) + + def clear_from(self, id): + self.append_from(FromClause(from_name = None, from_key = id)) def append_from(self, fromclause): if type(fromclause) == str: fromclause = FromClause(from_name = fromclause) @@ -667,8 +689,11 @@ class Select(Selectable): return None + def _set_from_objects(self, obj): + self._from_obj = obj + def _get_from_objects(self): - return [self] + return getattr(self, '_from_obj', [self]) class UpdateBase(ClauseElement): @@ -722,7 +747,7 @@ class UpdateBase(ClauseElement): for c in self.table.columns: if d.has_key(c): value = d[c] - if isinstance(value, str): + if not isinstance(value, schema.Column) and not isinstance(value, ClauseElement): value = bindparam(c.name, value) values.append((c, value)) return values -- 2.47.2