]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Aug 2005 22:06:15 +0000 (22:06 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Aug 2005 22:06:15 +0000 (22:06 +0000)
TODO [new file with mode: 0644]
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/mapper.py
lib/sqlalchemy/sql.py

diff --git a/TODO b/TODO
new file mode 100644 (file)
index 0000000..4413543
--- /dev/null
+++ b/TODO
@@ -0,0 +1,20 @@
+TODO:
+
+correlated subquery support, plus clauses like EXISTS, IN, etc:
+
+    select foo from lala where g = (select x from y where lala.xx = y.bar)
+    select foo from lala where exists (select x from y where lala.xx = y.bar)
+
+table reflection, i.e. create tables with autoload = True
+
+sequences/autoincrement support
+
+Oracle module
+
+Postgres module
+
+MySQL module
+
+INSERT from a SELECT
+
+
index 93e7e737d6a8bb1c3e2aa6aaee3f6eeb4435518c..a5e5a5b19f2b5a47af96ce22cbbb2fdfbecb5292 100644 (file)
@@ -116,11 +116,14 @@ class ANSICompiler(sql.Compiled):
         self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
         
     def visit_binary(self, binary):
-        
+        if isinstance(binary.right, sql.Select):
+            s = self.get_str(binary.left) + " " + str(binary.operator) + " (" + self.get_str(binary.right) + ")"
+        else:
+            s = self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right)
         if binary.parens:
-           self.strings[binary] = "(" + self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) + ")"
+           self.strings[binary] = "(" + s + ")"
         else:
-            self.strings[binary] = self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right)
+            self.strings[binary] = s
         
     def visit_bindparam(self, bindparam):
         self.binds[bindparam.shortname] = bindparam
@@ -136,6 +139,7 @@ class ANSICompiler(sql.Compiled):
 
     def visit_alias(self, alias):
         self.froms[alias] = self.get_from_text(alias.selectable) + " " + alias.name
+        self.strings[alias] = self.get_str(alias.selectable)
 
     def visit_select(self, select):
         inner_columns = []
@@ -183,6 +187,7 @@ class ANSICompiler(sql.Compiled):
 
     def visit_table(self, table):
         self.froms[table] = table.name
+        self.strings[table] = ""
 
     def visit_join(self, join):
         if join.isouter:
@@ -194,7 +199,6 @@ class ANSICompiler(sql.Compiled):
 
     def visit_insert(self, insert_stmt):
         colparams = insert_stmt.get_colparams(self._bindparams)
-
         for c in colparams:
             b = c[1]
             self.binds[b.key] = b
index fffda0916daab4f3c2ed5ebc9bdd55901e7c86b2..42b8c4c4315452d1fa81fdb6ef62002b8370c309 100644 (file)
@@ -68,8 +68,7 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
         raise NotImplementedError()
 
 class SQLiteCompiler(ansisql.ANSICompiler):
-    def visit_insert(self, insert):
-        ansisql.ANSICompiler.visit_insert(self, insert)
+    pass
         
 class SQLiteColumnImpl(sql.ColumnSelectable):
     def _get_specification(self):
index 561a0a4a7ed59dd703c1143a3a79b993e836ab65..37cdf49317226073ef680dce9d94ffd30ffd7d00 100644 (file)
@@ -70,7 +70,7 @@ def eagerload(name):
 def lazyload(name):
     return EagerLazySwitcher(name, toeager = False)
 
-copy_containerclass Mapper(object):
+class Mapper(object):
     def __init__(self, class_, selectable, table = None, properties = None, identitymap = None, use_smart_properties = True, isroot = True, echo = None):
         self.class_ = class_
         self.selectable = selectable
index 643f9e3d335ff3c716106173bdc38302ffc61307..a8b75b875dc1585628476a51fe06c79d1e1feab2 100644 (file)
@@ -101,6 +101,14 @@ def or_(*clauses):
     clause = _compound_clause('OR', *clauses)
     return clause
 
+def exists(*args, **params):
+    s = select(*args, **params)
+    return BinaryClause(TextClause("EXISTS"), s, '')
+
+def in_(*args, **params):
+    s = select(*args, **params)
+    return BinaryClause(TextClause("IN"), s, '')
+    
 def union(*selects, **params):
     return _compound_select('UNION', *selects, **params)
 
@@ -346,7 +354,9 @@ class ClauseList(ClauseElement):
         for c in self.clauses:
             c.accept_visitor(visitor)
         visitor.visit_clauselist(self)
-        
+    
+    def _get_from_objects(self):
+        return []
         
 class BinaryClause(ClauseElement):
     """represents two clauses with an operator in between"""
@@ -354,6 +364,8 @@ class BinaryClause(ClauseElement):
     def __init__(self, left, right, operator):
         self.left = left
         self.right = right
+        if isinstance(right, Select):
+            right._set_from_objects([])
         self.operator = operator
         self.parens = False
 
@@ -429,9 +441,11 @@ class Join(Selectable):
         return result
         
 class Alias(Selectable):
-    def __init__(self, selectable, alias):
+    def __init__(self, selectable, alias = None):
         self.selectable = selectable
         self.columns = util.OrderedProperties()
+        if alias is None:
+            alias = id(self)
         self.name = alias
         self.id = self.name
         self.count = 0
@@ -479,12 +493,12 @@ class ColumnSelectable(Selectable):
         return [self.column.table]
     
     def _compare(self, operator, obj):
-        if not isinstance(obj, BindParamClause) and not isinstance(obj, schema.Column):
+        if not isinstance(obj, ClauseElement) and not isinstance(obj, schema.Column):
             if self.column.table.name is None:
                 obj = BindParamClause(self.name, obj, shortname = self.name)
             else:
                 obj = BindParamClause(self.column.table.name + "_" + self.name, obj, shortname = self.name)
-        
+
         return BinaryClause(self.column, obj, operator)
 
     def __lt__(self, other):
@@ -605,6 +619,14 @@ class Select(Selectable):
         for f in self.whereclause._get_from_objects():
             self.froms.setdefault(f.id, f)
 
+        class CorrelatedVisitor(ClauseVisitor):
+            def visit_select(s, select):
+                for f in self.froms.keys():
+                    select.clear_from(f)
+        self.whereclause.accept_visitor(CorrelatedVisitor())
+   
+    def clear_from(self, id):
+        self.append_from(FromClause(from_name = None, from_key = id))
     def append_from(self, fromclause):
         if type(fromclause) == str:
             fromclause = FromClause(from_name = fromclause)
@@ -667,8 +689,11 @@ class Select(Selectable):
             
         return None
 
+    def _set_from_objects(self, obj):
+        self._from_obj = obj
+        
     def _get_from_objects(self):
-        return [self]
+        return getattr(self, '_from_obj', [self])
 
 
 class UpdateBase(ClauseElement):
@@ -722,7 +747,7 @@ class UpdateBase(ClauseElement):
         for c in self.table.columns:
             if d.has_key(c):
                 value = d[c]
-                if isinstance(value, str):
+                if not isinstance(value, schema.Column) and not isinstance(value, ClauseElement):
                     value = bindparam(c.name, value)
                 values.append((c, value))
         return values