]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- merged trunk r2880-r2901 (slightly manually for 2900-2901)
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Jul 2007 23:36:17 +0000 (23:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 14 Jul 2007 23:36:17 +0000 (23:36 +0000)
- merges "bind" argument change
- merges join fixes for [ticket:185]
- removed all "engine"/"connectable"/"bind_to"/"engine_or_url" arguments/attributes

26 files changed:
CHANGES
doc/build/content/tutorial.txt
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/ext/sqlsoup.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
lib/sqlalchemy/util.py
test/engine/alltests.py
test/engine/bind.py [new file with mode: 0644]
test/engine/metadata.py
test/engine/reflection.py
test/orm/inheritance/basic.py
test/orm/mapper.py
test/orm/query.py
test/orm/relationships.py
test/sql/defaults.py
test/sql/labels.py
test/sql/query.py
test/sql/selectable.py
test/sql/testtypes.py
test/sql/unicode.py
test/testbase.py

diff --git a/CHANGES b/CHANGES
index 194fefc5b0ba02bb9eb906af8746dc8741e166d0..229239fa42f9c1ba118bbc7b1df02711d70d0624 100644 (file)
--- a/CHANGES
+++ b/CHANGES
     - better error message for NoSuchColumnError [ticket:607]
     - finally figured out how to get setuptools version in, available
       as sqlalchemy.__version__ [ticket:428]
+    - the various "engine" arguments, such as "engine", "connectable",
+      "engine_or_url", "bind_to", etc. are all present, but deprecated.
+      they all get replaced by the single term "bind".  you also
+      set the "bind" of MetaData using 
+      metadata.bind = <engine or connection>
 - ext
     - iteration over dict association proxies is now dict-like, not
       InstrumentedList-like (e.g. over keys instead of values)
     - association proxies no longer bind tightly to source collections
       [ticket:597], and are constructed with a thunk instead
+    - added selectone_by() to assignmapper
 - orm
     - forwards-compatibility with 0.4: added one(), first(), and 
-      all() to Query
+      all() to Query.  almost all Query functionality from 0.4 is
+      present in 0.3.9 for forwards-compat purposes.
+    - reset_joinpoint() really really works this time, promise ! lets
+      you re-join from the root:
+      query.join(['a', 'b']).filter(<crit>).reset_joinpoint().\
+      join(['a', 'c']).filter(<some other crit>).all()
+      in 0.4 all join() calls start from the "root"
     - added synchronization to the mapper() construction step, to avoid
       thread collections when pre-existing mappers are compiling in a 
       different thread [ticket:613]
+    - a warning is issued by Mapper when two primary key columns of the
+      same name are munged into a single attribute.  this happens frequently
+      when mapping to joins (or inheritance). 
     - synonym() properties are fully supported by all Query joining/
       with_parent operations [ticket:598]
     - fixed very stupid bug when deleting items with many-to-many
     - DynamicMetaData has been renamed to ThreadLocalMetaData.  the
       DynamicMetaData name is deprecated and is an alias for ThreadLocalMetaData
       or a regular MetaData if threadlocal=False
+    - composite primary key is represented as a non-keyed set to allow for 
+      composite keys consisting of cols with the same name; occurs within a
+      Join.  helps inheritance scenarios formulate correct PK.
+    - improved ability to get the "correct" and most minimal set of primary key 
+      columns from a join, equating foreign keys and otherwise equated columns.
+      this is also mostly to help inheritance scenarios formulate the best 
+      choice of primary key columns.  [ticket:185]
+    - added 'bind' argument to Sequence.create()/drop(), ColumnDefault.execute()
     - some enhancements to "column targeting", the ability to match a column
       to a "corresponding" column in another selectable.  this affects mostly
       ORM ability to map to complex joins
     - the fix in "schema" above fixes reflection of foreign keys from an
       alt-schema table to a public schema table
 - sqlite
+    - rearranged dialect initialization so it has time to warn about pysqlite1
+      being too old.
     - sqlite better handles datetime/date/time objects mixed and matched
       with various Date/Time/DateTime columns
     - string PK column inserts dont get overwritten with OID [ticket:603] 
     - fix port option handling for pyodbc [ticket:634]
     - now able to reflect start and increment values for identity columns
     - preliminary support for using scope_identity() with pyodbc
-
-- extensions
-    - added selectone_by() to assignmapper
     
 0.3.8
 - engines
index 464d3044bcd97ed193af398a3a554f2628de64bf..615f275f9ac41ed67cd29d96b9fef6df45149193 100644 (file)
@@ -105,7 +105,7 @@ With `metadata` as our established home for tables, lets make a Table for it:
     >>> users_table = Table('users', metadata,
     ...     Column('user_id', Integer, primary_key=True),
     ...     Column('user_name', String(40)),
-    ...     Column('password', String(10))
+    ...     Column('password', String(15))
     ... )
 
 As you might have guessed, we have just defined a table named `users` which has three columns: `user_id` (which is a primary key column), `user_name` and `password`. Currently it is just an object that doesn't necessarily correspond to an existing table in our database.  To actually create the table, we use the `create()` method.  To make it interesting, we will have SQLAlchemy echo the SQL statements it sends to the database, by setting the `echo` flag on the `Engine` associated with our `MetaData`:
@@ -116,7 +116,7 @@ As you might have guessed, we have just defined a table named `users` which has
     CREATE TABLE users (
         user_id INTEGER NOT NULL,
         user_name VARCHAR(40),
-        password VARCHAR(10),
+        password VARCHAR(15),
         PRIMARY KEY (user_id)
     )
     ...
index 7e5956444591d19889556ad18fc8b89db16f9797..233fe050acd920d14104d3ec11374984e379a764 100644 (file)
@@ -1068,25 +1068,11 @@ class MySQLDialect(ansisql.ANSIDialect):
         return self._default_schema_name
 
     def has_table(self, connection, table_name, schema=None):
-        # TODO: this does not work for table names that contain multibyte characters.
-
-        # http://dev.mysql.com/doc/refman/5.0/en/error-messages-server.html
-
-        # Error: 1146 SQLSTATE: 42S02 (ER_NO_SUCH_TABLE)
-        # Message: Table '%s.%s' doesn't exist
-
-        # Error: 1046 SQLSTATE: 3D000 (ER_NO_DB_ERROR)
-        # Message: No database selected
-
-        try:
-            name = schema and ("%s.%s" % (schema, table_name)) or table_name
-            connection.execute("DESCRIBE `%s`" % name)
-            return True
-        except exceptions.SQLError, e:
-            if e.orig.args[0] in (1146, 1046): 
-                return False
-            else:
-                raise
+        if schema is not None:
+            st = 'SHOW TABLE STATUS FROM `%s` LIKE %%s' % schema
+        else:
+            st = 'SHOW TABLE STATUS LIKE %s'
+        return connection.execute(st, table_name).rowcount != 0
 
     def get_version_info(self, connectable):
         if hasattr(connectable, 'connect'):
@@ -1102,34 +1088,36 @@ class MySQLDialect(ansisql.ANSIDialect):
         return tuple(version)
 
     def reflecttable(self, connection, table):
-        # reference:  http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
-        cs = connection.execute("show variables like 'lower_case_table_names'").fetchone()[1]
-        if isinstance(cs, array):
-            cs = cs.tostring()
-        case_sensitive = int(cs) == 0
+        """Load column definitions from the server."""
 
-        decode_from = connection.execute("show variables like 'character_set_results'").fetchone()[1]
+        decode_from = self._detect_charset(connection)
+
+        # reference:
+        # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html
+        row = _compat_fetch(connection.execute(
+            "SHOW VARIABLES LIKE 'lower_case_table_names'"),
+                            one=True, charset=decode_from)
+        if not row:
+            case_sensitive = True
+        else:
+            case_sensitive = row[1] in ('0', 'OFF' 'off')
 
         if not case_sensitive:
             table.name = table.name.lower()
             table.metadata.tables[table.name]= table
+
         try:
-            c = connection.execute("describe " + table.fullname, {})
+            rp = connection.execute("describe " + self._escape_table_name(table),
+                                   {})
         except:
-            raise exceptions.NoSuchTableError(table.name)
-        found_table = False
-        while True:
-            row = c.fetchone()
-            if row is None:
-                break
-            #print "row! " + repr(row)
-            if not found_table:
-                found_table = True
-
-            # these can come back as unicode if use_unicode=1 in the mysql connection
-            (name, type, nullable, primary_key, default) = (row[0], str(row[1]), row[2] == 'YES', row[3] == 'PRI', row[4])
-            if not isinstance(name, unicode):
-                name = name.decode(decode_from)
+            raise exceptions.NoSuchTableError(table.fullname)
+
+        for row in _compat_fetch(rp, charset=decode_from):
+            (name, type, nullable, primary_key, default) = \
+                   (row[0], row[1], row[2] == 'YES', row[3] == 'PRI', row[4])
+
+            # leave column names as unicode
+            name = name.decode(decode_from)
 
             match = re.match(r'(\w+)(\(.*?\))?\s*(\w+)?\s*(\w+)?', type)
             col_type = match.group(1)
@@ -1137,7 +1125,6 @@ class MySQLDialect(ansisql.ANSIDialect):
             extra_1 = match.group(3)
             extra_2 = match.group(4)
 
-            #print "coltype: " + repr(col_type) + " args: " + repr(args) + "extras:" + repr(extra_1) + ' ' + repr(extra_2)
             try:
                 coltype = ischema_names[col_type]
             except KeyError:
@@ -1162,32 +1149,24 @@ class MySQLDialect(ansisql.ANSIDialect):
             colargs= []
             if default:
                 if col_type == 'timestamp' and default == 'CURRENT_TIMESTAMP':
-                    arg = sql.text(default)
-                else:
-                    arg = default
-                colargs.append(schema.PassiveDefault(arg))
+                    default = sql.text(default)
+                colargs.append(schema.PassiveDefault(default))
             table.append_column(schema.Column(name, coltype, *colargs,
                                             **dict(primary_key=primary_key,
                                                    nullable=nullable,
                                                    )))
 
-        tabletype = self.moretableinfo(connection, table=table)
+        tabletype = self.moretableinfo(connection, table, decode_from)
         table.kwargs['mysql_engine'] = tabletype
 
-        if not found_table:
-            raise exceptions.NoSuchTableError(table.name)
+    def moretableinfo(self, connection, table, charset=None):
+        """SHOW CREATE TABLE to get foreign key/table options."""
 
-    def moretableinfo(self, connection, table):
-        """runs SHOW CREATE TABLE to get foreign key/options information about the table.
-        
-        """
-        c = connection.execute("SHOW CREATE TABLE " + table.fullname, {})
-        desc_fetched = c.fetchone()[1]
-
-        if not isinstance(desc_fetched, basestring):
-            # may get array.array object here, depending on version (such as mysql 4.1.14 vs. 4.1.11)
-            desc_fetched = desc_fetched.tostring()
-        desc = desc_fetched.strip()
+        rp = connection.execute("SHOW CREATE TABLE " + self._escape_table_name(table), {})
+        row = _compat_fetch(rp, one=True, charset=charset)
+        if not row:
+            raise exceptions.NoSuchTableError(table.fullname)
+        desc = row[1].strip()
 
         tabletype = ''
         lastparen = re.search(r'\)[^\)]*\Z', desc)
@@ -1207,9 +1186,68 @@ class MySQLDialect(ansisql.ANSIDialect):
 
         return tabletype
 
+    def _escape_table_name(self, table):
+        if table.schema is not None:
+            return '`%s`.`%s`' % (table.schema. table.name)
+        else:
+            return '`%s`' % table.name
+
+    def _detect_charset(self, connection):
+        """Sniff out the character set in use for connection results."""
+
+        # Note: MySQL-python 1.2.1c7 seems to ignore changes made
+        # on a connection via set_character_set()
+        
+        rs = connection.execute("show variables like 'character_set%%'")
+        opts = dict([(row[0], row[1]) for row in _compat_fetch(rs)])
+
+        if 'character_set_results' in opts:
+            return opts['character_set_results']
+        try:
+            return connection.connection.character_set_name()
+        except AttributeError:
+            # < 1.2.1 final MySQL-python drivers have no charset support
+            if 'character_set' in opts:
+                return opts['character_set']
+            else:
+                warnings.warn(RuntimeWarning("Could not detect the connection character set with this combination of MySQL server and MySQL-python.  MySQL-python >= 1.2.2 is recommended.  Assuming latin1."))
+                return 'latin1'
+
+def _compat_fetch(rp, one=False, charset=None):
+    """Proxy result rows to smooth over MySQL-Python driver inconsistencies."""
+
+    if one:
+        return _MySQLPythonRowProxy(rp.fetchone(), charset)
+    else:
+        return [_MySQLPythonRowProxy(row, charset) for row in rp.fetchall()]
+        
+
+class _MySQLPythonRowProxy(object):
+    """Return consistent column values for all versions of MySQL-python (esp. alphas) and unicode settings."""
+
+    def __init__(self, rowproxy, charset):
+        self.rowproxy = rowproxy
+        self.charset = charset
+    def __getitem__(self, index):
+        item = self.rowproxy[index]
+        if isinstance(item, array):
+            item = item.tostring()
+        if self.charset and isinstance(item, unicode):
+            return item.encode(self.charset)
+        else:
+            return item
+    def __getattr__(self, attr):
+        item = getattr(self.rowproxy, attr)
+        if isinstance(item, array):
+            item = item.tostring()
+        if self.charset and isinstance(item, unicode):
+            return item.encode(self.charset)
+        else:
+            return item
+
+
 class MySQLCompiler(ansisql.ANSICompiler):
     def visit_cast(self, cast):
-        """hey ho MySQL supports almost no types at all for CAST"""
         if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)):
             return super(MySQLCompiler, self).visit_cast(cast)
         else:
index 0fd31928af838bd2528d030c0e6133991c5683f7..7ccf38c4bb25a94060d9eff3ad19a33d374e12b3 100644 (file)
@@ -161,13 +161,13 @@ class SQLiteDialect(ansisql.ANSIDialect):
         ansisql.ANSIDialect.__init__(self, default_paramstyle='qmark', **kwargs)
         def vers(num):
             return tuple([int(x) for x in num.split('.')])
-        self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
         if self.dbapi is not None:
             sqlite_ver = self.dbapi.version_info
             if sqlite_ver < (2,1,'3'):
                 warnings.warn(RuntimeWarning("The installed version of pysqlite2 (%s) is out-dated, and will cause errors in some cases.  Version 2.1.3 or greater is recommended." % '.'.join([str(subver) for subver in sqlite_ver])))
             if vers(self.dbapi.sqlite_version) < vers("3.3.13"):
                 warnings.warn(RuntimeWarning("The installed version of sqlite (%s) is out-dated, and will cause errors in some cases.  Version 3.3.13 or greater is recommended." % self.dbapi.sqlite_version))
+        self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
         
     def dbapi(cls):
         try:
index 1143fbf59a047ee0d838cc3f087eeb84bde660d0..6c23422a1f37e3f3580dcf25482fe8b32a700d7b 100644 (file)
@@ -406,7 +406,7 @@ class Compiled(sql.ClauseVisitor):
     defaults.
     """
 
-    def __init__(self, dialect, statement, parameters, engine=None):
+    def __init__(self, dialect, statement, parameters, bind=None):
         """Construct a new ``Compiled`` object.
 
         statement
@@ -426,13 +426,13 @@ class Compiled(sql.ClauseVisitor):
           can either be the string names of columns or
           ``_ColumnClause`` objects.
 
-        engine
-          Optional Engine to compile this statement against.
+        bind
+          Optional Engine or Connection to compile this statement against.
         """
         self.dialect = dialect
         self.statement = statement
         self.parameters = parameters
-        self.engine = engine
+        self.bind = bind
         self.can_execute = statement.supports_execution()
 
     def compile(self):
@@ -465,9 +465,9 @@ class Compiled(sql.ClauseVisitor):
     def execute(self, *multiparams, **params):
         """Execute this compiled object."""
 
-        e = self.engine
+        e = self.bind
         if e is None:
-            raise exceptions.InvalidRequestError("This Compiled object is not bound to any engine.")
+            raise exceptions.InvalidRequestError("This Compiled object is not bound to any Engine or Connection.")
         return e.execute_compiled(self, *multiparams, **params)
 
     def scalar(self, *multiparams, **params):
@@ -691,7 +691,7 @@ class Connection(Connectable):
         return self.execute(object, *multiparams, **params).scalar()
 
     def compiler(self, statement, parameters, **kwargs):
-        return self.dialect.compiler(statement, parameters, engine=self.engine, **kwargs)
+        return self.dialect.compiler(statement, parameters, bind=self.engine, **kwargs)
 
     def execute(self, object, *multiparams, **params):
         for c in type(object).__mro__:
@@ -945,14 +945,14 @@ class Engine(Connectable):
             connection.close()
 
     def _func(self):
-        return sql._FunctionGenerator(engine=self)
+        return sql._FunctionGenerator(bind=self)
 
     func = property(_func)
 
     def text(self, text, *args, **kwargs):
         """Return a sql.text() object for performing literal queries."""
 
-        return sql.text(text, engine=self, *args, **kwargs)
+        return sql.text(text, bind=self, *args, **kwargs)
 
     def _run_visitor(self, visitorcallable, element, connection=None, **kwargs):
         if connection is None:
@@ -1014,7 +1014,7 @@ class Engine(Connectable):
         return connection.execute_compiled(compiled, *multiparams, **params)
 
     def compiler(self, statement, parameters, **kwargs):
-        return self.dialect.compiler(statement, parameters, engine=self, **kwargs)
+        return self.dialect.compiler(statement, parameters, bind=self, **kwargs)
 
     def connect(self, **kwargs):
         """Return a newly allocated Connection object."""
@@ -1510,7 +1510,7 @@ class DefaultRunner(schema.SchemaVisitor):
         return None
 
     def exec_default_sql(self, default):
-        c = sql.select([default.arg]).compile(engine=self.connection)
+        c = sql.select([default.arg]).compile(bind=self.connection)
         return self.connection.execute_compiled(c).scalar()
 
     def visit_column_onupdate(self, onupdate):
index 9940610f754b8877ec6b13dfd7b27835d3802fff..90399b7b5c3b11940dfd18471eeacd6b2115bd61 100644 (file)
@@ -266,7 +266,7 @@ directly.  The engine's ``execute`` method corresponds to the one of a
 DBAPI cursor, and returns a ``ResultProxy`` that has ``fetch`` methods
 you would also see on a cursor::
 
-    >>> rp = db.engine.execute('select name, email from users order by name')
+    >>> rp = db.bind.execute('select name, email from users order by name')
     >>> for name, email in rp.fetchall(): print name, email
     Bhargan Basepair basepair+nospam@example.edu
     Joe Student student@example.edu
@@ -497,9 +497,10 @@ class SqlSoup:
         self.schema = None
 
     def engine(self):
-        return self._metadata._engine
+        return self._metadata.bind
 
     engine = property(engine)
+    bind = engine
 
     def delete(self, *args, **kwargs):
         objectstore.delete(*args, **kwargs)
index 3a5f60c2726080288d1573af0003698509e23b33..5718d49dd5312e6bb2be226981e12a61ae5be141 100644 (file)
@@ -10,7 +10,7 @@ from sqlalchemy.orm import util as mapperutil
 from sqlalchemy.orm.util import ExtensionCarrier
 from sqlalchemy.orm import sync
 from sqlalchemy.orm.interfaces import MapperProperty, MapperOption, OperationContext, EXT_PASS, MapperExtension, SynonymProperty
-import weakref
+import weakref, warnings
 
 __all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry']
 
@@ -543,9 +543,11 @@ class Mapper(object):
         # against the "mapped_table" of this mapper.
         equivalent_columns = self._get_equivalent_columns()
         
-        primary_key = sql.ColumnCollection()
+        primary_key = sql.ColumnSet()
 
         for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+            #primary_key.add(col)
+            #continue
             c = self.mapped_table.corresponding_column(col, raiseerr=False)
             if c is None:
                 for cc in equivalent_columns[col]:
@@ -690,6 +692,8 @@ class Mapper(object):
                     prop = prop.copy()
                     prop.set_parent(self)
                     self.__props[column_key] = prop
+                if column in self.primary_key and prop.columns[-1] in self.primary_key:
+                    warnings.warn(RuntimeWarning("On mapper %s, primary key column '%s' is being combined with distinct primary key column '%s' in attribute '%s'.  Use explicit properties to give each column its own mapped attribute name." % (str(self), str(column), str(prop.columns[-1]), column_key)))
                 prop.columns.append(column)
                 self.__log("appending to existing ColumnProperty %s" % (column_key))
             else:
@@ -1360,7 +1364,7 @@ class Mapper(object):
                 statement = table.delete(clause)
                 c = connection.execute(statement, delete)
                 if c.supports_sane_rowcount() and c.rowcount != len(delete):
-                    raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.cursor.rowcount, len(delete)))
+                    raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (c.rowcount, len(delete)))
 
         for obj in deleted_objects:
             for mapper in object_mapper(obj).iterate_to_root():
index 46ce21a7934e942cdca593c0c9f1ad9d6c126737..17d8feabb14401f0e2dace67fe08d97ce45d3fee 100644 (file)
@@ -38,21 +38,21 @@ class SessionTransaction(object):
     def _begin(self):
         return SessionTransaction(self.session, self)
 
-    def add(self, connectable):
-        if self.connections.has_key(connectable.engine):
+    def add(self, bind):
+        if self.connections.has_key(bind.engine):
             raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
-        return self.get_or_add(connectable)
+        return self.get_or_add(bind)
 
-    def get_or_add(self, connectable):
+    def get_or_add(self, bind):
         # we reference the 'engine' attribute on the given object, which in the case of
         # Connection, ProxyEngine, Engine, whatever, should return the original
         # "Engine" object that is handling the connection.
-        if self.connections.has_key(connectable.engine):
-            return self.connections[connectable.engine][0]
-        e = connectable.engine
-        c = connectable.contextual_connect()
+        if self.connections.has_key(bind.engine):
+            return self.connections[bind.engine][0]
+        e = bind.engine
+        c = bind.contextual_connect()
         if not self.connections.has_key(e):
-            self.connections[e] = (c, c.begin(), c is not connectable)
+            self.connections[e] = (c, c.begin(), c is not bind)
         return self.connections[e][0]
 
     def commit(self):
@@ -99,13 +99,13 @@ class Session(object):
     of Sessions, see the ``sqlalchemy.ext.sessioncontext`` module.
     """
 
-    def __init__(self, bind_to=None, hash_key=None, import_session=None, echo_uow=False, weak_identity_map=False):
+    def __init__(self, bind=None, bind_to=None, hash_key=None, import_session=None, echo_uow=False, weak_identity_map=False):
         if import_session is not None:
             self.uow = unitofwork.UnitOfWork(identity_map=import_session.uow.identity_map, weak_identity_map=weak_identity_map)
         else:
             self.uow = unitofwork.UnitOfWork(weak_identity_map=weak_identity_map)
 
-        self.bind_to = bind_to
+        self.bind = bind or bind_to
         self.binds = {}
         self.echo_uow = echo_uow
         self.weak_identity_map = weak_identity_map
@@ -122,6 +122,8 @@ class Session(object):
     def _set_echo_uow(self, value):
         self.uow.echo = value
     echo_uow = property(_get_echo_uow,_set_echo_uow)
+    
+    bind_to = property(lambda self:self.bind)
 
     def create_transaction(self, **kwargs):
         """Return a new ``SessionTransaction`` corresponding to an
@@ -213,23 +215,23 @@ class Session(object):
 
         return _class_mapper(class_, entity_name = entity_name)
 
-    def bind_mapper(self, mapper, bindto):
+    def bind_mapper(self, mapper, bind):
         """Bind the given `mapper` to the given ``Engine`` or ``Connection``.
 
         All subsequent operations involving this ``Mapper`` will use the
-        given `bindto`.
+        given `bind`.
         """
 
-        self.binds[mapper] = bindto
+        self.binds[mapper] = bind
 
-    def bind_table(self, table, bindto):
+    def bind_table(self, table, bind):
         """Bind the given `table` to the given ``Engine`` or ``Connection``.
 
         All subsequent operations involving this ``Table`` will use the
-        given `bindto`.
+        given `bind`.
         """
 
-        self.binds[table] = bindto
+        self.binds[table] = bind
 
     def get_bind(self, mapper):
         """Return the ``Engine`` or ``Connection`` which is used to execute
@@ -259,17 +261,17 @@ class Session(object):
         """
 
         if mapper is None:
-            return self.bind_to
+            return self.bind
         elif self.binds.has_key(mapper):
             return self.binds[mapper]
         elif self.binds.has_key(mapper.mapped_table):
             return self.binds[mapper.mapped_table]
-        elif self.bind_to is not None:
-            return self.bind_to
+        elif self.bind is not None:
+            return self.bind
         else:
-            e = mapper.mapped_table.engine
+            e = mapper.mapped_table.bind
             if e is None:
-                raise exceptions.InvalidRequestError("Could not locate any Engine bound to mapper '%s'" % str(mapper))
+                raise exceptions.InvalidRequestError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper))
             return e
 
     def query(self, mapper_or_class, entity_name=None, **kwargs):
index 9b9ae801a911d9d1de46aa683efe21127e473c14..897f397b61bdff4400a8f0cffbfd07ecbc8f01dd 100644 (file)
@@ -57,21 +57,27 @@ class SchemaItem(object):
 
         return None
 
-    def _get_engine(self):
+    def _get_engine(self, raiseerr=False):
         """Return the engine or None if no engine."""
 
-        return self._derived_metadata().engine
-
-    def get_engine(self, connectable=None):
-        """Return the engine or raise an error if no engine."""
-
-        if connectable is not None:
-            return connectable
-        e = self._get_engine()
-        if e is not None:
-            return e
+        if raiseerr:
+            m = self._derived_metadata()
+            e = m and m.bind or None
+            if e is None:
+                raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.")
+            else:
+                return e
         else:
-            raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine")
+            m = self._derived_metadata()
+            return m and m.bind or None
+
+    def get_engine(self):
+        """Return the engine or raise an error if no engine.
+        
+        Deprecated.  use the "bind" attribute.
+        """
+        
+        return self._get_engine(raiseerr=True)
 
     def _set_casing_strategy(self, kwargs, keyname='case_sensitive'):
         """Set the "case_sensitive" argument sent via keywords to the item's constructor.
@@ -121,9 +127,9 @@ class SchemaItem(object):
             return self.__case_sensitive
     case_sensitive = property(_get_case_sensitive)
 
-    engine = property(lambda s:s._get_engine())
     metadata = property(lambda s:s._derived_metadata())
-
+    bind = property(lambda s:s._get_engine())
+    
 def _get_table_key(name, schema):
     if schema is None:
         return name
@@ -159,7 +165,7 @@ class _TableSingleton(sql._FigureVisitName):
                     if autoload_with:
                         autoload_with.reflecttable(table)
                     else:
-                        metadata.get_engine().reflecttable(table)
+                        metadata._get_engine(raiseerr=True).reflecttable(table)
                 except exceptions.NoSuchTableError:
                     del metadata.tables[key]
                     raise
@@ -269,7 +275,9 @@ class Table(SchemaItem, sql.TableClause):
         self.schema = kwargs.pop('schema', None)
         self.indexes = util.Set()
         self.constraints = util.Set()
+        self._columns = sql.ColumnCollection()
         self.primary_key = PrimaryKeyConstraint()
+        self._foreign_keys = util.OrderedSet()
         self.quote = kwargs.pop('quote', False)
         self.quote_schema = kwargs.pop('quote_schema', False)
         if self.schema is not None:
@@ -289,6 +297,11 @@ class Table(SchemaItem, sql.TableClause):
 
     key = property(lambda self:_get_table_key(self.name, self.schema))
     
+    def _export_columns(self, columns=None):
+        # override FromClause's collection initialization logic; TableClause and Table
+        # implement it differently
+        pass
+
     def _get_case_sensitive_schema(self):
         try:
             return getattr(self, '_case_sensitive_schema')
@@ -343,30 +356,37 @@ class Table(SchemaItem, sql.TableClause):
             else:
                 return []
 
-    def exists(self, connectable=None):
+    def exists(self, bind=None, connectable=None):
         """Return True if this table exists."""
 
-        if connectable is None:
-            connectable = self.get_engine()
+        if connectable is not None:
+            bind = connectable
+            
+        if bind is None:
+            bind = self._get_engine(raiseerr=True)
 
         def do(conn):
             e = conn.engine
             return e.dialect.has_table(conn, self.name, schema=self.schema)
-        return connectable.run_callable(do)
+        return bind.run_callable(do)
 
-    def create(self, connectable=None, checkfirst=False):
+    def create(self, bind=None, checkfirst=False, connectable=None):
         """Issue a ``CREATE`` statement for this table.
 
         See also ``metadata.create_all()``."""
 
-        self.metadata.create_all(connectable=connectable, checkfirst=checkfirst, tables=[self])
+        if connectable is not None:
+            bind = connectable
+        self.metadata.create_all(bind=bind, checkfirst=checkfirst, tables=[self])
 
-    def drop(self, connectable=None, checkfirst=False):
+    def drop(self, bind=None, checkfirst=False, connectable=None):
         """Issue a ``DROP`` statement for this table.
 
         See also ``metadata.drop_all()``."""
 
-        self.metadata.drop_all(connectable=connectable, checkfirst=checkfirst, tables=[self])
+        if connectable is not None:
+            bind = connectable
+        self.metadata.drop_all(bind=bind, checkfirst=checkfirst, tables=[self])
 
     def tometadata(self, metadata, schema=None):
         """Return a copy of this ``Table`` associated with a different ``MetaData``."""
@@ -527,8 +547,16 @@ class Column(SchemaItem, sql._ColumnClause):
         return self.table.metadata
 
     def _get_engine(self):
-        return self.table.engine
+        return self.table.bind
 
+    def references(self, column):
+        """return true if this column references the given column via foreign key"""
+        for fk in self.foreign_keys:
+            if fk.column is column:
+                return True
+        else:
+            return False
+            
     def append_foreign_key(self, fk):
         fk._set_parent(self)
 
@@ -744,7 +772,7 @@ class DefaultGenerator(SchemaItem):
 
     def __init__(self, for_update=False, metadata=None):
         self.for_update = for_update
-        self._metadata = metadata
+        self._metadata = util.assert_arg_type(metadata, (MetaData, type(None)), 'metadata')
 
     def _derived_metadata(self):
         try:
@@ -763,8 +791,10 @@ class DefaultGenerator(SchemaItem):
         else:
             self.column.default = self
 
-    def execute(self, connectable=None, **kwargs):
-        return self.get_engine(connectable=connectable).execute_default(self, **kwargs)
+    def execute(self, bind=None, **kwargs):
+        if bind is None:
+            bind = self._get_engine(raiseerr=True)
+        return bind.execute_default(self, **kwargs)
 
     def __repr__(self):
         return "DefaultGenerator()"
@@ -822,12 +852,15 @@ class Sequence(DefaultGenerator):
         super(Sequence, self)._set_parent(column)
         column.sequence = self
 
-    def create(self, connectable=None, checkfirst=True):
-       self.get_engine(connectable=connectable).create(self, checkfirst=checkfirst)
-       return self
+    def create(self, bind=None, checkfirst=True):
+        if bind is None:
+            bind = self._get_engine(raiseerr=True)
+        bind.create(self, checkfirst=checkfirst)
 
-    def drop(self, connectable=None, checkfirst=True):
-       self.get_engine(connectable=connectable).drop(self, checkfirst=checkfirst)
+    def drop(self, bind=None, checkfirst=True):
+        if bind is None:
+            bind = self._get_engine(raiseerr=True)
+        bind.drop(self, checkfirst=checkfirst)
 
 
 class Constraint(SchemaItem):
@@ -1022,14 +1055,14 @@ class Index(SchemaItem):
         if connectable is not None:
             connectable.create(self)
         else:
-            self.get_engine().create(self)
+            self._get_engine(raiseerr=True).create(self)
         return self
 
     def drop(self, connectable=None):
         if connectable is not None:
             connectable.drop(self)
         else:
-            self.get_engine().drop(self)
+            self._get_engine(raiseerr=True).drop(self)
 
     def __str__(self):
         return repr(self)
@@ -1045,31 +1078,23 @@ class MetaData(SchemaItem):
 
     __visit_name__ = 'metadata'
     
-    def __init__(self, engine_or_url=None, **kwargs):
+    def __init__(self, bind=None, **kwargs):
         """create a new MetaData object.
-        
-            url
-                a string or URL instance which will be passed to create_engine(),
-                along with \**kwargs - this MetaData will be bound to the resulting
-                engine.
             
-            engine
-                an Engine instance to which this MetaData will be bound.
-                
+            bind
+                an Engine, or a string or URL instance which will be passed
+                to create_engine(), along with \**kwargs - this MetaData will
+                be bound to the resulting engine.
+
             case_sensitive
                 popped from \**kwargs, indicates default case sensitive setting for
                 all contained objects.  defaults to True.
             
         """        
-        
-        if engine_or_url is None:
-            # limited backwards compatability
-            engine_or_url = kwargs.get('url', None) or kwargs.get('engine', None)
+
         self.tables = {}
-        self._engine = None
         self._set_casing_strategy(kwargs)
-        if engine_or_url:
-            self.connect(engine_or_url, **kwargs)
+        self.bind = bind
 
     def __getstate__(self):
         return {'tables':self.tables, 'casesensitive':self._case_sensitive_setting}
@@ -1077,38 +1102,32 @@ class MetaData(SchemaItem):
     def __setstate__(self, state):
         self.tables = state['tables']
         self._case_sensitive_setting = state['casesensitive']
-        self._engine = None
+        self._bind = None
         
     def is_bound(self):
         """return True if this MetaData is bound to an Engine."""
-        return self._engine is not None
+        return self._bind is not None
 
-    def connect(self, engine_or_url, **kwargs):
+    def connect(self, bind, **kwargs):
         """bind this MetaData to an Engine.
+            
+            DEPRECATED.  use metadata.bind = <engine> or metadata.bind = <url>.
         
-            engine_or_url
+            bind
                 a string, URL or Engine instance.  If a string or URL,
                 will be passed to create_engine() along with \**kwargs to
                 produce the engine which to connect to.  otherwise connects
                 directly to the given Engine.
-                
+
         """
         
         from sqlalchemy.engine.url import URL
-        if isinstance(engine_or_url, (basestring, URL)):
-            self._engine = sqlalchemy.create_engine(engine_or_url, **kwargs)
+        if isinstance(bind, (basestring, URL)):
+            self._bind = sqlalchemy.create_engine(bind, **kwargs)
         else:
-            self._engine = engine_or_url
+            self._bind = bind
 
-    def _get_engine(self):
-        # we are checking is_bound() because _engine wires 
-        # into SchemaItem's _engine mechanism, which raises an error,
-        # whereas we just want to return None.
-        if not self.is_bound():
-            return None
-        return self._engine
-
-    engine = property(_get_engine, connect)
+    bind = property(lambda self:self._bind, connect, doc="""an Engine or Connection to which this MetaData is bound.  this is a settable property as well.""")
     
     def clear(self):
         self.tables.clear()
@@ -1129,47 +1148,64 @@ class MetaData(SchemaItem):
     def _get_parent(self):
         return None
 
-    def create_all(self, connectable=None, tables=None, checkfirst=True):
+    def create_all(self, bind=None, tables=None, checkfirst=True, connectable=None):
         """Create all tables stored in this metadata.
 
         This will conditionally create tables depending on if they do
         not yet exist in the database.
 
+        bind
+          A ``Connectable`` used to access the database; if None, uses
+          the existing bind on this ``MetaData``, if any.
+
         connectable
-          A ``Connectable`` used to access the database; or use the engine
-          bound to this ``MetaData``.
+          deprecated.  synonymous with "bind"
 
         tables
           Optional list of tables, which is a subset of the total
           tables in the ``MetaData`` (others are ignored).
         """
 
-        if connectable is None:
-            connectable = self.get_engine()
-        connectable.create(self, checkfirst=checkfirst, tables=tables)
+        if connectable is not None:
+            bind = connectable
+        if bind is None:
+            bind = self._get_engine(raiseerr=True)
+        bind.create(self, checkfirst=checkfirst, tables=tables)
 
-    def drop_all(self, connectable=None, tables=None, checkfirst=True):
+    def drop_all(self, bind=None, tables=None, checkfirst=True, connectable=None):
         """Drop all tables stored in this metadata.
 
         This will conditionally drop tables depending on if they
         currently exist in the database.
 
+        bind
+          A ``Connectable`` used to access the database; if None, uses
+          the existing bind on this ``MetaData``, if any.
+          
         connectable
-          A ``Connectable`` used to access the database; or use the engine
-          bound to this ``MetaData``.
+          deprecated.  synonymous with "bind"
 
         tables
           Optional list of tables, which is a subset of the total
           tables in the ``MetaData`` (others are ignored).
         """
 
-        if connectable is None:
-            connectable = self.get_engine()
-        connectable.drop(self, checkfirst=checkfirst, tables=tables)
+        if connectable is not None:
+            bind = connectable
+        if bind is None:
+            bind = self._get_engine(raiseerr=True)
+        bind.drop(self, checkfirst=checkfirst, tables=tables)
 
     def _derived_metadata(self):
         return self
 
+    def _get_engine(self, raiseerr=False):
+        if not self.is_bound():
+            if raiseerr:
+                raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.")
+            else:
+                return None
+        return self._bind
 
 class ThreadLocalMetaData(MetaData):
     """Build upon ``MetaData`` to provide the capability to bind to 
@@ -1209,13 +1245,15 @@ thread-local basis.
         for e in self.__engines.values():
             e.dispose()
 
-    def _get_engine(self):
+    def _get_engine(self, raiseerr=False):
         if hasattr(self.context, '_engine'):
             return self.context._engine
         else:
-            return None
-
-    engine = property(_get_engine, connect)
+            if raiseerr:
+                raise exceptions.InvalidRequestError("This SchemaItem is not connected to any Engine or Connection.")
+            else: 
+                return None
+    bind = property(_get_engine, connect)
 
 
 class SchemaVisitor(sql.ClauseVisitor):
index 32c20bc10ff47ab0fe2874a876b3bf614585e9d8..c5eeda9c935af0b838c1df6d761e09194a4c4b79 100644 (file)
@@ -218,12 +218,12 @@ def select(columns=None, whereclause=None, from_obj=[], **kwargs):
           and oracle supports "nowait" which translates to 
           ``FOR UPDATE NOWAIT``.
         
-        engine=None
-          an ``Engine`` instance to which the resulting ``Select`` 
+        bind=None
+          an ``Engine`` or ``Connection`` instance to which the resulting ``Select`` 
           object will be bound.  The ``Select`` object will otherwise
-          automatically bind to whatever ``Engine`` instances can be located
+          automatically bind to whatever ``Connectable`` instances can be located
           within its contained ``ClauseElement`` members.
-      
+        
         limit=None
           a numerical value which usually compiles to a ``LIMIT`` expression
           in the resulting select.  Databases that don't support ``LIMIT``
@@ -708,7 +708,7 @@ def bindparam(key, value=None, type=None, shortname=None, unique=False):
     else:
         return _BindParamClause(key, value, type=type, shortname=shortname, unique=unique)
 
-def text(text, engine=None, *args, **kwargs):
+def text(text, bind=None, *args, **kwargs):
     """Create literal text to be inserted into a query.
 
     When constructing a query from a ``select()``, ``update()``,
@@ -723,9 +723,9 @@ def text(text, engine=None, *args, **kwargs):
         to specify bind parameters; they will be compiled to their
         engine-specific format.
 
-      engine
-        An optional engine to be used for this text query.
-
+      bind
+        An optional connection or engine to be used for this text query.
+        
       bindparams
         A list of ``bindparam()`` instances which can be used to define
         the types and/or initial values for the bind parameters within
@@ -742,7 +742,7 @@ def text(text, engine=None, *args, **kwargs):
 
     """
 
-    return _TextClause(text, engine=engine, *args, **kwargs)
+    return _TextClause(text, bind=bind, *args, **kwargs)
 
 def null():
     """Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement."""
@@ -1040,22 +1040,20 @@ class ClauseElement(object):
         """
 
         try:
-            if self._engine is not None:
-                return self._engine
+            if self._bind is not None:
+                return self._bind
         except AttributeError:
             pass
         for f in self._get_from_objects():
             if f is self:
                 continue
-            engine = f.engine
+            engine = f.bind
             if engine is not None:
                 return engine
         else:
             return None
-
-    engine = property(lambda s: s._find_engine(),
-                      doc="""Attempts to locate a Engine within this ClauseElement
-                      structure, or returns None if none found.""")
+    
+    bind = property(lambda s:s._find_engine(), doc="""Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""")
 
     def execute(self, *multiparams, **params):
         """Compile and execute this ``ClauseElement``."""
@@ -1064,7 +1062,7 @@ class ClauseElement(object):
             compile_params = multiparams[0]
         else:
             compile_params = params
-        return self.compile(engine=self.engine, parameters=compile_params).execute(*multiparams, **params)
+        return self.compile(bind=self.bind, parameters=compile_params).execute(*multiparams, **params)
 
     def scalar(self, *multiparams, **params):
         """Compile and execute this ``ClauseElement``, returning the
@@ -1073,7 +1071,7 @@ class ClauseElement(object):
 
         return self.execute(*multiparams, **params).scalar()
 
-    def compile(self, engine=None, parameters=None, compiler=None, dialect=None):
+    def compile(self, bind=None, parameters=None, compiler=None, dialect=None):
         """Compile this SQL expression.
 
         Uses the given ``Compiler``, or the given ``AbstractDialect``
@@ -1102,10 +1100,10 @@ class ClauseElement(object):
         if compiler is None:
             if dialect is not None:
                 compiler = dialect.compiler(self, parameters)
-            elif engine is not None:
-                compiler = engine.compiler(self, parameters)
-            elif self.engine is not None:
-                compiler = self.engine.compiler(self, parameters)
+            elif bind is not None:
+                compiler = bind.compiler(self, parameters)
+            elif self.bind is not None:
+                compiler = self.bind.compiler(self, parameters)
 
         if compiler is None:
             import sqlalchemy.ansisql as ansisql
@@ -1473,6 +1471,25 @@ class ColumnCollection(util.OrderedProperties):
         # "True" value (i.e. a BinaryClause...)
         return col in util.Set(self)
 
+class ColumnSet(util.OrderedSet):
+    def contains_column(self, col):
+        return col in self
+        
+    def extend(self, cols):
+        for col in cols:
+            self.add(col)
+
+    def __add__(self, other):
+        return list(self) + list(other)
+
+    def __eq__(self, other):
+        l = []
+        for c in other:
+            for local in self:
+                if c.shares_lineage(local):
+                    l.append(c==local)
+        return and_(*l)
+            
 class FromClause(Selectable):
     """Represent an element that can be used within the ``FROM``
     clause of a ``SELECT`` statement.
@@ -1616,7 +1633,7 @@ class FromClause(Selectable):
         """)
     oid_column = property(_get_oid_column)
 
-    def _export_columns(self):
+    def _export_columns(self, columns=None):
         """Initialize column collections.
 
         The collections include the primary key, foreign keys, list of
@@ -1629,14 +1646,17 @@ class FromClause(Selectable):
         its parent ``Selectable`` is this ``FromClause``.
         """
 
-        if hasattr(self, '_columns'):
+        if hasattr(self, '_columns') and columns is None:
             # TODO: put a mutex here ?  this is a key place for threading probs
             return
         self._columns = ColumnCollection()
-        self._primary_key = ColumnCollection()
+        self._primary_key = ColumnSet()
         self._foreign_keys = util.Set()
         self._orig_cols = {}
-        for co in self._flatten_exportable_columns():
+
+        if columns is None:
+            columns = self._flatten_exportable_columns()
+        for co in columns:
             cp = self._proxy_column(co)
             for ci in cp.orig_set:
                 cx = self._orig_cols.get(ci)
@@ -1756,8 +1776,8 @@ class _TextClause(ClauseElement):
 
     __visit_name__ = 'textclause'
     
-    def __init__(self, text = "", engine=None, bindparams=None, typemap=None):
-        self._engine = engine
+    def __init__(self, text = "", bind=None, bindparams=None, typemap=None):
+        self._bind = bind
         self.bindparams = {}
         self.typemap = typemap
         if typemap is not None:
@@ -1883,7 +1903,7 @@ class _CalculatedClause(ColumnElement):
     def __init__(self, name, *clauses, **kwargs):
         self.name = name
         self.type = sqltypes.to_instance(kwargs.get('type', None))
-        self._engine = kwargs.get('engine', None)
+        self._bind = kwargs.get('bind', None)
         self.group = kwargs.pop('group', True)
         self.clauses = ClauseList(operator=kwargs.get('operator', None), group_contents=kwargs.get('group_contents', True), *clauses)
         if self.group:
@@ -1928,7 +1948,7 @@ class _Function(_CalculatedClause, FromClause):
         self.type = sqltypes.to_instance(kwargs.get('type', None))
         self.packagenames = kwargs.get('packagenames', None) or []
         kwargs['operator'] = ','
-        self._engine = kwargs.get('engine', None)
+        self._bind = kwargs.get('bind', None)
         _CalculatedClause.__init__(self, name, **kwargs)
         for c in clauses:
             self.append(c)
@@ -2091,15 +2111,38 @@ class Join(FromClause):
     encodedname = property(lambda s: s.name.encode('ascii', 'backslashreplace'))
 
     def _init_primary_key(self):
-        pkcol = util.OrderedSet()
-        for col in self._flatten_exportable_columns():
-            if col.primary_key:
-                pkcol.add(col)
-        for col in list(pkcol):
-            for f in col.foreign_keys:
-                if f.column in pkcol:
-                    pkcol.remove(col)
-        self.primary_key.extend(pkcol)
+        pkcol = util.Set([c for c in self._flatten_exportable_columns() if c.primary_key])
+    
+        equivs = {}
+        def add_equiv(a, b):
+            for x, y in ((a, b), (b, a)):
+                if x in equivs:
+                    equivs[x].add(y)
+                else:
+                    equivs[x] = util.Set([y])
+                    
+        class BinaryVisitor(ClauseVisitor):
+            def visit_binary(self, binary):
+                if binary.operator == '=':
+                    add_equiv(binary.left, binary.right)
+        BinaryVisitor().traverse(self.onclause)
+        
+        for col in pkcol:
+            for fk in col.foreign_keys:
+                if fk.column in pkcol:
+                    add_equiv(col, fk.column)
+                    
+        omit = util.Set()
+        for col in pkcol:
+            p = col
+            for c in equivs.get(col, util.Set()):
+                if p.references(c) or (c.primary_key and not p.primary_key):
+                    omit.add(p)
+                    p = c
+            
+        self.__primary_key = ColumnSet([c for c in self._flatten_exportable_columns() if c.primary_key and c not in omit])
+
+    primary_key = property(lambda s:s.__primary_key)
         
     def _locate_oid_column(self):
         return self.left.oid_column
@@ -2185,7 +2228,11 @@ class Join(FromClause):
                 collist.append(c)
         self.__folded_equivalents = collist
         return self.__folded_equivalents
-        
+
+    folded_equivalents = property(_get_folded_equivalents, doc="Returns the column list of this Join with all equivalently-named, "
+                                                            "equated columns folded into one column, where 'equated' means they are "
+                                                            "equated to each other in the ON clause of this join.")    
+    
     def select(self, whereclause = None, fold_equivalents=False, **kwargs):
         """Create a ``Select`` from this ``Join``.
         
@@ -2205,13 +2252,13 @@ class Join(FromClause):
           
         """
         if fold_equivalents:
-            collist = self._get_folded_equivalents()
+            collist = self.folded_equivalents
         else:
             collist = [self.left, self.right]
             
         return select(collist, whereclause, from_obj=[self], **kwargs)
 
-    engine = property(lambda s:s.left.engine or s.right.engine)
+    bind = property(lambda s:s.left.bind or s.right.bind)
 
     def alias(self, name=None):
         """Create a ``Select`` out of this ``Join`` clause and return an ``Alias`` of it.
@@ -2299,7 +2346,7 @@ class Alias(FromClause):
     def _group_parenthesized(self):
         return False
 
-    engine = property(lambda s: s.selectable.engine)
+    bind = property(lambda s: s.selectable.bind)
 
 class _Grouping(ColumnElement):
     def __init__(self, elem):
@@ -2492,12 +2539,8 @@ class TableClause(FromClause):
         super(TableClause, self).__init__(name)
         self.name = self.fullname = name
         self.encodedname = self.name.encode('ascii', 'backslashreplace')
-        self._columns = ColumnCollection()
-        self._foreign_keys = util.OrderedSet()
-        self._primary_key = ColumnCollection()
-        for c in columns:
-            self.append_column(c)
         self._oid_column = _ColumnClause('oid', self, _is_oid=True)
+        self._export_columns(columns)
 
     def _clone(self):
         # TableClause is immutable
@@ -2513,6 +2556,10 @@ class TableClause(FromClause):
     def _locate_oid_column(self):
         return self._oid_column
 
+    def _proxy_column(self, c):
+        self.append_column(c)
+        return c
+
     def _orig_columns(self):
         try:
             return self._orig_cols
@@ -2530,7 +2577,7 @@ class TableClause(FromClause):
             return [c for c in self.c]
         else:
             return []
-            
+
     def _exportable_columns(self):
         raise NotImplementedError()
 
@@ -2571,12 +2618,12 @@ class TableClause(FromClause):
 class _SelectBaseMixin(object):
     """Base class for ``Select`` and ``CompoundSelects``."""
 
-    def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, connectable=None, scalar=False, engine=None):
+    def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, scalar=False):
         self.use_labels = use_labels
         self.for_update = for_update
         self._limit = limit
         self._offset = offset
-        self._engine = connectable or engine
+        self._bind = bind
         self.is_scalar = scalar
         if self.is_scalar:
             # allow corresponding_column to return None
@@ -3001,14 +3048,14 @@ class Select(_SelectBaseMixin, FromClause):
         object, or searched within the from clauses for one.
         """
 
-        if self._engine is not None:
-            return self._engine
+        if self._bind is not None:
+            return self._bind
         for f in self._froms:
             if f is self:
                 continue
-            e = f.engine
+            e = f.bind
             if e is not None:
-                self._engine = e
+                self._bind = e
                 return e
         # look through the columns (largely synomous with looking
         # through the FROMs except in the case of _CalculatedClause/_Function)
@@ -3016,9 +3063,9 @@ class Select(_SelectBaseMixin, FromClause):
             for c in cc.columns:
                 if getattr(c, 'table', None) is self:
                     continue
-                e = c.engine
+                e = c.bind
                 if e is not None:
-                    self._engine = e
+                    self._bind = e
                     return e
         return None
 
@@ -3078,7 +3125,7 @@ class _UpdateBase(ClauseElement):
         return parameters
 
     def _find_engine(self):
-        return self.table.engine
+        return self.table.bind
 
 class Insert(_UpdateBase):
     def __init__(self, table, values=None):
index 028baf8aaa7e46e77dc0180a877c02597e7d8376..e711de3a3e93e16debf4c00ea72ef97068c0d7e2 100644 (file)
@@ -10,6 +10,7 @@ except ImportError:
     import dummy_thread as thread
     import dummy_threading as threading
 
+from sqlalchemy import exceptions
 import md5
 import sys
 import warnings
@@ -159,6 +160,15 @@ def duck_type_collection(specimen, default=None):
     else:
         return default
 
+def assert_arg_type(arg, argtype, name):
+    if isinstance(arg, argtype):
+        return arg
+    else:
+        if isinstance(argtype, tuple):
+            raise exceptions.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg))))
+        else:
+            raise exceptions.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg))))
+
 def warn_exception(func, *args, **kwargs):
     """executes the given function, catches all exceptions and converts to a warning."""
     try:
index 722f06256d592503b56cf885c138a73da78f6372..a34a82ed75c7b092d4c17a26bf484e199eb9e480 100644 (file)
@@ -7,6 +7,7 @@ def suite():
         # connectivity, execution
            'engine.parseconnect',
         'engine.pool', 
+        'engine.bind',
         'engine.reconnect',
         'engine.execute',
         'engine.metadata',
diff --git a/test/engine/bind.py b/test/engine/bind.py
new file mode 100644 (file)
index 0000000..b928d3d
--- /dev/null
@@ -0,0 +1,167 @@
+"""tests the "bind" attribute/argument across schema, SQL, and ORM sessions,
+including the deprecated versions of these arguments"""
+
+import testbase
+import unittest, sys, datetime
+import tables
+db = testbase.db
+from sqlalchemy import *
+
+class BindTest(testbase.PersistTest):
+    def test_create_drop_explicit(self):
+        metadata = MetaData()
+        table = Table('test_table', metadata,   
+            Column('foo', Integer))
+        for bind in (
+            testbase.db,
+            testbase.db.connect()
+        ):
+            for args in [
+                ([], {'bind':bind}),
+                ([bind], {})
+            ]:
+                metadata.create_all(*args[0], **args[1])
+                assert table.exists(*args[0], **args[1])
+                metadata.drop_all(*args[0], **args[1])
+                table.create(*args[0], **args[1])
+                table.drop(*args[0], **args[1])
+                assert not table.exists(*args[0], **args[1])
+    
+    def test_create_drop_err(self):
+        metadata = MetaData()
+        table = Table('test_table', metadata,   
+            Column('foo', Integer))
+
+        for meth in [
+            metadata.create_all,
+            table.exists,
+            metadata.drop_all,
+            table.create,
+            table.drop,
+        ]:
+            try:
+                meth()
+                assert False
+            except exceptions.InvalidRequestError, e:
+                assert str(e)  == "This SchemaItem is not connected to any Engine or Connection."
+        
+    def test_create_drop_bound(self):
+        
+        for meta in (MetaData,ThreadLocalMetaData):
+            for bind in (
+                testbase.db,
+                testbase.db.connect()
+            ):
+                metadata = meta()
+                table = Table('test_table', metadata,   
+                Column('foo', Integer))
+                metadata.bind = bind
+                assert metadata.bind is table.bind is bind
+                metadata.create_all()
+                assert table.exists()
+                metadata.drop_all()
+                table.create()
+                table.drop()
+                assert not table.exists()
+
+                metadata = meta()
+                table = Table('test_table', metadata,   
+                    Column('foo', Integer))
+
+                metadata.connect(bind)
+                assert metadata.bind is table.bind is bind
+                metadata.create_all()
+                assert table.exists()
+                metadata.drop_all()
+                table.create()
+                table.drop()
+                assert not table.exists()
+
+    def test_create_drop_constructor_bound(self):
+        for bind in (
+            testbase.db,
+            testbase.db.connect()
+        ):
+            for args in (
+                ([bind], {}),
+                ([], {'bind':bind}),
+            ):
+                metadata = MetaData(*args[0], **args[1])
+                table = Table('test_table', metadata,   
+                    Column('foo', Integer))
+
+                assert metadata.bind is table.bind is bind
+                metadata.create_all()
+                assert table.exists()
+                metadata.drop_all()
+                table.create()
+                table.drop()
+                assert not table.exists()
+
+
+    def test_clauseelement(self):
+        metadata = MetaData()
+        table = Table('test_table', metadata,   
+            Column('foo', Integer))
+        metadata.create_all(bind=testbase.db)
+        try:
+            for elem in [
+                table.select,
+                lambda **kwargs:func.current_timestamp(**kwargs).select(),
+#                func.current_timestamp().select,
+                lambda **kwargs:text("select * from test_table", **kwargs)
+            ]:
+                for bind in (
+                    testbase.db,
+                    testbase.db.connect()
+                ):
+                    e = elem(bind=bind)
+                    assert e.bind is bind
+                    e.execute()
+
+                try:
+                    e = elem()
+                    assert e.bind is None
+                    e.execute()
+                    assert False
+                except exceptions.InvalidRequestError, e:
+                    assert str(e) == "This Compiled object is not bound to any Engine or Connection."
+                
+        finally:
+            metadata.drop_all(bind=testbase.db)
+    
+    def test_session(self):
+        from sqlalchemy.orm import create_session, mapper
+        metadata = MetaData()
+        table = Table('test_table', metadata,   
+            Column('foo', Integer, primary_key=True),
+            Column('data', String(30)))
+        class Foo(object):
+            pass
+        mapper(Foo, table)
+        metadata.create_all(bind=testbase.db)
+        try:
+            for bind in (testbase.db, testbase.db.connect()):
+                for args in ({'bind':bind},):
+                    sess = create_session(**args)
+                    assert sess.bind is bind
+                    f = Foo()
+                    sess.save(f)
+                    sess.flush()
+                    assert sess.get(Foo, f.foo) is f
+                    
+            sess = create_session()
+            f = Foo()
+            sess.save(f)
+            try:
+                sess.flush()
+                assert False
+            except exceptions.InvalidRequestError, e:
+                assert str(e).startswith("Could not locate any Engine or Connection bound to mapper")
+                
+        finally:
+            metadata.drop_all(bind=testbase.db)
+        
+               
+if __name__ == '__main__':
+    testbase.main()
\ No newline at end of file
index c3c6441ef180cdb233f73ac510269823c725acd2..28b0535a5761adcec7b00408516068e90a5a638a 100644 (file)
@@ -7,7 +7,7 @@ class MetaDataTest(testbase.PersistTest):
         metadata = MetaData()
         t1 = Table('table1', metadata, Column('col1', Integer, primary_key=True),
             Column('col2', String(20)))
-        metadata.engine = testbase.db
+        metadata.bind = testbase.db
         metadata.create_all()
         try:
             assert t1.count().scalar() == 0
index 672d1bcd7cffe2c31242e778584892b3ae21b110..842be682d615f085e6f4a320c8acfba658a71285 100644 (file)
@@ -237,7 +237,7 @@ class ReflectionTest(PersistTest):
             PRIMARY KEY(id)
         )""")
         try:
-            metadata = MetaData(engine=testbase.db)
+            metadata = MetaData(bind=testbase.db)
             book = Table('book', metadata, autoload=True)
             assert book.c.id  in book.primary_key
             assert book.c.series not in book.primary_key
@@ -258,7 +258,7 @@ class ReflectionTest(PersistTest):
             PRIMARY KEY(id, isbn)
         )""")
         try:
-            metadata = MetaData(engine=testbase.db)
+            metadata = MetaData(bind=testbase.db)
             book = Table('book', metadata, autoload=True)
             assert book.c.id  in book.primary_key
             assert book.c.isbn  in book.primary_key
@@ -363,17 +363,17 @@ class ReflectionTest(PersistTest):
         def test_pickle():
             meta.connect(testbase.db)
             meta2 = pickle.loads(pickle.dumps(meta))
-            assert meta2.engine is None
+            assert meta2.bind is None
             return (meta2.tables['mytable'], meta2.tables['othertable'])
 
         def test_pickle_via_reflect():
             # this is the most common use case, pickling the results of a
             # database reflection
-            meta2 = MetaData(engine=testbase.db)
+            meta2 = MetaData(bind=testbase.db)
             t1 = Table('mytable', meta2, autoload=True)
             t2 = Table('othertable', meta2, autoload=True)
             meta3 = pickle.loads(pickle.dumps(meta2))
-            assert meta3.engine is None
+            assert meta3.bind is None
             assert meta3.tables['mytable'] is not t1
             return (meta3.tables['mytable'], meta3.tables['othertable'])
             
index 1437cde1fbcdfd1d21d65708a8dc883e08cd2545..3f61ec36913ea05f0c98793b6883d9aa19b0df4b 100644 (file)
@@ -170,7 +170,7 @@ class FlushTest(testbase.ORMTest):
         )
 
         admins = Table('admin', metadata,
-            Column('id', Integer, primary_key=True),
+            Column('admin_id', Integer, primary_key=True),
             Column('user_id', Integer, ForeignKey('users.id'))
         )
             
@@ -237,6 +237,86 @@ class FlushTest(testbase.ORMTest):
         a.password = 'sadmin'
         sess.flush()
         assert user_roles.count().scalar() == 1
-        
+
+class DistinctPKTest(testbase.ORMTest):
+    """test the construction of mapper.primary_key when an inheriting relationship
+    joins on a column other than primary key column."""
+    keep_data = True
+
+    def define_tables(self, metadata):
+        global person_table, employee_table, Person, Employee
+
+        person_table = Table("persons", metadata,
+                Column("id", Integer, primary_key=True),
+                Column("name", String(80)),
+                )
+
+        employee_table = Table("employees", metadata,
+                Column("id", Integer, primary_key=True),
+                Column("salary", Integer),
+                Column("person_id", Integer, ForeignKey("persons.id")),
+                )
+
+        class Person(object):
+            def __init__(self, name):
+                self.name = name
+
+        class Employee(Person): pass
+
+    def insert_data(self):
+        person_insert = person_table.insert()
+        person_insert.execute(id=1, name='alice')
+        person_insert.execute(id=2, name='bob')
+
+        employee_insert = employee_table.insert()
+        employee_insert.execute(id=2, salary=250, person_id=1) # alice
+        employee_insert.execute(id=3, salary=200, person_id=2) # bob
+
+    def test_implicit(self):
+        person_mapper = mapper(Person, person_table)
+        mapper(Employee, employee_table, inherits=person_mapper)
+        try:
+            print class_mapper(Employee).primary_key
+            assert list(class_mapper(Employee).primary_key) == [person_table.c.id, employee_table.c.id]
+            assert False
+        except RuntimeWarning, e:
+            assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'.  Use explicit properties to give each column its own mapped attribute name."
+
+    def test_explicit_props(self):
+        person_mapper = mapper(Person, person_table)
+        mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id})
+        self._do_test(True)
+
+    def test_explicit_composite_pk(self):
+        person_mapper = mapper(Person, person_table)
+        mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id])
+        try:
+            self._do_test(True)
+            assert False
+        except RuntimeWarning, e:
+            assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'.  Use explicit properties to give each column its own mapped attribute name."
+
+    def test_explicit_pk(self):
+        person_mapper = mapper(Person, person_table)
+        mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id])
+        self._do_test(False)
+
+    def _do_test(self, composite):
+        session = create_session()
+        query = session.query(Employee)
+
+        if composite:
+            alice1 = query.get([1,2])
+            bob = query.get([2,3])
+            alice2 = query.get([1,2])
+        else:
+            alice1 = query.get(1)
+            bob = query.get(2)
+            alice2 = query.get(1)
+
+            assert alice1.name == alice2.name == 'alice'
+            assert bob.name == 'bob'
+
+
 if __name__ == "__main__":    
     testbase.main()
index 26a0b80a44f4d69d67072fbecfe6303b814a762a..eb0d110a16118400a840bb40b3c2ee3f703b20d5 100644 (file)
@@ -150,6 +150,9 @@ class MapperTest(MapperSuperTest):
         def bad_expunge(foo):
             raise Exception("this exception should be stated as a warning")
 
+        import warnings
+        warnings.filterwarnings("always", r".*this exception should be stated as a warning")
+
         sess.expunge = bad_expunge
         try:
             Foo(_sa_session=sess)
@@ -660,7 +663,7 @@ class DeferredTest(MapperSuperTest):
             o2 = l[2]
             print o2.description
 
-        orderby = str(orders.default_order_by()[0].compile(engine=db))
+        orderby = str(orders.default_order_by()[0].compile(bind=db))
         self.assert_sql(db, go, [
             ("SELECT orders.order_id AS orders_order_id, orders.user_id AS orders_user_id, orders.isopen AS orders_isopen FROM orders ORDER BY %s" % orderby, {}),
             ("SELECT orders.description AS orders_description FROM orders WHERE orders.order_id = :orders_order_id", {'orders_order_id':3})
index f52d90c5d8f7042b51113581fbd9d6ae92b0f895..c0b0c8f84bfb04b2a87fe523816855cd15376d14 100644 (file)
@@ -54,7 +54,7 @@ class QueryTest(testbase.ORMTest):
     def define_tables(self, meta):
         # a slight dirty trick here. 
         meta.tables = metadata.tables
-        metadata.connect(meta.engine)
+        metadata.connect(meta.bind)
         
     def setup_mappers(self):
         mapper(User, users, properties={
index d95b7b4adf36baf3366842428892bcb0d99514d4..80fe147275ab5ef3bb2bbaf96c56e2507eeb9176 100644 (file)
@@ -561,19 +561,19 @@ class TypeMatchTest(testbase.ORMTest):
     def define_tables(self, metadata):
         global a, b, c, d
         a = Table("a", metadata, 
-            Column('id', Integer, primary_key=True),
+            Column('aid', Integer, primary_key=True),
             Column('data', String(30)))
         b = Table("b", metadata, 
-            Column('id', Integer, primary_key=True),
-            Column("a_id", Integer, ForeignKey("a.id")),
+            Column('bid', Integer, primary_key=True),
+            Column("a_id", Integer, ForeignKey("a.aid")),
             Column('data', String(30)))
         c = Table("c", metadata, 
-            Column('id', Integer, primary_key=True),
-            Column("b_id", Integer, ForeignKey("b.id")),
+            Column('cid', Integer, primary_key=True),
+            Column("b_id", Integer, ForeignKey("b.bid")),
             Column('data', String(30)))
         d = Table("d", metadata, 
-            Column('id', Integer, primary_key=True),
-            Column("a_id", Integer, ForeignKey("a.id")),
+            Column('did', Integer, primary_key=True),
+            Column("a_id", Integer, ForeignKey("a.aid")),
             Column('data', String(30)))
     def test_o2m_oncascade(self):
         class A(object):pass
index 2eeaef7cc15613572e217896fd2dd0275c116bc1..07363a402e67546810bc236199e74426f087f89e 100644 (file)
@@ -25,26 +25,26 @@ class DefaultTest(PersistTest):
  
         # select "count(1)" returns different results on different DBs
         # also correct for "current_date" compatible as column default, value differences
-        currenttime = func.current_date(type=Date, engine=db);
+        currenttime = func.current_date(type=Date, bind=db);
         if is_oracle:
             ts = db.func.trunc(func.sysdate(), literal_column("'DAY'")).scalar()
-            f = select([func.count(1) + 5], engine=db).scalar()
-            f2 = select([func.count(1) + 14], engine=db).scalar()
+            f = select([func.count(1) + 5], bind=db).scalar()
+            f2 = select([func.count(1) + 14], bind=db).scalar()
             # TODO: engine propigation across nested functions not working
-            currenttime = func.trunc(currenttime, literal_column("'DAY'"), engine=db)
+            currenttime = func.trunc(currenttime, literal_column("'DAY'"), bind=db)
             def1 = currenttime
             def2 = func.trunc(text("sysdate"), literal_column("'DAY'"))
             deftype = Date
         elif use_function_defaults:
-            f = select([func.count(1) + 5], engine=db).scalar()
-            f2 = select([func.count(1) + 14], engine=db).scalar()
+            f = select([func.count(1) + 5], bind=db).scalar()
+            f2 = select([func.count(1) + 14], bind=db).scalar()
             def1 = currenttime
             def2 = text("current_date")
             deftype = Date
             ts = db.func.current_date().scalar()
         else:
-            f = select([func.count(1) + 5], engine=db).scalar()
-            f2 = select([func.count(1) + 14], engine=db).scalar()
+            f = select([func.count(1) + 5], bind=db).scalar()
+            f2 = select([func.count(1) + 14], bind=db).scalar()
             def1 = def2 = "3"
             ts = 3
             deftype = Integer
@@ -257,7 +257,7 @@ class SequenceTest(PersistTest):
    
     @testbase.supported('postgres', 'oracle')
     def test_implicit_sequence_exec(self):
-        s = Sequence("my_sequence", metadata=testbase.db)
+        s = Sequence("my_sequence", metadata=MetaData(testbase.db))
         s.create()
         try:
             x = s.execute()
@@ -266,9 +266,9 @@ class SequenceTest(PersistTest):
             s.drop()
 
     @testbase.supported('postgres', 'oracle')
-    def test_explicit_sequence_exec(self):
+    def teststandalone_explicit(self):
         s = Sequence("my_sequence")
-        s.create(testbase.db)
+        s.create(bind=testbase.db)
         try:
             x = s.execute(testbase.db)
             self.assert_(x == 1)
index 968e75dfc330d319aacdc83ce77d4ef747748dcb..384fead50b4fc4936aa59d61880a233d133143f9 100644 (file)
@@ -17,7 +17,7 @@ class LabelTypeTest(testbase.PersistTest):
 class LongLabelsTest(testbase.PersistTest):
     def setUpAll(self):
         global metadata, table1
-        metadata = MetaData(engine=testbase.db)
+        metadata = MetaData(testbase.db)
         table1 = Table("some_large_named_table", metadata,
             Column("this_is_the_primarykey_column", Integer, Sequence("this_is_some_large_seq"), primary_key=True),
             Column("this_is_the_data_column", String(30))
index 76e07881b0a2490ec49ce95bb7357f06f8885652..77cf91d434cfe8850b566eae737f7f8bf6d4d4a7 100644 (file)
@@ -235,7 +235,7 @@ class QueryTest(PersistTest):
         self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
         self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
 
-        r = text("select * from query_users where user_id=2", engine=testbase.db).execute().fetchone()
+        r = text("select * from query_users where user_id=2", bind=testbase.db).execute().fetchone()
         self.assert_(r.user_id == r['user_id'] == r[users.c.user_id] == 2)
         self.assert_(r.user_name == r['user_name'] == r[users.c.user_name] == 'jack')
         
index 50b2fa6b4a3685c16f27e18eac2e40ea6bf34d2c..bcf70bd2062baefe4749d7cf6386fb557f39a258 100755 (executable)
-"""tests that various From objects properly export their columns, as well as
--useable primary keys and foreign keys.  Full relational algebra depends on
--every selectable unit behaving nicely with others.."""
-import testbase
-import unittest, sys, datetime
-from sqlalchemy import *
-from testbase import Table, Column
-
-db = testbase.db
-metadata = MetaData(db)
-
-
-table = Table('table1', metadata, 
-    Column('col1', Integer, primary_key=True),
-    Column('col2', String(20)),
-    Column('col3', Integer),
-    Column('colx', Integer),
-    
-)
-
-table2 = Table('table2', metadata,
-    Column('col1', Integer, primary_key=True),
-    Column('col2', Integer, ForeignKey('table1.col1')),
-    Column('col3', String(20)),
-    Column('coly', Integer),
-)
-
-class SelectableTest(testbase.AssertMixin):
-    def testdistance(self):
-        s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')])
-
-        # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far
-        #assert s.corresponding_column(table.c.col1) is s.c.col1
-        assert s.corresponding_column(s.c.col1) is s.c.col1
-        assert s.corresponding_column(s.c.c1) is s.c.c1
-        
-    def testjoinagainstself(self):
-        jj = select([table.c.col1.label('bar_col1')])
-        jjj = join(table, jj, table.c.col1==jj.c.bar_col1)
-        
-        # test column directly agaisnt itself
-        assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1
-
-        assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1
-        
-        # test alias of the join, targets the column with the least 
-        # "distance" between the requested column and the returned column
-        # (i.e. there is less indirection between j2.c.table1_col1 and table.c.col1, than
-        # there is from j2.c.bar_col1 to table.c.col1)
-        j2 = jjj.alias('foo')
-        assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1
-        
-
-    def testjoinagainstjoin(self):
-        j  = outerjoin(table, table2, table.c.col1==table2.c.col2)
-        jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo')
-        jjj = join(table, jj, table.c.col1==jj.c.bar_col1)
-        assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1
-        
-        j2 = jjj.alias('foo')
-        print j2.corresponding_column(jjj.c.table1_col1)
-        assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1
-        
-        assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1
-        
-    def testtablealias(self):
-        a = table.alias('a')
-        
-        j = join(a, table2)
-        
-        criterion = a.c.col1 == table2.c.col2
-        print
-        print str(j)
-        self.assert_(criterion.compare(j.onclause))
-
-    def testunion(self):
-        # tests that we can correspond a column in a Select statement with a certain Table, against
-        # a column in a Union where one of its underlying Selects matches to that same Table
-        u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
-                select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
-            )
-        s1 = table.select(use_labels=True)
-        s2 = table2.select(use_labels=True)
-        print ["%d %s" % (id(c),c.key) for c in u.c]
-        c = u.corresponding_column(s1.c.table1_col2)
-        print "%d %s" % (id(c), c.key)
-        print id(u.corresponding_column(s1.c.table1_col2).table)
-        print id(u.c.col2.table)
-        assert u.corresponding_column(s1.c.table1_col2) is u.c.col2
-        assert u.corresponding_column(s2.c.table2_col2) is u.c.col2
-
-    def testaliasunion(self):
-        # same as testunion, except its an alias of the union
-        u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
-                select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
-            ).alias('analias')
-        s1 = table.select(use_labels=True)
-        s2 = table2.select(use_labels=True)
-        assert u.corresponding_column(s1.c.table1_col2) is u.c.col2
-        assert u.corresponding_column(s2.c.table2_col2) is u.c.col2
-        assert u.corresponding_column(s2.c.table2_coly) is u.c.coly
-        assert s2.corresponding_column(u.c.coly) is s2.c.table2_coly
-
-    def testselectunion(self):
-        # like testaliasunion, but off a Select off the union.
-        u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
-                select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
-            ).alias('analias')
-        s = select([u])
-        s1 = table.select(use_labels=True)
-        s2 = table2.select(use_labels=True)
-        assert s.corresponding_column(s1.c.table1_col2) is s.c.col2
-        assert s.corresponding_column(s2.c.table2_col2) is s.c.col2
-
-    def testunionagainstjoin(self):
-        # same as testunion, except its an alias of the union
-        u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(
-                select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])
-            ).alias('analias')
-        j1 = table.join(table2)
-        assert u.corresponding_column(j1.c.table1_colx) is u.c.colx
-        assert j1.corresponding_column(u.c.colx) is j1.c.table1_colx
-        
-    def testjoin(self):
-        a = join(table, table2)
-        print str(a.select(use_labels=True))
-        b = table2.alias('b')
-        j = join(a, b)
-        print str(j)
-        criterion = a.c.table1_col1 == b.c.col2
-        self.assert_(criterion.compare(j.onclause))
-
-    def testselectalias(self):
-        a = table.select().alias('a')
-        print str(a.select())
-        j = join(a, table2)
-        
-        criterion = a.c.col1 == table2.c.col2
-        print criterion
-        print j.onclause
-        self.assert_(criterion.compare(j.onclause))
-
-    def testselectlabels(self):
-        a = table.select(use_labels=True)
-        print str(a.select())
-        j = join(a, table2)
-        
-        criterion = a.c.table1_col1 == table2.c.col2
-        print
-        print str(j)
-        self.assert_(criterion.compare(j.onclause))
-
-    def testcolumnlabels(self):
-        a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')])
-        print str(a)
-        print [c for c in a.columns]
-        print str(a.select())
-        j = join(a, table2)
-        criterion = a.c.acol1 == table2.c.col2
-        print str(j)
-        self.assert_(criterion.compare(j.onclause))
-        
-    def testselectaliaslabels(self):
-        a = table2.select(use_labels=True).alias('a')
-        print str(a.select())
-        j = join(a, table)
-        
-        criterion =  table.c.col1 == a.c.table2_col2
-        print str(criterion)
-        print str(j.onclause)
-        self.assert_(criterion.compare(j.onclause))
-        
-if __name__ == "__main__":
-    testbase.main()
-    
+"""tests that various From objects properly export their columns, as well as\r
+useable primary keys and foreign keys.  Full relational algebra depends on\r
+every selectable unit behaving nicely with others.."""\r
\r
+import testbase\r
+import unittest, sys, datetime\r
+from sqlalchemy import *\r
+from testbase import Table, Column\r
+\r
+db = testbase.db\r
+metadata = MetaData(db)\r
+\r
+\r
+table = Table('table1', metadata, \r
+    Column('col1', Integer, primary_key=True),\r
+    Column('col2', String(20)),\r
+    Column('col3', Integer),\r
+    Column('colx', Integer),\r
+    \r
+)\r
+\r
+table2 = Table('table2', metadata,\r
+    Column('col1', Integer, primary_key=True),\r
+    Column('col2', Integer, ForeignKey('table1.col1')),\r
+    Column('col3', String(20)),\r
+    Column('coly', Integer),\r
+)\r
+\r
+class SelectableTest(testbase.AssertMixin):\r
+    def testdistance(self):\r
+        s = select([table.c.col1.label('c2'), table.c.col1, table.c.col1.label('c1')])\r
+\r
+        # didnt do this yet...col.label().make_proxy() has same "distance" as col.make_proxy() so far\r
+        #assert s.corresponding_column(table.c.col1) is s.c.col1\r
+        assert s.corresponding_column(s.c.col1) is s.c.col1\r
+        assert s.corresponding_column(s.c.c1) is s.c.c1\r
+        \r
+    def testjoinagainstself(self):\r
+        jj = select([table.c.col1.label('bar_col1')])\r
+        jjj = join(table, jj, table.c.col1==jj.c.bar_col1)\r
+        \r
+        # test column directly agaisnt itself\r
+        assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1\r
+\r
+        assert jjj.corresponding_column(jj.c.bar_col1) is jjj.c.bar_col1\r
+        \r
+        # test alias of the join, targets the column with the least \r
+        # "distance" between the requested column and the returned column\r
+        # (i.e. there is less indirection between j2.c.table1_col1 and table.c.col1, than\r
+        # there is from j2.c.bar_col1 to table.c.col1)\r
+        j2 = jjj.alias('foo')\r
+        assert j2.corresponding_column(table.c.col1) is j2.c.table1_col1\r
+        \r
+\r
+    def testjoinagainstjoin(self):\r
+        j  = outerjoin(table, table2, table.c.col1==table2.c.col2)\r
+        jj = select([ table.c.col1.label('bar_col1')],from_obj=[j]).alias('foo')\r
+        jjj = join(table, jj, table.c.col1==jj.c.bar_col1)\r
+        assert jjj.corresponding_column(jjj.c.table1_col1) is jjj.c.table1_col1\r
+        \r
+        j2 = jjj.alias('foo')\r
+        print j2.corresponding_column(jjj.c.table1_col1)\r
+        assert j2.corresponding_column(jjj.c.table1_col1) is j2.c.table1_col1\r
+        \r
+        assert jjj.corresponding_column(jj.c.bar_col1) is jj.c.bar_col1\r
+        \r
+    def testtablealias(self):\r
+        a = table.alias('a')\r
+        \r
+        j = join(a, table2)\r
+        \r
+        criterion = a.c.col1 == table2.c.col2\r
+        print\r
+        print str(j)\r
+        self.assert_(criterion.compare(j.onclause))\r
+\r
+    def testunion(self):\r
+        # tests that we can correspond a column in a Select statement with a certain Table, against\r
+        # a column in a Union where one of its underlying Selects matches to that same Table\r
+        u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(\r
+                select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])\r
+            )\r
+        s1 = table.select(use_labels=True)\r
+        s2 = table2.select(use_labels=True)\r
+        print ["%d %s" % (id(c),c.key) for c in u.c]\r
+        c = u.corresponding_column(s1.c.table1_col2)\r
+        print "%d %s" % (id(c), c.key)\r
+        print id(u.corresponding_column(s1.c.table1_col2).table)\r
+        print id(u.c.col2.table)\r
+        assert u.corresponding_column(s1.c.table1_col2) is u.c.col2\r
+        assert u.corresponding_column(s2.c.table2_col2) is u.c.col2\r
+\r
+    def testaliasunion(self):\r
+        # same as testunion, except its an alias of the union\r
+        u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(\r
+                select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])\r
+            ).alias('analias')\r
+        s1 = table.select(use_labels=True)\r
+        s2 = table2.select(use_labels=True)\r
+        assert u.corresponding_column(s1.c.table1_col2) is u.c.col2\r
+        assert u.corresponding_column(s2.c.table2_col2) is u.c.col2\r
+        assert u.corresponding_column(s2.c.table2_coly) is u.c.coly\r
+        assert s2.corresponding_column(u.c.coly) is s2.c.table2_coly\r
+\r
+    def testselectunion(self):\r
+        # like testaliasunion, but off a Select off the union.\r
+        u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(\r
+                select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])\r
+            ).alias('analias')\r
+        s = select([u])\r
+        s1 = table.select(use_labels=True)\r
+        s2 = table2.select(use_labels=True)\r
+        assert s.corresponding_column(s1.c.table1_col2) is s.c.col2\r
+        assert s.corresponding_column(s2.c.table2_col2) is s.c.col2\r
+\r
+    def testunionagainstjoin(self):\r
+        # same as testunion, except its an alias of the union\r
+        u = select([table.c.col1, table.c.col2, table.c.col3, table.c.colx, null().label('coly')]).union(\r
+                select([table2.c.col1, table2.c.col2, table2.c.col3, null().label('colx'), table2.c.coly])\r
+            ).alias('analias')\r
+        j1 = table.join(table2)\r
+        assert u.corresponding_column(j1.c.table1_colx) is u.c.colx\r
+        assert j1.corresponding_column(u.c.colx) is j1.c.table1_colx\r
+        \r
+    def testjoin(self):\r
+        a = join(table, table2)\r
+        print str(a.select(use_labels=True))\r
+        b = table2.alias('b')\r
+        j = join(a, b)\r
+        print str(j)\r
+        criterion = a.c.table1_col1 == b.c.col2\r
+        self.assert_(criterion.compare(j.onclause))\r
+\r
+    def testselectalias(self):\r
+        a = table.select().alias('a')\r
+        print str(a.select())\r
+        j = join(a, table2)\r
+        \r
+        criterion = a.c.col1 == table2.c.col2\r
+        print criterion\r
+        print j.onclause\r
+        self.assert_(criterion.compare(j.onclause))\r
+\r
+    def testselectlabels(self):\r
+        a = table.select(use_labels=True)\r
+        print str(a.select())\r
+        j = join(a, table2)\r
+        \r
+        criterion = a.c.table1_col1 == table2.c.col2\r
+        print\r
+        print str(j)\r
+        self.assert_(criterion.compare(j.onclause))\r
+\r
+    def testcolumnlabels(self):\r
+        a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')])\r
+        print str(a)\r
+        print [c for c in a.columns]\r
+        print str(a.select())\r
+        j = join(a, table2)\r
+        criterion = a.c.acol1 == table2.c.col2\r
+        print str(j)\r
+        self.assert_(criterion.compare(j.onclause))\r
+        \r
+    def testselectaliaslabels(self):\r
+        a = table2.select(use_labels=True).alias('a')\r
+        print str(a.select())\r
+        j = join(a, table)\r
+        \r
+        criterion =  table.c.col1 == a.c.table2_col2\r
+        print str(criterion)\r
+        print str(j.onclause)\r
+        self.assert_(criterion.compare(j.onclause))\r
+        \r
+\r
+class PrimaryKeyTest(testbase.AssertMixin):\r
+    def test_join_pk_collapse_implicit(self):\r
+        """test that redundant columns in a join get 'collapsed' into a minimal primary key, \r
+        which is the root column along a chain of foreign key relationships."""\r
+        \r
+        meta = MetaData()\r
+        a = Table('a', meta, Column('id', Integer, primary_key=True))\r
+        b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True))\r
+        c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True))\r
+        d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True))\r
+\r
+        assert c.c.id.references(b.c.id)\r
+        assert not d.c.id.references(a.c.id)\r
+        \r
+        assert list(a.join(b).primary_key) == [a.c.id]\r
+        assert list(b.join(c).primary_key) == [b.c.id]\r
+        assert list(a.join(b).join(c).primary_key) == [a.c.id]\r
+        assert list(b.join(c).join(d).primary_key) == [b.c.id]\r
+        assert list(d.join(c).join(b).primary_key) == [b.c.id]\r
+        assert list(a.join(b).join(c).join(d).primary_key) == [a.c.id]\r
+\r
+    def test_join_pk_collapse_explicit(self):\r
+        """test that redundant columns in a join get 'collapsed' into a minimal primary key, \r
+        which is the root column along a chain of explicit join conditions."""\r
+\r
+        meta = MetaData()\r
+        a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer))\r
+        b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer))\r
+        c = Table('c', meta, Column('id', Integer, ForeignKey('b.id'), primary_key=True), Column('x', Integer))\r
+        d = Table('d', meta, Column('id', Integer, ForeignKey('c.id'), primary_key=True), Column('x', Integer))\r
+\r
+        print list(a.join(b, a.c.x==b.c.id).primary_key)\r
+        assert list(a.join(b, a.c.x==b.c.id).primary_key) == [b.c.id]\r
+        assert list(b.join(c, b.c.x==c.c.id).primary_key) == [b.c.id]\r
+        assert list(a.join(b).join(c, c.c.id==b.c.x).primary_key) == [a.c.id]\r
+        assert list(b.join(c, c.c.x==b.c.id).join(d).primary_key) == [c.c.id]\r
+        assert list(b.join(c, c.c.id==b.c.x).join(d).primary_key) == [b.c.id]\r
+        assert list(d.join(b, d.c.id==b.c.id).join(c, b.c.id==c.c.x).primary_key) == [c.c.id]\r
+        assert list(a.join(b).join(c, c.c.id==b.c.x).join(d).primary_key) == [a.c.id]\r
+        \r
+        assert list(a.join(b, and_(a.c.id==b.c.id, a.c.x==b.c.id)).primary_key) == [a.c.id]\r
+    \r
+    def test_init_doesnt_blowitaway(self):\r
+        meta = MetaData()\r
+        a = Table('a', meta, Column('id', Integer, primary_key=True), Column('x', Integer))\r
+        b = Table('b', meta, Column('id', Integer, ForeignKey('a.id'), primary_key=True), Column('x', Integer))\r
+\r
+        j = a.join(b)\r
+        assert list(j.primary_key) == [a.c.id]\r
+        \r
+        j.foreign_keys\r
+        assert list(j.primary_key) == [a.c.id]\r
+\r
+\r
+if __name__ == "__main__":\r
+    testbase.main()\r
+\r
index 8d5848d1cd241c777be4923d496b5dc7ebd2d27d..24fbde3a259ceb9a8552df59ac82232615925de7 100644 (file)
@@ -247,7 +247,7 @@ class BinaryTest(AssertMixin):
         
         for stmt in (
             binary_table.select(order_by=binary_table.c.primary_id),
-            text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType}, engine=testbase.db)
+            text("select * from binary_table order by binary_table.primary_id", typemap={'pickled':PickleType}, bind=testbase.db)
         ):
             l = stmt.execute().fetchall()
             print type(stream1), type(l[0]['data']), type(l[0]['data_slice'])
index c426a258c7a7387d018b0c352dc9963603b5e963..f885dc56ba7b1460d0bb088930dd4320fe74715f 100644 (file)
@@ -10,7 +10,7 @@ from testbase import Table, Column
 class UnicodeSchemaTest(testbase.PersistTest):
     def setUpAll(self):
         global metadata, t1, t2
-        metadata = MetaData(engine=testbase.db)
+        metadata = MetaData(testbase.db)
         t1 = Table('unitable1', metadata,
             Column(u'méil', Integer, primary_key=True),
             Column(u'éXXm', Integer),
index 41eb38ddfc4fe8b2d6a4bcc226e1f4773e20b751..d1e901a2e55bace3b65f29a5abc24b26d84bec34 100644 (file)
@@ -10,6 +10,9 @@ import sqlalchemy
 from sqlalchemy import sql, schema, engine, pool, MetaData
 from sqlalchemy.orm import clear_mappers
 
+import warnings
+warnings.filterwarnings("error")
+
 db = None
 metadata = None
 db_uri = None
@@ -312,8 +315,11 @@ class ORMTest(AssertMixin):
         _otest_metadata = MetaData(db)
         self.define_tables(_otest_metadata)
         _otest_metadata.create_all()
+        self.insert_data()
     def define_tables(self, _otest_metadata):
         raise NotImplementedError()
+    def insert_data(self):
+        pass
     def get_metadata(self):
         return _otest_metadata
     def tearDownAll(self):