]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Aug 2005 20:32:42 +0000 (20:32 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Aug 2005 20:32:42 +0000 (20:32 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/mapper.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/select.py

index 3f6cbb835df6d699732b9d88580fffc01e2515eb..93e7e737d6a8bb1c3e2aa6aaee3f6eeb4435518c 100644 (file)
@@ -207,13 +207,19 @@ class ANSICompiler(sql.Compiled):
 
     def visit_update(self, update_stmt):
         colparams = update_stmt.get_colparams(self._bindparams)
-        
-        for c in colparams:
-            b = c[1]
-            self.binds[b.key] = b
-            self.binds[b.shortname] = b
-            
-        text = "UPDATE " + update_stmt.table.name + " SET " + string.join(["%s=:%s" % (c[0].name, c[1].key) for c in colparams], ', ')
+        def create_param(p):
+            if isinstance(p, BindParamClause):
+                self.binds[p.key] = p
+                self.binds[p.shortname] = p
+                return ":" + p.key
+            else:
+                p.accept_visitor(self)
+                if isinstance(p, ClauseElement):
+                    return "(" + self.get_str(p) + ")"
+                else:
+                    return self.get_str(p)
+                
+        text = "UPDATE " + update_stmt.table.name + " SET " + string.join(["%s=%s" % (c[0].name, create_param(c[1])) for c in colparams], ', ')
         
         if update_stmt.whereclause:
             text += " WHERE " + self.get_str(update_stmt.whereclause)
index 28d42792f834b18b50f42eeca37dd30bc9325ecb..561a0a4a7ed59dd703c1143a3a79b993e836ab65 100644 (file)
@@ -70,7 +70,7 @@ def eagerload(name):
 def lazyload(name):
     return EagerLazySwitcher(name, toeager = False)
 
-class Mapper(object):
+copy_containerclass Mapper(object):
     def __init__(self, class_, selectable, table = None, properties = None, identitymap = None, use_smart_properties = True, isroot = True, echo = None):
         self.class_ = class_
         self.selectable = selectable
@@ -408,7 +408,7 @@ class LazyLoader(PropertyLoader):
             self.lazywhere = sql.and_(self.primaryjoin, self.secondaryjoin)
         else:
             self.lazywhere = self.primaryjoin
-        self.lazywhere = self.lazywhere.copy_structure()
+        self.lazywhere = self.lazywhere.copy_container()
         li = LazyIzer(primarytable)
         self.lazywhere.accept_visitor(li)
         self.binds = li.binds
index 41cb166212febccc1c3af69df7e9db7cecb54b8a..74bb2e3c5e33cf2ed2965399ec700e1de359e763 100644 (file)
 from sqlalchemy.util import *
 import copy
 
-engine = None
-
-
-__ALL__ = ['Table', 'Column', 'Relation', 'Sequence', 
+__ALL__ = ['Table', 'Column', 'Sequence', 
             'INT', 'CHAR', 'VARCHAR', 'TEXT', 'FLOAT', 'DECIMAL', 
             'TIMESTAMP', 'DATETIME', 'CLOB', 'BLOB', 'BOOLEAN'
             ]
 
 
-class INT: pass
+class INT:
+    """integer datatype"""
+    pass
 
 class CHAR:
+    """character datatype"""
     def __init__(self, length):
         self.length = length
-        
+
 class VARCHAR:
     def __init__(self, length):
         self.length = length
-        
-class TEXT: pass
+
+
 class FLOAT:
     def __init__(self, precision, length):
         self.precision = precision
         self.length = length
-        
+
+class TEXT: pass
 class DECIMAL: pass
 class TIMESTAMP: pass
 class DATETIME: pass
@@ -56,12 +57,18 @@ class SchemaItem(object):
     def _init_items(self, *args):
         for item in args:
             item._set_parent(self)
-            
-    def accept_visitor(self, visitor): raise NotImplementedError()
-    def _set_parent(self, parent): raise NotImplementedError()
+
+    def accept_visitor(self, visitor):
+        raise NotImplementedError()
+
+    def _set_parent(self, parent):
+        """a child item attaches itself to its parent via this method."""
+        raise NotImplementedError()
+
     def hash_key(self):
+        """returns a string that identifies this SchemaItem uniquely"""
         return repr(self)
-        
+
     def __getattr__(self, key):
         return getattr(self._impl, key)
 
@@ -69,27 +76,27 @@ class SchemaItem(object):
 class Table(SchemaItem):
     """represents a relational database table."""
     
-    def __init__(self, name, engine, *args, **params):
+    def __init__(self, name, engine, *args, **kwargs):
         self.name = name
         self.columns = OrderedProperties()
         self.c = self.columns
         self.relations = []
+        self.primary_keys = []
         self.engine = engine
         self._impl = self.engine.tableimpl(self)
         self._init_items(*args)
-        
-        if params.get('autoload', False):
+
+        # load column definitions from the database if 'autoload' is defined
+        if kwargs.get('autoload', False):
             self.engine.reflecttable(self)
 
     def append_item(self, item):
         self._init_items(item)
-        
+
     def _set_parent(self, schema):
         schema.tables[self.name] = self
         self.schema = schema
 
-    primary_keys = property (lambda self: [c for c in self.columns if c.primary_key])
-
     def accept_visitor(self, visitor): 
         for c in self.columns:
             c.accept_visitor(visitor)
@@ -97,29 +104,27 @@ class Table(SchemaItem):
 
 class Column(SchemaItem):
     """represents a column in a database table."""
-    def __init__(self, name, type, reference = None, key = None, primary_key = False, *args, **params):
+    def __init__(self, name, type, key = None, primary_key = False, *args):
         self.name = name
         self.type = type
         self.sequences = OrderedProperties()
-        self.reference = reference
         self.key = key or name
         self.primary_key = primary_key
         self._items = args
 
     def _set_parent(self, table):
         table.columns[self.key] = self
+        if self.primary_key:
+            table.primary_keys.append(self)
         self.table = table
         self.engine = table.engine
 
         self._impl = self.engine.columnimpl(self)
-                        
         self._init_items(*self._items)
-        if self.reference is not None:
-            Relation(self.table, self.reference.table, self == self.reference)
 
     def _make_proxy(self, selectable, name = None):
-        # wow! using copy.copy(c) adds a full second to the select.py unittest package
+        """creates a copy of this Column for use in a new selectable unit"""
+        # using copy.copy(c) seems to add a full second to the select.py unittest package
         #c = copy.copy(self)
         #if name is not None:
          #   c.name = name
@@ -131,7 +136,7 @@ class Column(SchemaItem):
         selectable.columns[c.key] = c
         c._impl = self.engine.columnimpl(c)
         return c
-                
+
     def accept_visitor(self, visitor): 
         return visitor.visit_column(self)
 
@@ -144,42 +149,24 @@ class Column(SchemaItem):
     def __str__(self): return self._impl.__str__()
 
 
-class Relation(SchemaItem):
-    def __init__(self, parent, child, relationship, association = None, lazy = True):
-        self.parent = parent
-        self.child = child
-        self.relationship = relationship
-        self.lazy = lazy
-        self.association = association
-
-        self._set_parent(parent)
-
-    def _set_parent(self, table):
-        table.relations.append(self)
-        self.table = table
-
-    def accept_visitor(self, visitor):
-        visitor.visit_relation(self)
-            
 class Sequence(SchemaItem):
-    """represents a sequence."""
+    """represents a sequence, which applies to Oracle and Postgres databases."""
     def set_parent(self, column, key):
         column.sequences[key] = self
         self.column = column
-        
-    def accept_visitor(self, visitor): 
+    def accept_visitor(self, visitor):
         return visitor.visit_sequence(self)
-        
+
 class SchemaEngine(object):
+    """a factory object used to create implementations for schema objects"""
     def tableimpl(self, table):
         raise NotImplementedError()
-        
+
     def columnimpl(self, column):
         raise NotImplementedError()
 
 class SchemaVisitor(object):
-        """base class for an object that traverses a Schema object structure,
-        or sub-objects within one, and acts upon each node."""
+        """base class for an object that traverses across Schema objects"""
 
         def visit_schema(self, schema):pass
         def visit_table(self, table):pass
index 2c133f6c166375f987bd1e9c75030d63019a9274..643f9e3d335ff3c716106173bdc38302ffc61307 100644 (file)
@@ -16,9 +16,8 @@
 # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 
 
-"""base sql module used by all sql implementations.  defines abstract units which construct
-expression trees that generate into text strings + bind parameters.
-"""
+"""defines the base components of SQL expression trees."""
+
 import sqlalchemy.schema as schema
 import sqlalchemy.util as util
 import string
@@ -26,28 +25,74 @@ import string
 __ALL__ = ['textclause', 'select', 'join', 'and_', 'or_', 'union', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'bindparam', 'sequence']
 
 def desc(column):
+    """returns a descending ORDER BY clause element"""
     return CompoundClause(None, column, "DESC")
 
 def asc(column):
+    """returns an ascending ORDER BY clause element"""
     return CompoundClause(None, column, "ASC")
 
-def outerjoin(left, right, onclause, **params):
-    return Join(left, right, onclause, isouter = True, **params)
+def outerjoin(left, right, onclause, **kwargs):
+    """returns an OUTER JOIN clause element, given the left and right hand expressions,
+    as well as the ON condition's expression.  When chaining joins together, the previous JOIN
+    expression should be specified as the left side of this JOIN expression."""
+    return Join(left, right, onclause, isouter = True, **kwargs)
+
+def join(left, right, onclause, **kwargs):
+    """returns a JOIN clause element (regular inner join), given the left and right hand expressions,
+    as well as the ON condition's expression.  When chaining joins together, the previous JOIN
+    expression should be specified as the left side of this JOIN expression."""
+    return Join(left, right, onclause, **kwargs)
+
+def select(columns, whereclause = None, from_obj = [], **kwargs):
+    """returns a SELECT clause element, given a list of columns and/or selectable items to select
+    columns from, an optional expression for the WHERE clause, an optional list of "FROM" objects
+    to select from, and additional parameters."""
+    return Select(columns, whereclause = whereclause, from_obj = from_obj, **kwargs)
+
+def insert(table, values = None, **kwargs):
+    """returns an INSERT clause element.
     
-def join(left, right, onclause, **params):
-    return Join(left, right, onclause, **params)
-
-def select(columns, whereclause = None, from_obj = [], **params):
-    return Select(columns, whereclause = whereclause, from_obj = from_obj, **params)
-
-def insert(table, values = None, **params):
-    return Insert(table, values, **params)
+    'table' is the table to be inserted into.
+    'values' is a dictionary which specifies the column specifications of the INSERT, and is optional.  
+    If left as None, the
+    column specifications are determined from the bind parameters used during the compile phase of the
+    INSERT statement.  If the bind parameters also are None during the compile phase, then the column
+    specifications will be generated from the full list of table columns.
+
+    If both 'values' and compile-time bind parameters are present, the compile-time bind parameters
+    override the information specified within 'values' on a per-key basis.
+
+    The keys within 'values' can be either Column objects or their string identifiers.  
+    Each key may reference one of: a literal data value (i.e. string, number, etc.), a Column object,
+    or a SELECT statement.  If a SELECT statement is specified which references this INSERT statement's
+    table, the statement will be correlated against the INSERT statement.  
+    """
+    return Insert(table, values, **kwargs)
 
-def update(table, whereclause = None, values = None, **params):
-    return Update(table, whereclause, values, **params)
+def update(table, whereclause = None, values = None, **kwargs):
+    """returns an UPDATE clause element.  
+    
+    'table' is the table to be updated.
+    'whereclause' is a ClauseElement describing the WHERE condition of the UPDATE statement.
+    'values' is a dictionary which specifies the SET conditions of the UPDATE, and is optional.  
+    If left as None, the
+    SET conditions are determined from the bind parameters used during the compile phase of the
+    UPDATE statement.  If the bind parameters also are None during the compile phase, then the SET
+    conditions will be generated from the full list of table columns.
+    
+    If both 'values' and compile-time bind parameters are present, the compile-time bind parameters
+    override the information specified within 'values' on a per-key basis.
+    
+    The keys within 'values' can be either Column objects or their string identifiers.  
+    Each key may reference one of: a literal data value (i.e. string, number, etc.), a Column object,
+    or a SELECT statement.  If a SELECT statement is specified which references this UPDATE statement's
+    table, the statement will be correlated against the UPDATE statement.  
+    """
+    return Update(table, whereclause, values, **kwargs)
 
-def delete(table, whereclause = None, **params):
-    return Delete(table, whereclause, **params)
+def delete(table, whereclause = None, **kwargs):
+    return Delete(table, whereclause, **kwargs)
 
 def and_(*clauses):
     return _compound_clause('AND', *clauses)
@@ -55,7 +100,7 @@ def and_(*clauses):
 def or_(*clauses):
     clause = _compound_clause('OR', *clauses)
     return clause
-    
+
 def union(*selects, **params):
     return _compound_select('UNION', *selects, **params)
 
@@ -73,23 +118,25 @@ def textclause(text):
 
 def sequence():
     return Sequence()
-    
+
 def _compound_clause(keyword, *clauses):
     return CompoundClause(keyword, *clauses)
 
 def _compound_select(keyword, *selects, **params):
-    if len(selects) == 0: return None
-    
+    if len(selects) == 0:
+        return None
     s = selects[0]
     for n in selects[1:]:
         s.append_clause(keyword, n)
-        
+
     if params.get('order_by', None) is not None:
         s.order_by(*params['order_by'])
 
     return s
 
 class ClauseVisitor(schema.SchemaVisitor):
+    """builds upon SchemaVisitor to define the visiting of SQL statement elements in 
+    addition to Schema elements."""
     def visit_columnclause(self, column):pass
     def visit_fromclause(self, fromclause):pass
     def visit_bindparam(self, bindparam):pass
@@ -101,13 +148,23 @@ class ClauseVisitor(schema.SchemaVisitor):
     def visit_join(self, join):pass
     
 class Compiled(ClauseVisitor):
-    pass
-    
+    """represents a compiled SQL expression.  the __str__ method of the Compiled object
+    should produce the actual text of the statement.  Compiled objects are specific to the database
+    library that created them, and also may or may not be specific to the columns referenced 
+    within a particular set of bind parameters.  In no case should the Compiled object be dependent
+    on the actual values of those bind parameters, even though it may reference those values
+    as defaults."""
+    def __str__(self):
+        raise NotImplementedError()
+    def get_params(self, **params):
+        """returns the bind params for this compiled object, with values overridden by 
+        those given in the **params dictionary"""
+        raise NotImplementedError()
+        
 class ClauseElement(object):
-    """base class for elements of a generated SQL statement.
+    """base class for elements of a programmatically constructed SQL expression.
     
-    includes a parameter hash to store bind parameter key/value pairs,
-    as well as a list of 'from objects' which collects items to be placed
+    includes a list of 'from objects' which collects items to be placed
     in the FROM clause of a SQL statement.
     
     when many ClauseElements are attached together, the from objects and bind
@@ -115,26 +172,46 @@ class ClauseElement(object):
     """
 
     def hash_key(self):
+        """returns a string that uniquely identifies the concept this ClauseElement represents.
+        
+        two ClauseElements can have the same value for hash_key() iff they both correspond to the 
+        exact same generated SQL.  This allows the hash_key() values of a collection of ClauseElements
+        to be constructed into a larger identifying string for the purpose of caching a SQL expression.
+        
+        Note that since ClauseElements may be mutable, the hash_key() value is subject to change
+        if the underlying structure of the ClauseElement changes."""
         raise NotImplementedError(repr(self))
     def _get_from_objects(self):
         raise NotImplementedError(repr(self))
     def accept_visitor(self, visitor):
         raise NotImplementedError(repr(self))
 
-    def compile(self, engine, bindparams = None):
-        return engine.compile(self, bindparams = bindparams)
 
-    def copy_structure(self):
-        """allows the copying of a statement's containers, so that a modified statement
-        can be produced without affecting the original.  containing clauseelements,
-        like Select, Join, CompoundClause, BinaryClause, etc.,  should produce a copy of 
-        themselves, whereas "leaf-node" clauseelements should return themselves."""
+    def copy_container(self):
+        """should return a copy of this ClauseElement, iff this ClauseElement contains other 
+        ClauseElements.  Otherwise, it should be left alone to return self.  This is used to create
+        copies of expression trees that still reference the same "leaf nodes".  The new structure
+        can then be restructured without affecting the original."""
         return self
-        
+
     def _engine(self):
+        """should return a SQLEngine instance that is associated with this expression tree.
+        this engine is usually attached to one of the underlying Table objects within the expression."""
         raise NotImplementedError("Object %s has no built-in SQLEngine." % repr(self))
-        
+
+    def compile(self, engine, bindparams = None):
+        """compiles this SQL expression using its underlying SQLEngine to produce
+        a Compiled object.  The actual SQL statement is the Compiled object's string representation.   
+        bindparams is an optional dictionary representing the bind parameters to be used with 
+        the statement.  Currently, only the compilations of INSERT and UPDATE statements
+        use the bind parameters, in order to determine which
+        table columns should be used in the statement."""
+        return engine.compile(self, bindparams = bindparams)
+
     def execute(self, **params):
+        """compiles and executes this SQL expression using its underlying SQLEngine.
+        the given **params are used as bind parameters when compiling and executing the expression. 
+        the DBAPI cursor object is returned."""
         e = self._engine()
         c = self.compile(e, bindparams = params)
         # TODO: do pre-execute right here, for sequences, if the compiled object
@@ -142,13 +219,14 @@ class ClauseElement(object):
         return e.execute(str(c), c.get_params(), echo = getattr(self, 'echo', None))
 
     def result(self, **params):
+        """the same as execute(), except a RowProxy object is returned instead of a DBAPI cursor."""
         e = self._engine()
         c = self.compile(e, bindparams = params)
-        return e.result(str(c), c.binds)
-        
+        return e.result(str(c), c.get_params(), echo = getattr(self, 'echo', None))
+
 class ColumnClause(ClauseElement):
-    """represents a column clause element in a SQL statement."""
-    
+    """represents a textual column clause in a SQL statement."""
+
     def __init__(self, text, selectable):
         self.text = text
         self.table = selectable
@@ -165,7 +243,7 @@ class ColumnClause(ClauseElement):
 
     def hash_key(self):
         return "ColumnClause(%s, %s)" % (self.text, self.table.hash_key())
-        
+
     def _get_from_objects(self):
         return []
 
@@ -236,8 +314,8 @@ class CompoundClause(ClauseElement):
             if c is None: continue
             self.append(c)
     
-    def copy_structure(self):
-        clauses = [clause.copy_structure() for clause in self.clauses]
+    def copy_container(self):
+        clauses = [clause.copy_container() for clause in self.clauses]
         return CompoundClause(self.operator, *clauses)
         
     def append(self, clause):
@@ -279,8 +357,8 @@ class BinaryClause(ClauseElement):
         self.operator = operator
         self.parens = False
 
-    def copy_structure(self):
-        return BinaryClause(self.left.copy_structure(), self.right.copy_structure(), self.operator)
+    def copy_container(self):
+        return BinaryClause(self.left.copy_container(), self.right.copy_container(), self.operator)
         
     def _get_from_objects(self):
         return self.left._get_from_objects() + self.right._get_from_objects()
@@ -394,7 +472,7 @@ class ColumnSelectable(Selectable):
             self.label = self.column.name
             self.fullname = self.column.name
 
-    def copy_structure(self):
+    def copy_container(self):
         return self.column
     
     def _get_from_objects(self):
@@ -594,29 +672,59 @@ class Select(Selectable):
 
 
 class UpdateBase(ClauseElement):
+    """forms the base for INSERT, UPDATE, and DELETE statements.  
+    Deals with the special needs of INSERT and UPDATE parameter lists -  
+    these statements have two separate lists of parameters, those
+    defined when the statement is constructed, and those specified at compile time."""
+    
+    def _process_colparams(self, parameters):
+        if parameters is None:
+            return None
+
+        for key in parameters.keys():
+            value = parameters[key]
+            if isinstance(value, Select):
+                value.append_from(FromClause(from_key=self.table.id))
+            elif not isinstance(value, schema.Column) and not isinstance(value, ClauseElement):
+                try:
+                    col = self.table.c[key]
+                    parameters[key] = bindparam(col.name, value)
+                except KeyError:
+                    del parameters[key]
+
+        return parameters
+        
     def get_colparams(self, parameters):
-        values = []
+        # case one: no parameters in the statement, no parameters in the 
+        # compiled params - just return binds for all the table columns
+        if parameters is None and self.parameters is None:
+            return [(c, bindparam(c.name)) for c in self.table.columns]
 
+        # if we have statement parameters - set defaults in the 
+        # compiled params
         if parameters is None:
-            parameters = self.parameters
+            parameters = {}
             
-        if parameters is None:
-            for c in self.table.columns:
-                values.append((c, bindparam(c.name)))                
-        else:
-            d = {}
-            for key, value in parameters.iteritems():
-                if isinstance(key, schema.Column):
-                    d[key] = value
-                else:
-                    d[self.table.columns[str(key)]] = value
-                
-            for c in self.table.columns:
-                if d.has_key(c):
-                    value = d[c]
-                    if not isinstance(value, BindParamClause):
-                        value = bindparam(c.name, value)
-                    values.append((c, value))
+        if self.parameters is not None:
+            for k, v in self.parameters.iteritems():
+                parameters.setdefault(k, v)
+
+        # now go thru compiled params, get the Column object for each key
+        d = {}
+        for key, value in parameters.iteritems():
+            if isinstance(key, schema.Column):
+                d[key] = value
+            else:
+                d[self.table.columns[str(key)]] = value
+
+        # create a list of column assignment clauses as tuples
+        values = []
+        for c in self.table.columns:
+            if d.has_key(c):
+                value = d[c]
+                if isinstance(value, str):
+                    value = bindparam(c.name, value)
+                values.append((c, value))
         return values
 
     def _engine(self):
@@ -635,7 +743,7 @@ class Insert(UpdateBase):
     def __init__(self, table, parameters = None, **params):
         self.table = table
         self.select = None
-        self.parameters = parameters
+        self.parameters = self._process_colparams(parameters)
         self.engine = self.table._engine()
         
     def accept_visitor(self, visitor):
@@ -650,41 +758,33 @@ class Insert(UpdateBase):
     def compile(self, engine = None, bindparams = None):
         if engine is None:
             engine = self.engine
-            
         if engine is None:
             raise "no engine supplied, and no engine could be located within the clauses!"
-
         return engine.compile(self, bindparams)
 
 class Update(UpdateBase):
     def __init__(self, table, whereclause, parameters = None, **params):
         self.table = table
         self.whereclause = whereclause
-
-        self.parameters = parameters
+        self.parameters = self._process_colparams(parameters)
         self.engine = self.table._engine()
 
-    
     def accept_visitor(self, visitor):
         if self.whereclause is not None:
             self.whereclause.accept_visitor(visitor)
-
         visitor.visit_update(self)
 
 class Delete(UpdateBase):
     def __init__(self, table, whereclause, **params):
         self.table = table
         self.whereclause = whereclause
-
         self.engine = self.table._engine()
 
-    
     def accept_visitor(self, visitor):
         if self.whereclause is not None:
             self.whereclause.accept_visitor(visitor)
-
         visitor.visit_delete(self)
-        
+
 class Sequence(BindParamClause):
     def __init__(self):
         BindParamClause.__init__(self, 'sequence')
index b1e2ed13caeaf2eea7293738076c295f4ddce953..2d3f23eb6199370edd8db8065b06a844243154f0 100644 (file)
@@ -11,38 +11,26 @@ from sqlalchemy.schema import *
 from testbase import PersistTest
 import unittest, re
 
-class SelectTest(PersistTest):
-    
-    def setUp(self):
-
-        self.table = Table('mytable', db,
-            Column('myid', 3, key = 'id'),
-            Column('name', 4, key = 'name'),
-            Column('description', 4, key = 'description'),
-        )
 
-        self.table2 = Table(
-            'myothertable', db,
-            Column('otherid',3, key='id'),
-            Column('othername', 4, key='name'),
-        )
+table = Table('mytable', db,
+    Column('myid', 3, key = 'id'),
+    Column('name', 4, key = 'name'),
+    Column('description', 4, key = 'description'),
+)
 
-        self.table3 = Table(
-            'thirdtable', db,
-            Column('userid', 5, key='id'),
-            Column('otherstuff', 5),
-        )
+table2 = Table(
+    'myothertable', db,
+    Column('otherid',3, key='id'),
+    Column('othername', 4, key='name'),
+)
 
-    
-    def testoperator(self):
-        return
-        table = Table(
-            'mytable',
-            Column('myid',3, key='id'),
-            Column('name', 4)
-        )
+table3 = Table(
+    'thirdtable', db,
+    Column('userid', 5, key='id'),
+    Column('otherstuff', 5),
+)
 
-        print (table.c.id == 5)
+class SelectTest(PersistTest):
 
     def testtext(self):
         self.runtest(
@@ -52,14 +40,14 @@ class SelectTest(PersistTest):
         )
     
     def testtableselect(self):
-        self.runtest(self.table.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable")
+        self.runtest(table.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable")
 
-        self.runtest(select([self.table, self.table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \
+        self.runtest(select([table, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \
 myothertable.othername FROM mytable, myothertable")
         
     def testsubquery(self):
     
-        s = select([self.table], self.table.c.name == 'jack')    
+        s = select([table], table.c.name == 'jack')    
         self.runtest(
             select(
                 [s],
@@ -68,7 +56,7 @@ myothertable.othername FROM mytable, myothertable")
             ,
         "SELECT myid, name, description FROM (SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name) WHERE myid = :myid")
         
-        sq = Select([self.table])
+        sq = Select([table])
         self.runtest(
             sq.select(),
             "SELECT myid, name, description FROM (SELECT mytable.myid, mytable.name, mytable.description FROM mytable)"
@@ -76,7 +64,7 @@ myothertable.othername FROM mytable, myothertable")
         
         sq = subquery(
             'sq',
-            [self.table],
+            [table],
         )
 
         self.runtest(
@@ -87,8 +75,8 @@ myothertable.othername FROM mytable, myothertable")
         
         sq = subquery(
             'sq',
-            [self.table, self.table2],
-            and_(self.table.c.id ==7, self.table2.c.id==self.table.c.id),
+            [table, table2],
+            and_(table.c.id ==7, table2.c.id==table.c.id),
             use_labels = True
         )
         
@@ -115,15 +103,15 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") s
         
     def testand(self):
         self.runtest(
-            select(['*'], and_(self.table.c.id == 12, self.table.c.name=='asdf', self.table2.c.name == 'foo', "sysdate() = today()")), 
+            select(['*'], and_(table.c.id == 12, table.c.name=='asdf', table2.c.name == 'foo', "sysdate() = today()")), 
             "SELECT * FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name AND myothertable.othername = :myothertable_othername AND sysdate() = today()"
         )
 
     def testor(self):
         self.runtest(
-            select([self.table], and_(
-                self.table.c.id == 12,
-                or_(self.table2.c.name=='asdf', self.table2.c.name == 'foo', self.table2.c.id == 9),
+            select([table], and_(
+                table.c.id == 12,
+                or_(table2.c.name=='asdf', table2.c.name == 'foo', table2.c.id == 9),
                 "sysdate() = today()", 
             )),
             "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND (myothertable.othername = :myothertable_othername OR myothertable.othername = :myothertable_othername_1 OR myothertable.otherid = :myothertable_otherid) AND sysdate() = today()"
@@ -132,25 +120,25 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") s
 
     def testmultiparam(self):
         self.runtest(
-            select(["*"], or_(self.table.c.id == 12, self.table.c.id=='asdf', self.table.c.id == 'foo')), 
+            select(["*"], or_(table.c.id == 12, table.c.id=='asdf', table.c.id == 'foo')), 
             "SELECT * FROM mytable WHERE mytable.myid = :mytable_myid OR mytable.myid = :mytable_myid_1 OR mytable.myid = :mytable_myid_2"
         )
 
     def testorderby(self):
         self.runtest(
-            self.table2.select(order_by = [self.table2.c.id, asc(self.table2.c.name)]),
+            table2.select(order_by = [table2.c.id, asc(table2.c.name)]),
             "SELECT myothertable.otherid, myothertable.othername FROM myothertable ORDER BY myothertable.otherid, myothertable.othername ASC"
         )
     def testalias(self):
         # test the alias for a table.  column names stay the same, table name "changes" to "foo".
         self.runtest(
-        select([alias(self.table, 'foo')])
+        select([alias(table, 'foo')])
         ,"SELECT foo.myid, foo.name, foo.description FROM mytable foo")
     
         # create a select for a join of two tables.  use_labels means the column names will have
         # labels tablename_columnname, which become the column keys accessible off the Selectable object.
         # also, only use one column from the second table and all columns from the first table.
-        q = select([self.table, self.table2.c.id], self.table.c.id == self.table2.c.id, use_labels = True)
+        q = select([table, table2.c.id], table.c.id == table2.c.id, use_labels = True)
         
         # make an alias of the "selectable".  column names stay the same (i.e. the labels), table name "changes" to "t2view".
         a = alias(q, 't2view')
@@ -177,11 +165,11 @@ WHERE mytable.myid = myothertable.otherid) t2view WHERE t2view.mytable_myid = :t
 
     def testliteralmix(self):
         self.runtest(select(
-            [self.table, self.table2.c.id, "sysdate()", "foo, bar, lala"],
+            [table, table2.c.id, "sysdate()", "foo, bar, lala"],
             and_(
                 "foo.id = foofoo(lala)",
                 "datetime(foo) = Today",
-                self.table.c.id == self.table2.c.id,
+                table.c.id == table2.c.id,
             )
         ), 
         "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, sysdate(), foo, bar, lala \
@@ -189,7 +177,7 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
 
     def testliteralsubquery(self):
         self.runtest(select(
-            [alias(self.table, 't'), "foo.f"],
+            [alias(table, 't'), "foo.f"],
             "foo.f = t.id",
             from_obj = ["(select f from bar where lala=heyhey) foo"]
         ), 
@@ -197,38 +185,38 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
 
     def testjoin(self):
         self.runtest(
-            join(self.table2, self.table, self.table.c.id == self.table2.c.id).select(),
+            join(table2, table, table.c.id == table2.c.id).select(),
             "SELECT myothertable.otherid, myothertable.othername, mytable.myid, mytable.name, mytable.description \
 FROM myothertable, mytable WHERE mytable.myid = myothertable.otherid"
         )
         
         self.runtest(
             select(
-                [self.table],
-                from_obj = [join(self.table, self.table2, self.table.c.id == self.table2.c.id)]
+                [table],
+                from_obj = [join(table, table2, table.c.id == table2.c.id)]
             ),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid")
         
         self.runtest(
             select(
-                [join(join(self.table, self.table2, self.table.c.id == self.table2.c.id), self.table3, self.table.c.id == self.table3.c.id)
+                [join(join(table, table2, table.c.id == table2.c.id), table3, table.c.id == table3.c.id)
             ]),
             "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid JOIN thirdtable ON mytable.myid = thirdtable.userid"
         )
         
     def testmultijoin(self):
         self.runtest(
-                select([self.table, self.table2, self.table3],
-                from_obj = [outerjoin(join(self.table, self.table2, self.table.c.id == self.table2.c.id), self.table3, self.table.c.id==self.table3.c.id)]
+                select([table, table2, table3],
+                from_obj = [outerjoin(join(table, table2, table.c.id == table2.c.id), table3, table.c.id==table3.c.id)]
                 )
                 ,"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON mytable.myid = thirdtable.userid"
             )
             
     def testunion(self):
             x = union(
-                  select([self.table], self.table.c.id == 5),
-                  select([self.table], self.table.c.id == 12),
-                  order_by = [self.table.c.id],
+                  select([table], table.c.id == 5),
+                  select([table], table.c.id == 12),
+                  order_by = [table.c.id],
             )
   
             self.runtest(x, "SELECT mytable.myid, mytable.name, mytable.description \
@@ -238,9 +226,9 @@ FROM mytable WHERE mytable.myid = :mytable_myid_1 ORDER BY mytable.myid")
   
             self.runtest(
                     union(
-                        select([self.table]),
-                        select([self.table2]),
-                        select([self.table3])
+                        select([table]),
+                        select([table2]),
+                        select([table3])
                     )
             ,
             "SELECT mytable.myid, mytable.name, mytable.description \
@@ -255,14 +243,14 @@ FROM myothertable UNION SELECT thirdtable.userid, thirdtable.otherstuff FROM thi
         # parameters.
         
         query = select(
-                [self.table, self.table2],
+                [table, table2],
                 and_(
-                    self.table.c.name == 'fred',
-                    self.table.c.id == 10,
-                    self.table2.c.name != 'jack',
+                    table.c.name == 'fred',
+                    table.c.id == 10,
+                    table2.c.name != 'jack',
                     "EXISTS (select yay from foo where boo = lar)"
                 ),
-                from_obj = [ outerjoin(self.table, self.table2, self.table.c.id == self.table2.c.id) ]
+                from_obj = [ outerjoin(table, table2, table.c.id == table2.c.id) ]
                 )
                 
         self.runtest(query, 
@@ -286,9 +274,9 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo
     def testbindparam(self):
         #return
         self.runtest(select(
-                    [self.table, self.table2],
-                    and_(self.table.c.id == self.table2.c.id,
-                    self.table.c.name == bindparam('mytablename'),
+                    [table, table2],
+                    and_(table.c.id == table2.c.id,
+                    table.c.name == bindparam('mytablename'),
                     )
                 ),
                 "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \
@@ -298,33 +286,45 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
 
     def testinsert(self):
         # generic insert, will create bind params for all columns
-        self.runtest(insert(self.table), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)")
+        self.runtest(insert(table), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)")
 
         # insert with user-supplied bind params for specific columns,
         # cols provided literally
         self.runtest(
-            insert(self.table, {self.table.c.id : bindparam('userid'), self.table.c.name : bindparam('username')}), 
+            insert(table, {table.c.id : bindparam('userid'), table.c.name : bindparam('username')}), 
             "INSERT INTO mytable (myid, name) VALUES (:userid, :username)")
         
         # insert with user-supplied bind params for specific columns, cols
         # provided as strings
         self.runtest(
-            insert(self.table, dict(id = 3, name = 'jack')), 
+            insert(table, dict(id = 3, name = 'jack')), 
             "INSERT INTO mytable (myid, name) VALUES (:myid, :name)"
         )
         
         # insert with a subselect provided 
         #self.runtest(
-         #   insert(self.table, select([self.table2])),
+         #   insert(table, select([table2])),
          #   ""
         #)
 
     def testupdate(self):
-        self.runtest(update(self.table, self.table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {self.table.c.name:'fred'})
-        self.runtest(update(self.table, self.table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {'name':'fred'})
-
+        self.runtest(update(table, table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {table.c.name:'fred'})
+        self.runtest(update(table, table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {'name':'fred'})
+        self.runtest(update(table, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid")
+        self.runtest(update(table, table.c.id == 12, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'})
+
+    def testcorrelatedupdate(self):
+        # test against a straight text subquery
+        u = update(table, values = {table.c.name : TextClause("select name from mytable where id=mytable.id")})
+        self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)")
+        
+        # test against a regular constructed subquery
+        s = select([table2], table2.c.id == table.c.id)
+        u = update(table, table.c.name == 'jack', values = {table.c.name : s})
+        self.runtest(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :mytable_name")
+        
     def testdelete(self):
-        self.runtest(delete(self.table, self.table.c.id == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")
+        self.runtest(delete(table, table.c.id == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")
         
         
     def runtest(self, clause, result, engine = None, params = None):