]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merged sql_rearrangement branch , refactors sql package to work standalone with
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Feb 2006 07:12:50 +0000 (07:12 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 25 Feb 2006 07:12:50 +0000 (07:12 +0000)
clause elements including tables and columns, schema package deals with "physical"
representations

15 files changed:
CHANGES
lib/sqlalchemy/__init__.py
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine.py
lib/sqlalchemy/ext/proxy.py
lib/sqlalchemy/mapping/mapper.py
lib/sqlalchemy/mapping/objectstore.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py
test/objectstore.py
test/select.py
test/testbase.py

diff --git a/CHANGES b/CHANGES
index 6b85598d9ec2c0a755d5f07a06da6ac3df09cb74..52b2846fffa229b9746c604080ca6d64ff899299 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -2,6 +2,9 @@
 - fix to Oracle "row_number over" clause with mulitple tables
 - mapper.get() was not selecting multiple-keyed objects if the mapper's table was a join,
 such as in an inheritance relationship, this is fixed.
+- overhaul to sql/schema packages so that the sql package can run all on its own,
+producing selects, inserts, etc. without any engine dependencies.  Table/Column
+are the "physical" subclasses of TableClause/ColumnClause.
 0.1.2
 - fixed a recursive call in schema that was somehow running 994 times then returning
 normally.  broke nothing, slowed down everything.  thanks to jpellerin for finding this.
index 0c8aa2fe018ba92e5054ff08d51927a045b42f1c..d38a557f97769de55dc6f499466f19c4f12d1275 100644 (file)
@@ -6,8 +6,8 @@
 
 from engine import *
 from types import *
+from sql import *
 from schema import *
 from exceptions import *
-from sql import *
 import mapping as mapperlib
 from mapping import *
index ac10d27f16858c45a501b21f59a981e01fbc1493..c25a55c7acf5938b28236c6e71689bc4c103cce2 100644 (file)
@@ -152,16 +152,11 @@ class ANSICompiler(sql.Compiled):
             # if we are within a visit to a Select, set up the "typemap"
             # for this column which is used to translate result set values
             self.typemap.setdefault(column.key.lower(), column.type)
-        if column.table.name is None:
+        if column.table is not None and column.table.name is None:
             self.strings[column] = column.name
         else:
             self.strings[column] = "%s.%s" % (column.table.name, column.name)
 
-    def visit_columnclause(self, column):
-        if column.table is not None and column.table.name is not None:
-            self.strings[column] = "%s.%s" % (column.table.name, column.text)
-        else:
-            self.strings[column] = column.text
 
     def visit_fromclause(self, fromclause):
         self.froms[fromclause] = fromclause.from_name
@@ -257,11 +252,13 @@ class ANSICompiler(sql.Compiled):
                         l = co.label(co._label)
                         l.accept_visitor(self)
                         inner_columns[co._label] = l
-                    elif select.issubquery and isinstance(co, Column):
+                    # TODO: figure this out, a ColumnClause with a select as a parent
+                    # is different from any other kind of parent
+                    elif select.issubquery and isinstance(co, sql.ColumnClause) and co.table is not None and not isinstance(co.table, sql.Select):
                         # SQLite doesnt like selecting from a subquery where the column
                         # names look like table.colname, so add a label synonomous with
                         # the column name
-                        l = co.label(co.key)
+                        l = co.label(co.text)
                         l.accept_visitor(self)
                         inner_columns[self.get_str(l.obj)] = l
                     else:
@@ -379,7 +376,7 @@ class ANSICompiler(sql.Compiled):
         contains a Sequence object."""
         pass
     
-    def visit_insert_column(selef, column):
+    def visit_insert_column(self, column):
         """called when visiting an Insert statement, for each column in the table
         that is a NULL insert into the table"""
         pass
@@ -395,8 +392,8 @@ class ANSICompiler(sql.Compiled):
                 self.visit_insert_sequence(c, seq)
         vis = DefaultVisitor()
         for c in insert_stmt.table.c:
-            if (self.parameters is None or self.parameters.get(c.key, None) is None):
-                c.accept_visitor(vis)
+            if (isinstance(c, schema.SchemaItem) and (self.parameters is None or self.parameters.get(c.key, None) is None)):
+                c.accept_schema_visitor(vis)
         
         self.isinsert = True
         colparams = self._get_colparams(insert_stmt)
@@ -419,7 +416,7 @@ class ANSICompiler(sql.Compiled):
                 return self.bindparam_string(p.key)
             else:
                 p.accept_visitor(self)
-                if isinstance(p, sql.ClauseElement):
+                if isinstance(p, sql.ClauseElement) and not isinstance(p, sql.ColumnClause):
                     return "(" + self.get_str(p) + ")"
                 else:
                     return self.get_str(p)
@@ -466,7 +463,7 @@ class ANSICompiler(sql.Compiled):
         # now go thru compiled params, get the Column object for each key
         d = {}
         for key, value in parameters.iteritems():
-            if isinstance(key, schema.Column):
+            if isinstance(key, sql.ColumnClause):
                 d[key] = value
             else:
                 try:
index 04bdc24fa4e46313ad4bae44b56feaf037107b19..d660db7bdc69aa09f65cd3eb7e3bacc85d34802e 100644 (file)
@@ -131,11 +131,6 @@ class MySQLEngine(ansisql.ANSISQLEngine):
     def supports_sane_rowcount(self):
         return False
 
-    def tableimpl(self, table, **kwargs):
-        """returns a new sql.TableImpl object to correspond to the given Table object."""
-        mysql_engine = kwargs.pop('mysql_engine', None)
-        return MySQLTableImpl(table, mysql_engine=mysql_engine)
-
     def compiler(self, statement, bindparams, **kwargs):
         return MySQLCompiler(self, statement, bindparams, **kwargs)
 
@@ -175,7 +170,7 @@ class MySQLEngine(ansisql.ANSISQLEngine):
         #ischema.reflecttable(self, table, ischema_names, use_mysql=True)
         
         tabletype, foreignkeyD = self.moretableinfo(table=table)
-        table._impl.mysql_engine = tabletype
+        table.kwargs['mysql_engine'] = tabletype
         
         c = self.execute("describe " + table.name, {})
         while True:
@@ -235,14 +230,6 @@ class MySQLEngine(ansisql.ANSISQLEngine):
         return (tabletype, foreignkeyD)
         
 
-class MySQLTableImpl(sql.TableImpl):
-    """attached to a schema.Table to provide it with a Selectable interface
-    as well as other functions
-    """
-    def __init__(self, table, mysql_engine=None):
-        super(MySQLTableImpl, self).__init__(table)
-        self.mysql_engine = mysql_engine
-
 class MySQLCompiler(ansisql.ANSICompiler):
 
     def visit_function(self, func):
@@ -277,12 +264,13 @@ class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
             if first_pk and isinstance(column.type, types.Integer):
                 colspec += " AUTO_INCREMENT"
         if column.foreign_key:
-            colspec += ", FOREIGN KEY (%s) REFERENCES %s(%s)" % (column.name, column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) 
+            colspec += ", FOREIGN KEY (%s) REFERENCES %s(%s)" % (column.name, column.foreign_key.column.table.name, column.foreign_key.column.name) 
         return colspec
 
     def post_create_table(self, table):
-        if table.mysql_engine is not None:
-            return " ENGINE=%s" % table.mysql_engine
+        mysql_engine = table.kwargs.get('mysql_engine', None)
+        if mysql_engine is not None:
+            return " ENGINE=%s" % mysql_engine
         else:
             return ""
 
index 2115f5d568192fe607514ba8795b013ace9d2917..238310b1b1daa01ab910ecdf866e67f0e69fbba0 100644 (file)
@@ -312,7 +312,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
         if column.primary_key and not override_pk:
             colspec += " PRIMARY KEY"
         if column.foreign_key:
-            colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.fullname, column.column.foreign_key.column.name) 
+            colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.fullname, column.foreign_key.column.name) 
         return colspec
 
     def visit_sequence(self, sequence):
index 73b8769f257dffb247bd0fc141cf3558a994dad3..56488a197c8a172852eb19c22acf65345e8c6605 100644 (file)
@@ -16,8 +16,7 @@ A SQLEngine is provided to an application as a subclass that is specific to a pa
 of DBAPI, and is the central switching point for abstracting different kinds of database
 behavior into a consistent set of behaviors.  It provides a variety of factory methods 
 to produce everything specific to a certain kind of database, including a Compiler, 
-schema creation/dropping objects, and TableImpl and ColumnImpl objects to augment the
-behavior of table metadata objects.
+schema creation/dropping objects.
 
 The term "database-specific" will be used to describe any object or function that has behavior
 corresponding to a particular vendor, such as mysql-specific, sqlite-specific, etc.
@@ -131,7 +130,7 @@ class DefaultRunner(schema.SchemaVisitor):
 
     def get_column_default(self, column):
         if column.default is not None:
-            return column.default.accept_visitor(self)
+            return column.default.accept_schema_visitor(self)
         else:
             return None
 
@@ -296,11 +295,11 @@ class SQLEngine(schema.SchemaEngine):
         
     def create(self, entity, **params):
         """creates a table or index within this engine's database connection given a schema.Table object."""
-        entity.accept_visitor(self.schemagenerator(**params))
+        entity.accept_schema_visitor(self.schemagenerator(**params))
 
     def drop(self, entity, **params):
         """drops a table or index within this engine's database connection given a schema.Table object."""
-        entity.accept_visitor(self.schemadropper(**params))
+        entity.accept_schema_visitor(self.schemadropper(**params))
 
     def compile(self, statement, parameters, **kwargs):
         """given a sql.ClauseElement statement plus optional bind parameters, creates a new
@@ -315,28 +314,6 @@ class SQLEngine(schema.SchemaEngine):
         """given a Table object, reflects its columns and properties from the database."""
         raise NotImplementedError()
 
-    def tableimpl(self, table, **kwargs):
-        """returns a new sql.TableImpl object to correspond to the given Table object.
-        A TableImpl provides SQL statement builder operations on a Table metadata object, 
-        and a subclass of this object may be provided by a SQLEngine subclass to provide
-        database-specific behavior."""
-        return sql.TableImpl(table)
-
-    def columnimpl(self, column):
-        """returns a new sql.ColumnImpl object to correspond to the given Column object.
-        A ColumnImpl provides SQL statement builder operations on a Column metadata object, 
-        and a subclass of this object may be provided by a SQLEngine subclass to provide
-        database-specific behavior."""
-        return sql.ColumnImpl(column)
-
-    def indeximpl(self, index):
-        """returns a new sql.IndexImpl object to correspond to the given Index
-        object. An IndexImpl provides SQL statement builder operations on an
-        Index metadata object, and a subclass of this object may be provided
-        by a SQLEngine subclass to provide database-specific behavior.
-        """
-        return sql.IndexImpl(index)
-    
     def get_default_schema_name(self):
         """returns the currently selected schema in the current connection."""
         return None
index c1bdd9fa534fb5dbda7944f12a46e87c3787545e..4f1f9e4010c13fafbcec596960e66cb2b203ea8e 100644 (file)
@@ -13,7 +13,7 @@ class ProxyEngine(object):
     """
     SQLEngine proxy. Supports lazy and late initialization by
     delegating to a real engine (set with connect()), and using proxy
-    classes for TableImpl, ColumnImpl and TypeEngine.
+    classes for TypeEngine.
     """
 
     def __init__(self):
@@ -61,15 +61,6 @@ class ProxyEngine(object):
             return None
         return self.get_engine().oid_column_name()
     
-    def columnimpl(self, column):
-        """Proxy point: return a ProxyColumnImpl
-        """
-        return ProxyColumnImpl(self, column)
-
-    def tableimpl(self, table):
-        """Proxy point: return a ProxyTableImpl
-        """
-        return ProxyTableImpl(self, table)
         
     def type_descriptor(self, typeobj):
         """Proxy point: return a ProxyTypeEngine 
@@ -84,45 +75,6 @@ class ProxyEngine(object):
         raise AttributeError('No connection established in ProxyEngine: '
                              ' no access to %s' % attr)
 
-        
-class ProxyColumnImpl(sql.ColumnImpl):
-    """Proxy column; defers engine access to ProxyEngine
-    """
-    def __init__(self, engine, column):
-        sql.ColumnImpl.__init__(self, column)
-        self._engine = engine
-        self.impls = weakref.WeakKeyDictionary()
-    def _get_impl(self):
-        e = self._engine.engine
-        try:
-            return self.impls[e]
-        except KeyError:
-            impl = e.columnimpl(self.column)
-            self.impls[e] = impl
-    def __getattr__(self, key):
-        return getattr(self._get_impl(), key)
-    engine = property(lambda self: self._engine.engine)
-
-class ProxyTableImpl(sql.TableImpl):
-    """Proxy table; defers engine access to ProxyEngine
-    """
-    def __init__(self, engine, table):
-        sql.TableImpl.__init__(self, table)
-        self._engine = engine
-        self.impls = weakref.WeakKeyDictionary()
-    def _get_impl(self):
-        e = self._engine.engine
-        try:
-            return self.impls[e]
-        except KeyError:
-            impl = e.tableimpl(self.table)
-            self.impls[e] = impl
-            return impl
-    def __getattr__(self, key):
-        return getattr(self._get_impl(), key)
-
-    engine = property(lambda self: self._engine.engine)
-
 class ProxyType(object):
     """ProxyType base class; used by ProxyTypeEngine to construct proxying
     types    
index 61d9a3c2ef0d2e8b452d5bec9dd82d12ff7fb79a..33bec863e82ec74f09ecf2f8bb0e29b986018678 100644 (file)
@@ -262,7 +262,7 @@ class Mapper(object):
         """returns an instance of the object based on the given identifier, or None
         if not found.  The *ident argument is a 
         list of primary key columns in the order of the table def's primary key columns."""
-        key = objectstore.get_id_key(ident, self.class_, self.primarytable)
+        key = objectstore.get_id_key(ident, self.class_)
         #print "key: " + repr(key) + " ident: " + repr(ident)
         return self._get(key, ident)
         
@@ -284,7 +284,7 @@ class Mapper(object):
 
         
     def identity_key(self, *primary_key):
-        return objectstore.get_id_key(tuple(primary_key), self.class_, self.primarytable)
+        return objectstore.get_id_key(tuple(primary_key), self.class_)
     
     def instance_key(self, instance):
         return self.identity_key(*[self._getattrbycolumn(instance, column) for column in self.pks_by_table[self.table]])
@@ -683,7 +683,7 @@ class Mapper(object):
         return statement
         
     def _identity_key(self, row):
-        return objectstore.get_row_key(row, self.class_, self.identitytable, self.pks_by_table[self.table])
+        return objectstore.get_row_key(row, self.class_, self.pks_by_table[self.table])
 
     def _instance(self, row, imap, result = None, populate_existing = False):
         """pulls an object instance from the given row and appends it to the given result
index 311a6c54206968671e4f0f32ad741f3fe56e77b2..4f0dc4dafd30ae6b07d210ec8edafd2f146bde6b 100644 (file)
@@ -48,7 +48,7 @@ class Session(object):
             self.hash_key = hash_key
         _sessions[self.hash_key] = self
         
-    def get_id_key(ident, class_, table):
+    def get_id_key(ident, class_):
         """returns an identity-map key for use in storing/retrieving an item from the identity
         map, given a tuple of the object's primary key values.
 
@@ -62,10 +62,10 @@ class Session(object):
         selectable - a Selectable object which represents all the object's column-based fields.
         this Selectable may be synonymous with the table argument or can be a larger construct
         containing that table. return value: a tuple object which is used as an identity key. """
-        return (class_, table.hash_key(), tuple(ident))
+        return (class_, tuple(ident))
     get_id_key = staticmethod(get_id_key)
 
-    def get_row_key(row, class_, table, primary_key):
+    def get_row_key(row, class_, primary_key):
         """returns an identity-map key for use in storing/retrieving an item from the identity
         map, given a result set row.
 
@@ -80,7 +80,7 @@ class Session(object):
         this Selectable may be synonymous with the table argument or can be a larger construct
         containing that table. return value: a tuple object which is used as an identity key.
         """
-        return (class_, table.hash_key(), tuple([row[column] for column in primary_key]))
+        return (class_, tuple([row[column] for column in primary_key]))
     get_row_key = staticmethod(get_row_key)
 
     class SessionTrans(object):
@@ -181,7 +181,6 @@ class Session(object):
             return None
         key = getattr(instance, '_instance_key', None)
         mapper = object_mapper(instance)
-        key = (key[0], mapper.table.hash_key(), key[2])
         u = self.uow
         if key is not None:
             if u.identity_map.has_key(key):
@@ -194,11 +193,11 @@ class Session(object):
             u.register_new(instance)
         return instance
 
-def get_id_key(ident, class_, table):
-    return Session.get_id_key(ident, class_, table)
+def get_id_key(ident, class_):
+    return Session.get_id_key(ident, class_)
 
-def get_row_key(row, class_, table, primary_key):
-    return Session.get_row_key(row, class_, table, primary_key)
+def get_row_key(row, class_, primary_key):
+    return Session.get_row_key(row, class_, primary_key)
 
 def begin():
     """begins a new UnitOfWork transaction.  the next commit will affect only
index 8f5523138087be84c7f89b0ca825bf45dc9f2867..8e9a434825e3de5b9bf64919c75ef074b54a7441 100644 (file)
@@ -14,7 +14,7 @@ structure with its own clause-specific objects as well as the visitor interface,
 the schema package "plugs in" to the SQL package.
 
 """
-
+import sql
 from util import *
 from types import *
 from exceptions import *
@@ -29,30 +29,12 @@ class SchemaItem(object):
         for item in args:
             if item is not None:
                 item._set_parent(self)
-
-    def accept_visitor(self, visitor):
-        """all schema items implement an accept_visitor method that should call the appropriate
-        visit_XXXX method upon the given visitor object."""
-        raise NotImplementedError()
-
     def _set_parent(self, parent):
         """a child item attaches itself to its parent via this method."""
         raise NotImplementedError()
-
-    def hash_key(self):
-        """returns a string that identifies this SchemaItem uniquely"""
-        return "%s(%d)" % (self.__class__.__name__, id(self))
-
     def __repr__(self):
         return "%s()" % self.__class__.__name__
 
-    def __getattr__(self, key):
-        """proxies method calls to an underlying implementation object for methods not found
-        locally"""
-        if not self.__dict__.has_key('_impl'):
-            raise AttributeError(key)
-        return getattr(self._impl, key)
-
 def _get_table_key(engine, name, schema):
     if schema is not None and schema == engine.get_default_schema_name():
         schema = None
@@ -95,8 +77,10 @@ class TableSingleton(type):
             return table
 
         
-class Table(SchemaItem):
-    """represents a relational database table.  
+class Table(sql.TableClause, SchemaItem):
+    """represents a relational database table.  This subclasses sql.TableClause to provide
+    a table that is "wired" to an engine.  Whereas TableClause represents a table as its 
+    used in a SQL expression, Table represents a table as its created in the database.  
     
     Be sure to look at sqlalchemy.sql.TableImpl for additional methods defined on a Table."""
     __metaclass__ = TableSingleton
@@ -134,19 +118,15 @@ class Table(SchemaItem):
         the same table twice will result in an exception.
         
         """
-        self.name = name
-        self.columns = OrderedProperties()
-        self.c = self.columns
-        self.foreign_keys = []
-        self.primary_key = []
-        self.engine = engine
+        super(Table, self).__init__(name)
+        self._engine = engine
         self.schema = kwargs.pop('schema', None)
-        self._impl = self.engine.tableimpl(self, **kwargs)
         if self.schema is not None:
             self.fullname = "%s.%s" % (self.schema, self.name)
         else:
             self.fullname = self.name
-
+        self.kwargs = kwargs
+        
     def __repr__(self):
         return "Table(%s)" % string.join(
         [repr(self.name)] + [repr(self.engine)] +
@@ -160,44 +140,45 @@ class Table(SchemaItem):
         else:
             return self.schema + "." + self.name
         
-    def hash_key(self):
-        return "Table(%s)" % string.join(
-        [repr(self.name)] + [self.engine.hash_key()] +
-        ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']], ','
-        )
-        
     def reload_values(self, *args):
         """clears out the columns and other properties of this Table, and reloads them from the 
         given argument list.  This is used with the "redefine" keyword argument sent to the
         metaclass constructor."""
-        self.columns = OrderedProperties()
-        self.c = self.columns
-        self.foreign_keys = []
-        self.primary_key = []
-        self._impl = self.engine.tableimpl(self)
+        self._clear()
+        
+        print "RELOAD VALUES", args
         self._init_items(*args)
 
     def append_item(self, item):
         """appends a Column item or other schema item to this Table."""
         self._init_items(item)
-        
+    
+    def append_column(self, column):
+        if not column.hidden:
+            self._columns[column.key] = column
+        if column.primary_key:
+            self.primary_key.append(column)
+        column.table = self
+        column.type = self.engine.type_descriptor(column.type)
+            
     def _set_parent(self, schema):
         schema.tables[self.name] = self
         self.schema = schema
-
-    def accept_visitor(self, visitor): 
+    def accept_schema_visitor(self, visitor): 
         """traverses the given visitor across the Column objects inside this Table,
         then calls the visit_table method on the visitor."""
         for c in self.columns:
-            c.accept_visitor(visitor)
+            c.accept_schema_visitor(visitor)
         return visitor.visit_table(self)
-    
     def deregister(self):
         """removes this table from it's engines table registry.  this does not
         issue a SQL DROP statement."""
         key = _get_table_key(self.engine, self.name, self.schema)
         del self.engine.tables[key]
-        
+    def create(self, **params):
+        self.engine.create(self)
+    def drop(self, **params):
+        self.engine.drop(self)
     def toengine(self, engine, schema=None):
         """returns a singleton instance of this Table with a different engine"""
         try:
@@ -211,8 +192,9 @@ class Table(SchemaItem):
                 args.append(c.copy())
             return Table(self.name, engine, schema=schema, *args)
 
-class Column(SchemaItem):
-    """represents a column in a database table."""
+class Column(sql.ColumnClause, SchemaItem):
+    """represents a column in a database table.  this is a subclass of sql.ColumnClause and
+    represents an actual existing table in the database, in a similar fashion as TableClause/Table."""
     def __init__(self, name, type, *args, **kwargs):
         """constructs a new Column object.  Arguments are:
         
@@ -244,24 +226,27 @@ class Column(SchemaItem):
         hidden=False : indicates this column should not be listed in the table's list of columns.  Used for the "oid" 
         column, which generally isnt in column lists.
         """
-        self.name = str(name) # in case of incoming unicode
-        self.type = type
+        name = str(name) # in case of incoming unicode
+        super(Column, self).__init__(name, None, type)
         self.args = args
         self.key = kwargs.pop('key', name)
-        self.primary_key = kwargs.pop('primary_key', False)
+        self._primary_key = kwargs.pop('primary_key', False)
         self.nullable = kwargs.pop('nullable', not self.primary_key)
         self.hidden = kwargs.pop('hidden', False)
         self.default = kwargs.pop('default', None)
-        self.foreign_key = None
+        self._foreign_key = None
         self._orig = None
         self._parent = None
         if len(kwargs):
             raise ArgumentError("Unknown arguments passed to Column: " + repr(kwargs.keys()))
-        
+
+    primary_key = AttrProp('_primary_key')
+    foreign_key = AttrProp('_foreign_key')
     original = property(lambda s: s._orig or s)
     parent = property(lambda s:s._parent or s)
     engine = property(lambda s: s.table.engine)
-     
+    columns = property(lambda self:[self])
+
     def __repr__(self):
        return "Column(%s)" % string.join(
         [repr(self.name)] + [repr(self.type)] +
@@ -282,16 +267,7 @@ class Column(SchemaItem):
     def _set_parent(self, table):
         if getattr(self, 'table', None) is not None:
             raise ArgumentError("this Column already has a table!")
-        if not self.hidden:
-            table.columns[self.key] = self
-            if self.primary_key:
-                table.primary_key.append(self)
-        self.table = table
-        if self.table.engine is not None:
-            self.type = self.table.engine.type_descriptor(self.type)
-            
-        self._impl = self.table.engine.columnimpl(self)
-
+        table.append_column(self)
         if self.default is not None:
             self.default = ColumnDefault(self.default)
             self._init_items(self.default)
@@ -320,35 +296,19 @@ class Column(SchemaItem):
             selectable.columns[c.key] = c
             if self.primary_key:
                 selectable.primary_key.append(c)
-        c._impl = self.engine.columnimpl(c)
         if fk is not None:
             c._init_items(fk)
         return c
 
-    def accept_visitor(self, visitor):
+    def accept_schema_visitor(self, visitor):
         """traverses the given visitor to this Column's default and foreign key object,
         then calls visit_column on the visitor."""
         if self.default is not None:
-            self.default.accept_visitor(visitor)
+            self.default.accept_schema_visitor(visitor)
         if self.foreign_key is not None:
-            self.foreign_key.accept_visitor(visitor)
+            self.foreign_key.accept_schema_visitor(visitor)
         visitor.visit_column(self)
 
-    def __lt__(self, other): return self._impl.__lt__(other)
-    def __le__(self, other): return self._impl.__le__(other)
-    def __eq__(self, other): return self._impl.__eq__(other)
-    def __ne__(self, other): return self._impl.__ne__(other)
-    def __gt__(self, other): return self._impl.__gt__(other)
-    def __ge__(self, other): return self._impl.__ge__(other)
-    def __add__(self, other): return self._impl.__add__(other)
-    def __sub__(self, other): return self._impl.__sub__(other)
-    def __mul__(self, other): return self._impl.__mul__(other)
-    def __and__(self, other): return self._impl.__and__(other)
-    def __or__(self, other): return self._impl.__or__(other)
-    def __div__(self, other): return self._impl.__div__(other)
-    def __truediv__(self, other): return self._impl.__truediv__(other)
-    def __invert__(self, other): return self._impl.__invert__(other)
-    def __str__(self): return self._impl.__str__()
 
 class ForeignKey(SchemaItem):
     """defines a ForeignKey constraint between two columns.  ForeignKey is 
@@ -374,7 +334,7 @@ class ForeignKey(SchemaItem):
         elif self._colspec.table.schema is not None:
             return "%s.%s.%s" % (self._colspec.table.schema, self._colspec.table.name, self._colspec.column.key)
         else:
-            return "%s.%s" % (self._colspec.table.name, self._colspec.column.key)
+            return "%s.%s" % (self._colspec.table.name, self._colspec.key)
         
     def references(self, table):
         """returns True if the given table is referenced by this ForeignKey."""
@@ -406,7 +366,7 @@ class ForeignKey(SchemaItem):
             
     column = property(lambda s: s._init_column())
 
-    def accept_visitor(self, visitor):
+    def accept_schema_visitor(self, visitor):
         """calls the visit_foreign_key method on the given visitor."""
         visitor.visit_foreign_key(self)
         
@@ -432,7 +392,7 @@ class PassiveDefault(DefaultGenerator):
     """a default that takes effect on the database side"""
     def __init__(self, arg):
         self.arg = arg
-    def accept_visitor(self, visitor):
+    def accept_schema_visitor(self, visitor):
         return visitor.visit_passive_default(self)
     def __repr__(self):
         return "PassiveDefault(%s)" % repr(self.arg)
@@ -442,7 +402,7 @@ class ColumnDefault(DefaultGenerator):
     a callable function, or a SQL clause."""
     def __init__(self, arg):
         self.arg = arg
-    def accept_visitor(self, visitor):
+    def accept_schema_visitor(self, visitor):
         """calls the visit_column_default method on the given visitor."""
         return visitor.visit_column_default(self)
     def __repr__(self):
@@ -461,7 +421,7 @@ class Sequence(DefaultGenerator):
              ["%s=%s" % (k, repr(getattr(self, k))) for k in ['start', 'increment', 'optional']]
             , ',')
     
-    def accept_visitor(self, visitor):
+    def accept_schema_visitor(self, visitor):
         """calls the visit_seauence method on the given visitor."""
         return visitor.visit_sequence(self)
 
@@ -486,6 +446,7 @@ class Index(SchemaItem):
         self.unique = kw.pop('unique', False)
         self._init_items()
 
+    engine = property(lambda s:s.table.engine)
     def _init_items(self):
         # make sure all columns are from the same table
         # FIXME: and no column is repeated
@@ -499,10 +460,13 @@ class Index(SchemaItem):
                                  "%s is from %s not %s" % (column,
                                                            column.table,
                                                            self.table))
-        # set my _impl from col.table.engine
-        self._impl = self.table.engine.indeximpl(self)
-        
-    def accept_visitor(self, visitor):
+    def create(self):
+       self.engine.create(self)
+    def drop(self):
+       self.engine.drop(self)
+    def execute(self):
+       self.create()
+    def accept_schema_visitor(self, visitor):
         visitor.visit_index(self)
     def __str__(self):
         return repr(self)
@@ -515,24 +479,13 @@ class Index(SchemaItem):
 class SchemaEngine(object):
     """a factory object used to create implementations for schema objects.  This object
     is the ultimate base class for the engine.SQLEngine class."""
-    def tableimpl(self, table):
-        """returns a new implementation object for a Table (usually sql.TableImpl)"""
-        raise NotImplementedError()
-    def columnimpl(self, column):
-        """returns a new implementation object for a Column (usually sql.ColumnImpl)"""
-        raise NotImplementedError()
-    def indeximpl(self, index):
-        """returns a new implementation object for an Index (usually
-        sql.IndexImpl)
-        """
-        raise NotImplementedError()
     def reflecttable(self, table):
         """given a table, will query the database and populate its Column and ForeignKey 
         objects."""
         raise NotImplementedError()
         
-class SchemaVisitor(object):
-    """base class for an object that traverses across Schema structures."""
+class SchemaVisitor(sql.ClauseVisitor):
+    """defines the visiting for SchemaItem objects"""
     def visit_schema(self, schema):
         """visit a generic SchemaItem"""
         pass
index cbd9a82f31463b4c8e01032ef97c1e5d1672aaf8..8ebf7624efebe370928f3580c707c54b6ac9b3a4 100644 (file)
@@ -13,7 +13,7 @@ from exceptions import *
 import string, re, random
 types = __import__('types')
 
-__all__ = ['text', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
+__all__ = ['text', 'table', 'column', 'func', 'select', 'update', 'insert', 'delete', 'join', 'and_', 'or_', 'not_', 'union', 'union_all', 'desc', 'asc', 'outerjoin', 'alias', 'subquery', 'literal', 'bindparam', 'exists']
 
 def desc(column):
     """returns a descending ORDER BY clause element, e.g.:
@@ -160,11 +160,15 @@ def label(name, obj):
     """returns a Label object for the given selectable, used in the column list for a select statement."""
     return Label(name, obj)
     
-def column(table, text):
-    """returns a textual column clause, relative to a table.  this differs from using straight text
-    or text() in that the column is treated like a regular column, i.e. gets added to a Selectable's list
-    of columns."""
-    return ColumnClause(text, table)
+def column(text, table=None, type=None):
+    """returns a textual column clause, relative to a table.  this is also the primitive version of
+    a schema.Column which is a subclass. """
+    return ColumnClause(text, table, type)
+
+def table(name, *columns):
+    """returns a table clause.  this is a primitive version of the schema.Table object, which is a subclass
+    of this object."""
+    return TableClause(name, *columns)
     
 def bindparam(key, value = None, type=None):
     """creates a bind parameter clause with the given key.  
@@ -172,7 +176,7 @@ def bindparam(key, value = None, type=None):
     An optional default value can be specified by the value parameter, and the optional type parameter
     is a sqlalchemy.types.TypeEngine object which indicates bind-parameter and result-set translation for
     this bind parameter."""
-    if isinstance(key, schema.Column):
+    if isinstance(key, ColumnClause):
         return BindParamClause(key.name, value, type=key.type)
     else:
         return BindParamClause(key, value, type=type)
@@ -190,7 +194,7 @@ def text(text, engine=None, *args, **kwargs):
     text - the text of the SQL statement to be created.  use :<param> to specify
     bind parameters; they will be compiled to their engine-specific format.
 
-    engine - the engine to be used for this text query.  Alternatively, call the
+    engine - an optional engine to be used for this text query.  Alternatively, call the
     text() method off the engine directly.
 
     bindparams - a list of bindparam() instances which can be used to define the
@@ -222,15 +226,15 @@ def _compound_select(keyword, *selects, **kwargs):
     return CompoundSelect(keyword, *selects, **kwargs)
 
 def _is_literal(element):
-    return not isinstance(element, ClauseElement) and not isinstance(element, schema.SchemaItem)
+    return not isinstance(element, ClauseElement)
 
 def is_column(col):
-    return isinstance(col, schema.Column) or isinstance(col, ColumnElement)
+    return isinstance(col, ColumnElement)
 
-class ClauseVisitor(schema.SchemaVisitor):
-    """builds upon SchemaVisitor to define the visiting of SQL statement elements in 
-    addition to Schema elements."""
-    def visit_columnclause(self, column):pass
+class ClauseVisitor(object):
+    """Defines the visiting of ClauseElements."""
+    def visit_column(self, column):pass
+    def visit_table(self, column):pass
     def visit_fromclause(self, fromclause):pass
     def visit_bindparam(self, bindparam):pass
     def visit_textclause(self, textclause):pass
@@ -309,18 +313,6 @@ class Compiled(ClauseVisitor):
         
 class ClauseElement(object):
     """base class for elements of a programmatically constructed SQL expression."""
-    def hash_key(self):
-        """returns a string that uniquely identifies the concept this ClauseElement
-        represents.
-
-        two ClauseElements can have the same value for hash_key() iff they both correspond to
-        the exact same generated SQL.  This allows the hash_key() values of a collection of
-        ClauseElements to be constructed into a larger identifying string for the purpose of
-        caching a SQL expression.
-
-        Note that since ClauseElements may be mutable, the hash_key() value is subject to
-        change if the underlying structure of the ClauseElement changes.""" 
-        raise NotImplementedError(repr(self))
     def _get_from_objects(self):
         """returns objects represented in this ClauseElement that should be added to the
         FROM list of a query."""
@@ -357,19 +349,24 @@ class ClauseElement(object):
         return False
 
     def _find_engine(self):
+        """default strategy for locating an engine within the clause element.
+        relies upon a local engine property, or looks in the "from" objects which 
+        ultimately have to contain Tables or TableClauses. """
         try:
             if self._engine is not None:
                 return self._engine
         except AttributeError:
             pass
         for f in self._get_from_objects():
+            if f is self:
+                continue
             engine = f.engine
             if engine is not None: 
                 return engine
         else:
             return None
             
-    engine = property(lambda s: s._find_engine())
+    engine = property(lambda s: s._find_engine(), doc="attempts to locate a SQLEngine within this ClauseElement structure, or returns None if none found.")
     
     def compile(self, engine = None, parameters = None, typemap=None):
         """compiles this SQL expression using its underlying SQLEngine to produce
@@ -380,16 +377,13 @@ class ClauseElement(object):
             engine = self.engine
 
         if engine is None:
-            raise InvalidRequestError("no SQLEngine could be located within this ClauseElement.")
+            import sqlalchemy.ansisql as ansisql
+            engine = ansisql.engine()
 
         return engine.compile(self, parameters=parameters, typemap=typemap)
 
     def __str__(self):
-        e = self.engine
-        if e is None:
-            import sqlalchemy.ansisql as ansisql
-            e = ansisql.engine()
-        return str(self.compile(e))
+        return str(self.compile())
         
     def execute(self, *multiparams, **params):
         """compiles and executes this SQL expression using its underlying SQLEngine. the
@@ -425,6 +419,7 @@ class ClauseElement(object):
         return not_(self)
 
 class CompareMixin(object):
+    """defines comparison operations for ClauseElements."""
     def __lt__(self, other):
         return self._compare('<', other)
     def __le__(self, other):
@@ -500,19 +495,15 @@ class Selectable(ClauseElement):
 
     def accept_visitor(self, visitor):
         raise NotImplementedError(repr(self))
-
     def is_selectable(self):
         return True
-
     def select(self, whereclauses = None, **params):
         return select([self], whereclauses, **params)
-
     def _group_parenthesized(self):
         """indicates if this Selectable requires parenthesis when grouped into a compound
         statement"""
         return True
 
-
 class ColumnElement(Selectable, CompareMixin):
     """represents a column element within the list of a Selectable's columns.  Provides 
     default implementations for the things a "column" needs, including a "primary_key" flag,
@@ -552,8 +543,6 @@ class FromClause(Selectable):
             return [self.oid_column]    
         else:
             return self.primary_key
-    def hash_key(self):
-        return "FromClause(%s, %s)" % (repr(self.id), repr(self.from_name))
     def accept_visitor(self, visitor): 
         visitor.visit_fromclause(self)
     def count(self, whereclause=None, **params):
@@ -627,8 +616,6 @@ class BindParamClause(ClauseElement, CompareMixin):
         visitor.visit_bindparam(self)
     def _get_from_objects(self):
         return []
-    def hash_key(self):
-        return "BindParam(%s, %s, %s)" % (repr(self.key), repr(self.value), repr(self.shortname))
     def typeprocess(self, value, engine):
         return self._get_convert_type(engine).convert_bind_param(value, engine)
     def compare(self, other):
@@ -674,8 +661,6 @@ class TextClause(ClauseElement):
         for item in self.bindparams.values():
             item.accept_visitor(visitor)
         visitor.visit_textclause(self)
-    def hash_key(self):
-        return "TextClause(%s)" % repr(self.text)
     def _get_from_objects(self):
         return []
 
@@ -686,8 +671,6 @@ class Null(ClauseElement):
         visitor.visit_null(self)
     def _get_from_objects(self):
         return []
-    def hash_key(self):
-        return "Null"
 
 class ClauseList(ClauseElement):
     """describes a list of clauses.  by default, is comma-separated, 
@@ -698,8 +681,6 @@ class ClauseList(ClauseElement):
             if c is None: continue
             self.append(c)
         self.parens = kwargs.get('parens', False)
-    def hash_key(self):
-        return string.join([c.hash_key() for c in self.clauses], ",")
     def copy_container(self):
         clauses = [clause.copy_container() for clause in self.clauses]
         return ClauseList(parens=self.parens, *clauses)
@@ -753,8 +734,6 @@ class CompoundClause(ClauseList):
         for c in self.clauses:
             f += c._get_from_objects()
         return f
-    def hash_key(self):
-        return string.join([c.hash_key() for c in self.clauses], self.operator or " ")
     def compare(self, other):
         """compares this CompoundClause to the given item.  
         
@@ -794,8 +773,6 @@ class Function(ClauseList, ColumnElement):
         return BindParamClause(self.name, obj, shortname=self.name, type=self.type)
     def select(self):
         return select([self])
-    def hash_key(self):
-        return self.name + "(" + string.join([c.hash_key() for c in self.clauses], ", ") + ")"
     def _compare_type(self, obj):
         return self.type
                 
@@ -811,8 +788,6 @@ class BinaryClause(ClauseElement):
         return BinaryClause(self.left.copy_container(), self.right.copy_container(), self.operator)
     def _get_from_objects(self):
         return self.left._get_from_objects() + self.right._get_from_objects()
-    def hash_key(self):
-        return self.left.hash_key() + (self.operator or " ") + self.right.hash_key()
     def accept_visitor(self, visitor):
         self.left.accept_visitor(visitor)
         self.right.accept_visitor(visitor)
@@ -879,16 +854,9 @@ class Join(FromClause):
             return and_(*crit)
             
     def _group_parenthesized(self):
-        """indicates if this Selectable requires parenthesis when grouped into a compound
-        statement"""
         return True
-
-    def hash_key(self):
-        return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter))
-
     def select(self, whereclauses = None, **params):
         return select([self.left, self.right], whereclauses, from_obj=[self], **params)
-
     def accept_visitor(self, visitor):
         self.left.accept_visitor(visitor)
         self.right.accept_visitor(visitor)
@@ -941,9 +909,6 @@ class Alias(FromClause):
     def _exportable_columns(self):
         return self.selectable.columns
 
-    def hash_key(self):
-        return "Alias(%s, %s)" % (self.selectable.hash_key(), repr(self.name))
-
     def accept_visitor(self, visitor):
         self.selectable.accept_visitor(visitor)
         visitor.visit_alias(self)
@@ -975,35 +940,27 @@ class Label(ColumnElement):
         return self.obj._get_from_objects()
     def _make_proxy(self, selectable, name = None):
         return self.obj._make_proxy(selectable, name=self.name)
-        
-    def hash_key(self):
-        return "Label(%s, %s)" % (self.name, self.obj.hash_key())
      
 class ColumnClause(ColumnElement):
-    """represents a textual column clause in a SQL statement. allows the creation
-    of an additional ad-hoc column that is compiled against a particular table."""
-
-    def __init__(self, text, selectable=None):
-        self.text = text
+    """represents a textual column clause in a SQL statement.  May or may not
+    be bound to an underlying Selectable."""
+    def __init__(self, text, selectable=None, type=None):
+        self.key = self.name = self.text = text
         self.table = selectable
-        self.type = sqltypes.NullTypeEngine()
-
-    name = property(lambda self:self.text)
-    key = property(lambda self:self.text)
-    _label = property(lambda self:self.text)
-    
-    def accept_visitor(self, visitor): 
-        visitor.visit_columnclause(self)
-
-    def hash_key(self):
+        self.type = type or sqltypes.NullTypeEngine()
+    def _get_label(self):
         if self.table is not None:
-            return "ColumnClause(%s, %s)" % (self.text, util.hash_key(self.table))
+            return self.table.name + "_" + self.text
         else:
-            return "ColumnClause(%s)" % self.text
-
+            return self.text
+    _label = property(_get_label)
+    def accept_visitor(self, visitor): 
+        visitor.visit_column(self)
     def _get_from_objects(self):
-        return []
-
+        if self.table is not None:
+            return [self.table]
+        else:
+            return []
     def _bind_param(self, obj):
         if self.table.name is None:
             return BindParamClause(self.text, obj, shortname=self.text, type=self.type)
@@ -1013,79 +970,35 @@ class ColumnClause(ColumnElement):
         c = ColumnClause(name or self.text, selectable)
         selectable.columns[c.key] = c
         return c
-
-class ColumnImpl(ColumnElement):
-    """gets attached to a schema.Column object."""
-    
-    def __init__(self, column):
-        self.column = column
-        self.name = column.name
-        
-        if column.table.name:
-            self._label = column.table.name + "_" + self.column.name
-        else:
-            self._label = self.column.name
-
-    engine = property(lambda s: s.column.engine)
-    default_label = property(lambda s:s._label)
-    original = property(lambda self:self.column.original)
-    parent = property(lambda self:self.column.parent)
-    columns = property(lambda self:[self.column])
-    
-    def label(self, name):
-        return Label(name, self.column)
-        
-    def copy_container(self):
-        return self.column
-
-    def compare(self, other):
-        """compares this ColumnImpl's column to the other given Column"""
-        return self.column is other
-        
+    def _compare_type(self, obj):
+        return self.type
     def _group_parenthesized(self):
         return False
-        
-    def _get_from_objects(self):
-        return [self.column.table]
-    
-    def _bind_param(self, obj):
-        if self.column.table.name is None:
-            return BindParamClause(self.name, obj, shortname = self.name, type = self.column.type)
-        else:
-            return BindParamClause(self.column.table.name + "_" + self.name, obj, shortname = self.name, type = self.column.type)
-    def _compare_self(self):
-        """allows ColumnImpl to return its Column object for usage in ClauseElements, all others to
-        just return self"""
-        return self.column
-    def _compare_type(self, obj):
-        return self.column.type
-        
-    def compile(self, engine = None, parameters = None, typemap=None):
-        if engine is None:
-            engine = self.engine
-        if engine is None:
-            raise InvalidRequestError("no SQLEngine could be located within this ClauseElement.")
-        return engine.compile(self.column, parameters=parameters, typemap=typemap)
 
-class TableImpl(FromClause):
-    """attached to a schema.Table to provide it with a Selectable interface
-    as well as other functions
-    """
-
-    def __init__(self, table):
-        self.table = table
-        self.id = self.table.name
+class TableClause(FromClause):
+    def __init__(self, name, *columns):
+        super(TableClause, self).__init__(name)
+        self.name = self.id = self.fullname = name
+        self._columns = util.OrderedProperties()
+        self._foreign_keys = []
+        self._primary_key = []
+        for c in columns:
+            self.append_column(c)
 
+    def append_column(self, c):
+        self._columns[c.text] = c
+        c.table = self
     def _oid_col(self):
+        if self.engine is None:
+            return None
         # OID remains a little hackish so far
         if not hasattr(self, '_oid_column'):
-            if self.table.engine.oid_column_name() is not None:
-                self._oid_column = schema.Column(self.table.engine.oid_column_name(), sqltypes.Integer, hidden=True)
-                self._oid_column._set_parent(self.table)
+            if self.engine.oid_column_name() is not None:
+                self._oid_column = schema.Column(self.engine.oid_column_name(), sqltypes.Integer, hidden=True)
+                self._oid_column._set_parent(self)
             else:
                 self._oid_column = None
         return self._oid_column
-
     def _orig_columns(self):
         try:
             return self._orig_cols
@@ -1097,47 +1010,52 @@ class TableImpl(FromClause):
             if oid is not None:
                 self._orig_cols[oid.original] = oid
             return self._orig_cols
-            
-    oid_column = property(_oid_col)
-    engine = property(lambda s: s.table.engine)
-    columns = property(lambda self: self.table.columns)
-    primary_key = property(lambda self:self.table.primary_key)
-    foreign_keys = property(lambda self:self.table.foreign_keys)
+    columns = property(lambda s:s._columns)
+    c = property(lambda s:s._columns)
+    primary_key = property(lambda s:s._primary_key)
+    foreign_keys = property(lambda s:s._foreign_keys)
     original_columns = property(_orig_columns)
+    oid_column = property(_oid_col)
+
+    def _clear(self):
+        """clears all attributes on this TableClause so that new items can be added again"""
+        self.columns.clear()
+        self.foreign_keys[:] = []
+        self.primary_key[:] = []
+        try:
+            delattr(self, '_orig_cols')
+        except AttributeError:
+            pass
 
+    def accept_visitor(self, visitor):
+        visitor.visit_table(self)
     def _exportable_columns(self):
         raise NotImplementedError()
-        
     def _group_parenthesized(self):
         return False
-
     def _process_from_dict(self, data, asfrom):
         for f in self._get_from_objects():
             data.setdefault(f.id, f)
         if asfrom:
-            data[self.id] = self.table
+            data[self.id] = self
     def count(self, whereclause=None, **params):
-        return select([func.count(1).label('count')], whereclause, from_obj=[self.table], **params)
+        return select([func.count(1).label('count')], whereclause, from_obj=[self], **params)
     def join(self, right, *args, **kwargs):
-        return Join(self.table, right, *args, **kwargs)
+        return Join(self, right, *args, **kwargs)
     def outerjoin(self, right, *args, **kwargs):
-        return Join(self.table, right, isouter = True, *args, **kwargs)
+        return Join(self, right, isouter = True, *args, **kwargs)
     def alias(self, name=None):
-        return Alias(self.table, name)
+        return Alias(self, name)
     def select(self, whereclause = None, **params):
-        return select([self.table], whereclause, **params)
+        return select([self], whereclause, **params)
     def insert(self, values = None):
-        return insert(self.table, values=values)
+        return insert(self, values=values)
     def update(self, whereclause = None, values = None):
-        return update(self.table, whereclause, values)
+        return update(self, whereclause, values)
     def delete(self, whereclause = None):
-        return delete(self.table, whereclause)
-    def create(self, **params):
-        self.table.engine.create(self.table)
-    def drop(self, **params):
-        self.table.engine.drop(self.table)
+        return delete(self, whereclause)
     def _get_from_objects(self):
-        return [self.table]
+        return [self]
 
 class SelectBaseMixin(object):
     """base class for Select and CompoundSelects"""
@@ -1191,11 +1109,6 @@ class CompoundSelect(SelectBaseMixin, FromClause):
         order_by = kwargs.get('order_by', None)
         if order_by:
             self.order_by(*order_by)
-    def hash_key(self):
-        return "CompoundSelect(%s)" % string.join(
-            [util.hash_key(s) for s in self.selects] + 
-            ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'keyword']],
-            ",")
     def _exportable_columns(self):
         return self.selects[0].columns
     def _proxy_column(self, column):
@@ -1271,6 +1184,8 @@ class Select(SelectBaseMixin, FromClause):
             self.is_where = is_where
         def visit_compound_select(self, cs):
             self.visit_select(cs)
+        def visit_column(self, c):pass
+        def visit_table(self, c):pass
         def visit_select(self, select):
             if select is self.select:
                 return
@@ -1288,7 +1203,6 @@ class Select(SelectBaseMixin, FromClause):
         for f in column._get_from_objects():
             f.accept_visitor(self._correlator)
         column._process_from_dict(self._froms, False)
-        
     def _exportable_columns(self):
         return self._raw_columns
     def _proxy_column(self, column):
@@ -1313,24 +1227,6 @@ class Select(SelectBaseMixin, FromClause):
 
     _hash_recursion = util.RecursionStack()
     
-    def hash_key(self):
-        # selects call alot of stuff so we do some "recursion checking"
-        # to eliminate loops
-        if Select._hash_recursion.push(self):
-            return "recursive_select()"
-        try:
-            return "Select(%s)" % string.join(
-                [
-                    "columns=" + string.join([util.hash_key(c) for c in self._raw_columns],','),
-                    "where=" + util.hash_key(self.whereclause),
-                    "from=" + string.join([util.hash_key(f) for f in self.froms],','),
-                    "having=" + util.hash_key(self.having),
-                    "clauses=" + string.join([util.hash_key(c) for c in self.clauses], ',')
-                ] + ["%s=%s" % (k, repr(getattr(self, k))) for k in ['use_labels', 'distinct', 'limit', 'offset']], ","
-            ) 
-        finally:
-            Select._hash_recursion.pop(self)
-        
     def clear_from(self, id):
         self.append_from(FromClause(from_name = None, from_key = id))
         
@@ -1342,7 +1238,7 @@ class Select(SelectBaseMixin, FromClause):
         fromclause._process_from_dict(self._froms, True)
 
     def _get_froms(self):
-        return [f for f in self._froms.values() if self._correlated is None or not self._correlated.has_key(f.id)]
+        return [f for f in self._froms.values() if f is not self and (self._correlated is None or not self._correlated.has_key(f.id))]
     froms = property(lambda s: s._get_froms())
 
     def accept_visitor(self, visitor):
@@ -1388,9 +1284,6 @@ class Select(SelectBaseMixin, FromClause):
 class UpdateBase(ClauseElement):
     """forms the base for INSERT, UPDATE, and DELETE statements."""
     
-    def hash_key(self):
-        return str(id(self))
-        
     def _process_colparams(self, parameters):
         """receives the "values" of an INSERT or UPDATE statement and constructs
         appropriate ind parameters."""
@@ -1419,6 +1312,9 @@ class UpdateBase(ClauseElement):
                 except KeyError:
                     del parameters[key]
         return parameters
+
+    def _find_engine(self):
+        return self._engine
         
 
 class Insert(UpdateBase):
@@ -1457,25 +1353,3 @@ class Delete(UpdateBase):
             self.whereclause.accept_visitor(visitor)
         visitor.visit_delete(self)
 
-class IndexImpl(ClauseElement):
-
-    def __init__(self, index):
-        self.index = index
-        self.name = index.name
-        self._engine = self.index.table.engine
-
-    table = property(lambda s: s.index.table)
-    columns = property(lambda s: s.index.columns)
-        
-    def hash_key(self):
-        return self.index.hash_key()
-    def accept_visitor(self, visitor):
-        visitor.visit_index(self.index)
-    def compare(self, other):
-        return self.index is other
-    def create(self):
-        self._engine.create(self.index)
-    def drop(self):
-        self._engine.drop(self.index)
-    def execute(self):
-        self.create()
index 301db0ec4da3c2a82dab4c6aa8e651df0314b3ea..fccb2f3bd3e3a0af44e340e65c48e1113f35fc3d 100644 (file)
@@ -4,7 +4,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-__all__ = ['OrderedProperties', 'OrderedDict', 'generic_repr', 'HashSet']
+__all__ = ['OrderedProperties', 'OrderedDict', 'generic_repr', 'HashSet', 'AttrProp']
 import thread, weakref, UserList,string, inspect
 from exceptions import *
 
@@ -23,7 +23,21 @@ def to_set(x):
         return HashSet(to_list(x))
     else:
         return x
-        
+
+class AttrProp(object):
+    """a quick way to stick a property accessor on an object"""
+    def __init__(self, key):
+        self.key = key
+    def __set__(self, obj, value):
+        setattr(obj, self.key, value)
+    def __delete__(self, obj):
+        delattr(obj, self.key)
+    def __get__(self, obj, owner):
+        if obj is None:
+            return self
+        else:
+            return getattr(obj, self.key)
+    
 def generic_repr(obj, exclude=None):
     L = ['%s=%s' % (a, repr(getattr(obj, a))) for a in dir(obj) if not callable(getattr(obj, a)) and not a.startswith('_') and (exclude is None or not exclude.has_key(a))]
     return '%s(%s)' % (obj.__class__.__name__, ','.join(L))
@@ -65,9 +79,10 @@ class OrderedProperties(object):
     def __setattr__(self, key, object):
         if not hasattr(self, key):
             self._list.append(key)
-    
         self.__dict__[key] = object
-    
+    def clear(self):
+        for key in self._list[:]:
+            del self[key]
 class RecursionStack(object):
     """a thread-local stack used to detect recursive object traversals."""
     def __init__(self):
index 687c9b10283effbb1b75f980a656e14b0368e3da..63a39641e829b96a11f767d934795fb764517a0b 100644 (file)
@@ -961,6 +961,9 @@ class SaveTest2(AssertMixin):
             Column('email_address', String(20)),
             redefine=True
         )
+        x = sql.Join(self.users, self.addresses)
+#        raise repr(self.users) + repr(self.users.primary_key)
+#        raise repr(self.addresses) + repr(self.addresses.foreign_keys)
         self.users.create()
         self.addresses.create()
         db.echo = testbase.echo
index 625a1ec7cf459405c179514524482c3411f37304..788e39f7b5e07a62c04c207678d26484e7f69c80 100644 (file)
@@ -10,23 +10,26 @@ db = ansisql.engine()
 from testbase import PersistTest
 import unittest, re
 
-
-table = Table('mytable', db,
-    Column('myid', Integer, key = 'id'),
-    Column('name', String, key = 'name'),
-    Column('description', String, key = 'description'),
+# the select test now tests almost completely with TableClause/ColumnClause objects,
+# which are free-roaming table/column objects not attached to any database.  
+# so SQLAlchemy's SQL construction engine can be used with no database dependencies at all.
+
+table1 = table('mytable', 
+    column('myid'),
+    column('name'),
+    column('description'),
 )
 
-table2 = Table(
-    'myothertable', db,
-    Column('otherid', Integer, key='id'),
-    Column('othername', String, key='name'),
+table2 = table(
+    'myothertable', 
+    column('otherid'),
+    column('othername'),
 )
 
-table3 = Table(
-    'thirdtable', db,
-    Column('userid', Integer, key='id'),
-    Column('otherstuff', Integer),
+table3 = table(
+    'thirdtable', 
+    column('userid'),
+    column('otherstuff'),
 )
 
 table4 = Table(
@@ -37,27 +40,27 @@ table4 = Table(
     schema = 'remote_owner'
 )
 
-users = Table('users', db,
-    Column('user_id', Integer, primary_key = True),
-    Column('user_name', String(40)),
-    Column('password', String(10)),
+users = table('users', 
+    column('user_id'),
+    column('user_name'),
+    column('password'),
 )
 
-addresses = Table('addresses', db,
-    Column('address_id', Integer, primary_key = True),
-    Column('user_id', Integer, ForeignKey("users.user_id")),
-    Column('street', String(100)),
-    Column('city', String(80)),
-    Column('state', String(2)),
-    Column('zip', String(10))
+addresses = table('addresses', 
+    column('address_id'),
+    column('user_id'),
+    column('street'),
+    column('city'),
+    column('state'),
+    column('zip')
 )
 
-
 class SQLTest(PersistTest):
     def runtest(self, clause, result, engine = None, params = None, checkparams = None):
+        if engine is None:
+            engine = db
         c = clause.compile(engine, params)
         self.echo("\nSQL String:\n" + str(c) + repr(c.get_params()))
-        self.echo("\nHash Key:\n" + clause.hash_key())
         cc = re.sub(r'\n', '', str(c))
         self.assert_(cc == result, str(c) + "\n does not match \n" + result)
         if checkparams is not None:
@@ -67,53 +70,44 @@ class SQLTest(PersistTest):
                 self.assert_(c.get_params() == checkparams, "params dont match")
             
 class SelectTest(SQLTest):
-
-
     def testtableselect(self):
-        self.runtest(table.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable")
+        self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable")
 
-        self.runtest(select([table, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \
+        self.runtest(select([table1, table2]), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, \
 myothertable.othername FROM mytable, myothertable")
 
     def testsubquery(self):
-
-        # TODO: a subquery in a column clause.
-        #self.runtest(
-        #    select([table, select([table2.c.id])]),
-        #    """"""
-        #)
-
-        s = select([table], table.c.name == 'jack')
+        s = select([table1], table1.c.name == 'jack')
         print [key for key in s.c.keys()]
         self.runtest(
             select(
                 [s],
-                s.c.id == 7
+                s.c.myid == 7
             )
             ,
-        "SELECT id, name, description FROM (SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable WHERE mytable.name = :mytable_name) WHERE id = :id")
+        "SELECT myid, name, description FROM (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable WHERE mytable.name = :mytable_name) WHERE myid = :myid")
         
-        sq = select([table])
+        sq = select([table1])
         self.runtest(
             sq.select(),
-            "SELECT id, name, description FROM (SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable)"
+            "SELECT myid, name, description FROM (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)"
         )
         
         sq = subquery(
             'sq',
-            [table],
+            [table1],
         )
 
         self.runtest(
-            sq.select(sq.c.id == 7), 
-            "SELECT sq.id, sq.name, sq.description FROM \
-(SELECT mytable.myid AS id, mytable.name AS name, mytable.description AS description FROM mytable) AS sq WHERE sq.id = :sq_id"
+            sq.select(sq.c.myid == 7), 
+            "SELECT sq.myid, sq.name, sq.description FROM \
+(SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable) AS sq WHERE sq.myid = :sq_myid"
         )
         
         sq = subquery(
             'sq',
-            [table, table2],
-            and_(table.c.id ==7, table2.c.id==table.c.id),
+            [table1, table2],
+            and_(table1.c.myid ==7, table2.c.otherid==table1.c.myid),
             use_labels = True
         )
         
@@ -140,15 +134,15 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         
     def testand(self):
         self.runtest(
-            select(['*'], and_(table.c.id == 12, table.c.name=='asdf', table2.c.name == 'foo', "sysdate() = today()")), 
+            select(['*'], and_(table1.c.myid == 12, table1.c.name=='asdf', table2.c.othername == 'foo', "sysdate() = today()")), 
             "SELECT * FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name AND myothertable.othername = :myothertable_othername AND sysdate() = today()"
         )
 
     def testor(self):
         self.runtest(
-            select([table], and_(
-                table.c.id == 12,
-                or_(table2.c.name=='asdf', table2.c.name == 'foo', table2.c.id == 9),
+            select([table1], and_(
+                table1.c.myid == 12,
+                or_(table2.c.othername=='asdf', table2.c.othername == 'foo', table2.c.otherid == 9),
                 "sysdate() = today()", 
             )),
             "SELECT mytable.myid, mytable.name, mytable.description FROM mytable, myothertable WHERE mytable.myid = :mytable_myid AND (myothertable.othername = :myothertable_othername OR myothertable.othername = :myothertable_othername_1 OR myothertable.otherid = :myothertable_otherid) AND sysdate() = today()",
@@ -157,7 +151,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
 
     def testoperators(self):
         self.runtest(
-            table.select((table.c.id != 12) & ~(table.c.name=='john')), 
+            table1.select((table1.c.myid != 12) & ~(table1.c.name=='john')), 
             "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid != :mytable_myid AND NOT (mytable.name = :mytable_name)"
         )
         
@@ -167,35 +161,35 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
 
     def testmultiparam(self):
         self.runtest(
-            select(["*"], or_(table.c.id == 12, table.c.id=='asdf', table.c.id == 'foo')), 
+            select(["*"], or_(table1.c.myid == 12, table1.c.myid=='asdf', table1.c.myid == 'foo')), 
             "SELECT * FROM mytable WHERE mytable.myid = :mytable_myid OR mytable.myid = :mytable_myid_1 OR mytable.myid = :mytable_myid_2"
         )
 
     def testorderby(self):
         self.runtest(
-            table2.select(order_by = [table2.c.id, asc(table2.c.name)]),
+            table2.select(order_by = [table2.c.otherid, asc(table2.c.othername)]),
             "SELECT myothertable.otherid, myothertable.othername FROM myothertable ORDER BY myothertable.otherid, myothertable.othername ASC"
         )
     def testgroupby(self):
         self.runtest(
-            select([table2.c.name, func.count(table2.c.id)], group_by = [table2.c.name]),
+            select([table2.c.othername, func.count(table2.c.otherid)], group_by = [table2.c.othername]),
             "SELECT myothertable.othername, count(myothertable.otherid) FROM myothertable GROUP BY myothertable.othername"
         )
     def testgroupby_and_orderby(self):
         self.runtest(
-            select([table2.c.name, func.count(table2.c.id)], group_by = [table2.c.name], order_by = [table2.c.name]),
+            select([table2.c.othername, func.count(table2.c.otherid)], group_by = [table2.c.othername], order_by = [table2.c.othername]),
             "SELECT myothertable.othername, count(myothertable.otherid) FROM myothertable GROUP BY myothertable.othername ORDER BY myothertable.othername"
         )
     def testalias(self):
-        # test the alias for a table.  column names stay the same, table name "changes" to "foo".
+        # test the alias for a table1.  column names stay the same, table name "changes" to "foo".
         self.runtest(
-        select([alias(table, 'foo')])
+        select([alias(table1, 'foo')])
         ,"SELECT foo.myid, foo.name, foo.description FROM mytable AS foo")
     
         # create a select for a join of two tables.  use_labels means the column names will have
         # labels tablename_columnname, which become the column keys accessible off the Selectable object.
-        # also, only use one column from the second table and all columns from the first table.
-        q = select([table, table2.c.id], table.c.id == table2.c.id, use_labels = True)
+        # also, only use one column from the second table and all columns from the first table1.
+        q = select([table1, table2.c.otherid], table1.c.myid == table2.c.otherid, use_labels = True)
         
         # make an alias of the "selectable".  column names stay the same (i.e. the labels), table name "changes" to "t2view".
         a = alias(q, 't2view')
@@ -265,11 +259,11 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid =
         
     def testtextmix(self):
         self.runtest(select(
-            [table, table2.c.id, "sysdate()", "foo, bar, lala"],
+            [table1, table2.c.otherid, "sysdate()", "foo, bar, lala"],
             and_(
                 "foo.id = foofoo(lala)",
                 "datetime(foo) = Today",
-                table.c.id == table2.c.id,
+                table1.c.myid == table2.c.otherid,
             )
         ), 
         "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, sysdate(), foo, bar, lala \
@@ -277,68 +271,68 @@ FROM mytable, myothertable WHERE foo.id = foofoo(lala) AND datetime(foo) = Today
 
     def testtextualsubquery(self):
         self.runtest(select(
-            [alias(table, 't'), "foo.f"],
+            [alias(table1, 't'), "foo.f"],
             "foo.f = t.id",
             from_obj = ["(select f from bar where lala=heyhey) foo"]
         ), 
         "SELECT t.myid, t.name, t.description, foo.f FROM mytable AS t, (select f from bar where lala=heyhey) foo WHERE foo.f = t.id")
 
     def testliteral(self):
-        self.runtest(select([literal("foo") + literal("bar")], from_obj=[table]), 
+        self.runtest(select([literal("foo") + literal("bar")], from_obj=[table1]), 
             "SELECT :literal + :literal_1 FROM mytable", engine=db)
 
     def testfunction(self):
-        self.runtest(func.lala(3, 4, literal("five"), table.c.id) * table2.c.id, 
+        self.runtest(func.lala(3, 4, literal("five"), table1.c.myid) * table2.c.otherid, 
             "lala(:lala, :lala_1, :literal, mytable.myid) * myothertable.otherid", engine=db)
 
     def testjoin(self):
         self.runtest(
-            join(table2, table, table.c.id == table2.c.id).select(),
+            join(table2, table1, table1.c.myid == table2.c.otherid).select(),
             "SELECT myothertable.otherid, myothertable.othername, mytable.myid, mytable.name, \
 mytable.description FROM myothertable JOIN mytable ON mytable.myid = myothertable.otherid"
         )
 
         self.runtest(
             select(
-             [table],
-                from_obj = [join(table, table2, table.c.id == table2.c.id)]
+             [table1],
+                from_obj = [join(table1, table2, table1.c.myid == table2.c.otherid)]
             ),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid")
 
         self.runtest(
             select(
-                [join(join(table, table2, table.c.id == table2.c.id), table3, table.c.id == table3.c.id)
+                [join(join(table1, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid == table3.c.userid)
             ]),
             "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid JOIN thirdtable ON mytable.myid = thirdtable.userid"
         )
         
         self.runtest(
-            join(users, addresses).select(),
+            join(users, addresses, users.c.user_id==addresses.c.user_id).select(),
             "SELECT users.user_id, users.user_name, users.password, addresses.address_id, addresses.user_id, addresses.street, addresses.city, addresses.state, addresses.zip FROM users JOIN addresses ON users.user_id = addresses.user_id"
         )
         
     def testmultijoin(self):
         self.runtest(
-                select([table, table2, table3],
+                select([table1, table2, table3],
                 
-                from_obj = [join(table, table2, table.c.id == table2.c.id).outerjoin(table3, table.c.id==table3.c.id)]
+                from_obj = [join(table1, table2, table1.c.myid == table2.c.otherid).outerjoin(table3, table1.c.myid==table3.c.userid)]
                 
-                #from_obj = [outerjoin(join(table, table2, table.c.id == table2.c.id), table3, table.c.id==table3.c.id)]
+                #from_obj = [outerjoin(join(table, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid==table3.c.userid)]
                 )
                 ,"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON mytable.myid = thirdtable.userid"
             )
         self.runtest(
-                select([table, table2, table3],
-                from_obj = [outerjoin(table, join(table2, table3, table2.c.id == table3.c.id), table.c.id==table2.c.id)]
+                select([table1, table2, table3],
+                from_obj = [outerjoin(table1, join(table2, table3, table2.c.otherid == table3.c.userid), table1.c.myid==table2.c.otherid)]
                 )
                 ,"SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable LEFT OUTER JOIN (myothertable JOIN thirdtable ON myothertable.otherid = thirdtable.userid) ON mytable.myid = myothertable.otherid"
             )
             
     def testunion(self):
             x = union(
-                  select([table], table.c.id == 5),
-                  select([table], table.c.id == 12),
-                  order_by = [table.c.id],
+                  select([table1], table1.c.myid == 5),
+                  select([table1], table1.c.myid == 12),
+                  order_by = [table1.c.myid],
             )
   
             self.runtest(x, "SELECT mytable.myid, mytable.name, mytable.description \
@@ -348,7 +342,7 @@ FROM mytable WHERE mytable.myid = :mytable_myid_1 ORDER BY mytable.myid")
   
             self.runtest(
                     union(
-                        select([table]),
+                        select([table1]),
                         select([table2]),
                         select([table3])
                     )
@@ -365,14 +359,14 @@ FROM myothertable UNION SELECT thirdtable.userid, thirdtable.otherstuff FROM thi
         # parameters.
         
         query = select(
-                [table, table2],
+                [table1, table2],
                 and_(
-                    table.c.name == 'fred',
-                    table.c.id == 10,
-                    table2.c.name != 'jack',
+                    table1.c.name == 'fred',
+                    table1.c.myid == 10,
+                    table2.c.othername != 'jack',
                     "EXISTS (select yay from foo where boo = lar)"
                 ),
-                from_obj = [ outerjoin(table, table2, table.c.id == table2.c.id) ]
+                from_obj = [ outerjoin(table1, table2, table1.c.myid == table2.c.otherid) ]
                 )
                 
         self.runtest(query, 
@@ -393,9 +387,9 @@ myothertable.othername != :myothertable_othername AND EXISTS (select yay from fo
 
     def testbindparam(self):
         self.runtest(select(
-                    [table, table2],
-                    and_(table.c.id == table2.c.id,
-                    table.c.name == bindparam('mytablename'),
+                    [table1, table2],
+                    and_(table1.c.myid == table2.c.otherid,
+                    table1.c.name == bindparam('mytablename'),
                     )
                 ),
                 "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \
@@ -404,26 +398,26 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
 
         # check that the bind params sent along with a compile() call
         # get preserved when the params are retreived later
-        s = select([table], table.c.id == bindparam('test'))
-        c = s.compile(parameters = {'test' : 7})
+        s = select([table1], table1.c.myid == bindparam('test'))
+        c = s.compile(parameters = {'test' : 7}, engine=db)
         self.assert_(c.get_params() == {'test' : 7})
 
     def testcorrelatedsubquery(self):
         self.runtest(
-            table.select(table.c.id == select([table2.c.id], table.c.name == table2.c.name)),
-            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid AS id FROM myothertable WHERE mytable.name = myothertable.othername)"
+            table1.select(table1.c.myid == select([table2.c.otherid], table1.c.name == table2.c.othername)),
+            "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = (SELECT myothertable.otherid AS otherid FROM myothertable WHERE mytable.name = myothertable.othername)"
         )
 
         self.runtest(
-            table.select(exists([1], table2.c.id == table.c.id)),
+            table1.select(exists([1], table2.c.otherid == table1.c.myid)),
             "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = mytable.myid)"
         )
 
-        talias = table.alias('ta')
-        s = subquery('sq2', [talias], exists([1], table2.c.id == talias.c.id))
+        talias = table1.alias('ta')
+        s = subquery('sq2', [talias], exists([1], table2.c.otherid == talias.c.myid))
         self.runtest(
-            select([s, table])
-            ,"SELECT sq2.id, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT ta.myid AS id, ta.name AS name, ta.description AS description FROM mytable AS ta WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = ta.myid)) AS sq2, mytable")
+            select([s, table1])
+            ,"SELECT sq2.myid, sq2.name, sq2.description, mytable.myid, mytable.name, mytable.description FROM (SELECT ta.myid AS myid, ta.name AS name, ta.description AS description FROM mytable AS ta WHERE EXISTS (SELECT 1 FROM myothertable WHERE myothertable.otherid = ta.myid)) AS sq2, mytable")
 
         s = select([addresses.c.street], addresses.c.user_id==users.c.user_id).alias('s')
         self.runtest(
@@ -431,81 +425,80 @@ FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid AND mytable
             """SELECT users.user_id, users.user_name, users.password, s.street FROM users, (SELECT addresses.street AS street FROM addresses WHERE addresses.user_id = users.user_id) AS s""")
 
     def testin(self):
-        self.runtest(select([table], table.c.id.in_(1, 2, 3)),
+        self.runtest(select([table1], table1.c.myid.in_(1, 2, 3)),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:mytable_myid, :mytable_myid_1, :mytable_myid_2)")
 
-        self.runtest(select([table], table.c.id.in_(select([table2.c.id]))),
-        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid AS id FROM myothertable)")
+        self.runtest(select([table1], table1.c.myid.in_(select([table2.c.otherid]))),
+        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (SELECT myothertable.otherid AS otherid FROM myothertable)")
     
     def testlateargs(self):
         """tests that a SELECT clause will have extra "WHERE" clauses added to it at compile time if extra arguments
         are sent"""
         
-        self.runtest(table.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name AND mytable.myid = :mytable_myid", params={'id':'3', 'name':'jack'})
+        self.runtest(table1.select(), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.name = :mytable_name AND mytable.myid = :mytable_myid", params={'myid':'3', 'name':'jack'})
 
-        self.runtest(table.select(table.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'id':'3'})
+        self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3'})
 
-        self.runtest(table.select(table.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'id':'3', 'name':'fred'})
+        self.runtest(table1.select(table1.c.name=='jack'), "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid = :mytable_myid AND mytable.name = :mytable_name", params={'myid':'3', 'name':'fred'})
         
 class CRUDTest(SQLTest):
     def testinsert(self):
         # generic insert, will create bind params for all columns
-        self.runtest(insert(table), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)")
+        self.runtest(insert(table1), "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)")
 
         # insert with user-supplied bind params for specific columns,
         # cols provided literally
         self.runtest(
-            insert(table, {table.c.id : bindparam('userid'), table.c.name : bindparam('username')}), 
+            insert(table1, {table1.c.myid : bindparam('userid'), table1.c.name : bindparam('username')}), 
             "INSERT INTO mytable (myid, name) VALUES (:userid, :username)")
         
         # insert with user-supplied bind params for specific columns, cols
         # provided as strings
         self.runtest(
-            insert(table, dict(id = 3, name = 'jack')), 
+            insert(table1, dict(myid = 3, name = 'jack')), 
             "INSERT INTO mytable (myid, name) VALUES (:myid, :name)"
         )
 
         # test with a tuple of params instead of named
         self.runtest(
-            insert(table, (3, 'jack', 'mydescription')), 
+            insert(table1, (3, 'jack', 'mydescription')), 
             "INSERT INTO mytable (myid, name, description) VALUES (:myid, :name, :description)",
             checkparams = {'myid':3, 'name':'jack', 'description':'mydescription'}
         )
         
     def testupdate(self):
-        self.runtest(update(table, table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {table.c.name:'fred'})
-        self.runtest(update(table, table.c.id == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {'name':'fred'})
-        self.runtest(update(table, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid")
-        self.runtest(update(table, whereclause = table.c.name == bindparam('crit'), values = {table.c.name : 'hi'}), "UPDATE mytable SET name=:name WHERE mytable.name = :crit", params = {'crit' : 'notthere'})
-        self.runtest(update(table, table.c.id == 12, values = {table.c.name : table.c.id}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'})
-        self.runtest(update(table, table.c.id == 12, values = {table.c.id : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'})
-        s = table.update(table.c.id == 12, values = {table.c.name : 'lala'})
-        print str(s)
-        c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'})
+        self.runtest(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {table1.c.name:'fred'})
+        self.runtest(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {'name':'fred'})
+        self.runtest(update(table1, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid")
+        self.runtest(update(table1, whereclause = table1.c.name == bindparam('crit'), values = {table1.c.name : 'hi'}), "UPDATE mytable SET name=:name WHERE mytable.name = :crit", params = {'crit' : 'notthere'})
+        self.runtest(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'})
+        self.runtest(update(table1, table1.c.myid == 12, values = {table1.c.myid : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'})
+        s = table1.update(table1.c.myid == 12, values = {table1.c.name : 'lala'})
+        c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'}, engine=db)
         print str(c)
         self.assert_(str(s) == str(c))
         
     def testupdateexpression(self):
-        self.runtest(update(table, 
-            (table.c.id == func.hoho(4)) &
-            (table.c.name == literal('foo') + table.c.name + literal('lala')),
+        self.runtest(update(table1
+            (table1.c.myid == func.hoho(4)) &
+            (table1.c.name == literal('foo') + table1.c.name + literal('lala')),
             values = {
-            table.c.name : table.c.name + "lala",
-            table.c.id : func.do_stuff(table.c.id, literal('hoho'))
+            table1.c.name : table1.c.name + "lala",
+            table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho'))
             }), "UPDATE mytable SET myid=(do_stuff(mytable.myid, :literal_2)), name=(mytable.name + :mytable_name) WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal + mytable.name + :literal_1")
         
     def testcorrelatedupdate(self):
         # test against a straight text subquery
-        u = update(table, values = {table.c.name : text("select name from mytable where id=mytable.id")})
+        u = update(table1, values = {table1.c.name : text("select name from mytable where id=mytable.id")})
         self.runtest(u, "UPDATE mytable SET name=(select name from mytable where id=mytable.id)")
         
         # test against a regular constructed subquery
-        s = select([table2], table2.c.id == table.c.id)
-        u = update(table, table.c.name == 'jack', values = {table.c.name : s})
+        s = select([table2], table2.c.otherid == table1.c.myid)
+        u = update(table1, table1.c.name == 'jack', values = {table1.c.name : s})
         self.runtest(u, "UPDATE mytable SET name=(SELECT myothertable.otherid, myothertable.othername FROM myothertable WHERE myothertable.otherid = mytable.myid) WHERE mytable.name = :mytable_name")
         
     def testdelete(self):
-        self.runtest(delete(table, table.c.id == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")
+        self.runtest(delete(table1, table1.c.myid == 7), "DELETE FROM mytable WHERE mytable.myid = :mytable_myid")
         
 class SchemaTest(SQLTest):
     def testselect(self):
index a26b87bd4f2dcfa6f01a994c5a26338746fe3067..afdca47382cce4a8f1f51d34b47c11c9138d99f4 100644 (file)
@@ -2,9 +2,6 @@ import unittest
 import StringIO
 import sqlalchemy.engine as engine
 import re, sys
-import sqlalchemy.databases.sqlite as sqlite
-import sqlalchemy.databases.postgres as postgres
-#import sqlalchemy.databases.mysql as mysql
 
 echo = True
 #echo = False