]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
r4695 merged to trunk; trunk now becomes 0.5.
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 9 May 2008 16:34:10 +0000 (16:34 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 9 May 2008 16:34:10 +0000 (16:34 +0000)
0.4 development continues at /sqlalchemy/branches/rel_0_4

164 files changed:
CHANGES
README_THIS_IS_NOW_VERSION_0.5.txt [new file with mode: 0644]
VERSION
attributes_rollback_test.py [new file with mode: 0644]
doc/build/content/intro.txt
doc/build/content/mappers.txt
doc/build/content/ormtutorial.txt
doc/build/content/session.txt
doc/build/content/sqlexpression.txt
doc/build/gen_docstrings.py
doc/build/genhtml.py
doc/build/testdocs.py
examples/custom_attributes/custom_management.py [new file with mode: 0644]
examples/dynamic_dict/dynamic_dict.py
lib/sqlalchemy/__init__.py
lib/sqlalchemy/databases/access.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/information_schema.py
lib/sqlalchemy/databases/informix.py
lib/sqlalchemy/databases/maxdb.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mxODBC.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/databases/sybase.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/threadlocal.py
lib/sqlalchemy/engine/url.py
lib/sqlalchemy/exc.py [moved from lib/sqlalchemy/exceptions.py with 75% similarity]
lib/sqlalchemy/ext/activemapper.py [deleted file]
lib/sqlalchemy/ext/assignmapper.py [deleted file]
lib/sqlalchemy/ext/associationproxy.py
lib/sqlalchemy/ext/declarative.py
lib/sqlalchemy/ext/orderinglist.py
lib/sqlalchemy/ext/selectresults.py [deleted file]
lib/sqlalchemy/ext/sessioncontext.py [deleted file]
lib/sqlalchemy/ext/sqlsoup.py
lib/sqlalchemy/interfaces.py
lib/sqlalchemy/log.py [moved from lib/sqlalchemy/logging.py with 90% similarity]
lib/sqlalchemy/mods/__init__.py [deleted file]
lib/sqlalchemy/mods/selectresults.py [deleted file]
lib/sqlalchemy/orm/__init__.py
lib/sqlalchemy/orm/attributes.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/dynamic.py
lib/sqlalchemy/orm/exc.py [new file with mode: 0644]
lib/sqlalchemy/orm/identity.py [new file with mode: 0644]
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/shard.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/orm/unitofwork.py
lib/sqlalchemy/orm/uowdumper.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/pool.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/operators.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
lib/sqlalchemy/topological.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util.py
test/base/dependency.py
test/base/except.py
test/base/utils.py
test/dialect/firebird.py
test/dialect/maxdb.py
test/dialect/mssql.py
test/dialect/mysql.py
test/dialect/oracle.py
test/dialect/postgres.py
test/dialect/sqlite.py
test/engine/bind.py
test/engine/ddlevents.py
test/engine/execute.py
test/engine/metadata.py
test/engine/parseconnect.py
test/engine/pool.py
test/engine/reconnect.py
test/engine/reflection.py
test/engine/transaction.py
test/ext/activemapper.py [deleted file]
test/ext/alltests.py
test/ext/assignmapper.py [deleted file]
test/ext/declarative.py
test/orm/alltests.py
test/orm/association.py
test/orm/assorted_eager.py
test/orm/attributes.py
test/orm/cascade.py
test/orm/collection.py
test/orm/compile.py
test/orm/cycles.py
test/orm/deprecations.py [new file with mode: 0644]
test/orm/dynamic.py
test/orm/eager_relations.py
test/orm/entity.py
test/orm/expire.py
test/orm/extendedattr.py [new file with mode: 0644]
test/orm/generative.py
test/orm/inheritance/abc_inheritance.py
test/orm/inheritance/abc_polymorphic.py
test/orm/inheritance/basic.py
test/orm/inheritance/concrete.py
test/orm/inheritance/poly_linked_list.py
test/orm/inheritance/polymorph.py
test/orm/inheritance/polymorph2.py
test/orm/inheritance/query.py
test/orm/inheritance/single.py
test/orm/instrumentation.py [new file with mode: 0644]
test/orm/lazy_relations.py
test/orm/manytomany.py
test/orm/mapper.py
test/orm/merge.py
test/orm/naturalpks.py
test/orm/onetoone.py
test/orm/pickled.py
test/orm/query.py
test/orm/relationships.py
test/orm/scoping.py [new file with mode: 0644]
test/orm/selectable.py
test/orm/session.py
test/orm/sessioncontext.py [deleted file]
test/orm/sharding/shard.py
test/orm/transaction.py [new file with mode: 0644]
test/orm/unitofwork.py
test/orm/utils.py [new file with mode: 0644]
test/perf/masseagerload.py
test/profiling/compiler.py
test/profiling/zoomark.py
test/sql/case_statement.py
test/sql/columns.py
test/sql/constraints.py
test/sql/defaults.py
test/sql/functions.py
test/sql/generative.py
test/sql/query.py
test/sql/quote.py
test/sql/select.py
test/sql/selectable.py
test/sql/testtypes.py
test/testlib/__init__.py
test/testlib/compat.py
test/testlib/engines.py
test/testlib/filters.py
test/testlib/fixtures.py
test/testlib/profiling.py
test/testlib/requires.py [new file with mode: 0644]
test/testlib/schema.py
test/testlib/tables.py
test/testlib/testing.py

diff --git a/CHANGES b/CHANGES
index 35d53ab61815ae534f581315052f51dfcf63c0a1..894be6116d8522fbb93a317f18a3bb91c9e82682 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,11 +1,76 @@
+-*- coding: utf-8; fill-column: 68 -*-
+
 =======
 CHANGES
 =======
 
+user_defined_state
+==================
+
+    - The "__init__" trigger/decorator added by mapper now attempts
+      to exactly mirror the argument signature of the original
+      __init__.  The pass-through for '_sa_session' is no longer
+      implicit- you must allow for this keyword argument in your
+      constructor.
+
+    - ClassState is renamed to ClassManager.
+
+    - Classes may supply their own InstrumentationManager by
+      providing a __sa_instrumentation_manager__ property.
+
+    - Custom instrumentation may use any mechanism to associate a
+      ClassManager with a class and an InstanceState with an
+      instance.  Attributes on those objects are still the default
+      association mechanism used by SQLAlchemy's native
+      instrumentation.
+
+    - Moved entity_name, _sa_session_id, and _instance_key from the
+      instance object to the instance state.  These values are still
+      available in the old way, which is now deprecated, using
+      descriptors attached to the class.  A deprecation warning will
+      be issued when accessed.
+
+    - attribute savepoint/rollback capability has been added.  For
+      starters, this takes effect within the flush() call, so that
+      attribute changes which occur within flush() are rolled back
+      when the flush fails.  Since it's primarily new primary key
+      values that get assigned within flush(), expiring those
+      attributes is not an option.  The next place we might use
+      savepoints is within SAVEPOINT transactions, since rolling
+      back to a savepoint is a transaction-contained operation.
+
+    - The _prepare_instrumentation alias for prepare_instrumentation
+      has been removed.
+
+    - sqlalchemy.exceptions has been renamed to sqlalchemy.exc.  The
+      module may be imported under either name.
+
+    - ORM-related exceptions are now defined in sqlalchemy.orm.exc.
+      ConcurrentModificationError, FlushError, and
+      UnmappedColumnError compatibility aliases are installed in
+      sqlalchemy.exc during the import of sqlalchemy.orm.
+
+    - sqlalchemy.logging has been renamed to sqlalchemy.log.
+
+    - The transitional sqlalchemy.log.SADeprecationWarning alias for
+      the warning's definition in sqlalchemy.exc has been removed.
+
+    - exc.AssertionError has been removed and usage replaced with
+      Python's built-in AssertionError.
+
+    - The behavior of MapperExtensions attached to multiple,
+      entity_name= primary mappers for a single class has been
+      altered.  The first mapper() defined for a class is the only
+      mapper eligible for the MapperExtension 'instrument_class',
+      'init_instance' and 'init_failed' events.  This is backwards
+      incompatible; previously the extensions of last mapper defined
+      would receive these events.
+
+
 0.4.6
 =====
 - orm
-    - A fix to the recent relation() refactoring which fixes
+    - Fix to the recent relation() refactoring which fixes
       exotic viewonly relations which join between local and
       remote table multiple times, with a common column shared
       between the joins.
@@ -13,12 +78,7 @@ CHANGES
     - Also re-established viewonly relation() configurations
       that join across multiple tables.
 
-    - contains_eager(), the hot function of the week, suppresses
-      the eager loader's own generation of the LEFT OUTER JOIN,
-      so that it is reasonable to use any Query, not just those
-      which use from_statement().  
-      
-    - Added an experimental relation() flag to help with
+    - Added experimental relation() flag to help with
       primaryjoins across functions, etc.,
       _local_remote_pairs=[tuples].  This complements a complex
       primaryjoin condition allowing you to provide the
@@ -32,44 +92,39 @@ CHANGES
       Query.order_by() if clause adaption had taken place.
       [ticket:1027]
       
-    - Removed an ancient assertion that mapped selectables
-      require "alias names" - the mapper creates its own alias
-      now if none is present.  Though in this case you need to
-      use the class, not the mapped selectable, as the source of
+    - Removed ancient assertion that mapped selectables require
+      "alias names" - the mapper creates its own alias now if
+      none is present.  Though in this case you need to use the
+      class, not the mapped selectable, as the source of column
+      attributes - so a warning is still issued.
+
+    - fixes to the "exists" function involving inheritance (any(), has(),
+      ~contains()); the full target join will be rendered into the
+      EXISTS clause for relations that link to subclasses.
+      
+    - restored usage of append_result() extension method for primary 
+      query rows, when the extension is present and only a single-
+      entity result is being returned.
+    
+    - Also re-established viewonly relation() configurations that
+      join across multiple tables.
+      
+    - removed ancient assertion that mapped selectables require
+      "alias names" - the mapper creates its own alias now if
+      none is present.  Though in this case you need to use 
+      the class, not the mapped selectable, as the source of
       column attributes - so a warning is still issued.
-
-    - Fixes to the "exists" function involving inheritance
-      (any(), has(), ~contains()); the full target join will be
-      rendered into the EXISTS clause for relations that link to
-      subclasses.
-
-    - Restored usage of append_result() extension method for
-      primary query rows, when the extension is present and only
-      a single- entity result is being returned.
-
-    - Fixed Class.collection==None for m2m relationships
-      [ticket:4213]
-
-    - Refined mapper._save_obj() which was unnecessarily calling
+      
+    - refined mapper._save_obj() which was unnecessarily calling
       __ne__() on scalar values during flush [ticket:1015]
-
-    - Added a feature to eager loading whereby subqueries set as
-      column_property() with explicit label names (which is not
-      necessary, btw) will have the label anonymized when the
-      instance is part of the eager join, to prevent conflicts
-      with a subquery or column of the same name on the parent
-      object.  [ticket:1019]
-
-    - Same as [ticket:1019] but repaired the non-labeled use
-      case [ticket:1022]
-
-    - Adjusted class-member inspection during attribute and
-      collection instrumentation that could be problematic when
-      integrating with other frameworks.
-
-    - Fixed duplicate append event emission on repeated
-      instrumented set.add() operations.
-
+      
+    - added a feature to eager loading whereby subqueries set
+      as column_property() with explicit label names (which is not
+      necessary, btw) will have the label anonymized when
+      the instance is part of the eager join, to prevent
+      conflicts with a subquery or column of the same name 
+      on the parent object.  [ticket:1019]
+      
     - set-based collections |=, -=, ^= and &= are stricter about
       their operands and only operate on sets, frozensets or
       subclasses of the collection type. Previously, they would
@@ -79,6 +134,16 @@ CHANGES
       a simple way to place dictionary behavior on top of 
       a dynamic_loader.
 
+- declarative extension
+    - Joined table inheritance mappers use a slightly relaxed
+      function to create the "inherit condition" to the parent
+      table, so that other foreign keys to not-yet-declared 
+      Table objects don't trigger an error.
+      
+    - fixed reentrant mapper compile hang when 
+      a declared attribute is used within ForeignKey, 
+      ie. ForeignKey(MyOtherClass.someattribute)
+      
 - sql
     - Added COLLATE support via the .collate(<collation>)
       expression operator and collate(<expr>, <collation>) sql
@@ -87,38 +152,29 @@ CHANGES
     - Fixed bug with union() when applied to non-Table connected
       select statements
 
-    - Improved behavior of text() expressions when used as FROM
-      clauses, such as select().select_from(text("sometext"))
+    - improved behavior of text() expressions when used as 
+      FROM clauses, such as select().select_from(text("sometext"))
       [ticket:1014]
 
-    - Column.copy() respects the value of "autoincrement", fixes
-      usage with Migrate [ticket:1021]
-
+    - Column.copy() respects the value of "autoincrement",
+      fixes usage with Migrate [ticket:1021]
+      
 - engines
     - Pool listeners can now be provided as a dictionary of
       callables or a (possibly partial) duck-type of
       PoolListener, your choice.
+      
+    - added "rollback_returned" option to Pool which will 
+      disable the rollback() issued when connections are 
+      returned.  This flag is only safe to use with a database
+      which does not support transactions (i.e. MySQL/MyISAM).
 
-    - Added "reset_on_return" option to Pool which will disable
-      the database state cleanup step (e.g. issuing a
-      rollback()) when connections are returned to the pool.
-
--extensions
+-ext
     - set-based association proxies |=, -=, ^= and &= are
       stricter about their operands and only operate on sets,
       frozensets or other association proxies. Previously, they
       would accept any duck-typed set.
 
-- declarative extension
-    - Joined table inheritance mappers use a slightly relaxed
-      function to create the "inherit condition" to the parent
-      table, so that other foreign keys to not-yet-declared
-      Table objects don't trigger an error.
-
-    - Fixed re-entrant mapper compile hang when a declared
-      attribute is used within ForeignKey,
-      i.e. ForeignKey(MyOtherClass.someattribute)
-
 - mssql
     - Added "odbc_autotranslate" parameter to engine / dburi
       parameters. Any given string will be passed through to the
@@ -135,11 +191,11 @@ CHANGES
       This should obviate the need of adding a myriad of ODBC
       options in the future.
 
+
 - firebird
     - Handle the "SUBSTRING(:string FROM :start FOR :length)"
       builtin.
 
-
 0.4.5
 =====
 - orm
@@ -175,81 +231,78 @@ CHANGES
     - Added comparable_property(), adds query Comparator
       behavior to regular, unmanaged Python properties
 
-    - The functionality of query.with_polymorphic() has been
-      added to mapper() as a configuration option.
+    - the functionality of query.with_polymorphic() has 
+      been added to mapper() as a configuration option.  
 
       It's set via several forms:
-
             with_polymorphic='*'
             with_polymorphic=[mappers]
             with_polymorphic=('*', selectable)
             with_polymorphic=([mappers], selectable)
-
-      This controls the default polymorphic loading strategy for
-      inherited mappers. When a selectable is not given, outer
-      joins are created for all joined-table inheriting mappers
-      requested. Note that the auto-create of joins is not
-      compatible with concrete table inheritance.
-
-      The existing select_table flag on mapper() is now
-      deprecated and is synonymous with:
-
-        with_polymorphic('*', select_table).
-
-      Note that the underlying "guts" of select_table have been
-      completely removed and replaced with the newer, more
-      flexible approach.
-
-      The new approach also automatically allows eager loads to
-      work for subclasses, if they are present, for example
-
+    
+      This controls the default polymorphic loading strategy
+      for inherited mappers. When a selectable is not given,
+      outer joins are created for all joined-table inheriting
+      mappers requested. Note that the auto-create of joins
+      is not compatible with concrete table inheritance.
+
+      The existing select_table flag on mapper() is now 
+      deprecated and is synonymous with 
+      with_polymorphic('*', select_table).  Note that the
+      underlying "guts" of select_table have been 
+      completely removed and replaced with the newer,
+      more flexible approach.  
+      
+      The new approach also automatically allows eager loads
+      to work for subclasses, if they are present, for
+      example
         sess.query(Company).options(
          eagerload_all(
           [Company.employees.of_type(Engineer), 'machines']
         ))
-
       to load Company objects, their employees, and the
       'machines' collection of employees who happen to be
       Engineers. A "with_polymorphic" Query option should be
       introduced soon as well which would allow per-Query
       control of with_polymorphic() on relations.
-
-    - Added two "experimental" features to Query, "experimental"
-      in that their specific name/behavior is not carved in
-      stone just yet: _values() and _from_self().  We'd like
-      feedback on these.
-
-      - _values(*columns) is given a list of column expressions,
-        and returns a new Query that only returns those
-        columns. When evaluated, the return value is a list of
-        tuples just like when using add_column() or
-        add_entity(), the only difference is that "entity zero",
-        i.e. the mapped class, is not included in the
-        results. This means it finally makes sense to use
-        group_by() and having() on Query, which have been
-        sitting around uselessly until now.
-
+    
+    - added two "experimental" features to Query, 
+      "experimental" in that their specific name/behavior
+      is not carved in stone just yet:  _values() and
+      _from_self().  We'd like feedback on these.
+      
+      - _values(*columns) is given a list of column
+        expressions, and returns a new Query that only
+        returns those columns. When evaluated, the return
+        value is a list of tuples just like when using
+        add_column() or add_entity(), the only difference is
+        that "entity zero", i.e. the mapped class, is not
+        included in the results. This means it finally makes
+        sense to use group_by() and having() on Query, which
+        have been sitting around uselessly until now.  
+        
         A future change to this method may include that its
         ability to join, filter and allow other options not
         related to a "resultset" are removed, so the feedback
         we're looking for is how people want to use
-        _values()...i.e. at the very end, or do people prefer to
-        continue generating after it's called.
-
-      - _from_self() compiles the SELECT statement for the Query
-        (minus any eager loaders), and returns a new Query that
-        selects from that SELECT. So basically you can query
-        from a Query without needing to extract the SELECT
-        statement manually. This gives meaning to operations
-        like query[3:5]._from_self().filter(some
-        criterion). There's not much controversial here except
-        that you can quickly create highly nested queries that
-        are less efficient, and we want feedback on the naming
-        choice.
-
-    - query.order_by() and query.group_by() will accept multiple
-      arguments using *args (like select() already does).
-
+        _values()...i.e. at the very end, or do people prefer
+        to continue generating after it's called.
+
+      - _from_self() compiles the SELECT statement for the
+        Query (minus any eager loaders), and returns a new
+        Query that selects from that SELECT. So basically you
+        can query from a Query without needing to extract the
+        SELECT statement manually. This gives meaning to
+        operations like query[3:5]._from_self().filter(some
+        criterion). There's not much controversial here
+        except that you can quickly create highly nested
+        queries that are less efficient, and we want feedback
+        on the naming choice.
+      
+    - query.order_by() and query.group_by() will accept
+      multiple arguments using *args (like select() 
+      already does).
+      
     - Added some convenience descriptors to Query:
       query.statement returns the full SELECT construct,
       query.whereclause returns just the WHERE part of the
@@ -293,10 +346,9 @@ CHANGES
         - Delete cascade with delete-orphan will delete orphans
           whether or not it remains attached to its also-deleted
           parent.
-
-        - delete-orphan casacde is properly detected on
-          relations that are present on superclasses when using
-          inheritance.
+          
+        - delete-orphan casacde is properly detected on relations
+          that are present on superclasses when using inheritance.
 
     - Fixed order_by calculation in Query to properly alias
       mapper-config'ed order_by when using select_from()
@@ -309,16 +361,16 @@ CHANGES
       iterative to support deep object graphs.
 
 - sql
-    - Schema-qualified tables now will place the schemaname
+    - schema-qualified tables now will place the schemaname
       ahead of the tablename in all column expressions as well
       as when generating column labels.  This prevents cross-
       schema name collisions in all cases [ticket:999]
-
-    - Can now allow selects which correlate all FROM clauses and
-      have no FROM themselves.  These are typically used in a
-      scalar context, i.e. SELECT x, (SELECT x WHERE y) FROM
-      table.  Requires explicit correlate() call.
-
+      
+    - can now allow selects which correlate all FROM clauses
+      and have no FROM themselves.  These are typically
+      used in a scalar context, i.e. SELECT x, (SELECT x WHERE y)
+      FROM table.  Requires explicit correlate() call.
+      
     - 'name' is no longer a required constructor argument for
       Column().  It (and .key) may now be deferred until the
       column is added to a Table.
@@ -350,6 +402,24 @@ CHANGES
       SA will force explicit usage of either text() or
       literal().
 
+- oracle
+    - The "owner" keyword on Table is now deprecated, and is
+      exactly synonymous with the "schema" keyword.  Tables can
+      now be reflected with alternate "owner" attributes,
+      explicitly stated on the Table object or not using
+      "schema".
+
+    - All of the "magic" searching for synonyms, DBLINKs etc.
+      during table reflection are disabled by default unless you
+      specify "oracle_resolve_synonyms=True" on the Table
+      object.  Resolving synonyms necessarily leads to some
+      messy guessing which we'd rather leave off by default.
+      When the flag is set, tables and related tables will be
+      resolved against synonyms in all cases, meaning if a
+      synonym exists for a particular table, reflection will use
+      it when reflecting related tables.  This is stickier
+      behavior than before which is why it's off by default.
+
 - declarative extension
     - The "synonym" function is now directly usable with
       "declarative".  Pass in the decorated property using the
@@ -378,10 +448,10 @@ CHANGES
      - inheritance in declarative can be disabled when sending
        "inherits=None" to __mapper_args__.
 
-     - declarative_base() takes optional kwarg "mapper", which
-       is any callable/class/method that produces a mapper, such
-       as declarative_base(mapper=scopedsession.mapper).  This
-       property can also be set on individual declarative
+     - declarative_base() takes optional kwarg "mapper", which 
+       is any callable/class/method that produces a mapper,
+       such as declarative_base(mapper=scopedsession.mapper).
+       This property can also be set on individual declarative
        classes using the "__mapper_cls__" property.
 
 - postgres
@@ -408,18 +478,17 @@ CHANGES
       behavior than before which is why it's off by default.
 
 - mssql
-     - Reflected tables will now automatically load other tables
+     - Reflected tables will now automatically load other tables 
        which are referenced by Foreign keys in the auto-loaded
-       table, [ticket:979].
+       table, [ticket:979]. 
 
-     - Added executemany check to skip identity fetch,
-       [ticket:916].
+     - Added executemany check to skip identity fetch, [ticket:916].
 
      - Added stubs for small date type, [ticket:884]
 
-     - Added a new 'driver' keyword parameter for the pyodbc
-       dialect.  Will substitute into the ODBC connection string
-       if given, defaults to 'SQL Server'.
+     - Added a new 'driver' keyword parameter for the pyodbc dialect.
+       Will substitute into the ODBC connection string if given,
+       defaults to 'SQL Server'.
 
      - Added a new 'max_identifier_length' keyword parameter for
        the pyodbc dialect.
diff --git a/README_THIS_IS_NOW_VERSION_0.5.txt b/README_THIS_IS_NOW_VERSION_0.5.txt
new file mode 100644 (file)
index 0000000..4586e32
--- /dev/null
@@ -0,0 +1,14 @@
+Trunk of SQLAlchemy is now on the 0.5 version.  This version 
+removes many things which were deprecated in 0.4 and therefore 
+is not backwards compatible with all 0.4 appliactions.
+
+A work in progress describing the changes from 0.4 is at:
+
+    http://www.sqlalchemy.org/trac/wiki/05Migration
+
+To continue working with the current development revision of 
+version 0.4, switch this working copy to the 0.4 maintenance branch:
+
+    svn switch http://svn.sqlalchemy.org/sqlalchemy/branches/rel_0_4
+
+
diff --git a/VERSION b/VERSION
index ef52a648073dd38aebdd7505edb3ba36e8bfd230..f9a891bcab4d9b687eaff1d6954cd7d7c7143f9d 100644 (file)
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-0.4.6
+0.5.0alpha1
diff --git a/attributes_rollback_test.py b/attributes_rollback_test.py
new file mode 100644 (file)
index 0000000..ab0705a
--- /dev/null
@@ -0,0 +1,31 @@
+from sqlalchemy.orm import attributes
+class Foo(object):pass
+attributes.register_class(Foo)
+attributes.register_attribute(Foo, 'x', uselist=False, useobject=False, mutable_scalars=True, copy_function=lambda x:x.copy())
+
+f = Foo()
+f._foostate.set_savepoint()
+print f._foostate.get_history('x')
+
+f.x = {'1':15}
+
+
+print f._foostate.get_history('x')
+f._foostate.commit_all()
+
+print f._foostate.get_history('x')
+
+f.x['2'] = 40
+print f._foostate.get_history('x')
+
+f._foostate.rollback()
+
+print f._foostate.get_history('x')
+
+#import pdb
+#pdb.Pdb().break_here()
+
+print f.x
+f.x['2'] = 40
+print f._foostate.get_history('x')
+
index d6ded5bdf89a1266450fedce46eac1d53cd502be..c45577d808c961b90e2e8b069a853a0e35b1ac41 100644 (file)
@@ -64,103 +64,22 @@ SQLAlchemy is designed to operate with a [DB-API](http://www.python.org/doc/peps
 * SQLite:  [pysqlite](http://initd.org/tracker/pysqlite), [sqlite3](http://docs.python.org/lib/module-sqlite3.html) (included with Python 2.5 or greater)
 * MySQL:   [MySQLdb](http://sourceforge.net/projects/mysql-python)
 * Oracle:  [cx_Oracle](http://www.cxtools.net/default.aspx?nav=home)
-* MS-SQL:  [pyodbc](http://pyodbc.sourceforge.net/) (recommended), [adodbapi](http://adodbapi.sourceforge.net/)  or [pymssql](http://pymssql.sourceforge.net/)
+* MS-SQL, MSAccess:  [pyodbc](http://pyodbc.sourceforge.net/) (recommended), [adodbapi](http://adodbapi.sourceforge.net/)  or [pymssql](http://pymssql.sourceforge.net/)
 * Firebird:  [kinterbasdb](http://kinterbasdb.sourceforge.net/)
 * Informix:  [informixdb](http://informixdb.sourceforge.net/)
+* DB2/Informix IDS: [ibm-db](http://code.google.com/p/ibm-db/)
+* Sybase:   TODO
+* MAXDB:    TODO
 
 ### Checking the Installed SQLAlchemy Version
  
-This documentation covers SQLAlchemy version 0.4.  If you're working on a system that already has SQLAlchemy installed, check the version from your Python prompt like this:
+This documentation covers SQLAlchemy version 0.5.  If you're working on a system that already has SQLAlchemy installed, check the version from your Python prompt like this:
 
      {python}
      >>> import sqlalchemy
      >>> sqlalchemy.__version__ # doctest: +SKIP
-     0.4.0
+     0.5.0
 
-## 0.3 to 0.4 Migration {@name=migration}
+## 0.4 to 0.5 Migration {@name=migration}
 
-From version 0.3 to version 0.4 of SQLAlchemy, some conventions have changed.  Most of these conventions are available in the most recent releases of the 0.3 series starting with version 0.3.9, so that you can make a 0.3 application compatible with 0.4 in most cases.
-
-This section will detail only those things that have changed in a backwards-incompatible manner.  For a full overview of everything that's new and changed, see [WhatsNewIn04](http://www.sqlalchemy.org/trac/wiki/WhatsNewIn04).
-
-### ORM Package is now sqlalchemy.orm {@name=imports}
-
-All symbols related to the SQLAlchemy Object Relational Mapper, i.e. names like `mapper()`, `relation()`, `backref()`, `create_session()` `synonym()`, `eagerload()`, etc. are now only in the `sqlalchemy.orm` package, and **not** in `sqlalchemy`.  So if you were previously importing everything on an asterisk:
-
-    {python}
-    from sqlalchemy import *
-    
-You should now import separately from orm:
-
-    {python}
-    from sqlalchemy import *
-    from sqlalchemy.orm import *
-    
-Or more commonly, just pull in the names you'll need:
-
-    {python}
-    from sqlalchemy import create_engine, MetaData, Table, Column, types
-    from sqlalchemy.orm import mapper, relation, backref, create_session
-
-### BoundMetaData is now MetaData {@name=metadata}
-
-The `BoundMetaData` name is removed.  Now, you just use `MetaData`.  Additionally, the `engine` parameter/attribute is now called `bind`, and `connect()` is deprecated:
-
-    {python}
-    # plain metadata
-    meta = MetaData()
-    
-    # metadata bound to an engine
-    meta = MetaData(engine)
-    
-    # bind metadata to an engine later
-    meta.bind = engine
-    
-Additionally, `DynamicMetaData` is now known as `ThreadLocalMetaData`.
-
-### "Magic" Global MetaData removed {@name=global}
-
-There was an old way to specify `Table` objects using an implicit, global `MetaData` object.  To do this you'd omit the second positional argument, and specify `Table('tablename', Column(...))`.  This no longer exists in 0.4 and the second `MetaData` positional argument is required, i.e. `Table('tablename', meta, Column(...))`.
-
-### Some existing select() methods become generative {@name=generative}
-
-The methods `correlate()`, `order_by()`, and `group_by()` on the `select()` construct now return a **new** select object, and do not change the original one.  Additionally, the generative methods `where()`, `column()`, `distinct()`, and several others have been added:
-
-    {python}
-    s = table.select().order_by(table.c.id).where(table.c.x==7)
-    result = engine.execute(s)
-
-### collection_class behavior is changed {@name=collection}
-
-If you've been using the `collection_class` option on `mapper()`, the requirements for instrumented collections have changed.  For an overview, see [advdatamapping_relation_collections](rel:advdatamapping_relation_collections).
-
-### All "engine", "bind_to", "connectable" Keyword Arguments Changed to "bind" {@name=bind}
-
-This is for create/drop statements, sessions, SQL constructs, metadatas:
-
-    {python}
-    myengine = create_engine('sqlite://')
-
-    meta = MetaData(myengine)
-
-    meta2 = MetaData()
-    meta2.bind = myengine
-
-    session = create_session(bind=myengine)
-
-    statement = select([table], bind=myengine)
-    
-    meta.create_all(bind=myengine)
-    
-### All "type" Keyword Arguments Changed to "type_" {@name=type}
-
-This mostly applies to SQL constructs where you pass a type in:
-
-    {python}
-    s = select([mytable], mytable.c.x=bindparam(y, type_=DateTime))
-    
-    func.now(type_=DateTime)
-    
-### Mapper Extensions must return EXT_CONTINUE to continue execution to the next mapper
-
-If you extend the mapper, the methods in your mapper extension must return EXT_CONTINUE to continue executing additional mappers.
+Notes on what's changed from 0.4 to 0.5 is available on the SQLAlchemy wiki at [05Migration](http://www.sqlalchemy.org/trac/wiki/05Migration).
index fca2076bc7577e00f5fbdf6f6598fc21fa5d5065..1035b89dfcfbcf9ab222e52d5cc531de4a8c95c9 100644 (file)
@@ -1105,9 +1105,9 @@ When using `primaryjoin` and `secondaryjoin`, SQLAlchemy also needs to be aware
     {python}
     mapper(Address, addresses_table)
     mapper(User, users_table, properties={
-        'addresses' : relation(Address, 
-             primaryjoin=users_table.c.user_id==addresses_table.c.user_id,
-             foreign_keys=[addresses_table.c.user_id])
+        'addresses' : relation(Address, primaryjoin=
+                    users_table.c.user_id==addresses_table.c.user_id,
+                    foreign_keys=[addresses_table.c.user_id])
     })
 
 ##### Building Query-Enabled Properties {@name=properties}
@@ -1361,9 +1361,11 @@ or more simply just use `eagerload_all()`:
 
 There are two other loader strategies available, **dynamic loading** and **no loading**; these are described in [advdatamapping_relation_largecollections](rel:advdatamapping_relation_largecollections).
 
-##### Combining Eager Loads with Statement/Result Set Queries
+##### Routing Explicit Joins/Statements into Eagerly Loaded Collections {@name=containseager}
 
-When full statement or result-set loads are used with `Query`, SQLAlchemy does not affect the SQL query itself, and therefore has no way of tacking on its own `LEFT [OUTER] JOIN` conditions that are normally used to eager load relationships.  If the query being constructed is created in such a way that it returns rows not just from a parent table (or tables) but also returns rows from child tables, the result-set mapping can be notified as to which additional properties are contained within the result set.  This is done using the `contains_eager()` query option, which specifies the name of the relationship to be eagerly loaded.
+When full statement loads are used with `Query`, the user defined SQL is used verbatim and the `Query` does not play any role in generating it.  In this scenario, if eager loading is desired, the `Query` should be informed as to what collections should also be loaded from the result set.  Similarly, Queries which compile their statement in the usual way may also have user-defined joins built in which are synonymous with what eager loading would normally produce, and it improves performance to utilize those same JOINs for both purposes, instead of allowing the eager load mechanism to generate essentially the same JOIN redundantly.   Yet another use case for such a feature is a Query which returns instances with a filtered view of their collections loaded, in which case the default eager load mechanisms need to be bypassed.
+
+The single option `Query` provides to control this is the `contains_eager()` option, which specifies the path of a single relationship to be eagerly loaded.  Like all relation-oriented options, it takes a string or Python descriptor as an argument.  Below it's used with a `from_statement` load:
 
     {python}
     # mapping is the users->addresses mapping
@@ -1372,7 +1374,7 @@ When full statement or result-set loads are used with `Query`, SQLAlchemy does n
     })
     
     # define a query on USERS with an outer join to ADDRESSES
-    statement = users_table.outerjoin(addresses_table).select(use_labels=True)
+    statement = users_table.outerjoin(addresses_table).select().apply_labels()
     
     # construct a Query object which expects the "addresses" results 
     query = session.query(User).options(contains_eager('addresses'))
@@ -1380,36 +1382,48 @@ When full statement or result-set loads are used with `Query`, SQLAlchemy does n
     # get results normally
     r = query.from_statement(statement)
 
+It works just as well with an inline `Query.join()` or `Query.outerjoin()`:
+
+    {python}
+    session.query(User).outerjoin(User.addresses).options(contains_eager(User.addresses)).all()
+
 If the "eager" portion of the statement is "aliased", the `alias` keyword argument to `contains_eager()` may be used to indicate it.  This is a string alias name or reference to an actual `Alias` object:
 
     {python}
-    # use an alias of the addresses table
-    adalias = addresses_table.alias('adalias')
+    # use an alias of the Address entity
+    adalias = aliased(Address)
     
-    # define a query on USERS with an outer join to adalias
-    statement = users_table.outerjoin(adalias).select(use_labels=True)
-
     # construct a Query object which expects the "addresses" results 
-    query = session.query(User).options(contains_eager('addresses', alias=adalias))
+    query = session.query(User).outerjoin((adalias, User.addresses)).options(contains_eager(User.addresses, alias=adalias))
 
     # get results normally
-    {sql}r = query.from_statement(statement).all()
+    {sql}r = query.all()
     SELECT users.user_id AS users_user_id, users.user_name AS users_user_name, adalias.address_id AS adalias_address_id, 
     adalias.user_id AS adalias_user_id, adalias.email_address AS adalias_email_address, (...other columns...)
-    FROM users LEFT OUTER JOIN email_addresses AS adalias ON users.user_id = adalias.user_id
+    FROM users LEFT OUTER JOIN email_addresses AS email_addresses_1 ON users.user_id = email_addresses_1.user_id
+
+The path given as the argument to `contains_eager()` needs to be a full path from the starting entity.  For example if we were loading `Users->orders->Order->items->Item`, the string version would look like:
 
-In the case that the main table itself is also aliased, the `contains_alias()` option can be used:
+    {python}
+    query(User).options(contains_eager('orders', 'items'))
+    
+The descriptor version like:
+
+    {python}
+    query(User).options(contains_eager(User.orders, Order.items))
+    
+A variant on `contains_eager()` is the `contains_alias()` option, which is used in the rare case that the parent object is loaded from an alias within a user-defined SELECT statement:
 
     {python}
     # define an aliased UNION called 'ulist'
     statement = users.select(users.c.user_id==7).union(users.select(users.c.user_id>7)).alias('ulist')
 
     # add on an eager load of "addresses"
-    statement = statement.outerjoin(addresses).select(use_labels=True)
+    statement = statement.outerjoin(addresses).select().apply_labels()
     
     # create query, indicating "ulist" is an alias for the main table, "addresses" property should
     # be eager loaded
-    query = create_session().query(User).options(contains_alias('ulist'), contains_eager('addresses'))
+    query = session.query(User).options(contains_alias('ulist'), contains_eager('addresses'))
     
     # results
     r = query.from_statement(statement)
index 85ffa0ff93a6a9ae54621a2383c8673858257934..f784d57cf65109cb64619141e8a227bf0dd6dc35 100644 (file)
@@ -1,23 +1,20 @@
-[alpha_api]: javascript:alphaApi()
-[alpha_implementation]: javascript:alphaImplementation()
-
 Object Relational Tutorial {@name=datamapping}
 ============
 
-In this tutorial we will cover a basic SQLAlchemy object-relational mapping scenario, where we store and retrieve Python objects from a database representation.  The database schema will begin with one table, and will later develop into several.  The tutorial is in doctest format, meaning each `>>>` line represents something you can type at a Python command prompt, and the following text represents the expected return value.  The tutorial has no prerequisites.
+In this tutorial we will cover a basic SQLAlchemy object-relational mapping scenario, where we store and retrieve Python objects from a database representation.  The tutorial is in doctest format, meaning each `>>>` line represents something you can type at a Python command prompt, and the following text represents the expected return value.
 
 ## Version Check
 
-A quick check to verify that we are on at least **version 0.4** of SQLAlchemy:
+A quick check to verify that we are on at least **version 0.5** of SQLAlchemy:
 
     {python}
     >>> import sqlalchemy
     >>> sqlalchemy.__version__ # doctest:+SKIP
-    0.4.0
+    0.5.0
     
 ## Connecting
 
-For this tutorial we will use an in-memory-only SQLite database.   This is an easy way to test things without needing to have an actual database defined anywhere.  To connect we use `create_engine()`:
+For this tutorial we will use an in-memory-only SQLite database.  To connect we use `create_engine()`:
 
     {python}
     >>> from sqlalchemy import create_engine
@@ -27,21 +24,21 @@ The `echo` flag is a shortcut to setting up SQLAlchemy logging, which is accompl
     
 ## Define and Create a Table {@name=tables}
 
-Next we want to tell SQLAlchemy about our tables.  We will start with just a single table called `users`, which will store records for the end-users using our application (lets assume it's a website).  We define our tables all within a catalog called `MetaData`, using the `Table` construct, which resembles regular SQL CREATE TABLE syntax:
+Next we want to tell SQLAlchemy about our tables.  We will start with just a single table called `users`, which will store records for the end-users using our application (lets assume it's a website).  We define our tables within a catalog called `MetaData`, using the `Table` construct, which is used in a manner similar to SQL's CREATE TABLE syntax:
 
     {python}
     >>> from sqlalchemy import Table, Column, Integer, String, MetaData, ForeignKey    
     >>> metadata = MetaData()
     >>> users_table = Table('users', metadata,
     ...     Column('id', Integer, primary_key=True),
-    ...     Column('name', String(40)),
-    ...     Column('fullname', String(100)),
-    ...     Column('password', String(15))
+    ...     Column('name', String),
+    ...     Column('fullname', String),
+    ...     Column('password', String)
     ... )
 
-All about how to define `Table` objects, as well as how to create them from an existing database automatically, is described in [metadata](rel:metadata).
+All about how to define `Table` objects, as well as how to load their definition from an existing database (known as **reflection**), is described in [metadata](rel:metadata).
 
-Next, to tell the `MetaData` we'd actually like to create our `users_table` for real inside the SQLite database, we use `create_all()`, passing it the `engine` instance which points to our database.  This will check for the presence of a table first before creating, so it's safe to call multiple times:
+Next, we can issue CREATE TABLE statements derived from our table metadata, by calling `create_all()` and passing it the `engine` instance which points to our database.  This will check for the presence of a table first before creating, so it's safe to call multiple times:
 
     {python}
     {sql}>>> metadata.create_all(engine) # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE
@@ -49,19 +46,24 @@ Next, to tell the `MetaData` we'd actually like to create our `users_table` for
     {}
     CREATE TABLE users (
         id INTEGER NOT NULL, 
-        name VARCHAR(40)
-        fullname VARCHAR(100)
-        password VARCHAR(15)
+        name VARCHAR, 
+        fullname VARCHAR, 
+        password VARCHAR, 
         PRIMARY KEY (id)
     )
     {}
     COMMIT
 
-So now our database is created, our initial schema is present, and our SQLAlchemy application knows all about the tables and columns in the database; this information is to be re-used by the Object Relational Mapper, as we'll see now.
+Users familiar with the syntax of CREATE TABLE may notice that the VARCHAR columns were generated without a length; on SQLite, this is a valid datatype, but on most databases it's not allowed.  So if running this tutorial on a database such as Postgres or MySQL, and you wish to use SQLAlchemy to generate the tables, a "length" may be provided to the `String` type as below:
+
+    {python}
+    Column('name', String(50))
+    
+The length field on `String`, as well as similar precision/scale fields available on `Integer`, `Numeric`, etc. are not referenced by SQLAlchemy other than when creating tables.
+
 ## Define a Python Class to be Mapped {@name=mapping}
 
-So lets create a rudimentary `User` object to be mapped in the database.  This object will for starters have three attributes, `name`, `fullname` and `password`.  It only need subclass Python's built-in `object` class (i.e. it's a new style class).  We will give it a constructor so that it may conveniently be instantiated with its attributes at once, as well as a `__repr__` method so that we can get a nice string representation of it:
+While the `Table` object defines information about our database, it does not say anything about the definition or behavior of the business objects used by our application;  SQLAlchemy views this as a separate concern.  To correspond to our `users` table, let's create a rudimentary `User` class.  It only need subclass Python's built-in `object` class (i.e. it's a new style class):
 
     {python}
     >>> class User(object):
@@ -73,6 +75,8 @@ So lets create a rudimentary `User` object to be mapped in the database.  This o
     ...     def __repr__(self):
     ...        return "<User('%s','%s', '%s')>" % (self.name, self.fullname, self.password)
 
+The class has an `__init__()` and a `__repr__()` method for convenience.  These methods are both entirely optional, and can be of any form.  SQLAlchemy never calls `__init__()` directly.
+
 ## Setting up the Mapping
 
 With our `users_table` and `User` class, we now want to map the two together.  That's where the SQLAlchemy ORM package comes in.  We'll use the `mapper` function to create a **mapping** between `users_table` and `User`:
@@ -80,9 +84,9 @@ With our `users_table` and `User` class, we now want to map the two together.  T
     {python}
     >>> from sqlalchemy.orm import mapper
     >>> mapper(User, users_table) # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE
-    <sqlalchemy.orm.mapper.Mapper object at 0x...>
+    <Mapper at 0x...; User>
     
-The `mapper()` function creates a new `Mapper` object and stores it away for future reference.  It also **instruments** the attributes on our `User` class, corresponding to the `users_table` table.  The `id`, `name`, `fullname`, and `password` columns in our `users_table` are now instrumented upon our `User` class, meaning it will keep track of all changes to these attributes, and can save and load their values to/from the database.  Lets create our first user, 'Ed Jones', and ensure that the object has all three of these attributes:
+The `mapper()` function creates a new `Mapper` object and stores it away for future reference, associated with our class.  Let's now create and inspect a `User` object:
 
     {python}
     >>> ed_user = User('ed', 'Ed Jones', 'edspassword')
@@ -93,53 +97,87 @@ The `mapper()` function creates a new `Mapper` object and stores it away for fut
     >>> str(ed_user.id)
     'None'
     
-What was that last `id` attribute?  That was placed there by the `Mapper`, to track the value of the `id` column in the `users_table`.  Since our `User` doesn't exist in the database, its id is `None`.  When we save the object, it will get populated automatically with its new id.
+The `id` attribute, which while not defined by our `__init__()` method, exists due to the `id` column present within the `users_table` object.  By default, the `mapper` creates class attributes for all columns present within the `Table`.  These class attributes exist as Python descriptors, and define **instrumentation** for the mapped class.  The functionality of this instrumentation is very rich and includes the ability to track modifications and automatically load new data from the database when needed.
 
-## Too Verbose ?  There are alternatives
+Since we have not yet told SQLAlchemy to persist `Ed Jones` within the database, its id is `None`.  When we persist the object later, this attribute will be populated with a newly generated value.
 
-The full set of steps to map a class, which are to define a `Table`, define a class, and then define a `mapper()`, are fairly verbose and for simple cases may appear overly disjoint.   Most popular object relational products use the so-called "active record" approach, where the table definition and its class mapping are all defined at once.  With SQLAlchemy, there are two excellent alternatives to its usual configuration which provide this approach:
+## Creating Table, Class and Mapper All at Once Declaratively {@name=declarative}
 
-  * [Elixir](http://elixir.ematia.de/) is a "sister" product to SQLAlchemy, which is a full "declarative" layer built on top of SQLAlchemy.  It has existed almost as long as SA itself and defines a rich featureset on top of SA's normal configuration, adding many new capabilities such as plugins, automatic generation of table and column names based on configurations, and an intuitive system of defining relations.
-  * [declarative](rel:plugins_declarative) is a so-called "micro-declarative" plugin included with SQLAlchemy 0.4.4 and above.  In contrast to Elixir, it maintains the use of the same configurational constructs outlined in this tutorial, except it allows the `Column`, `relation()`, and other constructs to be defined "inline" with the mapped class itself, so that explicit calls to `Table` and `mapper()` are not needed in most cases.
+The preceding approach to configuration involving a `Table`, user-defined class, and `mapper()` call illustrate classical SQLAlchemy usage, which values the highest separation of concerns possible.  A large number of applications don't require this degree of separation, and for those SQLAlchemy offers an alternate "shorthand" configurational style called **declarative**.  For many applications, this is the only style of configuration needed.  Our above example using this style is as follows:
 
-With either declarative layer it's a good idea to be familiar with SQLAlchemy's "base" configurational style in any case.  But now that we have our configuration started, we're ready to look at how to build sessions and query the database; this process is the same regardless of configurational style.
+    {python}
+    >>> from sqlalchemy.ext.declarative import declarative_base
+    
+    >>> Base = declarative_base()
+    >>> class User(Base):
+    ...     __tablename__ = 'users'
+    ...
+    ...     id = Column(Integer, primary_key=True)
+    ...     name = Column(String)
+    ...     fullname = Column(String)
+    ...     password = Column(String)
+    ...
+    ...     def __init__(self, name, fullname, password):
+    ...         self.name = name
+    ...         self.fullname = fullname
+    ...         self.password = password
+    ...
+    ...     def __repr__(self):
+    ...        return "<User('%s','%s', '%s')>" % (self.name, self.fullname, self.password)
 
-## Creating a Session
+Above, the `declarative_base()` function defines a new class which we name `Base`, from which all of our ORM-enabled classes will derive.  Note that we define `Column` objects with no "name" field, since it's inferred from the given attribute name.
+
+The underlying `Table` object created by our `declarative_base()` version of `User` is accessible via the `__table__` attribute:
 
-We're now ready to start talking to the database.  The ORM's "handle" to the database is the `Session`.  When we first set up the application, at the same level as our `create_engine()` statement, we define a second object called `Session` (or whatever you want to call it, `create_session`, etc.) which is configured by the `sessionmaker()` function.  This function is configurational and need only be called once.  
+    {python}
+    >>> users_table = User.__table__
     
+and the owning `MetaData` object is available as well:
+
+    {python}
+    >>> metadata = Base.metadata
+
+Yet another "declarative" method is available for SQLAlchemy as a third party library called [Elixir](http://elixir.ematia.de/).  This is a full-featured configurational product which also includes many higher level mapping configurations built in.  Like declarative, once classes and mappings are defined, ORM usage is the same as with a classical SQLAlchemy configuration.
+
+## Creating a Session
+
+We're now ready to start talking to the database.  The ORM's "handle" to the database is the `Session`.  When we first set up the application, at the same level as our `create_engine()` statement, we define a `Session` class which will serve as a factory for new `Session` objects:
+
     {python}
     >>> from sqlalchemy.orm import sessionmaker
-    >>> Session = sessionmaker(bind=engine, autoflush=True, transactional=True)
+    >>> Session = sessionmaker(bind=engine)
 
 In the case where your application does not yet have an `Engine` when you define your module-level objects, just set it up like this:
 
     {python}
-    >>> Session = sessionmaker(autoflush=True, transactional=True)
+    >>> Session = sessionmaker()
 
 Later, when you create your engine with `create_engine()`, connect it to the `Session` using `configure()`:
 
     {python}
     >>> Session.configure(bind=engine)  # once engine is available
     
-This `Session` class will create new `Session` objects which are bound to our database and have the transactional characteristics we've configured.  Whenever you need to have a conversation with the database, you instantiate a `Session`:
+This custom-made `Session` class will create new `Session` objects which are bound to our database.  Other transactional characteristics may be defined when calling `sessionmaker()` as well; these are described in a later chapter.  Then, whenever you need to have a conversation with the database, you instantiate a `Session`:
 
     {python}
     >>> session = Session()
     
-The above `Session` is associated with our SQLite `engine`, but it hasn't opened any connections yet.  When it's first used, it retrieves a connection from a pool of connections maintained by the `engine`, and holds onto it until we commit all changes and/or close the session object.  Because we configured `transactional=True`, there's also a transaction in progress (one notable exception to this is MySQL, when you use its default table style of MyISAM).  There's options available to modify this behavior but we'll go with this straightforward version to start.    
+The above `Session` is associated with our SQLite `engine`, but it hasn't opened any connections yet.  When it's first used, it retrieves a connection from a pool of connections maintained by the `engine`, and holds onto it until we commit all changes and/or close the session object.
 
-## Saving Objects
+## Adding new Objects
 
-So saving our `User` is as easy as issuing `save()`:
+To persist our `User` object, we `add()` it to our `Session`:
 
     {python}
-    >>> session.save(ed_user)
+    >>> ed_user = User('ed', 'Ed Jones', 'edspassword')
+    >>> session.add(ed_user)
     
-But you'll notice nothing has happened yet.  Well, lets pretend something did, and try to query for our user.  This is done using the `query()` method on `Session`.  We create a new query representing the set of all `User` objects first.  Then we narrow the results by "filtering" down to the user we want; that is, the user whose `name` attribute is `"ed"`.  Finally we call `first()` which tells `Query`, "we'd like the first result in this list".
+At this point, the instance is **pending**; no SQL has yet been issued.  The `Session` will issue the SQL to persist `Ed Jones` as soon as is needed, using a process known as a **flush**.  If we query the database for `Ed Jones`, all pending information will first be flushed, and the query is issued afterwards.
+
+For example, below we create a new `Query` object which loads instances of `User`.  We "filter by" the `name` attribute of `ed`, and indicate that we'd like only the first result in the full list of rows.  A `User` instance is returned which is equivalent to that which we've added:
 
     {python}
-    {sql}>>> session.query(User).filter_by(name='ed').first() # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE
+    {sql}>>> our_user = session.query(User).filter_by(name='ed').first() # doctest:+ELLIPSIS,+NORMALIZE_WHITESPACE
     BEGIN
     INSERT INTO users (name, fullname, password) VALUES (?, ?, ?)
     ['ed', 'Ed Jones', 'edspassword']
@@ -148,23 +186,45 @@ But you'll notice nothing has happened yet.  Well, lets pretend something did, a
     WHERE users.name = ? ORDER BY users.oid 
      LIMIT 1 OFFSET 0
     ['ed']
-    {stop}<User('ed','Ed Jones', 'edspassword')>
+    {stop}>>> our_user
+    <User('ed','Ed Jones', 'edspassword')>
+
+In fact, the `Session` has identified that the row returned is the **same** row as one already represented within its internal map of objects, so we actually got back the identical instance as that which we just added:
+
+    {python}
+    >>> ed_user is our_user
+    True
 
-And we get back our new user.  If you view the generated SQL, you'll see that the `Session` issued an `INSERT` statement before querying.  The `Session` stores whatever you put into it in memory, and at certain points it issues a **flush**, which issues SQL to the database to store all pending new objects and changes to existing objects.  You can manually invoke the flush operation using `flush()`; however when the `Session` is configured to `autoflush`, it's usually not needed.
+The ORM concept at work here is known as an **identity map** and ensures that all operations upon a particular row within a `Session` operate upon the same set of data.  Once an object with a particular primary key is present in the `Session`, all SQL queries on that `Session` will always return the same Python object for that particular primary key; it also will raise an error if an attempt is made to place a second, already-persisted object with the same primary key within the session.
 
-OK, let's do some more operations.  We'll create and save three more users:
+We can add more `User` objects at once using `add_all()`:
 
     {python}
-    >>> session.save(User('wendy', 'Wendy Williams', 'foobar'))
-    >>> session.save(User('mary', 'Mary Contrary', 'xxg527'))
-    >>> session.save(User('fred', 'Fred Flinstone', 'blah'))
+    >>> session.add_all([
+    ...     User('wendy', 'Wendy Williams', 'foobar'),
+    ...     User('mary', 'Mary Contrary', 'xxg527'),
+    ...     User('fred', 'Fred Flinstone', 'blah')])
 
 Also, Ed has already decided his password isn't too secure, so lets change it:
     
     {python}
     >>> ed_user.password = 'f8s7ccs'
+
+The `Session` is paying attention.  It knows, for example, that `Ed Jones` has been modified:
+    
+    {python}
+    >>> session.dirty
+    IdentitySet([<User('ed','Ed Jones', 'f8s7ccs')>])
     
-Then we'll permanently store everything thats been changed and added to the database.  We do this via `commit()`:
+and that three new `User` objects are pending:
+
+    {python}
+    >>> session.new  # doctest: +NORMALIZE_WHITESPACE
+    IdentitySet([<User('wendy','Wendy Williams', 'foobar')>, 
+    <User('mary','Mary Contrary', 'xxg527')>, 
+    <User('fred','Fred Flinstone', 'blah')>])
+    
+We tell the `Session` that we'd like to issue all remaining changes to the database and commit the transaction, which has been in progress throughout.  We do this via `commit()`:
 
     {python}
     {sql}>>> session.commit()
@@ -183,114 +243,139 @@ Then we'll permanently store everything thats been changed and added to the data
 If we look at Ed's `id` attribute, which earlier was `None`, it now has a value:
 
     {python}
-    >>> ed_user.id
-    1
-
-After each `INSERT` operation, the `Session` assigns all newly generated ids and column defaults to the mapped object instance.  For column defaults which are database-generated and are not part of the table's primary key, they'll be loaded when you first reference the attribute on the instance.
-
-One crucial thing to note about the `Session` is that each object instance is cached within the Session, based on its primary key identifier.  The reason for this cache is not as much for performance as it is for maintaining an **identity map** of instances.  This map guarantees that whenever you work with a particular `User` object in a session, **you always get the same instance back**.  As below, reloading Ed gives us the same instance back:
-
-    {python}
-    {sql}>>> ed_user is session.query(User).filter_by(name='ed').one() # doctest: +NORMALIZE_WHITESPACE
+    {sql}>>> ed_user.id # doctest: +NORMALIZE_WHITESPACE
     BEGIN
     SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
     FROM users 
-    WHERE users.name = ? ORDER BY users.oid 
-    LIMIT 2 OFFSET 0
-    ['ed']
-    {stop}True
+    WHERE users.id = ?
+    [1]
+    {stop}1
 
-The `get()` method, which queries based on primary key, will not issue any SQL to the database if the given key is already present:
+After the `Session` inserts new rows in the database, all newly generated identifiers and database-generated defaults become available on the instance, either immediately or via load-on-first-access.  In this case, the entire row was re-loaded on access because a new transaction was begun after we issued `commit()`.  SQLAlchemy by default refreshes data from a previous transaction the first time it's accessed within a new transaction, so that the most recent state is available.  The level of reloading is configurable as is described in the chapter on Sessions.
 
-    {python}
-    >>> ed_user is session.query(User).get(ed_user.id)
-    True
-    
 ## Querying
 
-A whirlwind tour through querying.
-
-A `Query` is created from the `Session`, relative to a particular class we wish to load.
+A `Query` is created using the `query()` function on `Session`.  This function takes a variable number of arguments, which can be any combination of classes and class-instrumented descriptors.  Below, we indicate a `Query` which loads `User` instances.  When evaluated in an iterative context, the list of `User` objects present is returned:
 
     {python}
-    >>> query = session.query(User)
+    {sql}>>> for instance in session.query(User): # doctest: +NORMALIZE_WHITESPACE
+    ...     print instance.name, instance.fullname 
+    SELECT users.id AS users_id, users.name AS users_name, 
+    users.fullname AS users_fullname, users.password AS users_password 
+    FROM users ORDER BY users.oid
+    []
+    {stop}ed Ed Jones
+    wendy Wendy Williams
+    mary Mary Contrary
+    fred Fred Flinstone
 
-Once we have a query, we can start loading objects.  The Query object, when first created, represents all the instances of its main class.  You can iterate through it directly:
+The `Query` also accepts ORM-instrumented descriptors as arguments.  Any time multiple class entities or column-based entities are expressed as arguments to the `query()` function, the return result is expressed as tuples:
 
     {python}
-    {sql}>>> for user in session.query(User):
-    ...     print user.name
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-    FROM users ORDER BY users.oid
+    {sql}>>> for name, fullname in session.query(User.name, User.fullname): # doctest: +NORMALIZE_WHITESPACE
+    ...     print name, fullname
+    SELECT users.name AS users_name, users.fullname AS users_fullname
+    FROM users
     []
-    {stop}ed
-    wendy
-    mary
-    fred
+    {stop}ed Ed Jones
+    wendy Wendy Williams
+    mary Mary Contrary
+    fred Fred Flinstone
 
-...and the SQL will be issued at the point where the query is evaluated as a list.  If you apply array slices before iterating, LIMIT and OFFSET are applied to the query:
+Basic operations with `Query` include issuing LIMIT and OFFSET, most conveniently using Python array slices and typically in conjunction with ORDER BY:
 
     {python}
-    {sql}>>> for u in session.query(User)[1:3]: #doctest: +NORMALIZE_WHITESPACE
+    {sql}>>> for u in session.query(User).order_by(User.id)[1:3]: #doctest: +NORMALIZE_WHITESPACE
     ...    print u
     SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-    FROM users ORDER BY users.oid 
+    FROM users ORDER BY users.id 
     LIMIT 2 OFFSET 1
     []
     {stop}<User('wendy','Wendy Williams', 'foobar')>
     <User('mary','Mary Contrary', 'xxg527')>
 
-Narrowing the results down is accomplished either with `filter_by()`, which uses keyword arguments:
+and filtering results, which is accomplished either with `filter_by()`, which uses keyword arguments:
 
     {python}
-    {sql}>>> for user in session.query(User).filter_by(name='ed', fullname='Ed Jones'):
-    ...    print user
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-    FROM users 
-    WHERE users.fullname = ? AND users.name = ? ORDER BY users.oid
-    ['Ed Jones', 'ed']
-    {stop}<User('ed','Ed Jones', 'f8s7ccs')>
+    {sql}>>> for name, in session.query(User.name).filter_by(fullname='Ed Jones'): # doctest: +NORMALIZE_WHITESPACE
+    ...    print name
+    SELECT users.name AS users_name FROM users 
+    WHERE users.fullname = ?
+    ['Ed Jones']
+    {stop}ed
 
-...or `filter()`, which uses SQL expression language constructs.  These allow you to use regular Python operators with the class-level attributes on your mapped class:
+...or `filter()`, which uses more flexible SQL expression language constructs.  These allow you to use regular Python operators with the class-level attributes on your mapped class:
 
     {python}
-    {sql}>>> for user in session.query(User).filter(User.name=='ed'):
-    ...    print user
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-    FROM users 
-    WHERE users.name = ? ORDER BY users.oid
-    ['ed']
-    {stop}<User('ed','Ed Jones', 'f8s7ccs')>
+    {sql}>>> for name, in session.query(User.name).filter(User.fullname=='Ed Jones'): # doctest: +NORMALIZE_WHITESPACE
+    ...    print name
+    SELECT users.name AS users_name FROM users 
+    WHERE users.fullname = ?
+    ['Ed Jones']
+    {stop}ed
 
-You can also use the `Column` constructs attached to the `users_table` object to construct SQL expressions:
+The `Query` object is fully *generative*, meaning that most method calls return a new `Query` object upon which further criteria may be added.  For example, to query for users named "ed" with a full name of "Ed Jones", you can call `filter()` twice, which joins criteria using `AND`:
 
     {python}
-    {sql}>>> for user in session.query(User).filter(users_table.c.name=='ed'):
+    {sql}>>> for user in session.query(User).filter(User.name=='ed').filter(User.fullname=='Ed Jones'): # doctest: +NORMALIZE_WHITESPACE
     ...    print user
     SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
     FROM users 
-    WHERE users.name = ? ORDER BY users.oid
-    ['ed']
+    WHERE users.name = ? AND users.fullname = ? ORDER BY users.oid
+    ['ed', 'Ed Jones']
     {stop}<User('ed','Ed Jones', 'f8s7ccs')>
 
-Most common SQL operators are available, such as `LIKE`:
 
-    {python}
-    {sql}>>> session.query(User).filter(User.name.like('%ed'))[1] # doctest: +NORMALIZE_WHITESPACE
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-    FROM users 
-    WHERE users.name LIKE ? ORDER BY users.oid 
-     LIMIT 1 OFFSET 1
-    ['%ed']
-    {stop}<User('fred','Fred Flinstone', 'blah')>
+### Common Filter Operators
 
-Note above our array index of `1` placed the appropriate LIMIT/OFFSET and returned a scalar result immediately.
+Here's a rundown of some of the most common operators used in `filter()`:
 
-The `all()`, `one()`, and `first()` methods immediately issue SQL without using an iterative context or array index.  `all()` returns a list:
+  * equals
+
+        {python}
+        query.filter(User.name == 'ed')
+    
+  * not equals
+    
+        {python}
+        query.filter(User.name != 'ed')
+    
+  * LIKE
+    
+        {python}
+        query.filter(User.name.like('%ed%'))
+        
+  * IN
+    
+        {python}
+        query.filter(User.name.in_(['ed', 'wendy', 'jack']))
+        
+  * IS NULL
+    
+        {python}
+        filter(User.name == None)
+        
+  * AND
+    
+        {python}
+        from sqlalchemy import and_
+        filter(and_(User.name == 'ed', User.fullname == 'Ed Jones'))
+        
+        # or call filter()/filter_by() multiple times
+        filter(User.name == 'ed').filter(User.fullname == 'Ed Jones')
+    
+  * OR
+        
+        {python}
+        from sqlalchemy import or_
+        filter(or_(User.name == 'ed', User.name == 'wendy'))
+        
+### Returning Lists and Scalars {@name=scalars}
+
+The `all()`, `one()`, and `first()` methods of `Query` immediately issue SQL and return a non-iterator value.  `all()` returns a list:
 
     {python}
     >>> query = session.query(User).filter(User.name.like('%ed'))
-
     {sql}>>> query.all()
     SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
     FROM users 
@@ -309,7 +394,7 @@ The `all()`, `one()`, and `first()` methods immediately issue SQL without using
     ['%ed']
     {stop}<User('ed','Ed Jones', 'f8s7ccs')>
 
-and `one()`, applies a limit of *two*, and if not exactly one row returned (no more, no less), raises an error:
+`one()`, applies a limit of *two*, and if not exactly one row returned, raises an error:
 
     {python}
     {sql}>>> try:  
@@ -323,32 +408,9 @@ and `one()`, applies a limit of *two*, and if not exactly one row returned (no m
     ['%ed']
     {stop}Multiple rows returned for one()
 
-All `Query` methods that don't return a result instead return a new `Query` object, with modifications applied.  Therefore you can call many query methods successively to build up the criterion you want:
-
-    {python}
-    {sql}>>> session.query(User).filter(User.id<2).filter_by(name='ed').\
-    ...     filter(User.fullname=='Ed Jones').all()
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-    FROM users 
-    WHERE users.id < ? AND users.name = ? AND users.fullname = ? ORDER BY users.oid
-    [2, 'ed', 'Ed Jones']
-    {stop}[<User('ed','Ed Jones', 'f8s7ccs')>]
-
-If you need to use other conjunctions besides `AND`, all SQL conjunctions are available explicitly within expressions, such as `and_()` and `or_()`, when using `filter()`:
+### Using Literal SQL {@naqme=literal}
 
-    {python}
-    >>> from sqlalchemy import and_, or_
-    
-    {sql}>>> session.query(User).filter(
-    ...    and_(User.id<224, or_(User.name=='ed', User.name=='wendy'))
-    ...    ).all()
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-    FROM users 
-    WHERE users.id < ? AND (users.name = ? OR users.name = ?) ORDER BY users.oid
-    [224, 'ed', 'wendy']
-    {stop}[<User('ed','Ed Jones', 'f8s7ccs')>, <User('wendy','Wendy Williams', 'foobar')>]
-    
-You also have full ability to use literal strings to construct SQL.  For a single criterion, use a string with `filter()`:
+Literal strings can be used flexibly with `Query`.  Most methods accept strings in addition to SQLAlchemy clause constructs.  For example, `filter()`:
 
     {python}
     {sql}>>> for user in session.query(User).filter("id<224").all():
@@ -374,8 +436,6 @@ Bind parameters can be specified with string-based SQL, using a colon.  To speci
     [224, 'fred']
     {stop}<User('fred','Fred Flinstone', 'blah')>
 
-Note that when we use constructed SQL expressions, bind parameters are generated for us automatically; we don't need to worry about them.
-       
 To use an entirely string-based statement, using `from_statement()`; just ensure that the columns clause of the statement contains the column names normally used by the mapper (below illustrated using an asterisk):
 
     {python}
@@ -384,51 +444,39 @@ To use an entirely string-based statement, using `from_statement()`; just ensure
     ['ed']
     {stop}[<User('ed','Ed Jones', 'f8s7ccs')>]
 
-`from_statement()` can also accomodate full `select()` constructs.  These are described in the [sql](rel:sql):
+## Building a Relation {@name=relation}
 
-    {python}
-    >>> from sqlalchemy import select, func
-    
-    {sql}>>> session.query(User).from_statement(
-    ...     select(
-    ...            [users_table], 
-    ...            select([func.max(users_table.c.name)]).label('maxuser')==users_table.c.name) 
-    ...    ).all() # doctest: +NORMALIZE_WHITESPACE
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-    FROM users 
-    WHERE (SELECT max(users.name) AS max_1
-    FROM users) = users.name
-    []
-    {stop}[<User('wendy','Wendy Williams', 'foobar')>]
-    
-There's also a way to combine scalar results with objects, using `add_column()`.  This is often used for functions and aggregates.  When `add_column()` (or its cousin `add_entity()`, described later) is used, tuples are returned:
+Now let's consider a second table to be dealt with.  Users in our system also can store any number of email addresses associated with their username.  This implies a basic one to many association from the `users_table` to a new table which stores email addresses, which we will call `addresses`.  Using declarative, we define this table along with its mapped class, `Address`:
 
     {python}
-    {sql}>>> for r in session.query(User).\
-    ...     add_column(select([func.max(users_table.c.name)]).label('maxuser')):
-    ...     print r # doctest: +NORMALIZE_WHITESPACE
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password, (SELECT max(users.name) AS max_1
-    FROM users) AS maxuser 
-    FROM users ORDER BY users.oid
-    []
-    {stop}(<User('ed','Ed Jones', 'f8s7ccs')>, u'wendy')
-    (<User('wendy','Wendy Williams', 'foobar')>, u'wendy')
-    (<User('mary','Mary Contrary', 'xxg527')>, u'wendy')
-    (<User('fred','Fred Flinstone', 'blah')>, u'wendy')
+    >>> from sqlalchemy import ForeignKey
+    >>> from sqlalchemy.orm import relation
+    >>> class Address(Base):
+    ...     __tablename__ = 'addresses'
+    ...     id = Column(Integer, primary_key=True)
+    ...     email_address = Column(String, nullable=False)
+    ...     user_id = Column(Integer, ForeignKey('users.id'))
+    ...
+    ...     user = relation(User, backref='addresses')
+    ...
+    ...     def __init__(self, email_address):
+    ...         self.email_address = email_address
+    ...
+    ...     def __repr__(self):
+    ...         return "<Address('%s')>" % self.email_address
 
-## Building a One-to-Many Relation {@name=onetomany}
+The above class introduces a **foreign key** constraint which references the `users` table.  This defines for SQLAlchemy the relationship between the two tables at the database level.  The relationship between the `User` and `Address` classes is defined separately using the `relation()` function, which defines an attribute `user` to be placed on the `Address` class, as well as an `addresses` collection to be placed on the `User` class.  Such a relation is known as a **bidirectional** relationship.   Because of the placement of the foreign key, from `Address` to `User` it is **many to one**, and from `User` to `Address` it is **one to many**.  SQLAlchemy is automatically aware of many-to-one/one-to-many based on foreign keys.
 
-We've spent a lot of time dealing with just one class, and one table.  Let's now look at how SQLAlchemy deals with two tables, which have a relationship to each other.   Let's say that the users in our system also can store any number of email addresses associated with their username.  This implies a basic one to many association from the `users_table` to a new table which stores email addresses, which we will call `addresses`.  We will also create a relationship between this new table to the users table, using a `ForeignKey`:
+The `relation()` function is extremely flexible, and could just have easily been defined on the `User` class:
 
     {python}
-    >>> from sqlalchemy import ForeignKey
-    
-    >>> addresses_table = Table('addresses', metadata, 
-    ...     Column('id', Integer, primary_key=True),
-    ...     Column('email_address', String(100), nullable=False),
-    ...     Column('user_id', Integer, ForeignKey('users.id')))
-    
-Another call to `create_all()` will skip over our `users` table and build just the new `addresses` table:
+    class User(Base):
+        ....
+        addresses = relation("Address", backref="user")
+        
+Where above we used the string name `"Addresses"` in the event that the `Address` class was not yet defined.   We are also free to not define a backref, and to define the `relation()` only on one class and not the other.   It is also possible to define two separate `relation()`s for either direction, which is generally safe for many-to-one and one-to-many relations, but not for many-to-many relations.
+
+We'll need to create the `addresses` table in the database, so we will issue another CREATE from our metadata, which will skip over tables which have already been created:
 
     {python}
     {sql}>>> metadata.create_all(engine) # doctest: +NORMALIZE_WHITESPACE
@@ -438,7 +486,7 @@ Another call to `create_all()` will skip over our `users` table and build just t
     {}
     CREATE TABLE addresses (
         id INTEGER NOT NULL, 
-        email_address VARCHAR(100) NOT NULL, 
+        email_address VARCHAR NOT NULL, 
         user_id INTEGER, 
         PRIMARY KEY (id), 
          FOREIGN KEY(user_id) REFERENCES users (id)
@@ -446,54 +494,21 @@ Another call to `create_all()` will skip over our `users` table and build just t
     {}
     COMMIT
 
-For our ORM setup, we're going to start all over again.  We will first close out our `Session` and clear all `Mapper` objects:
-
-    {python}
-    >>> from sqlalchemy.orm import clear_mappers
-    >>> session.close()
-    >>> clear_mappers()
-    
-Our `User` class, still around, reverts to being just a plain old class.  Lets create an `Address` class to represent a user's email address:
-
-    {python}
-    >>> class Address(object):
-    ...     def __init__(self, email_address):
-    ...         self.email_address = email_address
-    ...
-    ...     def __repr__(self):
-    ...         return "<Address('%s')>" % self.email_address
-
-Now comes the fun part.  We define a mapper for each class, and associate them using a function called `relation()`.  We can define each mapper in any order we want:
-
-    {python}
-    >>> from sqlalchemy.orm import relation
-    
-    >>> mapper(User, users_table, properties={    # doctest: +ELLIPSIS
-    ...     'addresses':relation(Address, backref='user')
-    ... })
-    <sqlalchemy.orm.mapper.Mapper object at 0x...>
-    
-    >>> mapper(Address, addresses_table) # doctest: +ELLIPSIS
-    <sqlalchemy.orm.mapper.Mapper object at 0x...>
-
-Above, the new thing we see is that `User` has defined a relation named `addresses`, which will reference a list of `Address` objects.  How does it know it's a list ?  SQLAlchemy figures it out for you, based on the foreign key relationship between `users_table` and `addresses_table`.  
+## Working with Related Objects {@name=related_objects}
 
-## Working with Related Objects and Backreferences {@name=relation_backref}
-
-Now when we create a `User`, it automatically has this collection present:
+Now when we create a `User`, a blank `addresses` collection will be present.  By default, the collection is a Python list.  Other collection types, such as sets and dictionaries, are available as well:
 
     {python}
     >>> jack = User('jack', 'Jack Bean', 'gjffdd')
     >>> jack.addresses
     []
     
-We are free to add `Address` objects, and the `session` will take care of everything for us.
+We are free to add `Address` objects on our `User` object.  In this case we just assign a full list directly:
 
     {python}
-    >>> jack.addresses.append(Address(email_address='jack@google.com'))
-    >>> jack.addresses.append(Address(email_address='j25@yahoo.com'))
-    
-Before we save into the `Session`, lets examine one other thing that's happened here.  The `addresses` collection is present on our `User` because we added a `relation()` with that name.  But also within the `relation()` function is the keyword `backref`.  This keyword indicates that we wish to make a **bi-directional relationship**.  What this basically means is that not only did we generate a one-to-many relationship called `addresses` on the `User` class, we also generated a **many-to-one** relationship on the `Address` class.  This relationship is self-updating, without any data being flushed to the database, as we can see on one of Jack's addresses:
+    >>> jack.addresses = [Address(email_address='jack@google.com'), Address(email_address='j25@yahoo.com')]
+
+When using a bidirectional relationship, elements added in one direction automatically become visible in the other direction.  This is the basic behavior of the **backref** keyword, which maintains the relationship purely in memory, without using any SQL:
 
     {python}
     >>> jack.addresses[1]
@@ -501,13 +516,12 @@ Before we save into the `Session`, lets examine one other thing that's happened
     
     >>> jack.addresses[1].user
     <User('jack','Jack Bean', 'gjffdd')>
-    
-Let's save into the session, then close out the session and create a new one...so that we can see how `Jack` and his email addresses come back to us:
+
+Let's add and commit `Jack Bean` to the database.  `jack` as well as the two `Address` members in his `addresses` collection are both added to the session at once, using a process known as **cascading**:
 
     {python}
-    >>> session.save(jack)
+    >>> session.add(jack)
     {sql}>>> session.commit()
-    BEGIN
     INSERT INTO users (name, fullname, password) VALUES (?, ?, ?)
     ['jack', 'Jack Bean', 'gjffdd']
     INSERT INTO addresses (email_address, user_id) VALUES (?, ?)
@@ -516,8 +530,6 @@ Let's save into the session, then close out the session and create a new one...s
     ['j25@yahoo.com', 5]
     COMMIT
     
-    >>> session = Session()
-    
 Querying for Jack, we get just Jack back.  No SQL is yet issued for for Jack's addresses:
 
     {python}
@@ -542,14 +554,9 @@ Let's look at the `addresses` collection.  Watch the SQL:
     [5]
     {stop}[<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
     
-When we accessed the `addresses` collection, SQL was suddenly issued.  This is an example of a **lazy loading relation**.
-
-If you want to reduce the number of queries (dramatically, in many cases), we can apply an **eager load** to the query operation.  We clear out the session to ensure that a full reload occurs:
-
-    {python}
-    >>> session.clear()
+When we accessed the `addresses` collection, SQL was suddenly issued.  This is an example of a **lazy loading relation**.  The `addresses` collection is now loaded and behaves just like an ordinary list.  
     
-Then apply an **option** to the query, indicating that we'd like `addresses` to load "eagerly".  SQLAlchemy then constructs a join between the `users` and `addresses` tables:
+If you want to reduce the number of queries (dramatically, in many cases), we can apply an **eager load** to the query operation.   With the same query, we may apply an **option** to the query, indicating that we'd like `addresses` to load "eagerly".  SQLAlchemy then constructs an outer join between the `users` and `addresses` tables, and loads them at once, populating the `addresses` collection on each `User` object if it's not already populated:
 
     {python}
     >>> from sqlalchemy.orm import eagerload
@@ -572,38 +579,30 @@ Then apply an **option** to the query, indicating that we'd like `addresses` to
     
     >>> jack.addresses
     [<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
-    
-If you think that query is elaborate, it is !  But SQLAlchemy is just getting started.  Note that when using eager loading, *nothing* changes as far as the ultimate results returned.  The "loading strategy", as it's called, is designed to be completely transparent in all cases, and is for optimization purposes only.  Any query criterion you use to load objects, including ordering, limiting, other joins, etc., should return identical results regardless of the combination of lazily- and eagerly- loaded relationships present.
 
-An eagerload targeting across multiple relations can use dot separated names:
+SQLAlchemy has the ability to control exactly which attributes and how many levels deep should be joined together in a single SQL query.  More information on this feature is available in [advdatamapping_relation](rel:advdatamapping_relation).
 
-    {python}
-    query.options(eagerload('orders'), eagerload('orders.items'), eagerload('orders.items.keywords'))
-    
-To roll up the above three individual `eagerload()` calls into one, use `eagerload_all()`:
-
-    {python}
-    query.options(eagerload_all('orders.items.keywords'))
-    
 ## Querying with Joins {@name=joins}
 
-Which brings us to the next big topic.  What if we want to create joins that *do* change the results ?  For that, another `Query` tornado is coming....
-
-One way to join two tables together is just to compose a SQL expression.   Below we make one up using the `id` and `user_id` attributes on our mapped classes:
+While the eager load created a JOIN specifically to populate a collection, we can also work explicitly with joins in many ways.  For example, to construct a simple inner join between `User` and `Address`, we can just `filter()` their related columns together.  Below we load the `User` and `Address` entities at once using this method:
 
     {python}
-    {sql}>>> session.query(User).filter(User.id==Address.user_id).\
-    ...         filter(Address.email_address=='jack@google.com').all()
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
+    {sql}>>> for u, a in session.query(User, Address).filter(User.id==Address.user_id).\
+    ...         filter(Address.email_address=='jack@google.com').all():   # doctest: +NORMALIZE_WHITESPACE
+    ...     print u, a
+    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, 
+    users.password AS users_password, addresses.id AS addresses_id, 
+    addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id 
     FROM users, addresses 
     WHERE users.id = addresses.user_id AND addresses.email_address = ? ORDER BY users.oid
     ['jack@google.com']
-    {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+    {stop}<User('jack','Jack Bean', 'gjffdd')> <Address('jack@google.com')>
 
-Or we can make a real JOIN construct; below we use the `join()` function available on `Table` to create a `Join` object, then tell the `Query` to use it as our FROM clause:
+Or we can make a real JOIN construct; one way to do so is to use the ORM `join()` function, and tell `Query` to "select from" this join:
 
     {python}
-    {sql}>>> session.query(User).select_from(users_table.join(addresses_table)).\
+    >>> from sqlalchemy.orm import join
+    {sql}>>> session.query(User).select_from(join(User, Address)).\
     ...         filter(Address.email_address=='jack@google.com').all()
     SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
     FROM users JOIN addresses ON users.id = addresses.user_id 
@@ -611,12 +610,17 @@ Or we can make a real JOIN construct; below we use the `join()` function availab
     ['jack@google.com']
     {stop}[<User('jack','Jack Bean', 'gjffdd')>]
 
-Note that the `join()` construct has no problem figuring out the correct join condition between `users_table` and `addresses_table`..the `ForeignKey` we constructed says it all.
+`join()` knows how to join between `User` and `Address` because there's only one foreign key between them.  If there were no foreign keys, or several, `join()` would require a third argument indicating the ON clause of the join, in one of the following forms:
 
-The easiest way to join is automatically, using the `join()` method on `Query`.  Just give this method the path from A to B, using the name of a mapped relationship directly:
+    {python}
+    join(User, Address, User.id==Address.user_id)  # explicit condition
+    join(User, Address, User.addresses)            # specify relation from left to right
+    join(User, Address, 'addresses')               # same, using a string
+    
+The functionality of `join()` is also available generatively from `Query` itself using `Query.join`.  This is most easily used with just the "ON" clause portion of the join, such as:
 
     {python}
-    {sql}>>> session.query(User).join('addresses').\
+    {sql}>>> session.query(User).join(User.addresses).\
     ...     filter(Address.email_address=='jack@google.com').all()
     SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
     FROM users JOIN addresses ON users.id = addresses.user_id 
@@ -624,238 +628,173 @@ The easiest way to join is automatically, using the `join()` method on `Query`.
     ['jack@google.com']
     {stop}[<User('jack','Jack Bean', 'gjffdd')>]
 
-By "A to B", we mean a single relation name or a path of relations.  In our case we only have `User->addresses->Address` configured, but if we had a setup like `A->bars->B->bats->C->widgets->D`, a join along all four entities would look like:
+To explicitly specify the target of the join, use tuples to form an argument list similar to the standalone join.  This becomes more important when using aliases and similar constructs:
 
     {python}
-    session.query(Foo).join(['bars', 'bats', 'widgets']).filter(...)
+    session.query(User).join((Address, User.addresses))
     
-Each time `join()` is called on `Query`, the **joinpoint** of the query is moved to be that of the endpoint of the join.  As above, when we joined from `users_table` to `addresses_table`, all subsequent criterion used by `filter_by()` are against the `addresses` table.  When you `join()` again, the joinpoint starts back from the root.  We can also backtrack to the beginning explicitly using `reset_joinpoint()`.  This instruction will place the joinpoint back at the root `users` table, where subsequent `filter_by()` criterion are again against `users`:
+Multiple joins can be created by passing a list of arguments:
 
     {python}
-    {sql}>>> session.query(User).join('addresses').\
-    ...     filter_by(email_address='jack@google.com').\
-    ...     reset_joinpoint().filter_by(name='jack').all()
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-    FROM users JOIN addresses ON users.id = addresses.user_id 
-    WHERE addresses.email_address = ? AND users.name = ? ORDER BY users.oid
-    ['jack@google.com', 'jack']
-    {stop}[<User('jack','Jack Bean', 'gjffdd')>]
-
-In all cases, we can get the `User` and the matching `Address` objects back at the same time, by telling the session we want both.  This returns the results as a list of tuples:
-
-    {python}
-    {sql}>>> session.query(User).add_entity(Address).join('addresses').\
-    ...     filter(Address.email_address=='jack@google.com').all()
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password, addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id 
-    FROM users JOIN addresses ON users.id = addresses.user_id 
-    WHERE addresses.email_address = ? ORDER BY users.oid
-    ['jack@google.com']
-    {stop}[(<User('jack','Jack Bean', 'gjffdd')>, <Address('jack@google.com')>)]
+    session.query(Foo).join(Foo.bars, Bar.bats, (Bat, 'widgets'))
+    
+The above would produce SQL something like `foo JOIN bars ON <onclause> JOIN bats ON <onclause> JOIN widgets ON <onclause>`.
+    
+### Using Aliases {@name=aliases}
 
-Another common scenario is the need to join on the same table more than once.  For example, if we want to find a `User` who has two distinct email addresses, both `jack@google.com` as well as `j25@yahoo.com`, we need to join to the `Addresses` table twice.  SQLAlchemy does provide `Alias` objects which can accomplish this; but far easier is just to tell `join()` to alias for you:
+When querying across multiple tables, if the same table needs to be referenced more than once, SQL typically requires that the table be *aliased* with another name, so that it can be distinguished against other occurences of that table.  The `Query` supports this most expicitly using the `aliased` construct.  Below we join to the `Address` entity twice, to locate a user who has two distinct email addresses at the same time:
 
     {python}
-    {sql}>>> session.query(User).\
-    ...     join('addresses', aliased=True).filter(Address.email_address=='jack@google.com').\
-    ...     join('addresses', aliased=True).filter(Address.email_address=='j25@yahoo.com').all()
-    SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-    FROM users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id JOIN addresses AS addresses_2 ON users.id = addresses_2.user_id 
-    WHERE addresses_1.email_address = ? AND addresses_2.email_address = ? ORDER BY users.oid
+    >>> from sqlalchemy.orm import aliased
+    >>> adalias1 = aliased(Address)
+    >>> adalias2 = aliased(Address)
+    {sql}>>> for username, email1, email2 in session.query(User.name, adalias1.email_address, adalias2.email_address).\
+    ...     join((adalias1, User.addresses), (adalias2, User.addresses)).\
+    ...     filter(adalias1.email_address=='jack@google.com').\
+    ...     filter(adalias2.email_address=='j25@yahoo.com'):
+    ...     print username, email1, email2      # doctest: +NORMALIZE_WHITESPACE
+    SELECT users.name AS users_name, addresses_1.email_address AS addresses_1_email_address, 
+    addresses_2.email_address AS addresses_2_email_address 
+    FROM users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id 
+    JOIN addresses AS addresses_2 ON users.id = addresses_2.user_id 
+    WHERE addresses_1.email_address = ? AND addresses_2.email_address = ?
     ['jack@google.com', 'j25@yahoo.com']
-    {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+    {stop}jack jack@google.com j25@yahoo.com
 
-The key thing which occurred above is that our SQL criterion were **aliased** as appropriate corresponding to the alias generated in the most recent `join()` call.
+### Using Subqueries {@name=subqueries}
 
-The next section describes some "higher level" operators, including `any()` and `has()`, which make patterns like joining to multiple aliases unnecessary in most cases.
+The `Query` is suitable for generating statements which can be used as subqueries.  Suppose we wanted to load `User` objects along with a count of how many `Address` records each user has.  The best way to generate SQL like this is to get the count of addresses grouped by user ids, and JOIN to the parent.  In this case we use a LEFT OUTER JOIN so that we get rows back for those users who don't have any addresses, e.g.:
 
-### Relation Operators
+    {code}
+    SELECT users.*, adr_count.address_count FROM users LEFT OUTER JOIN
+        (SELECT user_id, count(*) AS address_count FROM addresses GROUP BY user_id) AS adr_count
+        ON users.id=adr_count.user_id
 
-A summary of all operators usable on relations:
+Using the `Query`, we build a statement like this from the inside out.  The `statement` accessor returns a SQL expression representing the statement generated by a particular `Query` - this is an instance of a `select()` construct, which are described in [sql](rel:sql):
+    
+    {python}
+    >>> from sqlalchemy.sql import func
+    >>> stmt = session.query(Address.user_id, func.count('*').label('address_count')).group_by(Address.user_id).statement.alias()
+    
+The `func` keyword generates SQL functions, and the `alias()` method on `Select` (the return value of `query.statement`) creates a SQL alias, in this case an anonymous one which will have a generated name.
 
-* Filter on explicit column criterion, combined with a join.  Column criterion can make usage of all supported SQL operators and expression constructs:
+Once we have our statement, it behaves like a `Table` construct, which we created for `users` at the top of this tutorial.  The columns on the statement are accessible through an attribute called `c`:
 
-        {python}
-        {sql}>>> session.query(User).join('addresses').\
-        ...    filter(Address.email_address=='jack@google.com').all()
-        SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-        FROM users JOIN addresses ON users.id = addresses.user_id 
-        WHERE addresses.email_address = ? ORDER BY users.oid
-        ['jack@google.com']
-        {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+    {python}
+    {sql}>>> for u, count in session.query(User, stmt.c.address_count).outerjoin((stmt, User.id==stmt.c.user_id)): # doctest: +NORMALIZE_WHITESPACE
+    ...     print u, count
+    SELECT users.id AS users_id, users.name AS users_name, 
+    users.fullname AS users_fullname, users.password AS users_password, 
+    anon_1.address_count AS anon_1_address_count 
+    FROM users LEFT OUTER JOIN (SELECT addresses.user_id AS user_id, count(?) AS address_count 
+    FROM addresses GROUP BY addresses.user_id) AS anon_1 ON users.id = anon_1.user_id 
+    ORDER BY users.oid
+    ['*']
+    {stop}<User('ed','Ed Jones', 'f8s7ccs')> None
+    <User('wendy','Wendy Williams', 'foobar')> None
+    <User('mary','Mary Contrary', 'xxg527')> None
+    <User('fred','Fred Flinstone', 'blah')> None
+    <User('jack','Jack Bean', 'gjffdd')> 2
 
-    Criterion placed in `filter()` usually correspond to the last `join()` call; if the join was specified with `aliased=True`, class-level criterion against the join's target (or targets) will be appropriately aliased as well.  
+### Using EXISTS
 
-        {python}
-        {sql}>>> session.query(User).join('addresses', aliased=True).\
-        ...    filter(Address.email_address=='jack@google.com').all()
-        SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-        FROM users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id 
-        WHERE addresses_1.email_address = ? ORDER BY users.oid
-        ['jack@google.com']
-        {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+The EXISTS keyword in SQL is a boolean operator which returns True if the given expression contains any rows.  It may be used in many scenarios in place of joins, and is also useful for locating rows which do not have a corresponding row in a related table.
 
-* Filter_by on key=value criterion, combined with a join.  Same as `filter()` on column criterion except keyword arguments are used.
+There is an explicit EXISTS construct, which looks like this:
 
-        {python}
-        {sql}>>> session.query(User).join('addresses').\
-        ...    filter_by(email_address='jack@google.com').all()
-        SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-        FROM users JOIN addresses ON users.id = addresses.user_id 
-        WHERE addresses.email_address = ? ORDER BY users.oid
-        ['jack@google.com']
-        {stop}[<User('jack','Jack Bean', 'gjffdd')>]
-    
-* Filter on explicit column criterion using `any()` (for collections) or `has()` (for scalar relations).  This is a more succinct method than joining, as an `EXISTS` subquery is generated automatically.  `any()` means, "find all parent items where any child item of its collection meets this criterion":
+    {python}
+    >>> from sqlalchemy.sql import exists
+    >>> stmt = exists().where(Address.user_id==User.id)
+    {sql}>>> for name, in session.query(User.name).filter(stmt):   # doctest: +NORMALIZE_WHITESPACE
+    ...     print name
+    SELECT users.name AS users_name 
+    FROM users 
+    WHERE EXISTS (SELECT * 
+    FROM addresses 
+    WHERE addresses.user_id = users.id)
+    []
+    {stop}jack
 
-        {python}
-        {sql}>>> session.query(User).\
-        ...    filter(User.addresses.any(Address.email_address=='jack@google.com')).all()
-        SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-        FROM users 
-        WHERE EXISTS (SELECT 1 
-        FROM addresses 
-        WHERE users.id = addresses.user_id AND addresses.email_address = ?) ORDER BY users.oid
-        ['jack@google.com']
-        {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+The `Query` features several operators which make usage of EXISTS automatically.  Above, the statement can be expressed along the `User.addresses` relation using `any()`:
+
+    {python}
+    {sql}>>> for name, in session.query(User.name).filter(User.addresses.any()):   # doctest: +NORMALIZE_WHITESPACE
+    ...     print name
+    SELECT users.name AS users_name 
+    FROM users 
+    WHERE EXISTS (SELECT 1 
+    FROM addresses 
+    WHERE users.id = addresses.user_id)
+    []
+    {stop}jack
 
-    `has()` means, "find all parent items where the child item meets this criterion":
+`any()` takes criterion as well, to limit the rows matched:
 
-        {python}
-        {sql}>>> session.query(Address).\
-        ...    filter(Address.user.has(User.name=='jack')).all()
-        SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id 
-        FROM addresse
-        WHERE EXISTS (SELECT 1 
-        FROM user
-        WHERE users.id = addresses.user_id AND users.name = ?) ORDER BY addresses.oid
-        ['jack']
-        {stop}[<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
+    {python}
+    {sql}>>> for name, in session.query(User.name).filter(User.addresses.any(Address.email_address.like('%google%'))):   # doctest: +NORMALIZE_WHITESPACE
+    ...     print name
+    SELECT users.name AS users_name 
+    FROM user
+    WHERE EXISTS (SELECT 1 
+    FROM addresse
+    WHERE users.id = addresses.user_id AND addresses.email_address LIKE ?)
+    ['%google%']
+    {stop}jack
 
-    Both `has()` and `any()` also accept keyword arguments which are interpreted against the child classes' attributes:
+`has()` is the same operator as `any()` for many-to-one relations (note the `~` operator here too, which means "NOT"):
 
-        {python}
-        {sql}>>> session.query(User).\
-        ...    filter(User.addresses.any(email_address='jack@google.com')).all()
-        SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_passwor
-        FROM user
-        WHERE EXISTS (SELECT 1 
-        FROM addresse
-        WHERE users.id = addresses.user_id AND addresses.email_address = ?) ORDER BY users.oid
-        ['jack@google.com']
-        {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+    {python}
+    {sql}>>> session.query(Address).filter(~Address.user.has(User.name=='jack')).all() # doctest: +NORMALIZE_WHITESPACE
+    SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, 
+    addresses.user_id AS addresses_user_i
+    FROM addresse
+    WHERE NOT (EXISTS (SELECT 1 
+    FROM user
+    WHERE users.id = addresses.user_id AND users.name = ?)) ORDER BY addresses.oid
+    ['jack']
+    {stop}[]
     
-* Filter_by on instance identity criterion.  When comparing to a related instance, `filter_by()` will in most cases not need to reference the child table, since a child instance already contains enough information with which to generate criterion against the parent table.  `filter_by()` uses an equality comparison for all relationship types.  For many-to-one and one-to-one, this represents all objects which reference the given child object:
-
-        {python}
-        # locate a user
-        {sql}>>> user = session.query(User).filter(User.name=='jack').one() #doctest: +NORMALIZE_WHITESPACE
-        SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-        FROM users 
-        WHERE users.name = ? ORDER BY users.oid 
-        LIMIT 2 OFFSET 0
-        ['jack']
-        {stop}
-        
-        # use the user in a filter_by() expression
-        {sql}>>> session.query(Address).filter_by(user=user).all()
-        SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id 
-        FROM addresses 
-        WHERE ? = addresses.user_id ORDER BY addresses.oid
-        [5]
-        {stop}[<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
+### Common Relation Operators {@name=relationop}
 
-    For one-to-many and many-to-many, it represents all objects which contain the given child object in the related collection:
+Here's all the operators which build on relations:
 
+  * equals (used for many-to-one)
+    
         {python}
-        # locate an address
-        {sql}>>> address = session.query(Address).\
-        ...    filter(Address.email_address=='jack@google.com').one() #doctest: +NORMALIZE_WHITESPACE
-        SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id 
-        FROM addresses 
-        WHERE addresses.email_address = ? ORDER BY addresses.oid 
-        LIMIT 2 OFFSET 0
-        {stop}['jack@google.com']
-    
-        # use the address in a filter_by expression
-        {sql}>>> session.query(User).filter_by(addresses=address).all()
-        SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-        FROM users 
-        WHERE users.id = ? ORDER BY users.oid
-        [5]
-        {stop}[<User('jack','Jack Bean', 'gjffdd')>]
-
-* Select instances with a particular parent.  This is the "reverse" operation of filtering by instance identity criterion; the criterion is against a relation pointing *to* the desired class, instead of one pointing *from* it.  This will utilize the same "optimized" query criterion, usually not requiring any joins:
+        query.filter(Address.user == someuser)
+    
+  * not equals (used for many-to-one)
 
         {python}
-        {sql}>>> session.query(Address).with_parent(user, property='addresses').all()
-        SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id 
-        FROM addresses 
-        WHERE ? = addresses.user_id ORDER BY addresses.oid
-        [5]
-        {stop}[<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
-        
-* Filter on a many-to-one/one-to-one instance identity criterion.  The class-level `==` operator will act the same as `filter_by()` for a scalar relation:
+        query.filter(Address.user != someuser)
 
+  * IS NULL (used for many-to-one)
+    
         {python}
-        {sql}>>> session.query(Address).filter(Address.user==user).all()
-        SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id 
-        FROM addresses 
-        WHERE ? = addresses.user_id ORDER BY addresses.oid
-        [5]
-        {stop}[<Address('jack@google.com')>, <Address('j25@yahoo.com')>]
-
-    whereas the `!=` operator will generate a negated EXISTS clause:
-
+        query.filter(Address.user == None)
+        
+  * contains (used for one-to-many and many-to-many collections)
+    
         {python}
-        {sql}>>> session.query(Address).filter(Address.user!=user).all()
-        SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id 
-        FROM addresses 
-        WHERE NOT (EXISTS (SELECT 1 
-        FROM users 
-        WHERE users.id = addresses.user_id AND users.id = ?)) ORDER BY addresses.oid
-        [5]
-        {stop}[]
-
-    a comparison to `None` also generates an IS NULL clause for a many-to-one relation:
-
+        query.filter(User.addresses.contains(someaddress))
+    
+  * any (used for one-to-many and many-to-many collections)
+    
         {python}
-        {sql}>>> session.query(Address).filter(Address.user==None).all()
-        SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id 
-        FROM addresses 
-        WHERE addresses.user_id IS NULL ORDER BY addresses.oid
-        []
-        {stop}[]
-
-* Filter on a one-to-many instance identity criterion.  The `contains()` operator returns all parent objects which contain the given object as one of its collection members:
-
+        query.filter(User.addresses.any(Address.email_address == 'bar'))
+        
+        # also takes keyword arguments:
+        query.filter(User.addresses.any(email_address='bar'))
+    
+  * has (used for many-to-one)
+    
         {python}
-        {sql}>>> session.query(User).filter(User.addresses.contains(address)).all()
-        SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-        FROM users 
-        WHERE users.id = ? ORDER BY users.oid
-        [5]
-        {stop}[<User('jack','Jack Bean', 'gjffdd')>]
-
-* Filter on a multiple one-to-many instance identity criterion.  The `==` operator can be used with a collection-based attribute against a list of items, which will generate multiple `EXISTS` clauses:
-
+        query.filter(Address.user.has(name='ed'))
+    
+  * with_parent (used for any relation)
+    
         {python}
-        {sql}>>> addresses = session.query(Address).filter(Address.user==user).all()
-        SELECT addresses.id AS addresses_id, addresses.email_address AS addresses_email_address, addresses.user_id AS addresses_user_id 
-        FROM addresses 
-        WHERE ? = addresses.user_id ORDER BY addresses.oid
-        [5]
-        {stop}
-        
-        {sql}>>> session.query(User).filter(User.addresses == addresses).all()
-        SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
-        FROM users 
-        WHERE (EXISTS (SELECT 1 
-        FROM addresses 
-        WHERE users.id = addresses.user_id AND addresses.id = ?)) AND (EXISTS (SELECT 1 
-        FROM addresses 
-        WHERE users.id = addresses.user_id AND addresses.id = ?)) ORDER BY users.oid
-        [1, 2]
-        {stop}[<User('jack','Jack Bean', 'gjffdd')>]
+        session.query(Address).with_parent(someuser, 'addresses')
 
 ## Deleting
 
@@ -870,7 +809,7 @@ Let's try to delete `jack` and see how that goes.  We'll mark as deleted in the
     [None, 2]
     DELETE FROM users WHERE users.id = ?
     [5]
-    SELECT count(users.id) AS count_1
+    SELECT count(1) AS count_1 
     FROM users 
     WHERE users.name = ?
     ['jack']
@@ -882,39 +821,42 @@ So far, so good.  How about Jack's `Address` objects ?
     {sql}>>> session.query(Address).filter(
     ...     Address.email_address.in_(['jack@google.com', 'j25@yahoo.com'])
     ...  ).count() # doctest: +NORMALIZE_WHITESPACE
-    SELECT count(addresses.id) AS count_1
+    SELECT count(1) AS count_1
     FROM addresses 
     WHERE addresses.email_address IN (?, ?)
     ['jack@google.com', 'j25@yahoo.com']
     {stop}2
     
-Uh oh, they're still there !  Analyzing the flush SQL, we can see that the `user_id` column of each address was set to NULL, but the rows weren't deleted.  SQLAlchemy doesn't assume that deletes cascade, you have to tell it so.
+Uh oh, they're still there !  Analyzing the flush SQL, we can see that the `user_id` column of each address was set to NULL, but the rows weren't deleted.  SQLAlchemy doesn't assume that deletes cascade, you have to tell it to do so.
 
-So let's rollback our work, and start fresh with new mappers that express the relationship the way we want:
+### Configuring delete/delete-orphan Cascade {@name=cascade}
+
+We will configure **cascade** options on the `User.addresses` relation to change the behavior.  While SQLAlchemy allows you to add new attributes and relations to mappings at any point in time, in this case the existing relation needs to be removed, so we need to tear down the mappings completely and start again.  This is not a typical operation and is here just for illustrative purposes.
+
+Removing all ORM state is as follows:
 
     {python}
-    {sql}>>> session.rollback()  # roll back the transaction
-    ROLLBACK
-    
-    >>> session.clear() # clear the session
+    >>> session.close()  # roll back and close the transaction
+    >>> from sqlalchemy.orm import clear_mappers
     >>> clear_mappers() # clear mappers
     
-We need to tell the `addresses` relation on `User` that we'd like session.delete() operations to cascade down to the child `Address` objects.  Further, we also want `Address` objects which get detached from their parent `User`, whether or not the parent is deleted, to be deleted.  For these behaviors we use two **cascade options** `delete` and `delete-orphan`, using the string-based `cascade` option to the `relation()` function:
+Below, we use `mapper()` to reconfigure an ORM mapping for `User` and `Address`, on our existing but currently un-mapped classes.  The `User.addresses` relation now has `delete, delete-orphan` cascade on it, which indicates that DELETE operations will cascade to attached `Address` objects as well as `Address` objects which are removed from their parent:
 
     {python}
     >>> mapper(User, users_table, properties={    # doctest: +ELLIPSIS
     ...     'addresses':relation(Address, backref='user', cascade="all, delete, delete-orphan")
     ... })
-    <sqlalchemy.orm.mapper.Mapper object at 0x...>
+    <Mapper at 0x...; User>
     
+    >>> addresses_table = Address.__table__
     >>> mapper(Address, addresses_table) # doctest: +ELLIPSIS
-    <sqlalchemy.orm.mapper.Mapper object at 0x...>
+    <Mapper at 0x...; Address>
 
-Now when we load Jack, removing an address from his `addresses` collection will result in that `Address` being deleted:
+Now when we load Jack (below using `get()`, which loads by primary key), removing an address from his `addresses` collection will result in that `Address` being deleted:
 
     {python}
     # load Jack by primary key
-    {sql}>>> jack = session.query(User).get(jack.id)    #doctest: +NORMALIZE_WHITESPACE
+    {sql}>>> jack = session.query(User).get(5)    #doctest: +NORMALIZE_WHITESPACE
     BEGIN
     SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, users.password AS users_password 
     FROM users 
@@ -929,14 +871,14 @@ Now when we load Jack, removing an address from his `addresses` collection will
     WHERE ? = addresses.user_id ORDER BY addresses.oid
     [5]
     {stop}
-    
+
     # only one address remains
     {sql}>>> session.query(Address).filter(
     ...     Address.email_address.in_(['jack@google.com', 'j25@yahoo.com'])
     ... ).count() # doctest: +NORMALIZE_WHITESPACE
     DELETE FROM addresses WHERE addresses.id = ?
     [2]
-    SELECT count(addresses.id) AS count_1
+    SELECT count(1) AS count_1
     FROM addresses 
     WHERE addresses.email_address IN (?, ?)
     ['jack@google.com', 'j25@yahoo.com']
@@ -947,17 +889,12 @@ Deleting Jack will delete both Jack and his remaining `Address`:
     {python}
     >>> session.delete(jack)
     
-    {sql}>>> session.commit()
+    {sql}>>> session.query(User).filter_by(name='jack').count() # doctest: +NORMALIZE_WHITESPACE
     DELETE FROM addresses WHERE addresses.id = ?
     [1]
     DELETE FROM users WHERE users.id = ?
     [5]
-    COMMIT
-    {stop}
-    
-    {sql}>>> session.query(User).filter_by(name='jack').count() # doctest: +NORMALIZE_WHITESPACE
-    BEGIN
-    SELECT count(users.id) AS count_1
+    SELECT count(1) AS count_1
     FROM users 
     WHERE users.name = ?
     ['jack']
@@ -966,7 +903,7 @@ Deleting Jack will delete both Jack and his remaining `Address`:
     {sql}>>> session.query(Address).filter(
     ...    Address.email_address.in_(['jack@google.com', 'j25@yahoo.com'])
     ... ).count() # doctest: +NORMALIZE_WHITESPACE
-    SELECT count(addresses.id) AS count_1
+    SELECT count(1) AS count_1
     FROM addresses 
     WHERE addresses.email_address IN (?, ?)
     ['jack@google.com', 'j25@yahoo.com']
@@ -976,25 +913,59 @@ Deleting Jack will delete both Jack and his remaining `Address`:
 
 We're moving into the bonus round here, but lets show off a many-to-many relationship.  We'll sneak in some other features too, just to take a tour.  We'll make our application a blog application, where users can write `BlogPost`s, which have `Keywords` associated with them.
 
-First some new tables:
+The declarative setup is as follows:
 
     {python}
     >>> from sqlalchemy import Text
-    >>> post_table = Table('posts', metadata, 
-    ...        Column('id', Integer, primary_key=True),
-    ...        Column('user_id', Integer, ForeignKey('users.id')),
-    ...        Column('headline', String(255), nullable=False),
-    ...        Column('body', Text)
-    ...        )
-    
+
+    >>> # association table
     >>> post_keywords = Table('post_keywords', metadata,
-    ...        Column('post_id', Integer, ForeignKey('posts.id')),
-    ...        Column('keyword_id', Integer, ForeignKey('keywords.id')))
-    
-    >>> keywords_table = Table('keywords', metadata,
-    ...        Column('id', Integer, primary_key=True),
-    ...        Column('keyword', String(50), nullable=False, unique=True))
+    ...     Column('post_id', Integer, ForeignKey('posts.id')),
+    ...     Column('keyword_id', Integer, ForeignKey('keywords.id'))
+    ... )
+
+    >>> class BlogPost(Base):
+    ...     __tablename__ = 'posts'
+    ...
+    ...     id = Column(Integer, primary_key=True)
+    ...     user_id = Column(Integer, ForeignKey('users.id'))
+    ...     headline = Column(String(255), nullable=False)
+    ...     body = Column(Text)
+    ...
+    ...     # many to many BlogPost<->Keyword
+    ...     keywords = relation('Keyword', secondary=post_keywords, backref='posts')
+    ...
+    ...     def __init__(self, headline, body, author):
+    ...         self.author = author
+    ...         self.headline = headline
+    ...         self.body = body
+    ...
+    ...     def __repr__(self):
+    ...         return "BlogPost(%r, %r, %r)" % (self.headline, self.body, self.author)
+
+    >>> class Keyword(Base):
+    ...     __tablename__ = 'keywords'
+    ...
+    ...     id = Column(Integer, primary_key=True)
+    ...     keyword = Column(String(50), nullable=False, unique=True)
+    ...
+    ...     def __init__(self, keyword):
+    ...         self.keyword = keyword
+
+Above, the many-to-many relation above is `BlogPost.keywords`.  The defining feature of a many to many relation is the `secondary` keyword argument which references a `Table` object representing the association table.  This table only contains columns which reference the two sides of the relation; if it has *any* other columns, such as its own primary key, or foreign keys to other tables, SQLAlchemy requires a different usage pattern called the "association object", described at [advdatamapping_relation_patterns_association](rel:advdatamapping_relation_patterns_association).
+
+The many-to-many relation is also bi-directional using the `backref` keyword.  This is the one case where usage of `backref` is generally required, since if a separate `posts` relation were added to the `Keyword` entity, both relations would independently add and remove rows from the `post_keywords` table and produce conflicts.
+
+We would also like our `BlogPost` class to have an `author` field.  We will add this as another bidirectional relationship, except one issue we'll have is that a single user might have lots of blog posts.  When we access `User.posts`, we'd like to be able to filter results further so as not to load the entire collection.  For this we use a setting accepted by `relation()` called `lazy='dynamic'`, which configures an alternate **loader strategy** on the attribute.  To use it on the "reverse" side of a `relation()`, we use the `backref()` function:
+
+    {python}
+    >>> from sqlalchemy.orm import backref
+    >>> # "dynamic" loading relation to User
+    >>> BlogPost.author = relation(User, backref=backref('posts', lazy='dynamic'))
+
+Create new tables:
     
+    {python}
     {sql}>>> metadata.create_all(engine) # doctest: +NORMALIZE_WHITESPACE
     PRAGMA table_info("users")
     {}
@@ -1033,41 +1004,6 @@ First some new tables:
     {}
     COMMIT
 
-Then some classes:
-
-    {python}
-    >>> class BlogPost(object):
-    ...     def __init__(self, headline, body, author):
-    ...         self.author = author
-    ...         self.headline = headline
-    ...         self.body = body
-    ...     def __repr__(self):
-    ...         return "BlogPost(%r, %r, %r)" % (self.headline, self.body, self.author)
-    
-    >>> class Keyword(object):
-    ...     def __init__(self, keyword):
-    ...         self.keyword = keyword
-            
-And the mappers.  `BlogPost` will reference `User` via its `author` attribute:
-
-    {python}
-    >>> from sqlalchemy.orm import backref
-    
-    >>> mapper(Keyword, keywords_table) # doctest: +ELLIPSIS
-    <sqlalchemy.orm.mapper.Mapper object at 0x...>
-    
-    >>> mapper(BlogPost, post_table, properties={   # doctest: +ELLIPSIS
-    ...    'author':relation(User, backref=backref('posts', lazy='dynamic')),
-    ...    'keywords':relation(Keyword, secondary=post_keywords)
-    ... }) 
-    <sqlalchemy.orm.mapper.Mapper object at 0x...>
-    
-There's three new things in the above mapper:
-
-  * the `User` relation has a backref, like we've used before, except this time it references a function called `backref()`.  This function is used when yo'd like to specify keyword options for the backwards relationship.
-  * the keyword option we specified to `backref()` is `lazy="dynamic"`.  This sets a default **loader strategy** on the attribute, in this case a special strategy that allows partial loading of results.
-  * The `keywords` relation uses a keyword argument `secondary` to indicate the **association table** for the many to many relationship from `BlogPost` to `Keyword`.
-    
 Usage is not too different from what we've been doing.  Let's give Wendy some blog posts:
 
     {python}
@@ -1079,7 +1015,7 @@ Usage is not too different from what we've been doing.  Let's give Wendy some bl
     ['wendy']
     
     >>> post = BlogPost("Wendy's Blog Post", "This is a test", wendy)
-    >>> session.save(post)
+    >>> session.add(post)
     
 We're storing keywords uniquely in the database, but we know that we don't have any yet, so we can just create them:
 
@@ -1087,7 +1023,7 @@ We're storing keywords uniquely in the database, but we know that we don't have
     >>> post.keywords.append(Keyword('wendy'))
     >>> post.keywords.append(Keyword('firstpost'))
     
-We can now look up all blog posts with the keyword 'firstpost'.   We'll use a special collection operator `any` to locate "blog posts where any of its keywords has the keyword string 'firstpost'":
+We can now look up all blog posts with the keyword 'firstpost'.   We'll use the `any` operator to locate "blog posts where any of its keywords has the keyword string 'firstpost'":
 
     {python}
     {sql}>>> session.query(BlogPost).filter(BlogPost.keywords.any(keyword='firstpost')).all()
@@ -1110,7 +1046,7 @@ We can now look up all blog posts with the keyword 'firstpost'.   We'll use a sp
 If we want to look up just Wendy's posts, we can tell the query to narrow down to her as a parent:
 
     {python}
-    {sql}>>> session.query(BlogPost).with_parent(wendy).\
+    {sql}>>> session.query(BlogPost).filter(BlogPost.author==wendy).\
     ... filter(BlogPost.keywords.any(keyword='firstpost')).all()
     SELECT posts.id AS posts_id, posts.user_id AS posts_user_id, posts.headline AS posts_headline, posts.body AS posts_body 
     FROM posts 
index 0e94cef244d3d6a670dd35db9bdeea6b75e9f536..f9d8f4ed99ccad453fe9fbfd618fb31c0e9ff4ea 100644 (file)
@@ -1,4 +1,4 @@
-    Using the Session {@name=unitofwork}
+Using the Session {@name=unitofwork}
 ============
 
 The [Mapper](rel:advdatamapping) is the entrypoint to the configurational API of the SQLAlchemy object relational mapper.  But the primary object one works with when using the ORM is the [Session](rel:docstrings_sqlalchemy.orm.session_Session).
@@ -7,11 +7,11 @@ The [Mapper](rel:advdatamapping) is the entrypoint to the configurational API of
 
 In the most general sense, the `Session` establishes all conversations with the database and represents a "holding zone" for all the mapped instances which you've loaded or created during its lifespan.  It implements the [Unit of Work](http://martinfowler.com/eaaCatalog/unitOfWork.html) pattern, which means it keeps track of all changes which occur, and is capable of **flushing** those changes to the database as appropriate.   Another important facet of the `Session` is that it's also maintaining **unique** copies of each instance, where "unique" means "only one object with a particular primary key" - this pattern is called the [Identity Map](http://martinfowler.com/eaaCatalog/identityMap.html).
 
-Beyond that, the `Session` implements an interface which let's you move objects in or out of the session in a variety of ways, it provides the entryway to a `Query` object which is used to query the database for data, it is commonly used to provide transactional boundaries (though this is optional), and it also can serve as a configurational "home base" for one or more `Engine` objects, which allows various vertical and horizontal partitioning strategies to be achieved.
+Beyond that, the `Session` implements an interface which let's you move objects in or out of the session in a variety of ways, it provides the entryway to a `Query` object which is used to query the database for data, and it also provides a transactional context for SQL operations which rides on top of the transactional capabilities of `Engine` and `Connection` objects.
 
 ## Getting a Session
 
-The `Session` object exists just as a regular Python object, which can be directly instantiated.  However, it takes a fair amount of keyword options, several of which you probably want to set explicitly.  It's fairly inconvenient to deal with the "configuration" of a session every time you want to create one.  Therefore, SQLAlchemy recommends the usage of a helper function called `sessionmaker()`, which typically you call only once for the lifespan of an application.  This function creates a customized `Session` subclass for you, with your desired configurational arguments pre-loaded.  Then, whenever you need a new `Session`, you use your custom `Session` class with no arguments to create the session.
+`Session` is a regular Python class which can be directly instantiated.  However, to standardize how sessions are configured and acquired, the `sessionmaker()` function is normally used to create a top level `Session` configuration which can then be used throughout an application without the need to repeat the configurational arguments.
 
 ### Using a sessionmaker() Configuration {@name=sessionmaker}
 
@@ -21,43 +21,30 @@ The usage of `sessionmaker()` is illustrated below:
     from sqlalchemy.orm import sessionmaker
     
     # create a configured "Session" class
-    Session = sessionmaker(autoflush=True, transactional=True)
+    Session = sessionmaker(bind=some_engine)
 
     # create a Session
     sess = Session()
     
     # work with sess
-    sess.save(x)
+    myobject = MyObject('foo', 'bar')
+    sess.add(myobject)
     sess.commit()
     
     # close when finished
     sess.close()
 
-Above, the `sessionmaker` call creates a class for us, which we assign to the name `Session`.  This class is a subclass of the actual `sqlalchemy.orm.session.Session` class, which will instantiate with the arguments of `autoflush=True` and `transactional=True`.
+Above, the `sessionmaker` call creates a class for us, which we assign to the name `Session`.  This class is a subclass of the actual `sqlalchemy.orm.session.Session` class, which will instantiate with a particular bound engine.
 
 When you write your application, place the call to `sessionmaker()` somewhere global, and then make your new `Session` class available to the rest of your application.
 
-### Binding Session to an Engine or Connection {@name=binding}
+### Binding Session to an Engine {@name=binding}
 
-In our previous example regarding `sessionmaker()`, nowhere did we specify how our session would connect to our database.  When the session is configured in this manner, it will look for a database engine to connect with via the `Table` objects that it works with - the chapter called [metadata_tables_binding](rel:metadata_tables_binding) describes how to associate `Table` objects directly with a source of database connections.
-
-However, it is often more straightforward to explicitly tell the session what database engine (or engines) you'd like it to communicate with.  This is particularly handy with multiple-database scenarios where the session can be used as the central point of configuration.  To achieve this, the constructor keyword `bind` is used for a basic single-database configuration:
-
-    {python}
-    # create engine
-    engine = create_engine('postgres://...')
-    
-    # bind custom Session class to the engine
-    Session = sessionmaker(bind=engine, autoflush=True, transactional=True)
-    
-    # work with the session
-    sess = Session()
-    
-One common issue with the above scenario is that an application will often organize its global imports before it ever connects to a database.  Since the `Session` class created by `sessionmaker()` is meant to be a global application object (note we are saying the session *class*, not a session *instance*), we may not have a `bind` argument available.  For this, the `Session` class returned by `sessionmaker()` supports post-configuration of all options, through its method `configure()`:
+In our previous example regarding `sessionmaker()`, we specified a `bind` for a particular `Engine`.  If we'd like to construct a `sessionmaker()` without an engine available and bind it later on, or to specify other options to an existing `sessionmaker()`, we may use the `configure()` method:
 
     {python}
     # configure Session class with desired options
-    Session = sessionmaker(autoflush=True, transactional=True)
+    Session = sessionmaker()
 
     # later, we create the engine
     engine = create_engine('postgres://...')
@@ -68,16 +55,17 @@ One common issue with the above scenario is that an application will often organ
     # work with the session
     sess = Session()
 
-The `Session` also has the ability to be bound to multiple engines.   Descriptions of these scenarios are described in [unitofwork_partitioning](rel:unitofwork_partitioning).
+It's actually entirely optional to bind a Session to an engine.  If the underlying mapped `Table` objects use "bound" metadata, the `Session` will make use of the bound engine instead (or will even use multiple engines if multiple binds are present within the mapped tables).  "Bound" metadata is described at [metadata_tables_binding](rel:metadata_tables_binding).
 
+The `Session` also has the ability to be bound to multiple engines explicitly.   Descriptions of these scenarios are described in [unitofwork_partitioning](rel:unitofwork_partitioning).
 
-#### Binding Session to a Connection {@name=connection}
+### Binding Session to a Connection {@name=connection}
 
-The examples involving `bind` so far are dealing with the `Engine` object, which is, like the `Session` class itself, a global configurational object.  The `Session` can also be bound to an individual database `Connection`.  The reason you might want to do this is if your application controls the boundaries of transactions using distinct `Transaction` objects (these objects are described in [dbengine_transactions](rel:dbengine_transactions)).  You'd have a transactional `Connection`, and then you'd want to work with an ORM-level `Session` which participates in that transaction.  Since `Connection` is definitely not a globally-scoped object in all but the most rudimental commandline applications, you can bind an individual `Session()` instance to a particular `Connection` not at class configuration time, but at session instance construction time:
+The `Session` can also be explicitly bound to an individual database `Connection`.  Reasons for doing this may include to join a `Session` with an ongoing transaction local to a specific `Connection` object, or to bypass connection pooling by just having connections persistently checked out and associated with distinct, long running sessions:
 
     {python}
     # global application scope.  create Session class, engine
-    Session = sessionmaker(autoflush=True, transactional=True)
+    Session = sessionmaker()
 
     engine = create_engine('postgres://...')
     
@@ -93,44 +81,24 @@ The examples involving `bind` so far are dealing with the `Engine` object, which
 
 ### Using create_session() {@name=createsession}
 
-As an alternative to `sessionmaker()`, `create_session()` exists literally as a function which calls the normal `Session` constructor directly.  All arguments are passed through and the new `Session` object is returned:
+As an alternative to `sessionmaker()`, `create_session()` is a function which calls the normal `Session` constructor directly.  All arguments are passed through and the new `Session` object is returned:
 
     {python}
-    session = create_session(bind=myengine)
-    
-The `create_session()` function doesn't add any functionality to the regular `Session`, it just sets up a default argument set of `autoflush=False, transactional=False`.  But also, by calling `create_session()` instead of instantiating `Session` directly, you leave room in your application to change the type of session which the function creates.  For example, an application which is calling `create_session()` in many places, which is typical for a pre-0.4 application, can be changed to use a `sessionmaker()` by just assigning the return of `sessionmaker()` to the `create_session` name:
+    session = create_session(bind=myengine, autocommit=True, autoflush=False)
 
-    {python}
-    # change from:
-    from sqlalchemy.orm import create_session
+### Configurational Arguments {@name=configuration}
 
-    # to:
-    create_session = sessionmaker()
+Configurational arguments accepted by `sessionmaker()` and `create_session()` are the same as that of the `Session` class itself, and are described at [docstrings_sqlalchemy.orm_modfunc_sessionmaker](rel:docstrings_sqlalchemy.orm_modfunc_sessionmaker).
 
 ## Using the Session 
 
-A typical session conversation starts with creating a new session, or acquiring one from an ongoing context.    You save new objects and load existing ones, make changes, mark some as deleted, and then persist your changes to the database.  If your session is transactional, you use `commit()` to persist any remaining changes and to commit the transaction.  If not, you call `flush()` which will flush any remaining data to the database.
-
-Below, we open a new `Session` using a configured `sessionmaker()`, make some changes, and commit:
-
-    {python}
-    # configured Session class
-    Session = sessionmaker(autoflush=True, transactional=True)
-    
-    sess = Session()
-    d = Data(value=10)
-    sess.save(d)
-    d2 = sess.query(Data).filter(Data.value==15).one()
-    d2.value = 19
-    sess.commit()
-
 ### Quickie Intro to Object States {@name=states}
 
 It's helpful to know the states which an instance can have within a session:
 
 * *Transient* - an instance that's not in a session, and is not saved to the database; i.e. it has no database identity.  The only relationship such an object has to the ORM is that its class has a `mapper()` associated with it.
 
-* *Pending* - when you `save()` a transient instance, it becomes pending.  It still wasn't actually flushed to the database yet, but it will be when the next flush occurs.
+* *Pending* - when you `add()` a transient instance, it becomes pending.  It still wasn't actually flushed to the database yet, but it will be when the next flush occurs.
 
 * *Persistent* - An instance which is present in the session and has a record in the database.  You get persistent instances by either flushing so that the pending instances become persistent, or by querying the database for existing instances (or moving persistent instances from other sessions into your local session).
 
@@ -152,7 +120,7 @@ Knowing these states is important, since the `Session` tries to be strict about
 
     You typically invoke `Session()` when you first need to talk to your database, and want to save some objects or load some existing ones.  Then, you work with it, save your changes, and then dispose of it....or at the very least `close()` it.  It's not a "global" kind of object, and should be handled more like a "local variable", as it's generally **not** safe to use with concurrent threads.  Sessions are very inexpensive to make, and don't use any resources whatsoever until they are first used...so create some !
 
-    There is also a pattern whereby you're using a **contextual session**, this is described later in [unitofwork_contextual](rel:unitofwork_contextual).  In this pattern, a helper object is maintaining a `Session` for you, most commonly one that is local to the current thread (and sometimes also local to an application instance).  SQLAlchemy 0.4 has worked this pattern out such that it still *looks* like you're creating a new session as you need one...so in that case, it's still a guaranteed win to just say `Session()` whenever you want a session.  
+    There is also a pattern whereby you're using a **contextual session**, this is described later in [unitofwork_contextual](rel:unitofwork_contextual).  In this pattern, a helper object is maintaining a `Session` for you, most commonly one that is local to the current thread (and sometimes also local to an application instance).  SQLAlchemy has worked this pattern out such that it still *looks* like you're creating a new session as you need one...so in that case, it's still a guaranteed win to just say `Session()` whenever you want a session.  
 
 * Is the Session a cache ? 
 
@@ -175,121 +143,66 @@ Knowing these states is important, since the `Session` tries to be strict about
 
     But the bigger point here is, you should not *want* to use the session with multiple concurrent threads.  That would be like having everyone at a restaurant all eat from the same plate.  The session is a local "workspace" that you use for a specific set of tasks; you don't want to, or need to, share that session with other threads who are doing some other task.  If, on the other hand, there are other threads  participating in the same task you are, such as in a desktop graphical application, then you would be sharing the session with those threads, but you also will have implemented a proper locking scheme (or your graphical framework does) so that those threads do not collide.
   
-### Session Attributes {@name=attributes} 
-
-The session provides a set of attributes and collection-oriented methods which allow you to view the current state of the session.
-
-The **identity map** is accessed by the `identity_map` attribute, which provides a dictionary interface.  The keys are "identity keys", which are attached to all persistent objects by the attribute `_instance_key`:
-
-    {python}
-    >>> myobject._instance_key 
-    (<class 'test.tables.User'>, (7,))
-
-    >>> myobject._instance_key in session.identity_map
-    True
-
-    >>> session.identity_map.values()
-    [<__main__.User object at 0x712630>, <__main__.Address object at 0x712a70>]
-
-The identity map is a weak-referencing dictionary by default.  This means that objects which are dereferenced on the outside will be removed from the session automatically.  Note that objects which are marked as "dirty" will not fall out of scope until after changes on them have been flushed; special logic kicks in at the point of auto-removal which ensures that no pending changes remain on the object, else a temporary strong reference is created to the object.
-
-Some people prefer objects to stay in the session until explicitly removed in all cases; for this,  you can specify the flag `weak_identity_map=False` to the `create_session` or `sessionmaker` functions so that the `Session` will use a regular dictionary.
-
-While the `identity_map` accessor is currently the actual dictionary used by the `Session` to store instances, you should not add or remove items from this dictionary.  Use the session methods `save_or_update()` and `expunge()` to add or remove items.
-
-The Session also supports an iterator interface in order to see all objects in the identity map:
-
-    {python}
-    for obj in session:
-        print obj
-
-As well as `__contains__()`:
-
-    {python}
-    if obj in session:
-        print "Object is present"
-
-The session is also keeping track of all newly created (i.e. pending) objects, all objects which have had changes since they were last loaded or saved (i.e. "dirty"), and everything that's been marked as deleted.  
-
-    {python}
-    # pending objects recently added to the Session
-    session.new
-    
-    # persistent objects which currently have changes detected
-    # (this collection is now created on the fly each time the property is called)
-    session.dirty
-
-    # persistent objects that have been marked as deleted via session.delete(obj)
-    session.deleted
-
 ### Querying
 
-The `query()` function takes one or more classes and/or mappers, along with an optional `entity_name` parameter, and returns a new `Query` object which will issue mapper queries within the context of this Session.  For each mapper is passed, the Query uses that mapper.  For each class, the Query will locate the primary mapper for the class using `class_mapper()`.
+The `query()` function takes one or more *entities* and returns a new `Query` object which will issue mapper queries within the context of this Session.  An entity is defined as a mapped class, a `Mapper` object, an orm-enabled *descriptor*, or an `AliasedClass` object (a future release will also include an `Entity` object for use with entity_name mappers).
 
     {python}
     # query from a class
     session.query(User).filter_by(name='ed').all()
 
     # query with multiple classes, returns tuples
-    session.query(User).add_entity(Address).join('addresses').filter_by(name='ed').all()
+    session.query(User, Address).join('addresses').filter_by(name='ed').all()
+
+    # query using orm-enabled descriptors
+    session.query(User.name, User.fullname).all()
     
     # query from a mapper
-    query = session.query(usermapper)
-    x = query.get(1)
-    
-    # query from a class mapped with entity name 'alt_users'
-    q = session.query(User, entity_name='alt_users')
-    y = q.options(eagerload('orders')).all()
-    
-`entity_name` is an optional keyword argument sent with a class object, in order to further qualify which primary mapper to be used; this only applies if there was a `Mapper` created with that particular class/entity name combination, else an exception is raised.  All of the methods on Session which take a class or mapper argument also take the `entity_name` argument, so that a given class can be properly matched to the desired primary mapper.
+    user_mapper = class_mapper(User)
+    session.query(user_mapper)
 
-All instances retrieved by the returned `Query` object will be stored as persistent instances within the originating `Session`.
+When `Query` returns results, each object instantiated is stored within the identity map.   When a row matches an object which is already present, the same object is returned.  In the latter case, whether or not the row is populated onto an existing object depends upon whether the attributes of the instance have been *expired* or not.  As of 0.5, a default-configured `Session` automatically expires all instances along transaction boundaries, so that with a normally isolated transaction, there shouldn't be any issue of instances representing data which is stale with regards to the current transaction.
 
-### Saving New Instances
+### Adding New or Existing Items
 
-`save()` is called with a single transient instance as an argument, which is then added to the Session and becomes pending.  When the session is next flushed, the instance will be saved to the database.  If the given instance is not transient, meaning it is either attached to an existing Session or it has a database identity, an exception is raised.
+`add()` is used to place instances in the session.  For *transient* (i.e. brand new) instances, this will have the effect of an INSERT taking place for those instances upon the next flush.  For instances which are *persistent* (i.e. were loaded by this session), they are already present and do not need to be added.  Instances which are *detached* (i.e. have been removed from a session) may be re-associated with a session using this method:
 
     {python}
     user1 = User(name='user1')
     user2 = User(name='user2')
-    session.save(user1)
-    session.save(user2)
+    session.add(user1)
+    session.add(user2)
     
     session.commit()     # write changes to the database
 
-There's also other ways to have objects saved to the session automatically; one is by using cascade rules, and the other is by using a contextual session.  Both of these are described later.
+To add a list of items to the session at once, use `add_all()`:
 
-### Updating/Merging Existing Instances
+    {python}
+    session.add_all([item1, item2, item3])
 
-The `update()` method is used when you have a detached instance, and you want to put it back into a `Session`.  Recall that "detached" means the object has a database identity.
+The `add()` operation **cascades** along the `save-update` cascade.  For more details see the section [unitofwork_cascades](rel:unitofwork_cascades).
 
-Since `update()` is a little picky that way, most people use `save_or_update()`, which checks for an `_instance_key` attribute, and based on whether it's there or not, calls either `save()` or `update()`:
+### Merging
 
-    {python}
-    # load user1 using session 1
-    user1 = sess1.query(User).get(5)
-    
-    # remove it from session 1
-    sess1.expunge(user1)
-    
-    # move it into session 2
-    sess2.save_or_update(user1)
+`merge()` reconciles the current state of an instance and its associated children with existing data in the database, and returns a copy of the instance associated with the session.  Usage is as follows:
 
-`update()` is also an operation that can happen automatically using cascade rules, just like `save()`.  
+    {python}
+    merged_object = session.merge(existing_object)
 
-`merge()` on the other hand is a little like `update()`, except it creates a **copy** of the given instance in the session, and returns to you that instance; the instance you send it never goes into the session.  `merge()` is much fancier than `update()`; it will actually look to see if an object with the same primary key is already present in the session, and if not will load it by primary key.  Then, it will merge the attributes of the given object into the one which it just located.
+When given an instance, it follows these steps:
 
-This method is useful for bringing in objects which may have been restored from a serialization, such as those stored in an HTTP session, where the object may be present in the session already:
+  * It examines the primary key of the instance.  If it's present, it attempts to load an instance with that primary key (or pulls from the local identity map).
+  * If there's no primary key on the given instance, or the given primary key does not exist in the database, a new instance is created.
+  * The state of the given instance is then copied onto the located/newly created instance.
+  * The operation is cascaded to associated child items along the `merge` cascade.  Note that all changes present on the given instance, including changes to collections, are merged.
+  * The new instance is returned.
 
-    {python}
-    # deserialize an object
-    myobj = pickle.loads(mystring)
+With `merge()`, the given instance is not placed within the session, and can be associated with a different session or detached.  `merge()` is very useful for taking the state of any kind of object structure without regard for its origins or current session associations and placing that state within a session.   Here's two examples:
 
-    # "merge" it.  if the session already had this object in the 
-    # identity map, then you get back the one from the current session.
-    myobj = session.merge(myobj)
+  * An application which reads an object structure from a file and wishes to save it to the database might parse the file, build up the structure, and then use `merge()` to save it to the database, ensuring that the data within the file is used to formulate the primary key of each element of the structure.  Later, when the file has changed, the same process can be re-run, producing a slightly different object structure, which can then be `merged()` in again, and the `Session` will automatically update the database to reflect those changes.
+  * A web application stores mapped entities within an HTTP session object.  When each request starts up, the serialized data can be merged into the session, so that the original entity may be safely shared among requests and threads.
 
-`merge()` includes an important option called `dont_load`.  When this boolean flag is set to `True`, the merge of a detached object will not force a `get()` of that object from the database.  Normally, `merge()` issues a `get()` for every existing object so that it can load the most recent state of the object, which is then modified according to the state of the given object.  With `dont_load=True`, the `get()` is skipped and `merge()` places an exact copy of the given object in the session.  This allows objects which were retrieved from a caching system to be copied back into a session without any SQL overhead being added.
+`merge()` is frequently used by applications which implement their own second level caches.  This refers to an application which uses an in memory dictionary, or an tool like Memcached to store objects over long running spans of time.  When such an object needs to exist within a `Session`, `merge()` is a good choice since it leaves the original cached object untouched.  For this use case, merge provides a keyword option called `dont_load=True`.  When this boolean flag is set to `True`, `merge()` will not issue any SQL to reconcile the given object against the current state of the database, thereby reducing query overhead.   The limitation is that the given object and all of its children may not contain any pending changes, and it's also of course possible that newer information in the database will not be present on the merged object, since no load is issued.
 
 ### Deleting
 
@@ -323,74 +236,58 @@ The solution is to use proper cascading:
 
 ### Flushing
 
-This is the main gateway to what the `Session` does best, which is save everything !  It should be clear by now what a flush looks like:
+When the `Session` is used with its default configuration, the flush step is nearly always done transparently.  Specifically, the flush occurs before any individual `Query` is issued, as well as within the `commit()` call before the transaction is committed.  This behavior can be disabled by constructing `sessionmaker()` with the flag `autoflush=False`.
+
+Regardless of the autoflush setting, a flush can always be forced by issing `flush()`:
     
     {python}
     session.flush()
     
-It also can be called with a list of objects; in this form, the flush operation will be limited only to the objects specified in the list:
+`flush()` also supports the ability to flush a subset of objects which are present in the session, by passing a list of objects:
 
     {python}
     # saves only user1 and address2.  all other modified
     # objects remain present in the session.
     session.flush([user1, address2])
     
-This second form of flush should be used carefully as it will not necessarily locate other dependent objects within the session, whose database representation may have foreign constraint relationships with the objects being operated upon.
-
-Theres also a way to have `flush()` called automatically before each query; this is called "autoflush" and is described below.
-
-Note that when using a `Session` that has been placed into a transaction, the `commit()` method will also `flush()` the `Session` unconditionally before committing the transaction.  
+This second form of flush should be used carefully as it currently does not cascade, meaning that it will not necessarily affect other objects directly associated with the objects given.
 
-Note that flush **does not change** the state of any collections or entity relationships in memory; for example, if you set a foreign key attribute `b_id` on object `A` with the identifier `B.id`, the change will be flushed to the database, but `A` will not have `B` added to its collection.  If you want to manipulate foreign key attributes directly, `refresh()` or `expire()` the objects whose state needs to be refreshed subsequent to flushing.
+The flush process *always* occurs within a transaction, even if the `Session` has been configured with `autocommit=True`, a setting that disables the session's persistent transactional state.  If no transaction is present, `flush()` creates its own transaction and commits it.  Any failures during flush will always result in a rollback of whatever transaction is present.
 
-### Autoflush
+### Committing
 
-A session can be configured to issue `flush()` calls before each query.  This allows you to immediately have DB access to whatever has been saved to the session.  It's recommended to use autoflush with `transactional=True`, that way an unexpected flush call won't permanently save to the database:
+`commit()` is used to commit the current transaction.  It always issues `flush()` beforehand to flush any remaining state to the database; this is independent of the "autoflush" setting.   If no transaction is present, it raises an error.  Note that the default behavior of the `Session` is that a transaction is always present; this behavior can be disabled by setting `autocommit=True`.  In autocommit mode, a transaction can be initiated by calling the `begin()` method.
 
-    {python}
-    Session = sessionmaker(autoflush=True, transactional=True)
-    sess = Session()
-    u1 = User(name='jack')
-    sess.save(u1)
-    
-    # reload user1
-    u2 = sess.query(User).filter_by(name='jack').one()
-    assert u2 is u1
+Another behavior of `commit()` is that by default it expires the state of all instances present after the commit is complete.  This is so that when the instances are next accessed, either through attribute access or by them being present in a `Query` result set, they receive the most recent state.  To disable this behavior, configure `sessionmaker()` with `autoexpire=False`.
 
-    # commit session, flushes whatever is remaining
-    sess.commit()
+Normally, instances loaded into the `Session` are never changed by subsequent queries; the assumption is that the current transaction is isolated so the state most recently loaded is correct as long as the transaction continues.  Setting `autocommit=True` works against this model to some degree since the `Session` behaves in exactly the same way with regard to attribute state, except no transaction is present.
 
-Autoflush is particularly handy when using "dynamic" mapper relations, so that changes to the underlying collection are immediately available via its query interface.
+### Rolling Back
 
-### Committing
+`rollback()` rolls back the current transaction.   With a default configured session, the post-rollback state of the session is as follows:
 
-The `commit()` method on `Session` is used specifically when the `Session` is in a transactional state.  The two ways that a session may be placed in a transactional state are to create it using the `transactional=True` option, or to call the `begin()` method.  
+  * All connections are rolled back and returned to the connection pool, unless the Session was bound directly to 
+  a Connection, in which case the connection is still maintained (but still rolled back).
+  * Objects which were initially in the *pending* state when they were added to the `Session` within the lifespan of the transaction are expunged, corresponding to their INSERT statement being rolled back.  The state of their attributes remains unchanged.
+  * Objects which were marked as *deleted* within the lifespan of the transaction are promoted back to the *persistent* state, corresponding to their DELETE statement being rolled back.  Note that if those objects were first *pending* within the transaction, that operation takes precedence instead.
+  * All objects not expunged are fully expired.  This aspect of the behavior may be disabled by configuring `sessionmaker()` with `autoexpire=False`.
 
-`commit()` serves **two** purposes; it issues a `flush()` unconditionally to persist any remaining pending changes, and it issues a commit to all currently managed database connections.  In the typical case this is just a single connection.  After the commit, connection resources which were allocated by the `Session` are released.  This holds true even for a `Session` which specifies `transactional=True`; when such a session is committed, the next transaction is not "begun" until the next database operation occurs.
+With that state understood, the `Session` may safely continue usage after a rollback occurs (note that this is a new feature as of version 0.5).
 
-See the section below on "Managing Transactions" for further detail.
+When a `flush()` fails, typically for reasons like primary key, foreign key, or "not nullable" constraint violations, a `rollback()` is issued automatically (it's currently not possible for a flush to continue after a partial failure).  However, the flush process always uses its own transactional demarcator called a *subtransaction*, which is described more fully in the docstrings for `Session`.  What it means here is that even though the database transaction has been rolled back, the end user must still issue `rollback()` to fully reset the state of the `Session`.
 
-### Expunge / Clear
+### Expunging
 
 Expunge removes an object from the Session, sending persistent instances to the detached state, and pending instances to the transient state:
 
     {python}
     session.expunge(obj1)
     
-Use `expunge` when you'd like to remove an object altogether from memory, such as before calling `del` on it, which will prevent any "ghost" operations occurring when the session is flushed.
-
-This `clear()` method is equivalent to `expunge()`-ing everything from the Session:
-    
-    {python}
-    session.clear()
-
-However note that the `clear()` method does not reset any transactional state or connection resources; therefore what you usually want to call instead of `clear()` is `close()`.    
+To remove all items, call `session.expunge_all()`.
 
 ### Closing
 
-The `close()` method issues a `clear()`, and releases any transactional/connection resources.  When connections are returned to the connection pool, whatever transactional state exists is rolled back.
-
-When `close()` is called, the `Session` is in the same state as when it was first created, and is safe to be used again.  `close()` is especially important when using a contextual session, which remains in memory after usage.  By issuing `close()`, the session will be clean for the next request that makes use of it.
+The `close()` method issues a `expunge_alll()`, and releases any transactional/connection resources.  When connections are returned to the connection pool, transactional state is rolled back as well.
 
 ### Refreshing / Expiring
 
@@ -418,6 +315,42 @@ To assist with the Session's "sticky" behavior of instances which are present, i
     session.expire(obj1, ['hello', 'world'])
     session.expire(obj2, ['hello', 'world'])
 
+The full contents of the session may be expired at once using `expire_all()`:
+
+    {python}
+    session.expire_all()
+
+`refresh()` and `expire()` are usually not needed when working with a default-configured `Session`.  The usual need is when an UPDATE or DELETE has been issued manually within the transaction using `Session.execute()`.
+
+### Session Attributes {@name=attributes} 
+
+The `Session` itself acts somewhat like a set-like collection.  All items present may be accessed using the iterator interface:
+
+    {python}
+    for obj in session:
+        print obj
+
+And presence may be tested for using regular "contains" semantics:
+
+    {python}
+    if obj in session:
+        print "Object is present"
+
+The session is also keeping track of all newly created (i.e. pending) objects, all objects which have had changes since they were last loaded or saved (i.e. "dirty"), and everything that's been marked as deleted.  
+
+    {python}
+    # pending objects recently added to the Session
+    session.new
+
+    # persistent objects which currently have changes detected
+    # (this collection is now created on the fly each time the property is called)
+    session.dirty
+
+    # persistent objects that have been marked as deleted via session.delete(obj)
+    session.deleted
+
+Note that objects within the session are by default *weakly referenced*.  This means that when they are dereferenced in the outside application, they fall out of scope from within the `Session` as well and are subject to garbage collection by the Python interpreter.  The exceptions to this include objects which are pending, objects which are marked as deleted, or persistent objects which have pending changes on them.  After a full flush, these collections are all empty, and all objects are again weakly referenced.  To disable the weak referencing behavior and force all objects within the session to remain until explicitly expunged, configure `sessionmaker()` with the `weak_identity_map=False` setting.
+
 ## Cascades
 
 Mappers support the concept of configurable *cascade* behavior on `relation()`s.  This behavior controls how the Session should treat the instances that have a parent-child relationship with another instance that is operated upon by the Session.  Cascade is indicated as a comma-separated list of string keywords, with the possible values `all`, `delete`, `save-update`, `refresh-expire`, `merge`, `expunge`, and `delete-orphan`.
@@ -430,25 +363,20 @@ Cascading is configured by setting the `cascade` keyword argument on a `relation
         'customer' : relation(User, users_table, user_orders_table, cascade="save-update"),
     })
 
-The above mapper specifies two relations, `items` and `customer`.  The `items` relationship specifies "all, delete-orphan" as its `cascade` value, indicating that all  `save`, `update`, `merge`, `expunge`, `refresh` `delete` and `expire` operations performed on a parent `Order` instance should also be performed on the child `Item` instances attached to it (`save` and `update` are cascaded using the `save_or_update()` method, so that the database identity of the instance doesn't matter).  The `delete-orphan` cascade value additionally indicates that if an `Item` instance is no longer associated with an `Order`, it should also be deleted.  The "all, delete-orphan" cascade argument allows a so-called *lifecycle* relationship between an `Order` and an `Item` object.
-
-The `customer` relationship specifies only the "save-update" cascade value, indicating most operations will not be cascaded from a parent `Order` instance to a child `User` instance, except for if the `Order` is attached with a particular session, either via the `save()`, `update()`, or `save-update()` method.
-
-Additionally, when a child item is attached to a parent item that specifies the "save-update" cascade value on the relationship, the child is automatically passed to `save_or_update()` (and the operation is further cascaded to the child item).
+The above mapper specifies two relations, `items` and `customer`.  The `items` relationship specifies "all, delete-orphan" as its `cascade` value, indicating that all  `add`, `merge`, `expunge`, `refresh` `delete` and `expire` operations performed on a parent `Order` instance should also be performed on the child `Item` instances attached to it.  The `delete-orphan` cascade value additionally indicates that if an `Item` instance is no longer associated with an `Order`, it should also be deleted.  The "all, delete-orphan" cascade argument allows a so-called *lifecycle* relationship between an `Order` and an `Item` object.
 
-Note that cascading doesn't do anything that isn't possible by manually calling Session methods on individual instances within a hierarchy, it merely automates common operations on a group of associated instances.
+The `customer` relationship specifies only the "save-update" cascade value, indicating most operations will not be cascaded from a parent `Order` instance to a child `User` instance except for the `add()` operation.  "save-update" cascade indicates that an `add()` on the parent will casade to all child items, and also that items added to a parent which is already present in the sessio will also be added.
 
 The default value for `cascade` on `relation()`s is `save-update, merge`.
 
 ## Managing Transactions
 
-The Session can manage transactions automatically, including across multiple engines.  When the Session is in a transaction, as it receives requests to execute SQL statements, it adds each individual Connection/Engine encountered to its transactional state.  At commit time, all unflushed data is flushed, and each individual transaction is committed.  If the underlying databases support two-phase semantics, this may be used by the Session as well if two-phase transactions are enabled.
+The `Session` manages transactions across all engines associated with it.  As the `Session` receives requests to execute SQL statements using a particular `Engine` or `Connection`, it adds each individual `Engine` encountered to its transactional state and maintains an open connection for each one (note that a simple application normally has just one `Engine`).  At commit time, all unflushed data is flushed, and each individual transaction is committed.  If the underlying databases support two-phase semantics, this may be used by the Session as well if two-phase transactions are enabled.
 
-The easiest way to use a Session with transactions is just to declare it as transactional.  The session will remain in a transaction at all times:
+Normal operation ends the transactional state using the `rolback()` or `commit()` methods.  After either is called, the `Session` starts a new transaction.
 
     {python}
-    # transactional session
-    Session = sessionmaker(transactional=True)
+    Session = sessionmaker()
     sess = Session()
     try:
         item1 = sess.query(Item).get(1)
@@ -462,16 +390,10 @@ The easiest way to use a Session with transactions is just to declare it as tran
         # rollback - will immediately go into a new transaction afterwards.
         sess.rollback()
 
-Things to note above:
-
-  * When using a transactional session, either a `rollback()` or a `close()` call **is required** when an error is raised by `flush()` or `commit()`.  The `flush()` error condition will issue a ROLLBACK to the database automatically, but the state of the `Session` itself remains in an "undefined" state until the user decides whether to rollback or close.
-  * The `commit()` call unconditionally issues a `flush()`.  Particularly when using `transactional=True` in conjunction with `autoflush=True`, explicit `flush()` calls are usually not needed.
-
-Alternatively, a transaction can be begun explicitly using `begin()`:
+A session which is configured with `autocommit=True` may be placed into a transaction using `begin()`.  With an `autocommit=True` session that's been placed into a transaction using `begin()`, the session releases all connection resources after a `commit()` or `rollback()` and remains transaction-less (with the exception of flushes) until the next `begin()` call:
 
     {python}
-    # non transactional session
-    Session = sessionmaker(transactional=False)
+    Session = sessionmaker(autocommit=True)
     sess = Session()
     sess.begin()
     try:
@@ -484,12 +406,10 @@ Alternatively, a transaction can be begun explicitly using `begin()`:
         sess.rollback()
         raise
 
-Like the `transactional` example, the same rules apply; an explicit `rollback()` or `close()` is required when an error occurs, and the `commit()` call issues a `flush()` as well.
-
-Session also supports Python 2.5's with statement so that the example above can be written as:
+The `begin()` method also returns a transactional token which is compatible with the Python 2.6 `with` statement:
 
     {python}
-    Session = sessionmaker(transactional=False)
+    Session = sessionmaker(autocommit=True)
     sess = Session()
     with sess.begin():
         item1 = sess.query(Item).get(1)
@@ -497,29 +417,31 @@ Session also supports Python 2.5's with statement so that the example above can
         item1.foo = 'bar'
         item2.bar = 'foo'
 
-Subtransactions can be created by calling the `begin()` method repeatedly. For each transaction you `begin()` you must always call either `commit()` or `rollback()`. Note that this includes the implicit transaction created by the transactional session. When a subtransaction is created the current transaction of the session is set to that transaction. Commiting the subtransaction will return you to the next outer transaction. Rolling it back will also return you to the next outer transaction, but in addition it will roll back database state to the innermost transaction that supports rolling back to. Usually this means the root transaction, unless you use the nested transaction functionality via the `begin_nested()` method. MySQL and Postgres (and soon Oracle) support using "nested" transactions by creating SAVEPOINTs, :
+SAVEPOINT transactions, if supported by the underlying engine, may be delineated using the `begin_nested()` method:
 
     {python}
-    Session = sessionmaker(transactional=False)
+    Session = sessionmaker()
     sess = Session()
-    sess.begin()
-    sess.save(u1)
-    sess.save(u2)
-    sess.flush()
+    sess.add(u1)
+    sess.add(u2)
 
     sess.begin_nested() # establish a savepoint
-    sess.save(u3)
+    sess.add(u3)
     sess.rollback()  # rolls back u3, keeps u1 and u2
 
     sess.commit() # commits u1 and u2
 
+`begin_nested()` may be called any number of times, which will issue a new SAVEPOINT with a unique identifier for each call.  For each `begin_nested()` call, a corresponding `rollback()` or `commit()` must be issued.  
+
+When `begin_nested()` is called, a `flush()` is unconditionally issued (regardless of the `autoflush` setting).  This is so that when a `rollback()` occurs, the full state of the session is expired, thus causing all subsequent attribute/instance access to reference the full state of the `Session` right before `begin_nested()` was called.
+
 Finally, for MySQL, Postgres, and soon Oracle as well, the session can be instructed to use two-phase commit semantics. This will coordinate the commiting of transactions across databases so that the transaction is either committed or rolled back in all databases. You can also `prepare()` the session for interacting with transactions not managed by SQLAlchemy. To use two phase transactions set the flag `twophase=True` on the session:
 
     {python}
     engine1 = create_engine('postgres://db1')
     engine2 = create_engine('postgres://db2')
     
-    Session = sessionmaker(twophase=True, transactional=True)
+    Session = sessionmaker(twophase=True)
 
     # bind User operations to engine 1, Account operations to engine 2
     Session.configure(binds={User:engine1, Account:engine2})
@@ -532,8 +454,6 @@ Finally, for MySQL, Postgres, and soon Oracle as well, the session can be instru
     # before committing both transactions
     sess.commit()
 
-Be aware that when a crash occurs in one of the databases while the the transactions are prepared you have to manually commit or rollback the prepared transactions in your database as appropriate.
-
 ## Embedding SQL Insert/Update Expressions into a Flush {@name=flushsql}
 
 This feature allows the value of a database column to be set to a SQL expression instead of a literal value.  It's especially useful for atomic updates, calling stored procedures, etc.  All you do is assign an expression to an attribute:
@@ -551,61 +471,64 @@ This feature allows the value of a database column to be set to a SQL expression
     # issues "UPDATE some_table SET value=value+1"
     session.commit()
     
-This works both for INSERT and UPDATE statements.  After the flush/commit operation, the `value` attribute on `someobject` gets "deferred", so that when you again access it the newly generated value will be loaded from the database.  This is the same mechanism at work when database-side column defaults fire off.
+This technique works both for INSERT and UPDATE statements.  After the flush/commit operation, the `value` attribute on `someobject` above is expired, so that when next accessed the newly generated value will be loaded from the database. 
 
 ## Using SQL Expressions with Sessions {@name=sql}
 
-SQL constructs and string statements can be executed via the `Session`.  You'd want to do this normally when your `Session` is transactional and you'd like your free-standing SQL statements to participate in the same transaction.
-
-The two ways to do this are to use the connection/execution services of the Session, or to have your Session participate in a regular SQL transaction.
-
-First, a Session thats associated with an Engine or Connection can execute statements immediately (whether or not it's transactional):
+SQL expressions and strings can be executed via the `Session` within its transactional context.  This is most easily accomplished using the `execute()` method, which returns a `ResultProxy` in the same manner as an `Engine` or `Connection`:
 
     {python}
-    Session = sessionmaker(bind=engine, transactional=True)
+    Session = sessionmaker(bind=engine)
     sess = Session()
+    
+    # execute a string statement
     result = sess.execute("select * from table where id=:id", {'id':7})
-    result2 = sess.execute(select([mytable], mytable.c.id==7))
+    
+    # execute a SQL expression construct
+    result = sess.execute(select([mytable]).where(mytable.c.id==7))
 
-To get at the current connection used by the session, which will be part of the current transaction if one is in progress, use `connection()`:
+The current `Connection` held by the `Session` is accessible using the `connection()` method:
 
     {python}
     connection = sess.connection()
-    
-A second scenario is that of a Session which is not directly bound to a connectable.  This session executes statements relative to a particular `Mapper`, since the mappers are bound to tables which are in turn bound to connectables via their `MetaData` (either the session or the mapped tables need to be bound).  In this case, the Session can conceivably be associated with multiple databases through different mappers; so it wants you to send along a `mapper` argument, which can be any mapped class or mapper instance:
 
+The examples above deal with a `Session` that's bound to a single `Engine` or `Connection`.  To execute statements using a `Session` which is bound either to multiple engines, or none at all (i.e. relies upon bound metadata), both `execute()` and `connection()` accept a `mapper` keyword argument, which is passed a mapped class or `Mapper` instance, which is used to locate the proper context for the desired engine:
+    
     {python}
-    # session is *not* bound to an engine or connection
-    Session = sessionmaker(transactional=True)
+    Session = sessionmaker()
     sess = Session()
     
     # need to specify mapper or class when executing
     result = sess.execute("select * from table where id=:id", {'id':7}, mapper=MyMappedClass)
-    result2 = sess.execute(select([mytable], mytable.c.id==7), mapper=MyMappedClass)
 
-    # need to specify mapper or class when you call connection()
+    result = sess.execute(select([mytable], mytable.c.id==7), mapper=MyMappedClass)
+
     connection = sess.connection(MyMappedClass)
 
-The third scenario is when you are using `Connection` and `Transaction` yourself, and want the `Session` to participate.  This is easy, as you just bind the `Session` to the connection:
+## Joining a Session into an External Transaction {@name=joining}
+
+If a `Connection` is being used which is already in a transactional state (i.e. has a `Transaction`), a `Session` can be made to participate within that transaction by just binding the `Session` to that `Connection`:
 
     {python}
-    # non-transactional session
-    Session = sessionmaker(transactional=False)
+    Session = sessionmaker()
     
     # non-ORM connection + transaction
     conn = engine.connect()
     trans = conn.begin()
     
-    # bind the Session *instance* to the connection
+    # create a Session, bind to the connection
     sess = Session(bind=conn)
     
-    # ... etc
+    # ... work with session
     
-    trans.commit()
+    sess.commit() # commit the session
+    sess.close()  # close it out, prohibit further actions
     
-It's safe to use a `Session` which is transactional or autoflushing, as well as to call `begin()`/`commit()` on the session too; the outermost Transaction object, the one we declared explicitly, controls the scope of the transaction.
+    trans.commit() # commit the actual transaction
 
-When using the `threadlocal` engine context, things are that much easier; the `Session` uses the same connection/transaction as everyone else in the current thread, whether or not you explicitly bind it:
+Note that above, we issue a `commit()` both on the `Session` as well as the `Transaction`.  This is an example of where we take advantage of `Connection`'s ability to maintain *subtransactions*, or nested begin/commit pairs.  The `Session` is used exactly as though it were managing the transaction on its own; its `commit()` method issues its `flush()`, and commits the subtransaction.   The subsequent transaction the `Session` starts after commit will not begin until it's next used.  Above we issue a `close()` to prevent this from occuring.  Finally, the actual transaction is committed using `Transaction.commit()`.
+
+When using the `threadlocal` engine context, the process above is simplified; the `Session` uses the same connection/transaction as everyone else in the current thread, whether or not you explicitly bind it:
 
     {python}
     engine = create_engine('postgres://mydb', strategy="threadlocal")
@@ -627,7 +550,7 @@ The `scoped_session()` function wraps around the `sessionmaker()` function, and
 
     {python}
     from sqlalchemy.orm import scoped_session, sessionmaker
-    Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
+    Session = scoped_session(sessionmaker())
     
 However, when you instantiate this `Session` "class", in reality the object is pulled from a threadlocal variable, or if it doesn't exist yet, it's created using the underlying class generated by `sessionmaker()`:
 
@@ -650,38 +573,22 @@ Since the `Session()` constructor now returns the same `Session` object every ti
     u2 = User()
     
     # save to the contextual session, without instantiating
-    Session.save(u1)
-    Session.save(u2)
+    Session.add(u1)
+    Session.add(u2)
     
     # view the "new" attribute
     assert u1 in Session.new
     
-    # flush changes (if not using autoflush)
-    Session.flush()
-    
-    # commit transaction (if using a transactional session)
+    # commit changes
     Session.commit()
 
-To "dispose" of the `Session`, there's two general approaches.  One is to close out the current session, but to leave it assigned to the current context.  This allows the same object to be re-used on another operation.  This may be called from a current, instantiated `Session`:
-
-    {python}
-    sess.close()
-    
-Or, when using `scoped_session()`, the `close()` method may also be called as a classmethod on the `Session` "class":
-
-    {python}
-    Session.close()
-
-When the `Session` is closed, it remains attached, but clears all of its contents and releases any ongoing transactional resources, including rolling back any remaining transactional state.  The `Session` can then be used again.
-
-The other method is to remove the current session from the current context altogether.  This is accomplished using the classmethod `remove()`:
+The contextual session may be disposed of by calling `Session.remove()`:
 
     {python}
+    # remove current contextual session
     Session.remove()
-    
-After `remove()`  is called, the next call to `Session()` will create a *new* `Session` object which then becomes the contextual session.
 
-That, in a nutshell, is all there really is to it.  Now for all the extra things one should know.
+After `remove()` is called, the next operation with the contextual session will start a new `Session` for the current thread.
 
 ### Lifespan of a Contextual Session {@name=lifespan}
 
@@ -701,7 +608,7 @@ A (really, really) common question is when does the contextual session get creat
                                              # some other code calls Session, it's the 
                                              # same contextual session as "sess"
                                              sess2 = Session()
-                                             sess2.save(foo)
+                                             sess2.add(foo)
                                              sess2.commit()
                                              
                                              # generate content to be returned
@@ -709,65 +616,14 @@ A (really, really) common question is when does the contextual session get creat
                         Session.remove() <-
     web response   <-  
 
-Above, we illustrate a *typical* organization of duties, where the "Web Framework" layer has some integration built-in to manage the span of ORM sessions.  Upon the initial handling of an incoming web request, the framework passes control to a controller.  The controller then calls `Session()` when it wishes to work with the ORM; this method establishes the contextual Session which will remain until it's removed.  Disparate parts of the controller code may all call `Session()` and will get the same session object.  Then, when the controller has completed and the response is to be sent to the web server, the framework **closes out** the current contextual session, above using the `remove()` method which removes the session from the context altogether.
-
-As an alternative, the "finalization" step can also call `Session.close()`, which will leave the same session object in place.  Which one is better ?  For a web framework which runs from a fixed pool of threads, it doesn't matter much.  For a framework which runs a **variable** number of threads, or which **creates and disposes** of a thread for each request, `remove()` is better, since it leaves no resources associated with the thread which might not exist.
-
-* Why close out the session at all ?  Why not just leave it going so the next request doesn't have to do as many queries ?
-
-    There are some cases where you may actually want to do this.  However, this is a special case where you are dealing with data which **does not change** very often, or you don't care about the "freshness" of the data.  In reality, a single thread of a web server may, on a slow day, sit around for many minutes or even hours without being accessed.  When it's next accessed, if data from the previous request still exists in the session, that data may be very stale indeed.  So it's generally better to have an empty session at the start of a web request.
-
-### Associating Classes and Mappers with a Contextual Session {@name=associating}
-
-Another luxury we gain, when we've established a `Session()` that can be globally accessed, is the ability for mapped classes and objects to provide us with session-oriented functionality automatically.  When using the `scoped_session()` function, we access this feature using the `mapper` attribute on the object in place of the normal `sqlalchemy.orm.mapper` function:
-
-    {python}
-    # "contextual" mapper function
-    mapper = Session.mapper
-    
-    # use normally
-    mapper(User, users_table, properties={
-        relation(Address)
-    })
-    mapper(Address, addresses_table)
-
-When we use the contextual `mapper()` function, our `User` and `Address` now gain a new attribute `query`, which will create a `Query` object for us against the contextual session:
-
-    {python}
-    wendy = User.query.filter_by(name='wendy').one()
-
-#### Auto-Save Behavior with Contextual Session's Mapper {@name=autosave}
-
-By default, when using Session.mapper, **new instances are saved into the contextual session automatically upon construction;** there is no longer a need to call `save()`:
-
-    {python}
-    >>> newuser = User(name='ed')
-    >>> assert newuser in Session.new
-    True
-
-The auto-save functionality can cause problems, namely that any `flush()` which occurs before a newly constructed object is fully populated will result in that object being INSERTed without all of its attributes completed.  As a `flush()` is more frequent when using sessions with `autoflush=True`, **the auto-save behavior can be disabled**, using the `save_on_init=False` flag:
-
-    {python}
-    # "contextual" mapper function
-    mapper = Session.mapper
+The above example illustrates an explicit call to `Session.remove()`.  This has the effect such that each web request starts fresh with a brand new session.   When integrating with a web framework, there's actually many options on how to proceed for this step, particularly as of version 0.5:
 
-    # use normally, specify no save on init:
-    mapper(User, users_table, properties={
-        relation(Address)
-    }, save_on_init=False)
-    mapper(Address, addresses_table, save_on_init=False)
+ * Session.remove() - this is the most cut and dry approach; the `Session` is thrown away, all of its transactional/connection resources are closed out, everything within it is explicitly gone.  A new `Session` will be used on the next request.
+ * Session.close() - Similar to calling `remove()`, in that all objects are explicitly expunged and all transactional/connection resources closed, except the actual `Session` object hangs around.  It doesn't make too much difference here unless the start of the web request would like to pass specific options to the initial construction of `Session()`, such as a specific `Engine` to bind to.
+ * Session.commit() - In this case, the behavior is that any remaining changes pending are flushed, and the transaction is committed.  The full state of the session is expired, so that when the next web request is started, all data will be reloaded.  In reality, the contents of the `Session` are weakly referenced anyway so its likely that it will be empty on the next request in any case.
+ * Session.rollback() - Similar to calling commit, except we assume that the user would have called commit explicitly if that was desired; the `rollback()` ensures that no transactional state remains and expires all data, in the case that the request was aborted and did not roll back itself.
+ * do nothing - this is a valid option as well.  The controller code is responsible for doing one of the above steps at the end of the request.
 
-    # objects now again require explicit "save"
-    >>> newuser = User(name='ed')
-    >>> assert newuser in Session.new
-    False
-    
-    >>> Session.save(newuser)
-    >>> assert newuser in Session.new
-    True
-
-The functionality of `Session.mapper` is an updated version of what used to be accomplished by the `assignmapper()` SQLAlchemy extension.
-    
 [Generated docstrings for scoped_session()](rel:docstrings_sqlalchemy.orm_modfunc_scoped_session)
 
 ## Partitioning Strategies
@@ -782,7 +638,7 @@ Vertical partitioning places different kinds of objects, or different tables, ac
     engine1 = create_engine('postgres://db1')
     engine2 = create_engine('postgres://db2')
 
-    Session = sessionmaker(twophase=True, transactional=True)
+    Session = sessionmaker(twophase=True)
 
     # bind User operations to engine 1, Account operations to engine 2
     Session.configure(binds={User:engine1, Account:engine2})
index ec7e92c24594af8924545c4f8a2c2b37554bdd5a..bc46607d77e482db4c0c047b99afd2646916bc60 100644 (file)
@@ -5,12 +5,12 @@ This tutorial will cover SQLAlchemy SQL Expressions, which are Python constructs
 
 ## Version Check
 
-A quick check to verify that we are on at least **version 0.4** of SQLAlchemy:
+A quick check to verify that we are on at least **version 0.5** of SQLAlchemy:
 
     {python}
     >>> import sqlalchemy
     >>> sqlalchemy.__version__ # doctest:+SKIP
-    0.4.0
+    0.5.0
     
 ## Connecting
 
@@ -33,14 +33,14 @@ We define our tables all within a catalog called `MetaData`, using the `Table` c
     >>> metadata = MetaData()
     >>> users = Table('users', metadata,
     ...     Column('id', Integer, primary_key=True),
-    ...     Column('name', String(40)),
-    ...     Column('fullname', String(100)),
+    ...     Column('name', String),
+    ...     Column('fullname', String),
     ... )
 
     >>> addresses = Table('addresses', metadata, 
     ...   Column('id', Integer, primary_key=True),
     ...   Column('user_id', None, ForeignKey('users.id')),
-    ...   Column('email_address', String(50), nullable=False)
+    ...   Column('email_address', String, nullable=False)
     ...  )
 
 All about how to define `Table` objects, as well as how to create them from an existing database automatically, is described in [metadata](rel:metadata).
@@ -55,8 +55,8 @@ Next, to tell the `MetaData` we'd actually like to create our selection of table
     {}
     CREATE TABLE users (
         id INTEGER NOT NULL, 
-        name VARCHAR(40)
-        fullname VARCHAR(100)
+        name VARCHAR, 
+        fullname VARCHAR, 
         PRIMARY KEY (id)
     )
     {}
@@ -64,13 +64,20 @@ Next, to tell the `MetaData` we'd actually like to create our selection of table
     CREATE TABLE addresses (
         id INTEGER NOT NULL, 
         user_id INTEGER, 
-        email_address VARCHAR(50) NOT NULL, 
+        email_address VARCHAR NOT NULL, 
         PRIMARY KEY (id), 
          FOREIGN KEY(user_id) REFERENCES users (id)
     )
     {}
     COMMIT
 
+Users familiar with the syntax of CREATE TABLE may notice that the VARCHAR columns were generated without a length; on SQLite, this is a valid datatype, but on most databases it's not allowed.  So if running this tutorial on a database such as Postgres or MySQL, and you wish to use SQLAlchemy to generate the tables, a "length" may be provided to the `String` type as below:
+
+    {python}
+    Column('name', String(50))
+
+The length field on `String`, as well as similar fields available on `Integer`, `Numeric`, etc. are not referenced by SQLAlchemy other than when creating tables.
+
 ## Insert Expressions
 
 The first SQL expression we'll create is the `Insert` construct, which represents an INSERT statement.   This is typically created relative to its target table:
@@ -327,19 +334,19 @@ If we use a literal value (a literal meaning, not a SQLAlchemy clause object), w
 
     {python}
     >>> print users.c.id==7
-    users.id = :users_id_1
+    users.id = :id_1
     
 The `7` literal is embedded in `ClauseElement`; we can use the same trick we did with the `Insert` object to see it:
 
     {python}
     >>> (users.c.id==7).compile().params
-    {'users_id_1': 7}
+    {'id_1': 7}
     
 Most Python operators, as it turns out, produce a SQL expression here, like equals, not equals, etc.:
 
     {python}
     >>> print users.c.id != 7
-    users.id != :users_id_1
+    users.id != :id_1
     
     >>> # None converts to IS NULL
     >>> print users.c.name == None
@@ -347,7 +354,7 @@ Most Python operators, as it turns out, produce a SQL expression here, like equa
      
     >>> # reverse works too 
     >>> print 'fred' > users.c.name
-    users.name < :users_name_1
+    users.name < :name_1
     
 If we add two integer columns together, we get an addition expression:
 
@@ -373,7 +380,7 @@ If you have come across an operator which really isn't available, you can always
 
     {python}
     >>> print users.c.name.op('tiddlywinks')('foo')
-    users.name tiddlywinks :users_name_1
+    users.name tiddlywinks :name_1
     
 ## Conjunctions {@name=conjunctions}
 
@@ -384,9 +391,9 @@ We'd like to show off some of our operators inside of `select()` constructs.  Bu
     >>> print and_(users.c.name.like('j%'), users.c.id==addresses.c.user_id, #doctest: +NORMALIZE_WHITESPACE  
     ...     or_(addresses.c.email_address=='wendy@aol.com', addresses.c.email_address=='jack@yahoo.com'),
     ...     not_(users.c.id>5))
-    users.name LIKE :users_name_1 AND users.id = addresses.user_id AND 
-    (addresses.email_address = :addresses_email_address_1 OR addresses.email_address = :addresses_email_address_2) 
-    AND users.id <= :users_id_1
+    users.name LIKE :name_1 AND users.id = addresses.user_id AND 
+    (addresses.email_address = :email_address_1 OR addresses.email_address = :email_address_2) 
+    AND users.id <= :id_1
 
 And you can also use the re-jiggered bitwise AND, OR and NOT operators, although because of Python operator precedence you have to watch your parenthesis:
 
@@ -394,9 +401,9 @@ And you can also use the re-jiggered bitwise AND, OR and NOT operators, although
     >>> print users.c.name.like('j%') & (users.c.id==addresses.c.user_id) &  \
     ...     ((addresses.c.email_address=='wendy@aol.com') | (addresses.c.email_address=='jack@yahoo.com')) \
     ...     & ~(users.c.id>5) # doctest: +NORMALIZE_WHITESPACE
-    users.name LIKE :users_name_1 AND users.id = addresses.user_id AND 
-    (addresses.email_address = :addresses_email_address_1 OR addresses.email_address = :addresses_email_address_2) 
-    AND users.id <= :users_id_1
+    users.name LIKE :name_1 AND users.id = addresses.user_id AND 
+    (addresses.email_address = :email_address_1 OR addresses.email_address = :email_address_2) 
+    AND users.id <= :id_1
 
 So with all of this vocabulary, let's select all users who have an email address at AOL or MSN, whose name starts with a letter between "m" and "z", and we'll also generate a column containing their full name combined with their email address.  We will add two new constructs to this statement, `between()` and `label()`.  `between()` produces a BETWEEN clause, and `label()` is used in a column expression to produce labels using the `AS` keyword; it's recommended when selecting from expressions that otherwise would not have a name:
 
@@ -528,7 +535,7 @@ Of course you can join on whatever expression you want, such as if we want to jo
 
     {python}
     >>> print users.join(addresses, addresses.c.email_address.like(users.c.name + '%'))
-    users JOIN addresses ON addresses.email_address LIKE users.name || :users_name_1
+    users JOIN addresses ON addresses.email_address LIKE users.name || :name_1
 
 When we create a `select()` construct, SQLAlchemy looks around at the tables we've mentioned and then places them in the FROM clause of the statement.  When we use JOINs however, we know what FROM clause we want, so here we make usage of the `from_obj` keyword argument:
 
@@ -617,9 +624,9 @@ So we started small, added one little thing at a time, and at the end we have a
     >>> print query
     {opensql}SELECT users.id AS users_id, users.name AS users_name, users.fullname AS users_fullname, addresses_1.id AS addresses_1_id, addresses_1.user_id AS addresses_1_user_id, addresses_1.email_address AS addresses_1_email_address 
     FROM users LEFT OUTER JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id 
-    WHERE users.name = :users_name_1 AND (EXISTS (SELECT addresses_1.id 
+    WHERE users.name = :name_1 AND (EXISTS (SELECT addresses_1.id 
     FROM addresses AS addresses_1 
-    WHERE addresses_1.user_id = users.id AND addresses_1.email_address LIKE :addresses_email_address_1)) ORDER BY users.fullname DESC
+    WHERE addresses_1.user_id = users.id AND addresses_1.email_address LIKE :email_address_1)) ORDER BY users.fullname DESC
 
 One more thing though, with automatic labeling applied as well as anonymous aliasing, how do we retrieve the columns from the rows for this thing ?  The label for the `email_addresses` column is now the generated name `addresses_1_email_address`; and in another statement might be something different !  This is where accessing by result columns by `Column` object becomes very useful:
 
@@ -783,11 +790,11 @@ Also available, though not supported on all databases, are `intersect()`, `inter
 To embed a SELECT in a column expression, use `as_scalar()`:
 
     {python}
-    {sql}>>> print conn.execute(select([
+    {sql}>>> print conn.execute(select([   # doctest: +NORMALIZE_WHITESPACE
     ...       users.c.name, 
     ...       select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).as_scalar()
     ...    ])).fetchall()
-    SELECT users.name, (SELECT count(addresses.id) 
+    SELECT users.name, (SELECT count(addresses.id) AS count_1
     FROM addresses 
     WHERE users.id = addresses.user_id) AS anon_1 
     FROM users
@@ -797,11 +804,11 @@ To embed a SELECT in a column expression, use `as_scalar()`:
 Alternatively, applying a `label()` to a select evaluates it as a scalar as well:
 
     {python}
-    {sql}>>> print conn.execute(select([
+    {sql}>>> print conn.execute(select([    # doctest: +NORMALIZE_WHITESPACE
     ...       users.c.name, 
     ...       select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).label('address_count')
     ...    ])).fetchall()
-    SELECT users.name, (SELECT count(addresses.id) 
+    SELECT users.name, (SELECT count(addresses.id) AS count_1
     FROM addresses 
     WHERE users.id = addresses.user_id) AS address_count 
     FROM users
@@ -839,7 +846,7 @@ The `select()` function can take keyword arguments `order_by`, `group_by` (as we
     >>> s = select([addresses.c.user_id, func.count(addresses.c.id)]).\
     ...     group_by(addresses.c.user_id).having(func.count(addresses.c.id)>1)
     {opensql}>>> print conn.execute(s).fetchall()
-    SELECT addresses.user_id, count(addresses.id) 
+    SELECT addresses.user_id, count(addresses.id) AS count_1 
     FROM addresses GROUP BY addresses.user_id 
     HAVING count(addresses.id) > ?
     [1]
index d9bad13841d88e9aabde56937e51ec41b94fbd8f..8ea4c765294ed6ad009b28bb35ab51875d42ad65 100644 (file)
@@ -6,11 +6,8 @@ from sqlalchemy import schema, types, engine, sql, pool, orm, exceptions, databa
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.engine import default, strategies, threadlocal, url
 import sqlalchemy.orm.shard
-import sqlalchemy.ext.sessioncontext as sessioncontext
-import sqlalchemy.ext.selectresults as selectresults
 import sqlalchemy.ext.orderinglist as orderinglist
 import sqlalchemy.ext.associationproxy as associationproxy
-import sqlalchemy.ext.assignmapper as assignmapper
 import sqlalchemy.ext.sqlsoup as sqlsoup
 import sqlalchemy.ext.declarative as declarative
 
index e28f86609527753f7bb243da20694f565d03bbe6..f53332e6419df5da4696c4e4394cc82b3bbee34f 100644 (file)
@@ -47,7 +47,7 @@ if options.file:
 else:
     to_gen = files + post_files
 
-title='SQLAlchemy 0.4 Documentation'
+title='SQLAlchemy 0.5 Documentation'
 version = options.version
 
 
index 998320fb0d1dc7b9c975d4a93517f5dee96b66da..cb6499d5f42566179297eb7cd1b6a16347e6090f 100644 (file)
@@ -5,7 +5,7 @@ import os
 import re
 import doctest
 import sqlalchemy.util as util
-import sqlalchemy.logging as salog
+import sqlalchemy.log as salog
 import logging
 
 salog.default_enabled=True
diff --git a/examples/custom_attributes/custom_management.py b/examples/custom_attributes/custom_management.py
new file mode 100644 (file)
index 0000000..707e182
--- /dev/null
@@ -0,0 +1,190 @@
+"""this example illustrates how to replace SQLAlchemy's class descriptors with a user-defined system.
+
+This sort of thing is appropriate for integration with frameworks that redefine class behaviors
+in their own way, such that SQLA's default instrumentation is not compatible.   
+
+The example illustrates redefinition of instrumentation at the class level as well as the collection
+level, and redefines the storage of the class to store state within "instance._goofy_dict" instead
+of "instance.__dict__".  Note that the default collection implementations can be used 
+with a custom attribute system as well.
+
+"""
+from sqlalchemy import *
+from sqlalchemy.orm import *
+
+from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute, is_instrumented
+from sqlalchemy.orm.collections import collection_adapter
+
+
+class MyClassState(InstrumentationManager):
+    def __init__(self, cls):
+        self.states = {}
+
+    def instrument_attribute(self, class_, key, attr):
+        pass
+
+    def install_descriptor(self, class_, key, attr):
+        pass
+
+    def uninstall_descriptor(self, class_, key, attr):
+        pass
+
+    def instrument_collection_class(self, class_, key, collection_class):
+        return MyCollection
+
+    def get_instance_dict(self, class_, instance):
+        return instance._goofy_dict
+
+    def initialize_instance_dict(self, class_, instance):
+        instance.__dict__['_goofy_dict'] = {}
+
+    def initialize_collection(self, key, state, factory):
+        data = factory()
+        return MyCollectionAdapter(key, state, data), data
+
+    def install_state(self, class_, instance, state):
+        self.states[id(instance)] = state
+
+    def state_getter(self, class_):
+        def find(instance):
+            return self.states[id(instance)]
+        return find
+
+class MyClass(object):
+    __sa_instrumentation_manager__ = MyClassState
+
+    def __init__(self, **kwargs):
+        for k in kwargs:
+            setattr(self, k, kwargs[k])
+
+    def __getattr__(self, key):
+        if is_instrumented(self, key):
+            return get_attribute(self, key)
+        else:
+            try:
+                return self._goofy_dict[key]
+            except KeyError:
+                raise AttributeError(key)
+
+    def __setattr__(self, key, value):
+        if is_instrumented(self, key):
+            set_attribute(self, key, value)
+        else:
+            self._goofy_dict[key] = value
+
+    def __delattr__(self, key):
+        if is_instrumented(self, key):
+            del_attribute(self, key)
+        else:
+            del self._goofy_dict[key]
+
+class MyCollectionAdapter(object):
+    """An wholly alternative instrumentation implementation."""
+    def __init__(self, key, state, collection):
+        self.key = key
+        self.state = state
+        self.collection = collection
+        setattr(collection, '_sa_adapter', self)
+
+    def unlink(self, data):
+        setattr(data, '_sa_adapter', None)
+
+    def adapt_like_to_iterable(self, obj):
+        return iter(obj)
+
+    def append_with_event(self, item, initiator=None):
+        self.collection.add(item, emit=initiator)
+
+    def append_without_event(self, item):
+        self.collection.add(item, emit=False)
+
+    def remove_with_event(self, item, initiator=None):
+        self.collection.remove(item, emit=initiator)
+
+    def remove_without_event(self, item):
+        self.collection.remove(item, emit=False)
+
+    def clear_with_event(self, initiator=None):
+        for item in list(self):
+            self.remove_with_event(item, initiator)
+    def clear_without_event(self):
+        for item in list(self):
+            self.remove_without_event(item)
+    def __iter__(self):
+        return iter(self.collection)
+
+    def fire_append_event(self, item, initiator=None):
+        if initiator is not False and item is not None:
+            self.state.get_impl(self.key).fire_append_event(self.state, item,
+                                                            initiator)
+
+    def fire_remove_event(self, item, initiator=None):
+        if initiator is not False and item is not None:
+            self.state.get_impl(self.key).fire_remove_event(self.state, item,
+                                                            initiator)
+
+    def fire_pre_remove_event(self, initiator=None):
+        self.state.get_impl(self.key).fire_pre_remove_event(self.state,
+                                                            initiator)
+
+class MyCollection(object):
+    def __init__(self):
+        self.members = list()
+    def add(self, object, emit=None):
+        self.members.append(object)
+        collection_adapter(self).fire_append_event(object, emit)
+    def remove(self, object, emit=None):
+        collection_adapter(self).fire_pre_remove_event(object)
+        self.members.remove(object)
+        collection_adapter(self).fire_remove_event(object, emit)
+    def __getitem__(self, index):
+        return self.members[index]
+    def __iter__(self):
+        return iter(self.members)
+    def __len__(self):
+        return len(self.members)
+
+if __name__ == '__main__':
+    meta = MetaData(create_engine('sqlite://'))
+
+    table1 = Table('table1', meta, Column('id', Integer, primary_key=True), Column('name', Text))
+    table2 = Table('table2', meta, Column('id', Integer, primary_key=True), Column('name', Text), Column('t1id', Integer, ForeignKey('table1.id')))
+    meta.create_all()
+
+    class A(MyClass):
+        pass
+
+    class B(MyClass):
+        pass
+
+    mapper(A, table1, properties={
+        'bs':relation(B)
+    })
+
+    mapper(B, table2)
+
+    a1 = A(name='a1', bs=[B(name='b1'), B(name='b2')])
+
+    assert a1.name == 'a1'
+    assert a1.bs[0].name == 'b1'
+    assert isinstance(a1.bs, MyCollection)
+
+    sess = create_session()
+    sess.save(a1)
+
+    sess.flush()
+    sess.clear()
+
+    a1 = sess.query(A).get(a1.id)
+
+    assert a1.name == 'a1'
+    assert a1.bs[0].name == 'b1'
+    assert isinstance(a1.bs, MyCollection)
+
+    a1.bs.remove(a1.bs[0])
+
+    sess.flush()
+    sess.clear()
+
+    a1 = sess.query(A).get(a1.id)
+    assert len(a1.bs) == 1
index 682def78c3724a6d6442f915d369e15eb458a926..b47a6d68f6313333b600820eee45f5899c5ba3de 100644 (file)
@@ -10,9 +10,10 @@ although the hash policy of the members would need to be distilled into a filter
 """\r
 \r
 class MyProxyDict(object):\r
-    def __init__(self, parent, collection_name, keyname):\r
+    def __init__(self, parent, collection_name, childclass, keyname):\r
         self.parent = parent\r
         self.collection_name = collection_name\r
+        self.childclass = childclass\r
         self.keyname = keyname\r
         \r
     def collection(self):\r
@@ -20,8 +21,8 @@ class MyProxyDict(object):
     collection = property(collection)\r
     \r
     def keys(self):\r
-        # this can be improved to not query all columns\r
-        return [getattr(x, self.keyname) for x in self.collection.all()]\r
+        descriptor = getattr(self.childclass, self.keyname)\r
+        return [x[0] for x in self.collection.values(descriptor)]\r
         \r
     def __getitem__(self, key):\r
         x = self.collection.filter_by(**{self.keyname:key}).first()\r
@@ -51,7 +52,7 @@ class MyParent(Base):
     _collection = dynamic_loader("MyChild", cascade="all, delete-orphan")\r
     \r
     def child_map(self):\r
-        return MyProxyDict(self, '_collection', 'key')\r
+        return MyProxyDict(self, '_collection', MyChild, 'key')\r
     child_map = property(child_map)\r
     \r
 class MyChild(Base):\r
@@ -63,15 +64,14 @@ class MyChild(Base):
     \r
 Base.metadata.create_all()\r
 \r
-sess = create_session(autoflush=True, transactional=True)\r
+sess = sessionmaker()()\r
 \r
 p1 = MyParent(name='p1')\r
-sess.save(p1)\r
+sess.add(p1)\r
 \r
 p1.child_map['k1'] = k1 = MyChild(key='k1')\r
 p1.child_map['k2'] = k2 = MyChild(key='k2')\r
 \r
-\r
 assert p1.child_map.keys() == ['k1', 'k2']\r
 \r
 assert p1.child_map['k1'] is k1\r
index 343a0cac8c90331e0f82b414b7a1c0205eb5ac10..e33451611f939bee16959ab679b3b990003bac9c 100644 (file)
@@ -5,6 +5,11 @@
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 import inspect
+import sys
+
+import sqlalchemy.exc as exceptions
+sys.modules['sqlalchemy.exceptions'] = exceptions
+
 from sqlalchemy.types import \
     BLOB, BOOLEAN, CHAR, CLOB, DATE, DATETIME, DECIMAL, FLOAT, INT, \
     NCHAR, NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, \
@@ -32,3 +37,5 @@ __all__ = [ name for name, obj in locals().items()
             if not (name.startswith('_') or inspect.ismodule(obj)) ]
 
 __version__ = 'svn'
+
+del inspect, sys
index 38dba17a5a18fb7d61a89357a5da6e2d3427948a..aa65985d40715190ed44bb6b5ce34641ab71a513 100644 (file)
@@ -5,7 +5,7 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from sqlalchemy import sql, schema, types, exceptions, pool
+from sqlalchemy import sql, schema, types, exc, pool
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.engine import default, base
 
@@ -202,7 +202,7 @@ class AccessDialect(default.DefaultDialect):
                 except pythoncom.com_error:
                     pass
             else:
-                raise exceptions.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.")
+                raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.")
 
         import pyodbc as module
         return module
@@ -236,7 +236,7 @@ class AccessDialect(default.DefaultDialect):
             c.execute(statement, parameters)
             self.context.rowcount = c.rowcount
         except Exception, e:
-            raise exceptions.DBAPIError.instance(statement, parameters, e)
+            raise exc.DBAPIError.instance(statement, parameters, e)
 
     def has_table(self, connection, tablename, schema=None):
         # This approach seems to be more reliable that using DAO
@@ -272,7 +272,7 @@ class AccessDialect(default.DefaultDialect):
                 if tbl.Name.lower() == table.name.lower():
                     break
             else:
-                raise exceptions.NoSuchTableError(table.name)
+                raise exc.NoSuchTableError(table.name)
 
             for col in tbl.Fields:
                 coltype = self.ischema_names[col.Type]
@@ -333,7 +333,7 @@ class AccessDialect(default.DefaultDialect):
         # This is necessary, so we get the latest updates
         dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
 
-        names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] <> "~TMP"]
+        names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"]
         dtbs.Close()
         return names
 
@@ -345,7 +345,7 @@ class AccessCompiler(compiler.DefaultCompiler):
         if select.limit:
             s += "TOP %s " % (select.limit)
         if select.offset:
-            raise exceptions.InvalidRequestError('Access does not support LIMIT with an offset')
+            raise exc.InvalidRequestError('Access does not support LIMIT with an offset')
         return s
 
     def limit_clause(self, select):
@@ -378,14 +378,14 @@ class AccessCompiler(compiler.DefaultCompiler):
     # Strip schema
     def visit_table(self, table, asfrom=False, **kwargs):
         if asfrom:
-            return self.preparer.quote(table, table.name)
+            return self.preparer.quote(table.name, table.quote)
         else:
             return ""
 
 
 class AccessSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
 
         # install a sequence if we have an implicit IDENTITY column
         if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
index 5e1dd72bb04c04bcdce5693029a2eac6f2bd3a5d..098759d183d2b22b14e43958ba6152c7ca77b5fd 100644 (file)
@@ -89,7 +89,7 @@ connections are active, the following setting may alleviate the problem::
 
 import datetime
 
-from sqlalchemy import exceptions, schema, types as sqltypes, sql, util
+from sqlalchemy import exc, schema, types as sqltypes, sql, util
 from sqlalchemy.engine import base, default
 
 
@@ -272,7 +272,7 @@ class FBDialect(default.DefaultDialect):
         default.DefaultDialect.__init__(self, **kwargs)
 
         self.type_conv = type_conv
-        self.concurrency_level= concurrency_level
+        self.concurrency_level = concurrency_level
 
     def dbapi(cls):
         import kinterbasdb
@@ -320,7 +320,7 @@ class FBDialect(default.DefaultDialect):
         version = fbconn.server_version
         m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
         if not m:
-            raise exceptions.AssertionError("Could not determine version from string '%s'" % version)
+            raise AssertionError("Could not determine version from string '%s'" % version)
         return tuple([int(x) for x in m.group(5, 6, 4)])
 
     def _normalize_name(self, name):
@@ -455,7 +455,7 @@ class FBDialect(default.DefaultDialect):
 
         # get primary key fields
         c = connection.execute(keyqry, ["PRIMARY KEY", tablename])
-        pkfields =[self._normalize_name(r['fname']) for r in c.fetchall()]
+        pkfields = [self._normalize_name(r['fname']) for r in c.fetchall()]
 
         # get all of the fields for this table
         c = connection.execute(tblqry, [tablename])
@@ -509,14 +509,15 @@ class FBDialect(default.DefaultDialect):
             table.append_column(col)
 
         if not found_table:
-            raise exceptions.NoSuchTableError(table.name)
+            raise exc.NoSuchTableError(table.name)
 
         # get the foreign keys
         c = connection.execute(fkqry, ["FOREIGN KEY", tablename])
         fks = {}
         while True:
             row = c.fetchone()
-            if not row: break
+            if not row:
+                break
 
             cname = self._normalize_name(row['cname'])
             try:
@@ -530,7 +531,7 @@ class FBDialect(default.DefaultDialect):
             fk[0].append(fname)
             fk[1].append(refspec)
 
-        for name,value in fks.iteritems():
+        for name, value in fks.iteritems():
             table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
 
     def do_execute(self, cursor, statement, parameters, **kwargs):
@@ -626,7 +627,7 @@ class FBSchemaGenerator(sql.compiler.SchemaGenerator):
 
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
-        colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+        colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
 
         default = self.get_column_default_string(column)
         if default is not None:
@@ -711,7 +712,7 @@ class FBIdentifierPreparer(sql.compiler.IdentifierPreparer):
     reserved_words = RESERVED_WORDS
 
     def __init__(self, dialect):
-        super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True)
+        super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True)
 
 
 dialect = FBDialect
index 1b3b3838abf0cb74ef024db7d11edf8f27106843..20929cf1e9c5a695edd5c405f8482cac3ef8de4e 100644 (file)
@@ -1,5 +1,5 @@
 import sqlalchemy.sql as sql
-import sqlalchemy.exceptions as exceptions
+import sqlalchemy.exc as exc
 from sqlalchemy import select, MetaData, Table, Column, String, Integer
 from sqlalchemy.schema import PassiveDefault, ForeignKeyConstraint
 
@@ -124,13 +124,13 @@ def reflecttable(connection, table, include_columns, ischema_names):
         coltype = ischema_names[type]
         #print "coltype " + repr(coltype) + " args " +  repr(args)
         coltype = coltype(*args)
-        colargs= []
+        colargs = []
         if default is not None:
             colargs.append(PassiveDefault(sql.text(default)))
         table.append_column(Column(name, coltype, nullable=nullable, *colargs))
     
     if not found_table:
-        raise exceptions.NoSuchTableError(table.name)
+        raise exc.NoSuchTableError(table.name)
 
     # we are relying on the natural ordering of the constraint_column_usage table to return the referenced columns
     # in an order that corresponds to the ordinal_position in the key_constraints table, otherwise composite foreign keys
@@ -157,13 +157,13 @@ def reflecttable(connection, table, include_columns, ischema_names):
             row[colmap[6]]
         )
         #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) 
-        if type=='PRIMARY KEY':
+        if type == 'PRIMARY KEY':
             table.primary_key.add(table.c[constrained_column])
-        elif type=='FOREIGN KEY':
+        elif type == 'FOREIGN KEY':
             try:
                 fk = fks[constraint_name]
             except KeyError:
-                fk = ([],[])
+                fk = ([], [])
                 fks[constraint_name] = fk
             if current_schema == referred_schema:
                 referred_schema = table.schema
index 2e1f19de96c98473104ce68f90cb4a8cd07bb7cc..c7bc49dbe8b6506ae696c5bd67848e1308e7a243 100644 (file)
@@ -8,7 +8,7 @@
 
 import datetime
 
-from sqlalchemy import sql, schema, exceptions, pool, util
+from sqlalchemy import sql, schema, exc, pool, util
 from sqlalchemy.sql import compiler
 from sqlalchemy.engine import default
 from sqlalchemy import types as sqltypes
@@ -197,7 +197,7 @@ class InfoExecutionContext(default.DefaultExecutionContext):
     # 5 - rowid after insert
     def post_exec(self):
         if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None:
-            self._last_inserted_ids = [self.cursor.sqlerrd[1],]
+            self._last_inserted_ids = [self.cursor.sqlerrd[1]]
         elif hasattr( self.compiled , 'offset' ):
             self.cursor.offset( self.compiled.offset )
         super(InfoExecutionContext, self).post_exec()
@@ -210,7 +210,7 @@ class InfoDialect(default.DefaultDialect):
     # for informix 7.31
     max_identifier_length = 18
 
-    def __init__(self, use_ansi=True,**kwargs):
+    def __init__(self, use_ansi=True, **kwargs):
         self.use_ansi = use_ansi
         default.DefaultDialect.__init__(self, **kwargs)
 
@@ -244,19 +244,19 @@ class InfoDialect(default.DefaultDialect):
         else:
             opt = {}
 
-        return ([dsn,], opt )
+        return ([dsn], opt)
 
     def create_execution_context(self , *args, **kwargs):
         return InfoExecutionContext(self, *args, **kwargs)
 
-    def oid_column_name(self,column):
+    def oid_column_name(self, column):
         return "rowid"
 
     def table_names(self, connection, schema):
         s = "select tabname from systables"
         return [row[0] for row in connection.execute(s)]
 
-    def has_table(self, connection, table_name,schema=None):
+    def has_table(self, connection, table_name, schema=None):
         cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() )
         return bool( cursor.fetchone() is not None )
 
@@ -264,18 +264,18 @@ class InfoDialect(default.DefaultDialect):
         c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() )
         rows = c.fetchall()
         if not rows :
-            raise exceptions.NoSuchTableError(table.name)
+            raise exc.NoSuchTableError(table.name)
         else:
             if table.owner is not None:
                 if table.owner.lower() in [r[0] for r in rows]:
                     owner = table.owner.lower()
                 else:
-                    raise exceptions.AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name))
+                    raise AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name))
             else:
                 if len(rows)==1:
                     owner = rows[0][0]
                 else:
-                    raise exceptions.AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name)
+                    raise AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name)
 
         c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3
                                     where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=?
@@ -284,7 +284,7 @@ class InfoDialect(default.DefaultDialect):
         rows = c.fetchall()
 
         if not rows:
-            raise exceptions.NoSuchTableError(table.name)
+            raise exc.NoSuchTableError(table.name)
 
         for name , colattr , collength , default , colno in rows:
             name = name.lower()
@@ -341,8 +341,8 @@ class InfoDialect(default.DefaultDialect):
             try:
                 fk = fks[cons_name]
             except KeyError:
-               fk = ([], [])
-               fks[cons_name] = fk
+                fk = ([], [])
+                fks[cons_name] = fk
             refspec = ".".join([remote_table, remote_column])
             schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection)
             if local_column not in fk[0]:
@@ -436,7 +436,7 @@ class InfoSchemaGenerator(compiler.SchemaGenerator):
             colspec += " SERIAL"
             self.has_serial = True
         else:
-            colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+            colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
             default = self.get_column_default_string(column)
             if default is not None:
                 colspec += " DEFAULT " + default
index 23ff1f4a000f4d3d89c5629f1e9846cb35ea04b0..392cde61fbbc14a7e90e1542f3a68c8c18527531 100644 (file)
@@ -58,7 +58,7 @@ this.
 
 import datetime, itertools, re
 
-from sqlalchemy import exceptions, schema, sql, util
+from sqlalchemy import exc, schema, sql, util
 from sqlalchemy.sql import operators as sql_operators, expression as sql_expr
 from sqlalchemy.sql import compiler, visitors
 from sqlalchemy.engine import base as engine_base, default
@@ -213,7 +213,7 @@ class MaxTimestamp(sqltypes.DateTime):
                 ms = getattr(value, 'microsecond', 0)
                 return value.strftime("%Y-%m-%d %H:%M:%S." + ("%06u" % ms))
             else:
-                raise exceptions.InvalidRequestError(
+                raise exc.InvalidRequestError(
                     "datetimeformat '%s' is not supported." % (
                     dialect.datetimeformat,))
         return process
@@ -235,7 +235,7 @@ class MaxTimestamp(sqltypes.DateTime):
                                 value[11:13], value[14:16], value[17:19],
                                 value[20:])])
             else:
-                raise exceptions.InvalidRequestError(
+                raise exc.InvalidRequestError(
                     "datetimeformat '%s' is not supported." % (
                     dialect.datetimeformat,))
         return process
@@ -256,7 +256,7 @@ class MaxDate(sqltypes.Date):
             elif dialect.datetimeformat == 'iso':
                 return value.strftime("%Y-%m-%d")
             else:
-                raise exceptions.InvalidRequestError(
+                raise exc.InvalidRequestError(
                     "datetimeformat '%s' is not supported." % (
                     dialect.datetimeformat,))
         return process
@@ -272,7 +272,7 @@ class MaxDate(sqltypes.Date):
                 return datetime.date(
                     *[int(v) for v in (value[0:4], value[5:7], value[8:10])])
             else:
-                raise exceptions.InvalidRequestError(
+                raise exc.InvalidRequestError(
                     "datetimeformat '%s' is not supported." % (
                     dialect.datetimeformat,))
         return process
@@ -293,7 +293,7 @@ class MaxTime(sqltypes.Time):
             elif dialect.datetimeformat == 'iso':
                 return value.strftime("%H-%M-%S")
             else:
-                raise exceptions.InvalidRequestError(
+                raise exc.InvalidRequestError(
                     "datetimeformat '%s' is not supported." % (
                     dialect.datetimeformat,))
         return process
@@ -310,7 +310,7 @@ class MaxTime(sqltypes.Time):
                 return datetime.time(
                     *[int(v) for v in (value[0:4], value[5:7], value[8:10])])
             else:
-                raise exceptions.InvalidRequestError(
+                raise exc.InvalidRequestError(
                     "datetimeformat '%s' is not supported." % (
                     dialect.datetimeformat,))
         return process
@@ -599,7 +599,7 @@ class MaxDBDialect(default.DefaultDialect):
 
         rows = connection.execute(st, params).fetchall()
         if not rows:
-            raise exceptions.NoSuchTableError(table.fullname)
+            raise exc.NoSuchTableError(table.fullname)
 
         include_columns = util.Set(include_columns or [])
 
@@ -833,7 +833,7 @@ class MaxDBCompiler(compiler.DefaultCompiler):
                 # LIMIT.  Right?  Other dialects seem to get away with
                 # dropping order.
                 if select._limit:
-                    raise exceptions.InvalidRequestError(
+                    raise exc.InvalidRequestError(
                         "MaxDB does not support ORDER BY in subqueries")
                 else:
                     return ""
@@ -846,7 +846,7 @@ class MaxDBCompiler(compiler.DefaultCompiler):
         sql = select._distinct and 'DISTINCT ' or ''
         if self.is_subquery(select) and select._limit:
             if select._offset:
-                raise exceptions.InvalidRequestError(
+                raise exc.InvalidRequestError(
                     'MaxDB does not support LIMIT with an offset.')
             sql += 'TOP %s ' % select._limit
         return sql
@@ -858,7 +858,7 @@ class MaxDBCompiler(compiler.DefaultCompiler):
             # sub queries need TOP
             return ''
         elif select._offset:
-            raise exceptions.InvalidRequestError(
+            raise exc.InvalidRequestError(
                 'MaxDB does not support LIMIT with an offset.')
         else:
             return ' \n LIMIT %s' % (select._limit,)
@@ -952,7 +952,7 @@ class MaxDBIdentifierPreparer(compiler.IdentifierPreparer):
 class MaxDBSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kw):
         colspec = [self.preparer.format_column(column),
-                   column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()]
+                   column.type.dialect_impl(self.dialect).get_col_spec()]
 
         if not column.nullable:
             colspec.append('NOT NULL')
index ab5a968716b7a8f54a9d27b2e943de06ec61b299..4e129952fa7367dc320c7670c75aca32fea6cafa 100644 (file)
@@ -40,7 +40,7 @@ Known issues / TODO:
 
 import datetime, operator, re, sys
 
-from sqlalchemy import sql, schema, exceptions, util
+from sqlalchemy import sql, schema, exc, util
 from sqlalchemy.sql import compiler, expression, operators as sqlops, functions as sql_functions
 from sqlalchemy.engine import default, base
 from sqlalchemy import types as sqltypes
@@ -440,7 +440,7 @@ class MSSQLDialect(default.DefaultDialect):
                 dialect_cls = dialect_mapping[module_name]
                 return dialect_cls.import_dbapi()
             except KeyError:
-                raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
+                raise exc.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
         else:
             for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]:
                 try:
@@ -512,7 +512,7 @@ class MSSQLDialect(default.DefaultDialect):
             self.context.rowcount = c.rowcount
             c.DBPROP_COMMITPRESERVE = "Y"
         except Exception, e:
-            raise exceptions.DBAPIError.instance(statement, parameters, e)
+            raise exc.DBAPIError.instance(statement, parameters, e)
 
     def table_names(self, connection, schema):
         from sqlalchemy.databases import information_schema as ischema
@@ -602,14 +602,14 @@ class MSSQLDialect(default.DefaultDialect):
                 elif coltype in (MSNVarchar, AdoMSNVarchar) and charlen == -1:
                     args[0] = None
                 coltype = coltype(*args)
-            colargs= []
+            colargs = []
             if default is not None:
                 colargs.append(schema.PassiveDefault(sql.text(default)))
 
             table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs))
 
         if not found_table:
-            raise exceptions.NoSuchTableError(table.name)
+            raise exc.NoSuchTableError(table.name)
 
         # We also run an sp_columns to check for identity columns:
         cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (table.name, current_schema))
@@ -633,8 +633,8 @@ class MSSQLDialect(default.DefaultDialect):
                 row = cursor.fetchone()
                 cursor.close()
                 if not row is None:
-                    ic.sequence.start=int(row[0])
-                    ic.sequence.increment=int(row[1])
+                    ic.sequence.start = int(row[0])
+                    ic.sequence.increment = int(row[1])
             except:
                 # ignoring it, works just like before
                 pass
@@ -684,13 +684,15 @@ class MSSQLDialect(default.DefaultDialect):
                 
             if rfknm != fknm:
                 if fknm:
-                    table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table,s,t,c) for s,t,c in rcols], fknm))
+                    table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm))
                 fknm, scols, rcols = (rfknm, [], [])
-            if (not scol in scols): scols.append(scol)
-            if (not (rschema, rtbl, rcol) in rcols): rcols.append((rschema, rtbl, rcol))
+            if not scol in scols:
+                scols.append(scol)
+            if not (rschema, rtbl, rcol) in rcols:
+                rcols.append((rschema, rtbl, rcol))
 
         if fknm and scols:
-            table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table,s,t,c) for s,t,c in rcols], fknm))
+            table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm))
 
 
 class MSSQLDialect_pymssql(MSSQLDialect):
@@ -895,7 +897,7 @@ class MSSQLCompiler(compiler.DefaultCompiler):
             if select._limit:
                 s += "TOP %s " % (select._limit,)
             if select._offset:
-                raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset')
+                raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
             return s
         return compiler.DefaultCompiler.get_select_precolumns(self, select)
 
@@ -1005,7 +1007,7 @@ class MSSQLCompiler(compiler.DefaultCompiler):
 
 class MSSQLSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
 
         # install a IDENTITY Sequence if we have an implicit IDENTITY column
         if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
index a3acac587026f0d550edf26c7835dacc213dfa16..92f533633c2b990a4f859543f3d22dd32e5ef31a 100644 (file)
@@ -53,8 +53,8 @@ class Connection:
 
 # override 'connect' call
 def connect(*args, **kwargs):
-        import mx.ODBC.Windows
-        conn = mx.ODBC.Windows.Connect(*args, **kwargs)
-        conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT
-        return Connection(conn)
+    import mx.ODBC.Windows
+    conn = mx.ODBC.Windows.Connect(*args, **kwargs)
+    conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT
+    return Connection(conn)
 Connect = connect
index a86035be50754d0c0feda4c9a0c687fd9b03b895..9cc5c38a692f8d2fae81d800cc766894fd8664e8 100644 (file)
@@ -156,7 +156,7 @@ timely information affecting MySQL in SQLAlchemy.
 import datetime, inspect, re, sys
 from array import array as _array
 
-from sqlalchemy import exceptions, logging, schema, sql, util
+from sqlalchemy import exc, log, schema, sql, util
 from sqlalchemy.sql import operators as sql_operators
 from sqlalchemy.sql import functions as sql_functions
 from sqlalchemy.sql import compiler
@@ -404,7 +404,7 @@ class MSDouble(sqltypes.Float, _NumericType):
 
         if ((precision is None and length is not None) or
             (precision is not None and length is None)):
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "You must specify both precision and length or omit "
                 "both altogether.")
 
@@ -1188,7 +1188,7 @@ class MSEnum(MSString):
         super_convert = super(MSEnum, self).bind_processor(dialect)
         def process(value):
             if self.strict and value is not None and value not in self.enums:
-                raise exceptions.InvalidRequestError('"%s" not a valid value for '
+                raise exc.InvalidRequestError('"%s" not a valid value for '
                                                      'this enum' % value)
             if super_convert:
                 return super_convert(value)
@@ -1588,7 +1588,7 @@ class MySQLDialect(default.DefaultDialect):
                 have = rs.rowcount > 0
                 rs.close()
                 return have
-            except exceptions.SQLError, e:
+            except exc.SQLError, e:
                 if e.orig.args[0] == 1146:
                     return False
                 raise
@@ -1823,14 +1823,14 @@ class MySQLDialect(default.DefaultDialect):
         try:
             try:
                 rp = connection.execute(st)
-            except exceptions.SQLError, e:
+            except exc.SQLError, e:
                 if e.orig.args[0] == 1146:
-                    raise exceptions.NoSuchTableError(full_name)
+                    raise exc.NoSuchTableError(full_name)
                 else:
                     raise
             row = _compat_fetchone(rp, charset=charset)
             if not row:
-                raise exceptions.NoSuchTableError(full_name)
+                raise exc.NoSuchTableError(full_name)
             return row[1].strip()
         finally:
             if rp:
@@ -1850,9 +1850,9 @@ class MySQLDialect(default.DefaultDialect):
         try:
             try:
                 rp = connection.execute(st)
-            except exceptions.SQLError, e:
+            except exc.SQLError, e:
                 if e.orig.args[0] == 1146:
-                    raise exceptions.NoSuchTableError(full_name)
+                    raise exc.NoSuchTableError(full_name)
                 else:
                     raise
             rows = _compat_fetchall(rp, charset=charset)
@@ -1966,7 +1966,7 @@ class MySQLCompiler(compiler.DefaultCompiler):
 
     def for_update_clause(self, select):
         if select.for_update == 'read':
-             return ' LOCK IN SHARE MODE'
+            return ' LOCK IN SHARE MODE'
         else:
             return super(MySQLCompiler, self).for_update_clause(select)
 
@@ -2022,8 +2022,7 @@ class MySQLSchemaGenerator(compiler.SchemaGenerator):
         """Builds column DDL."""
 
         colspec = [self.preparer.format_column(column),
-                   column.type.dialect_impl(self.dialect,
-                                            _for_ddl=column).get_col_spec()]
+                   column.type.dialect_impl(self.dialect).get_col_spec()]
 
         default = self.get_column_default_string(column)
         if default is not None:
@@ -2308,7 +2307,7 @@ class MySQLSchemaReflector(object):
             ref_names = spec['foreign']
             if not util.Set(ref_names).issubset(
                 util.Set([c.name for c in ref_table.c])):
-                raise exceptions.InvalidRequestError(
+                raise exc.InvalidRequestError(
                     "Foreign key columns (%s) are not present on "
                     "foreign table %s" %
                     (', '.join(ref_names), ref_table.fullname()))
@@ -2643,7 +2642,7 @@ class MySQLSchemaReflector(object):
 
         return self._re_keyexprs.findall(identifiers)
 
-MySQLSchemaReflector.logger = logging.class_logger(MySQLSchemaReflector)
+MySQLSchemaReflector.logger = log.class_logger(MySQLSchemaReflector)
 
 
 class _MySQLIdentifierPreparer(compiler.IdentifierPreparer):
index 734ad58d108993b69a7829f8c5a8bb0e800edbfe..5bc8a186faa1b6735fbf785d6c981659749d518e 100644 (file)
@@ -7,7 +7,7 @@
 
 import datetime, random, re
 
-from sqlalchemy import util, sql, schema, exceptions, logging
+from sqlalchemy import util, sql, schema, log
 from sqlalchemy.engine import default, base
 from sqlalchemy.sql import compiler, visitors
 from sqlalchemy.sql import operators as sql_operators, functions as sql_functions
@@ -49,11 +49,11 @@ class OracleDateTime(sqltypes.DateTime):
 
     def result_processor(self, dialect):
         def process(value):
-            if value is None or isinstance(value,datetime.datetime):
+            if value is None or isinstance(value, datetime.datetime):
                 return value
             else:
                 # convert cx_oracle datetime object returned pre-python 2.4
-                return datetime.datetime(value.year,value.month,
+                return datetime.datetime(value.year, value.month,
                     value.day,value.hour, value.minute, value.second)
         return process
 
@@ -72,11 +72,11 @@ class OracleTimestamp(sqltypes.TIMESTAMP):
 
     def result_processor(self, dialect):
         def process(value):
-            if value is None or isinstance(value,datetime.datetime):
+            if value is None or isinstance(value, datetime.datetime):
                 return value
             else:
                 # convert cx_oracle datetime object returned pre-python 2.4
-                return datetime.datetime(value.year,value.month,
+                return datetime.datetime(value.year, value.month,
                     value.day,value.hour, value.minute, value.second)
         return process
 
@@ -216,13 +216,13 @@ class OracleExecutionContext(default.DefaultExecutionContext):
     def get_result_proxy(self):
         if hasattr(self, 'out_parameters'):
             if self.compiled_parameters is not None and len(self.compiled_parameters) == 1:
-                 for bind, name in self.compiled.bind_names.iteritems():
-                     if name in self.out_parameters:
-                         type = bind.type
-                         self.out_parameters[name] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[name].getvalue())
+                for bind, name in self.compiled.bind_names.iteritems():
+                    if name in self.out_parameters:
+                        type = bind.type
+                        self.out_parameters[name] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[name].getvalue())
             else:
-                 for k in self.out_parameters:
-                     self.out_parameters[k] = self.out_parameters[k].getvalue()
+                for k in self.out_parameters:
+                    self.out_parameters[k] = self.out_parameters[k].getvalue()
 
         if self.cursor.description is not None:
             for column in self.cursor.description:
@@ -331,7 +331,7 @@ class OracleDialect(default.DefaultDialect):
         this id will be passed to do_begin_twophase(), do_rollback_twophase(),
         do_commit_twophase().  its format is unspecified."""
 
-        id = random.randint(0,2**128)
+        id = random.randint(0, 2 ** 128)
         return (0x1234, "%032x" % 9, "%032x" % id)
 
     def do_release_savepoint(self, connection, name):
@@ -392,7 +392,7 @@ class OracleDialect(default.DefaultDialect):
             cursor = connection.execute(s)
         else:
             s = "select table_name from all_tables where tablespace_name NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner"
-            cursor = connection.execute(s,{'owner':self._denormalize_name(schema)})
+            cursor = connection.execute(s, {'owner': self._denormalize_name(schema)})
         return [self._normalize_name(row[0]) for row in cursor]
 
     def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None):
@@ -400,11 +400,11 @@ class OracleDialect(default.DefaultDialect):
 
         if desired_owner is None, attempts to locate a distinct owner.
 
-       returns the actual name, owner, dblink name, and synonym name if found.
+        returns the actual name, owner, dblink name, and synonym name if found.
         """
 
-       sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME
-                  from   ALL_SYNONYMS WHERE """
+        sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME
+                   from   ALL_SYNONYMS WHERE """
 
         clauses = []
         params = {}
@@ -418,9 +418,9 @@ class OracleDialect(default.DefaultDialect):
             clauses.append("TABLE_NAME=:tname")
             params['tname'] = desired_table
 
-        sql += " AND ".join(clauses) 
+        sql += " AND ".join(clauses)
 
-       result = connection.execute(sql, **params)
+        result = connection.execute(sql, **params)
         if desired_owner:
             row = result.fetchone()
             if row:
@@ -430,7 +430,7 @@ class OracleDialect(default.DefaultDialect):
         else:
             rows = result.fetchall()
             if len(rows) > 1:
-                raise exceptions.AssertionError("There are multiple tables visible to the schema, you must specify owner")
+                raise AssertionError("There are multiple tables visible to the schema, you must specify owner")
             elif len(rows) == 1:
                 row = rows[0]
                 return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME']
@@ -442,7 +442,7 @@ class OracleDialect(default.DefaultDialect):
 
         resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False)
 
-       if resolve_synonyms:
+        if resolve_synonyms:
             actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name))
         else:
             actual_name, owner, dblink, synonym = None, None, None, None
@@ -473,7 +473,7 @@ class OracleDialect(default.DefaultDialect):
             # NUMBER(9,2) if the precision is 9 and the scale is 2
             # NUMBER(3) if the precision is 3 and scale is 0
             #length is ignored except for CHAR and VARCHAR2
-            if coltype=='NUMBER' :
+            if coltype == 'NUMBER' :
                 if precision is None and scale is None:
                     coltype = OracleNumeric
                 elif precision is None and scale == 0  :
@@ -498,7 +498,7 @@ class OracleDialect(default.DefaultDialect):
             table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs))
 
         if not table.columns:
-           raise exceptions.AssertionError("Couldn't find any column information for table %s" % actual_name)
+            raise AssertionError("Couldn't find any column information for table %s" % actual_name)
 
         c = connection.execute("""SELECT
              ac.constraint_name,
@@ -534,8 +534,8 @@ class OracleDialect(default.DefaultDialect):
                 try:
                     fk = fks[cons_name]
                 except KeyError:
-                   fk = ([], [])
-                   fks[cons_name] = fk
+                    fk = ([], [])
+                    fks[cons_name] = fk
                 if remote_table is None:
                     # ticket 363
                     util.warn(
@@ -551,7 +551,7 @@ class OracleDialect(default.DefaultDialect):
                         remote_owner = self._normalize_name(ref_remote_owner)
 
                 if not table.schema and self._denormalize_name(remote_owner) == owner:
-                    refspec =  ".".join([remote_table, remote_column])               
+                    refspec =  ".".join([remote_table, remote_column])
                     t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True)
                 else:
                     refspec =  ".".join([x for x in [remote_owner, remote_table, remote_column] if x])
@@ -566,7 +566,7 @@ class OracleDialect(default.DefaultDialect):
             table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name))
 
 
-OracleDialect.logger = logging.class_logger(OracleDialect)
+OracleDialect.logger = log.class_logger(OracleDialect)
 
 class _OuterJoinColumn(sql.ClauseElement):
     __visit_name__ = 'outer_join_column'
@@ -574,7 +574,7 @@ class _OuterJoinColumn(sql.ClauseElement):
         self.column = column
     def _get_from_objects(self, **kwargs):
         return []
-    
+
 class OracleCompiler(compiler.DefaultCompiler):
     """Oracle compiler modifies the lexical structure of Select
     statements to work under non-ANSI configured Oracle databases, if
@@ -615,10 +615,10 @@ class OracleCompiler(compiler.DefaultCompiler):
             return compiler.DefaultCompiler.visit_join(self, join, **kwargs)
         else:
             return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
-    
+
     def _get_nonansi_join_whereclause(self, froms):
         clauses = []
-        
+
         def visit_join(join):
             if join.isouter:
                 def visit_binary(binary):
@@ -627,14 +627,14 @@ class OracleCompiler(compiler.DefaultCompiler):
                             binary.left = _OuterJoinColumn(binary.left)
                         elif binary.right.table is join.right:
                             binary.right = _OuterJoinColumn(binary.right)
-                clauses.append(visitors.traverse(join.onclause, visit_binary=visit_binary, clone=True))
+                clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary}))
             else:
                 clauses.append(join.onclause)
-        
+
         for f in froms:
-            visitors.traverse(f, visit_join=visit_join)
+            visitors.traverse(f, {}, {'join':visit_join})
         return sql.and_(*clauses)
-        
+
     def visit_outer_join_column(self, vc):
         return self.process(vc.column) + "(+)"
 
@@ -670,7 +670,7 @@ class OracleCompiler(compiler.DefaultCompiler):
                 if whereclause:
                     select = select.where(whereclause)
                     select._oracle_visit = True
-                
+
             if select._limit is not None or select._offset is not None:
                 # to use ROW_NUMBER(), an ORDER BY is required.
                 orderby = self.process(select._order_by_clause)
@@ -680,11 +680,11 @@ class OracleCompiler(compiler.DefaultCompiler):
 
                 select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
                 select._oracle_visit = True
-                
+
                 limitselect = sql.select([c for c in select.c if c.key!='ora_rn'])
                 limitselect._oracle_visit = True
                 limitselect._is_wrapper = True
-                
+
                 if select._offset is not None:
                     limitselect.append_whereclause("ora_rn>%d" % select._offset)
                     if select._limit is not None:
@@ -692,7 +692,7 @@ class OracleCompiler(compiler.DefaultCompiler):
                 else:
                     limitselect.append_whereclause("ora_rn<=%d" % select._limit)
                 select = limitselect
-        
+
         kwargs['iswrapper'] = getattr(select, '_is_wrapper', False)
         return compiler.DefaultCompiler.visit_select(self, select, **kwargs)
 
@@ -700,7 +700,7 @@ class OracleCompiler(compiler.DefaultCompiler):
         return ""
 
     def for_update_clause(self, select):
-        if select.for_update=="nowait":
+        if select.for_update == "nowait":
             return " FOR UPDATE NOWAIT"
         else:
             return super(OracleCompiler, self).for_update_clause(select)
@@ -709,7 +709,7 @@ class OracleCompiler(compiler.DefaultCompiler):
 class OracleSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column)
-        colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+        colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
index 605ce7272b08788b44ec44d69c1e8736c3cfcd9c..23b0a273e9a2b2a85857a9e4c3c0f5cf6aa645c7 100644 (file)
@@ -21,7 +21,7 @@ parameter when creating the queries::
 
 import random, re, string
 
-from sqlalchemy import sql, schema, exceptions, util
+from sqlalchemy import sql, schema, exc, util
 from sqlalchemy.engine import base, default
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.sql import operators as sql_operators
@@ -99,11 +99,17 @@ class PGText(sqltypes.Text):
 
 class PGString(sqltypes.String):
     def get_col_spec(self):
-        return "VARCHAR(%(length)s)" % {'length' : self.length}
+        if self.length:
+            return "VARCHAR(%(length)d)" % {'length' : self.length}
+        else:
+            return "VARCHAR"
 
 class PGChar(sqltypes.CHAR):
     def get_col_spec(self):
-        return "CHAR(%(length)s)" % {'length' : self.length}
+        if self.length:
+            return "CHAR(%(length)d)" % {'length' : self.length}
+        else:
+            return "CHAR"
 
 class PGBinary(sqltypes.Binary):
     def get_col_spec(self):
@@ -146,7 +152,7 @@ class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine):
             if value is None:
                 return value
             def convert_item(item):
-                if isinstance(item, (list,tuple)):
+                if isinstance(item, (list, tuple)):
                     return [convert_item(child) for child in item]
                 else:
                     if item_proc:
@@ -373,7 +379,7 @@ class PGDialect(default.DefaultDialect):
 
     def last_inserted_ids(self):
         if self.context.last_inserted_ids is None:
-            raise exceptions.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without Postgres OIDs enabled")
+            raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without Postgres OIDs enabled")
         else:
             return self.context.last_inserted_ids
 
@@ -419,7 +425,7 @@ class PGDialect(default.DefaultDialect):
         v = connection.execute("select version()").scalar()
         m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v)
         if not m:
-            raise exceptions.AssertionError("Could not determine version from string '%s'" % v)
+            raise AssertionError("Could not determine version from string '%s'" % v)
         return tuple([int(x) for x in m.group(1, 2, 3)])
 
     def reflecttable(self, connection, table, include_columns):
@@ -459,7 +465,7 @@ class PGDialect(default.DefaultDialect):
         rows = c.fetchall()
 
         if not rows:
-            raise exceptions.NoSuchTableError(table.name)
+            raise exc.NoSuchTableError(table.name)
 
         domains = self._load_domains(connection)
 
@@ -519,7 +525,7 @@ class PGDialect(default.DefaultDialect):
                             default = domain['default']
                         coltype = ischema_names[domain['attype']]
                 else:
-                    coltype=None
+                    coltype = None
 
             if coltype:
                 coltype = coltype(*args, **kwargs)
@@ -530,7 +536,7 @@ class PGDialect(default.DefaultDialect):
                           (attype, name))
                 coltype = sqltypes.NULLTYPE
 
-            colargs= []
+            colargs = []
             if default is not None:
                 match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
                 if match is not None:
@@ -560,7 +566,7 @@ class PGDialect(default.DefaultDialect):
             col = table.c[pk]
             table.primary_key.add(col)
             if col.default is None:
-                col.autoincrement=False
+                col.autoincrement = False
 
         # Foreign keys
         FK_SQL = """
@@ -697,7 +703,7 @@ class PGCompiler(compiler.DefaultCompiler):
                         yield co
                 else:
                     yield c
-        columns = [self.process(c) for c in flatten_columnlist(returning_cols)]
+        columns = [self.process(c, render_labels=True) for c in flatten_columnlist(returning_cols)]
         text += ' RETURNING ' + string.join(columns, ', ')
         return text
 
@@ -724,7 +730,7 @@ class PGSchemaGenerator(compiler.SchemaGenerator):
             else:
                 colspec += " SERIAL"
         else:
-            colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+            colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
             default = self.get_column_default_string(column)
             if default is not None:
                 colspec += " DEFAULT " + default
index f8bea90ebcf90d81639babd48436423713833c24..a63741cf7c8404fad49e7a719ad425d1d1284945 100644 (file)
@@ -7,7 +7,7 @@
 
 import datetime, re, time
 
-from sqlalchemy import schema, exceptions, pool, PassiveDefault
+from sqlalchemy import schema, exc, pool, PassiveDefault
 from sqlalchemy.engine import default
 import sqlalchemy.types as sqltypes
 import sqlalchemy.util as util
@@ -67,7 +67,7 @@ class DateTimeMixin(object):
             microsecond = 0
         return time.strptime(value, self.__format__)[0:6] + (microsecond,)
 
-class SLDateTime(DateTimeMixin,sqltypes.DateTime):
+class SLDateTime(DateTimeMixin, sqltypes.DateTime):
     __format__ = "%Y-%m-%d %H:%M:%S"
     __microsecond__ = True
 
@@ -112,11 +112,11 @@ class SLText(sqltypes.Text):
 
 class SLString(sqltypes.String):
     def get_col_spec(self):
-        return "VARCHAR(%(length)s)" % {'length' : self.length}
+        return "VARCHAR" + (self.length and "(%d)" % self.length or "")
 
 class SLChar(sqltypes.CHAR):
     def get_col_spec(self):
-        return "CHAR(%(length)s)" % {'length' : self.length}
+        return "CHAR" + (self.length and "(%d)" % self.length or "")
 
 class SLBinary(sqltypes.Binary):
     def get_col_spec(self):
@@ -203,7 +203,7 @@ class SQLiteDialect(default.DefaultDialect):
             return tuple([int(x) for x in num.split('.')])
         if self.dbapi is not None:
             sqlite_ver = self.dbapi.version_info
-            if sqlite_ver < (2,1,'3'):
+            if sqlite_ver < (2, 1, '3'):
                 util.warn(
                     ("The installed version of pysqlite2 (%s) is out-dated "
                      "and will cause errors in some cases.  Version 2.1.3 "
@@ -227,7 +227,7 @@ class SQLiteDialect(default.DefaultDialect):
 
     def create_connect_args(self, url):
         if url.username or url.password or url.host or url.port:
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "Invalid SQLite URL: %s\n"
                 "Valid SQLite URL forms are:\n"
                 " sqlite:///:memory: (or, sqlite://)\n"
@@ -270,7 +270,7 @@ class SQLiteDialect(default.DefaultDialect):
                      "  SELECT * FROM sqlite_temp_master) "
                      "WHERE type='table' ORDER BY name")
                 rs = connection.execute(s)
-            except exceptions.DBAPIError:
+            except exc.DBAPIError:
                 raise
                 s = ("SELECT name FROM sqlite_master "
                      "WHERE type='table' ORDER BY name")
@@ -334,13 +334,13 @@ class SQLiteDialect(default.DefaultDialect):
                 args = re.findall(r'(\d+)', args)
                 coltype = coltype(*[int(a) for a in args])
 
-            colargs= []
+            colargs = []
             if has_default:
                 colargs.append(PassiveDefault('?'))
             table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs))
 
         if not found_table:
-            raise exceptions.NoSuchTableError(table.name)
+            raise exc.NoSuchTableError(table.name)
 
         c = connection.execute("%sforeign_key_list(%s)" % (pragma, qtable))
         fks = {}
@@ -355,7 +355,7 @@ class SQLiteDialect(default.DefaultDialect):
             try:
                 fk = fks[constraint_name]
             except KeyError:
-                fk = ([],[])
+                fk = ([], [])
                 fks[constraint_name] = fk
 
             # look up the table based on the given table's engine, not 'self',
@@ -438,7 +438,7 @@ class SQLiteCompiler(compiler.DefaultCompiler):
 class SQLiteSchemaGenerator(compiler.SchemaGenerator):
 
     def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
         default = self.get_column_default_string(column)
         if default is not None:
             colspec += " DEFAULT " + default
index 2551e90c53a4eb11141d0b529300a45419ec562e..14734c6e0e2717fc6d050a09dd3a916c3cc9fa1a 100644 (file)
@@ -24,7 +24,7 @@ Known issues / TODO:
 
 import datetime, operator
 
-from sqlalchemy import util, sql, schema, exceptions
+from sqlalchemy import util, sql, schema, exc
 from sqlalchemy.sql import compiler, expression
 from sqlalchemy.engine import default, base
 from sqlalchemy import types as sqltypes
@@ -160,11 +160,11 @@ class SybaseTypeError(sqltypes.TypeEngine):
 
     def bind_processor(self, dialect):
         def process(value):
-            raise exceptions.NotSupportedError("Data type not supported", [value])
+            raise exc.NotSupportedError("Data type not supported", [value])
         return process
 
     def get_col_spec(self):
-        raise exceptions.NotSupportedError("Data type not supported")
+        raise exc.NotSupportedError("Data type not supported")
 
 class SybaseNumeric(sqltypes.Numeric):
     def get_col_spec(self):
@@ -487,7 +487,7 @@ class SybaseSQLDialect(default.DefaultDialect):
                 dialect_cls = dialect_mapping[module_name]
                 return dialect_cls.import_dbapi()
             except KeyError:
-                raise exceptions.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name)
+                raise exc.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name)
         else:
             for dialect_cls in dialect_mapping.values():
                 try:
@@ -527,7 +527,7 @@ class SybaseSQLDialect(default.DefaultDialect):
             self.context.rowcount = c.rowcount
             c.DBPROP_COMMITPRESERVE = "Y"
         except Exception, e:
-            raise exceptions.DBAPIError.instance(statement, parameters, e)
+            raise exc.DBAPIError.instance(statement, parameters, e)
 
     def table_names(self, connection, schema):
         """Ignore the schema and the charset for now."""
@@ -597,7 +597,7 @@ class SybaseSQLDialect(default.DefaultDialect):
                               (type, name))
                     coltype = sqltypes.NULLTYPE
                 coltype = coltype(*args)
-            colargs= []
+            colargs = []
             if default is not None:
                 colargs.append(schema.PassiveDefault(sql.text(default)))
 
@@ -624,16 +624,16 @@ class SybaseSQLDialect(default.DefaultDialect):
                 row[0], row[1], row[2], row[3],
             )
             if not primary_table in foreignKeys.keys():
-                foreignKeys[primary_table] = [['%s'%(foreign_column)], ['%s.%s'%(primary_table,primary_column)]]
+                foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]]
             else:
                 foreignKeys[primary_table][0].append('%s'%(foreign_column))
-                foreignKeys[primary_table][1].append('%s.%s'%(primary_table,primary_column))
+                foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column))
         for primary_table in foreignKeys.keys():
             #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)]))
             table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1]))
 
         if not found_table:
-            raise exceptions.NoSuchTableError(table.name)
+            raise exc.NoSuchTableError(table.name)
 
 
 class SybaseSQLDialect_mxodbc(SybaseSQLDialect):
@@ -749,7 +749,7 @@ class SybaseSQLCompiler(compiler.DefaultCompiler):
     def bindparam_string(self, name):
         res = super(SybaseSQLCompiler, self).bindparam_string(name)
         if name.lower().startswith('literal'):
-            res = 'STRING(%s)'%res
+            res = 'STRING(%s)' % res
         return res
 
     def get_select_precolumns(self, select):
@@ -828,7 +828,7 @@ class SybaseSQLSchemaGenerator(compiler.SchemaGenerator):
             #colspec += " numeric(30,0) IDENTITY"
             colspec += " Integer IDENTITY"
         else:
-            colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()
+            colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec()
 
         if not column.nullable:
             colspec += " NOT NULL"
index 583a02763824944407a2119d443c0cc366757077..2ca2ac5f7150ae9ee9032de070f7202565beaab3 100644 (file)
@@ -13,7 +13,7 @@ and result contexts.
 """
 
 import inspect, StringIO, sys
-from sqlalchemy import exceptions, schema, util, types, logging
+from sqlalchemy import exc, schema, util, types, log
 from sqlalchemy.sql import expression
 
 
@@ -451,7 +451,7 @@ class Compiled(object):
         self.statement = statement
         self.column_keys = column_keys
         self.bind = bind
-        self.can_execute = statement.supports_execution()
+        self.can_execute = statement.supports_execution
 
     def compile(self):
         """Produce the internal string representation of this element."""
@@ -482,7 +482,7 @@ class Compiled(object):
 
         e = self.bind
         if e is None:
-            raise exceptions.UnboundExecutionError("This Compiled object is not bound to any Engine or Connection.")
+            raise exc.UnboundExecutionError("This Compiled object is not bound to any Engine or Connection.")
         return e._execute_compiled(self, multiparams, params)
 
     def scalar(self, *multiparams, **params):
@@ -541,7 +541,7 @@ class Connection(Connectable):
         self.__savepoint_seq = 0
         self.__branch = _branch
         self.__invalid = False
-
+        
     def _branch(self):
         """Return a new Connection which references this Connection's
         engine and connection; but does not have close_with_result enabled,
@@ -550,7 +550,7 @@ class Connection(Connectable):
         This is used to execute "sub" statements within a single execution,
         usually an INSERT statement.
         """
-        return Connection(self.engine, self.__connection, _branch=True)
+        return self.engine.Connection(self.engine, self.__connection, _branch=True)
 
     def dialect(self):
         "Dialect used by this Connection."
@@ -578,11 +578,11 @@ class Connection(Connectable):
         except AttributeError:
             if self.__invalid:
                 if self.__transaction is not None:
-                    raise exceptions.InvalidRequestError("Can't reconnect until invalid transaction is rolled back")
+                    raise exc.InvalidRequestError("Can't reconnect until invalid transaction is rolled back")
                 self.__connection = self.engine.raw_connection()
                 self.__invalid = False
                 return self.__connection
-            raise exceptions.InvalidRequestError("This Connection is closed")
+            raise exc.InvalidRequestError("This Connection is closed")
     connection = property(connection)
 
     def should_close_with_result(self):
@@ -702,7 +702,7 @@ class Connection(Connectable):
         """
 
         if self.__transaction is not None:
-            raise exceptions.InvalidRequestError(
+            raise exc.InvalidRequestError(
                 "Cannot start a two phase transaction when a transaction "
                 "is already in progress.")
         if xid is None:
@@ -843,7 +843,7 @@ class Connection(Connectable):
             if c in Connection.executors:
                 return Connection.executors[c](self, object, multiparams, params)
         else:
-            raise exceptions.InvalidRequestError("Unexecutable object type: " + str(type(object)))
+            raise exc.InvalidRequestError("Unexecutable object type: " + str(type(object)))
 
     def _execute_default(self, default, multiparams=None, params=None):
         return self.engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default)
@@ -862,7 +862,7 @@ class Connection(Connectable):
         in the case of 'raw' execution which accepts positional parameters,
         it may be a list of tuples or lists."""
 
-        if multiparams is None or len(multiparams) == 0:
+        if not multiparams:
             if params:
                 return [params]
             else:
@@ -897,7 +897,7 @@ class Connection(Connectable):
     def _execute_compiled(self, compiled, multiparams=None, params=None, distilled_params=None):
         """Execute a sql.Compiled object."""
         if not compiled.can_execute:
-            raise exceptions.ArgumentError("Not an executable clause: %s" % (str(compiled)))
+            raise exc.ArgumentError("Not an executable clause: %s" % (str(compiled)))
 
         if distilled_params is None:
             distilled_params = self.__distill_params(multiparams, params)
@@ -924,7 +924,7 @@ class Connection(Connectable):
 
     def _handle_dbapi_exception(self, e, statement, parameters, cursor):
         if getattr(self, '_reentrant_error', False):
-            raise exceptions.DBAPIError.instance(None, None, e)
+            raise exc.DBAPIError.instance(None, None, e)
         self._reentrant_error = True
         try:
             if not isinstance(e, self.dialect.dbapi.Error):
@@ -939,7 +939,7 @@ class Connection(Connectable):
                 self._autorollback()
                 if self.__close_with_result:
                     self.close()
-            raise exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
+            raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect)
         finally:
             del self._reentrant_error
 
@@ -1047,7 +1047,7 @@ class Transaction(object):
 
     def commit(self):
         if not self._parent._is_active:
-            raise exceptions.InvalidRequestError("This transaction is inactive")
+            raise exc.InvalidRequestError("This transaction is inactive")
         self._do_commit()
         self._is_active = False
 
@@ -1094,7 +1094,7 @@ class TwoPhaseTransaction(Transaction):
 
     def prepare(self):
         if not self._parent._is_active:
-            raise exceptions.InvalidRequestError("This transaction is inactive")
+            raise exc.InvalidRequestError("This transaction is inactive")
         self._connection._prepare_twophase_impl(self.xid)
         self._is_prepared = True
 
@@ -1110,13 +1110,17 @@ class Engine(Connectable):
     provide a default implementation of SchemaEngine.
     """
 
-    def __init__(self, pool, dialect, url, echo=None):
+    def __init__(self, pool, dialect, url, echo=None, proxy=None):
         self.pool = pool
         self.url = url
-        self.dialect=dialect
+        self.dialect = dialect
         self.echo = echo
         self.engine = self
-        self.logger = logging.instance_logger(self, echoflag=echo)
+        self.logger = log.instance_logger(self, echoflag=echo)
+        if proxy:
+            self.Connection = _proxy_connection_cls(Connection, proxy)
+        else:
+            self.Connection = Connection
 
     def name(self):
         "String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``."
@@ -1124,7 +1128,7 @@ class Engine(Connectable):
         return sys.modules[self.dialect.__module__].descriptor()['name']
     name = property(name)
 
-    echo = logging.echo_property()
+    echo = log.echo_property()
 
     def __repr__(self):
         return 'Engine(%s)' % str(self.url)
@@ -1228,7 +1232,7 @@ class Engine(Connectable):
     def connect(self, **kwargs):
         """Return a newly allocated Connection object."""
 
-        return Connection(self, **kwargs)
+        return self.Connection(self, **kwargs)
 
     def contextual_connect(self, close_with_result=False, **kwargs):
         """Return a Connection object which may be newly allocated, or may be part of some ongoing context.
@@ -1236,7 +1240,7 @@ class Engine(Connectable):
         This Connection is meant to be used by the various "auto-connecting" operations.
         """
 
-        return Connection(self, self.pool.connect(), close_with_result=close_with_result, **kwargs)
+        return self.Connection(self, self.pool.connect(), close_with_result=close_with_result, **kwargs)
 
     def table_names(self, schema=None, connection=None):
         """Return a list of all table names available in the database.
@@ -1286,6 +1290,22 @@ class Engine(Connectable):
         return self.pool.unique_connection()
 
 
+def _proxy_connection_cls(cls, proxy):
+    class ProxyConnection(cls):
+        def execute(self, object, *multiparams, **params):
+            return proxy.execute(self, super(ProxyConnection, self).execute, object, *multiparams, **params)
+        def execute_clauseelement(self, elem, multiparams=None, params=None):
+            return proxy.execute(self, super(ProxyConnection, self).execute, elem, *(multiparams or []), **(params or {}))
+            
+        def _cursor_execute(self, cursor, statement, parameters, context=None):
+            return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, cursor, statement, parameters, context, False)
+        def _cursor_executemany(self, cursor, statement, parameters, context=None):
+            return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, cursor, statement, parameters, context, True)
+
+    return ProxyConnection
+
 class RowProxy(object):
     """Proxy a single cursor row for a parent ResultProxy.
 
@@ -1296,6 +1316,8 @@ class RowProxy(object):
     results that correspond to constructed SQL expressions).
     """
 
+    __slots__ = ['__parent', '__row']
+    
     def __init__(self, parent, row):
         """RowProxy objects are constructed by ResultProxy objects."""
 
@@ -1488,14 +1510,14 @@ class ResultProxy(object):
                         return props[key._label.lower()]
                     elif hasattr(key, 'name') and key.name.lower() in props:
                         return props[key.name.lower()]
-                raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
+                raise exc.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key)))
 
             return rec
         return util.PopulateDict(lookup_key)
 
     def __ambiguous_processor(self, colname):
         def process(value):
-            raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % colname)
+            raise exc.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % colname)
         return process
 
     def close(self):
index 3c1721f9d9e39ea396127e7e63a6c2ddaa6f2c1e..e39cbdd39dfbb22e133776cb5c198412d945af74 100644 (file)
@@ -12,7 +12,6 @@ as the base class for their own corresponding classes.
 
 """
 
-
 import re, random
 from sqlalchemy.engine import base
 from sqlalchemy.sql import compiler, expression
@@ -112,7 +111,7 @@ class DefaultDialect(base.Dialect):
         This id will be passed to do_begin_twophase(), do_rollback_twophase(),
         do_commit_twophase().  Its format is unspecified."""
 
-        return "_sa_%032x" % random.randint(0,2**128)
+        return "_sa_%032x" % random.randint(0, 2 ** 128)
 
     def do_savepoint(self, connection, name):
         connection.execute(expression.SavepointClause(name))
@@ -331,9 +330,9 @@ class DefaultExecutionContext(base.ExecutionContext):
         if self.dialect.positional:
             inputsizes = []
             for key in self.compiled.positiontup:
-               typeengine = types[key]
-               dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
-               if dbtype is not None:
+                typeengine = types[key]
+                dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi)
+                if dbtype is not None:
                     inputsizes.append(dbtype)
             try:
                 self.cursor.setinputsizes(*inputsizes)
@@ -395,4 +394,4 @@ class DefaultExecutionContext(base.ExecutionContext):
                 self._last_updated_params = compiled_parameters
 
             self.postfetch_cols = self.compiled.postfetch
-            self.prefetch_cols = self.compiled.prefetch
\ No newline at end of file
+            self.prefetch_cols = self.compiled.prefetch
index d4a0ad841881d4af47ba06c257a6fccf8a1a7fad..aab191231ce85635e7a9ef2189ef13719b9cda05 100644 (file)
@@ -12,7 +12,7 @@ classes.
 
 
 from sqlalchemy.engine import base, threadlocal, url
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exc
 from sqlalchemy import pool as poollib
 
 strategies = {}
@@ -77,7 +77,7 @@ class DefaultEngineStrategy(EngineStrategy):
                 try:
                     return dbapi.connect(*cargs, **cparams)
                 except Exception, e:
-                    raise exceptions.DBAPIError.instance(None, None, e)
+                    raise exc.DBAPIError.instance(None, None, e)
             creator = kwargs.pop('creator', connect)
 
             poolclass = (kwargs.pop('poolclass', None) or
@@ -200,7 +200,7 @@ class MockEngineStrategy(EngineStrategy):
 
         def create(self, entity, **kwargs):
             kwargs['checkfirst'] = False
-            self.dialect.schemagenerator(self.dialect ,self, **kwargs).traverse(entity)
+            self.dialect.schemagenerator(self.dialectself, **kwargs).traverse(entity)
 
         def drop(self, entity, **kwargs):
             kwargs['checkfirst'] = False
index e4b2859dc50283d146fecaf0009d7bc2d7da2c53..91b16ed5fa710b8fea115ae57145fc5459cf6448 100644 (file)
@@ -17,7 +17,7 @@ class TLSession(object):
         try:
             return self.__transaction._increment_connect()
         except AttributeError:
-            return TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result)
+            return self.engine.TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result)
 
     def reset(self):
         try:
@@ -81,11 +81,14 @@ class TLSession(object):
 
 
 class TLConnection(base.Connection):
-    def __init__(self, session, connection, close_with_result):
-        base.Connection.__init__(self, session.engine, connection, close_with_result=close_with_result)
+    def __init__(self, session, connection, **kwargs):
+        base.Connection.__init__(self, session.engine, connection, **kwargs)
         self.__session = session
         self.__opencount = 1
 
+    def _branch(self):
+        return self.engine.Connection(self.engine, self.connection, _branch=True)
+
     def session(self):
         return self.__session
     session = property(session)
@@ -168,6 +171,12 @@ class TLEngine(base.Engine):
         super(TLEngine, self).__init__(*args, **kwargs)
         self.context = util.ThreadLocal()
 
+        proxy = kwargs.get('proxy')
+        if proxy:
+            self.TLConnection = base._proxy_connection_cls(TLConnection, proxy)
+        else:
+            self.TLConnection = TLConnection
+
     def session(self):
         "Returns the current thread's TLSession"
         if not hasattr(self.context, 'session'):
index 7364f0227ca82f6dd739c273837803a539debc5b..72d09bf8589a08af1fb0d238d7c2daca799b74d1 100644 (file)
@@ -7,7 +7,7 @@ be used directly and is also accepted directly by ``create_engine()``.
 """
 
 import re, cgi, sys, urllib
-from sqlalchemy import exceptions
+from sqlalchemy import exc
 
 
 class URL(object):
@@ -53,7 +53,7 @@ class URL(object):
             self.port = int(port)
         else:
             self.port = None
-        self.database= database
+        self.database = database
         self.query = query or {}
 
     def __str__(self):
@@ -180,7 +180,7 @@ def _parse_rfc1738_args(name):
         name = components.pop('name')
         return URL(name, **components)
     else:
-        raise exceptions.ArgumentError(
+        raise exc.ArgumentError(
             "Could not parse rfc1738 URL from string '%s'" % name)
 
 def _parse_keyvalue_args(name):
similarity index 75%
rename from lib/sqlalchemy/exceptions.py
rename to lib/sqlalchemy/exc.py
index 43623df93fa0ce1efb4c239c19c059ef8c3c9802..71b46ca114590e93809876aa482dca1d62adbec7 100644 (file)
@@ -3,78 +3,78 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
+
 """Exceptions used with SQLAlchemy.
 
-The base exception class is SQLAlchemyError.  Exceptions which are raised as a result
-of DBAPI exceptions are all subclasses of [sqlalchemy.exceptions#DBAPIError]."""
+The base exception class is SQLAlchemyError.  Exceptions which are raised as a
+result of DBAPI exceptions are all subclasses of
+[sqlalchemy.exceptions#DBAPIError].
+
+"""
+
 
 class SQLAlchemyError(Exception):
     """Generic error class."""
 
 
 class ArgumentError(SQLAlchemyError):
-    """Raised for all those conditions where invalid arguments are
-    sent to constructed objects.  This error generally corresponds to
-    construction time state errors.
+    """Raised when an invalid or conflicting function argument is supplied.
+
+    This error generally corresponds to construction time state errors.
+
     """
 
 
+class CircularDependencyError(SQLAlchemyError):
+    """Raised by topological sorts when a circular dependency is detected"""
+
+
 class CompileError(SQLAlchemyError):
     """Raised when an error occurs during SQL compilation"""
 
 
-class TimeoutError(SQLAlchemyError):
-    """Raised when a connection pool times out on getting a connection."""
+# Moved to orm.exc; compatability definition installed by orm import until 0.6
+ConcurrentModificationError = None
 
+class DisconnectionError(SQLAlchemyError):
+    """A disconnect is detected on a raw DB-API connection.
 
-class ConcurrentModificationError(SQLAlchemyError):
-    """Raised when a concurrent modification condition is detected."""
+    This error is raised and consumed internally by a connection pool.  It can
+    be raised by a ``PoolListener`` so that the host pool forces a disconnect.
 
+    """
 
-class CircularDependencyError(SQLAlchemyError):
-    """Raised by topological sorts when a circular dependency is detected"""
 
+# Moved to orm.exc; compatability definition installed by orm import until 0.6
+FlushError = None
 
-class FlushError(SQLAlchemyError):
-    """Raised when an invalid condition is detected upon a ``flush()``."""
+class TimeoutError(SQLAlchemyError):
+    """Raised when a connection pool times out on getting a connection."""
 
 
 class InvalidRequestError(SQLAlchemyError):
-    """SQLAlchemy was asked to do something it can't do, return
-    nonexistent data, etc.
+    """SQLAlchemy was asked to do something it can't do.
 
     This error generally corresponds to runtime state errors.
-    """
-
-class UnmappedColumnError(InvalidRequestError):
-    """A mapper was asked to return mapped information about a column
-    which it does not map"""
 
-class NoSuchTableError(InvalidRequestError):
-    """SQLAlchemy was asked to load a table's definition from the
-    database, but the table doesn't exist.
     """
 
-class UnboundExecutionError(InvalidRequestError):
-    """SQL was attempted without a database connection to execute it on."""
+class NoSuchColumnError(KeyError, InvalidRequestError):
+    """A nonexistent column is requested from a ``RowProxy``."""
 
-class AssertionError(SQLAlchemyError):
-    """Corresponds to internal state being detected in an invalid state."""
-
-
-class NoSuchColumnError(KeyError, SQLAlchemyError):
-    """Raised by ``RowProxy`` when a nonexistent column is requested from a row."""
-    
 class NoReferencedTableError(InvalidRequestError):
     """Raised by ``ForeignKey`` when the referred ``Table`` cannot be located."""
 
-class DisconnectionError(SQLAlchemyError):
-    """Raised within ``Pool`` when a disconnect is detected on a raw DB-API connection.
+class NoSuchTableError(InvalidRequestError):
+    """Table does not exist or is not visible to a connection."""
 
-    This error is consumed internally by a connection pool.  It can be raised by
-    a ``PoolListener`` so that the host pool forces a disconnect.
-    """
 
+class UnboundExecutionError(InvalidRequestError):
+    """SQL was attempted without a database connection to execute it on."""
+
+
+# Moved to orm.exc; compatability definition installed by orm import until 0.6
+UnmappedColumnError = None
 
 class DBAPIError(SQLAlchemyError):
     """Raised when the execution of a database operation fails.
@@ -93,6 +93,7 @@ class DBAPIError(SQLAlchemyError):
 
     The wrapped exception object is available in the ``orig`` attribute.
     Its type and properties are DB-API implementation specific.
+
     """
 
     def instance(cls, statement, params, orig, connection_invalidated=False):
@@ -117,7 +118,7 @@ class DBAPIError(SQLAlchemyError):
         except Exception, e:
             text = 'Error in str() of DB-API-generated exception: ' + str(e)
         SQLAlchemyError.__init__(
-            self, "(%s) %s" % (orig.__class__.__name__, text))
+            self, '(%s) %s' % (orig.__class__.__name__, text))
         self.statement = statement
         self.params = params
         self.orig = orig
@@ -128,39 +129,51 @@ class DBAPIError(SQLAlchemyError):
                          repr(self.statement), repr(self.params)])
 
 
-# As of 0.4, SQLError is now DBAPIError
+# As of 0.4, SQLError is now DBAPIError.
+# SQLError alias will be removed in 0.6.
 SQLError = DBAPIError
 
 class InterfaceError(DBAPIError):
     """Wraps a DB-API InterfaceError."""
 
+
 class DatabaseError(DBAPIError):
     """Wraps a DB-API DatabaseError."""
 
+
 class DataError(DatabaseError):
     """Wraps a DB-API DataError."""
 
+
 class OperationalError(DatabaseError):
     """Wraps a DB-API OperationalError."""
 
+
 class IntegrityError(DatabaseError):
     """Wraps a DB-API IntegrityError."""
 
+
 class InternalError(DatabaseError):
     """Wraps a DB-API InternalError."""
 
+
 class ProgrammingError(DatabaseError):
     """Wraps a DB-API ProgrammingError."""
 
+
 class NotSupportedError(DatabaseError):
     """Wraps a DB-API NotSupportedError."""
 
+
 # Warnings
+
 class SADeprecationWarning(DeprecationWarning):
     """Issued once per usage of a deprecated API."""
 
+
 class SAPendingDeprecationWarning(PendingDeprecationWarning):
     """Issued once per usage of a deprecated API."""
 
+
 class SAWarning(RuntimeWarning):
     """Issued at runtime."""
diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py
deleted file mode 100644 (file)
index 02f4b5b..0000000
+++ /dev/null
@@ -1,298 +0,0 @@
-from sqlalchemy             import ThreadLocalMetaData, util, Integer
-from sqlalchemy             import Table, Column, ForeignKey
-from sqlalchemy.orm         import class_mapper, relation, scoped_session
-from sqlalchemy.orm         import sessionmaker
-                                   
-from sqlalchemy.orm import backref as create_backref
-
-import inspect
-import sys
-
-#
-# the "proxy" to the database engine... this can be swapped out at runtime
-#
-metadata = ThreadLocalMetaData()
-Objectstore = scoped_session
-objectstore = scoped_session(sessionmaker(autoflush=True, transactional=False))
-
-#
-# declarative column declaration - this is so that we can infer the colname
-#
-class column(object):
-    def __init__(self, coltype, colname=None, foreign_key=None,
-                 primary_key=False, *args, **kwargs):
-        if isinstance(foreign_key, basestring): 
-            foreign_key = ForeignKey(foreign_key)
-        
-        self.coltype     = coltype
-        self.colname     = colname
-        self.foreign_key = foreign_key
-        self.primary_key = primary_key
-        self.kwargs      = kwargs
-        self.args        = args
-
-#
-# declarative relationship declaration
-#
-class relationship(object):
-    def __init__(self, classname, colname=None, backref=None, private=False,
-                 lazy=True, uselist=True, secondary=None, order_by=False, viewonly=False):
-        self.classname = classname
-        self.colname   = colname
-        self.backref   = backref
-        self.private   = private
-        self.lazy      = lazy
-        self.uselist   = uselist
-        self.secondary = secondary
-        self.order_by  = order_by
-        self.viewonly  = viewonly
-    
-    def process(self, klass, propname, relations):
-        relclass = ActiveMapperMeta.classes[self.classname]
-        
-        if isinstance(self.order_by, str):
-            self.order_by = [ self.order_by ]
-        
-        if isinstance(self.order_by, list):
-            for itemno in range(len(self.order_by)):
-                if isinstance(self.order_by[itemno], str):
-                    self.order_by[itemno] = \
-                        getattr(relclass.c, self.order_by[itemno])
-        
-        backref = self.create_backref(klass)
-        relations[propname] = relation(relclass.mapper,
-                                       secondary=self.secondary,
-                                       backref=backref, 
-                                       private=self.private, 
-                                       lazy=self.lazy, 
-                                       uselist=self.uselist,
-                                       order_by=self.order_by, 
-                                       viewonly=self.viewonly)
-    
-    def create_backref(self, klass):
-        if self.backref is None:
-            return None
-        
-        relclass = ActiveMapperMeta.classes[self.classname]
-        
-        if klass.__name__ == self.classname:
-            class_mapper(relclass).compile()
-            br_fkey = relclass.c[self.colname]
-        else:
-            br_fkey = None
-        
-        return create_backref(self.backref, remote_side=br_fkey)
-
-
-class one_to_many(relationship):
-    def __init__(self, *args, **kwargs):
-        kwargs['uselist'] = True
-        relationship.__init__(self, *args, **kwargs)
-
-class one_to_one(relationship):
-    def __init__(self, *args, **kwargs):
-        kwargs['uselist'] = False
-        relationship.__init__(self, *args, **kwargs)
-    
-    def create_backref(self, klass):
-        if self.backref is None:
-            return None
-        
-        relclass = ActiveMapperMeta.classes[self.classname]
-        
-        if klass.__name__ == self.classname:
-            br_fkey = getattr(relclass.c, self.colname)
-        else:
-            br_fkey = None
-        
-        return create_backref(self.backref, foreignkey=br_fkey, uselist=False)
-
-
-class many_to_many(relationship):
-    def __init__(self, classname, secondary, backref=None, lazy=True,
-                 order_by=False):
-        relationship.__init__(self, classname, None, backref, False, lazy,
-                              uselist=True, secondary=secondary,
-                              order_by=order_by)
-
-
-# 
-# SQLAlchemy metaclass and superclass that can be used to do SQLAlchemy 
-# mapping in a declarative way, along with a function to process the 
-# relationships between dependent objects as they come in, without blowing
-# up if the classes aren't specified in a proper order
-# 
-
-__deferred_classes__ = {}
-__processed_classes__ = {}
-def process_relationships(klass, was_deferred=False):
-    # first, we loop through all of the relationships defined on the
-    # class, and make sure that the related class already has been
-    # completely processed and defer processing if it has not
-    defer = False
-    for propname, reldesc in klass.relations.items():
-        found = (reldesc.classname == klass.__name__ or reldesc.classname in __processed_classes__)
-        if not found:
-            defer = True
-            break
-    
-    # next, we loop through all the columns looking for foreign keys
-    # and make sure that we can find the related tables (they do not 
-    # have to be processed yet, just defined), and we defer if we are 
-    # not able to find any of the related tables
-    if not defer:
-        for col in klass.columns:
-            if col.foreign_keys:
-                found = False
-                cn = col.foreign_keys[0]._colspec
-                table_name = cn[:cn.rindex('.')]
-                for other_klass in ActiveMapperMeta.classes.values():
-                    if other_klass.table.fullname.lower() == table_name.lower():
-                        found = True
-                        
-                if not found:
-                    defer = True
-                    break
-
-    if defer and not was_deferred:
-        __deferred_classes__[klass.__name__] = klass
-        
-    # if we are able to find all related and referred to tables, then
-    # we can go ahead and assign the relationships to the class
-    if not defer:
-        relations = {}
-        for propname, reldesc in klass.relations.items():
-            reldesc.process(klass, propname, relations)
-        
-        class_mapper(klass).add_properties(relations)
-        if klass.__name__ in __deferred_classes__: 
-            del __deferred_classes__[klass.__name__]
-        __processed_classes__[klass.__name__] = klass
-    
-    # finally, loop through the deferred classes and attempt to process
-    # relationships for them
-    if not was_deferred:
-        # loop through the list of deferred classes, processing the
-        # relationships, until we can make no more progress
-        last_count = len(__deferred_classes__) + 1
-        while last_count > len(__deferred_classes__):
-            last_count = len(__deferred_classes__)
-            deferred = __deferred_classes__.copy()
-            for deferred_class in deferred.values():
-                process_relationships(deferred_class, was_deferred=True)
-
-
-class ActiveMapperMeta(type):
-    classes = {}
-    metadatas = util.Set()
-    def __init__(cls, clsname, bases, dict):
-        table_name = clsname.lower()
-        columns    = []
-        relations  = {}
-        autoload   = False
-        _metadata  = getattr(sys.modules[cls.__module__], 
-                             "__metadata__", metadata)
-        version_id_col = None
-        version_id_col_object = None
-        table_opts = {}
-
-        if 'mapping' in dict:
-            found_pk = False
-            
-            members = inspect.getmembers(dict.get('mapping'))
-            for name, value in members:
-                if name == '__table__':
-                    table_name = value
-                    continue
-                
-                if '__metadata__' == name:
-                    _metadata= value
-                    continue
-                
-                if '__autoload__' == name:
-                    autoload = True
-                    continue
-                
-                if '__version_id_col__' == name:
-                    version_id_col = value
-                
-                if '__table_opts__' == name:
-                    table_opts = value
-
-                if name.startswith('__'): continue
-                
-                if isinstance(value, column):
-                    if value.primary_key == True: found_pk = True
-                        
-                    if value.foreign_key:
-                        col = Column(value.colname or name, 
-                                     value.coltype,
-                                     value.foreign_key, 
-                                     primary_key=value.primary_key,
-                                     *value.args, **value.kwargs)
-                    else:
-                        col = Column(value.colname or name,
-                                     value.coltype,
-                                     primary_key=value.primary_key,
-                                     *value.args, **value.kwargs)
-                    columns.append(col)
-                    continue
-                
-                if isinstance(value, relationship):
-                    relations[name] = value
-            
-            if not found_pk and not autoload:
-                col = Column('id', Integer, primary_key=True)
-                cls.mapping.id = col
-                columns.append(col)
-            
-            assert _metadata is not None, "No MetaData specified"
-            
-            ActiveMapperMeta.metadatas.add(_metadata)
-            
-            if not autoload:
-                cls.table = Table(table_name, _metadata, *columns, **table_opts)
-                cls.columns = columns
-            else:
-                cls.table = Table(table_name, _metadata, autoload=True, **table_opts)
-                cls.columns = cls.table._columns
-            
-            if version_id_col is not None:
-                version_id_col_object = getattr(cls.table.c, version_id_col, None)
-                assert(version_id_col_object is not None, "version_id_col (%s) does not exist." % version_id_col)
-
-            # check for inheritence
-            if hasattr(bases[0], "mapping"):
-                cls._base_mapper= bases[0].mapper
-                cls.mapper = objectstore.mapper(cls, cls.table, 
-                              inherits=cls._base_mapper, version_id_col=version_id_col_object)
-            else:
-                cls.mapper = objectstore.mapper(cls, cls.table, version_id_col=version_id_col_object)
-            cls.relations = relations
-            ActiveMapperMeta.classes[clsname] = cls
-            
-            process_relationships(cls)
-        
-        super(ActiveMapperMeta, cls).__init__(clsname, bases, dict)
-
-
-
-class ActiveMapper(object):
-    __metaclass__ = ActiveMapperMeta
-    
-    def set(self, **kwargs):
-        for key, value in kwargs.items():
-            setattr(self, key, value)
-
-
-#
-# a utility function to create all tables for all ActiveMapper classes
-#
-
-def create_tables():
-    for metadata in ActiveMapperMeta.metadatas:
-        metadata.create_all()
-
-def drop_tables():
-    for metadata in ActiveMapperMeta.metadatas:
-        metadata.drop_all()
diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py
deleted file mode 100644 (file)
index 5a28fbe..0000000
+++ /dev/null
@@ -1,72 +0,0 @@
-from sqlalchemy import util, exceptions
-import types
-from sqlalchemy.orm import mapper, Query
-
-def _monkeypatch_query_method(name, ctx, class_):
-    def do(self, *args, **kwargs):
-        query = Query(class_, session=ctx.current)
-        util.warn_deprecated('Query methods on the class are deprecated; use %s.query.%s instead' % (class_.__name__, name))
-        return getattr(query, name)(*args, **kwargs)
-    try:
-        do.__name__ = name
-    except:
-        pass
-    if not hasattr(class_, name):
-        setattr(class_, name, classmethod(do))
-
-def _monkeypatch_session_method(name, ctx, class_):
-    def do(self, *args, **kwargs):
-        session = ctx.current
-        return getattr(session, name)(self, *args, **kwargs)
-    try:
-        do.__name__ = name
-    except:
-        pass
-    if not hasattr(class_, name):
-        setattr(class_, name, do)
-
-def assign_mapper(ctx, class_, *args, **kwargs):
-    extension = kwargs.pop('extension', None)
-    if extension is not None:
-        extension = util.to_list(extension)
-        extension.append(ctx.mapper_extension)
-    else:
-        extension = ctx.mapper_extension
-
-    validate = kwargs.pop('validate', False)
-
-    if not isinstance(getattr(class_, '__init__'), types.MethodType):
-        def __init__(self, **kwargs):
-             for key, value in kwargs.items():
-                 if validate:
-                     if not self.mapper.get_property(key,
-                                                     resolve_synonyms=False,
-                                                     raiseerr=False):
-                         raise exceptions.ArgumentError(
-                             "Invalid __init__ argument: '%s'" % key)
-                 setattr(self, key, value)
-        class_.__init__ = __init__
-
-    class query(object):
-        def __getattr__(self, key):
-            return getattr(ctx.current.query(class_), key)
-        def __call__(self):
-            return ctx.current.query(class_)
-
-    if not hasattr(class_, 'query'):
-        class_.query = query()
-
-    for name in ('get', 'filter', 'filter_by', 'select', 'select_by',
-                 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by',
-                 'get_by', 'join_to', 'join_via', 'count', 'count_by',
-                 'options', 'instances'):
-        _monkeypatch_query_method(name, ctx, class_)
-    for name in ('refresh', 'expire', 'delete', 'expunge', 'update'):
-        _monkeypatch_session_method(name, ctx, class_)
-
-    m = mapper(class_, extension=extension, *args, **kwargs)
-    class_.mapper = m
-    return m
-
-assign_mapper = util.deprecated(
-    "assign_mapper is deprecated. Use scoped_session() instead.")(assign_mapper)
index d878f7b9b002961234cd89756100e28330263df8..4d54f6072d286d98ad32477919c9a12889467b63 100644 (file)
@@ -406,13 +406,26 @@ class _AssociationList(object):
     def clear(self):
         del self.col[0:len(self.col)]
 
-    def __eq__(self, other): return list(self) == other
-    def __ne__(self, other): return list(self) != other
-    def __lt__(self, other): return list(self) < other
-    def __le__(self, other): return list(self) <= other
-    def __gt__(self, other): return list(self) > other
-    def __ge__(self, other): return list(self) >= other
-    def __cmp__(self, other): return cmp(list(self), other)
+    def __eq__(self, other):
+        return list(self) == other
+
+    def __ne__(self, other):
+        return list(self) != other
+
+    def __lt__(self, other):
+        return list(self) < other
+
+    def __le__(self, other):
+        return list(self) <= other
+
+    def __gt__(self, other):
+        return list(self) > other
+
+    def __ge__(self, other):
+        return list(self) >= other
+
+    def __cmp__(self, other):
+        return cmp(list(self), other)
 
     def __add__(self, iterable):
         try:
@@ -534,13 +547,26 @@ class _AssociationDict(object):
     def clear(self):
         self.col.clear()
 
-    def __eq__(self, other): return dict(self) == other
-    def __ne__(self, other): return dict(self) != other
-    def __lt__(self, other): return dict(self) < other
-    def __le__(self, other): return dict(self) <= other
-    def __gt__(self, other): return dict(self) > other
-    def __ge__(self, other): return dict(self) >= other
-    def __cmp__(self, other): return cmp(dict(self), other)
+    def __eq__(self, other):
+        return dict(self) == other
+
+    def __ne__(self, other):
+        return dict(self) != other
+
+    def __lt__(self, other):
+        return dict(self) < other
+
+    def __le__(self, other):
+        return dict(self) <= other
+
+    def __gt__(self, other):
+        return dict(self) > other
+
+    def __ge__(self, other):
+        return dict(self) >= other
+
+    def __cmp__(self, other):
+        return cmp(dict(self), other)
 
     def __repr__(self):
         return repr(dict(self.items()))
@@ -802,12 +828,23 @@ class _AssociationSet(object):
     def copy(self):
         return util.Set(self)
 
-    def __eq__(self, other): return util.Set(self) == other
-    def __ne__(self, other): return util.Set(self) != other
-    def __lt__(self, other): return util.Set(self) < other
-    def __le__(self, other): return util.Set(self) <= other
-    def __gt__(self, other): return util.Set(self) > other
-    def __ge__(self, other): return util.Set(self) >= other
+    def __eq__(self, other):
+        return util.Set(self) == other
+
+    def __ne__(self, other):
+        return util.Set(self) != other
+
+    def __lt__(self, other):
+        return util.Set(self) < other
+
+    def __le__(self, other):
+        return util.Set(self) <= other
+
+    def __gt__(self, other):
+        return util.Set(self) > other
+
+    def __ge__(self, other):
+        return util.Set(self) >= other
 
     def __repr__(self):
         return repr(util.Set(self))
index d736736e953ab72822097a0d46f50ee2a63e7d85..f06f16059fae54314b826d27e10a141ed81faf34 100644 (file)
@@ -213,6 +213,9 @@ class DeclarativeMeta(type):
                 continue
             prop = _deferred_relation(cls, value)
             our_stuff[k] = prop
+        
+        # set up attributes in the order they were created
+        our_stuff.sort(lambda x, y: cmp(our_stuff[x]._creation_order, our_stuff[y]._creation_order))
 
         table = None
         if '__table__' not in cls.__dict__:
@@ -254,6 +257,7 @@ class DeclarativeMeta(type):
             mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
         else:
             mapper_cls = mapper
+
         cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args)
         return type.__init__(cls, classname, bases, dict_)
 
index e7464b0bdd2884b3a9357f9e34cdfeb5ff6350bd..21adc85a8fede308d867446f0f7b6b5554d2b2f6 100644 (file)
@@ -34,7 +34,7 @@ which have a user-defined, serialized order::
   u = User()
   u.topten.append(Blurb('Number one!'))
   u.topten.append(Blurb('Number two!'))
-  
+
   # Like magic.
   assert [blurb.position for blurb in u.topten] == [0, 1]
 
@@ -60,7 +60,7 @@ __all__ = [ 'ordering_list' ]
 
 def ordering_list(attr, count_from=None, **kw):
     """Prepares an OrderingList factory for use in mapper definitions.
-    
+
     Returns an object suitable for use as an argument to a Mapper relation's
     ``collection_class`` option.  Arguments are:
 
@@ -73,7 +73,7 @@ def ordering_list(attr, count_from=None, **kw):
       example, ``ordering_list('pos', count_from=1)`` would create a 1-based
       list in SQL, storing the value in the 'pos' column.  Ignored if
       ``ordering_func`` is supplied.
-      
+
     Passes along any keyword arguments to ``OrderingList`` constructor.
     """
 
@@ -108,7 +108,7 @@ def _unsugar_count_from(**kw):
     Keyword argument filter, prepares a simple ``ordering_func`` from a
     ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged.
     """
-    
+
     count_from = kw.pop('count_from', None)
     if kw.get('ordering_func', None) is None and count_from is not None:
         if count_from == 0:
@@ -126,11 +126,11 @@ class OrderingList(list):
     ``ordering_list`` function is used to configure ``OrderingList``
     collections in ``mapper`` relation definitions.
     """
-    
+
     def __init__(self, ordering_attr=None, ordering_func=None,
                  reorder_on_append=False):
         """A custom list that manages position information for its children.
-        
+
         ``OrderingList`` is a ``collection_class`` list implementation that
         syncs position in a Python list with a position attribute on the
         mapped objects.
@@ -148,7 +148,7 @@ class OrderingList(list):
 
           An ``ordering_func`` is called with two positional parameters: the
           index of the element in the list, and the list itself.
-          
+
           If omitted, Python list indexes are used for the attribute values.
           Two basic pre-built numbering functions are provided in this module:
           ``count_from_0`` and ``count_from_1``.  For more exotic examples
@@ -194,7 +194,7 @@ class OrderingList(list):
     def _reorder(self):
         """Sweep through the list and ensure that each object has accurate
         ordering information set."""
-        
+
         for index, entity in enumerate(self):
             self._order_entity(index, entity, True)
 
@@ -206,7 +206,7 @@ class OrderingList(list):
             return
 
         should_be = self.ordering_func(index, self)
-        if have <> should_be:
+        if have != should_be:
             self._set_order_value(entity, should_be)
 
     def append(self, entity):
@@ -229,7 +229,7 @@ class OrderingList(list):
         entity = super(OrderingList, self).pop(index)
         self._reorder()
         return entity
-        
+
     def __setitem__(self, index, entity):
         if isinstance(index, slice):
             for i in range(index.start or 0, index.stop or 0, index.step or 1):
@@ -237,7 +237,7 @@ class OrderingList(list):
         else:
             self._order_entity(index, entity, True)
             super(OrderingList, self).__setitem__(index, entity)
-            
+
     def __delitem__(self, index):
         super(OrderingList, self).__delitem__(index)
         self._reorder()
diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py
deleted file mode 100644 (file)
index 4462282..0000000
+++ /dev/null
@@ -1,28 +0,0 @@
-"""SelectResults has been rolled into Query.  This class is now just a placeholder."""
-
-import sqlalchemy.sql as sql
-import sqlalchemy.orm as orm
-
-class SelectResultsExt(orm.MapperExtension):
-    """a MapperExtension that provides SelectResults functionality for the
-    results of query.select_by() and query.select()"""
-    
-    def select_by(self, query, *args, **params):
-        q = query
-        for a in args:
-            q = q.filter(a)
-        return q.filter_by(**params)
-        
-    def select(self, query, arg=None, **kwargs):
-        if isinstance(arg, sql.FromClause) and arg.supports_execution():
-            return orm.EXT_CONTINUE
-        else:
-            if arg is not None:
-                query = query.filter(arg)
-            return query._legacy_select_kwargs(**kwargs)
-
-def SelectResults(query, clause=None, ops={}):
-    if clause is not None:
-        query = query.filter(clause)
-    query = query.options(orm.extension(SelectResultsExt()))
-    return query._legacy_select_kwargs(**ops)
diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py
deleted file mode 100644 (file)
index 5ac8acb..0000000
+++ /dev/null
@@ -1,50 +0,0 @@
-from sqlalchemy.orm.scoping import ScopedSession, _ScopedExt
-from sqlalchemy.util import warn_deprecated
-from sqlalchemy.orm import create_session
-
-__all__ = ['SessionContext', 'SessionContextExt']
-
-
-class SessionContext(ScopedSession):
-    """Provides thread-local management of Sessions.
-
-    Usage::
-
-      context = SessionContext(sessionmaker(autoflush=True))
-
-    """
-
-    def __init__(self, session_factory=None, scopefunc=None):
-        warn_deprecated("SessionContext is deprecated.  Use scoped_session().")
-        if session_factory is None:
-            session_factory=create_session
-        super(SessionContext, self).__init__(session_factory, scopefunc=scopefunc)
-
-    def get_current(self):
-        return self.registry()
-
-    def set_current(self, session):
-        self.registry.set(session)
-
-    def del_current(self):
-        self.registry.clear()
-
-    current = property(get_current, set_current, del_current,
-                       """Property used to get/set/del the session in the current scope.""")
-
-    def _get_mapper_extension(self):
-        try:
-            return self._extension
-        except AttributeError:
-            self._extension = ext = SessionContextExt(self)
-            return ext
-
-    mapper_extension = property(_get_mapper_extension,
-                                doc="""Get a mapper extension that implements `get_session` using this context.  Deprecated.""")
-
-
-class SessionContextExt(_ScopedExt):
-    def __init__(self, *args, **kwargs):
-        warn_deprecated("SessionContextExt is deprecated.  Use ScopedSession(enhance_classes=True)")
-        super(SessionContextExt, self).__init__(*args, **kwargs)
-
index bad9ba5a80062dbbb06bb2833f21d9cc5b6168da..95971f78786c2e392a26a63b2a42a6e7498548e6 100644 (file)
@@ -210,7 +210,7 @@ Advanced Use
 Accessing the Session
 ---------------------
 
-SqlSoup uses a SessionContext to provide thread-local sessions.  You
+SqlSoup uses a ScopedSession to provide thread-local sessions.  You
 can get a reference to the current one like this::
 
     >>> from sqlalchemy.ext.sqlsoup import objectstore
@@ -325,7 +325,7 @@ Boring tests here.  Nothing of real expository value.
 from sqlalchemy import *
 from sqlalchemy import schema, sql
 from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
+from sqlalchemy.orm.scoping import ScopedSession
 from sqlalchemy.exceptions import *
 from sqlalchemy.sql import expression
 
@@ -379,15 +379,24 @@ __all__ = ['PKNotFoundError', 'SqlSoup']
 #
 # thread local SessionContext
 #
-class Objectstore(SessionContext):
+class Objectstore(ScopedSession):
     def __getattr__(self, key):
-        return getattr(self.current, key)
+        if key.startswith('__'):        # dont trip the registry for module-level sweeps of things
+                                        # like '__bases__'.  the session gets bound to the
+                                        # module which is interfered with by other unit tests.
+                                        # (removal of mapper.get_session() revealed the issue)
+            raise AttributeError()
+        return getattr(self.registry(), key)
+    def current(self):
+        return self.registry()
+    current = property(current)
     def get_session(self):
-        return self.current
+        return self.registry()
 
 objectstore = Objectstore(create_session)
 
-class PKNotFoundError(SQLAlchemyError): pass
+class PKNotFoundError(SQLAlchemyError):
+    pass
 
 def _ddl_error(cls):
     msg = 'SQLSoup can only modify mapped Tables (found: %s)' \
@@ -439,7 +448,7 @@ def _is_outer_join(selectable):
 
 def _selectable_name(selectable):
     if isinstance(selectable, sql.Alias):
-        return _selectable_name(selectable.selectable)
+        return _selectable_name(selectable.element)
     elif isinstance(selectable, sql.Select):
         return ''.join([_selectable_name(s) for s in selectable.froms])
     elif isinstance(selectable, schema.Table):
@@ -457,7 +466,7 @@ def class_for_table(selectable, **mapper_kwargs):
         klass = TableClassType(mapname, (object,), {})
     else:
         klass = SelectableClassType(mapname, (object,), {})
-
+    
     def __cmp__(self, o):
         L = self.__class__.c.keys()
         L.sort()
@@ -482,12 +491,17 @@ def class_for_table(selectable, **mapper_kwargs):
     for m in ['__cmp__', '__repr__']:
         setattr(klass, m, eval(m))
     klass._table = selectable
+    klass.c = expression.ColumnCollection()
     mappr = mapper(klass,
                    selectable,
-                   extension=objectstore.mapper_extension,
+                   extension=objectstore.extension,
                    allow_null_pks=_is_outer_join(selectable),
                    **mapper_kwargs)
-    klass._query = Query(mappr)
+                   
+    for k in mappr.iterate_properties:
+        klass.c[k.key] = k.columns[0]
+
+    klass._query = objectstore.query_property()
     return klass
 
 class SqlSoup:
index eaad2576988fdb0bf58b970a633bd34a91ac5034..959989662d8115b871f1dc9aa54e0c4a75ecc419 100644 (file)
@@ -67,7 +67,7 @@ class PoolListener(object):
           The ``_ConnectionFairy`` which manages the connection for the span of
           the current checkout.
 
-        If you raise an ``exceptions.DisconnectionError``, the current
+        If you raise an ``exc.DisconnectionError``, the current
         connection will be disposed and a fresh connection retrieved.
         Processing of all checkout listeners will abort and restart
         using the new connection.
@@ -87,3 +87,24 @@ class PoolListener(object):
           The ``_ConnectionRecord`` that persistently manages the connection
 
         """
+
+class ConnectionProxy(object):
+    """Allows interception of statement execution by Connections.
+    
+    Subclass ``ConnectionProxy``, overriding either or both of 
+    ``execute()`` and ``cursor_execute()``  The default behavior is provided,
+    which is to call the given executor function with the remaining 
+    arguments.  The proxy is then connected to an engine via
+    ``create_engine(url, proxy=MyProxy())`` where ``MyProxy`` is
+    the user-defined ``ConnectionProxy`` class.
+    
+    """
+    def execute(self, conn, execute, clauseelement, *multiparams, **params):
+        """"""
+        return execute(clauseelement, *multiparams, **params)
+
+    def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+        """"""
+        return execute(cursor, statement, parameters, context)
+
+        
similarity index 90%
rename from lib/sqlalchemy/logging.py
rename to lib/sqlalchemy/log.py
index 13872caa3850da70dba89e7f542d95e92ead3398..65100d4695571c45323021b29629f74e2bd0a7c8 100644 (file)
@@ -1,4 +1,4 @@
-# logging.py - adapt python logging module to SQLAlchemy
+# log.py - adapt python logging module to SQLAlchemy
 # Copyright (C) 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com
 #
 # This module is part of SQLAlchemy and is released under
@@ -28,25 +28,19 @@ is equivalent to::
     logger.setLevel(logging.DEBUG)
 """
 
-import sys, warnings
-import sqlalchemy.exceptions as sa_exc
+import logging
+import sys
 
-# py2.5 absolute imports will fix....
-logging = __import__('logging')
-
-# moved to sqlalchemy.exceptions.  this alias will be removed in 0.5.
-SADeprecationWarning = sa_exc.SADeprecationWarning
 
 rootlogger = logging.getLogger('sqlalchemy')
 if rootlogger.level == logging.NOTSET:
     rootlogger.setLevel(logging.WARN)
-warnings.filterwarnings("once", category=sa_exc.SADeprecationWarning)
 
 default_enabled = False
 def default_logging(name):
     global default_enabled
     if logging.getLogger(name).getEffectiveLevel() < logging.WARN:
-        default_enabled=True
+        default_enabled = True
     if not default_enabled:
         default_enabled = True
         handler = logging.StreamHandler(sys.stdout)
diff --git a/lib/sqlalchemy/mods/__init__.py b/lib/sqlalchemy/mods/__init__.py
deleted file mode 100644 (file)
index e69de29..0000000
diff --git a/lib/sqlalchemy/mods/selectresults.py b/lib/sqlalchemy/mods/selectresults.py
deleted file mode 100644 (file)
index 25bfa28..0000000
+++ /dev/null
@@ -1,7 +0,0 @@
-from sqlalchemy.ext.selectresults import SelectResultsExt
-from sqlalchemy.orm.mapper import global_extensions
-
-def install_plugin():
-    global_extensions.append(SelectResultsExt)
-
-install_plugin()
index 2466a27637bf9b0e16541d5a600312b5b8d3e136..9c23fd409cfce5dc2b822dcffdb75f3842b46c90 100644 (file)
@@ -9,63 +9,98 @@ Functional constructs for ORM configuration.
 
 See the SQLAlchemy object relational tutorial and mapper configuration
 documentation for an overview of how this module is used.
+
 """
 
-from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, _mapper_registry
-from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, EXT_STOP, EXT_PASS, ExtensionOption, PropComparator
-from sqlalchemy.orm.properties import SynonymProperty, ComparableProperty, PropertyLoader, ColumnProperty, CompositeProperty, BackRef
+from sqlalchemy.orm import exc
+from sqlalchemy.orm.mapper import \
+     Mapper, _mapper_registry, class_mapper, object_mapper
+from sqlalchemy.orm.interfaces import \
+     EXT_CONTINUE, EXT_STOP, ExtensionOption, InstrumentationManager, \
+     MapperExtension, PropComparator, SessionExtension
+from sqlalchemy.orm.properties import \
+     BackRef, ColumnProperty, ComparableProperty, CompositeProperty, \
+     PropertyLoader, SynonymProperty
 from sqlalchemy.orm import mapper as mapperlib
 from sqlalchemy.orm import strategies
-from sqlalchemy.orm.query import Query, aliased
-from sqlalchemy.orm.util import polymorphic_union, create_row_adapter
+from sqlalchemy.orm.query import AliasOption, Query
+from sqlalchemy.orm.util import \
+     AliasedClass as aliased, join, outerjoin, polymorphic_union, with_parent
+from sqlalchemy.sql import util as sql_util
 from sqlalchemy.orm.session import Session as _Session
 from sqlalchemy.orm.session import object_session, sessionmaker
 from sqlalchemy.orm.scoping import ScopedSession
-
-
-__all__ = [ 'relation', 'column_property', 'composite', 'backref', 'eagerload',
-            'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer',
-            'undefer', 'undefer_group', 'extension', 'mapper', 'clear_mappers',
-            'compile_mappers', 'class_mapper', 'object_mapper', 'sessionmaker',
-            'scoped_session', 'dynamic_loader', 'MapperExtension',
-            'polymorphic_union', 'comparable_property',
-            'create_session', 'synonym', 'contains_alias', 'Query', 'aliased',
-            'contains_eager', 'EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS',
-            'object_session', 'PropComparator' ]
+from sqlalchemy import util as sa_util
+
+__all__ = (
+    'EXT_CONTINUE',
+    'EXT_STOP',
+    'InstrumentationManager',
+    'MapperExtension',
+    'PropComparator',
+    'Query',
+    'aliased',
+    'backref',
+    'class_mapper',
+    'clear_mappers',
+    'column_property',
+    'comparable_property',
+    'compile_mappers',
+    'composite',
+    'contains_alias',
+    'contains_eager',
+    'create_session',
+    'defer',
+    'deferred',
+    'dynamic_loader',
+    'eagerload',
+    'eagerload_all',
+    'extension',
+    'lazyload',
+    'mapper',
+    'noload',
+    'object_mapper',
+    'object_session',
+    'polymorphic_union',
+    'relation',
+    'scoped_session',
+    'sessionmaker',
+    'synonym',
+    'undefer',
+    'undefer_group',
+    )
 
 
 def scoped_session(session_factory, scopefunc=None):
-  """Provides thread-local management of Sessions.
-
-  This is a front-end function to the [sqlalchemy.orm.scoping#ScopedSession]
-  class.
+    """Provides thread-local management of Sessions.
 
-  Usage::
+    This is a front-end function to the [sqlalchemy.orm.scoping#ScopedSession]
+    class.
 
-    Session = scoped_session(sessionmaker(autoflush=True))
+    Usage::
 
-  To instantiate a Session object which is part of the scoped
-  context, instantiate normally::
+      Session = scoped_session(sessionmaker(autoflush=True))
 
-    session = Session()
+    To instantiate a Session object which is part of the scoped context,
+    instantiate normally::
 
-  Most session methods are available as classmethods from
-  the scoped session::
+      session = Session()
 
-    Session.commit()
-    Session.close()
+    Most session methods are available as classmethods from the scoped
+    session::
 
-  To map classes so that new instances are saved in the current
-  Session automatically, as well as to provide session-aware
-  class attributes such as "query", use the `mapper` classmethod
-  from the scoped session::
+      Session.commit()
+      Session.close()
 
-    mapper = Session.mapper
-    mapper(Class, table, ...)
+    To map classes so that new instances are saved in the current Session
+    automatically, as well as to provide session-aware class attributes such
+    as "query", use the `mapper` classmethod from the scoped session::
 
-  """
+      mapper = Session.mapper
+      mapper(Class, table, ...)
 
-  return ScopedSession(session_factory, scopefunc=scopefunc)
+    """
+    return ScopedSession(session_factory, scopefunc=scopefunc)
 
 def create_session(bind=None, **kwargs):
     """create a new [sqlalchemy.orm.session#Session].
@@ -76,26 +111,36 @@ def create_session(bind=None, **kwargs):
     It is recommended to use the [sqlalchemy.orm#sessionmaker()] function
     instead of create_session().
     """
+
+    if 'transactional' in kwargs:
+        sa_util.warn_deprecated(
+            "The 'transactional' argument to sessionmaker() is deprecated; "
+            "use autocommit=True|False instead.")
+        if 'autocommit' in kwargs:
+            raise TypeError('Specify autocommit *or* transactional, not both.')
+        kwargs['autocommit'] = not kwargs.pop('transactional')
+
     kwargs.setdefault('autoflush', False)
-    kwargs.setdefault('transactional', False)
+    kwargs.setdefault('autocommit', True)
+    kwargs.setdefault('autoexpire', False)
     return _Session(bind=bind, **kwargs)
 
 def relation(argument, secondary=None, **kwargs):
     """Provide a relationship of a primary Mapper to a secondary Mapper.
 
-    This corresponds to a parent-child or associative table relationship.
-    The constructed class is an instance of [sqlalchemy.orm.properties#PropertyLoader].
+    This corresponds to a parent-child or associative table relationship.  The
+    constructed class is an instance of
+    [sqlalchemy.orm.properties#PropertyLoader].
 
       argument
           a class or Mapper instance, representing the target of the relation.
 
       secondary
         for a many-to-many relationship, specifies the intermediary table. The
-        ``secondary`` keyword argument should generally only be used for a table
-        that is not otherwise expressed in any class mapping. In particular,
-        using the Association Object Pattern is
-        generally mutually exclusive against using the ``secondary`` keyword
-        argument.
+        ``secondary`` keyword argument should generally only be used for a
+        table that is not otherwise expressed in any class mapping. In
+        particular, using the Association Object Pattern is generally mutually
+        exclusive against using the ``secondary`` keyword argument.
 
       \**kwargs follow:
 
@@ -482,8 +527,8 @@ def mapper(class_, local_table=None, *args, **params):
         which will identify the class/mapper combination to be used
         with a particular row.  Requires the ``polymorphic_identity``
         value to be set for all mappers in the inheritance
-        hierarchy.  The column specified by ``polymorphic_on`` is 
-        usually a column that resides directly within the base 
+        hierarchy.  The column specified by ``polymorphic_on`` is
+        usually a column that resides directly within the base
         mapper's mapped table; alternatively, it may be a column
         that is only present within the <selectable> portion
         of the ``with_polymorphic`` argument.
@@ -532,7 +577,7 @@ def mapper(class_, local_table=None, *args, **params):
         to be used against this mapper's selectable unit.  This is
         normally simply the primary key of the `local_table`, but
         can be overridden here.
-    
+
       with_polymorphic
         A tuple in the form ``(<classes>, <selectable>)`` indicating the
         default style of "polymorphic" loading, that is, which tables
@@ -549,9 +594,9 @@ def mapper(class_, local_table=None, *args, **params):
         which load from a "concrete" inheriting table, the <selectable>
         argument is required, since it usually requires more complex
         UNION queries.
-        
+
       select_table
-        Deprecated.  Synonymous with 
+        Deprecated.  Synonymous with
         ``with_polymorphic=('*', <selectable>)``.
 
       version_id_col
@@ -677,15 +722,16 @@ def extension(ext):
 
     return ExtensionOption(ext)
 
-def eagerload(name, mapper=None):
+def eagerload(*keys):
     """Return a ``MapperOption`` that will convert the property of the given name into an eager load.
 
     Used with ``query.options()``.
     """
 
-    return strategies.EagerLazyOption(name, lazy=False, mapper=mapper)
+    return strategies.EagerLazyOption(keys, lazy=False)
+eagerload = sa_util.array_as_starargs_fn_decorator(eagerload)
 
-def eagerload_all(name, mapper=None):
+def eagerload_all(*keys):
     """Return a ``MapperOption`` that will convert all properties along the given dot-separated path into an eager load.
 
     For example, this::
@@ -698,25 +744,27 @@ def eagerload_all(name, mapper=None):
     Used with ``query.options()``.
     """
 
-    return strategies.EagerLazyOption(name, lazy=False, chained=True, mapper=mapper)
+    return strategies.EagerLazyOption(keys, lazy=False, chained=True)
+eagerload_all = sa_util.array_as_starargs_fn_decorator(eagerload_all)
 
-def lazyload(name, mapper=None):
+def lazyload(*keys):
     """Return a ``MapperOption`` that will convert the property of the
     given name into a lazy load.
 
     Used with ``query.options()``.
     """
 
-    return strategies.EagerLazyOption(name, lazy=True, mapper=mapper)
+    return strategies.EagerLazyOption(keys, lazy=True)
+lazyload = sa_util.array_as_starargs_fn_decorator(lazyload)
 
-def noload(name):
+def noload(*keys):
     """Return a ``MapperOption`` that will convert the property of the
     given name into a non-load.
 
     Used with ``query.options()``.
     """
 
-    return strategies.EagerLazyOption(name, lazy=None)
+    return strategies.EagerLazyOption(keys, lazy=None)
 
 def contains_alias(alias):
     """Return a ``MapperOption`` that will indicate to the query that
@@ -726,22 +774,9 @@ def contains_alias(alias):
     alias.
     """
 
-    class AliasedRow(MapperExtension):
-        def __init__(self, alias):
-            self.alias = alias
-            if isinstance(self.alias, basestring):
-                self.translator = None
-            else:
-                self.translator = create_row_adapter(alias)
-        
-        def translate_row(self, mapper, context, row):
-            if not self.translator:
-                self.translator = create_row_adapter(mapper.mapped_table.alias(self.alias))
-            return self.translator(row)
-
-    return ExtensionOption(AliasedRow(alias))
+    return AliasOption(alias)
 
-def contains_eager(key, alias=None, decorator=None):
+def contains_eager(*keys, **kwargs):
     """Return a ``MapperOption`` that will indicate to the query that
     the given attribute will be eagerly loaded.
 
@@ -752,30 +787,31 @@ def contains_eager(key, alias=None, decorator=None):
     `alias` is the string name of an alias, **or** an ``sql.Alias``
     object, which represents the aliased columns in the query.  This
     argument is optional.
-
-    `decorator` is mutually exclusive of `alias` and is a
-    row-processing function which will be applied to the incoming row
-    before sending to the eager load handler.  use this for more
-    sophisticated row adjustments beyond a straight alias.
     """
+    alias = kwargs.pop('alias', None)
+    if kwargs:
+        raise exceptions.ArgumentError("Invalid kwargs for contains_eager: %r" % kwargs.keys())
+        
+    return (strategies.EagerLazyOption(keys, lazy=False), strategies.LoadEagerFromAliasOption(keys, alias=alias))
+contains_eager = sa_util.array_as_starargs_fn_decorator(contains_eager)
 
-    return (strategies.EagerLazyOption(key, lazy=False), strategies.RowDecorateOption(key, alias=alias, decorator=decorator))
-
-def defer(name):
+def defer(*keys):
     """Return a ``MapperOption`` that will convert the column property
     of the given name into a deferred load.
 
     Used with ``query.options()``"""
-    return strategies.DeferredOption(name, defer=True)
+    return strategies.DeferredOption(keys, defer=True)
+defer = sa_util.array_as_starargs_fn_decorator(defer)
 
-def undefer(name):
+def undefer(*keys):
     """Return a ``MapperOption`` that will convert the column property
     of the given name into a non-deferred (regular column) load.
 
     Used with ``query.options()``.
     """
 
-    return strategies.DeferredOption(name, defer=False)
+    return strategies.DeferredOption(keys, defer=False)
+undefer = sa_util.array_as_starargs_fn_decorator(undefer)
 
 def undefer_group(name):
     """Return a ``MapperOption`` that will convert the given
index fb0621a70f5320b861f17cd49e59e7867cba1c6b..7ce825c9d759d3e21ccb0dfc72bcd9db82cf6950 100644 (file)
@@ -4,26 +4,69 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import operator, weakref
-from itertools import chain
-import UserDict
+import operator
+import weakref
+
 from sqlalchemy import util
+from sqlalchemy.util import attrgetter, itemgetter, EMPTY_SET
 from sqlalchemy.orm import interfaces, collections
-from sqlalchemy.orm.util import identity_equal
-from sqlalchemy import exceptions
+import sqlalchemy.exceptions as sa_exc
+
+# lazy imports
+_entity_info = None
+identity_equal = None
 
 PASSIVE_NORESULT = util.symbol('PASSIVE_NORESULT')
 ATTR_WAS_SET = util.symbol('ATTR_WAS_SET')
 NO_VALUE = util.symbol('NO_VALUE')
 NEVER_SET = util.symbol('NEVER_SET')
+NO_ENTITY_NAME = util.symbol('NO_ENTITY_NAME')
 
-class InstrumentedAttribute(interfaces.PropComparator):
-    """public-facing instrumented attribute, placed in the
-    class dictionary.
+INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__'
+"""Attribute, elects custom instrumentation when present on a mapped class.
 
-    """
+Allows a class to specify a slightly or wildly different technique for
+tracking changes made to mapped attributes and collections.
+
+Only one instrumentation implementation is allowed in a given object
+inheritance hierarchy.
+
+The value of this attribute must be a callable and will be passed a class
+object.  The callable must return one of:
+
+  - An instance of an interfaces.InstrumentationManager or subclass
+  - An object implementing all or some of InstrumentationManager (todo)
+  - A dictionary of callables, implementing all or some of the above (todo)
+  - An instance of a ClassManager or subclass
+
+interfaces.InstrumentationManager is public API and will remain stable
+between releases.  ClassManager is not public and no guarantees are made
+about stability.  Caveat emptor.
+
+This attribute is consulted by the default SQLAlchemy instrumentation
+resultion code.  If custom finders are installed in the global
+instrumentation_finders list, they may or may not choose to honor this
+attribute.
+
+"""
 
-    def __init__(self, impl, comparator=None):
+instrumentation_finders = []
+"""An extensible sequence of instrumentation implementation finding callables.
+
+Finders callables will be passed a class object.  If None is returned, the
+next finder in the sequence is consulted.  Otherwise the return must be an
+instrumentation factory that follows the same guidelines as
+INSTRUMENTATION_MANAGER.
+
+By default, the only finder is find_native_user_instrumentation_hook, which
+searches for INSTRUMENTATION_MANAGER.  If all finders return None, standard
+ClassManager instrumentation is used.
+
+"""
+    
+class QueryableAttribute(interfaces.PropComparator):
+
+    def __init__(self, impl, comparator=None, parententity=None):
         """Construct an InstrumentedAttribute.
         comparator
           a sql.Comparator to which class-level compare/math events will be sent
@@ -31,76 +74,52 @@ class InstrumentedAttribute(interfaces.PropComparator):
 
         self.impl = impl
         self.comparator = comparator
+        self.parententity = parententity
 
-    def __set__(self, instance, value):
-        self.impl.set(instance._state, value, None)
-
-    def __delete__(self, instance):
-        self.impl.delete(instance._state)
-
-    def __get__(self, instance, owner):
-        if instance is None:
-            return self
-        return self.impl.get(instance._state)
+        if parententity:
+            mapper, selectable, is_aliased_class = _entity_info(parententity, compile=False)
+            self.property = mapper._get_property(self.impl.key)
+        else:
+            self.property = None
 
     def get_history(self, instance, **kwargs):
-        return self.impl.get_history(instance._state, **kwargs)
-
-    def clause_element(self):
-        return self.comparator.clause_element()
-
-    def expression_element(self):
-        return self.comparator.expression_element()
-
+        return self.impl.get_history(instance_state(instance), **kwargs)
+    
+    def __selectable__(self):
+        # TODO: conditionally attach this method based on clause_element ?
+        return self
+    
+    def __clause_element__(self):
+        return self.comparator.__clause_element__()
+    
+    def label(self, name):
+        return self.__clause_element__().label(name)
+        
     def operate(self, op, *other, **kwargs):
         return op(self.comparator, *other, **kwargs)
 
     def reverse_operate(self, op, other, **kwargs):
         return op(other, self.comparator, **kwargs)
 
-    def hasparent(self, instance, optimistic=False):
-        return self.impl.hasparent(instance._state, optimistic=optimistic)
-
-    def _property(self):
-        from sqlalchemy.orm.mapper import class_mapper
-        return class_mapper(self.impl.class_).get_property(self.impl.key)
-    property = property(_property, doc="the MapperProperty object associated with this attribute")
-
-class ProxiedAttribute(InstrumentedAttribute):
-    """Adds InstrumentedAttribute class-level behavior to a regular descriptor.
-
-    Obsoleted by proxied_attribute_factory.
-    """
+    def hasparent(self, state, optimistic=False):
+        return self.impl.hasparent(state, optimistic=optimistic)
 
-    class ProxyImpl(object):
-        accepts_scalar_loader = False
+    def __str__(self):
+        return repr(self.parententity) + "." + self.property.key
 
-        def __init__(self, key):
-            self.key = key
+class InstrumentedAttribute(QueryableAttribute):
+    """Public-facing descriptor, placed in the mapped class dictionary."""
 
-    def __init__(self, key, user_prop, comparator=None):
-        self.user_prop = user_prop
-        self._comparator = comparator
-        self.key = key
-        self.impl = ProxiedAttribute.ProxyImpl(key)
+    def __set__(self, instance, value):
+        self.impl.set(instance_state(instance), value, None)
 
-    def comparator(self):
-        if callable(self._comparator):
-            self._comparator = self._comparator()
-        return self._comparator
-    comparator = property(comparator)
+    def __delete__(self, instance):
+        self.impl.delete(instance_state(instance))
 
     def __get__(self, instance, owner):
         if instance is None:
-            self.user_prop.__get__(instance, owner)
             return self
-        return self.user_prop.__get__(instance, owner)
-
-    def __set__(self, instance, value):
-        return self.user_prop.__set__(instance, value)
-
-    def __delete__(self, instance):
-        return self.user_prop.__delete__(instance)
+        return self.impl.get(instance_state(instance))
 
 def proxied_attribute_factory(descriptor):
     """Create an InstrumentedAttribute / user descriptor hybrid.
@@ -111,17 +130,19 @@ def proxied_attribute_factory(descriptor):
 
     class ProxyImpl(object):
         accepts_scalar_loader = False
+        
         def __init__(self, key):
             self.key = key
-
+        
     class Proxy(InstrumentedAttribute):
         """A combination of InsturmentedAttribute and a regular descriptor."""
 
-        def __init__(self, key, descriptor, comparator):
+        def __init__(self, key, descriptor, comparator, parententity):
             self.key = key
             # maintain ProxiedAttribute.user_prop compatability.
             self.descriptor = self.user_prop = descriptor
             self._comparator = comparator
+            self._parententity = parententity
             self.impl = ProxyImpl(key)
 
         def comparator(self):
@@ -148,6 +169,11 @@ def proxied_attribute_factory(descriptor):
         def __getattr__(self, attribute):
             """Delegate __getattr__ to the original descriptor."""
             return getattr(descriptor, attribute)
+            
+        def _property(self):
+            return self._parententity.get_property(self.key, resolve_synonyms=True)
+        property = property(_property)
+        
     Proxy.__name__ = type(descriptor).__name__ + 'Proxy'
 
     util.monkeypatch_proxied_specials(Proxy, type(descriptor),
@@ -158,7 +184,7 @@ def proxied_attribute_factory(descriptor):
 class AttributeImpl(object):
     """internal implementation for instrumented attributes."""
 
-    def __init__(self, class_, key, callable_, trackparent=False, extension=None, compare_function=None, **kwargs):
+    def __init__(self, class_, key, callable_, class_manager, trackparent=False, extension=None, compare_function=None, **kwargs):
         """Construct an AttributeImpl.
 
         class_
@@ -190,6 +216,7 @@ class AttributeImpl(object):
         self.class_ = class_
         self.key = key
         self.callable_ = callable_
+        self.class_manager = class_manager
         self.trackparent = trackparent
         if compare_function is None:
             self.is_equal = operator.eq
@@ -210,16 +237,16 @@ class AttributeImpl(object):
 
         An instance attribute that is loaded by a callable function
         will also not have a `hasparent` flag.
-        """
 
+        """
         return state.parents.get(id(self), optimistic)
 
     def sethasparent(self, state, value):
         """Set a boolean flag on the given item corresponding to
         whether or not it is attached to a parent object via the
         attribute represented by this ``InstrumentedAttribute``.
-        """
 
+        """
         state.parents[id(self)] = value
 
     def set_callable(self, state, callable_):
@@ -235,8 +262,8 @@ class AttributeImpl(object):
 
         The callable overrides the class level callable set in the
         ``InstrumentedAttribute` constructor.
-        """
 
+        """
         if callable_ is None:
             self.initialize(state)
         else:
@@ -249,7 +276,7 @@ class AttributeImpl(object):
         if self.key in state.callables:
             return state.callables[self.key]
         elif self.callable_ is not None:
-            return self.callable_(state.obj())
+            return self.callable_(state)
         else:
             return None
 
@@ -271,7 +298,7 @@ class AttributeImpl(object):
             return state.dict[self.key]
         except KeyError:
             # if no history, check for lazy callables, etc.
-            if self.key not in state.committed_state:
+            if state.committed_state.get(self.key, NEVER_SET) is NEVER_SET:
                 callable_ = self._get_callable(state)
                 if callable_ is not None:
                     if passive:
@@ -310,34 +337,54 @@ class AttributeImpl(object):
     def set_committed_value(self, state, value):
         """set an attribute value on the given instance and 'commit' it."""
 
-        state.commit_attr(self, value)
+        state.commit([self.key])
+        
+        state.callables.pop(self.key, None)
+        state.dict[self.key] = value
+        
         return value
 
 class ScalarAttributeImpl(AttributeImpl):
     """represents a scalar value-holding InstrumentedAttribute."""
 
     accepts_scalar_loader = True
+    uses_objects = False
 
     def delete(self, state):
-        if self.key not in state.committed_state:
-            state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE)
+        state.modified_event(self, False, state.dict.get(self.key, NO_VALUE))
 
         # TODO: catch key errors, convert to attributeerror?
-        del state.dict[self.key]
-        state.modified=True
+        if self.extensions:
+            old = self.get(state)
+            del state.dict[self.key]
+            self.fire_remove_event(state, old, None)
+        else:
+            del state.dict[self.key]
 
     def get_history(self, state, passive=False):
-        return _create_history(self, state, state.dict.get(self.key, NO_VALUE))
+        return History.from_attribute(
+            self, state, state.dict.get(self.key, NO_VALUE))
 
     def set(self, state, value, initiator):
         if initiator is self:
             return
 
-        if self.key not in state.committed_state:
-            state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE)
+        state.modified_event(self, False, state.dict.get(self.key, NO_VALUE))
 
-        state.dict[self.key] = value
-        state.modified=True
+        if self.extensions:
+            old = self.get(state)
+            state.dict[self.key] = value
+            self.fire_replace_event(state, value, old, initiator)
+        else:
+            state.dict[self.key] = value
+
+    def fire_replace_event(self, state, value, previous, initiator):
+        for ext in self.extensions:
+            ext.set(state, value, previous, initiator or self)
+
+    def fire_remove_event(self, state, value, initiator):
+        for ext in self.extensions:
+            ext.remove(state, value, initiator or self)
 
     def type(self):
         self.property.columns[0].type
@@ -348,39 +395,38 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl):
     changes within the value itself.
     """
 
-    def __init__(self, class_, key, callable_, copy_function=None, compare_function=None, **kwargs):
-        super(ScalarAttributeImpl, self).__init__(class_, key, callable_, compare_function=compare_function, **kwargs)
-        class_._class_state.has_mutable_scalars = True
+    uses_objects = False
+
+    def __init__(self, class_, key, callable_, class_manager, copy_function=None, compare_function=None, **kwargs):
+        super(ScalarAttributeImpl, self).__init__(class_, key, callable_, class_manager, compare_function=compare_function, **kwargs)
+        class_manager.mutable_attributes.add(key)
         if copy_function is None:
-            raise exceptions.ArgumentError("MutableScalarAttributeImpl requires a copy function")
+            raise sa_exc.ArgumentError("MutableScalarAttributeImpl requires a copy function")
         self.copy = copy_function
 
     def get_history(self, state, passive=False):
-        return _create_history(self, state, state.dict.get(self.key, NO_VALUE))
+        return History.from_attribute(
+            self, state, state.dict.get(self.key, NO_VALUE))
 
-    def commit_to_state(self, state, value):
-        state.committed_state[self.key] = self.copy(value)
+    def commit_to_state(self, state, dest):
+        dest[self.key] = self.copy(state.dict[self.key])
 
     def check_mutable_modified(self, state):
         (added, unchanged, deleted) = self.get_history(state, passive=True)
-        if added or deleted:
-            state.modified = True
-            return True
-        else:
-            return False
+        return bool(added or deleted)
 
     def set(self, state, value, initiator):
         if initiator is self:
             return
 
-        if self.key not in state.committed_state:
-            if self.key in state.dict:
-                state.committed_state[self.key] = self.copy(state.dict[self.key])
-            else:
-                state.committed_state[self.key] = NO_VALUE
+        state.modified_event(self, True, NEVER_SET)
 
-        state.dict[self.key] = value
-        state.modified=True
+        if self.extensions:
+            old = self.get(state)
+            state.dict[self.key] = value
+            self.fire_replace_event(state, value, old, initiator)
+        else:
+            state.dict[self.key] = value
 
 
 class ScalarObjectAttributeImpl(ScalarAttributeImpl):
@@ -390,10 +436,11 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
     """
 
     accepts_scalar_loader = False
+    uses_objects = True
 
-    def __init__(self, class_, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
+    def __init__(self, class_, key, callable_, class_manager, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
         super(ScalarObjectAttributeImpl, self).__init__(class_, key,
-          callable_, trackparent=trackparent, extension=extension,
+          callable_, class_manager, trackparent=trackparent, extension=extension,
           compare_function=compare_function, **kwargs)
         if compare_function is None:
             self.is_equal = identity_equal
@@ -406,13 +453,13 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 
     def get_history(self, state, passive=False):
         if self.key in state.dict:
-            return _create_history(self, state, state.dict[self.key])
+            return History.from_attribute(self, state, state.dict[self.key])
         else:
             current = self.get(state, passive=passive)
             if current is PASSIVE_NORESULT:
                 return (None, None, None)
             else:
-                return _create_history(self, state, current)
+                return History.from_attribute(self, state, current)
 
     def set(self, state, value, initiator):
         """Set a value on the given InstanceState.
@@ -424,43 +471,33 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl):
 
         if initiator is self:
             return
-
-        if value is not None and not hasattr(value, '_state'):
-            raise TypeError("Can not assign %s instance to %s's %r attribute, "
-                            "a mapped instance was expected." % (
-                type(value).__name__, type(state.obj()).__name__, self.key))
-
-        # TODO: add options to allow the get() to be passive
+        
+        # may want to add options to allow the get() here to be passive
         old = self.get(state)
         state.dict[self.key] = value
         self.fire_replace_event(state, value, old, initiator)
 
     def fire_remove_event(self, state, value, initiator):
-        if self.key not in state.committed_state:
-            state.committed_state[self.key] = value
-        state.modified = True
+        state.modified_event(self, False, value)
 
         if self.trackparent and value is not None:
-            self.sethasparent(value._state, False)
+            self.sethasparent(instance_state(value), False)
 
-        instance = state.obj()
         for ext in self.extensions:
-            ext.remove(instance, value, initiator or self)
+            ext.remove(state, value, initiator or self)
 
     def fire_replace_event(self, state, value, previous, initiator):
-        if self.key not in state.committed_state:
-            state.committed_state[self.key] = previous
-        state.modified = True
+        state.modified_event(self, False, previous)
 
         if self.trackparent:
             if value is not None:
-                self.sethasparent(value._state, True)
+                self.sethasparent(instance_state(value), True)
             if previous is not value and previous is not None:
-                self.sethasparent(previous._state, False)
+                self.sethasparent(instance_state(previous), False)
 
-        instance = state.obj()
         for ext in self.extensions:
-            ext.set(instance, value, previous, initiator or self)
+            ext.set(state, value, previous, initiator or self)
+
 
 class CollectionAttributeImpl(AttributeImpl):
     """A collection-holding attribute that instruments changes in membership.
@@ -471,22 +508,21 @@ class CollectionAttributeImpl(AttributeImpl):
     container object (defaulting to a list) and brokers access to the
     CollectionAdapter, a "view" onto that object that presents consistent
     bag semantics to the orm layer independent of the user data implementation.
+    
     """
     accepts_scalar_loader = False
+    uses_objects = True
 
-    def __init__(self, class_, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
+    def __init__(self, class_, key, callable_, class_manager, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
         super(CollectionAttributeImpl, self).__init__(class_,
-          key, callable_, trackparent=trackparent, extension=extension,
-          compare_function=compare_function, **kwargs)
+          key, callable_, class_manager, trackparent=trackparent,
+          extension=extension, compare_function=compare_function, **kwargs)
 
         if copy_function is None:
             copy_function = self.__copy
         self.copy = copy_function
 
-        if typecallable is None:
-            typecallable = list
-        self.collection_factory = \
-          collections._prepare_instrumentation(typecallable)
+        self.collection_factory = typecallable
         # may be removed in 0.5:
         self.collection_interface = \
           util.duck_type_collection(self.collection_factory())
@@ -499,42 +535,34 @@ class CollectionAttributeImpl(AttributeImpl):
         if current is PASSIVE_NORESULT:
             return (None, None, None)
         else:
-            return _create_history(self, state, current)
+            return History.from_attribute(self, state, current)
 
     def fire_append_event(self, state, value, initiator):
-        if self.key not in state.committed_state and self.key in state.dict:
-            state.committed_state[self.key] = self.copy(state.dict[self.key])
-
-        state.modified = True
+        state.modified_event(self, True, NEVER_SET, passive=True)
 
         if self.trackparent and value is not None:
-            self.sethasparent(value._state, True)
-        instance = state.obj()
+            self.sethasparent(instance_state(value), True)
+
         for ext in self.extensions:
-            ext.append(instance, value, initiator or self)
+            ext.append(state, value, initiator or self)
 
     def fire_pre_remove_event(self, state, initiator):
-        if self.key not in state.committed_state and self.key in state.dict:
-            state.committed_state[self.key] = self.copy(state.dict[self.key])
+        state.modified_event(self, True, NEVER_SET, passive=True)
 
     def fire_remove_event(self, state, value, initiator):
-        if self.key not in state.committed_state and self.key in state.dict:
-            state.committed_state[self.key] = self.copy(state.dict[self.key])
-
-        state.modified = True
+        state.modified_event(self, True, NEVER_SET, passive=True)
 
         if self.trackparent and value is not None:
-            self.sethasparent(value._state, False)
+            self.sethasparent(instance_state(value), False)
 
-        instance = state.obj()
         for ext in self.extensions:
-            ext.remove(instance, value, initiator or self)
+            ext.remove(state, value, initiator or self)
 
     def delete(self, state):
         if self.key not in state.dict:
             return
 
-        state.modified = True
+        state.modified_event(self, True, NEVER_SET)
 
         collection = self.get_collection(state)
         collection.clear_with_event()
@@ -544,10 +572,14 @@ class CollectionAttributeImpl(AttributeImpl):
     def initialize(self, state):
         """Initialize this attribute on the given object instance with an empty collection."""
 
-        _, user_data = self._build_collection(state)
+        _, user_data = self._initialize_collection(state)
         state.dict[self.key] = user_data
         return user_data
 
+    def _initialize_collection(self, state):
+        return state.manager.initialize_collection(
+            self.key, state, self.collection_factory)
+
     def append(self, state, value, initiator, passive=False):
         if initiator is self:
             return
@@ -597,7 +629,7 @@ class CollectionAttributeImpl(AttributeImpl):
         """
         # pulling a new collection first so that an adaptation exception does
         # not trigger a lazy load of the old collection.
-        new_collection, user_data = self._build_collection(state)
+        new_collection, user_data = self._initialize_collection(state)
         if adapter:
             new_values = list(adapter(new_collection, iterable))
         else:
@@ -610,25 +642,20 @@ class CollectionAttributeImpl(AttributeImpl):
         if old is iterable:
             return
 
-        if self.key not in state.committed_state:
-            state.committed_state[self.key] = self.copy(old)
+        state.modified_event(self, True, old)
 
         old_collection = self.get_collection(state, old)
 
         state.dict[self.key] = user_data
-        state.modified = True
 
         collections.bulk_replace(new_values, old_collection, new_collection)
         old_collection.unlink(old)
 
 
     def set_committed_value(self, state, value):
-        """Set an attribute value on the given instance and 'commit' it.
-
-        Loads the existing collection from lazy callables in all cases.
-        """
+        """Set an attribute value on the given instance and 'commit' it."""
 
-        collection, user_data = self._build_collection(state)
+        collection, user_data = self._initialize_collection(state)
 
         if value:
             for item in value:
@@ -637,30 +664,23 @@ class CollectionAttributeImpl(AttributeImpl):
         state.callables.pop(self.key, None)
         state.dict[self.key] = user_data
 
+        state.commit([self.key])
+
         if self.key in state.pending:
-            # pending items.  commit loaded data, add/remove new data
-            state.committed_state[self.key] = list(value or [])
-            added = state.pending[self.key].added_items
-            removed = state.pending[self.key].deleted_items
+            # pending items exist.  issue a modified event,
+            # add/remove new items.
+            state.modified_event(self, True, user_data)
+
+            pending = state.pending.pop(self.key)
+            added = pending.added_items
+            removed = pending.deleted_items
             for item in added:
                 collection.append_without_event(item)
             for item in removed:
                 collection.remove_without_event(item)
-            del state.pending[self.key]
-        elif self.key in state.committed_state:
-            # no pending items.  remove committed state if any.
-            # (this can occur with an expired attribute)
-            del state.committed_state[self.key]
 
         return user_data
 
-    def _build_collection(self, state):
-        """build a new, blank collection and return it wrapped in a CollectionAdapter."""
-
-        user_data = self.collection_factory()
-        collection = collections.CollectionAdapter(self, state, user_data)
-        return collection, user_data
-
     def get_collection(self, state, user_data=None, passive=False):
         """retrieve the CollectionAdapter associated with the given state.
 
@@ -672,13 +692,8 @@ class CollectionAttributeImpl(AttributeImpl):
             user_data = self.get(state, passive=passive)
             if user_data is PASSIVE_NORESULT:
                 return user_data
-        try:
-            return getattr(user_data, '_sa_adapter')
-        except AttributeError:
-            # TODO: this codepath never occurs, and this
-            # except/initialize should be removed
-            collections.CollectionAdapter(self, state, user_data)
-            return getattr(user_data, '_sa_adapter')
+
+        return getattr(user_data, '_sa_adapter')
 
 class GenericBackrefExtension(interfaces.AttributeExtension):
     """An extension which synchronizes a two-way relationship.
@@ -692,134 +707,150 @@ class GenericBackrefExtension(interfaces.AttributeExtension):
     def __init__(self, key):
         self.key = key
 
-    def set(self, instance, child, oldchild, initiator):
+    def set(self, state, child, oldchild, initiator):
         if oldchild is child:
             return
         if oldchild is not None:
             # With lazy=None, there's no guarantee that the full collection is
             # present when updating via a backref.
-            impl = getattr(oldchild.__class__, self.key).impl
+            old_state = instance_state(oldchild)
+            impl = old_state.get_impl(self.key)
             try:
-                impl.remove(oldchild._state, instance, initiator, passive=True)
+                impl.remove(old_state, state.obj(), initiator, passive=True)
             except (ValueError, KeyError, IndexError):
                 pass
         if child is not None:
-            getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True)
+            new_state = instance_state(child)
+            new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=True)
 
-    def append(self, instance, child, initiator):
-        getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True)
+    def append(self, state, child, initiator):
+        child_state = instance_state(child)
+        child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=True)
 
-    def remove(self, instance, child, initiator):
+    def remove(self, state, child, initiator):
         if child is not None:
-            getattr(child.__class__, self.key).impl.remove(child._state, instance, initiator, passive=True)
-
-class ClassState(object):
-    """tracks state information at the class level."""
-    def __init__(self):
-        self.mappers = {}
-        self.attrs = {}
-        self.has_mutable_scalars = False
+            child_state = instance_state(child)
+            child_state.get_impl(self.key).remove(child_state, state.obj(), initiator, passive=True)
 
-import sets
-_empty_set = sets.ImmutableSet()
 
 class InstanceState(object):
     """tracks state information at the instance level."""
 
-    def __init__(self, obj):
+    _cleanup = None
+    session_id = None
+    key = None
+    runid = None
+    entity_name = NO_ENTITY_NAME
+    expired_attributes = EMPTY_SET
+    
+    def __init__(self, obj, manager):
         self.class_ = obj.__class__
-        self.obj = weakref.ref(obj, self.__cleanup)
+        self.manager = manager
+        self.obj = weakref.ref(obj, self._cleanup)
         self.dict = obj.__dict__
         self.committed_state = {}
         self.modified = False
         self.callables = {}
         self.parents = {}
         self.pending = {}
-        self.appenders = {}
-        self.instance_dict = None
-        self.runid = None
-        self.expired_attributes = _empty_set
-
-    def __cleanup(self, ref):
-        # tiptoe around Python GC unpredictableness
-        instance_dict = self.instance_dict
-        if instance_dict is None:
-            return
+        self.expired = False
+    
+    def dispose(self):
+        del self.session_id
+        
+    def check_modified(self):
+        if self.modified:
+            return True
+        else:
+            for key in self.manager.mutable_attributes:
+                if self.manager[key].impl.check_mutable_modified(self):
+                    return True
+            else:
+                return False
 
-        instance_dict = instance_dict()
-        if instance_dict is None or instance_dict._mutex is None:
-            return
+    def initialize_instance(*mixed, **kwargs):
+        self, instance, args = mixed[0], mixed[1], mixed[2:]
+        manager = self.manager
 
-        # the mutexing here is based on the assumption that gc.collect()
-        # may be firing off cleanup handlers in a different thread than that
-        # which is normally operating upon the instance dict.
-        instance_dict._mutex.acquire()
+        for fn in manager.events.on_init:
+            fn(self, instance, args, kwargs)
         try:
-            try:
-                self.__resurrect(instance_dict)
-            except:
-                # catch app cleanup exceptions.  no other way around this
-                # without warnings being produced
-                pass
-        finally:
-            instance_dict._mutex.release()
+            return manager.events.original_init(*mixed[1:], **kwargs)
+        except:
+            for fn in manager.events.on_init_failure:
+                fn(self, instance, args, kwargs)
+            raise
 
-    def _check_resurrect(self, instance_dict):
-        instance_dict._mutex.acquire()
-        try:
-            return self.obj() or self.__resurrect(instance_dict)
-        finally:
-            instance_dict._mutex.release()
+    def get_history(self, key, **kwargs):
+        return self.manager.get_impl(key).get_history(self, **kwargs)
+
+    def get_impl(self, key):
+        return self.manager.get_impl(key)
+
+    def get_inst(self, key):
+        return self.manager.get_inst(key)
 
     def get_pending(self, key):
         if key not in self.pending:
             self.pending[key] = PendingCollection()
         return self.pending[key]
 
-    def is_modified(self):
-        if self.modified:
-            return True
-        elif self.class_._class_state.has_mutable_scalars:
-            for attr in _managed_attributes(self.class_):
-                if hasattr(attr.impl, 'check_mutable_modified') and attr.impl.check_mutable_modified(self):
-                    return True
-            else:
-                return False
-        else:
-            return False
+    def value_as_iterable(self, key, passive=False):
+        """return an InstanceState attribute as a list,
+        regardless of it being a scalar or collection-based
+        attribute.
 
-    def __resurrect(self, instance_dict):
-        if self.is_modified():
-            # store strong ref'ed version of the object; will revert
-            # to weakref when changes are persisted
-            obj = new_instance(self.class_, state=self)
-            self.obj = weakref.ref(obj, self.__cleanup)
-            self._strong_obj = obj
-            obj.__dict__.update(self.dict)
-            self.dict = obj.__dict__
-            return obj
-        else:
-            del instance_dict[self.dict['_instance_key']]
+        returns None if passive=True and the getter returns
+        PASSIVE_NORESULT.
+        """
+
+        impl = self.get_impl(key)
+        x = impl.get(self, passive=passive)
+        if x is PASSIVE_NORESULT:
             return None
+        elif hasattr(impl, 'get_collection'):
+            return impl.get_collection(self, x, passive=passive)
+        elif isinstance(x, list):
+            return x
+        else:
+            return [x]
+
+    def _run_on_load(self, instance=None):
+        if instance is None:
+            instance = self.obj()
+        self.manager.events.run('on_load', instance)
 
     def __getstate__(self):
-        return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':self.expired_attributes, 'callables':self.callables}
+        return {'key': self.key,
+                'entity_name': self.entity_name,
+                'committed_state': self.committed_state,
+                'pending': self.pending,
+                'parents': self.parents,
+                'modified': self.modified,
+                'expired':self.expired,
+                'instance': self.obj(),
+                'expired_attributes':self.expired_attributes,
+                'callables': self.callables}
 
     def __setstate__(self, state):
         self.committed_state = state['committed_state']
         self.parents = state['parents']
+        self.key = state['key']
+        self.session_id = None
+        self.entity_name = state['entity_name']
         self.pending = state['pending']
         self.modified = state['modified']
         self.obj = weakref.ref(state['instance'])
         self.class_ = self.obj().__class__
+        self.manager = manager_of_class(self.class_)
         self.dict = self.obj().__dict__
         self.callables = state['callables']
         self.runid = None
-        self.appenders = {}
+        self.expired = state['expired']
         self.expired_attributes = state['expired_attributes']
 
     def initialize(self, key):
-        getattr(self.class_, key).impl.initialize(self)
+        self.manager.get_impl(key).initialize(self)
 
     def set_callable(self, key, callable_):
         self.dict.pop(key, None)
@@ -829,70 +860,70 @@ class InstanceState(object):
         """__call__ allows the InstanceState to act as a deferred
         callable for loading expired attributes, which is also
         serializable.
+        
         """
-        instance = self.obj()
         unmodified = self.unmodified
-        self.class_._class_state.deferred_scalar_loader(instance, [
-            attr.impl.key for attr in _managed_attributes(self.class_) if
+        class_manager = self.manager
+        class_manager.deferred_scalar_loader(self, [
+            attr.impl.key for attr in class_manager.attributes if
                 attr.impl.accepts_scalar_loader and
                 attr.impl.key in self.expired_attributes and
                 attr.impl.key in unmodified
             ])
         for k in self.expired_attributes:
             self.callables.pop(k, None)
-        self.expired_attributes.clear()
+        del self.expired_attributes
         return ATTR_WAS_SET
 
     def unmodified(self):
         """a set of keys which have no uncommitted changes"""
 
         return util.Set([
-            attr.impl.key for attr in _managed_attributes(self.class_) if
-            attr.impl.key not in self.committed_state
-            and (not hasattr(attr.impl, 'commit_to_state') or not attr.impl.check_mutable_modified(self))
-        ])
+            key for key in self.manager.keys() if 
+            key not in self.committed_state
+            or (key in self.manager.mutable_attributes and not self.manager[key].impl.check_mutable_modified(self))
+        ])  
     unmodified = property(unmodified)
 
     def expire_attributes(self, attribute_names):
         self.expired_attributes = util.Set(self.expired_attributes)
 
         if attribute_names is None:
-            for attr in _managed_attributes(self.class_):
-                self.dict.pop(attr.impl.key, None)
-                self.expired_attributes.add(attr.impl.key)
-                if attr.impl.accepts_scalar_loader:
-                    self.callables[attr.impl.key] = self
-
-            self.committed_state = {}
-        else:
-            for key in attribute_names:
-                self.dict.pop(key, None)
-                self.committed_state.pop(key, None)
-                self.expired_attributes.add(key)
-                if getattr(self.class_, key).impl.accepts_scalar_loader:
-                    self.callables[key] = self
+            attribute_names = self.manager.keys()
+            self.expired = True
+            self.modified = False
+        for key in attribute_names:
+            self.dict.pop(key, None)
+            self.committed_state.pop(key, None)
+            self.expired_attributes.add(key)
+            if self.manager.get_impl(key).accepts_scalar_loader:
+                self.callables[key] = self
 
     def reset(self, key):
         """remove the given attribute and any callables associated with it."""
+
         self.dict.pop(key, None)
         self.callables.pop(key, None)
-
-    def commit_attr(self, attr, value):
-        """set the value of an attribute and mark it 'committed'."""
-
-        if hasattr(attr, 'commit_to_state'):
-            attr.commit_to_state(self, value)
-        else:
-            self.committed_state.pop(attr.key, None)
-        self.dict[attr.key] = value
-        self.pending.pop(attr.key, None)
-        self.appenders.pop(attr.key, None)
-
-        # we have a value so we can also unexpire it
-        self.callables.pop(attr.key, None)
-        if attr.key in self.expired_attributes:
-            self.expired_attributes.remove(attr.key)
-
+    
+    def modified_event(self, attr, should_copy, previous, passive=False):
+        needs_committed = attr.key not in self.committed_state
+    
+        if needs_committed:
+            if previous is NEVER_SET:
+                if passive:
+                    if attr.key in self.dict:
+                        previous = self.dict[attr.key]
+                else:
+                    previous = attr.get(self)
+                
+            if should_copy and previous not in (None, NO_VALUE, NEVER_SET):
+                previous = attr.copy(previous)
+            
+            if needs_committed:
+                self.committed_state[attr.key] = previous
+                
+        self.modified = True
+    
     def commit(self, keys):
         """commit all attributes named in the given list of key names.
 
@@ -903,219 +934,405 @@ class InstanceState(object):
         if a value was not populated in state.dict.
         """
 
-        if self.class_._class_state.has_mutable_scalars:
-            for key in keys:
-                attr = getattr(self.class_, key).impl
-                if hasattr(attr, 'commit_to_state') and attr.key in self.dict:
-                    attr.commit_to_state(self, self.dict[attr.key])
-                else:
-                    self.committed_state.pop(attr.key, None)
-                self.pending.pop(key, None)
-                self.appenders.pop(key, None)
-        else:
-            for key in keys:
+        class_manager = self.manager
+        for key in keys:
+            if key in self.dict and key in class_manager.mutable_attributes:
+                class_manager[key].impl.commit_to_state(self, self.committed_state)
+            else:
                 self.committed_state.pop(key, None)
-                self.pending.pop(key, None)
-                self.appenders.pop(key, None)
 
+        self.expired = False
         # unexpire attributes which have loaded
         for key in self.expired_attributes.intersection(keys):
             if key in self.dict:
                 self.expired_attributes.remove(key)
                 self.callables.pop(key, None)
 
-
     def commit_all(self):
         """commit all attributes unconditionally.
 
-        This is used after a flush() or a regular instance load or refresh operation
-        to mark committed all populated attributes.
+        This is used after a flush() or a full load/refresh
+        to remove all pending state from the instance.
+        
+         - all attributes are marked as "committed"
+         - the "strong dirty reference" is removed
+         - the "modified" flag is set to False
+         - any "expired" markers/callables are removed.
 
         Attributes marked as "expired" can potentially remain "expired" after this step
         if a value was not populated in state.dict.
+        
         """
-
         self.committed_state = {}
-        self.modified = False
-        self.pending = {}
-        self.appenders = {}
-
+        
         # unexpire attributes which have loaded
-        for key in list(self.expired_attributes):
-            if key in self.dict:
-                self.expired_attributes.remove(key)
+        if self.expired_attributes:
+            for key in self.expired_attributes.intersection(self.dict):
                 self.callables.pop(key, None)
+            self.expired_attributes.difference_update(self.dict)
+        
+        for key in self.manager.mutable_attributes:
+            if key in self.dict:
+                self.manager[key].impl.commit_to_state(self, self.committed_state)
 
-        if self.class_._class_state.has_mutable_scalars:
-            for attr in _managed_attributes(self.class_):
-                if hasattr(attr.impl, 'commit_to_state') and attr.impl.key in self.dict:
-                    attr.impl.commit_to_state(self, self.dict[attr.impl.key])
-
-        # remove strong ref
+        self.modified = self.expired = False
         self._strong_obj = None
 
 
-class WeakInstanceDict(UserDict.UserDict):
-    """similar to WeakValueDictionary, but wired towards 'state' objects."""
+class Events(object):
+    def __init__(self):
+        self.original_init = object.__init__
+        self.on_init = ()
+        self.on_init_failure = ()
+        self.on_load = ()
+
+    def run(self, event, *args, **kwargs):
+        for fn in getattr(self, event):
+            fn(*args, **kwargs)
+
+    def add_listener(self, event, listener):
+        # not thread safe... problem?
+        bucket = getattr(self, event)
+        if bucket == ():
+            setattr(self, event, [listener])
+        else:
+            bucket.append(listener)
 
-    def __init__(self, *args, **kw):
-        self._wr = weakref.ref(self)
-        # RLock because the mutex is used by a cleanup handler, which can be
-        # called at any time (including within an already mutexed block)
-        self._mutex = util.threading.RLock()
-        UserDict.UserDict.__init__(self, *args, **kw)
+    def remove_listener(self, event, listener):
+        bucket = getattr(self, event)
+        bucket.remove(listener)
 
-    def __getitem__(self, key):
-        state = self.data[key]
-        o = state.obj()
-        if o is None:
-            o = state._check_resurrect(self)
-        if o is None:
-            raise KeyError, key
-        return o
 
-    def __contains__(self, key):
-        try:
-            state = self.data[key]
-            o = state.obj()
-            if o is None:
-                o = state._check_resurrect(self)
-        except KeyError:
-            return False
-        return o is not None
+class ClassManager(dict):
+    """tracks state information at the class level."""
 
-    def has_key(self, key):
-        return key in self
+    MANAGER_ATTR = '_fooclass_manager'
+    STATE_ATTR = '_foostate'
 
-    def __repr__(self):
-        return "<InstanceDict at %s>" % id(self)
+    event_registry_factory = Events
+    instance_state_factory = InstanceState
 
-    def __setitem__(self, key, value):
-        if key in self.data:
-            self._mutex.acquire()
-            try:
-                if key in self.data:
-                    self.data[key].instance_dict = None
-            finally:
-                self._mutex.release()
-        self.data[key] = value._state
-        value._state.instance_dict = self._wr
-
-    def __delitem__(self, key):
-        state = self.data[key]
-        state.instance_dict = None
-        del self.data[key]
-
-    def get(self, key, default=None):
-        try:
-            state = self.data[key]
-        except KeyError:
-            return default
+    def __init__(self, class_):
+        self.class_ = class_
+        self.factory = None  # where we came from, for inheritance bookkeeping
+        self.info = {}
+        self.mappers = {}
+        self.mutable_attributes = util.Set()
+        self.local_attrs = {}
+        self.originals = {}
+        for base in class_.__mro__[-2:0:-1]:   # reverse, skipping 1st and last
+            cls_state = manager_of_class(base)
+            if cls_state:
+                self.update(cls_state)
+        self.registered = False
+        self._instantiable = False
+        self.events = self.event_registry_factory()
+
+    def instantiable(self, boolean):
+        # experiment, probably won't stay in this form
+        assert boolean ^ self._instantiable, (boolean, self._instantiable)
+        if boolean:
+            self.events.original_init = self.class_.__init__
+            new_init = _generate_init(self.class_, self)
+            self.install_member('__init__', new_init)
         else:
-            o = state.obj()
-            if o is None:
-                # This should only happen
-                return default
-            else:
-                return o
-
-    def items(self):
-        L = []
-        for key, state in self.data.items():
-            o = state.obj()
-            if o is not None:
-                L.append((key, o))
-        return L
-
-    def iteritems(self):
-        for state in self.data.itervalues():
-            value = state.obj()
-            if value is not None:
-                yield value._instance_key, value
+            self.uninstall_member('__init__')
+        self._instantiable = bool(boolean)
+    instantiable = property(lambda s: s._instantiable, instantiable)
 
-    def iterkeys(self):
-        return self.data.iterkeys()
+    def manage(self):
+        """Mark this instance as the manager for its class."""
+        setattr(self.class_, self.MANAGER_ATTR, self)
 
-    def __iter__(self):
-        return self.data.iterkeys()
+    def dispose(self):
+        """Dissasociate this instance from its class."""
+        delattr(self.class_, self.MANAGER_ATTR)
 
-    def __len__(self):
-        return len(self.values())
+    def manager_getter(self):
+        return attrgetter(self.MANAGER_ATTR)
 
-    def itervalues(self):
-        for state in self.data.itervalues():
-            instance = state.obj()
-            if instance is not None:
-                yield instance
+    def instrument_attribute(self, key, inst, propagated=False):
+        if propagated:
+            if key in self.local_attrs:
+                return  # don't override local attr with inherited attr
+        else:
+            self.local_attrs[key] = inst
+            self.install_descriptor(key, inst)
+        self[key] = inst
+        for cls in self.class_.__subclasses__():
+            manager = manager_of_class(cls)
+            if manager is None:
+                manager = create_manager_for_cls(cls)
+            manager.instrument_attribute(key, inst, True)
+
+    def uninstrument_attribute(self, key, propagated=False):
+        if key not in self:
+            return
+        if propagated:
+            if key in self.local_attrs:
+                return  # don't get rid of local attr
+        else:
+            del self.local_attrs[key]
+            self.uninstall_descriptor(key)
+        del self[key]
+        if key in self.mutable_attributes:
+            self.mutable_attributes.remove(key)
+        for cls in self.class_.__subclasses__():
+            manager = manager_of_class(cls)
+            if manager is None:
+                manager = create_manager_for_cls(cls)
+            manager.uninstrument_attribute(key, True)
+
+    def unregister(self):
+        for key in list(self):
+            if key in self.local_attrs:
+                self.uninstrument_attribute(key)
+        self.registered = False
+
+    def install_descriptor(self, key, inst):
+        if key in (self.STATE_ATTR, self.MANAGER_ATTR):
+            raise KeyError("%r: requested attribute name conflicts with "
+                           "instrumentation attribute of the same name." % key)
+        setattr(self.class_, key, inst)
+
+    def uninstall_descriptor(self, key):
+        delattr(self.class_, key)
+
+    def install_member(self, key, implementation):
+        if key in (self.STATE_ATTR, self.MANAGER_ATTR):
+            raise KeyError("%r: requested attribute name conflicts with "
+                           "instrumentation attribute of the same name." % key)
+        self.originals.setdefault(key, getattr(self.class_, key, None))
+        setattr(self.class_, key, implementation)
+
+    def uninstall_member(self, key):
+        original = self.originals.pop(key, None)
+        if original is not None:
+            setattr(self.class_, key, original)
+
+    def instrument_collection_class(self, key, collection_class):
+        return collections.prepare_instrumentation(collection_class)
+
+    def initialize_collection(self, key, state, factory):
+        user_data = factory()
+        adapter = collections.CollectionAdapter(
+            self.get_impl(key), state, user_data)
+        return adapter, user_data
+
+    def is_instrumented(self, key, search=False):
+        if search:
+            return key in self
+        else:
+            return key in self.local_attrs
 
-    def values(self):
-        L = []
-        for state in self.data.values():
-            o = state.obj()
-            if o is not None:
-                L.append(o)
-        return L
+    def get_impl(self, key):
+        return self[key].impl
 
-    def popitem(self):
-        raise NotImplementedError()
+    get_inst = dict.__getitem__
 
-    def pop(self, key, *args):
-        raise NotImplementedError()
+    def attributes(self):
+        return self.itervalues()
+    attributes = property(attributes)
 
-    def setdefault(self, key, default=None):
-        raise NotImplementedError()
+    def deferred_scalar_loader(cls, state, keys):
+        """TODO"""
+    deferred_scalar_loader = classmethod(deferred_scalar_loader)
 
-    def update(self, dict=None, **kwargs):
-        raise NotImplementedError()
+    ## InstanceState management
 
-    def copy(self):
-        raise NotImplementedError()
+    def new_instance(self, state=None):
+        instance = self.class_.__new__(self.class_)
+        self.setup_instance(instance, state)
+        return instance
+
+    def setup_instance(self, instance, with_state=None):
+        """Register an InstanceState with an instance."""
+        if self.has_state(instance):
+            state = self.state_of(instance)
+            if with_state:
+                assert state is with_state
+            return state
+        if with_state is None:
+            with_state = self.instance_state_factory(instance, self)
+        self.install_state(instance, with_state)
+        return with_state
+
+    def install_state(self, instance, state):
+        setattr(instance, self.STATE_ATTR, state)
+
+    def has_state(self, instance):
+        """True if an InstanceState is installed on the instance."""
+        return bool(getattr(instance, self.STATE_ATTR, False))
+
+    def state_of(self, instance):
+        """Retrieve the InstanceState of an instance.
+
+        May raise KeyError or AttributeError if no state is available.
+        """
+        return getattr(instance, self.STATE_ATTR)
 
-    def all_states(self):
-        return self.data.values()
+    def state_getter(self):
+        """Return a (instance) -> InstanceState callable.
 
-class StrongInstanceDict(dict):
-    def all_states(self):
-        return [o._state for o in self.values()]
+        "state getter" callables should raise either KeyError or
+        AttributeError if no InstanceState could be found for the
+        instance.
+        """
+        return attrgetter(self.STATE_ATTR)
 
-def _create_history(attr, state, current):
-    original = state.committed_state.get(attr.key, NEVER_SET)
+    def _new_state_if_none(self, instance):
+        """Install a default InstanceState if none is present.
 
-    if hasattr(attr, 'get_collection'):
-        current = attr.get_collection(state, current)
-        if original is NO_VALUE:
-            return (list(current), [], [])
-        elif original is NEVER_SET:
-            return ([], list(current), [])
+        A private convenience method used by the __init__ decorator.
+        """
+        if self.has_state(instance):
+            return False
         else:
-            collection = util.OrderedIdentitySet(current)
-            s = util.OrderedIdentitySet(original)
-            return (list(collection.difference(s)), list(collection.intersection(s)), list(s.difference(collection)))
-    else:
-        if current is NO_VALUE:
-            if original not in [None, NEVER_SET, NO_VALUE]:
-                deleted = [original]
+            new_state = self.instance_state_factory(instance, self)
+            self.install_state(instance, new_state)
+            return new_state
+
+    def has_parent(self, state, key, optimistic=False):
+        """TODO"""
+        return self.get_impl(key).hasparent(state, optimistic=optimistic)
+
+    def __nonzero__(self):
+        """All ClassManagers are non-zero regardless of attribute state."""
+        return True
+
+    def __repr__(self):
+        return '<%s of %r at %x>' % (
+            self.__class__.__name__, self.class_, id(self))
+
+class _ClassInstrumentationAdapter(ClassManager):
+    """Adapts a user-defined InstrumentationManager to a ClassManager."""
+
+    def __init__(self, class_, override):
+        ClassManager.__init__(self, class_)
+        self._adapted = override
+
+    def manage(self):
+        self._adapted.manage(self.class_, self)
+
+    def dispose(self):
+        self._adapted.dispose(self.class_)
+
+    def manager_getter(self):
+        return self._adapted.manager_getter(self.class_)
+
+    def instrument_attribute(self, key, inst, propagated=False):
+        ClassManager.instrument_attribute(self, key, inst, propagated)
+        if not propagated:
+            self._adapted.instrument_attribute(self.class_, key, inst)
+
+    def install_descriptor(self, key, inst):
+        self._adapted.install_descriptor(self.class_, key, inst)
+
+    def uninstall_descriptor(self, key):
+        self._adapted.uninstall_descriptor(self.class_, key)
+
+    def install_member(self, key, implementation):
+        self._adapted.install_member(self.class_, key, implementation)
+
+    def uninstall_member(self, key):
+        self._adapted.uninstall_member(self.class_, key)
+
+    def instrument_collection_class(self, key, collection_class):
+        return self._adapted.instrument_collection_class(
+            self.class_, key, collection_class)
+
+    def initialize_collection(self, key, state, factory):
+        delegate = getattr(self._adapted, 'initialize_collection', None)
+        if delegate:
+            return delegate(key, state, factory)
+        else:
+            return ClassManager.initialize_collection(self, key, state, factory)
+            
+    def setup_instance(self, instance, state=None):
+        self._adapted.initialize_instance_dict(self.class_, instance)
+        state = ClassManager.setup_instance(self, instance, with_state=state)
+        state.dict = self._adapted.get_instance_dict(self.class_, instance)
+        return state
+
+    def install_state(self, instance, state):
+        self._adapted.install_state(self.class_, instance, state)
+
+    def state_of(self, instance):
+        if hasattr(self._adapted, 'state_of'):
+            return self._adapted.state_of(self.class_, instance)
+        else:
+            getter = self._adapted.state_getter(self.class_)
+            return getter(instance)
+
+    def has_state(self, instance):
+        if hasattr(self._adapted, 'has_state'):
+            return self._adapted.has_state(self.class_, instance)
+        else:
+            try:
+                state = self.state_of(instance)
+                return True
+            except (KeyError, AttributeError):
+                return False
+
+    def state_getter(self):
+        return self._adapted.state_getter(self.class_)
+
+
+class History(tuple):
+    # TODO: migrate [] marker for empty slots to ()
+    __slots__ = ()
+
+    added = property(itemgetter(0))
+    unchanged = property(itemgetter(1))
+    deleted = property(itemgetter(2))
+
+    def __new__(cls, added, unchanged, deleted):
+        return tuple.__new__(cls, (added, unchanged, deleted))
+
+    def from_attribute(cls, attribute, state, current):
+        original = state.committed_state.get(attribute.key, NEVER_SET)
+
+        if hasattr(attribute, 'get_collection'):
+            current = attribute.get_collection(state, current)
+            if original is NO_VALUE:
+                return cls(list(current), [], [])
+            elif original is NEVER_SET:
+                return cls([], list(current), [])
             else:
-                deleted = []
-            return ([], [], deleted)
-        elif original is NO_VALUE:
-            return ([current], [], [])
-        elif original is NEVER_SET or attr.is_equal(current, original) is True:   # dont let ClauseElement expressions here trip things up
-            return ([], [current], [])
+                collection = util.OrderedIdentitySet(current)
+                s = util.OrderedIdentitySet(original)
+                return cls(list(collection.difference(s)),
+                           list(collection.intersection(s)),
+                           list(s.difference(collection)))
         else:
-            if original is not None:
-                deleted = [original]
+            if current is NO_VALUE:
+                if original not in [None, NEVER_SET, NO_VALUE]:
+                    deleted = [original]
+                else:
+                    deleted = []
+                return cls([], [], deleted)
+            elif original is NO_VALUE:
+                return cls([current], [], [])
+            elif (original is NEVER_SET or
+                  attribute.is_equal(current, original) is True):
+                # dont let ClauseElement expressions here trip things up
+                return cls([], [current], [])
             else:
-                deleted = []
-            return ([current], [], deleted)
+                if original is not None:
+                    deleted = [original]
+                else:
+                    deleted = []
+                return cls([current], [], deleted)
+    from_attribute = classmethod(from_attribute)
+
 
 class PendingCollection(object):
     """stores items appended and removed from a collection that has not been loaded yet.
 
     When the collection is loaded, the changes present in PendingCollection are applied
     to produce the final result.
+    
     """
-
     def __init__(self):
         self.deleted_items = util.IdentitySet()
         self.added_items = util.OrderedIdentitySet()
@@ -1130,166 +1347,280 @@ class PendingCollection(object):
             self.added_items.remove(value)
         self.deleted_items.add(value)
 
-def _managed_attributes(class_):
-    """return all InstrumentedAttributes associated with the given class_ and its superclasses."""
-
-    return chain(*[cl._class_state.attrs.values() for cl in class_.__mro__[:-1] if hasattr(cl, '_class_state')])
 
 def get_history(state, key, **kwargs):
-    return getattr(state.class_, key).impl.get_history(state, **kwargs)
+    return state.get_history(key, **kwargs)
 
-def get_as_list(state, key, passive=False):
-    """return an InstanceState attribute as a list,
-    regardless of it being a scalar or collection-based
-    attribute.
 
-    returns None if passive=True and the getter returns
-    PASSIVE_NORESULT.
-    """
+def has_parent(cls, obj, key, optimistic=False):
+    """TODO"""
+    manager = manager_of_class(cls)
+    state = instance_state(obj)
+    return manager.has_parent(state, key, optimistic)
 
-    attr = getattr(state.class_, key).impl
-    x = attr.get(state, passive=passive)
-    if x is PASSIVE_NORESULT:
-        return None
-    elif hasattr(attr, 'get_collection'):
-        return attr.get_collection(state, x, passive=passive)
-    elif isinstance(x, list):
-        return x
-    else:
-        return [x]
-
-def has_parent(class_, instance, key, optimistic=False):
-    return getattr(class_, key).impl.hasparent(instance._state, optimistic=optimistic)
+def register_class(class_):
+    """TODO"""
+    
+    # TODO: what's this function for ?  why would I call this and not create_manager_for_cls ?
+    
+    manager = manager_of_class(class_)
+    if manager is None:
+        manager = create_manager_for_cls(class_)
+    if not manager.instantiable:
+        manager.instantiable = True
 
-def _create_prop(class_, key, uselist, callable_, typecallable, useobject, mutable_scalars, impl_class, **kwargs):
-    if impl_class:
-        return impl_class(class_, key, typecallable, **kwargs)
-    elif uselist:
-        return CollectionAttributeImpl(class_, key, callable_, typecallable, **kwargs)
-    elif useobject:
-        return ScalarObjectAttributeImpl(class_, key, callable_,**kwargs)
-    elif mutable_scalars:
-        return MutableScalarAttributeImpl(class_, key, callable_, **kwargs)
-    else:
-        return ScalarAttributeImpl(class_, key, callable_, **kwargs)
+def unregister_class(class_):
+    """TODO"""
+    manager = manager_of_class(class_)
+    assert manager
+    assert manager.instantiable
+    manager.instantiable = False
+    manager.unregister()
 
-def manage(instance):
-    """initialize an InstanceState on the given instance."""
+def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_property=None, mutable_scalars=False, impl_class=None, **kwargs):
 
-    if not hasattr(instance, '_state'):
-        instance._state = InstanceState(instance)
+    manager = manager_of_class(class_)
+    if manager.is_instrumented(key):
+        # this currently only occurs if two primary mappers are made for the
+        # same class.  TODO: possibly have InstrumentedAttribute check
+        # "entity_name" when searching for impl.  raise an error if two
+        # attrs attached simultaneously otherwise
+        return
 
-def new_instance(class_, state=None):
-    """create a new instance of class_ without its __init__() method being called.
+    if uselist:
+        factory = kwargs.pop('typecallable', None)
+        typecallable = manager.instrument_collection_class(
+            key, factory or list)
+    else:
+        typecallable = kwargs.pop('typecallable', None)
 
-    Also initializes an InstanceState on the new instance.
-    """
+    comparator = kwargs.pop('comparator', None)
+    parententity = kwargs.pop('parententity', None)
 
-    s = class_.__new__(class_)
-    if state:
-        s._state = state
+    if proxy_property:
+        proxy_type = proxied_attribute_factory(proxy_property)
+        descriptor = proxy_type(key, proxy_property, comparator, parententity)
     else:
-        s._state = InstanceState(s)
-    return s
+        descriptor = InstrumentedAttribute(
+            _create_prop(class_, key, uselist, callable_, 
+                    class_manager=manager,
+                    useobject=useobject,
+                    typecallable=typecallable, 
+                    mutable_scalars=mutable_scalars, 
+                    impl_class=impl_class, 
+                    **kwargs), 
+                comparator=comparator, parententity=parententity)
+
+    manager.instrument_attribute(key, descriptor)
 
-def _init_class_state(class_):
-    if not '_class_state' in class_.__dict__:
-        class_._class_state = ClassState()
+def unregister_attribute(class_, key):
+    manager_of_class(class_).uninstrument_attribute(key)
 
-def register_class(class_, extra_init=None, on_exception=None, deferred_scalar_loader=None):
-    _init_class_state(class_)
-    class_._class_state.deferred_scalar_loader=deferred_scalar_loader
+def init_collection(state, key):
+    """Initialize a collection attribute and return the collection adapter."""
+    attr = state.get_impl(key)
+    user_data = attr.initialize(state)
+    return attr.get_collection(state, user_data)
 
-    oldinit = None
-    doinit = False
+def set_attribute(instance, key, value):
+    state = instance_state(instance)
+    state.get_impl(key).set(state, value, None)
 
-    def init(instance, *args, **kwargs):
-        if not hasattr(instance, '_state'):
-            instance._state = InstanceState(instance)
+def get_attribute(instance, key):
+    state = instance_state(instance)
+    return state.get_impl(key).get(state)
 
-        if extra_init:
-            extra_init(class_, oldinit, instance, args, kwargs)
+def del_attribute(instance, key):
+    state = instance_state(instance)
+    state.get_impl(key).delete(state)
 
-        try:
-            if doinit:
-                oldinit(instance, *args, **kwargs)
-            elif args or kwargs:
-                # simulate error message raised by object(), but don't copy
-                # the text verbatim
-                raise TypeError("default constructor for object() takes no parameters")
-        except:
-            if on_exception:
-                on_exception(class_, oldinit, instance, args, kwargs)
-            raise
+def is_instrumented(instance, key):
+    return manager_of_class(instance.__class__).is_instrumented(key, search=True)
+
+class InstrumentationRegistry(object):
+    """Private instrumentation registration singleton."""
 
+    manager_finders = weakref.WeakKeyDictionary()
+    state_finders = util.WeakIdentityMapping()
+    extended = False
 
-    # override oldinit
-    oldinit = class_.__init__
-    if oldinit is None or not hasattr(oldinit, '_oldinit'):
-        init._oldinit = oldinit
-        class_.__init__ = init
-    # if oldinit is already one of our 'init' methods, replace it
-    elif hasattr(oldinit, '_oldinit'):
-        init._oldinit = oldinit._oldinit
-        class_.__init = init
-        oldinit = oldinit._oldinit
+    def create_manager_for_cls(self, class_):
+        assert class_ is not None
+        assert manager_of_class(class_) is None
 
-    if oldinit is not None:
-        doinit = oldinit is not object.__init__
+        for finder in instrumentation_finders:
+            factory = finder(class_)
+            if factory is not None:
+                break
+        else:
+            factory = ClassManager
+
+        existing_factories = collect_management_factories_for(class_)
+        existing_factories.add(factory)
+        if len(existing_factories) > 1:
+            raise TypeError(
+                "multiple instrumentation implementations specified "
+                "in %s inheritance hierarchy: %r" % (
+                    class_.__name__, list(existing_factories)))
+
+        manager = factory(class_)
+        if not isinstance(manager, ClassManager):
+            manager = _ClassInstrumentationAdapter(class_, manager)
+        if factory != ClassManager and not self.extended:
+            self.extended = True
+            _install_lookup_strategy(self)
+
+        manager.factory = factory
+        manager.manage()
+        self.manager_finders[class_] = manager.manager_getter()
+        self.state_finders[class_] = manager.state_getter()
+        return manager
+
+    def manager_of_class(self, cls):
+        if cls is None:
+            return None
         try:
-            init.__name__ = oldinit.__name__
-            init.__doc__ = oldinit.__doc__
-        except:
-            # cant set __name__ in py 2.3 !
-            pass
+            finder = self.manager_finders[cls]
+        except KeyError:
+            return None
+        else:
+            return finder(cls)
 
-def unregister_class(class_):
-    if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'):
-        if class_.__init__._oldinit is not None:
-            class_.__init__ = class_.__init__._oldinit
+    def state_of(self, instance):
+        if instance is None:
+            raise AttributeError("None has no persistent state.")
+        return self.state_finders[instance.__class__](instance)
+
+    def state_or_default(self, instance, default=None):
+        if instance is None:
+            return default
+        try:
+            finder = self.state_finders[instance.__class__]
+        except KeyError:
+            return default
         else:
-            delattr(class_, '__init__')
+            try:
+                return finder(instance)
+            except (KeyError, AttributeError):
+                return default
+            except:
+                raise
+
+    def unregister(self, class_):
+        if class_ in self.manager_finders:
+            manager = self.manager_of_class(class_)
+            manager.dispose()
+            del self.manager_finders[class_]
+            del self.state_finders[class_]
+
+# Create a registry singleton and prepare placeholders for lookup functions.
+
+instrumentation_registry = InstrumentationRegistry()
+create_manager_for_cls = None
+manager_of_class = None
+instance_state = None
+_lookup_strategy = None
+    
+def _install_lookup_strategy(implementation):
+    """Switch between native and extended instrumentation modes.
 
-    if '_class_state' in class_.__dict__:
-        _class_state = class_.__dict__['_class_state']
-        for key, attr in _class_state.attrs.iteritems():
-            if key in class_.__dict__:
-                delattr(class_, attr.impl.key)
-        delattr(class_, '_class_state')
+    Completely private.  Use the instrumentation_finders interface to
+    inject global instrumentation behavior.
 
-def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_property=None, mutable_scalars=False, impl_class=None, **kwargs):
-    _init_class_state(class_)
+    """
+    global manager_of_class, instance_state, create_manager_for_cls
+    global _lookup_strategy
+
+    # Using a symbol here to make debugging a little friendlier.
+    if implementation is not util.symbol('native'):
+        manager_of_class = implementation.manager_of_class
+        instance_state = implementation.state_of
+        create_manager_for_cls = implementation.create_manager_for_cls
+    else:
+        def manager_of_class(class_):
+            return getattr(class_, ClassManager.MANAGER_ATTR, None)
+        manager_of_class = instrumentation_registry.manager_of_class
+        instance_state = attrgetter(ClassManager.STATE_ATTR)
+        create_manager_for_cls = instrumentation_registry.create_manager_for_cls
+    # TODO: maybe log an event when setting a strategy.
+    _lookup_strategy = implementation
 
-    typecallable = kwargs.pop('typecallable', None)
-    if isinstance(typecallable, InstrumentedAttribute):
-        typecallable = None
-    comparator = kwargs.pop('comparator', None)
+_install_lookup_strategy(util.symbol('native'))
 
-    if key in class_.__dict__ and isinstance(class_.__dict__[key], InstrumentedAttribute):
-        # this currently only occurs if two primary mappers are made for the same class.
-        # TODO:  possibly have InstrumentedAttribute check "entity_name" when searching for impl.
-        # raise an error if two attrs attached simultaneously otherwise
-        return
+def find_native_user_instrumentation_hook(cls):
+    """Find user-specified instrumentation management for a class."""
+    return getattr(cls, INSTRUMENTATION_MANAGER, None)
+instrumentation_finders.append(find_native_user_instrumentation_hook)
 
-    if proxy_property:
-        proxy_type = proxied_attribute_factory(proxy_property)
-        inst = proxy_type(key, proxy_property, comparator)
-    else:
-        inst = InstrumentedAttribute(_create_prop(class_, key, uselist, callable_, useobject=useobject,
-                                       typecallable=typecallable, mutable_scalars=mutable_scalars, impl_class=impl_class, **kwargs), comparator=comparator)
+def collect_management_factories_for(cls):
+    """Return a collection of factories in play or specified for a hierarchy.
 
-    setattr(class_, key, inst)
-    class_._class_state.attrs[key] = inst
+    Traverses the entire inheritance graph of a cls and returns a collection
+    of instrumentation factories for those classes.  Factories are extracted
+    from active ClassManagers, if available, otherwise
+    instrumentation_finders is consulted.
 
-def unregister_attribute(class_, key):
-    class_state = class_._class_state
-    if key in class_state.attrs:
-        del class_._class_state.attrs[key]
-        delattr(class_, key)
+    """
+    hierarchy = util.class_hierarchy(cls)
+    factories = util.Set()
+    for member in hierarchy:
+        manager = manager_of_class(member)
+        if manager is not None:
+            factories.add(manager.factory)
+        else:
+            for finder in instrumentation_finders:
+                factory = finder(member)
+                if factory is not None:
+                    break
+            else:
+                factory = None
+            factories.add(factory)
+    factories.discard(None)
+    return factories
+
+
+def _create_prop(class_, key, uselist, callable_, class_manager, typecallable, useobject, mutable_scalars, impl_class, **kwargs):
+    if impl_class:
+        return impl_class(class_, key, typecallable, class_manager=class_manager, **kwargs)
+    elif uselist:
+        return CollectionAttributeImpl(class_, key, callable_,
+                                       typecallable=typecallable,
+                                       class_manager=class_manager, **kwargs)
+    elif useobject:
+        return ScalarObjectAttributeImpl(class_, key, callable_,
+                                         class_manager=class_manager, **kwargs)
+    elif mutable_scalars:
+        return MutableScalarAttributeImpl(class_, key, callable_,
+                                          class_manager=class_manager, **kwargs)
+    else:
+        return ScalarAttributeImpl(class_, key, callable_,
+                                   class_manager=class_manager, **kwargs)
+
+def _generate_init(class_, class_manager):
+    """Build an __init__ decorator that triggers ClassManager events."""
+
+    original__init__ = class_.__init__
+    assert original__init__
+
+    # Go through some effort here and don't change the user's __init__
+    # calling signature.
+    # FIXME: need to juggle local names to avoid constructor argument
+    # clashes.
+    func_body = """\
+def __init__(%(args)s):
+    new_state = class_manager._new_state_if_none(%(self_arg)s)
+    if new_state:
+        return new_state.initialize_instance(%(apply_kw)s)
+    else:
+        return original__init__(%(apply_kw)s)
+"""
+    func_vars = util.format_argspec_init(original__init__, grouped=False)
+    func_text = func_body % func_vars
+    #TODO: log debug #print func_text
+
+    env = locals().copy()
+    exec func_text in env
+    __init__ = env['__init__']
+    __init__.__doc__ = original__init__.__doc__
+    return __init__
 
-def init_collection(instance, key):
-    """Initialize a collection attribute and return the collection adapter."""
-    attr = getattr(instance.__class__, key).impl
-    state = instance._state
-    user_data = attr.initialize(state)
-    return attr.get_collection(state, user_data)
index c8fc2f189a3758bebbfb5c414bf8acad7a51b972..13204e8c12f44f528eef220b276c22565647c618 100644 (file)
@@ -93,6 +93,7 @@ explicit control over triggering append and remove events.
 
 The owning object and InstrumentedCollectionAttribute are also reachable
 through the adapter, allowing for some very sophisticated behavior.
+
 """
 
 import copy
@@ -101,7 +102,9 @@ import sets
 import sys
 import weakref
 
-from sqlalchemy import exceptions, schema, util as sautil
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import schema
+import sqlalchemy.util as sautil
 from sqlalchemy.util import attrgetter, Set
 
 
@@ -109,6 +112,9 @@ __all__ = ['collection', 'collection_adapter',
            'mapped_collection', 'column_mapped_collection',
            'attribute_mapped_collection']
 
+__instrumentation_mutex = sautil.threading.Lock()
+
+
 def column_mapped_collection(mapping_spec):
     """A dictionary-based collection type with column-based keying.
 
@@ -119,25 +125,29 @@ def column_mapped_collection(mapping_spec):
     can not, for example, map on foreign key values if those key values will
     change during the session, i.e. from None to a database-assigned integer
     after a session flush.
-    """
 
-    from sqlalchemy.orm import object_mapper
+    """
+    from sqlalchemy.orm.util import _state_mapper
+    from sqlalchemy.orm.attributes import instance_state
 
     if isinstance(mapping_spec, schema.Column):
         def keyfunc(value):
-            m = object_mapper(value)
-            return m._get_attr_by_column(value, mapping_spec)
+            state = instance_state(value)
+            m = _state_mapper(state)
+            return m._get_state_attr_by_column(state, mapping_spec)
     else:
         cols = []
         for c in mapping_spec:
             if not isinstance(c, schema.Column):
-                raise exceptions.ArgumentError(
+                raise sa_exc.ArgumentError(
                     "mapping_spec tuple may only contain columns")
             cols.append(c)
         mapping_spec = tuple(cols)
         def keyfunc(value):
-            m = object_mapper(value)
-            return tuple([m._get_attr_by_column(value, c) for c in mapping_spec])
+            state = instance_state(value)
+            m = _state_mapper(state)
+            return tuple([m._get_state_attr_by_column(state, c)
+                          for c in mapping_spec])
     return lambda: MappedCollection(keyfunc)
 
 def attribute_mapped_collection(attr_name):
@@ -150,8 +160,8 @@ def attribute_mapped_collection(attr_name):
     can not, for example, map on foreign key values if those key values will
     change during the session, i.e. from None to a database-assigned integer
     after a session flush.
-    """
 
+    """
     return lambda: MappedCollection(attrgetter(attr_name))
 
 
@@ -165,8 +175,8 @@ def mapped_collection(keyfunc):
     can not, for example, map on foreign key values if those key values will
     change during the session, i.e. from None to a database-assigned integer
     after a session flush.
-    """
 
+    """
     return lambda: MappedCollection(keyfunc)
 
 class collection(object):
@@ -193,8 +203,8 @@ class collection(object):
     Decorators can be specified in long-hand for Python 2.3, or with
     the class-level dict attribute '__instrumentation__'- see the source
     for details.
-    """
 
+    """
     # Bundled as a class solely for ease of use: packaging, doc strings,
     # importability.
 
@@ -236,8 +246,8 @@ class collection(object):
         If the appender method is internally instrumented, you must also
         receive the keyword argument '_sa_initiator' and ensure its
         promulgation to collection events.
-        """
 
+        """
         setattr(fn, '_sa_instrument_role', 'appender')
         return fn
     appender = classmethod(appender)
@@ -263,8 +273,8 @@ class collection(object):
         If the remove method is internally instrumented, you must also
         receive the keyword argument '_sa_initiator' and ensure its
         promulgation to collection events.
-        """
 
+        """
         setattr(fn, '_sa_instrument_role', 'remover')
         return fn
     remover = classmethod(remover)
@@ -277,8 +287,8 @@ class collection(object):
 
             @collection.iterator
             def __iter__(self): ...
-        """
 
+        """
         setattr(fn, '_sa_instrument_role', 'iterator')
         return fn
     iterator = classmethod(iterator)
@@ -297,8 +307,8 @@ class collection(object):
             # never be called, unless:
             @collection.internally_instrumented
             def extend(self, items): ...
-        """
 
+        """
         setattr(fn, '_sa_instrumented', True)
         return fn
     internally_instrumented = classmethod(internally_instrumented)
@@ -311,8 +321,8 @@ class collection(object):
         invoked immediately after the '_sa_adapter' property is set on
         the instance.  A single argument is passed: the collection adapter
         that has been linked, or None if unlinking.
-        """
 
+        """
         setattr(fn, '_sa_instrument_role', 'on_link')
         return fn
     on_link = classmethod(on_link)
@@ -344,8 +354,8 @@ class collection(object):
         Supply an implementation of this method if you want to expand the
         range of possible types that can be assigned in bulk or perform
         validation on the values about to be assigned.
-        """
 
+        """
         setattr(fn, '_sa_instrument_role', 'converter')
         return fn
     converter = classmethod(converter)
@@ -362,8 +372,8 @@ class collection(object):
 
             @collection.adds('entity')
             def do_stuff(self, thing, entity=None): ...
-        """
 
+        """
         def decorator(fn):
             setattr(fn, '_sa_instrument_before', ('fire_append_event', arg))
             return fn
@@ -382,8 +392,8 @@ class collection(object):
 
             @collection.replaces(2)
             def __setitem__(self, index, item): ...
-        """
 
+        """
         def decorator(fn):
             setattr(fn, '_sa_instrument_before', ('fire_append_event', arg))
             setattr(fn, '_sa_instrument_after', 'fire_remove_event')
@@ -404,8 +414,8 @@ class collection(object):
 
         For methods where the value to remove is not known at call-time, use
         collection.removes_return.
-        """
 
+        """
         def decorator(fn):
             setattr(fn, '_sa_instrument_before', ('fire_remove_event', arg))
             return fn
@@ -424,8 +434,8 @@ class collection(object):
 
         For methods where the value to remove is known at call-time, use
         collection.remove.
-        """
 
+        """
         def decorator(fn):
             setattr(fn, '_sa_instrument_after', 'fire_remove_event')
             return fn
@@ -437,7 +447,6 @@ class collection(object):
 # implementations
 def collection_adapter(collection):
     """Fetch the CollectionAdapter for a collection."""
-
     return getattr(collection, '_sa_adapter', None)
 
 def collection_iter(collection):
@@ -445,8 +454,8 @@ def collection_iter(collection):
 
     If the collection is an ORM collection, it need not be attached to an
     object to be iterable.
-    """
 
+    """
     try:
         return getattr(collection, '_sa_iterator',
                        getattr(collection, '__iter__'))()
@@ -464,8 +473,8 @@ class CollectionAdapter(object):
 
     The ORM uses an CollectionAdapter exclusively for interaction with
     entity collections.
-    """
 
+    """
     def __init__(self, attr, owner_state, data):
         self.attr = attr
         self._data = weakref.ref(data)
@@ -477,14 +486,12 @@ class CollectionAdapter(object):
 
     def link_to_self(self, data):
         """Link a collection to this adapter, and fire a link event."""
-
         setattr(data, '_sa_adapter', self)
         if hasattr(data, '_sa_on_link'):
             getattr(data, '_sa_on_link')(self)
 
     def unlink(self, data):
         """Unlink a collection from any adapter, and fire a link event."""
-
         setattr(data, '_sa_adapter', None)
         if hasattr(data, '_sa_on_link'):
             getattr(data, '_sa_on_link')(None)
@@ -501,8 +508,8 @@ class CollectionAdapter(object):
 
         If a converter implementation is not supplied on the collection,
         a default duck-typing-based implementation is used.
-        """
 
+        """
         converter = getattr(self._data(), '_sa_converter', None)
         if converter is not None:
             return converter(obj)
@@ -531,44 +538,36 @@ class CollectionAdapter(object):
 
     def append_with_event(self, item, initiator=None):
         """Add an entity to the collection, firing mutation events."""
-
         getattr(self._data(), '_sa_appender')(item, _sa_initiator=initiator)
 
     def append_without_event(self, item):
         """Add or restore an entity to the collection, firing no events."""
-
         getattr(self._data(), '_sa_appender')(item, _sa_initiator=False)
 
     def remove_with_event(self, item, initiator=None):
         """Remove an entity from the collection, firing mutation events."""
-
         getattr(self._data(), '_sa_remover')(item, _sa_initiator=initiator)
 
     def remove_without_event(self, item):
         """Remove an entity from the collection, firing no events."""
-
         getattr(self._data(), '_sa_remover')(item, _sa_initiator=False)
 
     def clear_with_event(self, initiator=None):
         """Empty the collection, firing a mutation event for each entity."""
-
         for item in list(self):
             self.remove_with_event(item, initiator)
 
     def clear_without_event(self):
         """Empty the collection, firing no events."""
-
         for item in list(self):
             self.remove_without_event(item)
 
     def __iter__(self):
         """Iterate over entities in the collection."""
-
         return getattr(self._data(), '_sa_iterator')()
 
     def __len__(self):
         """Count entities in the collection."""
-
         return len(list(getattr(self._data(), '_sa_iterator')()))
 
     def __nonzero__(self):
@@ -580,8 +579,8 @@ class CollectionAdapter(object):
         Initiator is the InstrumentedAttribute that initiated the membership
         mutation, and should be left as None unless you are passing along
         an initiator value from a chained operation.
-        """
 
+        """
         if initiator is not False and item is not None:
             self.attr.fire_append_event(self.owner_state, item, initiator)
 
@@ -591,8 +590,8 @@ class CollectionAdapter(object):
         Initiator is the InstrumentedAttribute that initiated the membership
         mutation, and should be left as None unless you are passing along
         an initiator value from a chained operation.
-        """
 
+        """
         if initiator is not False and item is not None:
             self.attr.fire_remove_event(self.owner_state, item, initiator)
 
@@ -601,8 +600,8 @@ class CollectionAdapter(object):
 
         Only called if the entity cannot be removed after calling
         fire_remove_event().
-        """
 
+        """
         self.attr.fire_pre_remove_event(self.owner_state, initiator=initiator)
 
     def __getstate__(self):
@@ -653,8 +652,7 @@ def bulk_replace(values, existing_adapter, new_adapter):
         for member in removals:
             existing_adapter.remove_with_event(member)
 
-__instrumentation_mutex = sautil.threading.Lock()
-def _prepare_instrumentation(factory):
+def prepare_instrumentation(factory):
     """Prepare a callable for future use as a collection class factory.
 
     Given a collection class factory (either a type or no-arg callable),
@@ -663,8 +661,8 @@ def _prepare_instrumentation(factory):
 
     This function is responsible for converting collection_class=list
     into the run-time behavior of collection_class=InstrumentedList.
-    """
 
+    """
     # Convert a builtin to 'Instrumented*'
     if factory in __canned_instrumentation:
         factory = __canned_instrumentation[factory]
@@ -694,8 +692,8 @@ def __converting_factory(original_factory):
     Given a collection factory that returns a builtin type (e.g. a list),
     return a wrapped function that converts that type to one of our
     instrumented types.
-    """
 
+    """
     def wrapper():
         collection = original_factory()
         type_ = type(collection)
@@ -704,7 +702,7 @@ def __converting_factory(original_factory):
             # collection
             return __canned_instrumentation[type_](collection)
         else:
-            raise exceptions.InvalidRequestError(
+            raise sa_exc.InvalidRequestError(
                 "Collection class factories must produce instances of a "
                 "single class.")
     try:
@@ -717,7 +715,6 @@ def __converting_factory(original_factory):
 
 def _instrument_class(cls):
     """Modify methods in a class and install instrumentation."""
-
     # FIXME: more formally document this as a decoratorless/Python 2.3
     # option for specifying instrumentation.  (likely doc'd here in code only,
     # not in online docs.)
@@ -737,7 +734,7 @@ def _instrument_class(cls):
     # types is transformed into one of our trivial subclasses
     # (e.g. InstrumentedList).  Catch anything else that sneaks in here...
     if cls.__module__ == '__builtin__':
-        raise exceptions.ArgumentError(
+        raise sa_exc.ArgumentError(
             "Can not instrument a built-in type. Use a "
             "subclass, even a trivial one.")
 
@@ -790,7 +787,7 @@ def _instrument_class(cls):
     # ensure all roles are present, and apply implicit instrumentation if
     # needed
     if 'appender' not in roles or not hasattr(cls, roles['appender']):
-        raise exceptions.ArgumentError(
+        raise sa_exc.ArgumentError(
             "Type %s must elect an appender method to be "
             "a collection class" % cls.__name__)
     elif (roles['appender'] not in methods and
@@ -798,7 +795,7 @@ def _instrument_class(cls):
         methods[roles['appender']] = ('fire_append_event', 1, None)
 
     if 'remover' not in roles or not hasattr(cls, roles['remover']):
-        raise exceptions.ArgumentError(
+        raise sa_exc.ArgumentError(
             "Type %s must elect a remover method to be "
             "a collection class" % cls.__name__)
     elif (roles['remover'] not in methods and
@@ -806,7 +803,7 @@ def _instrument_class(cls):
         methods[roles['remover']] = ('fire_remove_event', 1, None)
 
     if 'iterator' not in roles or not hasattr(cls, roles['iterator']):
-        raise exceptions.ArgumentError(
+        raise sa_exc.ArgumentError(
             "Type %s must elect an iterator method to be "
             "a collection class" % cls.__name__)
 
@@ -824,7 +821,6 @@ def _instrument_class(cls):
 
 def _instrument_membership_mutator(method, before, argument, after):
     """Route method args and/or return value through the collection adapter."""
-
     # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))'
     if before:
         fn_args = list(sautil.flatten_iterator(inspect.getargspec(method)[0]))
@@ -843,7 +839,7 @@ def _instrument_membership_mutator(method, before, argument, after):
         if before:
             if pos_arg is None:
                 if named_arg not in kw:
-                    raise exceptions.ArgumentError(
+                    raise sa_exc.ArgumentError(
                         "Missing argument %s" % argument)
                 value = kw[named_arg]
             else:
@@ -852,7 +848,7 @@ def _instrument_membership_mutator(method, before, argument, after):
                 elif named_arg in kw:
                     value = kw[named_arg]
                 else:
-                    raise exceptions.ArgumentError(
+                    raise sa_exc.ArgumentError(
                         "Missing argument %s" % argument)
 
         initiator = kw.pop('_sa_initiator', None)
@@ -881,7 +877,6 @@ def _instrument_membership_mutator(method, before, argument, after):
 
 def __set(collection, item, _sa_initiator=None):
     """Run set events, may eventually be inlined into decorators."""
-
     if _sa_initiator is not False and item is not None:
         executor = getattr(collection, '_sa_adapter', None)
         if executor:
@@ -889,7 +884,6 @@ def __set(collection, item, _sa_initiator=None):
 
 def __del(collection, item, _sa_initiator=None):
     """Run del events, may eventually be inlined into decorators."""
-
     if _sa_initiator is not False and item is not None:
         executor = getattr(collection, '_sa_adapter', None)
         if executor:
@@ -897,14 +891,12 @@ def __del(collection, item, _sa_initiator=None):
 
 def __before_delete(collection, _sa_initiator=None):
     """Special method to run 'commit existing value' methods"""
-
     executor = getattr(collection, '_sa_adapter', None)
     if executor:
         getattr(executor, 'fire_pre_remove_event')(_sa_initiator)
 
 def _list_decorators():
-    """Hand-turned instrumentation wrappers that can decorate any list-like
-    class."""
+    """Tailored instrumentation wrappers for any list-like class."""
 
     def _tidy(fn):
         setattr(fn, '_sa_instrumented', True)
@@ -1045,14 +1037,13 @@ def _list_decorators():
     return l
 
 def _dict_decorators():
-    """Hand-turned instrumentation wrappers that can decorate any dict-like
-    mapping class."""
+    """Tailored instrumentation wrappers for any dict-like mapping class."""
 
     def _tidy(fn):
         setattr(fn, '_sa_instrumented', True)
         fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__')
 
-    Unspecified=sautil.symbol('Unspecified')
+    Unspecified = sautil.symbol('Unspecified')
 
     def __setitem__(fn):
         def __setitem__(self, key, value, _sa_initiator=None):
@@ -1157,14 +1148,13 @@ def _set_binops_check_loose(self, obj):
 
 
 def _set_decorators():
-    """Hand-turned instrumentation wrappers that can decorate any set-like
-    sequence class."""
+    """Tailored instrumentation wrappers for any set-like class."""
 
     def _tidy(fn):
         setattr(fn, '_sa_instrumented', True)
         fn.__doc__ = getattr(getattr(Set, fn.__name__), '__doc__')
 
-    Unspecified=sautil.symbol('Unspecified')
+    Unspecified = sautil.symbol('Unspecified')
 
     def add(fn):
         def add(self, value, _sa_initiator=None):
@@ -1365,6 +1355,7 @@ class MappedCollection(dict):
     ``set`` and ``remove`` are implemented in terms of a keying function: any
     callable that takes an object and returns an object for use as a dictionary
     key.
+
     """
 
     def __init__(self, keyfunc):
@@ -1374,16 +1365,17 @@ class MappedCollection(dict):
         returns an object for use as a dictionary key.
 
         The keyfunc will be called every time the ORM needs to add a member by
-        value-only (such as when loading instances from the database) or remove
-        a member.  The usual cautions about dictionary keying apply-
+        value-only (such as when loading instances from the database) or
+        remove a member.  The usual cautions about dictionary keying apply-
         ``keyfunc(object)`` should return the same output for the life of the
         collection.  Keying based on mutable properties can result in
         unreachable instances "lost" in the collection.
+
         """
         self.keyfunc = keyfunc
 
     def set(self, value, _sa_initiator=None):
-        """Add an item to the collection, with a key provided by this instance's keyfunc."""
+        """Add an item by value, consulting the keyfunc for the key."""
 
         key = self.keyfunc(value)
         self.__setitem__(key, value, _sa_initiator)
@@ -1391,13 +1383,13 @@ class MappedCollection(dict):
     set = collection.appender(set)
 
     def remove(self, value, _sa_initiator=None):
-        """Remove an item from the collection by value, consulting this instance's keyfunc for the key."""
+        """Remove an item by value, consulting the keyfunc for the key."""
 
         key = self.keyfunc(value)
         # Let self[key] raise if key is not in this collection
         # testlib.pragma exempt:__ne__
         if self[key] != value:
-            raise exceptions.InvalidRequestError(
+            raise sa_exc.InvalidRequestError(
                 "Can not remove '%s': collection holds '%s' for key '%s'. "
                 "Possible cause: is the MappedCollection key function "
                 "based on mutable properties or properties that only obtain "
@@ -1418,8 +1410,8 @@ class MappedCollection(dict):
         Raises a TypeError if the key in any (key, value) pair in the dictlike
         object does not match the key that this collection's keyfunc would
         have assigned for that value.
-        """
 
+        """
         for incoming_key, value in sautil.dictlike_iteritems(dictlike):
             new_key = self.keyfunc(value)
             if incoming_key != new_key:
index c667460a71796bece5c567ab23a36b95dc0e674c..24bbdadcee08ab65333cebcf4acd19b53d7adfa9 100644 (file)
@@ -4,14 +4,17 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
+"""Relationship dependencies.
 
-"""Bridge the ``PropertyLoader`` (i.e. a ``relation()``) and the
+Bridges the ``PropertyLoader`` (i.e. a ``relation()``) and the
 ``UOWTransaction`` together to allow processing of relation()-based
- dependencies at flush time.
+dependencies at flush time.
+
 """
 
-from sqlalchemy.orm import sync
-from sqlalchemy import sql, util, exceptions
+from sqlalchemy import sql, util
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy.orm import attributes, exc, sync
 from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
 
 
@@ -21,10 +24,7 @@ def create_dependency_processor(prop):
         MANYTOONE: ManyToOneDP,
         MANYTOMANY : ManyToManyDP,
     }
-    if prop.association is not None:
-        return AssociationDP(prop)
-    else:
-        return types[prop.direction](prop)
+    return types[prop.direction](prop)
 
 class DependencyProcessor(object):
     no_dependencies = False
@@ -36,7 +36,7 @@ class DependencyProcessor(object):
         self.parent = prop.parent
         self.secondary = prop.secondary
         self.direction = prop.direction
-        self.is_backref = prop.is_backref
+        self.is_backref = prop._is_backref
         self.post_update = prop.post_update
         self.foreign_keys = prop.foreign_keys
         self.passive_deletes = prop.passive_deletes
@@ -44,21 +44,21 @@ class DependencyProcessor(object):
         self.enable_typechecks = prop.enable_typechecks
         self.key = prop.key
         if not self.prop.synchronize_pairs:
-            raise exceptions.ArgumentError("Can't build a DependencyProcessor for relation %s.  No target attributes to populate between parent and child are present" % self.prop)
+            raise sa_exc.ArgumentError("Can't build a DependencyProcessor for relation %s.  No target attributes to populate between parent and child are present" % self.prop)
 
     def _get_instrumented_attribute(self):
         """Return the ``InstrumentedAttribute`` handled by this
         ``DependencyProecssor``.
         """
 
-        return getattr(self.parent.class_, self.key)
+        return self.parent.class_manager.get_impl(self.key)
 
     def hasparent(self, state):
         """return True if the given object instance has a parent,
         according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``."""
 
         # TODO: use correct API for this
-        return self._get_instrumented_attribute().impl.hasparent(state)
+        return self._get_instrumented_attribute().hasparent(state)
 
     def register_dependencies(self, uowcommit):
         """Tell a ``UOWTransaction`` what mappers are dependent on
@@ -111,7 +111,7 @@ class DependencyProcessor(object):
         if not self.enable_typechecks:
             return
         if state is not None and not self.mapper._canload(state):
-            raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type.  Did you mean to use a polymorphic mapper for this relationship ?  Set 'enable_typechecks=False' on the relation() to disable this exception.  Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (state.class_, self.prop, self.mapper))
+            raise exc.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type.  Did you mean to use a polymorphic mapper for this relationship ?  Set 'enable_typechecks=False' on the relation() to disable this exception.  Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (state.class_, self.prop, self.mapper))
 
     def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
         """Called during a flush to synchronize primary key identifier
@@ -167,9 +167,9 @@ class OneToManyDP(DependencyProcessor):
             # the child objects have to have their foreign key to the parent set to NULL
             # this phase can be called safely for any cascade but is unnecessary if delete cascade
             # is on.
-            if self.post_update or not self.passive_deletes=='all':
+            if self.post_update or not self.passive_deletes == 'all':
                 for state in deplist:
-                    (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+                    (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
                     if unchanged or deleted:
                         for child in deleted:
                             if child is not None and self.hasparent(child) is False:
@@ -204,9 +204,9 @@ class OneToManyDP(DependencyProcessor):
             # head object is being deleted, and we manage its list of child objects
             # the child objects have to have their foreign key to the parent set to NULL
             if not self.post_update:
-                should_null_fks = not self.cascade.delete and not self.passive_deletes=='all'
+                should_null_fks = not self.cascade.delete and not self.passive_deletes == 'all'
                 for state in deplist:
-                    (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+                    (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
                     if unchanged or deleted:
                         for child in deleted:
                             if child is not None and self.hasparent(child) is False:
@@ -220,7 +220,7 @@ class OneToManyDP(DependencyProcessor):
                                     uowcommit.register_object(child)
         else:
             for state in deplist:
-                (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True)
+                (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=True)
                 if added or deleted:
                     for child in added:
                         if child is not None:
@@ -231,7 +231,9 @@ class OneToManyDP(DependencyProcessor):
                         elif self.hasparent(child) is False:
                             uowcommit.register_object(child, isdelete=True)
                             for c, m in self.mapper.cascade_iterator('delete', child):
-                                uowcommit.register_object(c._state, isdelete=True)
+                                uowcommit.register_object(
+                                    attributes.instance_state(c),
+                                    isdelete=True)
                 if not self.passive_updates and self._pks_changed(uowcommit, state):
                     if not unchanged:
                         (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=False)
@@ -287,10 +289,10 @@ class DetectKeySwitch(DependencyProcessor):
             for s in [elem for elem in uowcommit.session.identity_map.all_states()
                 if issubclass(elem.class_, self.parent.class_) and
                     self.key in elem.dict and
-                    elem.dict[self.key]._state in switchers
+                    attributes.instance_state(elem.dict[self.key]) in switchers
                 ]:
                 uowcommit.register_object(s, listonly=self.passive_updates)
-                sync.populate(s.dict[self.key]._state, self.mapper, s, self.parent, self.prop.synchronize_pairs)
+                sync.populate(attributes.instance_state(s.dict[self.key]), self.mapper, s, self.parent, self.prop.synchronize_pairs)
                 #self.syncrules.execute(s.dict[self.key]._state, s, None, None, False)
 
     def _pks_changed(self, uowcommit, state):
@@ -316,17 +318,17 @@ class ManyToOneDP(DependencyProcessor):
     def process_dependencies(self, task, deplist, uowcommit, delete = False):
         #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
         if delete:
-            if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes=='all':
+            if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all':
                 # post_update means we have to update our row to not reference the child object
                 # before we can DELETE the row
                 for state in deplist:
                     self._synchronize(state, None, None, True, uowcommit)
-                    (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+                    (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
                     if added or unchanged or deleted:
                         self._conditional_post_update(state, uowcommit, deleted + unchanged + added)
         else:
             for state in deplist:
-                (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True)
+                (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=True)
                 if added or deleted or unchanged:
                     for child in added:
                         self._synchronize(state, child, None, False, uowcommit)
@@ -339,7 +341,7 @@ class ManyToOneDP(DependencyProcessor):
         if delete:
             if self.cascade.delete or self.cascade.delete_orphan:
                 for state in deplist:
-                    (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+                    (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
                     if self.cascade.delete_orphan:
                         todelete = added + unchanged + deleted
                     else:
@@ -349,18 +351,21 @@ class ManyToOneDP(DependencyProcessor):
                             continue
                         uowcommit.register_object(child, isdelete=True)
                         for c, m in self.mapper.cascade_iterator('delete', child):
-                            uowcommit.register_object(c._state, isdelete=True)
+                            uowcommit.register_object(
+                                attributes.instance_state(c), isdelete=True)
         else:
             for state in deplist:
                 uowcommit.register_object(state)
                 if self.cascade.delete_orphan:
-                    (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+                    (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
                     if deleted:
                         for child in deleted:
                             if self.hasparent(child) is False:
                                 uowcommit.register_object(child, isdelete=True)
                                 for c, m in self.mapper.cascade_iterator('delete', child):
-                                    uowcommit.register_object(c._state, isdelete=True)
+                                    uowcommit.register_object(
+                                        attributes.instance_state(c),
+                                        isdelete=True)
 
 
     def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
@@ -400,7 +405,7 @@ class ManyToManyDP(DependencyProcessor):
 
         if delete:
             for state in deplist:
-                (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes)
+                (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes)
                 if deleted or unchanged:
                     for child in deleted + unchanged:
                         if child is None or (reverse_dep and (reverse_dep, "manytomany", child, state) in uowcommit.attributes):
@@ -443,13 +448,13 @@ class ManyToManyDP(DependencyProcessor):
             statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow]))
             result = connection.execute(statement, secondary_delete)
             if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_delete):
-                raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of secondary table rows deleted from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_delete)))
+                raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of secondary table rows deleted from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_delete)))
 
         if secondary_update:
             statement = self.secondary.update(sql.and_(*[c == sql.bindparam("old_" + c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow]))
             result = connection.execute(statement, secondary_update)
             if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update):
-                raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of secondary table rows updated from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_update)))
+                raise exc.ConcurrentModificationError("Updated rowcount %d does not match number of secondary table rows updated from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_update)))
 
         if secondary_insert:
             statement = self.secondary.insert()
@@ -459,13 +464,14 @@ class ManyToManyDP(DependencyProcessor):
         #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " preprocess_dep isdelete " + repr(delete) + " direction " + repr(self.direction)
         if not delete:
             for state in deplist:
-                (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True)
+                (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=True)
                 if deleted:
                     for child in deleted:
                         if self.cascade.delete_orphan and self.hasparent(child) is False:
                             uowcommit.register_object(child, isdelete=True)
                             for c, m in self.mapper.cascade_iterator('delete', child):
-                                uowcommit.register_object(c._state, isdelete=True)
+                                uowcommit.register_object(
+                                    attributes.instance_state(c), isdelete=True)
 
     def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
         if associationrow is None:
@@ -478,12 +484,6 @@ class ManyToManyDP(DependencyProcessor):
     def _pks_changed(self, uowcommit, state):
         return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
 
-class AssociationDP(OneToManyDP):
-    def __init__(self, *args, **kwargs):
-        super(AssociationDP, self).__init__(*args, **kwargs)
-        self.cascade.delete = True
-        self.cascade.delete_orphan = True
-
 class MapperStub(object):
     """Pose as a Mapper representing the association table in a
     many-to-many join, when performing a ``flush()``.
index 133ad99c897912be7c37e5b654901a604f9c7a4e..08e6a57f401c3df62a1fca938535de4ddccf6b49 100644 (file)
@@ -1,8 +1,21 @@
-"""'dynamic' collection API.  returns Query() objects on the 'read' side, alters
-a special AttributeHistory on the 'write' side."""
+# dynamic.py
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from sqlalchemy import exceptions, util, logging
-from sqlalchemy.orm import attributes, object_session, util as mapperutil, strategies
+"""Dynamic collection API.
+
+Dynamic collections act like Query() objects for read operations and support
+basic add/delete mutation.
+
+"""
+
+from sqlalchemy import log, util
+import sqlalchemy.exceptions as sa_exc
+
+from sqlalchemy.orm import attributes, object_session, \
+     util as mapperutil, strategies
 from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.mapper import has_identity, object_mapper
 
@@ -12,16 +25,19 @@ class DynaLoader(strategies.AbstractRelationLoader):
         self.is_class_level = True
         self._register_attribute(self.parent.class_, impl_class=DynamicAttributeImpl, target_mapper=self.parent_property.mapper, order_by=self.parent_property.order_by)
 
-    def create_row_processor(self, selectcontext, mapper, row):
-        return (None, None, None)
+    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+        return (None, None)
 
-DynaLoader.logger = logging.class_logger(DynaLoader)
+DynaLoader.logger = log.class_logger(DynaLoader)
 
 class DynamicAttributeImpl(attributes.AttributeImpl):
-    def __init__(self, class_, key, typecallable, target_mapper, order_by, **kwargs):
-        super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs)
+    uses_objects = True
+    accepts_scalar_loader = False
+    
+    def __init__(self, class_, key, typecallable, class_manager, target_mapper, order_by, **kwargs):
+        super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, class_manager, **kwargs)
         self.target_mapper = target_mapper
-        self.order_by=order_by
+        self.order_by = order_by
         self.query_class = AppenderQuery
 
     def get(self, state, passive=False):
@@ -41,20 +57,18 @@ class DynamicAttributeImpl(attributes.AttributeImpl):
         state.modified = True
 
         if self.trackparent and value is not None:
-            self.sethasparent(value._state, True)
-        instance = state.obj()
+            self.sethasparent(attributes.instance_state(value), True)
         for ext in self.extensions:
-            ext.append(instance, value, initiator or self)
+            ext.append(state, value, initiator or self)
 
     def fire_remove_event(self, state, value, initiator):
         state.modified = True
 
         if self.trackparent and value is not None:
-            self.sethasparent(value._state, False)
+            self.sethasparent(attributes.instance_state(value), False)
 
-        instance = state.obj()
         for ext in self.extensions:
-            ext.remove(instance, value, initiator or self)
+            ext.remove(state, value, initiator or self)
         
     def set(self, state, value, initiator):
         if initiator is self:
@@ -111,26 +125,32 @@ class AppenderQuery(Query):
     
     def session(self):
         return self.__session()
-    session = property(session)
+    session = property(session, lambda s, x:None)
     
     def __iter__(self):
         sess = self.__session()
         if sess is None:
-            return iter(self.attr._get_collection_history(self.instance._state, passive=True).added_items)
+            return iter(self.attr._get_collection_history(
+                attributes.instance_state(self.instance),
+                passive=True).added_items)
         else:
             return iter(self._clone(sess))
 
     def __getitem__(self, index):
         sess = self.__session()
         if sess is None:
-            return self.attr._get_collection_history(self.instance._state, passive=True).added_items.__getitem__(index)
+            return self.attr._get_collection_history(
+                attributes.instance_state(self.instance),
+                passive=True).added_items.__getitem__(index)
         else:
             return self._clone(sess).__getitem__(index)
     
     def count(self):
         sess = self.__session()
         if sess is None:
-            return len(self.attr._get_collection_history(self.instance._state, passive=True).added_items)
+            return len(self.attr._get_collection_history(
+                attributes.instance_state(self.instance),
+                passive=True).added_items)
         else:
             return self._clone(sess).count()
     
@@ -142,10 +162,7 @@ class AppenderQuery(Query):
         if sess is None:
             sess = object_session(instance)
             if sess is None:
-                try:
-                    sess = object_mapper(instance).get_session()
-                except exceptions.InvalidRequestError:
-                    raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.instance_str(instance), self.attr.key))
+                raise sa_exc.UnboundExecutionError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.instance_str(instance), self.attr.key))
 
         q = sess.query(self.attr.target_mapper).with_parent(instance, self.attr.key)
         if self.attr.order_by:
@@ -158,14 +175,14 @@ class AppenderQuery(Query):
             oldlist = list(self)
         else:
             oldlist = []
-        self.attr._get_collection_history(self.instance._state, passive=True).replace(oldlist, collection)
+        self.attr._get_collection_history(attributes.instance_state(self.instance), passive=True).replace(oldlist, collection)
         return oldlist
         
     def append(self, item):
-        self.attr.append(self.instance._state, item, None)
+        self.attr.append(attributes.instance_state(self.instance), item, None)
 
     def remove(self, item):
-        self.attr.remove(self.instance._state, item, None)
+        self.attr.remove(attributes.instance_state(self.instance), item, None)
 
             
 class CollectionHistory(object): 
diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py
new file mode 100644 (file)
index 0000000..2d1d2b1
--- /dev/null
@@ -0,0 +1,31 @@
+# exc.py - ORM exceptions
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+"""SQLAlchemy ORM exceptions."""
+
+import sqlalchemy.exceptions as sa_exc
+
+
+class ConcurrentModificationError(sa_exc.SQLAlchemyError):
+    """Rows have been modified outside of the unit of work."""
+
+
+class FlushError(sa_exc.SQLAlchemyError):
+    """A invalid condition was detected during flush()."""
+
+
+class ObjectDeletedError(sa_exc.InvalidRequestError):
+    """An refresh() operation failed to re-retrieve an object's row."""
+
+
+class UnmappedColumnError(sa_exc.InvalidRequestError):
+    """Mapping operation was requested on an unknown column."""
+
+
+# Legacy compat until 0.6.
+sa_exc.ConcurrentModificationError = ConcurrentModificationError
+sa_exc.FlushError = FlushError
+sa_exc.UnmappedColumnError
diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py
new file mode 100644 (file)
index 0000000..4487e21
--- /dev/null
@@ -0,0 +1,250 @@
+# identity.py
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import weakref
+
+from sqlalchemy import util as base_util
+from sqlalchemy.orm import attributes
+
+
+class IdentityMap(dict):
+    def __init__(self):
+        self._mutable_attrs = weakref.WeakKeyDictionary()
+        self.modified = False
+        
+    def add(self, state):
+        raise NotImplementedError()
+    
+    def remove(self, state):
+        raise NotImplementedError()
+    
+    def update(self, dict):
+        raise NotImplementedError("IdentityMap uses add() to insert data")
+    
+    def clear(self):
+        raise NotImplementedError("IdentityMap uses remove() to remove data")
+        
+    def _manage_incoming_state(self, state):
+        if state.modified:  
+            self.modified = True
+        if state.manager.mutable_attributes:
+            self._mutable_attrs[state] = True
+    
+    def _manage_removed_state(self, state):
+        if state in self._mutable_attrs:
+            del self._mutable_attrs[state]
+            
+    def check_modified(self):
+        """return True if any InstanceStates present have been marked as 'modified'."""
+        
+        if not self.modified:
+            for state in self._mutable_attrs:
+                if state.check_modified():
+                    return True
+            else:
+                return False
+        else:
+            return True
+            
+    def has_key(self, key):
+        return key in self
+        
+    def popitem(self):
+        raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+    def pop(self, key, *args):
+        raise NotImplementedError("IdentityMap uses remove() to remove data")
+
+    def setdefault(self, key, default=None):
+        raise NotImplementedError("IdentityMap uses add() to insert data")
+
+    def copy(self):
+        raise NotImplementedError()
+
+    def __setitem__(self, key, value):
+        raise NotImplementedError("IdentityMap uses add() to insert data")
+
+    def __delitem__(self, key):
+        raise NotImplementedError("IdentityMap uses remove() to remove data")
+        
+class WeakInstanceDict(IdentityMap):
+
+    def __init__(self):
+        IdentityMap.__init__(self)
+        self._wr = weakref.ref(self)
+        # RLock because the mutex is used by a cleanup
+        # handler, which can be called at any time (including within an already mutexed block)
+        self._mutex = base_util.threading.RLock()
+
+    def __getitem__(self, key):
+        state = dict.__getitem__(self, key)
+        o = state.obj()
+        if o is None:
+            o = state._check_resurrect(self)
+        if o is None:
+            raise KeyError, key
+        return o
+
+    def __contains__(self, key):
+        try:
+            state = dict.__getitem__(self, key)
+            o = state.obj()
+            if o is None:
+                o = state._check_resurrect(self)
+        except KeyError:
+            return False
+        return o is not None
+    
+    def contains_state(self, state):
+        return dict.get(self, state.key) is state
+        
+    def add(self, state):
+        if state.key in self:
+            if dict.__getitem__(self, state.key) is not state:
+                raise AssertionError("A conflicting state is already present in the identity map for key %r" % state.key)
+        else:
+            dict.__setitem__(self, state.key, state)
+            state._instance_dict = self._wr
+            self._manage_incoming_state(state)
+    
+    def remove_key(self, key):
+        state = dict.__getitem__(self, key)
+        self.remove(state)
+        
+    def remove(self, state):
+        if not self.contains_state(state):
+            raise AssertionError("State %s is not present in this identity map" % state)
+        dict.__delitem__(self, state.key)
+        del state._instance_dict
+        self._manage_removed_state(state)
+    
+    def discard(self, state):
+        if self.contains_state(state):
+            dict.__delitem__(self, state.key)
+            del state._instance_dict
+            self._manage_removed_state(state)
+        
+    def get(self, key, default=None):
+        try:
+            return self[key]
+        except KeyError:
+            return default
+            
+    def items(self):
+        return list(self.iteritems())
+
+    def iteritems(self):
+        for state in dict.itervalues(self):
+            value = state.obj()
+            if value is not None:
+                yield state.key, value
+
+    def itervalues(self):
+        for state in dict.itervalues(self):
+            instance = state.obj()
+            if instance is not None:
+                yield instance
+
+    def values(self):
+        return list(self.itervalues())
+
+    def all_states(self):
+        return dict.values(self)
+    
+    def prune(self):
+        return 0
+        
+class StrongInstanceDict(IdentityMap):
+    def all_states(self):
+        return [attributes.instance_state(o) for o in self.values()]
+    
+    def contains_state(self, state):
+        return state.key in self and attributes.instance_state(self[state.key]) is state
+    
+    def add(self, state):
+        dict.__setitem__(self, state.key, state.obj())
+        self._manage_incoming_state(state)
+    
+    def remove(self, state):
+        if not self.contains_state(state):
+            raise AssertionError("State %s is not present in this identity map" % state)
+        dict.__delitem__(self, state.key)
+        self._manage_removed_state(state)
+    
+    def discard(self, state):
+        if self.contains_state(state):
+            dict.__delitem__(self, state.key)
+            self._manage_removed_state(state)
+            
+    def remove_key(self, key):
+        state = dict.__getitem__(self, key)
+        self.remove(state)
+
+    def prune(self):
+        """prune unreferenced, non-dirty states."""
+        
+        ref_count = len(self)
+        dirty = [s.obj() for s in self.all_states() if s.check_modified()]
+        keepers = weakref.WeakValueDictionary(self)
+        dict.clear(self)
+        dict.update(self, keepers)
+        self.modified = bool(dirty)
+        return ref_count - len(self)
+        
+class IdentityManagedState(attributes.InstanceState):
+    def _instance_dict(self):
+        return None
+    
+    def _check_resurrect(self, instance_dict):
+        instance_dict._mutex.acquire()
+        try:
+            return self.obj() or self.__resurrect(instance_dict)
+        finally:
+            instance_dict._mutex.release()
+    
+    def modified_event(self, attr, should_copy, previous, passive=False):
+        attributes.InstanceState.modified_event(self, attr, should_copy, previous, passive)
+        
+        instance_dict = self._instance_dict()
+        if instance_dict:
+            instance_dict.modified = True
+        
+    def _cleanup(self, ref):
+        # tiptoe around Python GC unpredictableness
+        try:
+            instance_dict = self._instance_dict()
+            instance_dict._mutex.acquire()
+        except:
+            return
+        # the mutexing here is based on the assumption that gc.collect()
+        # may be firing off cleanup handlers in a different thread than that
+        # which is normally operating upon the instance dict.
+        try:
+            try:
+                self.__resurrect(instance_dict)
+            except:
+                # catch app cleanup exceptions.  no other way around this
+                # without warnings being produced
+                pass
+        finally:
+            instance_dict._mutex.release()
+
+    def __resurrect(self, instance_dict):
+        if self.check_modified():
+            # store strong ref'ed version of the object; will revert
+            # to weakref when changes are persisted
+            obj = self.manager.new_instance(state=self)
+            self.obj = weakref.ref(obj, self._cleanup)
+            self._strong_obj = obj
+            # todo: revisit this wrt user-defined-state
+            obj.__dict__.update(self.dict)
+            self.dict = obj.__dict__
+            self._run_on_load(obj)
+            return obj
+        else:
+            instance_dict.remove(self)
+            self.dispose()
+            return None
index d61ebe9603f7ab1962c0dc3355b76bb40a5eed32..6c9fe775331c8ba6da0efa571c0a4b87134e153b 100644 (file)
@@ -4,27 +4,45 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-"""Semi-private implementation objects which form the basis
-of ORM-mapped attributes, query options and mapper extension.
+"""
+
+Semi-private implementation objects which form the basis of ORM-mapped
+attributes, query options and mapper extension.
+
+Defines the [sqlalchemy.orm.interfaces#MapperExtension] class, which can be
+end-user subclassed to add event-based functionality to mappers.  The
+remainder of this module is generally private to the ORM.
 
-Defines the [sqlalchemy.orm.interfaces#MapperExtension] class,
-which can be end-user subclassed to add event-based functionality
-to mappers.  The remainder of this module is generally private to the
-ORM.
 """
 
 from itertools import chain
-from sqlalchemy import exceptions, logging, util
+
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import log, util
 from sqlalchemy.sql import expression
-class_mapper = None
 
-__all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
-           'MapperProperty', 'PropComparator', 'StrategizedProperty',
-           'build_path', 'MapperOption',
-           'ExtensionOption', 'PropertyOption',
-           'AttributeExtension', 'StrategizedOption', 'LoaderStrategy' ]
+class_mapper = None
+collections = None
+
+__all__ = (
+    'AttributeExtension',
+    'EXT_CONTINUE',
+    'EXT_STOP',
+    'ExtensionOption',
+    'InstrumentationManager',
+    'LoaderStrategy',
+    'MapperExtension',
+    'MapperOption',
+    'MapperProperty',
+    'PropComparator',
+    'PropertyOption',
+    'SessionExtension',
+    'StrategizedOption',
+    'StrategizedProperty',
+    'build_path',
+    )
 
-EXT_CONTINUE = EXT_PASS = util.symbol('EXT_CONTINUE')
+EXT_CONTINUE = util.symbol('EXT_CONTINUE')
 EXT_STOP = util.symbol('EXT_STOP')
 
 ONETOMANY = util.symbol('ONETOMANY')
@@ -44,10 +62,7 @@ class MapperExtension(object):
     these exception cases, any return value other than EXT_CONTINUE or
     EXT_STOP will be interpreted as equivalent to EXT_STOP.
 
-    EXT_PASS is a synonym for EXT_CONTINUE and is provided for backward
-    compatibility.
     """
-
     def instrument_class(self, mapper, class_):
         return EXT_CONTINUE
 
@@ -57,16 +72,6 @@ class MapperExtension(object):
     def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
         return EXT_CONTINUE
 
-    def get_session(self):
-        """Retrieve a contextual Session instance with which to
-        register a new object.
-
-        Note: this is not called if a session is provided with the
-        `__init__` params (i.e. `_sa_session`).
-        """
-
-        return EXT_CONTINUE
-
     def load(self, query, *args, **kwargs):
         """Override the `load` method of the Query object.
 
@@ -85,43 +90,6 @@ class MapperExtension(object):
 
         return EXT_CONTINUE
 
-    def get_by(self, query, *args, **kwargs):
-        """Override the `get_by` method of the Query object.
-
-        The return value of this method is used as the result of
-        ``query.get_by()`` if the value is anything other than
-        EXT_CONTINUE.
-
-        DEPRECATED.
-        """
-
-        return EXT_CONTINUE
-
-    def select_by(self, query, *args, **kwargs):
-        """Override the `select_by` method of the Query object.
-
-        The return value of this method is used as the result of
-        ``query.select_by()`` if the value is anything other than
-        EXT_CONTINUE.
-
-        DEPRECATED.
-        """
-
-        return EXT_CONTINUE
-
-    def select(self, query, *args, **kwargs):
-        """Override the `select` method of the Query object.
-
-        The return value of this method is used as the result of
-        ``query.select()`` if the value is anything other than
-        EXT_CONTINUE.
-
-        DEPRECATED.
-        """
-
-        return EXT_CONTINUE
-
-
     def translate_row(self, mapper, context, row):
         """Perform pre-processing on the given result row and return a
         new row instance.
@@ -276,6 +244,56 @@ class MapperExtension(object):
 
         return EXT_CONTINUE
 
+class SessionExtension(object):
+    """An extension hook object for Sessions.  Subclasses may be installed into a Session
+    (or sessionmaker) using the ``extension`` keyword argument.
+    """
+
+    def before_commit(self, session):
+        """Execute right before commit is called.
+
+        Note that this may not be per-flush if a longer running transaction is ongoing."""
+
+    def after_commit(self, session):
+        """Execute after a commit has occured.
+
+        Note that this may not be per-flush if a longer running transaction is ongoing."""
+
+    def after_rollback(self, session):
+        """Execute after a rollback has occured.
+
+        Note that this may not be per-flush if a longer running transaction is ongoing."""
+
+    def before_flush(self, session, flush_context, instances):
+        """Execute before flush process has started.
+
+        `instances` is an optional list of objects which were passed to the ``flush()``
+        method.
+        """
+
+    def after_flush(self, session, flush_context):
+        """Execute after flush has completed, but before commit has been called.
+
+        Note that the session's state is still in pre-flush, i.e. 'new', 'dirty',
+        and 'deleted' lists still show pre-flush state as well as the history
+        settings on instance attributes."""
+
+    def after_flush_postexec(self, session, flush_context):
+        """Execute after flush has completed, and after the post-exec state occurs.
+
+        This will be when the 'new', 'dirty', and 'deleted' lists are in their final
+        state.  An actual commit() may or may not have occured, depending on whether or not
+        the flush started its own transaction or participated in a larger transaction.
+        """
+
+    def after_begin(self, session, transaction, connection):
+        """Execute after a transaction is begun on a connection
+
+        `transaction` is the SessionTransaction. This method is called after an
+        engine level transaction is begun on a connection.
+        """
+
+
 class MapperProperty(object):
     """Manage the relationship of a ``Mapper`` to a single class
     attribute, as well as that attribute as it appears on individual
@@ -283,7 +301,7 @@ class MapperProperty(object):
     attribute access, loading behavior, and dependency calculations.
     """
 
-    def setup(self, querycontext, **kwargs):
+    def setup(self, context, entity, path, adapter, **kwargs):
         """Called by Query for the purposes of constructing a SQL statement.
 
         Each MapperProperty associated with the target mapper processes the
@@ -293,8 +311,8 @@ class MapperProperty(object):
 
         pass
 
-    def create_row_processor(self, selectcontext, mapper, row):
-        """Return a 3-tuple consiting of two row processing functions and an instance post-processing function.
+    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+        """Return a 2-tuple consiting of two row processing functions and an instance post-processing function.
 
         Input arguments are the query.SelectionContext and the *first*
         applicable row of a result set obtained within
@@ -305,32 +323,24 @@ class MapperProperty(object):
         columns present in the row (which will be the same columns present in
         all rows) are used to determine the presence and behavior of the
         returned callables.  The callables will then be used to process all
-        rows and to post-process all instances, respectively.
+        rows and instances.
 
         Callables are of the following form::
 
-            def new_execute(instance, row, **flags):
-                # process incoming instance and given row.  the instance is
+            def new_execute(state, row, **flags):
+                # process incoming instance state and given row.  the instance is
                 # "new" and was just created upon receipt of this row.
                 # flags is a dictionary containing at least the following
                 # attributes:
                 #   isnew - indicates if the instance was newly created as a
                 #           result of reading this row
                 #   instancekey - identity key of the instance
-                # optional attribute:
-                #   ispostselect - indicates if this row resulted from a
-                #                  'post' select of additional tables/columns
 
-            def existing_execute(instance, row, **flags):
-                # process incoming instance and given row.  the instance is
+            def existing_execute(state, row, **flags):
+                # process incoming instance state and given row.  the instance is
                 # "existing" and was created based on a previous row.
 
-            def post_execute(instance, **flags):
-                # process instance after all result rows have been processed.
-                # this function should be used to issue additional selections
-                # in order to eagerly load additional properties.
-
-            return (new_execute, existing_execute, post_execute)
+            return (new_execute, existing_execute)
 
         Either of the three tuples can be ``None`` in which case no function
         is called.
@@ -347,20 +357,6 @@ class MapperProperty(object):
 
         return iter([])
 
-    def get_criterion(self, query, key, value):
-        """Return a ``WHERE`` clause suitable for this
-        ``MapperProperty`` corresponding to the given key/value pair,
-        where the key is a column or object property name, and value
-        is a value to be matched.  This is only picked up by
-        ``PropertyLoaders``.
-
-        This is called by a ``Query``'s ``join_by`` method to formulate a set
-        of key/value pairs into a ``WHERE`` criterion that spans multiple
-        tables if needed.
-        """
-
-        return None
-
     def set_parent(self, parent):
         self.parent = parent
 
@@ -427,10 +423,10 @@ class PropComparator(expression.ColumnOperators):
     which returns the MapperProperty associated with this
     PropComparator.
     """
-
-    def expression_element(self):
-        return self.clause_element()
-
+    
+    def __clause_element__(self):
+        raise NotImplementedError("%r" % self)
+        
     def contains_op(a, b):
         return a.contains(b)
     contains_op = staticmethod(contains_op)
@@ -511,37 +507,44 @@ class StrategizedProperty(MapperProperty):
     ``StrategizedOption`` objects via the Query.options() method.
     """
 
-    def _get_context_strategy(self, context):
-        path = context.path
-        return self._get_strategy(context.attributes.get(("loaderstrategy", path), self.strategy.__class__))
-
+    def __get_context_strategy(self, context, path):
+        cls = context.attributes.get(("loaderstrategy", path), None)
+        if cls:
+            try:
+                return self.__all_strategies[cls]
+            except KeyError:
+                return self.__init_strategy(cls)
+        else:
+            return self.strategy
+    
     def _get_strategy(self, cls):
         try:
-            return self._all_strategies[cls]
+            return self.__all_strategies[cls]
         except KeyError:
-            # cache the located strategy per class for faster re-lookup
-            strategy = cls(self)
-            strategy.init()
-            self._all_strategies[cls] = strategy
-            return strategy
-
-    def setup(self, querycontext, **kwargs):
-        self._get_context_strategy(querycontext).setup_query(querycontext, **kwargs)
+            return self.__init_strategy(cls)
+    
+    def __init_strategy(self, cls):
+        self.__all_strategies[cls] = strategy = cls(self)
+        strategy.init()
+        return strategy
+        
+    def setup(self, context, entity, path, adapter, **kwargs):
+        self.__get_context_strategy(context, path + (self.key,)).setup_query(context, entity, path, adapter, **kwargs)
 
-    def create_row_processor(self, selectcontext, mapper, row):
-        return self._get_context_strategy(selectcontext).create_row_processor(selectcontext, mapper, row)
+    def create_row_processor(self, context, path, mapper, row, adapter):
+        return self.__get_context_strategy(context, path + (self.key,)).create_row_processor(context, path, mapper, row, adapter)
 
     def do_init(self):
-        self._all_strategies = {}
-        self.strategy = self._get_strategy(self.strategy_class)
+        self.__all_strategies = {}
+        self.strategy = self.__init_strategy(self.strategy_class)
         if self.is_primary():
             self.strategy.init_class_attribute()
 
-def build_path(mapper, key, prev=None):
+def build_path(entity, key, prev=None):
     if prev:
-        return prev + (mapper.base_mapper, key)
+        return prev + (entity, key)
     else:
-        return (mapper.base_mapper, key)
+        return (entity, key)
 
 def serialize_path(path):
     if path is None:
@@ -585,9 +588,9 @@ class ExtensionOption(MapperOption):
         self.ext = ext
 
     def process_query(self, query):
-        query._extension = query._extension.copy()
-        query._extension.insert(self.ext)
-
+        entity = query._generate_mapper_zero()
+        entity.extension = entity.extension.copy()
+        entity.extension.push(self.ext)
 
 class PropertyOption(MapperOption):
     """A MapperOption that is applied to a property off the mapper or
@@ -607,60 +610,86 @@ class PropertyOption(MapperOption):
     def _process(self, query, raiseerr):
         if self._should_log_debug:
             self.logger.debug("applying option to Query, property key '%s'" % self.key)
-        paths = self._get_paths(query, raiseerr)
+        paths = self.__get_paths(query, raiseerr)
         if paths:
             self.process_query_property(query, paths)
 
     def process_query_property(self, query, paths):
         pass
+    
+    def __find_entity(self, query, mapper, raiseerr):
+        from sqlalchemy.orm.util import _class_to_mapper, _is_aliased_class
+        
+        if _is_aliased_class(mapper):
+            searchfor = mapper
+        else:
+            searchfor = _class_to_mapper(mapper).base_mapper
 
-    def _get_paths(self, query, raiseerr):
+        for ent in query._mapper_entities:
+            if ent.path_entity is searchfor:
+                return ent
+        else:
+            if raiseerr:
+                raise sa_exc.ArgumentError("Can't find entity %s in Query.  Current list: %r" % (searchfor, [str(m.path_entity) for m in query._entities]))
+            else:
+                return None
+            
+    def __get_paths(self, query, raiseerr):
         path = None
+        entity = None
         l = []
+        
         current_path = list(query._current_path)
-
+        
         if self.mapper:
-            global class_mapper
-            if class_mapper is None:
-                from sqlalchemy.orm import class_mapper
-            mapper = self.mapper
-            if isinstance(self.mapper, type):
-                mapper = class_mapper(mapper)
-            if mapper is not query.mapper and mapper not in [q.mapper for q in query._entities]:
-                raise exceptions.ArgumentError("Can't find entity %s in Query.  Current list: %r" % (str(mapper), [str(m) for m in query._entities]))
-        else:
-            mapper = query.mapper
-        if isinstance(self.key, basestring):
-            tokens = self.key.split('.')
-        else:
-            tokens = util.to_list(self.key)
+            entity = self.__find_entity(query, self.mapper, raiseerr)
+            mapper = entity.mapper
+            path_element = entity.path_entity
             
-        for token in tokens:
-            if isinstance(token, basestring):
-                prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
-            elif isinstance(token, PropComparator):
-                prop = token.property
-                token = prop.key
-                    
+        for key in util.to_list(self.key):
+            if isinstance(key, basestring):
+                tokens = key.split('.')
             else:
-                raise exceptions.ArgumentError("mapper option expects string key or list of attributes")
-                
-            if current_path and token == current_path[1]:
-                current_path = current_path[2:]
-                continue
+                tokens = [key]
+            for token in tokens:
+                if isinstance(token, basestring):
+                    if not entity:
+                        entity = query._entity_zero()
+                        path_element = entity.path_entity
+                        mapper = entity.mapper
+                    prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
+                    key = token
+                elif isinstance(token, PropComparator):
+                    prop = token.property
+                    if not entity:
+                        entity = self.__find_entity(query, token.parententity, raiseerr)
+                        if not entity:
+                            return []
+                        path_element = entity.path_entity
+                    key = prop.key
+                else:
+                    raise sa_exc.ArgumentError("mapper option expects string key or list of attributes")
+            
+                if current_path and key == current_path[1]:
+                    current_path = current_path[2:]
+                    continue
                 
-            if prop is None:
-                return []
-            path = build_path(mapper, prop.key, path)
-            l.append(path)
-            if getattr(token, '_of_type', None):
-                mapper = token._of_type
-            else:
-                mapper = getattr(prop, 'mapper', None)
+                if prop is None:
+                    return []
+
+                path = build_path(path_element, prop.key, path)
+                l.append(path)
+                if getattr(token, '_of_type', None):
+                    path_element = mapper = token._of_type
+                else:
+                    path_element = mapper = getattr(prop, 'mapper', None)
+                if path_element:
+                    path_element = path_element.base_mapper
+            
         return l
 
-PropertyOption.logger = logging.class_logger(PropertyOption)
-PropertyOption._should_log_debug = logging.is_debug_enabled(PropertyOption.logger)
+PropertyOption.logger = log.class_logger(PropertyOption)
+PropertyOption._should_log_debug = log.is_debug_enabled(PropertyOption.logger)
 
 class AttributeExtension(object):
     """An abstract class which specifies `append`, `delete`, and `set`
@@ -732,10 +761,10 @@ class LoaderStrategy(object):
     def init_class_attribute(self):
         pass
 
-    def setup_query(self, context, **kwargs):
+    def setup_query(self, context, entity, path, adapter, **kwargs):
         pass
 
-    def create_row_processor(self, selectcontext, mapper, row):
+    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
         """Return row processing functions which fulfill the contract specified
         by MapperProperty.create_row_processor.
 
@@ -744,3 +773,71 @@ class LoaderStrategy(object):
         """
 
         raise NotImplementedError()
+
+    def __str__(self):
+        return str(self.parent_property)
+
+    def debug_callable(self, fn, logger, announcement, logfn):
+        if announcement:
+            logger.debug(announcement)
+        if logfn:
+            def call(*args, **kwargs):
+                logger.debug(logfn(*args, **kwargs))
+                return fn(*args, **kwargs)
+            return call
+        else:
+            return fn
+
+class InstrumentationManager(object):
+    """User-defined class instrumentation extension."""
+
+    # r4361 added a mandatory (cls) constructor to this interface.
+    # given that, perhaps class_ should be dropped from all of these
+    # signatures.
+
+    def __init__(self, class_):
+        pass
+
+    def manage(self, class_, manager):
+        setattr(class_, '_default_class_manager', manager)
+
+    def dispose(self, class_, manager):
+        delattr(class_, '_default_class_manager')
+
+    def manager_getter(self, class_):
+        def get(cls):
+            return cls._default_class_manager
+        return get
+
+    def instrument_attribute(self, class_, key, inst):
+        pass
+
+    def install_descriptor(self, class_, key, inst):
+        setattr(class_, key, inst)
+
+    def uninstall_descriptor(self, class_, key):
+        delattr(class_, key)
+
+    def install_member(self, class_, key, implementation):
+        setattr(class_, key, implementation)
+
+    def uninstall_member(self, class_, key):
+        delattr(class_, key)
+
+    def instrument_collection_class(self, class_, key, collection_class):
+        global collections
+        if collections is None:
+            from sqlalchemy.orm import collections
+        return collections.prepare_instrumentation(collection_class)
+
+    def get_instance_dict(self, class_, instance):
+        return instance.__dict__
+
+    def initialize_instance_dict(self, class_, instance):
+        pass
+
+    def install_state(self, class_, instance, state):
+        setattr(instance, '_default_state', state)
+
+    def state_getter(self, class_):
+        return lambda instance: getattr(instance, '_default_state')
index ba0644758f9aae6189137dba57599ae5df4b1a92..6d79f6cd502372bb85750056aacdb7e07f851275 100644 (file)
@@ -1,27 +1,44 @@
-# orm/mapper.py
+# mapper.py
 # Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-"""Defines the [sqlalchemy.orm.mapper#Mapper] class, the central configurational
+"""Logic to map Python classes to and from selectables.
+
+Defines the [sqlalchemy.orm.mapper#Mapper] class, the central configurational
 unit which associates a class with a database table.
 
 This is a semi-private module; the main configurational API of the ORM is
 available in [sqlalchemy.orm#].
+
 """
 
 import weakref
 from itertools import chain
-from sqlalchemy import sql, util, exceptions, logging
-from sqlalchemy.sql import expression, visitors, operators, util as sqlutil
-from sqlalchemy.orm import sync, attributes
-from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator
-from sqlalchemy.orm.util import has_identity, _state_has_identity, _is_mapped_class, has_mapper, \
-    _state_mapper, class_mapper, object_mapper, _class_to_mapper,\
-    ExtensionCarrier, state_str, instance_str
-    
-__all__ = ['Mapper', 'class_mapper', 'object_mapper', '_mapper_registry']
+
+from sqlalchemy import sql, util, log
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy.sql import expression, visitors, operators
+import sqlalchemy.sql.util as sqlutil
+from sqlalchemy.orm import attributes
+from sqlalchemy.orm import exc
+from sqlalchemy.orm import sync
+from sqlalchemy.orm.identity import IdentityManagedState
+from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, \
+     PropComparator
+from sqlalchemy.orm.util import \
+     ExtensionCarrier, _INSTRUMENTOR, _class_to_mapper, _is_mapped_class, \
+     _state_has_identity, _state_mapper, class_mapper, has_identity, \
+     has_mapper, instance_str, object_mapper, state_str
+
+
+__all__ = (
+    'Mapper',
+    '_mapper_registry',
+    'class_mapper',
+    'object_mapper',
+    )
 
 _mapper_registry = weakref.WeakKeyDictionary()
 _new_mappers = False
@@ -43,6 +60,7 @@ ColumnProperty = None
 SynonymProperty = None
 ComparableProperty = None
 _expire_state = None
+_state_session = None
 
 
 class Mapper(object):
@@ -85,10 +103,11 @@ class Mapper(object):
 
         Mappers are normally constructed via the [sqlalchemy.orm#mapper()]
         function.  See for details.
-        
+
         """
 
         self.class_ = class_
+        self.class_manager = None
         self.entity_name = entity_name
         self.primary_key_argument = primary_key
         self.non_primary = non_primary
@@ -110,19 +129,18 @@ class Mapper(object):
         self.eager_defaults = eager_defaults
         self.column_prefix = column_prefix
         self.polymorphic_on = polymorphic_on
-        self._eager_loaders = util.Set()
         self._dependency_processors = []
         self._clause_adapter = None
         self._requires_row_aliasing = False
         self.__inherits_equated_pairs = None
-        
+
         if not issubclass(class_, object):
-            raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
+            raise sa_exc.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
 
         self.select_table = select_table
         if select_table:
             if with_polymorphic:
-                raise exceptions.ArgumentError("select_table can't be used with with_polymorphic (they define conflicting settings)")
+                raise sa_exc.ArgumentError("select_table can't be used with with_polymorphic (they define conflicting settings)")
             self.with_polymorphic = ('*', select_table)
         else:
             if with_polymorphic == '*':
@@ -133,14 +151,14 @@ class Mapper(object):
                 else:
                     self.with_polymorphic = (with_polymorphic, None)
             elif with_polymorphic is not None:
-                raise exceptions.ArgumentError("Invalid setting for with_polymorphic")
+                raise sa_exc.ArgumentError("Invalid setting for with_polymorphic")
             else:
                 self.with_polymorphic = None
-        
+
         if isinstance(self.local_table, expression._SelectBaseMixin):
             util.warn("mapper %s creating an alias for the given selectable - use Class attributes for queries." % self)
             self.local_table = self.local_table.alias()
-        
+
         if self.with_polymorphic and isinstance(self.with_polymorphic[1], expression._SelectBaseMixin):
             self.with_polymorphic[1] = self.with_polymorphic[1].alias()
 
@@ -148,12 +166,8 @@ class Mapper(object):
         # indicates this Mapper should be used to construct the object instance for that row.
         self.polymorphic_identity = polymorphic_identity
 
-        if polymorphic_fetch not in (None, 'union', 'select', 'deferred'):
-            raise exceptions.ArgumentError("Invalid option for 'polymorphic_fetch': '%s'" % polymorphic_fetch)
-        if polymorphic_fetch is None:
-            self.polymorphic_fetch = (self.with_polymorphic is None) and 'select' or 'union'
-        else:
-            self.polymorphic_fetch = polymorphic_fetch
+        if polymorphic_fetch:
+            util.warn_deprecated('polymorphic_fetch option is deprecated.  Unloaded columns load as deferred in all cases; loading can be controlled using the "with_polymorphic" option.')
 
         # a dictionary of 'polymorphic identity' names, associating those names with
         # Mappers that will be used to construct object instances upon a select operation.
@@ -170,14 +184,14 @@ class Mapper(object):
         # a set of all mappers which inherit from this one.
         self._inheriting_mappers = util.Set()
 
-        self.__props_init = False
+        self.compiled = False
 
-        self.__should_log_info = logging.is_info_enabled(self.logger)
-        self.__should_log_debug = logging.is_debug_enabled(self.logger)
+        self.__should_log_info = log.is_info_enabled(self.logger)
+        self.__should_log_debug = log.is_debug_enabled(self.logger)
 
-        self.__compile_class()
         self.__compile_inheritance()
         self.__compile_extensions()
+        self.__compile_class()
         self.__compile_properties()
         self.__compile_pks()
         global _new_mappers
@@ -192,11 +206,12 @@ class Mapper(object):
         if self.__should_log_debug:
             self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") + ") " + msg)
 
-    def _is_orphan(self, obj):
+    def _is_orphan(self, state):
         o = False
         for mapper in self.iterate_to_root():
-            for (key,klass) in mapper.delete_orphans:
-                if attributes.has_parent(klass, obj, key, optimistic=has_identity(obj)):
+            for (key, cls) in mapper.delete_orphans:
+                if attributes.manager_of_class(cls).has_parent(
+                    state, key, optimistic=_state_has_identity(state)):
                     return False
             o = o or bool(mapper.delete_orphans)
         return o
@@ -208,41 +223,26 @@ class Mapper(object):
         return self._get_property(key, resolve_synonyms=resolve_synonyms, raiseerr=raiseerr)
 
     def _get_property(self, key, resolve_synonyms=False, raiseerr=True):
-        """private in-compilation version of get_property()."""
-
         prop = self.__props.get(key, None)
         if resolve_synonyms:
             while isinstance(prop, SynonymProperty):
                 prop = self.__props.get(prop.name, None)
         if prop is None and raiseerr:
-            raise exceptions.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key))
+            raise sa_exc.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key))
         return prop
 
     def iterate_properties(self):
+        """return an iterator of all MapperProperty objects."""
         self.compile()
         return self.__props.itervalues()
-    iterate_properties = property(iterate_properties, doc="returns an iterator of all MapperProperty objects.")
+    iterate_properties = property(iterate_properties)
 
-    def __adjust_wp_selectable(self, spec=None, selectable=False):
-        """given a with_polymorphic() argument, resolve it against this mapper's with_polymorphic setting"""
-        
-        isdefault = False
-        if self.with_polymorphic:
-            isdefault = not spec and selectable is False
-
-            if not spec:
-                spec = self.with_polymorphic[0]
-            if selectable is False:
-                selectable = self.with_polymorphic[1]
-                
-        return spec, selectable, isdefault
-        
     def __mappers_from_spec(self, spec, selectable):
         """given a with_polymorphic() argument, return the set of mappers it represents.
-        
+
         Trims the list of mappers to just those represented within the given selectable, if present.
         This helps some more legacy-ish mappings.
-        
+
         """
         if spec == '*':
             mappers = list(self.polymorphic_iterator())
@@ -250,86 +250,98 @@ class Mapper(object):
             mappers = [_class_to_mapper(m) for m in util.to_list(spec)]
         else:
             mappers = []
-        
+
         if selectable:
-            tables = util.Set(sqlutil.find_tables(selectable))
+            tables = util.Set(sqlutil.find_tables(selectable, include_aliases=True))
             mappers = [m for m in mappers if m.local_table in tables]
-            
+
         return mappers
-    __mappers_from_spec = util.conditional_cache_decorator(__mappers_from_spec)
-    
+
     def __selectable_from_mappers(self, mappers):
         """given a list of mappers (assumed to be within this mapper's inheritance hierarchy),
         construct an outerjoin amongst those mapper's mapped tables.
-        
+
         """
         from_obj = self.mapped_table
         for m in mappers:
             if m is self:
                 continue
             if m.concrete:
-                raise exceptions.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.")
+                raise sa_exc.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.")
             elif not m.single:
                 from_obj = from_obj.outerjoin(m.local_table, m.inherit_condition)
-        
+
         return from_obj
-    __selectable_from_mappers = util.conditional_cache_decorator(__selectable_from_mappers)
-    
-    def _with_polymorphic_mappers(self, spec=None, selectable=False):
-        spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable)
-        return self.__mappers_from_spec(spec, selectable, cache=isdefault)
-        
-    def _with_polymorphic_selectable(self, spec=None, selectable=False):
-        spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable)
+
+    def _with_polymorphic_mappers(self):
+        if not self.with_polymorphic:
+            return [self]
+        return self.__mappers_from_spec(*self.with_polymorphic)
+    _with_polymorphic_mappers = property(util.cache_decorator(_with_polymorphic_mappers))
+
+    def _with_polymorphic_selectable(self):
+        if not self.with_polymorphic:
+            return self.mapped_table
+
+        spec, selectable = self.with_polymorphic
         if selectable:
             return selectable
         else:
-            return self.__selectable_from_mappers(self.__mappers_from_spec(spec, selectable, cache=isdefault), cache=isdefault)
-    
+            return self.__selectable_from_mappers(self.__mappers_from_spec(spec, selectable))
+    _with_polymorphic_selectable = property(util.cache_decorator(_with_polymorphic_selectable))
+
     def _with_polymorphic_args(self, spec=None, selectable=False):
-        spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable)
-        mappers = self.__mappers_from_spec(spec, selectable, cache=isdefault)
+        if self.with_polymorphic:
+            if not spec:
+                spec = self.with_polymorphic[0]
+            if selectable is False:
+                selectable = self.with_polymorphic[1]
+
+        mappers = self.__mappers_from_spec(spec, selectable)
         if selectable:
             return mappers, selectable
         else:
-            return mappers, self.__selectable_from_mappers(mappers, cache=isdefault)
-        
-    def _iterate_polymorphic_properties(self, spec=None, selectable=False):
+            return mappers, self.__selectable_from_mappers(mappers)
+
+    def _iterate_polymorphic_properties(self, mappers=None):
+        if mappers is None:
+            mappers = self._with_polymorphic_mappers
         return iter(util.OrderedSet(
-            chain(*[list(mapper.iterate_properties) for mapper in [self] + self._with_polymorphic_mappers(spec, selectable)])
+            chain(*[list(mapper.iterate_properties) for mapper in [self] + mappers])
         ))
 
     def properties(self):
         raise NotImplementedError("Public collection of MapperProperty objects is provided by the get_property() and iterate_properties accessors.")
     properties = property(properties)
 
-    def compiled(self):
-        """return True if this mapper is compiled"""
-        return self.__props_init
-    compiled = property(compiled)
-
     def dispose(self):
-        # disaable any attribute-based compilation
-        self.__props_init = True
-        try:
-            del self.class_.c
-        except AttributeError:
-            pass
-        if not self.non_primary and self.entity_name in self._class_state.mappers:
-            del self._class_state.mappers[self.entity_name]
-        if not self._class_state.mappers:
+        # Disable any attribute-based compilation.
+        self.compiled = True
+
+        manager = self.class_manager
+        mappers = manager.mappers
+
+        if not self.non_primary and self.entity_name in mappers:
+            del mappers[self.entity_name]
+        if not mappers and manager.info.get(_INSTRUMENTOR, False):
+            for legacy in _legacy_descriptors.keys():
+                manager.uninstall_member(legacy)
+            manager.events.remove_listener('on_init', _event_on_init)
+            manager.events.remove_listener('on_init_failure',
+                                           _event_on_init_failure)
+            manager.uninstall_member('__init__')
+            del manager.info[_INSTRUMENTOR]
             attributes.unregister_class(self.class_)
 
     def compile(self):
         """Compile this mapper and all other non-compiled mappers.
-        
+
         This method checks the local compiled status as well as for
-        any new mappers that have been defined, and is safe to call 
+        any new mappers that have been defined, and is safe to call
         repeatedly.
         """
-        
         global _new_mappers
-        if self.__props_init and not _new_mappers:
+        if self.compiled and not _new_mappers:
             return self
             
         _COMPILE_MUTEX.acquire()
@@ -341,12 +353,12 @@ class Mapper(object):
         try:
 
             # double-check inside mutex
-            if self.__props_init and not _new_mappers:
+            if self.compiled and not _new_mappers:
                 return self
 
             # initialize properties on all mappers
             for mapper in list(_mapper_registry):
-                if not mapper.__props_init:
+                if not mapper.compiled:
                     mapper.__initialize_properties()
 
             _new_mappers = False
@@ -358,7 +370,7 @@ class Mapper(object):
     def __initialize_properties(self):
         """Call the ``init()`` method on all ``MapperProperties``
         attached to this mapper.
-        
+
         This is a deferred configuration step which is intended
         to execute once all mappers have been constructed.
         """
@@ -370,8 +382,7 @@ class Mapper(object):
             if getattr(prop, 'key', None) is None:
                 prop.init(key, self)
         self.__log("__initialize_properties() complete")
-        self.__props_init = True
-
+        self.compiled = True
 
     def __compile_extensions(self):
         """Go through the global_extensions list as well as the list
@@ -391,14 +402,12 @@ class Mapper(object):
             for ext in self.inherits.extension:
                 if ext not in extlist:
                     extlist.add(ext)
-                    ext.instrument_class(self, self.class_)
         else:
             for ext in global_extensions:
                 if isinstance(ext, type):
                     ext = ext()
                 if ext not in extlist:
                     extlist.add(ext)
-                    ext.instrument_class(self, self.class_)
 
         self.extension = ExtensionCarrier()
         for ext in extlist:
@@ -410,13 +419,11 @@ class Mapper(object):
         if self.inherits:
             if isinstance(self.inherits, type):
                 self.inherits = class_mapper(self.inherits, compile=False)
-            else:
-                self.inherits = self.inherits
             if not issubclass(self.class_, self.inherits.class_):
-                raise exceptions.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, self.inherits.class_.__name__))
+                raise sa_exc.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, self.inherits.class_.__name__))
             if self.non_primary != self.inherits.non_primary:
                 np = not self.non_primary and "primary" or "non-primary"
-                raise exceptions.ArgumentError("Inheritance of %s mapper for class '%s' is only allowed from a %s mapper" % (np, self.class_.__name__, np))
+                raise sa_exc.ArgumentError("Inheritance of %s mapper for class '%s' is only allowed from a %s mapper" % (np, self.class_.__name__, np))
             # inherit_condition is optional.
             if self.local_table is None:
                 self.local_table = self.inherits.local_table
@@ -428,29 +435,17 @@ class Mapper(object):
                         if mapper.polymorphic_on:
                             mapper._requires_row_aliasing = True
                 else:
-                    if self.inherit_condition is None:
+                    if not self.inherit_condition:
                         # figure out inherit condition from our table to the immediate table
                         # of the inherited mapper, not its full table which could pull in other
                         # stuff we dont want (allows test/inheritance.InheritTest4 to pass)
                         self.inherit_condition = sqlutil.join_condition(self.inherits.local_table, self.local_table)
                     self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition)
-                    
+
                     fks = util.to_set(self.inherit_foreign_keys)
                     self.__inherits_equated_pairs = sqlutil.criterion_as_pairs(self.mapped_table.onclause, consider_as_foreign_keys=fks)
             else:
                 self.mapped_table = self.local_table
-            if self.polymorphic_identity is not None:
-                self.inherits.polymorphic_map[self.polymorphic_identity] = self
-                if self.polymorphic_on is None:
-                    for mapper in self.iterate_to_root():
-                        # try to set up polymorphic on using correesponding_column(); else leave
-                        # as None
-                        if mapper.polymorphic_on:
-                            self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on)
-                            break
-                    else:
-                        # TODO: this exception not covered
-                        raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
 
             if self.polymorphic_identity and not self.concrete:
                 self._identity_class = self.inherits._identity_class
@@ -470,25 +465,38 @@ class Mapper(object):
             self.inherits._inheriting_mappers.add(self)
             self.base_mapper = self.inherits.base_mapper
             self._all_tables = self.inherits._all_tables
+
+            if self.polymorphic_identity is not None:
+                self.polymorphic_map[self.polymorphic_identity] = self
+                if not self.polymorphic_on:
+                    for mapper in self.iterate_to_root():
+                        # try to set up polymorphic on using correesponding_column(); else leave
+                        # as None
+                        if mapper.polymorphic_on:
+                            self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on)
+                            break
+                    else:
+                        # TODO: this exception not covered
+                        raise sa_exc.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
         else:
             self._all_tables = util.Set()
             self.base_mapper = self
             self.mapped_table = self.local_table
             if self.polymorphic_identity:
                 if self.polymorphic_on is None:
-                    raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
+                    raise sa_exc.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity))
                 self.polymorphic_map[self.polymorphic_identity] = self
             self._identity_class = self.class_
-        
+
         if self.mapped_table is None:
-            raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified.  (Are you using the return value of table.create()?  It no longer has a return value.)" % str(self))
+            raise sa_exc.ArgumentError("Mapper '%s' does not have a mapped_table specified.  (Are you using the return value of table.create()?  It no longer has a return value.)" % str(self))
 
     def __compile_pks(self):
 
         self.tables = sqlutil.find_tables(self.mapped_table)
 
         if not self.tables:
-            raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
+            raise sa_exc.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
 
         self._pks_by_table = {}
         self._cols_by_table = {}
@@ -512,7 +520,7 @@ class Mapper(object):
                 self._pks_by_table[k.table].add(k)
 
         if self.mapped_table not in self._pks_by_table or len(self._pks_by_table[self.mapped_table]) == 0:
-            raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description))
+            raise sa_exc.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description))
 
         if self.inherits and not self.concrete and not self.primary_key_argument:
             # if inheriting, the "primary key" for this mapper is that of the inheriting (unless concrete or explicit)
@@ -525,7 +533,7 @@ class Mapper(object):
                 primary_key = sqlutil.reduce_columns(self._pks_by_table[self.mapped_table])
 
             if len(primary_key) == 0:
-                raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description))
+                raise sa_exc.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description))
 
             self.primary_key = primary_key
             self.__log("Identified primary key columns: " + str(primary_key))
@@ -534,25 +542,18 @@ class Mapper(object):
         """create a "get clause" based on the primary key.  this is used
         by query.get() and many-to-one lazyloads to load this item
         by primary key.
-        
+
         """
         params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key]
         return sql.and_(*[k==v for (k, v) in params]), dict(params)
     _get_clause = property(util.cache_decorator(_get_clause))
-    
+
     def _equivalent_columns(self):
         """Create a map of all *equivalent* columns, based on
         the determination of column pairs that are equated to
         one another either by an established foreign key relationship
         or by a joined-table inheritance join.
 
-        This is used to determine the minimal set of primary key
-        columns for the mapper, as well as when relating
-        columns to those of a polymorphic selectable (i.e. a UNION of
-        several mapped tables), as that selectable usually only contains
-        one column in its columns clause out of a group of several which
-        are equated to each other.
-
         The resulting structure is a dictionary of columns mapped
         to lists of equivalent columns, i.e.
 
@@ -578,7 +579,7 @@ class Mapper(object):
                     result[binary.right] = util.Set([binary.left])
         for mapper in self.base_mapper.polymorphic_iterator():
             if mapper.inherit_condition:
-                visitors.traverse(mapper.inherit_condition, visit_binary=visit_binary)
+                visitors.traverse(mapper.inherit_condition, {}, {'binary':visit_binary})
 
         # TODO: matching of cols to foreign keys might better be generalized
         # into general column translation (i.e. corresponding_column)
@@ -619,7 +620,7 @@ class Mapper(object):
             cls = object.__getattribute__(self, 'class_')
             clskey = object.__getattribute__(self, 'key')
 
-            if key.startswith('__'):
+            if key.startswith('__') and key != '__clause_element__':
                 return object.__getattribute__(self, key)
 
             class_mapper(cls)
@@ -676,13 +677,13 @@ class Mapper(object):
             column_key = (self.column_prefix or '') + column.key
 
             self._compile_property(column_key, column, init=False, setparent=True)
-        
+
         # do a special check for the "discriminiator" column, as it may only be present
         # in the 'with_polymorphic' selectable but we need it for the base mapper
-        if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty: 
-           col = self.mapped_table.corresponding_column(self.polymorphic_on) or self.polymorphic_on
-           self._compile_property(col.key, ColumnProperty(col), init=False, setparent=True)
-            
+        if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty:
+            col = self.mapped_table.corresponding_column(self.polymorphic_on) or self.polymorphic_on
+            self._compile_property(col.key, ColumnProperty(col), init=False, setparent=True)
+
     def _adapt_inherited_property(self, key, prop):
         if not self.concrete:
             self._compile_property(key, prop, init=False, setparent=False)
@@ -696,7 +697,7 @@ class Mapper(object):
             columns = util.to_list(prop)
             column = columns[0]
             if not expression.is_column(column):
-                raise exceptions.ArgumentError("%s=%r is not an instance of MapperProperty or Column" % (key, prop))
+                raise sa_exc.ArgumentError("%s=%r is not an instance of MapperProperty or Column" % (key, prop))
 
             prop = self.__props.get(key, None)
 
@@ -715,12 +716,12 @@ class Mapper(object):
                 for c in columns:
                     mc = self.mapped_table.corresponding_column(c)
                     if not mc:
-                        raise exceptions.ArgumentError("Column '%s' is not represented in mapper's table.  Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c))
+                        raise sa_exc.ArgumentError("Column '%s' is not represented in mapper's table.  Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c))
                     mapped_column.append(mc)
                 prop = ColumnProperty(*mapped_column)
             else:
                 if not self.allow_column_override:
-                    raise exceptions.ArgumentError("WARNING: column '%s' not being added due to property '%s'.  Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop)))
+                    raise sa_exc.ArgumentError("WARNING: column '%s' not being added due to property '%s'.  Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop)))
                 else:
                     return
 
@@ -731,7 +732,7 @@ class Mapper(object):
             if col is None:
                 col = prop.columns[0]
             else:
-                # if column is coming in after _cols_by_table was initialized, ensure the col is in the 
+                # if column is coming in after _cols_by_table was initialized, ensure the col is in the
                 # right set
                 if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]:
                     self._cols_by_table[col.table].add(col)
@@ -740,35 +741,28 @@ class Mapper(object):
             for col in prop.columns:
                 for col in col.proxy_set:
                     self._columntoproperty[col] = prop
-            
-                
-        elif isinstance(prop, SynonymProperty) and setparent:
+
+        elif isinstance(prop, (ComparableProperty, SynonymProperty)) and setparent:
             if prop.descriptor is None:
                 prop.descriptor = getattr(self.class_, key, None)
                 if isinstance(prop.descriptor, Mapper._CompileOnAttr):
                     prop.descriptor = object.__getattribute__(prop.descriptor, 'existing_prop')
-            if prop.map_column:
-                if not key in self.mapped_table.c:
-                    raise exceptions.ArgumentError("Can't compile synonym '%s': no column on table '%s' named '%s'"  % (prop.name, self.mapped_table.description, key))
+            if getattr(prop, 'map_column', False):
+                if key not in self.mapped_table.c:
+                    raise sa_exc.ArgumentError("Can't compile synonym '%s': no column on table '%s' named '%s'"  % (prop.name, self.mapped_table.description, key))
                 self._compile_property(prop.name, ColumnProperty(self.mapped_table.c[key]), init=init, setparent=setparent)
-        elif isinstance(prop, ComparableProperty) and setparent:
-            # refactor me
-            if prop.descriptor is None:
-                prop.descriptor = getattr(self.class_, key, None)
-                if isinstance(prop.descriptor, Mapper._CompileOnAttr):
-                    prop.descriptor = object.__getattribute__(prop.descriptor,
-                                                              'existing_prop')
+
         self.__props[key] = prop
 
         if setparent:
             prop.set_parent(self)
 
             if not self.non_primary:
-                setattr(self.class_, key, Mapper._CompileOnAttr(self.class_, key))
-
+                self.class_manager.install_descriptor(
+                    key, Mapper._CompileOnAttr(self.class_, key))
         if init:
             prop.init(key, self)
-        
+
         for mapper in self._inheriting_mappers:
             mapper._adapt_inherited_property(key, prop)
 
@@ -783,49 +777,78 @@ class Mapper(object):
         auto-session attachment logic.
         """
 
+        manager = attributes.manager_of_class(self.class_)
+
         if self.non_primary:
-            if not hasattr(self.class_, '_class_state'):
-                raise exceptions.InvalidRequestError("Class %s has no primary mapper configured.  Configure a primary mapper first before setting up a non primary Mapper.")
-            self._class_state = self.class_._class_state
+            if not manager or None not in manager.mappers:
+                raise sa_exc.InvalidRequestError(
+                    "Class %s has no primary mapper configured.  Configure "
+                    "a primary mapper first before setting up a non primary "
+                    "Mapper.")
+            self.class_manager = manager
             _mapper_registry[self] = True
             return
 
-        if not self.non_primary and '_class_state' in self.class_.__dict__ and (self.entity_name in self.class_._class_state.mappers):
-             raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'.  Use non_primary=True to create a non primary Mapper.  clear_mappers() will remove *all* current mappers from all classes." % (self.class_, self.entity_name))
+        if manager is not None:
+            if manager.class_ is not self.class_:
+                # An inherited manager.  Install one for this subclass.
+                manager = None
+            elif self.entity_name in manager.mappers:
+                raise sa_exc.ArgumentError(
+                    "Class '%s' already has a primary mapper defined "
+                    "with entity name '%s'.  Use non_primary=True to "
+                    "create a non primary Mapper.  clear_mappers() will "
+                    "remove *all* current mappers from all classes." %
+                    (self.class_, self.entity_name))
 
-        def extra_init(class_, oldinit, instance, args, kwargs):
-            self.compile()
-            if 'init_instance' in self.extension.methods:
-                self.extension.init_instance(self, class_, oldinit, instance, args, kwargs)
+        _mapper_registry[self] = True
 
-        def on_exception(class_, oldinit, instance, args, kwargs):
-            util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
+        if manager is None:
+            manager = attributes.create_manager_for_cls(self.class_)
 
-        attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception, deferred_scalar_loader=_load_scalar_attributes)
+        self.class_manager = manager
 
-        self._class_state = self.class_._class_state
-        _mapper_registry[self] = True
+        has_been_initialized = bool(manager.info.get(_INSTRUMENTOR, False))
+        manager.mappers[self.entity_name] = self
 
-        self.class_._class_state.mappers[self.entity_name] = self
+        # The remaining members can be added by any mapper, e_name None or not.
+        if has_been_initialized:
+            return
 
-        for ext in util.to_list(self.extension, []):
-            ext.instrument_class(self, self.class_)
+        self.extension.instrument_class(self, self.class_)
 
-        if self.entity_name is None:
-            self.class_.c = self.c
+        manager.instantiable = True
+        manager.instance_state_factory = IdentityManagedState
+        manager.deferred_scalar_loader = _load_scalar_attributes
+
+        event_registry = manager.events
+        event_registry.add_listener('on_init', _event_on_init)
+        event_registry.add_listener('on_init_failure', _event_on_init_failure)
+
+        for key, impl in _legacy_descriptors.items():
+            manager.install_member(key, impl)
+
+        manager.info[_INSTRUMENTOR] = self
 
     def common_parent(self, other):
         """Return true if the given mapper shares a common inherited parent as this mapper."""
 
         return self.base_mapper is other.base_mapper
 
+    def _canload(self, state):
+        s = self.primary_mapper()
+        if s.polymorphic_on:
+            return _state_mapper(state).isa(s)
+        else:
+            return _state_mapper(state) is s
+
     def isa(self, other):
-        """Return True if the given mapper inherits from this mapper."""
+        """Return True if the this mapper inherits from the given mapper."""
 
-        m = other
-        while m is not self and m.inherits:
+        m = self
+        while m and m is not other:
             m = m.inherits
-        return m is self
+        return bool(m)
 
     def iterate_to_root(self):
         m = self
@@ -867,42 +890,20 @@ class Mapper(object):
         """
 
         self._init_properties[key] = prop
-        self._compile_property(key, prop, init=self.__props_init)
+        self._compile_property(key, prop, init=self.compiled)
+
+    def __repr__(self):
+        return '<Mapper at 0x%x; %s>' % (
+            id(self), self.class_.__name__)
 
     def __str__(self):
         return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "")
 
     def primary_mapper(self):
         """Return the primary mapper corresponding to this mapper's class key (class + entity_name)."""
-        return self._class_state.mappers[self.entity_name]
-
-    def get_session(self):
-        """Return the contextual session provided by the mapper
-        extension chain, if any.
-
-        Raise ``InvalidRequestError`` if a session cannot be retrieved
-        from the extension chain.
-        """
-
-        if 'get_session' in self.extension.methods:
-            s = self.extension.get_session()
-            if s is not EXT_CONTINUE:
-                return s
-
-        raise exceptions.InvalidRequestError("No contextual Session is established.")
-
-    def instances(self, cursor, session, *mappers, **kwargs):
-        """Return a list of mapped instances corresponding to the rows
-        in a given ResultProxy.
-
-        DEPRECATED.
-        """
+        return self.class_manager.mappers[self.entity_name]
 
-        import sqlalchemy.orm.query
-        return sqlalchemy.orm.Query(self, session).instances(cursor, *mappers, **kwargs)
-    instances = util.deprecated(None, False)(instances)
-
-    def identity_key_from_row(self, row):
+    def identity_key_from_row(self, row, adapter=None):
         """Return an identity-map key for use in storing/retrieving an
         item from the identity map.
 
@@ -911,7 +912,12 @@ class Mapper(object):
           dictionary corresponding result-set ``ColumnElement``
           instances to their values within a row.
         """
-        return (self._identity_class, tuple([row[column] for column in self.primary_key]), self.entity_name)
+
+        pk_cols = self.primary_key
+        if adapter:
+            pk_cols = [adapter.columns[c] for c in pk_cols]
+
+        return (self._identity_class, tuple([row[column] for column in pk_cols]), self.entity_name)
 
     def identity_key_from_primary_key(self, primary_key):
         """Return an identity-map key for use in storing/retrieving an
@@ -926,8 +932,9 @@ class Mapper(object):
         """Return the identity key for the given instance, based on
         its primary key attributes.
 
-        This value is typically also found on the instance itself
-        under the attribute name `_instance_key`.
+        This value is typically also found on the instance state under the
+        attribute name `key`.
+
         """
         return self.identity_key_from_primary_key(self.primary_key_from_instance(instance))
 
@@ -938,17 +945,12 @@ class Mapper(object):
         """Return the list of primary key values for the given
         instance.
         """
-
-        return [self._get_state_attr_by_column(instance._state, column) for column in self.primary_key]
+        state = attributes.instance_state(instance)
+        return self._primary_key_from_state(state)
 
     def _primary_key_from_state(self, state):
         return [self._get_state_attr_by_column(state, column) for column in self.primary_key]
 
-    def _canload(self, state):
-        if self.polymorphic_on:
-            return issubclass(state.class_, self.class_)
-        else:
-            return state.class_ is self.class_
 
     def _get_col_to_prop(self, column):
         try:
@@ -956,24 +958,23 @@ class Mapper(object):
         except KeyError:
             prop = self.__props.get(column.key, None)
             if prop:
-                raise exceptions.UnmappedColumnError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop)))
+                raise exc.UnmappedColumnError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop)))
             else:
-                raise exceptions.UnmappedColumnError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self)))
+                raise exc.UnmappedColumnError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self)))
 
+    # TODO: improve names
     def _get_state_attr_by_column(self, state, column):
         return self._get_col_to_prop(column).getattr(state, column)
 
     def _set_state_attr_by_column(self, state, column, value):
         return self._get_col_to_prop(column).setattr(state, value, column)
 
-    def _get_attr_by_column(self, obj, column):
-        return self._get_col_to_prop(column).getattr(obj._state, column)
-
     def _get_committed_attr_by_column(self, obj, column):
-        return self._get_col_to_prop(column).getcommitted(obj._state, column)
+        state = attributes.instance_state(obj)
+        return self._get_committed_state_attr_by_column(state, column)
 
-    def _set_attr_by_column(self, obj, column, value):
-        self._get_col_to_prop(column).setattr(obj._state, column, value)
+    def _get_committed_state_attr_by_column(self, state, column):
+        return self._get_col_to_prop(column).getcommitted(state, column)
 
     def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False):
         """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects.
@@ -1002,15 +1003,14 @@ class Mapper(object):
         # organize individual states with the connection to use for insert/update
         if 'connection_callable' in uowtransaction.mapper_flush_opts:
             connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
-            tups = [(state, connection_callable(self, state.obj()), _state_has_identity(state)) for state in states]
+            tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in states]
         else:
             connection = uowtransaction.transaction.connection(self)
-            tups = [(state, connection, _state_has_identity(state)) for state in states]
+            tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in states]
 
         if not postupdate:
             # call before_XXX extensions
-            for state, connection, has_identity in tups:
-                mapper = _state_mapper(state)
+            for state, mapper, connection, has_identity in tups:
                 if not has_identity:
                     if 'before_insert' in mapper.extension.methods:
                         mapper.extension.before_insert(mapper, connection, state.obj())
@@ -1018,16 +1018,16 @@ class Mapper(object):
                     if 'before_update' in mapper.extension.methods:
                         mapper.extension.before_update(mapper, connection, state.obj())
 
-        for state, connection, has_identity in tups:
+        for state, mapper, connection, has_identity in tups:
             # detect if we have a "pending" instance (i.e. has no instance_key attached to it),
             # and another instance with the same identity key already exists as persistent.  convert to an
             # UPDATE if so.
-            mapper = _state_mapper(state)
             instance_key = mapper._identity_key_from_state(state)
-            if not postupdate and not has_identity and instance_key in uowtransaction.uow.identity_map:
-                existing = uowtransaction.uow.identity_map[instance_key]._state
+            if not postupdate and not has_identity and instance_key in uowtransaction.session.identity_map:
+                instance = uowtransaction.session.identity_map[instance_key]
+                existing = attributes.instance_state(instance)
                 if not uowtransaction.is_deleted(existing):
-                    raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing)))
+                    raise exc.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing)))
                 if self.__should_log_debug:
                     self.__log_debug("detected row switch for identity %s.  will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing)))
                 uowtransaction.set_row_switch(existing)
@@ -1044,8 +1044,7 @@ class Mapper(object):
             insert = []
             update = []
 
-            for state, connection, has_identity in tups:
-                mapper = _state_mapper(state)
+            for state, mapper, connection, has_identity in tups:
                 if table not in mapper._pks_by_table:
                     continue
                 pks = mapper._pks_by_table[table]
@@ -1054,7 +1053,7 @@ class Mapper(object):
                 if self.__should_log_debug:
                     self.__log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key)))
 
-                isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity
+                isinsert = not instance_key in uowtransaction.session.identity_map and not postupdate and not has_identity
                 params = {}
                 value_params = {}
                 hasdata = False
@@ -1131,7 +1130,7 @@ class Mapper(object):
                 pks = mapper._pks_by_table[table]
                 def comparator(a, b):
                     for col in pks:
-                        x = cmp(a[1][col._label],b[1][col._label])
+                        x = cmp(a[1][col._label], b[1][col._label])
                         if x != 0:
                             return x
                     return 0
@@ -1148,7 +1147,7 @@ class Mapper(object):
                     rows += c.rowcount
 
                 if c.supports_sane_rowcount() and rows != len(update):
-                    raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update)))
+                    raise exc.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update)))
 
             if insert:
                 statement = table.insert()
@@ -1179,8 +1178,7 @@ class Mapper(object):
 
         if not postupdate:
             # call after_XXX extensions
-            for state, connection, has_identity in tups:
-                mapper = _state_mapper(state)
+            for state, mapper, connection, has_identity in tups:
                 if not has_identity:
                     if 'after_insert' in mapper.extension.methods:
                         mapper.extension.after_insert(mapper, connection, state.obj())
@@ -1216,9 +1214,10 @@ class Mapper(object):
 
         if deferred_props:
             if self.eager_defaults:
-                _instance_key = self._identity_key_from_state(state)
-                state.dict['_instance_key'] = _instance_key
-                uowtransaction.session.query(self)._get(_instance_key, refresh_instance=state, only_load_props=deferred_props)
+                state.key = self._identity_key_from_state(state)
+                uowtransaction.session.query(self)._get(
+                    state.key, refresh_instance=state,
+                    only_load_props=deferred_props)
             else:
                 _expire_state(state, deferred_props)
 
@@ -1234,17 +1233,15 @@ class Mapper(object):
 
         if 'connection_callable' in uowtransaction.mapper_flush_opts:
             connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
-            tups = [(state, connection_callable(self, state.obj())) for state in states]
+            tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in states]
         else:
             connection = uowtransaction.transaction.connection(self)
-            tups = [(state, connection) for state in states]
+            tups = [(state, _state_mapper(state), connection) for state in states]
 
-        for (state, connection) in tups:
-            mapper = _state_mapper(state)
+        for state, mapper, connection in tups:
             if 'before_delete' in mapper.extension.methods:
                 mapper.extension.before_delete(mapper, connection, state.obj())
 
-        deleted_objects = util.Set()
         table_to_mapper = {}
         for mapper in self.base_mapper.polymorphic_iterator():
             for t in mapper.tables:
@@ -1252,8 +1249,7 @@ class Mapper(object):
 
         for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True):
             delete = {}
-            for (state, connection) in tups:
-                mapper = _state_mapper(state)
+            for state, mapper, connection in tups:
                 if table not in mapper._pks_by_table:
                     continue
 
@@ -1266,13 +1262,12 @@ class Mapper(object):
                     params[col.key] = mapper._get_state_attr_by_column(state, col)
                 if mapper.version_id_col and table.c.contains_column(mapper.version_id_col):
                     params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col)
-                # testlib.pragma exempt:__hash__
-                deleted_objects.add((state, connection))
+
             for connection, del_objects in delete.iteritems():
                 mapper = table_to_mapper[table]
                 def comparator(a, b):
                     for col in mapper._pks_by_table[table]:
-                        x = cmp(a[col.key],b[col.key])
+                        x = cmp(a[col.key], b[col.key])
                         if x != 0:
                             return x
                     return 0
@@ -1285,10 +1280,9 @@ class Mapper(object):
                 statement = table.delete(clause)
                 c = connection.execute(statement, del_objects)
                 if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects):
-                    raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects)))
+                    raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects)))
 
-        for state, connection in deleted_objects:
-            mapper = _state_mapper(state)
+        for state, mapper, connection in tups:
             if 'after_delete' in mapper.extension.methods:
                 mapper.extension.after_delete(mapper, connection, state.obj())
 
@@ -1325,7 +1319,7 @@ class Mapper(object):
         visitables = [(self.__props.itervalues(), 'property', state)]
 
         while visitables:
-            iterator,item_type,parent_state = visitables[-1]
+            iterator, item_type, parent_state = visitables[-1]
             try:
                 if item_type == 'property':
                     prop = iterator.next()
@@ -1337,291 +1331,315 @@ class Mapper(object):
             except StopIteration:
                 visitables.pop()
 
-    def _instance(self, context, row, result=None, polymorphic_from=None, extension=None, only_load_props=None, refresh_instance=None):
-        if not extension:
-            extension = self.extension
+    def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_instance=None):
+        pk_cols = self.primary_key
 
-        if 'translate_row' in extension.methods:
-            ret = extension.translate_row(self, context, row)
-            if ret is not EXT_CONTINUE:
-                row = ret
-
-        if polymorphic_from:
-            # if we are called from a base mapper doing a polymorphic load, figure out what tables,
-            # if any, will need to be "post-fetched" based on the tables present in the row,
-            # or from the options set up on the query
-            if ('polymorphic_fetch', self) not in context.attributes:
-                if self in context.query._with_polymorphic:
-                    context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, [])
-                else:
-                    context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, [t for t in self.tables if t not in polymorphic_from.tables])
-                
-        elif not refresh_instance and self.polymorphic_on:
-            discriminator = row[self.polymorphic_on]
-            if discriminator is not None:
-                try:
-                    mapper = self.polymorphic_map[discriminator]
-                except KeyError:
-                    raise exceptions.AssertionError("No such polymorphic_identity %r is defined" % discriminator)
-                if mapper is not self:
-                    return mapper._instance(context, row, result=result, polymorphic_from=self)
-
-        # determine identity key
-        if refresh_instance:
-            try:
-                identitykey = refresh_instance.dict['_instance_key']
-            except KeyError:
-                # super-rare condition; a refresh is being called
-                # on a non-instance-key instance; this is meant to only
-                # occur wihtin a flush()
-                identitykey = self._identity_key_from_state(refresh_instance)
+        if polymorphic_from or refresh_instance:
+            polymorphic_on = None
         else:
-            identitykey = self.identity_key_from_row(row)
-        
-        session_identity_map = context.session.identity_map
+            polymorphic_on = self.polymorphic_on
+            polymorphic_instances = util.PopulateDict(self.__configure_subclass_mapper(context, path, adapter))
 
-        if identitykey in session_identity_map:
-            instance = session_identity_map[identitykey]
-            state = instance._state
-
-            if self.__should_log_debug:
-                self.__log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), str(identitykey)))
-
-            isnew = state.runid != context.runid
-            currentload = not isnew
-
-            if not currentload and context.version_check and self.version_id_col and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]:
-                raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col]))
-        elif refresh_instance:
-            # out of band refresh_instance detected (i.e. its not in the session.identity_map)
-            # honor it anyway.  this can happen if a _get() occurs within save_obj(), such as
-            # when eager_defaults is True.
-            state = refresh_instance
-            instance = state.obj()
-            isnew = state.runid != context.runid
-            currentload = True
-        else:
-            if self.__should_log_debug:
-                self.__log_debug("_instance(): identity key %s not in session" % str(identitykey))
+        version_id_col = self.version_id_col
 
-            if self.allow_null_pks:
-                for x in identitykey[1]:
-                    if x is not None:
-                        break
-                else:
-                    return None
-            else:
-                if None in identitykey[1]:
-                    return None
-            isnew = True
-            currentload = True
-
-            if 'create_instance' in extension.methods:
-                instance = extension.create_instance(self, context, row, self.class_)
-                if instance is EXT_CONTINUE:
-                    instance = attributes.new_instance(self.class_)
-                else:
-                    attributes.manage(instance)
-            else:
-                instance = attributes.new_instance(self.class_)
+        if adapter:
+            pk_cols = [adapter.columns[c] for c in pk_cols]
+            if polymorphic_on:
+                polymorphic_on = adapter.columns[polymorphic_on]
+            if version_id_col:
+                version_id_col = adapter.columns[version_id_col]
+
+        identity_class, entity_name = self._identity_class, self.entity_name
+        def identity_key(row):
+            return (identity_class, tuple([row[column] for column in pk_cols]), entity_name)
 
-            if self.__should_log_debug:
-                self.__log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey)))
+        new_populators = []
+        existing_populators = []
 
-            state = instance._state
-            instance._entity_name = self.entity_name
-            instance._instance_key = identitykey
-            instance._sa_session_id = context.session.hash_key
-            session_identity_map[identitykey] = instance
+        def populate_state(state, row, isnew, only_load_props, **flags):
+            if not new_populators:
+                new_populators[:], existing_populators[:] = self.__populators(context, path, row, adapter)
 
-        if currentload or context.populate_existing or self.always_refresh:
             if isnew:
-                state.runid = context.runid
-                context.progress.add(state)
+                populators = new_populators
+            else:
+                populators = existing_populators
 
-            if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
-                self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew)
-        
-        else:
-            # populate attributes on non-loading instances which have been expired
-            # TODO: also support deferred attributes here [ticket:870]
-            if state.expired_attributes: 
-                if state in context.partials:
-                    isnew = False
-                    attrs = context.partials[state]
-                else:
-                    isnew = True
-                    attrs = state.expired_attributes.intersection(state.unmodified)
-                    context.partials[state] = attrs  #<-- allow query.instances to commit the subset of attrs
+            if only_load_props:
+                populators = [p for p in populators if p[0] in only_load_props]
 
-                if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
-                    self.populate_instance(context, instance, row, only_load_props=attrs, instancekey=identitykey, isnew=isnew)
-            
-        if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
-            result.append(instance)
-
-        return instance
-
-    def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, only_load_props=None, **flags):
-        """populate an instance from a result row."""
-
-        snapshot = selectcontext.path + (self,)
-        # retrieve a set of "row population" functions derived from the MapperProperties attached
-        # to this Mapper.  These are keyed in the select context based primarily off the
-        # "snapshot" of the stack, which represents a path from the lead mapper in the query to this one,
-        # including relation() names.  the key also includes "self", and allows us to distinguish between
-        # other mappers within our inheritance hierarchy
-        (new_populators, existing_populators) = selectcontext.attributes.get(('populators', self, snapshot, ispostselect), (None, None))
-        if new_populators is None:
-            # no populators; therefore this is the first time we are receiving a row for
-            # this result set.  issue create_row_processor() on all MapperProperty objects
-            # and cache in the select context.
-            new_populators = []
-            existing_populators = []
-            post_processors = []
-            for prop in self.__props.values():
-                (newpop, existingpop, post_proc) = selectcontext.exec_with_path(self, prop.key, prop.create_row_processor, selectcontext, self, row)
-                if newpop:
-                    new_populators.append((prop.key, newpop))
-                if existingpop:
-                    existing_populators.append((prop.key, existingpop))
-                if post_proc:
-                    post_processors.append(post_proc)
-
-            # install a post processor for immediate post-load of joined-table inheriting mappers
-            poly_select_loader = self._get_poly_select_loader(selectcontext, row)
-            if poly_select_loader:
-                post_processors.append(poly_select_loader)
-
-            selectcontext.attributes[('populators', self, snapshot, ispostselect)] = (new_populators, existing_populators)
-            selectcontext.attributes[('post_processors', self, ispostselect)] = post_processors
-
-        if isnew or ispostselect:
-            populators = new_populators
-        else:
-            populators = existing_populators
+            for key, populator in populators:
+                populator(state, row, isnew=isnew, **flags)
 
-        if only_load_props:
-            populators = [p for p in populators if p[0] in only_load_props]
+        session_identity_map = context.session.identity_map
 
-        for (key, populator) in populators:
-            selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags)
+        if not extension:
+            extension = self.extension
 
-        if self.non_primary:
-            selectcontext.attributes[('populating_mapper', instance._state)] = self
+        translate_row = 'translate_row' in extension.methods
+        create_instance = 'create_instance' in extension.methods
+        populate_instance = 'populate_instance' in extension.methods
+        append_result = 'append_result' in extension.methods
+        populate_existing = context.populate_existing or self.always_refresh
+
+        def _instance(row, result):
+            if translate_row:
+                ret = extension.translate_row(self, context, row)
+                if ret is not EXT_CONTINUE:
+                    row = ret
+
+            if polymorphic_on:
+                discriminator = row[polymorphic_on]
+                if discriminator is not None:
+                    _instance = polymorphic_instances[discriminator]
+                    if _instance:
+                        return _instance(row, result)
+
+            # determine identity key
+            if refresh_instance:
+                # TODO: refresh_instance seems to be named wrongly -- it is always an instance state.
+                refresh_state = refresh_instance
+                identitykey = refresh_state.key
+                if identitykey is None:
+                    # super-rare condition; a refresh is being called
+                    # on a non-instance-key instance; this is meant to only
+                    # occur within a flush()
+                    identitykey = self._identity_key_from_state(refresh_state)
+            else:
+                identitykey = identity_key(row)
 
-    def _post_instance(self, selectcontext, state, **kwargs):
-        post_processors = selectcontext.attributes[('post_processors', self, None)]
-        for p in post_processors:
-            p(state.obj(), **kwargs)
+            if identitykey in session_identity_map:
+                instance = session_identity_map[identitykey]
+                state = attributes.instance_state(instance)
 
-    def _get_poly_select_loader(self, selectcontext, row):
-        """set up attribute loaders for 'select' and 'deferred' polymorphic loading.
+                if self.__should_log_debug:
+                    self.__log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), str(identitykey)))
+
+                isnew = state.runid != context.runid
+                currentload = not isnew
+                loaded_instance = False
+
+                if not currentload and version_id_col and context.version_check and self._get_state_attr_by_column(state, self.version_id_col) != row[version_id_col]:
+                    raise exc.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (state_str(state), self._get_state_attr_by_column(state, self.version_id_col), row[version_id_col]))
+            elif refresh_instance:
+                # out of band refresh_instance detected (i.e. its not in the session.identity_map)
+                # honor it anyway.  this can happen if a _get() occurs within save_obj(), such as
+                # when eager_defaults is True.
+                state = refresh_instance
+                instance = state.obj()
+                isnew = state.runid != context.runid
+                currentload = True
+                loaded_instance = False
+            else:
+                if self.__should_log_debug:
+                    self.__log_debug("_instance(): identity key %s not in session" % str(identitykey))
 
-        this loading uses a second SELECT statement to load additional tables,
-        either immediately after loading the main table or via a deferred attribute trigger.
-        """
+                if self.allow_null_pks:
+                    for x in identitykey[1]:
+                        if x is not None:
+                            break
+                    else:
+                        return None
+                else:
+                    if None in identitykey[1]:
+                        return None
+                isnew = True
+                currentload = True
+                loaded_instance = True
+
+                if create_instance:
+                    instance = extension.create_instance(self, context, row, self.class_)
+                    if instance is EXT_CONTINUE:
+                        instance = self.class_manager.new_instance()
+                    else:
+                        manager = attributes.manager_for_cls(instance.__class__)
+                        # TODO: if manager is None, raise a friendly error about
+                        # returning instances of unmapped types
+                        manager.setup_instance(instance)
+                else:
+                    instance = self.class_manager.new_instance()
 
-        (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None))
+                if self.__should_log_debug:
+                    self.__log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey)))
 
-        if hosted_mapper is None or not needs_tables:
-            return
+                state = attributes.instance_state(instance)
+                state.entity_name = self.entity_name
+                state.key = identitykey
+                # manually adding instance to session.  for a complete add,
+                # session._finalize_loaded() must be called.
+                state.session_id = context.session.hash_key
+                session_identity_map.add(state)
 
-        cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables)
-        statement = sql.select(needs_tables, cond, use_labels=True)
+            if currentload or populate_existing:
+                if isnew:
+                    state.runid = context.runid
+                    context.progress.add(state)
 
-        if hosted_mapper.polymorphic_fetch == 'select':
-            def post_execute(instance, **flags):
-                if self.__should_log_debug:
-                    self.__log_debug("Post query loading instance " + instance_str(instance))
+                if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+                    populate_state(state, row, isnew, only_load_props)
 
-                identitykey = self.identity_key_from_instance(instance)
-                
-                only_load_props = flags.get('only_load_props', None)
+            else:
+                # populate attributes on non-loading instances which have been expired
+                # TODO: also support deferred attributes here [ticket:870]
+                # TODO: apply eager loads to un-lazy loaded collections ?
+                # we might want to create an expanded form of 'state.expired_attributes' which includes deferred/un-lazy loaded
+                if state.expired_attributes:
+                    if state in context.partials:
+                        isnew = False
+                        attrs = context.partials[state]
+                    else:
+                        isnew = True
+                        attrs = state.expired_attributes.intersection(state.unmodified)
+                        context.partials[state] = attrs  #<-- allow query.instances to commit the subset of attrs
 
-                params = {}
-                for c, bind in param_names:
-                    params[bind] = self._get_attr_by_column(instance, c)
-                row = selectcontext.session.connection(self).execute(statement, params).fetchone()
-                self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True, only_load_props=only_load_props)
-            return post_execute
-        elif hosted_mapper.polymorphic_fetch == 'deferred':
-            from sqlalchemy.orm.strategies import DeferredColumnLoader
-
-            def post_execute(instance, **flags):
-                def create_statement(instance):
-                    params = {}
-                    for (c, bind) in param_names:
-                        # use the "committed" (database) version to get query column values
-                        params[bind] = self._get_committed_attr_by_column(instance, c)
-                    return (statement, params)
-
-                props = [prop for prop in [self._get_col_to_prop(col) for col in statement.inner_columns] if prop.key not in instance.__dict__]
-                keys = [p.key for p in props]
-                
-                only_load_props = flags.get('only_load_props', None)
-                if only_load_props:
-                    keys = util.Set(keys).difference(only_load_props)
-                    props = [p for p in props if p.key in only_load_props]
-                    
-                for prop in props:
-                    strategy = prop._get_strategy(DeferredColumnLoader)
-                    instance._state.set_callable(prop.key, strategy.setup_loader(instance, props=keys, create_statement=create_statement))
-            return post_execute
-        else:
-            return None
+                    if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+                        populate_state(state, row, isnew, attrs, instancekey=identitykey)
+
+            if result is not None and (not append_result or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
+                result.append(instance)
 
-    def _deferred_inheritance_condition(self, base_mapper, needs_tables):
-        base_mapper = base_mapper.primary_mapper()
+            if loaded_instance:
+                state._run_on_load(instance)
+
+            return instance
+        return _instance
+
+    def __populators(self, context, path, row, adapter):
+        new_populators, existing_populators = [], []
+        for prop in self.__props.values():
+            newpop, existingpop = prop.create_row_processor(context, path, self, row, adapter)
+            if newpop:
+                new_populators.append((prop.key, newpop))
+            if existingpop:
+                existing_populators.append((prop.key, existingpop))
+        return new_populators, existing_populators
+
+    def __configure_subclass_mapper(self, context, path, adapter):
+        def configure_subclass_mapper(discriminator):
+            try:
+                mapper = self.polymorphic_map[discriminator]
+            except KeyError:
+                raise AssertionError("No such polymorphic_identity %r is defined" % discriminator)
+            if mapper is self:
+                return None
+            return mapper._instance_processor(context, path, adapter, polymorphic_from=self)
+        return configure_subclass_mapper
+
+    def _optimized_get_statement(self, state, attribute_names):
+        props = self.__props
+        tables = util.Set([props[key].parent.local_table for key in attribute_names])
+        if self.base_mapper.local_table in tables:
+            return None
 
         def visit_binary(binary):
             leftcol = binary.left
             rightcol = binary.right
             if leftcol is None or rightcol is None:
                 return
-            if leftcol.table not in needs_tables:
-                binary.left = sql.bindparam(None, None, type_=binary.right.type)
-                param_names.append((leftcol, binary.left))
-            elif rightcol not in needs_tables:
-                binary.right = sql.bindparam(None, None, type_=binary.right.type)
-                param_names.append((rightcol, binary.right))
+
+            if leftcol.table not in tables:
+                binary.left = sql.bindparam(None, self._get_committed_state_attr_by_column(state, leftcol), type_=binary.right.type)
+            elif rightcol.table not in tables:
+                binary.right = sql.bindparam(None, self._get_committed_state_attr_by_column(state, rightcol), type_=binary.right.type)
 
         allconds = []
-        param_names = []
 
-        for mapper in self.iterate_to_root():
-            if mapper is base_mapper:
-                break
-            allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
+        start = False
+        for mapper in util.reversed(list(self.iterate_to_root())):
+            if mapper.local_table in tables:
+                start = True
+            if start:
+                allconds.append(visitors.cloned_traverse(mapper.inherit_condition, {}, {'binary':visit_binary}))
+
+        cond = sql.and_(*allconds)
+        return sql.select(tables, cond, use_labels=True)
+
+Mapper.logger = log.class_logger(Mapper)
+
 
-        return sql.and_(*allconds), param_names
+def _event_on_init(state, instance, args, kwargs):
+    """Trigger mapper compilation and run init_instance hooks."""
 
-Mapper.logger = logging.class_logger(Mapper)
+    instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
+    # compile() always compiles all mappers
+    instrumenting_mapper.compile()
+    if 'init_instance' in instrumenting_mapper.extension.methods:
+        instrumenting_mapper.extension.init_instance(
+            instrumenting_mapper, instrumenting_mapper.class_,
+            state.manager.events.original_init,
+            instance, args, kwargs)
 
+def _event_on_init_failure(state, instance, args, kwargs):
+    """Run init_failed hooks."""
 
-object_session = None
+    instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
+    if 'init_failed' in instrumenting_mapper.extension.methods:
+        util.warn_exception(
+            instrumenting_mapper.extension.init_failed,
+            instrumenting_mapper, instrumenting_mapper.class_,
+            state.manager.events.original_init, instance, args, kwargs)
 
-def _load_scalar_attributes(instance, attribute_names):
-    mapper = object_mapper(instance)
-    global object_session
-    if not object_session:
-        from sqlalchemy.orm.session import object_session
-    session = object_session(instance)
+def _legacy_descriptors():
+    """Build compatibility descriptors mapping legacy to InstanceState.
+
+    These are slated for removal in 0.5.  They were never part of the
+    official public API but were suggested as temporary workarounds in a
+    number of mailing list posts.  Permanent and public solutions for those
+    needs should be available now.  Consult the applicable mailing list
+    threads for details.
+
+    """
+    def _instance_key(self):
+        state = attributes.instance_state(self)
+        if state.key is not None:
+            return state.key
+        else:
+            raise AttributeError("_instance_key")
+    _instance_key = util.deprecated(None, False)(_instance_key)
+    _instance_key = property(_instance_key)
+
+    def _sa_session_id(self):
+        state = attributes.instance_state(self)
+        if state.session_id is not None:
+            return state.session_id
+        else:
+            raise AttributeError("_sa_session_id")
+    _sa_session_id = util.deprecated(None, False)(_sa_session_id)
+    _sa_session_id = property(_sa_session_id)
+
+    def _entity_name(self):
+        state = attributes.instance_state(self)
+        if state.entity_name is attributes.NO_ENTITY_NAME:
+            return None
+        else:
+            return state.entity_name
+    _entity_name = util.deprecated(None, False)(_entity_name)
+    _entity_name = property(_entity_name)
+
+    return dict(locals())
+_legacy_descriptors = _legacy_descriptors()
+
+def _load_scalar_attributes(state, attribute_names):
+    mapper = _state_mapper(state)
+    session = _state_session(state)
     if not session:
-        try:
-            session = mapper.get_session()
-        except exceptions.InvalidRequestError:
-            raise exceptions.UnboundExecutionError("Instance %s is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed" % (instance.__class__))
-
-    state = instance._state
-    if '_instance_key' in state.dict:
-        identity_key = state.dict['_instance_key']
-        shouldraise = True
-    else:
-        # if instance is pending, a refresh operation may not complete (even if PK attributes are assigned)
-        shouldraise = False
-        identity_key = mapper._identity_key_from_state(state)
-
-    if session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names) is None and shouldraise:
-        raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance))
+        raise sa_exc.UnboundExecutionError("Instance %s is not bound to a Session; attribute refresh operation cannot proceed" % (state_str(state)))
+
+    has_key = _state_has_identity(state)
+
+    result = False
+    if mapper.inherits and not mapper.concrete:
+        statement = mapper._optimized_get_statement(state, attribute_names)
+        if statement:
+            result = session.query(mapper).from_statement(statement)._get(None, only_load_props=attribute_names, refresh_instance=state)
+
+    if result is False:
+        if has_key:
+            identity_key = state.key
+        else:
+            identity_key = mapper._identity_key_from_state(state)
+        result = session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names)
 
+    # if instance is pending, a refresh operation may not complete (even if PK attributes are assigned)
+    if has_key and result is None:
+        raise exc.ObjectDeletedError("Instance '%s' has been deleted." % state_str(state))
index 33a0ff4326dad1ce67121c2d1fd5fc545904d8ba..fc2e901892b1bf0d534f30d9aa5b9ff89533fedb 100644 (file)
@@ -6,19 +6,20 @@
 
 """MapperProperty implementations.
 
-This is a private module which defines the behavior of
-invidual ORM-mapped attributes.
+This is a private module which defines the behavior of invidual ORM-mapped
+attributes.
+
 """
 
-from sqlalchemy import sql, schema, util, exceptions, logging
-from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, find_columns
-from sqlalchemy.sql import visitors, operators, ColumnElement
-from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
-from sqlalchemy.orm import session as sessionlib
-from sqlalchemy.orm.mapper import _class_to_mapper
-from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses
-from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty, ONETOMANY, MANYTOONE, MANYTOMANY
-from sqlalchemy.exceptions import ArgumentError
+from sqlalchemy import sql, util, log
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs
+from sqlalchemy.sql import operators, ColumnElement, expression
+from sqlalchemy.orm import mapper, strategies, attributes, dependency, \
+     object_mapper, session as sessionlib
+from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, _orm_annotate
+from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, \
+     MapperProperty, ONETOMANY, MANYTOONE, MANYTOMANY
 
 __all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty',
            'ComparableProperty', 'PropertyLoader', 'BackRef')
@@ -34,18 +35,15 @@ class ColumnProperty(StrategizedProperty):
         appears across each table.
         """
 
-        self.columns = list(columns)
+        self.columns = [expression._labeled(c) for c in columns]
         self.group = kwargs.pop('group', None)
         self.deferred = kwargs.pop('deferred', False)
         self.comparator = ColumnProperty.ColumnComparator(self)
+        util.set_creation_order(self)
         if self.deferred:
             self.strategy_class = strategies.DeferredColumnLoader
         else:
             self.strategy_class = strategies.ColumnLoader
-        # sanity check
-        for col in columns:
-            if not isinstance(col, ColumnElement):
-                raise ArgumentError('column_property() must be given a ColumnElement as its argument.  Try .label() or .as_scalar() for Selectables to fix this.')
 
     def do_init(self):
         super(ColumnProperty, self).do_init()
@@ -61,37 +59,41 @@ class ColumnProperty(StrategizedProperty):
         return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
 
     def getattr(self, state, column):
-        return getattr(state.class_, self.key).impl.get(state)
+        return state.get_impl(self.key).get(state)
 
     def getcommitted(self, state, column):
-        return getattr(state.class_, self.key).impl.get_committed_value(state)
+        return state.get_impl(self.key).get_committed_value(state)
 
     def setattr(self, state, value, column):
-        getattr(state.class_, self.key).impl.set(state, value, None)
+        state.get_impl(self.key).set(state, value, None)
 
     def merge(self, session, source, dest, dont_load, _recursive):
-        value = attributes.get_as_list(source._state, self.key, passive=True)
+        value = attributes.instance_state(source).value_as_iterable(
+            self.key, passive=True)
         if value:
             setattr(dest, self.key, value[0])
         else:
-            # TODO: lazy callable should merge to the new instance
-            dest._state.expire_attributes([self.key])
+            attributes.instance_state(dest).expire_attributes([self.key])
 
     def get_col_value(self, column, value):
         return value
 
     class ColumnComparator(PropComparator):
-        def clause_element(self):
-            return self.prop.columns[0]
-
+        def __clause_element__(self):
+            return self.prop.columns[0]._annotate({"parententity": self.prop.parent})
+        __clause_element__ = util.cache_decorator(__clause_element__)
+        
         def operate(self, op, *other, **kwargs):
-            return op(self.prop.columns[0], *other, **kwargs)
+            return op(self.__clause_element__(), *other, **kwargs)
 
         def reverse_operate(self, op, other, **kwargs):
-            col = self.prop.columns[0]
+            col = self.__clause_element__()
             return op(col._bind_param(other), col, **kwargs)
 
-ColumnProperty.logger = logging.class_logger(ColumnProperty)
+    def __str__(self):
+        return str(self.parent.class_.__name__) + "." + self.key
+
+ColumnProperty.logger = log.class_logger(ColumnProperty)
 
 class CompositeProperty(ColumnProperty):
     """subclasses ColumnProperty to provide composite type support."""
@@ -100,6 +102,7 @@ class CompositeProperty(ColumnProperty):
         super(CompositeProperty, self).__init__(*columns, **kwargs)
         self.composite_class = class_
         self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self)
+        self.strategy_class = strategies.CompositeColumnLoader
 
     def do_init(self):
         super(ColumnProperty, self).do_init()
@@ -109,19 +112,19 @@ class CompositeProperty(ColumnProperty):
         return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
 
     def getattr(self, state, column):
-        obj = getattr(state.class_, self.key).impl.get(state)
+        obj = state.get_impl(self.key).get(state)
         return self.get_col_value(column, obj)
 
     def getcommitted(self, state, column):
-        obj = getattr(state.class_, self.key).impl.get_committed_value(state)
+        obj = state.get_impl(self.key).get_committed_value(state)
         return self.get_col_value(column, obj)
 
     def setattr(self, state, value, column):
         # TODO: test coverage for this method
-        obj = getattr(state.class_, self.key).impl.get(state)
+        obj = state.get_impl(self.key).get(state)
         if obj is None:
             obj = self.composite_class(*[None for c in self.columns])
-            getattr(state.class_, self.key).impl.set(state, obj, None)
+            state.get_impl(self.key).set(state, obj, None)
 
         for a, b in zip(self.columns, value.__composite_values__()):
             if a is column:
@@ -133,6 +136,9 @@ class CompositeProperty(ColumnProperty):
                 return b
 
     class Comparator(PropComparator):
+        def __clause_element__(self):
+            return expression.ClauseList(*self.prop.columns)
+
         def __eq__(self, other):
             if other is None:
                 return sql.and_(*[a==None for a in self.prop.columns])
@@ -146,17 +152,21 @@ class CompositeProperty(ColumnProperty):
                              zip(self.prop.columns,
                                  other.__composite_values__())])
 
+    def __str__(self):
+        return str(self.parent.class_.__name__) + "." + self.key
+
 class SynonymProperty(MapperProperty):
     def __init__(self, name, map_column=None, descriptor=None):
         self.name = name
-        self.map_column=map_column
+        self.map_column = map_column
         self.descriptor = descriptor
+        util.set_creation_order(self)
 
-    def setup(self, querycontext, **kwargs):
+    def setup(self, context, entity, path, adapter, **kwargs):
         pass
 
-    def create_row_processor(self, selectcontext, mapper, row):
-        return (None, None, None)
+    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+        return (None, None)
 
     def do_init(self):
         class_ = self.parent.class_
@@ -174,12 +184,11 @@ class SynonymProperty(MapperProperty):
                         return s
                     return getattr(obj, self.name)
             self.descriptor = SynonymProp()
-        sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator)
+        sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator, parententity=self.parent)
 
     def merge(self, session, source, dest, _recursive):
         pass
-SynonymProperty.logger = logging.class_logger(SynonymProperty)
-
+SynonymProperty.logger = log.class_logger(SynonymProperty)
 
 class ComparableProperty(MapperProperty):
     """Instruments a Python property for use in query expressions."""
@@ -187,6 +196,7 @@ class ComparableProperty(MapperProperty):
     def __init__(self, comparator_factory, descriptor=None):
         self.descriptor = descriptor
         self.comparator = comparator_factory(self)
+        util.set_creation_order(self)
 
     def do_init(self):
         """Set up a proxy to the unmanaged descriptor."""
@@ -198,11 +208,11 @@ class ComparableProperty(MapperProperty):
                                       useobject=False,
                                       comparator=self.comparator)
 
-    def setup(self, querycontext, **kwargs):
+    def setup(self, context, entity, path, adapter, **kwargs):
         pass
 
-    def create_row_processor(self, selectcontext, mapper, row):
-        return (None, None, None)
+    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+        return (None, None)
 
 
 class PropertyLoader(StrategizedProperty):
@@ -210,7 +220,22 @@ class PropertyLoader(StrategizedProperty):
     of items that correspond to a related database table.
     """
 
-    def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, passive_updates=True, remote_side=None, enable_typechecks=True, join_depth=None, strategy_class=None, _local_remote_pairs=None):
+    def __init__(self, argument, 
+        secondary=None, primaryjoin=None, 
+        secondaryjoin=None, entity_name=None, 
+        foreign_keys=None, 
+        uselist=None, 
+        order_by=False, 
+        backref=None, 
+        _is_backref=False, 
+        post_update=False, 
+        cascade=None, 
+        viewonly=False, lazy=True, 
+        collection_class=None, passive_deletes=False, 
+        passive_updates=True, remote_side=None, 
+        enable_typechecks=True, join_depth=None, 
+        strategy_class=None, _local_remote_pairs=None):
+        
         self.uselist = uselist
         self.argument = argument
         self.entity_name = entity_name
@@ -222,9 +247,6 @@ class PropertyLoader(StrategizedProperty):
         self.viewonly = viewonly
         self.lazy = lazy
         self.foreign_keys = util.to_set(foreign_keys)
-        self._legacy_foreignkey = util.to_set(foreignkey)
-        if foreignkey:
-            util.warn_deprecated('foreignkey option is deprecated; see docs for details')
         self.collection_class = collection_class
         self.passive_deletes = passive_deletes
         self.passive_updates = passive_updates
@@ -233,6 +255,8 @@ class PropertyLoader(StrategizedProperty):
         self.comparator = PropertyLoader.Comparator(self)
         self.join_depth = join_depth
         self._arg_local_remote_pairs = _local_remote_pairs
+        self.__join_cache = {}
+        util.set_creation_order(self)
         
         if strategy_class:
             self.strategy_class = strategy_class
@@ -251,20 +275,13 @@ class PropertyLoader(StrategizedProperty):
         if cascade is not None:
             self.cascade = CascadeOptions(cascade)
         else:
-            if private:
-                util.warn_deprecated('private option is deprecated; see docs for details')
-                self.cascade = CascadeOptions("all, delete-orphan")
-            else:
-                self.cascade = CascadeOptions("save-update, merge")
+            self.cascade = CascadeOptions("save-update, merge")
 
         if self.passive_deletes == 'all' and ("delete" in self.cascade or "delete-orphan" in self.cascade):
-            raise exceptions.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade")
+            raise sa_exc.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade")
 
-        self.association = association
-        if association:
-            util.warn_deprecated('association option is deprecated; see docs for details')
         self.order_by = order_by
-        self.attributeext=attributeext
+
         if isinstance(backref, str):
             # propigate explicitly sent primary/secondary join conditions to the BackRef object if
             # just a string was sent
@@ -275,14 +292,21 @@ class PropertyLoader(StrategizedProperty):
                 self.backref = BackRef(backref, primaryjoin=primaryjoin, secondaryjoin=secondaryjoin, passive_updates=self.passive_updates)
         else:
             self.backref = backref
-        self.is_backref = is_backref
-
+        self._is_backref = _is_backref
+    
     class Comparator(PropComparator):
         def __init__(self, prop, of_type=None):
             self.prop = self.property = prop
             if of_type:
                 self._of_type = _class_to_mapper(of_type)
         
+        def parententity(self):
+            return self.prop.parent
+        parententity = property(parententity)
+        
+        def __clause_element__(self):
+            return self.prop.parent._with_polymorphic_selectable
+            
         def of_type(self, cls):
             return PropertyLoader.Comparator(self.prop, cls)
             
@@ -294,7 +318,7 @@ class PropertyLoader(StrategizedProperty):
                     return self.prop._optimized_compare(None)
             elif self.prop.uselist:
                 if not hasattr(other, '__iter__'):
-                    raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object.  Use contains().")
+                    raise sa_exc.InvalidRequestError("Can only compare a collection to an iterable object.  Use contains().")
                 else:
                     j = self.prop.primaryjoin
                     if self.prop.secondaryjoin:
@@ -308,60 +332,62 @@ class PropertyLoader(StrategizedProperty):
             else:
                 return self.prop._optimized_compare(other)
 
-        def _join_and_criterion(self, criterion=None, **kwargs):
+        def __criterion_exists(self, criterion=None, **kwargs):
             if getattr(self, '_of_type', None):
                 target_mapper = self._of_type
-                to_selectable = target_mapper._with_polymorphic_selectable() #mapped_table
+                to_selectable = target_mapper._with_polymorphic_selectable
             else:
                 to_selectable = None
             
-            pj, sj, source, dest, target_adapter = self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable)
+            pj, sj, source, dest, secondary, target_adapter = self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable)
 
             for k in kwargs:
-                crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
+                crit = self.prop.mapper.class_manager.get_inst(k) == kwargs[k]
                 if criterion is None:
                     criterion = crit
                 else:
                     criterion = criterion & crit
             
             if sj:
-                j = pj & sj
+                j = _orm_annotate(pj) & sj
             else:
-                j = pj
+                j = _orm_annotate(pj, exclude=self.prop.remote_side)
                 
             if criterion and target_adapter:
+                # limit this adapter to annotated only?
                 criterion = target_adapter.traverse(criterion)
             
-            return j, criterion, dest
+            # only have the "joined left side" of what we return be subject to Query adaption.  The right
+            # side of it is used for an exists() subquery and should not correlate or otherwise reach out
+            # to anything in the enclosing query.
+            if criterion:
+                criterion = criterion._annotate({'_halt_adapt': True})
+            return sql.exists([1], j & criterion, from_obj=dest).correlate(source)
             
         def any(self, criterion=None, **kwargs):
             if not self.prop.uselist:
-                raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
-            j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs)
+                raise sa_exc.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
 
-            return sql.exists([1], j & criterion, from_obj=from_obj)
+            return self.__criterion_exists(criterion, **kwargs)
 
         def has(self, criterion=None, **kwargs):
             if self.prop.uselist:
-                raise exceptions.InvalidRequestError("'has()' not implemented for collections.  Use any().")
-            j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs)
-
-            return sql.exists([1], j & criterion, from_obj=from_obj)
+                raise sa_exc.InvalidRequestError("'has()' not implemented for collections.  Use any().")
+            return self.__criterion_exists(criterion, **kwargs)
 
         def contains(self, other):
             if not self.prop.uselist:
-                raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes.  Use ==")
+                raise sa_exc.InvalidRequestError("'contains' not implemented for scalar attributes.  Use ==")
             clause = self.prop._optimized_compare(other)
 
             if self.prop.secondaryjoin:
-                clause.negation_clause = self._negated_contains_or_equals(other)
+                clause.negation_clause = self.__negated_contains_or_equals(other)
 
             return clause
 
-        def _negated_contains_or_equals(self, other):
+        def __negated_contains_or_equals(self, other):
             criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
-            j, criterion, from_obj = self._join_and_criterion(criterion)
-            return ~sql.exists([1], j & criterion, from_obj=from_obj)
+            return ~self.__criterion_exists(criterion)
             
         def __ne__(self, other):
             if other is None:
@@ -373,9 +399,9 @@ class PropertyLoader(StrategizedProperty):
                     return self.has()
 
             if self.prop.uselist and not hasattr(other, '__iter__'):
-                raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
+                raise sa_exc.InvalidRequestError("Can only compare a collection to an iterable object")
 
-            return self._negated_contains_or_equals(other)
+            return self.__negated_contains_or_equals(other)
 
     def compare(self, op, value, value_is_parent=False):
         if op == operators.eq:
@@ -390,27 +416,29 @@ class PropertyLoader(StrategizedProperty):
             return op(self.comparator, value)
 
     def _optimized_compare(self, value, value_is_parent=False):
+        if value is not None:
+            value = attributes.instance_state(value)
         return self._get_strategy(strategies.LazyLoader).lazy_clause(value, reverse_direction=not value_is_parent)
 
-    def private(self):
-        return self.cascade.delete_orphan
-    private = property(private)
-
     def __str__(self):
-        return str(self.parent.class_.__name__) + "." + self.key + " (" + str(self.mapper.class_.__name__)  + ")"
+        return str(self.parent.class_.__name__) + "." + self.key
 
     def merge(self, session, source, dest, dont_load, _recursive):
         if not dont_load and self._reverse_property and (source, self._reverse_property) in _recursive:
             return
-            
+
+        source_state = attributes.instance_state(source)
+        dest_state = attributes.instance_state(dest)
+
         if not "merge" in self.cascade:
-            dest._state.expire_attributes([self.key])
+            dest_state.expire_attributes([self.key])
             return
 
-        instances = attributes.get_as_list(source._state, self.key, passive=True)
+        instances = source_state.value_as_iterable(self.key, passive=True)
+
         if not instances:
             return
-        
+
         if self.uselist:
             dest_list = []
             for current in instances:
@@ -419,11 +447,11 @@ class PropertyLoader(StrategizedProperty):
                 if obj is not None:
                     dest_list.append(obj)
             if dont_load:
-                coll = attributes.init_collection(dest, self.key)
+                coll = attributes.init_collection(dest_state, self.key)
                 for c in dest_list:
                     coll.append_without_event(c) 
             else:
-                getattr(dest.__class__, self.key).impl._set_iterable(dest._state, dest_list)
+                getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_list)
         else:
             current = instances[0]
             if current is not None:
@@ -440,17 +468,17 @@ class PropertyLoader(StrategizedProperty):
             return
         passive = type_ != 'delete' or self.passive_deletes
         mapper = self.mapper.primary_mapper()
-        instances = attributes.get_as_list(state, self.key, passive=passive)
+        instances = state.value_as_iterable(self.key, passive=passive)
         if instances:
             for c in instances:
                 if c is not None and c not in visited_instances and (halt_on is None or not halt_on(c)):
                     if not isinstance(c, self.mapper.class_):
-                        raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
+                        raise AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
                     visited_instances.add(c)
 
                     # cascade using the mapper local to this object, so that its individual properties are located
                     instance_mapper = object_mapper(c, entity_name=mapper.entity_name)
-                    yield (c, instance_mapper, c._state)
+                    yield (c, instance_mapper, attributes.instance_state(c))
 
     def _get_target_class(self):
         """Return the target class of the relation, even if the
@@ -479,7 +507,8 @@ class PropertyLoader(StrategizedProperty):
             # accept a callable to suit various deferred-configurational schemes
             self.mapper = mapper.class_mapper(self.argument(), entity_name=self.entity_name, compile=False)
         else:
-            raise exceptions.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument)))
+            raise sa_exc.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument)))
+        assert isinstance(self.mapper, mapper.Mapper), self.mapper
 
         if not self.parent.concrete:
             for inheriting in self.parent.iterate_to_root():
@@ -495,14 +524,14 @@ class PropertyLoader(StrategizedProperty):
 
         if self.cascade.delete_orphan:
             if self.parent.class_ is self.mapper.class_:
-                raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade "
+                raise sa_exc.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade "
                             "rule on a self-referential relationship.  "
                             "You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self)))
             self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_))
 
     def __determine_joins(self):
         if self.secondaryjoin is not None and self.secondary is None:
-            raise exceptions.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument")
+            raise sa_exc.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument")
         # if join conditions were not specified, figure them out based on foreign keys
 
         def _search_for_join(mapper, table):
@@ -512,7 +541,7 @@ class PropertyLoader(StrategizedProperty):
             is a join."""
             try:
                 return sql.join(mapper.local_table, table)
-            except exceptions.ArgumentError, e:
+            except sa_exc.ArgumentError, e:
                 return sql.join(mapper.mapped_table, table)
 
         try:
@@ -524,8 +553,8 @@ class PropertyLoader(StrategizedProperty):
             else:
                 if self.primaryjoin is None:
                     self.primaryjoin = _search_for_join(self.parent, self.target).onclause
-        except exceptions.ArgumentError, e:
-            raise exceptions.ArgumentError("Could not determine join condition between parent/child tables on relation %s.  "
+        except sa_exc.ArgumentError, e:
+            raise sa_exc.ArgumentError("Could not determine join condition between parent/child tables on relation %s.  "
                         "Specify a 'primaryjoin' expression.  If this is a many-to-many relation, 'secondaryjoin' is needed as well." % (self))
 
 
@@ -540,14 +569,11 @@ class PropertyLoader(StrategizedProperty):
         
     def __determine_fks(self):
 
-        if self._legacy_foreignkey and not self._refers_to_parent_table():
-            self.foreign_keys = self._legacy_foreignkey
-
         arg_foreign_keys = self.foreign_keys
 
         if self._arg_local_remote_pairs:
             if not arg_foreign_keys:
-                raise exceptions.ArgumentError("foreign_keys argument is required with _local_remote_pairs argument")
+                raise sa_exc.ArgumentError("foreign_keys argument is required with _local_remote_pairs argument")
             self.foreign_keys = util.OrderedSet(arg_foreign_keys)
             self._opposite_side = util.OrderedSet()
             for l, r in self._arg_local_remote_pairs:
@@ -562,15 +588,15 @@ class PropertyLoader(StrategizedProperty):
 
             if not eq_pairs:
                 if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
-                    raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for primaryjoin condition '%s' on relation %s. "
+                    raise sa_exc.ArgumentError("Could not locate any equated, locally mapped column pairs for primaryjoin condition '%s' on relation %s. "
                         "For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.primaryjoin, self)
                     )
                 else:
                     if arg_foreign_keys:
-                        raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
+                        raise sa_exc.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
                             "Specify _local_remote_pairs=[(local, remote), (local, remote), ...] to explicitly establish the local/remote column pairs." % (self.primaryjoin, self))
                     else:
-                        raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
+                        raise sa_exc.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
                             "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self))
         
             self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs])
@@ -583,11 +609,11 @@ class PropertyLoader(StrategizedProperty):
             
             if not sq_pairs:
                 if not self.viewonly and criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
-                    raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for secondaryjoin condition '%s' on relation %s. "
+                    raise sa_exc.ArgumentError("Could not locate any equated, locally mapped column pairs for secondaryjoin condition '%s' on relation %s. "
                         "For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.secondaryjoin, self)
                     )
                 else:
-                    raise exceptions.ArgumentError("Could not determine relation direction for secondaryjoin condition '%s', on relation %s. "
+                    raise sa_exc.ArgumentError("Could not determine relation direction for secondaryjoin condition '%s', on relation %s. "
                     "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.secondaryjoin, self))
 
             self.foreign_keys.update([r for l, r in sq_pairs])
@@ -599,7 +625,7 @@ class PropertyLoader(StrategizedProperty):
     def __determine_remote_side(self):
         if self._arg_local_remote_pairs:
             if self.remote_side:
-                raise exceptions.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.")
+                raise sa_exc.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.")
             if self.direction is MANYTOONE:
                 eq_pairs = [(r, l) for l, r in self._arg_local_remote_pairs]
             else:
@@ -629,11 +655,11 @@ class PropertyLoader(StrategizedProperty):
         if self.direction is ONETOMANY:
             for l in self.local_side:
                 if not self.__col_is_part_of_mappings(l):
-                    raise exceptions.ArgumentError("Local column '%s' is not part of mapping %s.  Specify remote_side argument to indicate which column lazy join condition should compare against." % (l, self.parent))
+                    raise sa_exc.ArgumentError("Local column '%s' is not part of mapping %s.  Specify remote_side argument to indicate which column lazy join condition should compare against." % (l, self.parent))
         elif self.direction is MANYTOONE:
             for r in self.remote_side:
                 if not self.__col_is_part_of_mappings(r):
-                    raise exceptions.ArgumentError("Remote column '%s' is not part of mapping %s.  Specify remote_side argument to indicate which column lazy join condition should bind." % (r, self.mapper))
+                    raise sa_exc.ArgumentError("Remote column '%s' is not part of mapping %s.  Specify remote_side argument to indicate which column lazy join condition should bind." % (r, self.mapper))
             
     def __determine_direction(self):
         """Determine our *direction*, i.e. do we represent one to
@@ -646,13 +672,7 @@ class PropertyLoader(StrategizedProperty):
             # for a self referential mapper, if the "foreignkey" is a single or composite primary key,
             # then we are "many to one", since the remote site of the relationship identifies a singular entity.
             # otherwise we are "one to many".
-            if self._legacy_foreignkey:
-                for f in self._legacy_foreignkey:
-                    if not f.primary_key:
-                        self.direction = ONETOMANY
-                    else:
-                        self.direction = MANYTOONE
-            elif self._arg_local_remote_pairs:
+            if self._arg_local_remote_pairs:
                 remote = util.Set([r for l, r in self._arg_local_remote_pairs])
                 if self.foreign_keys.intersection(remote):
                     self.direction = ONETOMANY
@@ -671,7 +691,7 @@ class PropertyLoader(StrategizedProperty):
                 manytoone = [c for c in self.foreign_keys if parenttable.c.contains_column(c)]
 
                 if not onetomany and not manytoone:
-                    raise exceptions.ArgumentError(
+                    raise sa_exc.ArgumentError(
                         "Can't determine relation direction for relationship '%s' "
                         "- foreign key columns are present in neither the "
                         "parent nor the child's mapped tables" %(str(self)))
@@ -684,14 +704,14 @@ class PropertyLoader(StrategizedProperty):
                     self.direction = MANYTOONE
                     break
             else:
-                raise exceptions.ArgumentError(
+                raise sa_exc.ArgumentError(
                     "Can't determine relation direction for relationship '%s' "
                     "- foreign key columns are present in both the parent and "
                     "the child's mapped tables.  Specify 'foreign_keys' "
                     "argument." % (str(self)))
 
     def _post_init(self):
-        if logging.is_info_enabled(self.logger):
+        if log.is_info_enabled(self.logger):
             self.logger.info(str(self) + " setup primary join %s" % self.primaryjoin)
             self.logger.info(str(self) + " setup secondary join %s" % self.secondaryjoin)
             self.logger.info(str(self) + " synchronize pairs [%s]" % ",".join(["(%s => %s)" % (l, r) for l, r in self.synchronize_pairs]))
@@ -710,15 +730,10 @@ class PropertyLoader(StrategizedProperty):
 
         # primary property handler, set up class attributes
         if self.is_primary():
-            # if a backref name is defined, set up an extension to populate
-            # attributes in the other direction
-            if self.backref is not None:
-                self.attributeext = self.backref.get_extension()
-
             if self.backref is not None:
                 self.backref.compile(self)
         elif not mapper.class_mapper(self.parent.class_, compile=False)._get_property(self.key, raiseerr=False):
-            raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'.  New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__))
+            raise sa_exc.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'.  New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__))
 
         super(PropertyLoader, self).do_init()
 
@@ -729,50 +744,69 @@ class PropertyLoader(StrategizedProperty):
         return self.mapper.common_parent(self.parent)
     
     def _create_joins(self, source_polymorphic=False, source_selectable=None, dest_polymorphic=False, dest_selectable=None):
+        key = util.WeakCompositeKey(source_polymorphic, source_selectable, dest_polymorphic, dest_selectable)
+        try:
+            return self.__join_cache[key]
+        except KeyError:
+            pass
+
         if source_selectable is None:
             if source_polymorphic and self.parent.with_polymorphic:
-                source_selectable = self.parent._with_polymorphic_selectable()
-            else:
-                source_selectable = None
+                source_selectable = self.parent._with_polymorphic_selectable
+
+        aliased = False
         if dest_selectable is None:
             if dest_polymorphic and self.mapper.with_polymorphic:
-                dest_selectable = self.mapper._with_polymorphic_selectable()
+                dest_selectable = self.mapper._with_polymorphic_selectable
+                aliased = True
             else:
                 dest_selectable = self.mapper.mapped_table
-            if self._is_self_referential():
+                
+            if self._is_self_referential() and source_selectable is None:
+                dest_selectable = dest_selectable.alias()
+                aliased = True
+        else:
+            aliased = True
+            
+        aliased = aliased or bool(source_selectable)
+
+        primaryjoin, secondaryjoin, secondary = self.primaryjoin, self.secondaryjoin, self.secondary
+        if aliased:
+            if secondary:
+                secondary = secondary.alias()
+                primary_aliasizer = ClauseAdapter(secondary)
                 if dest_selectable:
-                    dest_selectable = dest_selectable.alias()
+                    secondary_aliasizer = ClauseAdapter(dest_selectable, equivalents=self.mapper._equivalent_columns).chain(primary_aliasizer)
                 else:
-                    dest_selectable = self.mapper.mapped_table.alias()
-                
-        primaryjoin = self.primaryjoin
-        if source_selectable:
-            if self.direction in (ONETOMANY, MANYTOMANY):
-                primaryjoin = ClauseAdapter(source_selectable, exclude=self.foreign_keys, equivalents=self.parent._equivalent_columns).traverse(primaryjoin)
+                    secondary_aliasizer = primary_aliasizer
+
+                if source_selectable:
+                    primary_aliasizer = ClauseAdapter(secondary).chain(ClauseAdapter(source_selectable, equivalents=self.parent._equivalent_columns))
+
+                secondaryjoin = secondary_aliasizer.traverse(secondaryjoin)
             else:
-                primaryjoin = ClauseAdapter(source_selectable, include=self.foreign_keys, equivalents=self.parent._equivalent_columns).traverse(primaryjoin)
+                if dest_selectable:
+                    primary_aliasizer = ClauseAdapter(dest_selectable, exclude=self.local_side, equivalents=self.mapper._equivalent_columns)
+                    if source_selectable: 
+                        primary_aliasizer.chain(ClauseAdapter(source_selectable, exclude=self.remote_side, equivalents=self.parent._equivalent_columns))
+                elif source_selectable:
+                    primary_aliasizer = ClauseAdapter(source_selectable, exclude=self.remote_side, equivalents=self.parent._equivalent_columns)
+
+                secondary_aliasizer = None
         
-        secondaryjoin = self.secondaryjoin
-        target_adapter = None
-        if dest_selectable:
-            if self.direction == ONETOMANY:
-                target_adapter = ClauseAdapter(dest_selectable, include=self.foreign_keys, equivalents=self.mapper._equivalent_columns)
-            elif self.direction == MANYTOMANY:
-                target_adapter = ClauseAdapter(dest_selectable, equivalents=self.mapper._equivalent_columns)
-            else:
-                target_adapter = ClauseAdapter(dest_selectable, exclude=self.foreign_keys, equivalents=self.mapper._equivalent_columns)
-            if secondaryjoin:
-                secondaryjoin = target_adapter.traverse(secondaryjoin)
-            else:
-                primaryjoin = target_adapter.traverse(primaryjoin)
+            primaryjoin = primary_aliasizer.traverse(primaryjoin)
+            target_adapter = secondary_aliasizer or primary_aliasizer
             target_adapter.include = target_adapter.exclude = None
-            
-        return primaryjoin, secondaryjoin, source_selectable or self.parent.local_table, dest_selectable or self.mapper.local_table, target_adapter
+        else:
+            target_adapter = None
+
+        self.__join_cache[key] = ret = (primaryjoin, secondaryjoin, (source_selectable or self.parent.local_table), (dest_selectable or self.mapper.local_table), secondary, target_adapter)
+        return ret
         
     def _get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True):
         """deprecated.  use primary_join_against(), secondary_join_against(), full_join_against()"""
         
-        pj, sj, source, dest, adapter = self._create_joins(source_polymorphic=polymorphic_parent)
+        pj, sj, source, dest, secondarytable, adapter = self._create_joins(source_polymorphic=polymorphic_parent)
         
         if primary and secondary:
             return pj & sj
@@ -788,7 +822,7 @@ class PropertyLoader(StrategizedProperty):
         if not self.viewonly:
             self._dependency_processor.register_dependencies(uowcommit)
 
-PropertyLoader.logger = logging.class_logger(PropertyLoader)
+PropertyLoader.logger = log.class_logger(PropertyLoader)
 
 class BackRef(object):
     """Attached to a PropertyLoader to indicate a complementary reverse relationship.
@@ -799,7 +833,8 @@ class BackRef(object):
         self.key = key
         self.kwargs = kwargs
         self.prop = _prop
-
+        self.extension = attributes.GenericBackrefExtension(self.key)
+        
     def compile(self, prop):
         if self.prop:
             return
@@ -817,7 +852,7 @@ class BackRef(object):
 
             relation = PropertyLoader(parent, prop.secondary, pj, sj,
                                       backref=BackRef(prop.key, _prop=prop),
-                                      is_backref=True,
+                                      _is_backref=True,
                                       **self.kwargs)
 
             mapper._compile_property(self.key, relation);
@@ -826,12 +861,7 @@ class BackRef(object):
             mapper._get_property(self.key)._reverse_property = prop
 
         else:
-            raise exceptions.ArgumentError("Error creating backref '%s' on relation '%s': property of that name exists on mapper '%s'" % (self.key, prop, mapper))
-
-    def get_extension(self):
-        """Return an attribute extension to use with this backreference."""
-
-        return attributes.GenericBackrefExtension(self.key)
+            raise sa_exc.ArgumentError("Error creating backref '%s' on relation '%s': property of that name exists on mapper '%s'" % (self.key, prop, mapper))
 
 mapper.ColumnProperty = ColumnProperty
 mapper.SynonymProperty = SynonymProperty
index 8996a758e6c2954490fa7d76e10ff108762fc130..dfa24efee6f3ab69ceb044434781d7dd4d33e788 100644 (file)
@@ -15,35 +15,54 @@ operations at the SQL (non-ORM) level.  ``Query`` differs from ``Select`` in
 that it returns ORM-mapped objects and interacts with an ORM session, whereas
 the ``Select`` construct interacts directly with the database to return
 iterable result sets.
+
 """
 
 from itertools import chain
-from sqlalchemy import sql, util, exceptions, logging
+
+from sqlalchemy import sql, util, log
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import expression, visitors, operators
-from sqlalchemy.orm import mapper, object_mapper
+from sqlalchemy.orm import attributes, interfaces, mapper, object_mapper
+from sqlalchemy.orm.util import _state_mapper, _is_mapped_class, \
+     _is_aliased_class, _entity_descriptor, _entity_info, _class_to_mapper, \
+     _orm_columns, AliasedClass, _orm_selectable, join as orm_join, ORMAdapter
 
-from sqlalchemy.orm.util import _state_mapper, _class_to_mapper, _is_mapped_class, _is_aliased_class
-from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm import interfaces
-from sqlalchemy.orm import attributes
-from sqlalchemy.orm.util import AliasedClass
+__all__ = ['Query', 'QueryContext', 'aliased']
 
-aliased = AliasedClass
 
-__all__ = ['Query', 'QueryContext', 'aliased']
+aliased = AliasedClass
 
+def _generative(*assertions):
+    """mark a method as generative."""
+
+    def decorate(fn):
+        argspec = util.format_argspec_plus(fn)
+        run_assertions = assertions
+        code = "\n".join([
+            "def %s%s:",
+            "    %r",
+            "    self = self._clone()",
+            "    for a in run_assertions:",
+            "        a(self, %r)",
+            "    fn%s",
+            "    return self"
+        ]) % (fn.__name__, argspec['args'], fn.__doc__, fn.__name__, argspec['apply_pos'])
+        env = locals().copy()
+        exec code in env
+        return env[fn.__name__]
+    return decorate
 
 class Query(object):
     """Encapsulates the object-fetching operations provided by Mappers."""
 
-    def __init__(self, class_or_mapper, session=None, entity_name=None):
-        self._session = session
-        
+    def __init__(self, entities, session=None, entity_name=None):
+        self.session = session
+
         self._with_options = []
         self._lockmode = None
-        
-        self._entities = []
         self._order_by = False
         self._group_by = False
         self._distinct = False
@@ -53,51 +72,53 @@ class Query(object):
         self._params = {}
         self._yield_per = None
         self._criterion = None
+        self._correlate = util.Set()
+        self._joinpoint = None
+        self._with_labels = False
         self.__joinable_tables = None
         self._having = None
-        self._column_aggregate = None
         self._populate_existing = False
         self._version_check = False
         self._autoflush = True
-        
         self._attributes = {}
         self._current_path = ()
         self._only_load_props = None
         self._refresh_instance = None
-        
-        self.__init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name))
-
-    def __init_mapper(self, mapper):
-        """populate all instance variables derived from this Query's mapper."""
-        
-        self.mapper = mapper
-        self.table = self._from_obj = self.mapper.mapped_table
-        self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]]))
-        self._extension = self.mapper.extension
-        self._aliases_head = self._aliases_tail = None
-        self._alias_ids = {}
-        self._joinpoint = self.mapper
-        self._entities.append(_PrimaryMapperEntity(self.mapper))
-        if self.mapper.with_polymorphic:
-            self.__set_with_polymorphic(*self.mapper.with_polymorphic)
-        else:
-            self._with_polymorphic = []
-
-    def __generate_alias_ids(self):
-        self._alias_ids = dict([
-            (k, list(v)) for k, v in self._alias_ids.iteritems()
-        ])
+        self._from_obj = None
+        self._entities = []
+        self._polymorphic_adapters = {}
+        self._filter_aliases = None
+        self._from_obj_alias = None
+        self.__currenttables = util.Set()
+
+        for ent in util.to_list(entities):
+            _QueryEntity(self, ent, entity_name=entity_name)
+
+        self.__setup_aliasizers(self._entities)
+
+    def __setup_aliasizers(self, entities):
+        d = {}
+        for ent in entities:
+            for entity in ent.entities:
+                if entity not in d:
+                    mapper, selectable, is_aliased_class = _entity_info(entity, ent.entity_name)
+                    if not is_aliased_class and mapper.with_polymorphic:
+                        with_polymorphic = mapper._with_polymorphic_mappers
+                        self.__mapper_loads_polymorphically_with(mapper, sql_util.ColumnAdapter(selectable, mapper._equivalent_columns))
+                        adapter = None
+                    elif is_aliased_class:
+                        adapter = sql_util.ColumnAdapter(selectable, mapper._equivalent_columns)
+                        with_polymorphic = None
+                    else:
+                        with_polymorphic = adapter = None
 
-    def __no_criterion(self, meth):
-        return self.__conditional_clone(meth, [self.__no_criterion_condition])
+                    d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic)
+                ent.setup_entity(entity, *d[entity])
 
-    def __no_statement(self, meth):
-        return self.__conditional_clone(meth, [self.__no_statement_condition])
-
-    def __reset_all(self, mapper, meth):
-        q = self.__conditional_clone(meth, [self.__no_criterion_condition])
-        q.__init_mapper(mapper, mapper)
-        return q
+    def __mapper_loads_polymorphically_with(self, mapper, adapter):
+        for m2 in mapper._with_polymorphic_mappers:
+            for m in m2.iterate_to_root():
+                self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter
 
     def __set_select_from(self, from_obj):
         if isinstance(from_obj, expression._SelectBaseMixin):
@@ -105,54 +126,168 @@ class Query(object):
             from_obj = from_obj.alias()
 
         self._from_obj = from_obj
-        self._alias_ids = {}
+        equivs = self.__all_equivs()
+
+        if isinstance(from_obj, expression.Alias):
+            # dont alias a regular join (since its not an alias itself)
+            self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj, equivs)
+
+    def _get_polymorphic_adapter(self, entity, selectable):
+        self.__mapper_loads_polymorphically_with(entity.mapper, sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns))
+
+    def _reset_polymorphic_adapter(self, mapper):
+        for m2 in mapper._with_polymorphic_mappers:
+            for m in m2.iterate_to_root():
+                self._polymorphic_adapters.pop(m.mapped_table, None)
+                self._polymorphic_adapters.pop(m.local_table, None)
+
+    def __reset_joinpoint(self):
+        self._joinpoint = None
+        self._filter_aliases = None
+
+    def __adapt_polymorphic_element(self, element):
+        if isinstance(element, expression.FromClause):
+            search = element
+        elif hasattr(element, 'table'):
+            search = element.table
+        else:
+            search = None
+
+        if search:
+            alias = self._polymorphic_adapters.get(search, None)
+            if alias:
+                return alias.adapt_clause(element)
+    
+    def __replace_element(self, adapters):
+        def replace(elem):
+            if '_halt_adapt' in elem._annotations:
+                return elem
+
+            for adapter in adapters:
+                e = adapter(elem)
+                if e:
+                    return e
+        return replace
+    
+    def __replace_orm_element(self, adapters):
+        def replace(elem):
+            if '_halt_adapt' in elem._annotations:
+                return elem
+
+            if "_orm_adapt" in elem._annotations or "parententity" in elem._annotations:
+                for adapter in adapters:
+                    e = adapter(elem)
+                    if e:
+                        return e
+        return replace
+
+    def _adapt_all_clauses(self):
+        self._disable_orm_filtering = True
+    _adapt_all_clauses = _generative()(_adapt_all_clauses)
+    
+    def _adapt_clause(self, clause, as_filter, orm_only):
+        adapters = []    
+        if as_filter and self._filter_aliases:
+            adapters.append(self._filter_aliases.replace)
+
+        if self._polymorphic_adapters:
+            adapters.append(self.__adapt_polymorphic_element)
+
+        if self._from_obj_alias:
+            adapters.append(self._from_obj_alias.replace)
+
+        if not adapters:
+            return clause
+            
+        if getattr(self, '_disable_orm_filtering', not orm_only):
+            return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_element(adapters))
+        else:
+            return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_orm_element(adapters))
         
-        if self.table not in self._get_joinable_tables():
-            self._aliases_head = self._aliases_tail = mapperutil.AliasedClauses(self._from_obj, equivalents=self.mapper._equivalent_columns)
-            self._alias_ids.setdefault(self.table, []).append(self._aliases_head)
+    def _entity_zero(self):
+        return self._entities[0]
+
+    def _mapper_zero(self):
+        return self._entity_zero().entity_zero
+
+    def _extension_zero(self):
+        ent = self._entity_zero()
+        return getattr(ent, 'extension', ent.mapper.extension)
+
+    def _mapper_entities(self):
+        for ent in self._entities:
+            if hasattr(ent, 'primary_entity'):
+                yield ent
+    _mapper_entities = property(_mapper_entities)
+
+    def _joinpoint_zero(self):
+        return self._joinpoint or self._entity_zero().entity_zero
+
+    def _mapper_zero_or_none(self):
+        if not getattr(self._entities[0], 'primary_entity', False):
+            return None
+        return self._entities[0].mapper
+
+    def _only_mapper_zero(self):
+        if len(self._entities) > 1:
+            raise sa_exc.InvalidRequestError("This operation requires a Query against a single mapper.")
+        return self._mapper_zero()
+
+    def _only_entity_zero(self):
+        if len(self._entities) > 1:
+            raise sa_exc.InvalidRequestError("This operation requires a Query against a single mapper.")
+        return self._entity_zero()
+
+    def _generate_mapper_zero(self):
+        if not getattr(self._entities[0], 'primary_entity', False):
+            raise sa_exc.InvalidRequestError("No primary mapper set up for this Query.")
+        entity = self._entities[0]._clone()
+        self._entities = [entity] + self._entities[1:]
+        return entity
+
+    def __mapper_zero_from_obj(self):
+        if self._from_obj:
+            return self._from_obj
         else:
-            self._aliases_head = self._aliases_tail = None
+            return self._entity_zero().selectable
 
-    def __set_with_polymorphic(self, cls_or_mappers, selectable=None):
-        mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable)
-        self._with_polymorphic = mappers
-        self.__set_select_from(from_obj)
+    def __all_equivs(self):
+        equivs = {}
+        for ent in self._mapper_entities:
+            equivs.update(ent.mapper._equivalent_columns)
+        return equivs
 
-    def __no_criterion_condition(self, q, meth):
-        if q._criterion or q._statement:
+    def __no_criterion_condition(self, meth):
+        if self._criterion or self._statement or self._from_obj:
             util.warn(
                 ("Query.%s() being called on a Query with existing criterion; "
-                 "criterion is being ignored.") % meth)
-
-        q._joinpoint = self.mapper
-        q._statement = q._criterion = None
-        q._order_by = q._group_by = q._distinct = False
-        q._aliases_tail = q._aliases_head
-        q.table = q._from_obj = q.mapper.mapped_table
-        if q.mapper.with_polymorphic:
-            q.__set_with_polymorphic(*q.mapper.with_polymorphic)
-
-    def __no_entities(self, meth):
-        q = self.__no_statement(meth)
-        if len(q._entities) > 1 and not isinstance(q._entities[0], _PrimaryMapperEntity):
-            raise exceptions.InvalidRequestError(
-                ("Query.%s() being called on a Query with existing  "
-                 "additional entities or columns - can't replace columns") % meth)
-        q._entities = []
-        return q
+                 "criterion is being ignored.  This usage is deprecated.") % meth)
 
-    def __no_statement_condition(self, q, meth):
-        if q._statement:
-            raise exceptions.InvalidRequestError(
+        self._statement = self._criterion = self._from_obj = None
+        self._order_by = self._group_by = self._distinct = False
+        self.__joined_tables = {}
+
+    def __no_from_condition(self, meth):
+        if self._from_obj:
+            raise sa_exc.InvalidRequestError("Query.%s() being called on a Query which already has a FROM clause established.  This usage is deprecated." % meth)
+
+    def __no_statement_condition(self, meth):
+        if self._statement:
+            raise sa_exc.InvalidRequestError(
                 ("Query.%s() being called on a Query with an existing full "
                  "statement - can't apply criterion.") % meth)
 
-    def __conditional_clone(self, methname=None, conditions=None):
-        q = self._clone()
-        if conditions:
-            for condition in conditions:
-                condition(q, methname)
-        return q
+    def __no_limit_offset(self, meth):
+        if self._limit or self._offset:
+            util.warn("Query.%s() being called on a Query which already has LIMIT or OFFSET applied. "
+            "This usage is deprecated. Apply filtering and joins before LIMIT or OFFSET are applied, "
+            "or to filter/join to the row-limited results of the query, call from_self() first."
+            "In release 0.5, from_self() will be called automatically in this scenario."
+            )
+
+    def __no_criterion(self):
+        """generate a Query with no criterion, warn if criterion was present"""
+    __no_criterion = _generative(__no_criterion_condition)(__no_criterion)
 
     def __get_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_instance=None):
         if populate_existing:
@@ -170,18 +305,27 @@ class Query(object):
         q.__dict__ = self.__dict__.copy()
         return q
 
-    def session(self):
-        if self._session is None:
-            return self.mapper.get_session()
-        else:
-            return self._session
-    session = property(session)
-
     def statement(self):
         """return the full SELECT statement represented by this Query."""
-        return self._compile_context().statement
+        return self._compile_context(labels=self._with_labels).statement
     statement = property(statement)
 
+    def with_labels(self):
+        """Apply column labels to the return value of Query.statement.
+        
+        Indicates that this Query's `statement` accessor should return a SELECT statement
+        that applies labels to all columns in the form <tablename>_<columnname>; this
+        is commonly used to disambiguate columns from multiple tables which have the
+        same name.
+        
+        When the `Query` actually issues SQL to load rows, it always uses 
+        column labeling.
+        
+        """
+        self._with_labels = True
+    with_labels = _generative()(with_labels)
+    
+    
     def whereclause(self):
         """return the WHERE criterion for this Query."""
         return self._criterion
@@ -189,48 +333,44 @@ class Query(object):
 
     def _with_current_path(self, path):
         """indicate that this query applies to objects loaded within a certain path.
-        
-        Used by deferred loaders (see strategies.py) which transfer query 
+
+        Used by deferred loaders (see strategies.py) which transfer query
         options from an originating query to a newly generated query intended
         for the deferred load.
-        
+
         """
-        q = self._clone()
-        q._current_path = path
-        return q
+        self._current_path = path
+    _with_current_path = _generative()(_with_current_path)
 
     def with_polymorphic(self, cls_or_mappers, selectable=None):
         """Load columns for descendant mappers of this Query's mapper.
-        
+
         Using this method will ensure that each descendant mapper's
-        tables are included in the FROM clause, and will allow filter() 
-        criterion to be used against those tables.  The resulting 
+        tables are included in the FROM clause, and will allow filter()
+        criterion to be used against those tables.  The resulting
         instances will also have those columns already loaded so that
         no "post fetch" of those columns will be required.
-        
+
         ``cls_or_mappers`` is a single class or mapper, or list of class/mappers,
         which inherit from this Query's mapper.  Alternatively, it
-        may also be the string ``'*'``, in which case all descending 
+        may also be the string ``'*'``, in which case all descending
         mappers will be added to the FROM clause.
-        
-        ``selectable`` is a table or select() statement that will 
+
+        ``selectable`` is a table or select() statement that will
         be used in place of the generated FROM clause.  This argument
-        is required if any of the desired mappers use concrete table 
-        inheritance, since SQLAlchemy currently cannot generate UNIONs 
-        among tables automatically.  If used, the ``selectable`` 
-        argument must represent the full set of tables and columns mapped 
+        is required if any of the desired mappers use concrete table
+        inheritance, since SQLAlchemy currently cannot generate UNIONs
+        among tables automatically.  If used, the ``selectable``
+        argument must represent the full set of tables and columns mapped
         by every desired mapper.  Otherwise, the unaccounted mapped columns
-        will result in their table being appended directly to the FROM 
+        will result in their table being appended directly to the FROM
         clause which will usually lead to incorrect results.
 
         """
-        q = self.__no_criterion('with_polymorphic')
-
-        q.__set_with_polymorphic(cls_or_mappers, selectable=selectable)
+        entity = self._generate_mapper_zero()
+        entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable)
+    with_polymorphic = _generative(__no_from_condition, __no_criterion_condition)(with_polymorphic)
 
-        return q
-    
-        
     def yield_per(self, count):
         """Yield only ``count`` rows at a time.
 
@@ -242,30 +382,28 @@ class Query(object):
         eagerly loaded collections (i.e. any lazy=False) since those
         collections will be cleared for a new load when encountered in a
         subsequent result batch.
-        """
 
-        q = self._clone()
-        q._yield_per = count
-        return q
+        """
+        self._yield_per = count
+    yield_per = _generative()(yield_per)
 
     def get(self, ident, **kwargs):
         """Return an instance of the object based on the given identifier, or None if not found.
 
         The `ident` argument is a scalar or tuple of primary key column values
         in the order of the table def's primary key columns.
+
         """
 
-        ret = self._extension.get(self, ident, **kwargs)
+        ret = self._extension_zero().get(self, ident, **kwargs)
         if ret is not mapper.EXT_CONTINUE:
             return ret
 
         # convert composite types to individual args
-        # TODO: account for the order of columns in the
-        # ColumnProperty it corresponds to
         if hasattr(ident, '__composite_values__'):
             ident = ident.__composite_values__()
 
-        key = self.mapper.identity_key_from_primary_key(ident)
+        key = self._only_mapper_zero().identity_key_from_primary_key(ident)
         return self._get(key, ident, **kwargs)
 
     def load(self, ident, raiseerr=True, **kwargs):
@@ -275,15 +413,20 @@ class Query(object):
         pending changes** to the object already existing in the Session.  The
         `ident` argument is a scalar or tuple of primary key column values in
         the order of the table def's primary key columns.
-        """
 
-        ret = self._extension.load(self, ident, **kwargs)
+        """
+        ret = self._extension_zero().load(self, ident, **kwargs)
         if ret is not mapper.EXT_CONTINUE:
             return ret
-        key = self.mapper.identity_key_from_primary_key(ident)
+
+        # convert composite types to individual args
+        if hasattr(ident, '__composite_values__'):
+            ident = ident.__composite_values__()
+
+        key = self._only_mapper_zero().identity_key_from_primary_key(ident)
         instance = self.populate_existing()._get(key, ident, **kwargs)
         if instance is None and raiseerr:
-            raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
+            raise sa_exc.InvalidRequestError("No instance found for identity %s" % repr(ident))
         return instance
 
     def query_from_parent(cls, instance, property, **kwargs):
@@ -303,27 +446,33 @@ class Query(object):
          \**kwargs
            all extra keyword arguments are propagated to the constructor of
            Query.
-        """
 
+       deprecated.  use sqlalchemy.orm.with_parent in conjunction with
+       filter().
+
+        """
         mapper = object_mapper(instance)
         prop = mapper.get_property(property, resolve_synonyms=True)
         target = prop.mapper
         criterion = prop.compare(operators.eq, instance, value_is_parent=True)
         return Query(target, **kwargs).filter(criterion)
-    query_from_parent = classmethod(query_from_parent)
+    query_from_parent = classmethod(util.deprecated(None, False)(query_from_parent))
+
+    def correlate(self, *args):
+        self._correlate = self._correlate.union([_orm_selectable(s) for s in args])
+    correlate = _generative()(correlate)
 
     def autoflush(self, setting):
         """Return a Query with a specific 'autoflush' setting.
 
         Note that a Session with autoflush=False will
-        not autoflush, even if this flag is set to True at the 
+        not autoflush, even if this flag is set to True at the
         Query level.  Therefore this flag is usually used only
         to disable autoflush for a specific Query.
-        
+
         """
-        q = self._clone()
-        q._autoflush = setting
-        return q
+        self._autoflush = setting
+    autoflush = _generative()(autoflush)
 
     def populate_existing(self):
         """Return a Query that will refresh all instances loaded.
@@ -336,11 +485,10 @@ class Query(object):
 
         An alternative to populate_existing() is to expire the Session
         fully using session.expire_all().
-        
+
         """
-        q = self._clone()
-        q._populate_existing = True
-        return q
+        self._populate_existing = True
+    populate_existing = _generative()(populate_existing)
 
     def with_parent(self, instance, property=None):
         """add a join criterion corresponding to a relationship to the given parent instance.
@@ -361,140 +509,98 @@ class Query(object):
         mapper = object_mapper(instance)
         if property is None:
             for prop in mapper.iterate_properties:
-                if isinstance(prop, properties.PropertyLoader) and prop.mapper is self.mapper:
+                if isinstance(prop, properties.PropertyLoader) and prop.mapper is self._mapper_zero():
                     break
             else:
-                raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__))
+                raise sa_exc.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self._mapper_zero().class_.__name__, instance.__class__.__name__))
         else:
             prop = mapper.get_property(property, resolve_synonyms=True)
         return self.filter(prop.compare(operators.eq, instance, value_is_parent=True))
 
-    def add_entity(self, entity, alias=None, id=None):
-        """add a mapped entity to the list of result columns to be returned.
-
-        This will have the effect of all result-returning methods returning a tuple
-        of results, the first element being an instance of the primary class for this
-        Query, and subsequent elements matching columns or entities which were
-        specified via add_column or add_entity.
-
-        When adding entities to the result, its generally desirable to add
-        limiting criterion to the query which can associate the primary entity
-        of this Query along with the additional entities.  The Query selects
-        from all tables with no joining criterion by default.
+    def add_entity(self, entity, alias=None):
+        """add a mapped entity to the list of result columns to be returned."""
 
-            entity
-                a class or mapper which will be added to the results.
+        if alias:
+            entity = aliased(entity, alias)
 
-            alias
-                a sqlalchemy.sql.Alias object which will be used to select rows.  this
-                will match the usage of the given Alias in filter(), order_by(), etc. expressions
+        self._entities = list(self._entities)
+        m = _MapperEntity(self, entity)
+        self.__setup_aliasizers([m])
+    add_entity = _generative()(add_entity)
 
-            id
-                a string ID matching that given to query.join() or query.outerjoin(); rows will be
-                selected from the aliased join created via those methods.
+    def from_self(self, *entities):
+        """return a Query that selects from this Query's SELECT statement.
 
+        \*entities - optional list of entities which will replace
+        those being selected.
         """
-        q = self._clone()
-
-        if not alias and _is_aliased_class(entity):
-            alias = entity.alias
 
-        if isinstance(entity, type):
-            entity = mapper.class_mapper(entity)
+        fromclause = self.compile().correlate(None)
+        self._statement = self._criterion = None
+        self._order_by = self._group_by = self._distinct = False
+        self._limit = self._offset = None
+        self.__set_select_from(fromclause)
+        if entities:
+            self._entities = []
+            for ent in entities:
+                _QueryEntity(self, ent)
+            self.__setup_aliasizers(self._entities)
 
-        if alias is not None:
-            alias = mapperutil.AliasedClauses(alias)
+    from_self = _generative()(from_self)
+    _from_self = from_self
 
-        q._entities = q._entities + [_MapperEntity(mapper=entity, alias=alias, id=id)]
-        return q
-    
-    def _from_self(self):
-        """return a Query that selects from this Query's SELECT statement.
-        
-        The API for this method hasn't been decided yet and is subject to change.
-
-        """
-        q = self._clone()
-        q._eager_loaders = util.Set()
-        fromclause = q.compile().correlate(None)
-        return Query(self.mapper, self.session).select_from(fromclause)
-        
     def values(self, *columns):
         """Return an iterator yielding result tuples corresponding to the given list of columns"""
-        
-        q = self.__no_entities('_values')
-        q._only_load_props = q._eager_loaders = util.Set()
-        q._no_filters = True
+
+        if not columns:
+            return iter(())
+        q = self._clone()
+        q._entities = []
         for column in columns:
-            q._entities.append(self._add_column(column, None, False))
+            _ColumnEntity(q, column)
+        q.__setup_aliasizers(q._entities)
         if not q._yield_per:
-            q = q.yield_per(10)
+            q._yield_per = 10
         return iter(q)
     _values = values
-    
-    def add_column(self, column, id=None):
-        """Add a SQL ColumnElement to the list of result columns to be returned.
 
-        This will have the effect of all result-returning methods returning a
-        tuple of results, the first element being an instance of the primary
-        class for this Query, and subsequent elements matching columns or
-        entities which were specified via add_column or add_entity.
+    def add_column(self, column):
+        """Add a SQL ColumnElement to the list of result columns to be returned."""
 
-        When adding columns to the result, its generally desirable to add
-        limiting criterion to the query which can associate the primary entity
-        of this Query along with the additional columns, if the column is
-        based on a table or selectable that is not the primary mapped
-        selectable.  The Query selects from all tables with no joining
-        criterion by default.
+        self._entities = list(self._entities)
+        c = _ColumnEntity(self, column)
+        self.__setup_aliasizers([c])
+    add_column = _generative()(add_column)
 
-        column
-          a string column name or sql.ColumnElement to be added to the results.
-
-        """
-        q = self._clone()
-        q._entities = q._entities + [self._add_column(column, id, True)]
-        return q
-    
-    def _add_column(self, column, id, looks_for_aliases):
-        if isinstance(column, interfaces.PropComparator):
-            column = column.clause_element()
-
-        elif not isinstance(column, (sql.ColumnElement, basestring)):
-            raise exceptions.InvalidRequestError("Invalid column expression '%r'" % column)
-
-        return _ColumnEntity(column, id)
-        
     def options(self, *args):
         """Return a new Query object, applying the given list of
         MapperOptions.
 
         """
-        return self._options(False, *args)
+        return self.__options(False, *args)
 
     def _conditional_options(self, *args):
-        return self._options(True, *args)
+        return self.__options(True, *args)
 
-    def _options(self, conditional, *args):
-        q = self._clone()
+    def __options(self, conditional, *args):
         # most MapperOptions write to the '_attributes' dictionary,
         # so copy that as well
-        q._attributes = q._attributes.copy()
+        self._attributes = self._attributes.copy()
         opts = [o for o in util.flatten_iterator(args)]
-        q._with_options = q._with_options + opts
+        self._with_options = self._with_options + opts
         if conditional:
             for opt in opts:
-                opt.process_query_conditionally(q)
+                opt.process_query_conditionally(self)
         else:
             for opt in opts:
-                opt.process_query(q)
-        return q
+                opt.process_query(self)
+    __options = _generative()(__options)
 
     def with_lockmode(self, mode):
         """Return a new Query object with the specified locking mode."""
-        
-        q = self._clone()
-        q._lockmode = mode
-        return q
+
+        self._lockmode = mode
+    with_lockmode = _generative()(with_lockmode)
 
     def params(self, *args, **kwargs):
         """add values for bind parameters which may have been specified in filter().
@@ -505,14 +611,13 @@ class Query(object):
         \**kwargs cannot be used.
 
         """
-        q = self._clone()
         if len(args) == 1:
             kwargs.update(args[0])
         elif len(args) > 0:
-            raise exceptions.ArgumentError("params() takes zero or one positional argument, which is a dictionary.")
-        q._params = q._params.copy()
-        q._params.update(kwargs)
-        return q
+            raise sa_exc.ArgumentError("params() takes zero or one positional argument, which is a dictionary.")
+        self._params = self._params.copy()
+        self._params.update(kwargs)
+    params = _generative()(params)
 
     def filter(self, criterion):
         """apply the given filtering criterion to the query and return the newly resulting ``Query``
@@ -524,22 +629,20 @@ class Query(object):
             criterion = sql.text(criterion)
 
         if criterion is not None and not isinstance(criterion, sql.ClauseElement):
-            raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string")
+            raise sa_exc.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string")
 
-        if self._aliases_tail:
-            criterion = self._aliases_tail.adapt_clause(criterion)
+        criterion = self._adapt_clause(criterion, True, True)
 
-        q = self.__no_statement("filter")
-        if q._criterion is not None:
-            q._criterion = q._criterion & criterion
+        if self._criterion is not None:
+            self._criterion = self._criterion & criterion
         else:
-            q._criterion = criterion
-        return q
+            self._criterion = criterion
+    filter = _generative(__no_statement_condition, __no_limit_offset)(filter)
 
     def filter_by(self, **kwargs):
         """apply the given filtering criterion to the query and return the newly resulting ``Query``."""
 
-        clauses = [self._joinpoint.get_property(key, resolve_synonyms=True).compare(operators.eq, value)
+        clauses = [_entity_descriptor(self._joinpoint_zero(), key)[0] == value
             for key, value in kwargs.iteritems()]
 
         return self.filter(sql.and_(*clauses))
@@ -568,31 +671,27 @@ class Query(object):
     def order_by(self, *criterion):
         """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
 
-        q = self.__no_statement("order_by")
+        criterion = [self._adapt_clause(expression._literal_as_text(o), True, True) for o in criterion]
 
-        if self._aliases_tail:
-            criterion = tuple(self._aliases_tail.adapt_list(
-                    [expression._literal_as_text(o) for o in criterion]
-                    ))
-
-        if q._order_by is False:
-            q._order_by = criterion
+        if self._order_by is False:
+            self._order_by = criterion
         else:
-            q._order_by = q._order_by + criterion
-        return q
+            self._order_by = self._order_by + criterion
     order_by = util.array_as_starargs_decorator(order_by)
-    
+    order_by = _generative(__no_statement_condition)(order_by)
+
     def group_by(self, *criterion):
         """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``"""
 
-        q = self.__no_statement("group_by")
-        if q._group_by is False:
-            q._group_by = criterion
+        criterion = list(chain(*[_orm_columns(c) for c in criterion]))
+
+        if self._group_by is False:
+            self._group_by = criterion
         else:
-            q._group_by = q._group_by + criterion
-        return q
+            self._group_by = self._group_by + criterion
     group_by = util.array_as_starargs_decorator(group_by)
-    
+    group_by = _generative(__no_statement_condition)(group_by)
+
     def having(self, criterion):
         """apply a HAVING criterion to the query and return the newly resulting ``Query``."""
 
@@ -600,190 +699,225 @@ class Query(object):
             criterion = sql.text(criterion)
 
         if criterion is not None and not isinstance(criterion, sql.ClauseElement):
-            raise exceptions.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string")
+            raise sa_exc.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string")
 
-        if self._aliases_tail:
-            criterion = self._aliases_tail.adapt_clause(criterion)
+        criterion = self._adapt_clause(criterion, True, True)
 
-        q = self.__no_statement("having")
-        if q._having is not None:
-            q._having = q._having & criterion
+        if self._having is not None:
+            self._having = self._having & criterion
         else:
-            q._having = criterion
-        return q
+            self._having = criterion
+    having = _generative(__no_statement_condition)(having)
 
-    def join(self, prop, id=None, aliased=False, from_joinpoint=False):
+    def join(self, *props, **kwargs):
         """Create a join against this ``Query`` object's criterion
-        and apply generatively, retunring the newly resulting ``Query``.
-
-        'prop' may be one of:
-          * a string property name, i.e. "rooms"
-          * a class-mapped attribute, i.e. Houses.rooms
-          * a 2-tuple containing one of the above, combined with a selectable
-            which derives from the properties' mapped table
-          * a list (not a tuple) containing a combination of any of the above.
+        and apply generatively, returning the newly resulting ``Query``.
 
+        each element in \*props may be:
+        
+          * a string property name, i.e. "rooms".  This will join along
+            the relation of the same name from this Query's "primary"
+            mapper, if one is present.
+          
+          * a class-mapped attribute, i.e. Houses.rooms.  This will create a
+            join from "Houses" table to that of the "rooms" relation.
+          
+          * a 2-tuple containing a target class or selectable, and 
+            an "ON" clause.  The ON clause can be the property name/
+            attribute like above, or a SQL expression.
+          
+          
         e.g.::
 
+            # join along string attribute names
             session.query(Company).join('employees')
-            session.query(Company).join(['employees', 'tasks'])
-            session.query(Houses).join([Colonials.rooms, Room.closets])
-            session.query(Company).join([('employees', people.join(engineers)), Engineer.computers])
+            session.query(Company).join('employees', 'tasks')
+
+            # join the Person entity to an alias of itself,
+            # along the "friends" relation
+            PAlias = aliased(Person)
+            session.query(Person).join((Palias, Person.friends))
+
+            # join from Houses to the "rooms" attribute on the
+            # "Colonials" subclass of Houses, then join to the 
+            # "closets" relation on Room
+            session.query(Houses).join(Colonials.rooms, Room.closets)
+            
+            # join from Company entities to the "employees" collection,
+            # using "people JOIN engineers" as the target.  Then join
+            # to the "computers" collection on the Engineer entity.
+            session.query(Company).join((people.join(engineers), 'employees'), Engineer.computers)
+            
+            # join from Articles to Keywords, using the "keywords" attribute.
+            # assume this is a many-to-many relation.
+            session.query(Article).join(Article.keywords)
+            
+            # same thing, but spelled out entirely explicitly 
+            # including the association table.
+            session.query(Article).join(
+                (article_keywords, Articles.id==article_keywords.c.article_id),
+                (Keyword, Keyword.id==article_keywords.c.keyword_id)
+                )
+
+        \**kwargs include:
+
+            aliased - when joining, create anonymous aliases of each table.  This is
+            used for self-referential joins or multiple joins to the same table.
+            Consider usage of the aliased(SomeClass) construct as a more explicit
+            approach to this.
+
+            from_joinpoint - when joins are specified using string property names,
+            locate the property from the mapper found in the most recent previous 
+            join() call, instead of from the root entity.
 
         """
-        return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint)
+        aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False)
+        if kwargs:
+            raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys()))
+        return self.__join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint)
+    join = util.array_as_starargs_decorator(join)
 
-    def outerjoin(self, prop, id=None, aliased=False, from_joinpoint=False):
+    def outerjoin(self, *props, **kwargs):
         """Create a left outer join against this ``Query`` object's criterion
         and apply generatively, retunring the newly resulting ``Query``.
+        
+        Usage is the same as the ``join()`` method.
 
-        'prop' may be one of:
-          * a string property name, i.e. "rooms"
-          * a class-mapped attribute, i.e. Houses.rooms
-          * a 2-tuple containing one of the above, combined with a selectable
-            which derives from the properties' mapped table
-          * a list (not a tuple) containing a combination of any of the above.
+        """
+        aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False)
+        if kwargs:
+            raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys()))
+        return self.__join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint)
+    outerjoin = util.array_as_starargs_decorator(outerjoin)
 
-        e.g.::
+    def __join(self, keys, outerjoin, create_aliases, from_joinpoint):
+        self.__currenttables = util.Set(self.__currenttables)
+        self._polymorphic_adapters = self._polymorphic_adapters.copy()
 
-            session.query(Company).outerjoin('employees')
-            session.query(Company).outerjoin(['employees', 'tasks'])
-            session.query(Houses).outerjoin([Colonials.rooms, Room.closets])
-            session.query(Company).join([('employees', people.join(engineers)), Engineer.computers])
+        if not from_joinpoint:
+            self.__reset_joinpoint()
 
-        """
-        return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint)
-    
-    def _join(self, prop, id, outerjoin, aliased, from_joinpoint):
-        (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased)
-        # TODO: improve the generative check here to look for primary mapped entity, etc.
-        q = self.__no_statement("join")
-        q._from_obj = clause
-        q._joinpoint = mapper
-        q._aliases = aliases
-        q.__generate_alias_ids()
-        
-        if aliases:
-            q._aliases_tail = aliases
-
-        a = aliases
-        while a is not None:
-            if isinstance(a, mapperutil.PropertyAliasedClauses):
-                q._alias_ids.setdefault(a.mapper, []).append(a)
-                q._alias_ids.setdefault(a.table, []).append(a)
-                a = a.parentclauses
+        clause = self._from_obj
+        right_entity = None
+
+        for arg1 in util.to_list(keys):
+            prop =  None
+            aliased_entity = False
+            alias_criterion = False
+            left_entity = right_entity
+            right_entity = right_mapper = None
+            
+            if isinstance(arg1, tuple):
+                arg1, arg2 = arg1
             else:
-                break
+                arg2 = None
+            
+            if isinstance(arg2, (interfaces.PropComparator, basestring)):
+                onclause = arg2
+                right_entity = arg1
+            elif isinstance(arg1, (interfaces.PropComparator, basestring)):
+                onclause = arg1
+                right_entity = arg2 
+            else:
+                onclause = arg2
+                right_entity = arg1
 
-        if id:
-            q._alias_ids[id] = [aliases]
-        return q
+            if isinstance(onclause, interfaces.PropComparator):
+                of_type = getattr(onclause, '_of_type', None)
+                prop = onclause.property
+                descriptor = onclause
+                
+                if not left_entity:
+                    left_entity = onclause.parententity
+                    
+                if of_type:
+                    right_mapper = of_type
+                else:
+                    right_mapper = prop.mapper
+                    
+                if not right_entity:
+                    right_entity = right_mapper
+                    
+            elif isinstance(onclause, basestring):
+                if not left_entity:
+                    left_entity = self._joinpoint_zero()
+                    
+                descriptor, prop = _entity_descriptor(left_entity, onclause)
+                right_mapper = prop.mapper
+                if not right_entity:
+                    right_entity = right_mapper
+            elif onclause is None:
+                if not left_entity:
+                    left_entity = self._joinpoint_zero()
+            else:
+                if not left_entity:
+                    left_entity = self._joinpoint_zero()
+                    
+            if not clause:
+                if isinstance(onclause, interfaces.PropComparator):
+                    clause = onclause.__clause_element__()
 
-    def _get_joinable_tables(self):
-        if not self.__joinable_tables or self.__joinable_tables[0] is not self._from_obj:
-            currenttables = [self._from_obj]
-            def visit_join(join):
-                currenttables.append(join.left)
-                currenttables.append(join.right)
-            visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False})
-            self.__joinable_tables = (self._from_obj, currenttables)
-            return currenttables
-        else:
-            return self.__joinable_tables[1]
+                for ent in self._mapper_entities:
+                    if ent.corresponds_to(left_entity):
+                        clause = ent.selectable
+                        break
 
-    def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True):
-        if start is None:
-            start = self._joinpoint
+            if not clause:
+                raise exc.InvalidRequestError("Could not find a FROM clause to join from")
 
-        clause = self._from_obj
+            bogus, right_selectable, is_aliased_class = _entity_info(right_entity)
 
-        currenttables = self._get_joinable_tables()
+            if right_mapper and not is_aliased_class:
+                if right_entity is right_selectable:
 
-        # determine if generated joins need to be aliased on the left
-        # hand side.
-        if self._aliases_head is self._aliases_tail is not None:
-            adapt_against = self._aliases_tail.alias
-        elif start is not self.mapper and self._aliases_tail:
-            adapt_against = self._aliases_tail.alias
-        else:
-            adapt_against = None
+                    if not right_selectable.is_derived_from(right_mapper.mapped_table):
+                        raise sa_exc.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (right_selectable.description, right_mapper.mapped_table.description))
 
-        mapper = start
-        alias = self._aliases_tail
+                    if not isinstance(right_selectable, expression.Alias):
+                        right_selectable = right_selectable.alias()
 
-        if not isinstance(keys, list):
-            keys = [keys]
-            
-        for key in keys:
-            use_selectable = None
-            of_type = None
-            is_aliased_class = False
-            
-            if isinstance(key, tuple):
-                key, use_selectable = key
-
-            if isinstance(key, interfaces.PropComparator):
-                prop = key.property
-                if getattr(key, '_of_type', None):
-                    of_type = key._of_type
-                    if not use_selectable:
-                        use_selectable = key._of_type.mapped_table
-            else:
-                prop = mapper.get_property(key, resolve_synonyms=True)
-
-            if use_selectable:
-                if _is_aliased_class(use_selectable):
-                    use_selectable = use_selectable.alias
-                    is_aliased_class = True
-                if not use_selectable.is_derived_from(prop.mapper.mapped_table):
-                    raise exceptions.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (use_selectable.description, prop.mapper.mapped_table.description))
-                if not isinstance(use_selectable, expression.Alias):
-                    use_selectable = use_selectable.alias()
-            elif prop.mapper.with_polymorphic:
-                use_selectable = prop.mapper._with_polymorphic_selectable()
-                if not isinstance(use_selectable, expression.Alias):
-                    use_selectable = use_selectable.alias()
-
-            if prop._is_self_referential() and not create_aliases and not use_selectable:
-                raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires aliased=True argument." % str(prop))
-
-            if prop.table not in currenttables or create_aliases or use_selectable:
+                    right_entity = aliased(right_mapper, right_selectable)
+                    alias_criterion = True
+
+                elif right_mapper.with_polymorphic or isinstance(right_mapper.mapped_table, expression.Join):
+                    aliased_entity = True
+                    right_entity = aliased(right_mapper)
+                    alias_criterion = True
                 
-                if use_selectable or create_aliases:
-                    alias = mapperutil.PropertyAliasedClauses(prop,
-                        prop.primaryjoin, 
-                        prop.secondaryjoin, 
-                        alias,
-                        alias=use_selectable,
-                        should_adapt=not is_aliased_class
-                    )
-                    crit = alias.primaryjoin
+                elif create_aliases:
+                    right_entity = aliased(right_mapper)
+                    alias_criterion = True
+                    
+                elif prop:
+                    if prop.table in self.__currenttables:
+                        if prop.secondary is not None and prop.secondary not in self.__currenttables:
+                            # TODO: this check is not strong enough for different paths to the same endpoint which
+                            # does not use secondary tables
+                            raise sa_exc.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists.  Use the `alias=True` argument to `join()`." % descriptor)
+
+                        continue
+
                     if prop.secondary:
-                        clause = clause.join(alias.secondary, crit, isouter=outerjoin)
-                        clause = clause.join(alias.alias, alias.secondaryjoin, isouter=outerjoin)
-                    else:
-                        clause = clause.join(alias.alias, crit, isouter=outerjoin)
-                else:
-                    assert not prop.mapper.with_polymorphic
-                    pj, sj, source, dest, target_adapter = prop._create_joins(source_selectable=adapt_against)
-                    if sj:
-                        clause = clause.join(prop.secondary, pj, isouter=outerjoin)
-                        clause = clause.join(prop.table, sj, isouter=outerjoin)
-                    else:
-                        clause = clause.join(prop.table, pj, isouter=outerjoin)
-                        
-            elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables:
-                # TODO: this check is not strong enough for different paths to the same endpoint which
-                # does not use secondary tables
-                raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists.  Use the `alias=True` argument to `join()`." % prop.key)
+                        self.__currenttables.add(prop.secondary)
+                    self.__currenttables.add(prop.table)
 
-            mapper = of_type or prop.mapper
+                    right_entity = prop.mapper
 
-            if use_selectable:
-                adapt_against = use_selectable
-        
-        return (clause, mapper, alias)
+            if prop:
+                onclause = prop
+            
+            clause = orm_join(clause, right_entity, onclause, isouter=outerjoin)
+            if alias_criterion: 
+                self._filter_aliases = ORMAdapter(right_entity, 
+                        equivalents=right_mapper._equivalent_columns, chain_to=self._filter_aliases)
+
+                if aliased_entity:
+                    self.__mapper_loads_polymorphically_with(right_mapper, ORMAdapter(right_entity, equivalents=right_mapper._equivalent_columns))
+
+        self._from_obj = clause
+        self._joinpoint = right_entity
 
+    __join = _generative(__no_statement_condition, __no_limit_offset)(__join)
 
     def reset_joinpoint(self):
         """return a new Query reset the 'joinpoint' of this Query reset
@@ -794,13 +928,8 @@ class Query(object):
         the root.
 
         """
-        q = self.__no_statement("reset_joinpoint")
-        q._joinpoint = q.mapper
-        if q.table not in q._get_joinable_tables():
-            q._aliases_head = q._aliases_tail = mapperutil.AliasedClauses(q._from_obj, equivalents=q.mapper._equivalent_columns)
-        else:
-            q._aliases_head = q._aliases_tail = None
-        return q
+        self.__reset_joinpoint()
+    reset_joinpoint = _generative(__no_statement_condition)(reset_joinpoint)
 
     def select_from(self, from_obj):
         """Set the `from_obj` parameter of the query and return the newly
@@ -811,14 +940,13 @@ class Query(object):
         `from_obj` is a single table or selectable.
 
         """
-        new = self.__no_criterion('select_from')
         if isinstance(from_obj, (tuple, list)):
             util.warn_deprecated("select_from() now accepts a single Selectable as its argument, which replaces any existing FROM criterion.")
             from_obj = from_obj[-1]
+        
+        self.__set_select_from(from_obj)
+    select_from = _generative(__no_from_condition, __no_criterion_condition)(select_from)
 
-        new.__set_select_from(from_obj)
-        return new
-    
     def __getitem__(self, item):
         if isinstance(item, slice):
             start = item.start
@@ -863,9 +991,8 @@ class Query(object):
         ``Query``.
 
         """
-        new = self.__no_statement("distinct")
-        new._distinct = True
-        return new
+        self._distinct = True
+    distinct = _generative(__no_statement_condition)(distinct)
 
     def all(self):
         """Return the results represented by this ``Query`` as a list.
@@ -875,7 +1002,6 @@ class Query(object):
         """
         return list(self)
 
-
     def from_statement(self, statement):
         """Execute the given SELECT statement and return results.
 
@@ -891,9 +1017,8 @@ class Query(object):
         """
         if isinstance(statement, basestring):
             statement = sql.text(statement)
-        q = self.__no_criterion('from_statement')
-        q._statement = statement
-        return q
+        self._statement = statement
+    from_statement = _generative(__no_criterion_condition)(from_statement)
 
     def first(self):
         """Return the first result of this ``Query`` or None if the result doesn't contain any row.
@@ -901,9 +1026,6 @@ class Query(object):
         This results in an execution of the underlying query.
 
         """
-        if self._column_aggregate is not None:
-            return self._col_aggregate(*self._column_aggregate)
-
         ret = list(self[0:1])
         if len(ret) > 0:
             return ret[0]
@@ -916,17 +1038,14 @@ class Query(object):
         This results in an execution of the underlying query.
 
         """
-        if self._column_aggregate is not None:
-            return self._col_aggregate(*self._column_aggregate)
-
         ret = list(self[0:2])
 
         if len(ret) == 1:
             return ret[0]
         elif len(ret) == 0:
-            raise exceptions.InvalidRequestError('No rows returned for one()')
+            raise sa_exc.InvalidRequestError('No rows returned for one()')
         else:
-            raise exceptions.InvalidRequestError('Multiple rows returned for one()')
+            raise sa_exc.InvalidRequestError('Multiple rows returned for one()')
 
     def __iter__(self):
         context = self._compile_context()
@@ -936,37 +1055,41 @@ class Query(object):
         return self._execute_and_instances(context)
 
     def _execute_and_instances(self, querycontext):
-        result = self.session.execute(querycontext.statement, params=self._params, mapper=self.mapper, instance=self._refresh_instance)
-        return self.iterate_instances(result, querycontext=querycontext)
+        result = self.session.execute(querycontext.statement, params=self._params, mapper=self._mapper_zero_or_none(), instance=self._refresh_instance)
+        return self.iterate_instances(result, querycontext)
 
-    def instances(self, cursor, *mappers_or_columns, **kwargs):
-        return list(self.iterate_instances(cursor, *mappers_or_columns, **kwargs))
+    def instances(self, cursor, __context=None):
+        return list(self.iterate_instances(cursor, __context))
 
-    def iterate_instances(self, cursor, *mappers_or_columns, **kwargs):
+    def iterate_instances(self, cursor, __context=None):
         session = self.session
 
-        context = kwargs.pop('querycontext', None)
+        context = __context
         if context is None:
             context = QueryContext(self)
 
         context.runid = _new_runid()
 
-        entities = self._entities + [_QueryEntity.legacy_guess_type(mc) for mc in mappers_or_columns]
-        
-        if getattr(self, '_no_filters', False):
-            filter = None
-            single_entity = custom_rows = False
-        else:
-            single_entity = isinstance(entities[0], _PrimaryMapperEntity) and len(entities) == 1
-            custom_rows = single_entity and 'append_result' in context.extension.methods
-            
+        filtered = bool(list(self._mapper_entities))
+        single_entity = filtered and len(self._entities) == 1
+
+        if filtered:
             if single_entity:
                 filter = util.OrderedIdentitySet
             else:
                 filter = util.OrderedSet
-        
-        process = [query_entity.row_processor(self, context, single_entity) for query_entity in entities]
+        else:
+            filter = None
+
+        custom_rows = single_entity and 'append_result' in self._entities[0].extension.methods
 
+        (process, labels) = zip(*[query_entity.row_processor(self, context, custom_rows) for query_entity in self._entities])
+
+        if not single_entity:
+            labels = dict([(label, property(util.itemgetter(i))) for i, label in enumerate(labels) if label])
+            rowtuple = type.__new__(type, "RowTuple", (tuple,), labels)
+            rowtuple.keys = labels.keys
+            
         while True:
             context.progress = util.Set()
             context.partials = {}
@@ -974,7 +1097,7 @@ class Query(object):
             if self._yield_per:
                 fetch = cursor.fetchmany(self._yield_per)
                 if not fetch:
-                    return
+                    break
             else:
                 fetch = cursor.fetchall()
             
@@ -985,23 +1108,20 @@ class Query(object):
             elif single_entity:
                 rows = [process[0](context, row) for row in fetch]
             else:
-                rows = [tuple([proc(context, row) for proc in process]) for row in fetch]
+                rows = [rowtuple([proc(context, row) for proc in process]) for row in fetch]
 
             if filter:
                 rows = filter(rows)
 
-            if context.refresh_instance and context.only_load_props and context.refresh_instance in context.progress:
-                context.refresh_instance.commit(context.only_load_props)
+            if context.refresh_instance and self._only_load_props and context.refresh_instance in context.progress:
+                context.refresh_instance.commit(self._only_load_props)
                 context.progress.remove(context.refresh_instance)
 
-            for ii in context.progress:
-                context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii)
-                ii.commit_all()
-                
+            session._finalize_loaded(context.progress)
+
             for ii, attrs in context.partials.items():
-                context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii, only_load_props=attrs)
                 ii.commit(attrs)
-                
+
             for row in rows:
                 yield row
 
@@ -1010,9 +1130,18 @@ class Query(object):
 
     def _get(self, key=None, ident=None, refresh_instance=None, lockmode=None, only_load_props=None):
         lockmode = lockmode or self._lockmode
-        if not self._populate_existing and not refresh_instance and not self.mapper.always_refresh and lockmode is None:
+        if not self._populate_existing and not refresh_instance and not self._mapper_zero().always_refresh and lockmode is None:
             try:
-                return self.session.identity_map[key]
+                instance = self.session.identity_map[key]
+                state = attributes.instance_state(instance)
+                if state.expired:
+                    try:
+                        state()
+                    except orm_exc.ObjectDeletedError:
+                        # TODO: should we expunge ?  if so, should we expunge here ? or in mapper._load_scalar_attributes ?
+                        self.session.expunge(instance)
+                        return None
+                return instance
             except KeyError:
                 pass
 
@@ -1022,27 +1151,29 @@ class Query(object):
         else:
             ident = util.to_list(ident)
 
-        q = self
-        
-        # dont use 'polymorphic' mapper if we are refreshing an instance
-        if refresh_instance and q.mapper is not q.mapper:
-            q = q.__reset_all(q.mapper, '_get')
+        if refresh_instance is None:
+            q = self.__no_criterion()
+        else:
+            q = self._clone()
 
         if ident is not None:
-            q = q.__no_criterion('get')
+            mapper = q._mapper_zero()
             params = {}
-            (_get_clause, _get_params) = q.mapper._get_clause
-            q = q.filter(_get_clause)
-            for i, primary_key in enumerate(q.mapper.primary_key):
+            (_get_clause, _get_params) = mapper._get_clause
+
+            _get_clause = q._adapt_clause(_get_clause, True, False)
+            q._criterion = _get_clause
+
+            for i, primary_key in enumerate(mapper.primary_key):
                 try:
                     params[_get_params[primary_key].key] = ident[i]
                 except IndexError:
-                    raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in q.mapper.primary_key]))
-            q = q.params(params)
+                    raise sa_exc.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in q.mapper.primary_key]))
+            q._params = params
 
         if lockmode is not None:
-            q = q.with_lockmode(lockmode)
-        q = q.__get_options(populate_existing=bool(refresh_instance), version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance)
+            q._lockmode = lockmode
+        q.__get_options(populate_existing=bool(refresh_instance), version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance)
         q._order_by = None
         try:
             # call using all() to avoid LIMIT compilation complexity
@@ -1053,41 +1184,26 @@ class Query(object):
     def _select_args(self):
         return {'limit':self._limit, 'offset':self._offset, 'distinct':self._distinct, 'group_by':self._group_by or None, 'having':self._having or None}
     _select_args = property(_select_args)
-    
+
     def _should_nest_selectable(self):
         kwargs = self._select_args
         return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False))
     _should_nest_selectable = property(_should_nest_selectable)
 
-    def count(self, whereclause=None, params=None, **kwargs):
-        """Apply this query's criterion to a SELECT COUNT statement.
-
-        the whereclause, params and \**kwargs arguments are deprecated.  use filter()
-        and other generative methods to establish modifiers.
-
-        """
-        q = self
-        if whereclause is not None:
-            q = q.filter(whereclause)
-        if params is not None:
-            q = q.params(params)
-        q = q._legacy_select_kwargs(**kwargs)
-        return q._count()
-
-    def _count(self):
+    def count(self):
         """Apply this query's criterion to a SELECT COUNT statement.
 
         this is the purely generative version which will become
         the public method in version 0.5.
 
         """
-        return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self.mapper.primary_key))
+        return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._mapper_zero().primary_key))
 
     def _col_aggregate(self, col, func, nested_cols=None):
         whereclause = self._criterion
-        
+
         context = QueryContext(self)
-        from_obj = self._from_obj
+        from_obj = self.__mapper_zero_from_obj()
 
         if self._should_nest_selectable:
             if not nested_cols:
@@ -1097,113 +1213,97 @@ class Query(object):
             s = sql.select([func(s.corresponding_column(col) or col)]).select_from(s)
         else:
             s = sql.select([func(col)], whereclause, from_obj=from_obj, **self._select_args)
-            
+
         if self._autoflush and not self._populate_existing:
             self.session._autoflush()
-        return self.session.scalar(s, params=self._params, mapper=self.mapper)
+        return self.session.scalar(s, params=self._params, mapper=self._mapper_zero())
 
     def compile(self):
         """compiles and returns a SQL statement based on the criterion and conditions within this Query."""
 
         return self._compile_context().statement
 
-    def _compile_context(self):
-
+    def _compile_context(self, labels=True):
         context = QueryContext(self)
 
-        if self._statement:
-            self._statement.use_labels = True
-            context.statement = self._statement
+        if context.statement:
             return context
 
-        from_obj = self._from_obj
-        adapter = self._aliases_head
-        
         if self._lockmode:
             try:
-                for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode]
+                for_update = {'read': 'read',
+                              'update': True,
+                              'update_nowait': 'nowait',
+                              None: False}[self._lockmode]
             except KeyError:
-                raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode)
+                raise sa_exc.ArgumentError("Unknown lockmode '%s'" % self._lockmode)
         else:
             for_update = False
-            
-        context.from_clause = from_obj
-        context.whereclause = self._criterion
-        context.order_by = self._order_by
-        
+
         for entity in self._entities:
             entity.setup_context(self, context)
-            
-        if self._eager_loaders and self._should_nest_selectable:
-            # eager loaders are present, and the SELECT has limiting criterion
-            # produce a "wrapped" selectable.
-            
+
+        eager_joins = context.eager_joins.values()
+
+        if context.from_clause:
+            froms = [context.from_clause]  # "load from a single FROM" mode, i.e. when select_from() or join() is used
+        else:
+            froms = context.froms   # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM
+
+        if eager_joins and self._should_nest_selectable:
+            # for eager joins present and LIMIT/OFFSET/DISTINCT, wrap the query inside a select,
+            # then append eager joins onto that
+
             if context.order_by:
-                context.order_by = [expression._literal_as_text(o) for o in util.to_list(context.order_by) or []]
-                if adapter:
-                    context.order_by = adapter.adapt_list(context.order_by)
-                # locate all embedded Column clauses so they can be added to the
-                # "inner" select statement where they'll be available to the enclosing
-                # statement's "order by"
-                # TODO: this likely doesn't work with very involved ORDER BY expressions,
-                # such as those including subqueries
                 order_by_col_expr = list(chain(*[sql_util.find_columns(o) for o in context.order_by]))
             else:
                 context.order_by = None
                 order_by_col_expr = []
-                
-            if adapter:
-                context.primary_columns = adapter.adapt_list(context.primary_columns)
-            
-            inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=context.from_clause, use_labels=True, correlate=False, order_by=context.order_by, **self._select_args).alias()
-            local_adapter = sql_util.ClauseAdapter(inner)
 
-            context.row_adapter = mapperutil.create_row_adapter(inner, equivalent_columns=self.mapper._equivalent_columns)
+            inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=froms, use_labels=labels, correlate=False, order_by=context.order_by, **self._select_args)
+
+            if self._correlate:
+                inner = inner.correlate(*self._correlate)
+
+            inner = inner.alias()
 
-            statement = sql.select([inner] + context.secondary_columns, for_update=for_update, use_labels=True)
+            equivs = self.__all_equivs()
 
-            if context.eager_joins:
-                eager_joins = local_adapter.traverse(context.eager_joins)
-                statement.append_from(eager_joins)
+            context.adapter = sql_util.ColumnAdapter(inner, equivs)
+
+            statement = sql.select([inner] + context.secondary_columns, for_update=for_update, use_labels=labels)
+
+            from_clause = inner
+            for eager_join in eager_joins:
+                # EagerLoader places a 'stop_on' attribute on the join, 
+                # giving us a marker as to where the "splice point" of the join should be
+                from_clause = sql_util.splice_joins(from_clause, eager_join, eager_join.stop_on)
+
+            statement.append_from(from_clause)
 
             if context.order_by:
+                local_adapter = sql_util.ClauseAdapter(inner)
                 statement.append_order_by(*local_adapter.copy_and_process(context.order_by))
 
             statement.append_order_by(*context.eager_order_by)
         else:
-            if context.order_by:
-                context.order_by = [expression._literal_as_text(o) for o in util.to_list(context.order_by) or []]
-                if adapter:
-                    context.order_by = adapter.adapt_list(context.order_by)
-            else:
+            if not context.order_by:
                 context.order_by = None
-            
-            if adapter:
-                context.primary_columns = adapter.adapt_list(context.primary_columns)
-                context.row_adapter = mapperutil.create_row_adapter(adapter.alias, equivalent_columns=self.mapper._equivalent_columns)
-                
+
             if self._distinct and context.order_by:
                 order_by_col_expr = list(chain(*[sql_util.find_columns(o) for o in context.order_by]))
                 context.primary_columns += order_by_col_expr
 
-            statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=context.order_by, **self._select_args)
+            froms += context.eager_joins.values()
 
-            if context.eager_joins:
-                if adapter:
-                    context.eager_joins = adapter.adapt_clause(context.eager_joins)
-                statement.append_from(context.eager_joins)
+            statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=froms, use_labels=labels, for_update=for_update, correlate=False, order_by=context.order_by, **self._select_args)
+            if self._correlate:
+                statement = statement.correlate(*self._correlate)
 
             if context.eager_order_by:
-                if adapter:
-                    context.eager_order_by = adapter.adapt_list(context.eager_order_by)
                 statement.append_order_by(*context.eager_order_by)
 
-        # polymorphic mappers which have concrete tables in their hierarchy usually
-        # require row aliasing unconditionally.  
-        if not context.row_adapter and self.mapper._requires_row_aliasing:
-            context.row_adapter = mapperutil.create_row_adapter(self.table, equivalent_columns=self.mapper._equivalent_columns)
-            
-        context.statement = statement
+        context.statement = statement._annotate({'_halt_adapt': True})
 
         return context
 
@@ -1213,462 +1313,257 @@ class Query(object):
     def __str__(self):
         return str(self.compile())
 
-    # DEPRECATED LAND !
-
-    def _generative_col_aggregate(self, col, func):
-        """apply the given aggregate function to the query and return the newly
-        resulting ``Query``. (deprecated)
-        """
-        if self._column_aggregate is not None:
-            raise exceptions.InvalidRequestError("Query already contains an aggregate column or function")
-        q = self.__no_statement("aggregate")
-        q._column_aggregate = (col, func)
-        return q
-
-    def apply_min(self, col):
-        """apply the SQL ``min()`` function against the given column to the
-        query and return the newly resulting ``Query``.
-        
-        DEPRECATED.
-        """
-        return self._generative_col_aggregate(col, sql.func.min)
-
-    def apply_max(self, col):
-        """apply the SQL ``max()`` function against the given column to the
-        query and return the newly resulting ``Query``.
-
-        DEPRECATED.
-        """
-        return self._generative_col_aggregate(col, sql.func.max)
-
-    def apply_sum(self, col):
-        """apply the SQL ``sum()`` function against the given column to the
-        query and return the newly resulting ``Query``.
-
-        DEPRECATED.
-        """
-        return self._generative_col_aggregate(col, sql.func.sum)
-
-    def apply_avg(self, col):
-        """apply the SQL ``avg()`` function against the given column to the
-        query and return the newly resulting ``Query``.
-
-        DEPRECATED.
-        """
-        return self._generative_col_aggregate(col, sql.func.avg)
-
-    def list(self): #pragma: no cover
-        """DEPRECATED.  use all()"""
-
-        return list(self)
-
-    def scalar(self): #pragma: no cover
-        """DEPRECATED.  use first()"""
-
-        return self.first()
-
-    def _legacy_filter_by(self, *args, **kwargs): #pragma: no cover
-        return self.filter(self._legacy_join_by(args, kwargs, start=self._joinpoint))
-
-    def count_by(self, *args, **params): #pragma: no cover
-        """DEPRECATED.  use query.filter_by(\**params).count()"""
-
-        return self.count(self.join_by(*args, **params))
-
 
-    def select_whereclause(self, whereclause=None, params=None, **kwargs): #pragma: no cover
-        """DEPRECATED.  use query.filter(whereclause).all()"""
-
-        q = self.filter(whereclause)._legacy_select_kwargs(**kwargs)
-        if params is not None:
-            q = q.params(params)
-        return list(q)
+class _QueryEntity(object):
+    """represent an entity column returned within a Query result."""
 
-    def _legacy_select_from(self, from_obj):
-        q = self._clone()
-        if len(from_obj) > 1:
-            raise exceptions.ArgumentError("Multiple-entry from_obj parameter no longer supported")
-        q._from_obj = from_obj[0]
-        return q
+    def __new__(cls, *args, **kwargs):
+        if cls is _QueryEntity:
+            entity = args[1]
+            if _is_mapped_class(entity):
+                cls = _MapperEntity
+            else:
+                cls = _ColumnEntity
+        return object.__new__(cls)
 
-    def _legacy_select_kwargs(self, **kwargs): #pragma: no cover
-        q = self
-        if "order_by" in kwargs and kwargs['order_by']:
-            q = q.order_by(kwargs['order_by'])
-        if "group_by" in kwargs:
-            q = q.group_by(kwargs['group_by'])
-        if "from_obj" in kwargs:
-            q = q._legacy_select_from(kwargs['from_obj'])
-        if "lockmode" in kwargs:
-            q = q.with_lockmode(kwargs['lockmode'])
-        if "distinct" in kwargs:
-            q = q.distinct()
-        if "limit" in kwargs:
-            q = q.limit(kwargs['limit'])
-        if "offset" in kwargs:
-            q = q.offset(kwargs['offset'])
+    def _clone(self):
+        q = self.__class__.__new__(self.__class__)
+        q.__dict__ = self.__dict__.copy()
         return q
 
+class _MapperEntity(_QueryEntity):
+    """mapper/class/AliasedClass entity"""
 
-    def get_by(self, *args, **params): #pragma: no cover
-        """DEPRECATED.  use query.filter_by(\**params).first()"""
-
-        ret = self._extension.get_by(self, *args, **params)
-        if ret is not mapper.EXT_CONTINUE:
-            return ret
-
-        return self._legacy_filter_by(*args, **params).first()
-
-    def select_by(self, *args, **params): #pragma: no cover
-        """DEPRECATED. use use query.filter_by(\**params).all()."""
-
-        ret = self._extension.select_by(self, *args, **params)
-        if ret is not mapper.EXT_CONTINUE:
-            return ret
-
-        return self._legacy_filter_by(*args, **params).list()
-
-    def join_by(self, *args, **params): #pragma: no cover
-        """DEPRECATED. use join() to construct joins based on attribute names."""
+    def __init__(self, query, entity, entity_name=None):
+        self.primary_entity = not query._entities
+        query._entities.append(self)
 
-        return self._legacy_join_by(args, params, start=self._joinpoint)
+        self.entities = [entity]
+        self.entity_zero = entity
+        self.entity_name = entity_name
 
-    def _build_select(self, arg=None, params=None, **kwargs): #pragma: no cover
-        if isinstance(arg, sql.FromClause) and arg.supports_execution():
-            return self.from_statement(arg)
-        elif arg is not None:
-            return self.filter(arg)._legacy_select_kwargs(**kwargs)
+    def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic):
+        self.mapper = mapper
+        self.extension = self.mapper.extension
+        self.adapter = adapter
+        self.selectable  = from_obj
+        self._with_polymorphic = with_polymorphic
+        self.is_aliased_class = is_aliased_class
+        if is_aliased_class:
+            self.path_entity = self.entity = self.entity_zero = entity
         else:
-            return self._legacy_select_kwargs(**kwargs)
-
-    def selectfirst(self, arg=None, **kwargs): #pragma: no cover
-        """DEPRECATED.  use query.filter(whereclause).first()"""
-
-        return self._build_select(arg, **kwargs).first()
-
-    def selectone(self, arg=None, **kwargs): #pragma: no cover
-        """DEPRECATED.  use query.filter(whereclause).one()"""
-
-        return self._build_select(arg, **kwargs).one()
-
-    def select(self, arg=None, **kwargs): #pragma: no cover
-        """DEPRECATED.  use query.filter(whereclause).all(), or query.from_statement(statement).all()"""
+            self.path_entity = mapper.base_mapper
+            self.entity = self.entity_zero = mapper
 
-        ret = self._extension.select(self, arg=arg, **kwargs)
-        if ret is not mapper.EXT_CONTINUE:
-            return ret
-        return self._build_select(arg, **kwargs).all()
-
-    def execute(self, clauseelement, params=None, *args, **kwargs): #pragma: no cover
-        """DEPRECATED.  use query.from_statement().all()"""
-
-        return self._select_statement(clauseelement, params, **kwargs)
-
-    def select_statement(self, statement, **params): #pragma: no cover
-        """DEPRECATED.  Use query.from_statement(statement)"""
-
-        return self._select_statement(statement, params)
-
-    def select_text(self, text, **params): #pragma: no cover
-        """DEPRECATED.  Use query.from_statement(statement)"""
+    def set_with_polymorphic(self, query, cls_or_mappers, selectable):
+        if cls_or_mappers is None:
+            query._reset_polymorphic_adapter(self.mapper)
+            return
 
-        return self._select_statement(text, params)
-
-    def _select_statement(self, statement, params=None, **kwargs): #pragma: no cover
-        q = self.from_statement(statement)
-        if params is not None:
-            q = q.params(params)
-        q.__get_options(**kwargs)
-        return list(q)
-
-    def join_to(self, key): #pragma: no cover
-        """DEPRECATED. use join() to create joins based on property names."""
-
-        [keys, p] = self._locate_prop(key)
-        return self.join_via(keys)
+        mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable)
+        self._with_polymorphic = mappers
 
-    def join_via(self, keys): #pragma: no cover
-        """DEPRECATED. use join() to create joins based on property names."""
+        # TODO: do the wrapped thing here too so that with_polymorphic() can be
+        # applied to aliases
+        if not self.is_aliased_class:
+            self.selectable = from_obj
+            self.adapter = query._get_polymorphic_adapter(self, from_obj)
 
-        mapper = self._joinpoint
-        clause = None
-        for key in keys:
-            prop = mapper.get_property(key, resolve_synonyms=True)
-            if clause is None:
-                clause = prop._get_join(mapper)
-            else:
-                clause &= prop._get_join(mapper)
-            mapper = prop.mapper
+    def corresponds_to(self, entity):
+        if _is_aliased_class(entity):
+            return entity is self.path_entity
+        else:
+            return entity.base_mapper is self.path_entity
 
-        return clause
+    def _get_entity_clauses(self, query, context):
 
-    def _legacy_join_by(self, args, params, start=None): #pragma: no cover
-        import properties
+        adapter = None
+        if not self.is_aliased_class and query._polymorphic_adapters:
+            for mapper in self.mapper.iterate_to_root():
+                adapter = query._polymorphic_adapters.get(mapper.mapped_table, None)
+                if adapter:
+                    break
 
-        clause = None
-        for arg in args:
-            if clause is None:
-                clause = arg
-            else:
-                clause &= arg
+        if not adapter and self.adapter:
+            adapter = self.adapter
 
-        for key, value in params.iteritems():
-            (keys, prop) = self._locate_prop(key, start=start)
-            if isinstance(prop, properties.PropertyLoader):
-                c = prop.compare(operators.eq, value) & self.join_via(keys[:-1])
+        if adapter:
+            if query._from_obj_alias:
+                ret = adapter.wrap(query._from_obj_alias)
             else:
-                c = prop.compare(operators.eq, value) & self.join_via(keys)
-            if clause is None:
-                clause =  c
-            else:
-                clause &= c
-        return clause
-
-    def _locate_prop(self, key, start=None): #pragma: no cover
-        import properties
-        keys = []
-        seen = util.Set()
-        def search_for_prop(mapper_):
-            if mapper_ in seen:
-                return None
-            seen.add(mapper_)
-
-            prop = mapper_.get_property(key, resolve_synonyms=True, raiseerr=False)
-            if prop is not None:
-                if isinstance(prop, properties.PropertyLoader):
-                    keys.insert(0, prop.key)
-                return prop
-            else:
-                for prop in mapper_.iterate_properties:
-                    if not isinstance(prop, properties.PropertyLoader):
-                        continue
-                    x = search_for_prop(prop.mapper)
-                    if x:
-                        keys.insert(0, prop.key)
-                        return x
-                else:
-                    return None
-        p = search_for_prop(start or self.mapper)
-        if p is None:
-            raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key)
-        return [keys, p]
-
-    def selectfirst_by(self, *args, **params): #pragma: no cover
-        """DEPRECATED. Use query.filter_by(\**kwargs).first()"""
-
-        return self._legacy_filter_by(*args, **params).first()
-
-    def selectone_by(self, *args, **params): #pragma: no cover
-        """DEPRECATED. Use query.filter_by(\**kwargs).one()"""
-
-        return self._legacy_filter_by(*args, **params).one()
-
-    for deprecated_method in ('list', 'scalar', 'count_by',
-                              'select_whereclause', 'get_by', 'select_by',
-                              'join_by', 'selectfirst', 'selectone', 'select',
-                              'execute', 'select_statement', 'select_text',
-                              'join_to', 'join_via', 'selectfirst_by',
-                              'selectone_by', 'apply_max', 'apply_min',
-                              'apply_avg', 'apply_sum'):
-        locals()[deprecated_method] = \
-            util.deprecated(None, False)(locals()[deprecated_method])
-
-class _QueryEntity(object):
-    """represent an entity column returned within a Query result."""
-    
-    def legacy_guess_type(self, e):
-        if isinstance(e, type):
-            return _MapperEntity(mapper=mapper.class_mapper(e))
-        elif isinstance(e, mapper.Mapper):
-            return _MapperEntity(mapper=e)
+                ret = adapter
         else:
-            return _ColumnEntity(column=e)
-    legacy_guess_type=classmethod(legacy_guess_type)
+            ret = query._from_obj_alias
 
-class _MapperEntity(_QueryEntity):
-    """entity column corresponding to mapped ORM instances."""
-    
-    def __init__(self, mapper, alias=None, id=None):
-        self.mapper = mapper
-        self.alias = alias
-        self.alias_id = id
-
-    def _get_entity_clauses(self, query):
-        if self.alias:
-            return self.alias
-        elif self.alias_id:
-            try:
-                return query._alias_ids[self.alias_id][0]
-            except KeyError:
-                raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % self.alias_id)
-
-        l = query._alias_ids.get(self.mapper)
-        if l:
-            if len(l) > 1:
-                raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_entity()" % str(self.mapper))
-            return l[0]
-        else:
-            return None
-            
-    def row_processor(self, query, context, single_entity):
-        clauses = self._get_entity_clauses(query) 
-        if clauses:
-            def proc(context, row):
-                return self.mapper._instance(context, clauses.row_decorator(row), None)
-        else:
-            def proc(context, row):
-                return self.mapper._instance(context, row, None)
-            
-        return proc
-    
-    def setup_context(self, query, context):
-        clauses = self._get_entity_clauses(query)
-        for value in self.mapper.iterate_properties:
-            context.exec_with_path(self.mapper, value.key, value.setup, context, parentclauses=clauses)
+        return ret
 
-    def __str__(self):
-        return str(self.mapper)
+    def row_processor(self, query, context, custom_rows):
+        adapter = self._get_entity_clauses(query, context)
 
-class _PrimaryMapperEntity(_MapperEntity):
-    """entity column corresponding to the 'primary' (first) mapped ORM instance."""
+        if context.adapter and adapter:
+            adapter = adapter.wrap(context.adapter)
+        elif not adapter:
+            adapter = context.adapter
 
-    def row_processor(self, query, context, single_entity):
-        if single_entity and 'append_result' in context.extension.methods:    
+        # polymorphic mappers which have concrete tables in their hierarchy usually
+        # require row aliasing unconditionally.
+        if not adapter and self.mapper._requires_row_aliasing:
+            adapter = sql_util.ColumnAdapter(self.selectable, self.mapper._equivalent_columns)
+
+        if self.primary_entity:
+            _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter, 
+                extension=self.extension, only_load_props=query._only_load_props, refresh_instance=context.refresh_instance
+            )
+        else:
+            _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter)
+        
+        if custom_rows:
             def main(context, row, result):
-                if context.row_adapter:
-                    row = context.row_adapter(row)
-                self.mapper._instance(context, row, result,
-                    extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance
-                )
-        elif context.row_adapter:
-            def main(context, row):
-                return self.mapper._instance(context, context.row_adapter(row), None,
-                    extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance
-                )
+                _instance(row, result)
         else:
             def main(context, row):
-                return self.mapper._instance(context, row, None,
-                    extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance
-                )
+                return _instance(row, None)
         
-        return main
+        if self.is_aliased_class:
+            entname = self.entity._sa_label_name
+        else:
+            entname = self.mapper.class_.__name__
+            
+        return main, entname
 
     def setup_context(self, query, context):
         # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
         # that we only load the appropriate types
         if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
             context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()]))
-        
-        if context.order_by is False:
-            if self.mapper.order_by:
-                context.order_by = self.mapper.order_by
-            elif context.from_clause.default_order_by():
-                context.order_by = context.from_clause.default_order_by()
-                
-        for value in self.mapper._iterate_polymorphic_properties(query._with_polymorphic, context.from_clause):
+
+        context.froms.append(self.selectable)
+
+        adapter = self._get_entity_clauses(query, context)
+
+        if self.primary_entity:
+            if context.order_by is False:
+                # the "default" ORDER BY use case applies only to "mapper zero".  the "from clause" default should
+                # go away in 0.5 (or...maybe 0.6).
+                if self.mapper.order_by:
+                    context.order_by = self.mapper.order_by
+                elif context.from_clause:
+                    context.order_by = context.from_clause.default_order_by()
+                else:
+                    context.order_by = self.selectable.default_order_by()
+            if context.order_by and adapter:
+                context.order_by = adapter.adapt_list(util.to_list(context.order_by))
+
+        for value in self.mapper._iterate_polymorphic_properties(self._with_polymorphic):
             if query._only_load_props and value.key not in query._only_load_props:
                 continue
-            context.exec_with_path(self.mapper, value.key, value.setup, context, only_load_props=query._only_load_props)
+            value.setup(context, self, (self.path_entity,), adapter, only_load_props=query._only_load_props, column_collection=context.primary_columns)
+
+    def __str__(self):
+        return str(self.mapper)
+
 
 class _ColumnEntity(_QueryEntity):
-    """entity column corresponding to Table or selectable columns."""
+    """Column/expression based entity."""
+
+    def __init__(self, query, column, entity_name=None):
+        if isinstance(column, expression.FromClause) and not isinstance(column, expression.ColumnElement):
+            for c in column.c:
+                _ColumnEntity(query, c)
+            return
+            
+        query._entities.append(self)
 
-    def __init__(self, column, id):
         if isinstance(column, basestring):
             column = sql.literal_column(column)
-            
-        if column and isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'):
+        elif isinstance(column, (attributes.QueryableAttribute, mapper.Mapper._CompileOnAttr)):
+            column = column.__clause_element__()
+        elif not isinstance(column, sql.ColumnElement):
+            raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column)
+
+        if not hasattr(column, '_label'):
             column = column.label(None)
+
         self.column = column
-        self.alias_id = id
+        self.entity_name = None
+        self.froms = util.Set()
+        self.entities = util.OrderedSet([elem._annotations['parententity'] for elem in visitors.iterate(column, {}) if 'parententity' in elem._annotations])
+        if self.entities:
+            self.entity_zero = list(self.entities)[0]
+        else:
+            self.entity_zero = None
+        
+    def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic):
+        self.froms.add(from_obj)
 
     def __resolve_expr_against_query_aliases(self, query, expr, context):
-        if not query._alias_ids:
-            return expr
-            
-        if ('_ColumnEntity', expr) in context.attributes:
-            return context.attributes[('_ColumnEntity', expr)]
-        
-        if self.alias_id:
-            try:
-                aliases = query._alias_ids[self.alias_id][0]
-            except KeyError:
-                raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % self.alias_id)
+        return query._adapt_clause(expr, False, True)
 
-            def _locate_aliased(element):
-                if element in query._alias_ids:
-                    return aliases
-        else:
-            def _locate_aliased(element):
-                if element in query._alias_ids:
-                    aliases = query._alias_ids[element]
-                    if len(aliases) > 1:
-                        raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_column(), or use the aliased() function to use explicit class aliases." % expr)
-                    return aliases[0]
-                return None
-
-        class Adapter(visitors.ClauseVisitor):
-            def before_clone(self, element):
-                if isinstance(element, expression.FromClause):
-                    alias = _locate_aliased(element)
-                    if alias:
-                        return alias.alias
-                
-                if hasattr(element, 'table'):
-                    alias = _locate_aliased(element.table)
-                    if alias:
-                        return alias.aliased_column(element)
+    def row_processor(self, query, context, custom_rows):
+        column = self.__resolve_expr_against_query_aliases(query, self.column, context)
 
-                return None
+        if context.adapter:
+            column = context.adapter.columns[column]
 
-        context.attributes[('_ColumnEntity', expr)] = ret = Adapter().traverse(expr, clone=True)
-        return ret
-        
-    def row_processor(self, query, context, single_entity):
-        column = self.__resolve_expr_against_query_aliases(query, self.column, context)
         def proc(context, row):
             return row[column]
-        return proc
-    
+            
+        return (proc, getattr(column, 'name', None))
+
     def setup_context(self, query, context):
         column = self.__resolve_expr_against_query_aliases(query, self.column, context)
-        context.secondary_columns.append(column)
-    
+        context.froms += list(self.froms)
+        context.primary_columns.append(column)
+
     def __str__(self):
         return str(self.column)
 
-        
-Query.logger = logging.class_logger(Query)
+Query.logger = log.class_logger(Query)
 
 class QueryContext(object):
     def __init__(self, query):
+
+        if query._statement:
+            if isinstance(query._statement, expression._SelectBaseMixin) and not query._statement.use_labels:
+                self.statement = query._statement.apply_labels()
+            else:
+                self.statement = query._statement
+        else:
+            self.statement = None
+            self.from_clause = query._from_obj
+            self.whereclause = query._criterion
+            self.order_by = query._order_by
+            if self.order_by:
+                self.order_by = [expression._literal_as_text(o) for o in util.to_list(self.order_by)]
+            
         self.query = query
-        self.mapper = query.mapper
         self.session = query.session
-        self.extension = query._extension
-        self.statement = None
-        self.row_adapter = None
         self.populate_existing = query._populate_existing
         self.version_check = query._version_check
-        self.only_load_props = query._only_load_props
         self.refresh_instance = query._refresh_instance
-        self.path = ()
         self.primary_columns = []
         self.secondary_columns = []
         self.eager_order_by = []
-        self.eager_joins = None
+
+        self.eager_joins = {}
+        self.froms = []
+        self.adapter = None
+
         self.options = query._with_options
         self.attributes = query._attributes.copy()
 
-    def exec_with_path(self, mapper, propkey, fn, *args, **kwargs):
-        oldpath = self.path
-        self.path += (mapper.base_mapper, propkey)
-        try:
-            return fn(*args, **kwargs)
-        finally:
-            self.path = oldpath
+class AliasOption(interfaces.MapperOption):
 
+    def __init__(self, alias):
+        self.alias = alias
 
+    def process_query(self, query):
+        if isinstance(self.alias, basestring):
+            alias = query._mapper_zero().mapped_table.alias(self.alias)
+        else:
+            alias = self.alias
+        query._from_obj_alias = sql_util.ColumnAdapter(alias)
+    
 
 _runid = 1L
 _id_lock = util.threading.Lock()
index 479b2f7374983c6781edde74e4fad80db1e32db2..c1d3db9f1d34d66df1af60c0758bc89c3473f14f 100644 (file)
@@ -1,8 +1,17 @@
+# scoping.py
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
+
+import inspect
+import types
+
+import sqlalchemy.exceptions as sa_exc
 from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs
-from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session, class_mapper
+from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session, \
+     class_mapper
 from sqlalchemy.orm.session import Session
-from sqlalchemy import exceptions
-import types
 
 __all__ = ['ScopedSession']
 
@@ -33,7 +42,7 @@ class ScopedSession(object):
             scope = kwargs.pop('scope', False)
             if scope is not None:
                 if self.registry.has():
-                    raise exceptions.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
+                    raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.")
                 else:
                     sess = self.session_factory(**kwargs)
                     self.registry.set(sess)
@@ -53,7 +62,7 @@ class ScopedSession(object):
         
         from sqlalchemy.orm import mapper
         
-        extension_args = dict([(arg,kwargs.pop(arg))
+        extension_args = dict([(arg, kwargs.pop(arg))
                                for arg in get_cls_kwargs(_ScopedExt)
                                if arg in kwargs])
         
@@ -110,10 +119,10 @@ for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map'):
     setattr(ScopedSession, prop, makeprop(prop))
 
 def clslevel(name):
-    def do(cls, *args,**kwargs):
+    def do(cls, *args, **kwargs):
         return getattr(Session, name)(*args, **kwargs)
     return classmethod(do)
-for prop in ('close_all','object_session', 'identity_key'):
+for prop in ('close_all', 'object_session', 'identity_key'):
     setattr(ScopedSession, prop, clslevel(prop))
     
 class _ScopedExt(MapperExtension):
@@ -121,6 +130,7 @@ class _ScopedExt(MapperExtension):
         self.context = context
         self.validate = validate
         self.save_on_init = save_on_init
+        self.set_kwargs_on_init = None
     
     def validating(self):
         return _ScopedExt(self.context, validate=True)
@@ -128,37 +138,49 @@ class _ScopedExt(MapperExtension):
     def configure(self, **kwargs):
         return _ScopedExt(self.context, **kwargs)
     
-    def get_session(self):
-        return self.context.registry()
-
     def instrument_class(self, mapper, class_):
         class query(object):
             def __getattr__(s, key):
                 return getattr(self.context.registry().query(class_), key)
             def __call__(s):
                 return self.context.registry().query(class_)
-
+            def __get__(self, instance, cls):
+                return self
+                
         if not 'query' in class_.__dict__: 
             class_.query = query()
-        
+
+        if self.set_kwargs_on_init is None:
+            self.set_kwargs_on_init = class_.__init__ is object.__init__
+        if self.set_kwargs_on_init:
+            def __init__(self, **kwargs):
+                pass
+            class_.__init__ = __init__
+
     def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
         if self.save_on_init:
             entity_name = kwargs.pop('_sa_entity_name', None)
             session = kwargs.pop('_sa_session', None)
-        if not isinstance(oldinit, types.MethodType):
+
+        if self.set_kwargs_on_init:
             for key, value in kwargs.items():
                 if self.validate:
-                    if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False):
-                        raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key)
+                    if not mapper.get_property(key, resolve_synonyms=False,
+                                               raiseerr=False):
+                        raise sa_exc.ArgumentError(
+                            "Invalid __init__ argument: '%s'" % key)
                 setattr(instance, key, value)
             kwargs.clear()
+
         if self.save_on_init:
             session = session or self.context.registry()
-            session._save_impl(instance, entity_name=entity_name)
+            session._save_without_cascade(instance, entity_name=entity_name)
         return EXT_CONTINUE
 
     def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
-        object_session(instance).expunge(instance)
+        sess = object_session(instance)
+        if sess:
+            sess.expunge(instance)
         return EXT_CONTINUE
 
     def dispose_class(self, mapper, class_):
index 57f23ace295b5936f9400a0b5b7c5496dfec6d2f..68a3aed68af7144592a7d840abec523c798dda36 100644 (file)
@@ -6,18 +6,24 @@
 
 """Provides the Session class and related utilities."""
 
-
 import weakref
-from sqlalchemy import util, exceptions, sql, engine
-from sqlalchemy.orm import unitofwork, query, attributes, util as mapperutil
-from sqlalchemy.orm.mapper import object_mapper as _object_mapper
-from sqlalchemy.orm.mapper import class_mapper as _class_mapper
-from sqlalchemy.orm.mapper import Mapper
 
+import sqlalchemy.exceptions as sa_exc
+import sqlalchemy.orm.attributes
+from sqlalchemy import util, sql, engine
+from sqlalchemy.sql import util as sql_util, expression
+from sqlalchemy.orm import exc, unitofwork, query, attributes, \
+     util as mapperutil, SessionExtension
+from sqlalchemy.orm.util import object_mapper as _object_mapper
+from sqlalchemy.orm.util import class_mapper as _class_mapper
+from sqlalchemy.orm.util import _state_mapper, _state_has_identity, _class_to_mapper
+from sqlalchemy.orm.mapper import Mapper
+from sqlalchemy.orm.unitofwork import UOWTransaction
+from sqlalchemy.orm import identity
 
 __all__ = ['Session', 'SessionTransaction', 'SessionExtension']
 
-def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **kwargs):
+def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, autoexpire=True, **kwargs):
     """Generate a custom-configured [sqlalchemy.orm.session#Session] class.
 
     The returned object is a subclass of ``Session``, which, when instantiated with no
@@ -54,20 +60,111 @@ def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **k
 
         sess = Session()
 
-    The function features a single keyword argument of its own, `class_`, which
-    may be used to specify an alternate class other than ``sqlalchemy.orm.session.Session``
-    which should be used by the returned class.  All other keyword arguments sent to
-    `sessionmaker()` are passed through to the instantiated `Session()` object.
-    """
+    Options:
+    
+        autocommit
+            Defaults to ``False``. When ``True``, the ``Session`` does not keep a
+            persistent transaction running, and will acquire connections from the engine
+            on an as-needed basis, returning them immediately after their use. Flushes
+            will begin and commit (or possibly rollback) their own transaction if no
+            transaction is present. When using this mode, the `session.begin()` method
+            may be used to begin a transaction explicitly.
+        
+            Leaving it on its default value of ``False`` means that the ``Session`` will
+            acquire a connection and begin a transaction the first time it is used, which
+            it will maintain persistently until ``rollback()``, ``commit()``, or
+            ``close()`` is called. When the transaction is released by any of these
+            methods, the ``Session`` is ready for the next usage, which will again acquire
+            and maintain a new connection/transaction.
+        
+        autoexpire
+            When ``True``, all instances will be fully expired after each ``rollback()``
+            and after each ``commit()``, so that all attribute/object access subsequent
+            to a completed transaction will load from the most recent database state.
+        
+        autoflush
+            When ``True``, all query operations will issue a ``flush()`` call to this
+            ``Session`` before proceeding. This is a convenience feature so that
+            ``flush()`` need not be called repeatedly in order for database queries to
+            retrieve results. It's typical that ``autoflush`` is used in conjunction with
+            ``autocommit=False``.  In this scenario, explicit calls to ``flush()`` are rarely
+            needed; you usually only need to call ``commit()`` (which flushes) to finalize 
+            changes.
+
+        bind
+            An optional ``Engine`` or ``Connection`` to which this ``Session`` should be
+            bound. When specified, all SQL operations performed by this session will
+            execute via this connectable.
+
+        binds
+            An optional dictionary, which contains more granular "bind" information than
+            the ``bind`` parameter provides. This dictionary can map individual ``Table``
+            instances as well as ``Mapper`` instances to individual ``Engine`` or
+            ``Connection`` objects. Operations which proceed relative to a particular
+            ``Mapper`` will consult this dictionary for the direct ``Mapper`` instance as
+            well as the mapper's ``mapped_table`` attribute in order to locate an
+            connectable to use. The full resolution is described in the ``get_bind()``
+            method of ``Session``. Usage looks like::
+
+                sess = Session(binds={
+                    SomeMappedClass : create_engine('postgres://engine1'),
+                    somemapper : create_engine('postgres://engine2'),
+                    some_table : create_engine('postgres://engine3'),
+                })
+
+            Also see the ``bind_mapper()`` and ``bind_table()`` methods.
+
+        \class_
+            Specify an alternate class other than ``sqlalchemy.orm.session.Session``
+            which should be used by the returned class.  This is the only argument 
+            that is local to the ``sessionmaker()`` function, and is not sent
+            directly to the constructor for ``Session``.
 
+        echo_uow
+            When ``True``, configure Python logging to dump all unit-of-work
+            transactions. This is the equivalent of
+            ``logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.DEBUG)``.
+
+        extension
+            An optional [sqlalchemy.orm.session#SessionExtension] instance, which will receive
+            pre- and post- commit and flush events, as well as a post-rollback event.  User-
+            defined code may be placed within these hooks using a user-defined subclass
+            of ``SessionExtension``.
+
+        twophase
+            When ``True``, all transactions will be started using
+            [sqlalchemy.engine_TwoPhaseTransaction]. During a ``commit()``, after
+            ``flush()`` has been issued for all attached databases, the ``prepare()``
+            method on each database's ``TwoPhaseTransaction`` will be called. This allows
+            each database to roll back the entire transaction, before each transaction is
+            committed.
+
+        weak_identity_map
+            When set to the default value of ``False``, a weak-referencing map is used;
+            instances which are not externally referenced will be garbage collected
+            immediately. For dereferenced instances which have pending changes present,
+            the attribute management system will create a temporary strong-reference to
+            the object which lasts until the changes are flushed to the database, at which
+            point it's again dereferenced. Alternatively, when using the value ``True``,
+            the identity map uses a regular Python dictionary to store instances. The
+            session will maintain all instances present until they are removed using
+            expunge(), clear(), or purge().
+    
+    """
+    
+    if 'transactional' in kwargs:
+        util.warn_deprecated("The 'transactional' argument to sessionmaker() is deprecated; use autocommit=True|False instead.")
+        autocommit = not kwargs.pop('transactional')
+        
     kwargs['bind'] = bind
     kwargs['autoflush'] = autoflush
-    kwargs['transactional'] = transactional
+    kwargs['autocommit'] = autocommit
+    kwargs['autoexpire'] = autoexpire
 
     if class_ is None:
         class_ = Session
 
-    class Sess(class_):
+    class Sess(object):
         def __init__(self, **local_kwargs):
             for k in kwargs:
                 local_kwargs.setdefault(k, kwargs[k])
@@ -83,57 +180,9 @@ def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **k
 
             kwargs.update(new_kwargs)
         configure = classmethod(configure)
+    s = type.__new__(type, "Session", (Sess, class_), {})
+    return s
 
-    return Sess
-
-class SessionExtension(object):
-    """An extension hook object for Sessions.  Subclasses may be installed into a Session
-    (or sessionmaker) using the ``extension`` keyword argument.
-    """
-
-    def before_commit(self, session):
-        """Execute right before commit is called.
-
-        Note that this may not be per-flush if a longer running transaction is ongoing."""
-
-    def after_commit(self, session):
-        """Execute after a commit has occured.
-
-        Note that this may not be per-flush if a longer running transaction is ongoing."""
-
-    def after_rollback(self, session):
-        """Execute after a rollback has occured.
-
-        Note that this may not be per-flush if a longer running transaction is ongoing."""
-
-    def before_flush(self, session, flush_context, instances):
-        """Execute before flush process has started.
-
-        `instances` is an optional list of objects which were passed to the ``flush()``
-        method.
-        """
-
-    def after_flush(self, session, flush_context):
-        """Execute after flush has completed, but before commit has been called.
-
-        Note that the session's state is still in pre-flush, i.e. 'new', 'dirty',
-        and 'deleted' lists still show pre-flush state as well as the history
-        settings on instance attributes."""
-
-    def after_flush_postexec(self, session, flush_context):
-        """Execute after flush has completed, and after the post-exec state occurs.
-
-        This will be when the 'new', 'dirty', and 'deleted' lists are in their final
-        state.  An actual commit() may or may not have occured, depending on whether or not
-        the flush started its own transaction or participated in a larger transaction.
-        """
-    
-    def after_begin(self, session, transaction, connection):
-        """Execute after a transaction is begun on a connection
-        
-        `transaction` is the SessionTransaction. This method is called after an
-        engine level transaction is begun on a connection.
-        """
 
 class SessionTransaction(object):
     """Represents a Session-level Transaction.
@@ -157,59 +206,100 @@ class SessionTransaction(object):
         self.nested = nested
         self._active = True
         self._prepared = False
+        if not parent and nested:
+            raise sa_exc.InvalidRequestError("Can't start a SAVEPOINT transaction when no existing transaction is in progress")
+        self._take_snapshot()
 
-    is_active = property(lambda s: s.session is not None and s._active)
+    def is_active(self):
+        return self.session is not None and self._active
+    is_active = property(is_active)
     
     def _assert_is_active(self):
         self._assert_is_open()
         if not self._active:
-            raise exceptions.InvalidRequestError("The transaction is inactive due to a rollback in a subtransaction and should be closed")
+            raise sa_exc.InvalidRequestError("The transaction is inactive due to a rollback in a subtransaction.  Issue rollback() to cancel the transaction.")
 
     def _assert_is_open(self):
         if self.session is None:
-            raise exceptions.InvalidRequestError("The transaction is closed")
-
+            raise sa_exc.InvalidRequestError("The transaction is closed")
+    
+    def _is_transaction_boundary(self):
+        return self.nested or not self._parent
+    _is_transaction_boundary = property(_is_transaction_boundary)
+    
     def connection(self, bindkey, **kwargs):
         self._assert_is_active()
         engine = self.session.get_bind(bindkey, **kwargs)
-        return self.get_or_add(engine)
+        return self._connection_for_bind(engine)
 
-    def _begin(self, **kwargs):
+    def _begin(self, autoflush=True, nested=False):
         self._assert_is_active()
-        return SessionTransaction(self.session, self, **kwargs)
+        return SessionTransaction(self.session, self, autoflush=autoflush, nested=nested)
 
     def _iterate_parents(self, upto=None):
         if self._parent is upto:
             return (self,)
         else:
             if self._parent is None:
-                raise exceptions.InvalidRequestError("Transaction %s is not on the active transaction list" % upto)
+                raise sa_exc.InvalidRequestError("Transaction %s is not on the active transaction list" % upto)
             return (self,) + self._parent._iterate_parents(upto)
+    
+    def _take_snapshot(self):
+        if not self._is_transaction_boundary:
+            self._new = self._parent._new
+            self._deleted = self._parent._deleted
+            return
+        
+        if self.nested:
+            self.session.flush()
+            
+        if self.autoflush:
+            # TODO: the "dirty_states" assertion is expensive,
+            # so consider these assertions as temporary
+            # during development
+            assert not self.session._new
+            assert not self.session._deleted
+            assert not self.session._dirty_states
+        
+        self._new = weakref.WeakKeyDictionary()
+        self._deleted = weakref.WeakKeyDictionary()
+    
+    def _restore_snapshot(self):
+        assert self._is_transaction_boundary
+        
+        for s in util.Set(self._deleted).union(self.session._deleted):
+            self.session._update_impl(s)
+        
+        assert not self.session._deleted
+            
+        for s in util.Set(self._new).union(self.session._new):
+            self.session._expunge_state(s)
+        
+        for s in self.session.identity_map.all_states():
+            _expire_state(s, None)
+    
+    def _remove_snapshot(self):
+        assert self._is_transaction_boundary
 
-    def add(self, bind):
-        self._assert_is_active()
-        if self._parent is not None and not self.nested:
-            return self._parent.add(bind)
-
-        if bind.engine in self._connections:
-            raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or ""))
-        return self.get_or_add(bind)
-
-    def get_or_add(self, bind):
+        if not self.nested and self.session.autoexpire:
+            for s in self.session.identity_map.all_states():
+                _expire_state(s, None)
+            
+    def _connection_for_bind(self, bind):
         self._assert_is_active()
         
         if bind in self._connections:
             return self._connections[bind][0]
         
-        if self._parent is not None:
-            conn = self._parent.get_or_add(bind)
+        if self._parent:
+            conn = self._parent._connection_for_bind(bind)
             if not self.nested:
                 return conn
         else:
             if isinstance(bind, engine.Connection):
                 conn = bind
                 if conn.engine in self._connections:
-                    raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
+                    raise sa_exc.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine")
             else:
                 conn = bind.contextual_connect()
 
@@ -227,9 +317,9 @@ class SessionTransaction(object):
 
     def prepare(self):
         if self._parent is not None or not self.session.twophase:
-            raise exceptions.InvalidRequestError("Only root two phase transactions of can be prepared")
+            raise sa_exc.InvalidRequestError("Only root two phase transactions of can be prepared")
         self._prepare_impl()
-        
+    
     def _prepare_impl(self):
         self._assert_is_active()
         if self.session.extension is not None and (self._parent is None or self.nested):
@@ -264,10 +354,12 @@ class SessionTransaction(object):
 
             if self.session.extension is not None:
                 self.session.extension.after_commit(self.session)
-
+            
+            self._remove_snapshot()
+                
         self.close()
         return self._parent
-
+    
     def rollback(self):
         self._assert_is_open()
         
@@ -291,6 +383,8 @@ class SessionTransaction(object):
         for t in util.Set(self._connections.values()):
             t[1].rollback()
 
+        self._restore_snapshot()
+
         if self.session.extension is not None:
             self.session.extension.after_rollback(self.session)
 
@@ -308,7 +402,7 @@ class SessionTransaction(object):
         self._deactivate()
         self.session = None
         self._connections = None
-
+    
     def __enter__(self):
         return self
 
@@ -356,9 +450,9 @@ class Session(object):
 
     * *Transient* - an instance that's not in a session, and is not saved to the database;
       i.e. it has no database identity. The only relationship such an object has to the ORM
-      is that its class has a `mapper()` associated with it.
+      is that its class has a ``mapper()`` associated with it.
 
-    * *Pending* - when you `save()` a transient instance, it becomes pending. It still
+    * *Pending* - when you ``add()`` a transient instance, it becomes pending. It still
       wasn't actually flushed to the database yet, but it will be when the next flush
       occurs.
 
@@ -372,108 +466,41 @@ class Session(object):
       they're detached, **except** they will not be able to issue any SQL in order to load
       collections or attributes which are not yet loaded, or were marked as "expired".
 
-    The session methods which control instance state include ``save()``, ``update()``,
-    ``save_or_update()``, ``delete()``, ``merge()``, and ``expunge()``.
+    The session methods which control instance state include ``add()``, ``delete()``, 
+    ``merge()``, and ``expunge()``.
 
-    The Session object is **not** threadsafe, particularly during flush operations.  A session
-    which is only read from (i.e. is never flushed) can be used by concurrent threads if it's
-    acceptable that some object instances may be loaded twice.
+    The Session object is generally **not** threadsafe.  A session which is set to ``autocommit``
+    and is only read from may be used by concurrent threads if it's acceptable that some object 
+    instances may be loaded twice.
 
     The typical pattern to managing Sessions in a multi-threaded environment is either to use
     mutexes to limit concurrent access to one thread at a time, or more commonly to establish
     a unique session for every thread, using a threadlocal variable.  SQLAlchemy provides
     a thread-managed Session adapter, provided by the [sqlalchemy.orm#scoped_session()] function.
+    
     """
-
-    def __init__(self, bind=None, autoflush=True, transactional=False, twophase=False, echo_uow=False, weak_identity_map=True, binds=None, extension=None):
+    def __init__(self, bind=None, autoflush=True, autoexpire=True, autocommit=False, twophase=False, echo_uow=False, weak_identity_map=True, binds=None, extension=None):
         """Construct a new Session.
         
-        A session is usually constructed using the [sqlalchemy.orm#create_session()] function, 
-        or its more "automated" variant [sqlalchemy.orm#sessionmaker()].
-
-        autoflush
-            When ``True``, all query operations will issue a ``flush()`` call to this
-            ``Session`` before proceeding. This is a convenience feature so that
-            ``flush()`` need not be called repeatedly in order for database queries to
-            retrieve results. It's typical that ``autoflush`` is used in conjunction with
-            ``transactional=True``, so that ``flush()`` is never called; you just call
-            ``commit()`` when changes are complete to finalize all changes to the
-            database.
-
-        bind
-            An optional ``Engine`` or ``Connection`` to which this ``Session`` should be
-            bound. When specified, all SQL operations performed by this session will
-            execute via this connectable.
-
-        binds
-            An optional dictionary, which contains more granular "bind" information than
-            the ``bind`` parameter provides. This dictionary can map individual ``Table``
-            instances as well as ``Mapper`` instances to individual ``Engine`` or
-            ``Connection`` objects. Operations which proceed relative to a particular
-            ``Mapper`` will consult this dictionary for the direct ``Mapper`` instance as
-            well as the mapper's ``mapped_table`` attribute in order to locate an
-            connectable to use. The full resolution is described in the ``get_bind()``
-            method of ``Session``. Usage looks like::
-
-                sess = Session(binds={
-                    SomeMappedClass : create_engine('postgres://engine1'),
-                    somemapper : create_engine('postgres://engine2'),
-                    some_table : create_engine('postgres://engine3'),
-                })
-
-            Also see the ``bind_mapper()`` and ``bind_table()`` methods.
-
-        echo_uow
-            When ``True``, configure Python logging to dump all unit-of-work
-            transactions. This is the equivalent of
-            ``logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.DEBUG)``.
-
-        extension
-            An optional [sqlalchemy.orm.session#SessionExtension] instance, which will receive
-            pre- and post- commit and flush events, as well as a post-rollback event.  User-
-            defined code may be placed within these hooks using a user-defined subclass
-            of ``SessionExtension``.
-
-        transactional
-            Set up this ``Session`` to automatically begin transactions. Setting this
-            flag to ``True`` is the rough equivalent of calling ``begin()`` after each
-            ``commit()`` operation, after each ``rollback()``, and after each
-            ``close()``. Basically, this has the effect that all session operations are
-            performed within the context of a transaction. Note that the ``begin()``
-            operation does not immediately utilize any connection resources; only when
-            connection resources are first required do they get allocated into a
-            transactional context.
-
-        twophase
-            When ``True``, all transactions will be started using
-            [sqlalchemy.engine_TwoPhaseTransaction]. During a ``commit()``, after
-            ``flush()`` has been issued for all attached databases, the ``prepare()``
-            method on each database's ``TwoPhaseTransaction`` will be called. This allows
-            each database to roll back the entire transaction, before each transaction is
-            committed.
-
-        weak_identity_map
-            When set to the default value of ``False``, a weak-referencing map is used;
-            instances which are not externally referenced will be garbage collected
-            immediately. For dereferenced instances which have pending changes present,
-            the attribute management system will create a temporary strong-reference to
-            the object which lasts until the changes are flushed to the database, at which
-            point it's again dereferenced. Alternatively, when using the value ``True``,
-            the identity map uses a regular Python dictionary to store instances. The
-            session will maintain all instances present until they are removed using
-            expunge(), clear(), or purge().
+        Arguments to ``Session`` are described using the [sqlalchemy.orm#sessionmaker()] function.
+        
         """
         self.echo_uow = echo_uow
-        self.weak_identity_map = weak_identity_map
-        self.uow = unitofwork.UnitOfWork(self)
-        self.identity_map = self.uow.identity_map
+        if weak_identity_map:
+            self._identity_cls = identity.WeakInstanceDict
+        else:
+            self._identity_cls = identity.StrongInstanceDict
+        self.identity_map = self._identity_cls()
 
+        self._new = {}   # InstanceState->object, strong refs object
+        self._deleted = {}  # same
         self.bind = bind
         self.__binds = {}
         self.transaction = None
         self.hash_key = id(self)
         self.autoflush = autoflush
-        self.transactional = transactional
+        self.autocommit = autocommit
+        self.autoexpire = autoexpire
         self.twophase = twophase
         self.extension = extension
         self._query_cls = query.Query
@@ -488,28 +515,59 @@ class Session(object):
                     for t in mapperortable._all_tables:
                         self.__binds[t] = value
 
-        if self.transactional:
+        if not self.autocommit:
             self.begin()
         _sessions[self.hash_key] = self
 
-    def begin(self, **kwargs):
-        """Begin a transaction on this Session."""
-
+    def begin(self, subtransactions=False, nested=False, _autoflush=True):
+        """Begin a transaction on this Session.
+        
+        If this Session is already within a transaction,
+        either a plain transaction or nested transaction,
+        an error is raised, unless ``subtransactions=True``
+        or ``nested=True`` is specified.
+        
+        The ``subtransactions=True`` flag indicates that
+        this ``begin()`` can create a subtransaction if a 
+        transaction is already in progress.  A subtransaction 
+        is a non-transactional, delimiting construct that 
+        allows matching begin()/commit() pairs to be nested 
+        together, with only the outermost begin/commit pair 
+        actually affecting transactional state.  When a rollback
+        is issued, the subtransaction will directly roll back 
+        the innermost real transaction, however each subtransaction 
+        still must be explicitly rolled back to maintain proper 
+        stacking of subtransactions.
+        
+        If no transaction is in progress,
+        then a real transaction is begun.  
+        
+        The ``nested`` flag begins a SAVEPOINT transaction
+        and is equivalent to calling ``begin_nested()``.
+        
+        """
         if self.transaction is not None:
-            self.transaction = self.transaction._begin(**kwargs)
+            if subtransactions or nested:
+                self.transaction = self.transaction._begin(nested=nested, autoflush=_autoflush)
+            else:
+                raise sa_exc.InvalidRequestError("A transaction is already begun.  Use subtransactions=True to allow subtransactions.")
         else:
-            self.transaction = SessionTransaction(self, **kwargs)
-        return self.transaction
-
-    create_transaction = begin
+            self.transaction = SessionTransaction(self, nested=nested, autoflush=_autoflush)
+        return self.transaction  # needed for __enter__/__exit__ hook
 
     def begin_nested(self):
         """Begin a `nested` transaction on this Session.
 
         This utilizes a ``SAVEPOINT`` transaction for databases
         which support this feature.
-        """
 
+        The nested transaction is a real transation, unlike
+        a "subtransaction" which corresponds to multiple
+        ``begin()`` calls.  The next ``rollback()`` or 
+        ``commit()`` call will operate upon this nested
+        transaction.
+        
+        """
         return self.begin(nested=True)
 
     def rollback(self):
@@ -517,42 +575,48 @@ class Session(object):
 
         If no transaction is in progress, this method is a
         pass-thru.
+        
+        This method rolls back the current transaction
+        or nested transaction regardless of subtransactions
+        being in effect.  All subtrasactions up to the 
+        first real transaction are closed.  Subtransactions 
+        occur when begin() is called mulitple times.
+        
         """
-
         if self.transaction is None:
             pass
         else:
             self.transaction.rollback()
-        # TODO: we can rollback attribute values.  however
-        # we would want to expand attributes.py to be able to save *two* rollback points, one to the
-        # last flush() and the other to when the object first entered the transaction.
-        # [ticket:705]
-        #attributes.rollback(*self.identity_map.values())
-        if self.transaction is None and self.transactional:
+        if self.transaction is None and not self.autocommit:
             self.begin()
 
     def commit(self):
-        """Commit the current transaction in progress.
+        """Flush any pending changes, and commit the current transaction 
+        in progress, assuming no subtransactions are in effect.
 
         If no transaction is in progress, this method raises
         an InvalidRequestError.
+        
+        If a subtransaction is in effect (which occurs when 
+        begin() is called multiple times), the subtransaction
+        will be closed, and the next call to ``commit()``
+        will operate on the enclosing transaction.
 
-        If the ``begin()`` method was called on this ``Session``
-        additional times subsequent to its first call,
-        ``commit()`` will not actually commit, and instead
-        pops an internal SessionTransaction off its internal stack
-        of transactions.  Only when the "root" SessionTransaction
-        is reached does an actual database-level commit occur.
-        """
+        For a session configured with autocommit=False, a new
+        transaction will be begun immediately after the commit,
+        but note that the newly begun transaction does *not* 
+        use any connection resources until the first SQL is 
+        actually emitted.
 
+        """
         if self.transaction is None:
-            if self.transactional:
+            if not self.autocommit:
                 self.begin()
             else:
-                raise exceptions.InvalidRequestError("No transaction is begun.")
+                raise sa_exc.InvalidRequestError("No transaction is begun.")
 
         self.transaction.commit()
-        if self.transaction is None and self.transactional:
+        if self.transaction is None and not self.autocommit:
             self.begin()
     
     def prepare(self):
@@ -565,10 +629,10 @@ class Session(object):
         not such, an InvalidRequestError is raised.
         """
         if self.transaction is None:
-            if self.transactional:
+            if not self.autocommit:
                 self.begin()
             else:
-                raise exceptions.InvalidRequestError("No transaction is begun.")
+                raise sa_exc.InvalidRequestError("No transaction is begun.")
 
         self.transaction.prepare()
 
@@ -594,7 +658,7 @@ class Session(object):
 
     def __connection(self, engine, **kwargs):
         if self.transaction is not None:
-            return self.transaction.get_or_add(engine)
+            return self.transaction._connection_for_bind(engine)
         else:
             return engine.contextual_connect(**kwargs)
 
@@ -620,6 +684,8 @@ class Session(object):
             the proper bind, in the case of ShardedSession.
             
         """
+        clause = expression._literal_as_text(clause)
+        
         engine = self.get_bind(mapper, clause=clause, instance=instance)
 
         return self.__connection(engine, close_with_result=True).execute(clause, params or {})
@@ -646,7 +712,7 @@ class Session(object):
         if self.transaction is not None:
             for transaction in self.transaction._iterate_parents():
                 transaction.close()
-        if self.transactional:
+        if not self.autocommit:
             # note this doesnt use any connection resources
             self.begin()
 
@@ -657,18 +723,24 @@ class Session(object):
             sess.close()
     close_all = classmethod(close_all)
 
-    def clear(self):
+    def expunge_all(self):
         """Remove all object instances from this ``Session``.
 
         This is equivalent to calling ``expunge()`` for all objects in
         this ``Session``.
         """
         
-        for instance in self:
-            self._unattach(instance)
-        self.uow = unitofwork.UnitOfWork(self)
-        self.identity_map = self.uow.identity_map
+        for state in self.identity_map.all_states() + list(self._new):
+            del state.session_id
 
+        self.identity_map = self._identity_cls()
+        self._new = {}
+        self._deleted = {}
+    clear = expunge_all
+    
+    # TODO: deprecate
+    #clear = util.deprecated()(expunge_all)
+    
     # TODO: need much more test coverage for bind_mapper() and similar !
 
     def bind_mapper(self, mapper, bind, entity_name=None):
@@ -713,79 +785,49 @@ class Session(object):
             
         """
         if mapper is None and clause is None:
-            if self.bind is not None:
+            if self.bind:
                 return self.bind
             else:
-                raise exceptions.UnboundExecutionError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()")
+                raise sa_exc.UnboundExecutionError("This session is not bound to any Engine or Connection; specify a mapper to get_bind()")
 
-        elif len(self.__binds):
-            if mapper is not None:
-                if isinstance(mapper, type):
-                    mapper = _class_mapper(mapper)
+        elif self.__binds:
+            if mapper:
+                mapper = _class_to_mapper(mapper)
                 if mapper.base_mapper in self.__binds:
                     return self.__binds[mapper.base_mapper]
-                elif mapper.compile().mapped_table in self.__binds:
+                elif mapper.mapped_table in self.__binds:
                     return self.__binds[mapper.mapped_table]
-            if clause is not None:
-                for t in clause._table_iterator():
+            if clause:
+                for t in sql_util.find_tables(clause):
                     if t in self.__binds:
                         return self.__binds[t]
 
-        if self.bind is not None:
+        if self.bind:
             return self.bind
-        elif isinstance(clause, sql.expression.ClauseElement) and clause.bind is not None:
+        elif isinstance(clause, sql.expression.ClauseElement) and clause.bind:
             return clause.bind
-        elif mapper is None:
-            raise exceptions.UnboundExecutionError("Could not locate any mapper associated with SQL expression")
+        elif not mapper:
+            raise sa_exc.UnboundExecutionError("Could not locate any mapper associated with SQL expression")
         else:
-            if isinstance(mapper, type):
-                mapper = _class_mapper(mapper)
-            else:
-                mapper = mapper.compile()
+            mapper = _class_to_mapper(mapper)
             e = mapper.mapped_table.bind
             if e is None:
-                raise exceptions.UnboundExecutionError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper))
+                raise sa_exc.UnboundExecutionError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper))
             return e
 
-    def query(self, mapper_or_class, *addtl_entities, **kwargs):
-        """Return a new ``Query`` object corresponding to this ``Session`` and
-        the mapper, or the classes' primary mapper.
-
-        """
-        entity_name = kwargs.pop('entity_name', None)
-
-        if isinstance(mapper_or_class, type):
-            q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs)
-        else:
-            q = self._query_cls(mapper_or_class, self, **kwargs)
-
-        for ent in addtl_entities:
-            q = q.add_entity(ent)
-        return q
-
+    def query(self, *entities, **kwargs):
+        """Return a new ``Query`` object corresponding to this ``Session``."""
+        
+        return self._query_cls(entities, self, **kwargs)
     def _autoflush(self):
         if self.autoflush and (self.transaction is None or self.transaction.autoflush):
             self.flush()
+    
+    def _finalize_loaded(self, states):
+        for state in states:
+            state.commit_all()
 
-    def flush(self, objects=None):
-        """Flush all the object modifications present in this session
-        to the database.
-
-        `objects` is a collection or iterator of objects specifically to be
-        flushed; if ``None``, all new and modified objects are flushed.
-
-        """
-        if objects is not None:
-            try:
-                if not len(objects):
-                    return
-            except TypeError:
-                objects = list(objects)
-                if not objects:
-                    return
-        self.uow.flush(self, objects)
-
-    def get(self, class_, ident, **kwargs):
+    def get(self, class_, ident, entity_name=None):
         """Return an instance of the object based on the given
         identifier, or ``None`` if not found.
 
@@ -798,10 +840,9 @@ class Session(object):
         query.
         """
 
-        entity_name = kwargs.pop('entity_name', None)
-        return self.query(class_, entity_name=entity_name).get(ident, **kwargs)
+        return self.query(class_, entity_name=entity_name).get(ident)
 
-    def load(self, class_, ident, **kwargs):
+    def load(self, class_, ident, entity_name=None):
         """Return an instance of the object based on the given
         identifier.
 
@@ -816,8 +857,7 @@ class Session(object):
         query.
         """
 
-        entity_name = kwargs.pop('entity_name', None)
-        return self.query(class_, entity_name=entity_name).load(ident, **kwargs)
+        return self.query(class_, entity_name=entity_name).load(ident)
 
     def refresh(self, instance, attribute_names=None):
         """Refresh the attributes on the given instance.
@@ -838,11 +878,13 @@ class Session(object):
         refreshed.
         """
 
-        self._validate_persistent(instance)
+        state = attributes.instance_state(instance)
+        self._validate_persistent(state)
+        if self.query(_object_mapper(instance))._get(
+                state.key, refresh_instance=state,
+                only_load_props=attribute_names) is None:
+            raise sa_exc.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
 
-        if self.query(_object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
-            raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
-    
     def expire_all(self):
         """Expires all persistent instances within this Session.  
         
@@ -862,19 +904,17 @@ class Session(object):
         of attribute names indicating a subset of attributes to be
         expired.
         """
-
+        state = attributes.instance_state(instance)
+        self._validate_persistent(state)
         if attribute_names:
-            self._validate_persistent(instance)
-            _expire_state(instance._state, attribute_names=attribute_names)
+            _expire_state(state, attribute_names=attribute_names)
         else:
             # pre-fetch the full cascade since the expire is going to
             # remove associations
-            cascaded = list(_cascade_iterator('refresh-expire', instance))
-            self._validate_persistent(instance)
-            _expire_state(instance._state, None)
-            for (c, m) in cascaded:
-                self._validate_persistent(c)
-                _expire_state(c._state, None)
+            cascaded = list(_cascade_state_iterator('refresh-expire', state))
+            _expire_state(state, None)
+            for (state, m) in cascaded:
+                _expire_state(state, None)
 
     def prune(self):
         """Remove unreferenced instances cached in the identity map.
@@ -887,7 +927,7 @@ class Session(object):
         Returns the number of objects pruned.
         """
 
-        return self.uow.prune_identity_map()
+        return self.identity_map.prune()
 
     def expunge(self, instance):
         """Remove the given `instance` from this ``Session``.
@@ -896,11 +936,58 @@ class Session(object):
         Cascading will be applied according to the *expunge* cascade
         rule.
         """
-        self._validate_persistent(instance)
-        for c, m in [(instance, None)] + list(_cascade_iterator('expunge', instance)):
-            if c in self:
-                self.uow._remove_deleted(c._state)
-                self._unattach(c)
+        
+        state = attributes.instance_state(instance)
+        if state.session_id is not self.hash_key:
+            raise sa_exc.InvalidRequestError("Instance %s is not present in this Session" % mapperutil.state_str(state))
+        for s, m in [(state, None)] + list(_cascade_state_iterator('expunge', state)):
+            self._expunge_state(s)
+    
+    def _expunge_state(self, state):
+        if state in self._new:
+            self._new.pop(state)
+            del state.session_id
+        elif self.identity_map.contains_state(state):
+            self.identity_map.discard(state)
+            self._deleted.pop(state, None)
+            del state.session_id
+
+    def _register_newly_persistent(self, state):
+        mapper = _state_mapper(state)
+        instance_key = mapper._identity_key_from_state(state)
+
+        if state.key is None:
+            state.key = instance_key
+        elif state.key != instance_key:
+            # primary key switch
+            self.identity_map.remove(state)
+            state.key = instance_key
+
+        if hasattr(state, 'insert_order'):
+            delattr(state, 'insert_order')
+
+        obj = state.obj()
+        # prevent against last minute dereferences of the object
+        # TODO: identify a code path where state.obj() is None
+        if obj is not None:
+            if state.key in self.identity_map and not self.identity_map.contains_state(state):
+                self.identity_map.remove_key(state.key)
+            self.identity_map.add(state)
+            state.commit_all()
+
+        # remove from new last, might be the last strong ref
+        if state in self._new:
+            if self.transaction:
+                self.transaction._new[state] = True
+            self._new.pop(state)
+        
+    def _remove_newly_deleted(self, state):
+        if self.transaction:
+            self.transaction._deleted[state] = True
+            
+        self.identity_map.discard(state)
+        self._deleted.pop(state, None)
+        del state.session_id
 
     def save(self, instance, entity_name=None):
         """Add a transient (unsaved) instance to this ``Session``.
@@ -911,10 +998,21 @@ class Session(object):
 
         The `entity_name` keyword argument will further qualify the
         specific ``Mapper`` used to handle this instance.
+        
         """
-        self._save_impl(instance, entity_name=entity_name)
-        self._cascade_save_or_update(instance)
-
+        state = _state_for_unsaved_instance(instance, entity_name)
+        self._save_impl(state)
+        self._cascade_save_or_update(state, entity_name)
+    
+    # TODO
+    #save = util.deprecated("Use the add() method.")(save)
+    
+    def _save_without_cascade(self, instance, entity_name=None):
+        """used by scoping.py to save on init without cascade."""
+        
+        state = _state_for_unsaved_instance(instance, entity_name)
+        self._save_impl(state)
+        
     def update(self, instance, entity_name=None):
         """Bring the given detached (saved) instance into this
         ``Session``.
@@ -926,24 +1024,42 @@ class Session(object):
         This operation cascades the `save_or_update` method to
         associated instances if the relation is mapped with
         ``cascade="save-update"``.
+        
         """
+        state = attributes.instance_state(instance)
+        self._update_impl(state)
+        self._cascade_save_or_update(state, entity_name)
+        
+    # TODO
+    #update = util.deprecated("Use the add() method.")(update)
+    
+    def add(self, instance, entity_name=None):
+        """Add the given instance into this ``Session``.
 
-        self._update_impl(instance, entity_name=entity_name)
-        self._cascade_save_or_update(instance)
-
-    def save_or_update(self, instance, entity_name=None):
-        """Save or update the given instance into this ``Session``.
+        The non-None state `key` on the instance's state determines whether
+        to ``save()`` or ``update()`` the instance.
 
-        The presence of an `_instance_key` attribute on the instance
-        determines whether to ``save()`` or ``update()`` the instance.
         """
-
-        self._save_or_update_impl(instance, entity_name=entity_name)
-        self._cascade_save_or_update(instance)
-
-    def _cascade_save_or_update(self, instance):
-        for obj, mapper in _cascade_iterator('save-update', instance, halt_on=lambda c:c in self):
-            self._save_or_update_impl(obj, mapper.entity_name)
+        state = _state_for_unknown_persistence_instance(instance, entity_name)
+        self._save_or_update_state(state, entity_name)
+    
+    def add_all(self, instances):
+        """Add the given collection of instances to this ``Session``."""
+        
+        for instance in instances:
+            self.add(instance)
+        
+    # TODO
+    # save_or_update = util.deprecated("Use the add() method.")(add)
+    save_or_update = add
+    
+    def _save_or_update_state(self, state, entity_name):
+        self._save_or_update_impl(state)
+        self._cascade_save_or_update(state, entity_name)
+        
+    def _cascade_save_or_update(self, state, entity_name):
+        for state, mapper in _cascade_unknown_state_iterator('save-update', state, halt_on=lambda c:c in self):
+            self._save_or_update_impl(state)
 
     def delete(self, instance):
         """Mark the given instance as deleted.
@@ -951,9 +1067,10 @@ class Session(object):
         The delete operation occurs upon ``flush()``.
         """
 
-        self._delete_impl(instance)
-        for c, m in _cascade_iterator('delete', instance):
-            self._delete_impl(c, ignore_transient=True)
+        state = attributes.instance_state(instance)
+        self._delete_impl(state)
+        for state, m in _cascade_state_iterator('delete', state):
+            self._delete_impl(state, ignore_transient=True)
 
 
     def merge(self, instance, entity_name=None, dont_load=False, _recursive=None):
@@ -980,103 +1097,51 @@ class Session(object):
         if instance in _recursive:
             return _recursive[instance]
 
-        key = getattr(instance, '_instance_key', None)
+        new_instance = False
+        state = attributes.instance_state(instance)
+        key = state.key
         if key is None:
             if dont_load:
-                raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects transient (i.e. unpersisted) objects.  flush() all changes on mapped instances before merging with dont_load=True.")
-            key = mapper.identity_key_from_instance(instance)
+                raise sa_exc.InvalidRequestError("merge() with dont_load=True option does not support objects transient (i.e. unpersisted) objects.  flush() all changes on mapped instances before merging with dont_load=True.")
+            key = mapper._identity_key_from_state(state)
 
         merged = None
         if key:
             if key in self.identity_map:
                 merged = self.identity_map[key]
             elif dont_load:
-                if instance._state.modified:
-                    raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects marked as 'dirty'.  flush() all changes on mapped instances before merging with dont_load=True.")
-
-                merged = attributes.new_instance(mapper.class_)
-                merged._instance_key = key
-                merged._entity_name = entity_name
-                self._update_impl(merged, entity_name=mapper.entity_name)
+                if state.modified:
+                    raise sa_exc.InvalidRequestError("merge() with dont_load=True option does not support objects marked as 'dirty'.  flush() all changes on mapped instances before merging with dont_load=True.")
+
+                merged = mapper.class_manager.new_instance()
+                merged_state = attributes.instance_state(merged)
+                merged_state.key = key
+                merged_state.entity_name = entity_name
+                self._update_impl(merged_state)
+                new_instance = True
             else:
                 merged = self.get(mapper.class_, key[1])
-        
+
         if merged is None:
-            merged = attributes.new_instance(mapper.class_)
+            merged = mapper.class_manager.new_instance()
+            merged_state = attributes.instance_state(merged)
+            new_instance = True
             self.save(merged, entity_name=mapper.entity_name)
-            
+
         _recursive[instance] = merged
-        
+
         for prop in mapper.iterate_properties:
             prop.merge(self, instance, merged, dont_load, _recursive)
-            
+
         if dont_load:
-            merged._state.commit_all()  # remove any history
+            attributes.instance_state(merged).commit_all()  # remove any history
 
+        if new_instance:
+            merged_state._run_on_load(merged)
         return merged
 
     def identity_key(cls, *args, **kwargs):
-        """Get an identity key.
-
-        Valid call signatures:
-
-        * ``identity_key(class, ident, entity_name=None)``
-
-          class
-              mapped class (must be a positional argument)
-
-          ident
-              primary key, if the key is composite this is a tuple
-
-          entity_name
-              optional entity name
-
-        * ``identity_key(instance=instance)``
-
-          instance
-              object instance (must be given as a keyword arg)
-
-        * ``identity_key(class, row=row, entity_name=None)``
-
-          class
-              mapped class (must be a positional argument)
-
-          row
-              result proxy row (must be given as a keyword arg)
-
-          entity_name
-              optional entity name (must be given as a keyword arg)
-        """
-
-        if args:
-            if len(args) == 1:
-                class_ = args[0]
-                try:
-                    row = kwargs.pop("row")
-                except KeyError:
-                    ident = kwargs.pop("ident")
-                entity_name = kwargs.pop("entity_name", None)
-            elif len(args) == 2:
-                class_, ident = args
-                entity_name = kwargs.pop("entity_name", None)
-            elif len(args) == 3:
-                class_, ident, entity_name = args
-            else:
-                raise exceptions.ArgumentError("expected up to three "
-                    "positional arguments, got %s" % len(args))
-            if kwargs:
-                raise exceptions.ArgumentError("unknown keyword arguments: %s"
-                    % ", ".join(kwargs.keys()))
-            mapper = _class_mapper(class_, entity_name=entity_name)
-            if "ident" in locals():
-                return mapper.identity_key_from_primary_key(ident)
-            return mapper.identity_key_from_row(row)
-        instance = kwargs.pop("instance")
-        if kwargs:
-            raise exceptions.ArgumentError("unknown keyword arguments: %s"
-                % ", ".join(kwargs.keys()))
-        mapper = _object_mapper(instance)
-        return mapper.identity_key_from_instance(instance)
+        return mapperutil.identity_key(*args, **kwargs)
     identity_key = classmethod(identity_key)
 
     def object_session(cls, instance):
@@ -1085,83 +1150,164 @@ class Session(object):
         return object_session(instance)
     object_session = classmethod(object_session)
 
-    def _save_impl(self, instance, **kwargs):
-        if hasattr(instance, '_instance_key'):
-            raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(instance))
-        else:
-            # TODO: consolidate the steps here
-            attributes.manage(instance)
-            instance._entity_name = kwargs.get('entity_name', None)
-            self._attach(instance)
-            self.uow.register_new(instance)
-
-    def _update_impl(self, instance, **kwargs):
-        if instance in self and instance not in self.deleted:
+    def _validate_persistent(self, state):
+        if not self.identity_map.contains_state(state):
+            raise sa_exc.InvalidRequestError("Instance '%s' is not persistent within this Session" % mapperutil.state_str(state))
+
+    def _save_impl(self, state):
+        if state.key is not None:
+            raise sa_exc.InvalidRequestError(
+                "Object '%s' already has an identity - it can't be registered "
+                "as pending" % repr(obj))
+        self._attach(state)
+        if state not in self._new:
+            self._new[state] = state.obj()
+            state.insert_order = len(self._new)
+
+    def _update_impl(self, state):
+        if self.identity_map.contains_state(state) and state not in self._deleted:
             return
-        if not hasattr(instance, '_instance_key'):
-            raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance))
-        elif self.identity_map.get(instance._instance_key, instance) is not instance:
-            raise exceptions.InvalidRequestError("Could not update instance '%s', identity key %s; a different instance with the same identity key already exists in this session." % (mapperutil.instance_str(instance), instance._instance_key))
-        self._attach(instance)
-
-    def _save_or_update_impl(self, instance, entity_name=None):
-        key = getattr(instance, '_instance_key', None)
-        if key is None:
-            self._save_impl(instance, entity_name=entity_name)
+
+        if state.key is None:
+            raise sa_exc.InvalidRequestError(
+                "Instance '%s' is not persisted" %
+                mapperutil.state_str(state))
+                
+        if state.key in self.identity_map and not self.identity_map.contains_state(state):
+            raise sa_exc.InvalidRequestError(
+                "Could not update instance '%s', identity key %s; a different "
+                "instance with the same identity key already exists in this "
+                "session." % (mapperutil.state_str(state), state.key))
+                
+        self._attach(state)
+        self._deleted.pop(state, None)
+        self.identity_map.add(state)
+        
+    def _save_or_update_impl(self, state):
+        if state.key is None:
+            self._save_impl(state)
         else:
-            self._update_impl(instance, entity_name=entity_name)
+            self._update_impl(state)
 
-    def _delete_impl(self, instance, ignore_transient=False):
-        if instance in self and instance in self.deleted:
+    def _delete_impl(self, state, ignore_transient=False):
+        if self.identity_map.contains_state(state) and state in self._deleted:
             return
-        if not hasattr(instance, '_instance_key'):
+            
+        if state.key is None:
             if ignore_transient:
                 return
             else:
-                raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance))
-        if self.identity_map.get(instance._instance_key, instance) is not instance:
-            raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(instance), instance._instance_key))
-        self._attach(instance)
-        self.uow.register_deleted(instance)
-
-    def _attach(self, instance):
-        old_id = getattr(instance, '_sa_session_id', None)
-        if old_id != self.hash_key:
-            if old_id is not None and old_id in _sessions and instance in _sessions[old_id]:
-                raise exceptions.InvalidRequestError("Object '%s' is already attached "
-                                                     "to session '%s' (this is '%s')" %
-                                                     (mapperutil.instance_str(instance), old_id, id(self)))
-
-            key = getattr(instance, '_instance_key', None)
-            if key is not None:
-                self.identity_map[key] = instance
-            instance._sa_session_id = self.hash_key
-
-    def _unattach(self, instance):
-        if instance._sa_session_id == self.hash_key:
-            del instance._sa_session_id
-
-    def _validate_persistent(self, instance):
-        """Validate that the given instance is persistent within this
-        ``Session``.
-        """
-
-        if instance not in self:
-            raise exceptions.InvalidRequestError("Instance '%s' is not persistent within this Session" % mapperutil.instance_str(instance))
+                raise sa_exc.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.state_str(state))
+        if state.key in self.identity_map and not self.identity_map.contains_state(state):
+            raise sa_exc.InvalidRequestError(
+                "Instance '%s' is with key %s already persisted with a "
+                "different identity" % (mapperutil.state_str(state),
+                                        state.key))
+
+        self._deleted[state] = state.obj()
+        self._attach(state)
+
+    def _attach(self, state):
+        if state.session_id and state.session_id is not self.hash_key:
+            raise sa_exc.InvalidRequestError(
+                "Object '%s' is already attached to session '%s' "
+                "(this is '%s')" % (mapperutil.state_str(state),
+                                    state.session_id, self.hash_key))
+        if state.session_id != self.hash_key:
+            state.session_id = self.hash_key
 
     def __contains__(self, instance):
         """Return True if the given instance is associated with this session.
 
         The instance may be pending or persistent within the Session for a
         result of True.
-        """
-
-        return instance._state in self.uow.new or (hasattr(instance, '_instance_key') and self.identity_map.get(instance._instance_key) is instance)
 
+        """
+        return self._contains_state(attributes.instance_state(instance))
+    
     def __iter__(self):
         """Return an iterator of all instances which are pending or persistent within this Session."""
 
-        return iter(list(self.uow.new.values()) + self.uow.identity_map.values())
+        return iter(list(self._new.values()) + self.identity_map.values())
+
+    def _contains_state(self, state):
+        return state in self._new or self.identity_map.contains_state(state)
+
+
+    def flush(self, objects=None):
+        """Flush all the object modifications present in this session
+        to the database.
+
+        `objects` is a list or tuple of objects specifically to be
+        flushed; if ``None``, all new and modified objects are flushed.
+
+        """
+        if not self.identity_map.check_modified() and not self._deleted and not self._new:
+            return
+            
+        dirty = self._dirty_states
+        if not dirty and not self._deleted and not self._new:
+            self.identity_map.modified = False
+            return
+
+        deleted = util.Set(self._deleted)
+        new = util.Set(self._new)
+
+        dirty = util.Set(dirty).difference(deleted)
+
+        flush_context = UOWTransaction(self)
+
+        if self.extension is not None:
+            self.extension.before_flush(self, flush_context, objects)
+
+        # create the set of all objects we want to operate upon
+        if objects:
+            # specific list passed in
+            objset = util.Set([attributes.instance_state(o) for o in objects])
+        else:
+            # or just everything
+            objset = util.Set(self.identity_map.all_states()).union(new)
+
+        # store objects whose fate has been decided
+        processed = util.Set()
+
+        # put all saves/updates into the flush context.  detect top-level orphans and throw them into deleted.
+        for state in new.union(dirty).intersection(objset).difference(deleted):
+            is_orphan = _state_mapper(state)._is_orphan(state)
+            if is_orphan and not _state_has_identity(state):
+                raise exc.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" %
+                    (
+                        mapperutil.state_str(state),
+                        ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in _state_mapper(state).delete_orphans])
+                    ))
+            flush_context.register_object(state, isdelete=is_orphan)
+            processed.add(state)
+
+        # put all remaining deletes into the flush context.
+        for state in deleted.intersection(objset).difference(processed):
+            flush_context.register_object(state, isdelete=True)
+
+        if len(flush_context.tasks) == 0:
+            return
+        
+        flush_context.transaction = transaction = self.begin(subtransactions=True, _autoflush=False)
+        try:
+            flush_context.execute()
+
+            if self.extension is not None:
+                self.extension.after_flush(self, flush_context)
+            transaction.commit()
+        except:
+            transaction.rollback()
+            raise
+
+        flush_context.finalize_flush_changes()
+
+        if not objects:
+            self.identity_map.modified = False
+
+        if self.extension is not None:
+            self.extension.after_flush_postexec(self, flush_context)
 
     def is_modified(self, instance, include_collections=True, passive=False):
         """Return True if the given instance has modified attributes.
@@ -1180,7 +1326,7 @@ class Session(object):
         not be loaded in the course of performing this test.
         """
 
-        for attr in attributes._managed_attributes(instance.__class__):
+        for attr in attributes.manager_of_class(instance.__class__).attributes:
             if not include_collections and hasattr(attr.impl, 'get_collection'):
                 continue
             (added, unchanged, deleted) = attr.get_history(instance)
@@ -1188,8 +1334,23 @@ class Session(object):
                 return True
         return False
 
+    def _dirty_states(self):
+        """Return a set of all persistent states considered dirty.
+
+        This method returns all states that were modified including those that
+        were possibly deleted.
+
+        """
+        return util.IdentitySet(
+            [state for state in self.identity_map.all_states() if state.check_modified()]
+        )
+    _dirty_states = property(_dirty_states)
+
     def dirty(self):
-        """Return a ``Set`` of all instances marked as 'dirty' within this ``Session``.
+        """Return a set of all persistent instances considered dirty.
+
+        Instances are considered dirty when they were modified but not
+        deleted.
 
         Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection
         modification operations will mark an instance as 'dirty' and place it in this set,
@@ -1200,21 +1361,25 @@ class Session(object):
 
         To check if an instance has actionable net changes to its attributes, use the
         is_modified() method.
+
         """
+        
+        return util.IdentitySet(
+            [state.obj() for state in self._dirty_states if state not in self._deleted]
+        )
 
-        return self.uow.locate_dirty()
     dirty = property(dirty)
 
     def deleted(self):
         "Return a ``Set`` of all instances marked as 'deleted' within this ``Session``"
         
-        return util.IdentitySet(self.uow.deleted.values())
+        return util.IdentitySet(self._deleted.values())
     deleted = property(deleted)
 
     def new(self):
         "Return a ``Set`` of all instances marked as 'new' within this ``Session``."
         
-        return util.IdentitySet(self.uow.new.values())
+        return util.IdentitySet(self._new.values())
     new = property(new)
 
 def _expire_state(state, attribute_names):
@@ -1233,22 +1398,52 @@ register_attribute = unitofwork.register_attribute
 
 _sessions = weakref.WeakValueDictionary()
 
-def _cascade_iterator(cascade, instance, **kwargs):
-    mapper = _object_mapper(instance)
-    for (o, m) in mapper.cascade_iterator(cascade, instance._state, **kwargs):
-        yield o, m
+def _cascade_state_iterator(cascade, state, **kwargs):
+    mapper = _state_mapper(state)
+    for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs):
+        yield attributes.instance_state(o), m
+
+def _cascade_unknown_state_iterator(cascade, state, **kwargs):
+    mapper = _state_mapper(state)
+    for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs):
+        yield _state_for_unknown_persistence_instance(o, m.entity_name), m
+
+def _state_for_unsaved_instance(instance, entity_name):
+    manager = attributes.manager_of_class(instance.__class__)
+    if manager is None:
+        raise "FIXME unmapped instance"
+    if manager.has_state(instance):
+        state = manager.state_of(instance)
+        if state.key is not None:
+            raise sa_exc.InvalidRequestError(
+                "Instance '%s' is already persistent" %
+                mapperutil.state_str(state))
+    else:
+        state = manager.setup_instance(instance)
+    state.entity_name = entity_name
+    return state
+
+def _state_for_unknown_persistence_instance(instance, entity_name):
+    state = attributes.instance_state(instance)
+    state.entity_name = entity_name
+    return state
 
 def object_session(instance):
     """Return the ``Session`` to which the given instance is bound, or ``None`` if none."""
 
-    hashkey = getattr(instance, '_sa_session_id', None)
-    if hashkey is not None:
-        sess = _sessions.get(hashkey)
-        if sess is not None and instance in sess:
-            return sess
+    return _state_session(attributes.instance_state(instance))
+    
+def _state_session(state):
+    if state.session_id:
+        try:
+            return _sessions[state.session_id]
+        except KeyError:
+            pass
     return None
 
 # Lazy initialization to avoid circular imports
 unitofwork.object_session = object_session
+unitofwork._state_session = _state_session
 from sqlalchemy.orm import mapper
 mapper._expire_state = _expire_state
+mapper._state_session = _state_session
index 7cf4eb2cc50af682bcb44b67a6c875b3fc82f8d3..6850a0bb07c2b745a4d465cf6d48a63a0855c40e 100644 (file)
@@ -1,38 +1,49 @@
-"""Defines a rudimental 'horizontal sharding' system which allows a 
-Session to distribute queries and persistence operations across multiple 
-databases.
+# shard.py
+# Copyright (C) the SQLAlchemy authors and contributors
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-For a usage example, see the example ``examples/sharding/attribute_shard.py``.
+"""Horizontal sharding support.
+
+Defines a rudimental 'horizontal sharding' system which allows a Session to
+distribute queries and persistence operations across multiple databases.
+
+For a usage example, see the file ``examples/sharding/attribute_shard.py``
+included in the source distrbution.
 
 """
+
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import util
 from sqlalchemy.orm.session import Session
 from sqlalchemy.orm.query import Query
-from sqlalchemy import exceptions, util
 
 __all__ = ['ShardedSession', 'ShardedQuery']
 
+
 class ShardedSession(Session):
     def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs):
-        """construct a ShardedSession.
-        
-            shard_chooser
-                a callable which, passed a Mapper, a mapped instance, and possibly a
-                SQL clause, returns a shard ID. this id may be based off of the
-                attributes present within the object, or on some round-robin scheme. If
-                the scheme is based on a selection, it should set whatever state on the
-                instance to mark it in the future as participating in that shard.
-            
-            id_chooser
-                a callable, passed a query and a tuple of identity values,
-                which should return a list of shard ids where the ID might
-                reside.  The databases will be queried in the order of this
-                listing.
-                
-            query_chooser
-                for a given Query, returns the list of shard_ids where the query
-                should be issued.  Results from all shards returned will be 
-                combined together into a single listing.
-        
+        """Construct a ShardedSession.
+
+        shard_chooser
+          A callable which, passed a Mapper, a mapped instance, and possibly a
+          SQL clause, returns a shard ID.  This id may be based off of the
+          attributes present within the object, or on some round-robin
+          scheme. If the scheme is based on a selection, it should set
+          whatever state on the instance to mark it in the future as
+          participating in that shard.
+
+        id_chooser
+          A callable, passed a query and a tuple of identity values, which
+          should return a list of shard ids where the ID might reside.  The
+          databases will be queried in the order of this listing.
+
+        query_chooser
+          For a given Query, returns the list of shard_ids where the query
+          should be issued.  Results from all shards returned will be combined
+          together into a single listing.
+
         """
         super(ShardedSession, self).__init__(**kwargs)
         self.shard_chooser = shard_chooser
@@ -87,17 +98,17 @@ class ShardedQuery(Query):
         
     def _execute_and_instances(self, context):
         if self._shard_id is not None:
-            result = self.session.connection(mapper=self.mapper, shard_id=self._shard_id).execute(context.statement, **self._params)
+            result = self.session.connection(mapper=self._mapper_zero(), shard_id=self._shard_id).execute(context.statement, **self._params)
             try:
-                return iter(self.instances(result, querycontext=context))
+                return iter(self.instances(result, context))
             finally:
                 result.close()
         else:
             partial = []
             for shard_id in self.query_chooser(self):
-                result = self.session.connection(mapper=self.mapper, shard_id=shard_id).execute(context.statement, **self._params)
+                result = self.session.connection(mapper=self._mapper_zero(), shard_id=shard_id).execute(context.statement, **self._params)
                 try:
-                    partial = partial + list(self.instances(result, querycontext=context))
+                    partial = partial + list(self.instances(result, context))
                 finally:
                     result.close()
             # if some kind of in memory 'sorting' were done, this is where it would happen
@@ -124,4 +135,4 @@ class ShardedQuery(Query):
                 if o is not None:
                     return o
             else:
-                raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident))
+                raise sa_exc.InvalidRequestError("No instance found for identity %s" % repr(ident))
index 65a8b019b865a3579061d6fa7b811cfdae73762d..8ae3042a6b3e1c8bf24fc0843f1624c801a204c3 100644 (file)
@@ -6,11 +6,13 @@
 
 """sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions."""
 
-from sqlalchemy import sql, util, exceptions, logging
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import sql, util, log
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import visitors, expression, operators
 from sqlalchemy.orm import mapper, attributes
-from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption, serialize_path, deserialize_path
+from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, \
+     MapperOption, PropertyOption, serialize_path, deserialize_path
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm import util as mapperutil
 
@@ -21,28 +23,53 @@ class ColumnLoader(LoaderStrategy):
     def init(self):
         super(ColumnLoader, self).init()
         self.columns = self.parent_property.columns
-        self._should_log_debug = logging.is_debug_enabled(self.logger)
+        self._should_log_debug = log.is_debug_enabled(self.logger)
         self.is_composite = hasattr(self.parent_property, 'composite_class')
         
-    def setup_query(self, context, parentclauses=None, **kwargs):
+    def setup_query(self, context, entity, path, adapter, column_collection=None, **kwargs):
         for c in self.columns:
-            if parentclauses is not None:
-                context.secondary_columns.append(parentclauses.aliased_column(c))
-            else:
-                context.primary_columns.append(c)
+            if adapter:
+                c = adapter.columns[c]
+            column_collection.append(c)
         
     def init_class_attribute(self):
         self.is_class_level = True
-        if self.is_composite:
-            self._init_composite_attribute()
+        self.logger.info("%s register managed attribute" % self)
+        coltype = self.columns[0].type
+        sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
+        
+    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+        key, col = self.key, self.columns[0]
+        if adapter:
+            col = adapter.columns[col]
+        if col in row:
+            def new_execute(state, row, **flags):
+                state.dict[key] = row[col]
+                
+            if self._should_log_debug:
+                new_execute = self.debug_callable(new_execute, self.logger,
+                    "%s returning active column fetcher" % self,
+                    lambda state, row, **flags: "%s populating %s" % (self, mapperutil.state_attribute_str(state, key))
+                )
+            return (new_execute, None)
         else:
-            self._init_scalar_attribute()
+            def new_execute(state, row, isnew, **flags):
+                if isnew:
+                    state.expire_attributes([key])
+            if self._should_log_debug:
+                self.logger.debug("%s deferring load" % self)
+            return (new_execute, None)
+
+ColumnLoader.logger = log.class_logger(ColumnLoader)
+
+class CompositeColumnLoader(ColumnLoader):
+    def init_class_attribute(self):
+        self.is_class_level = True
+        self.logger.info("%s register managed composite attribute" % self)
 
-    def _init_composite_attribute(self):
-        self.logger.info("register managed composite attribute %s on class %s" % (self.key, self.parent.class_.__name__))
         def copy(obj):
-            return self.parent_property.composite_class(
-                *obj.__composite_values__())
+            return self.parent_property.composite_class(*obj.__composite_values__())
+            
         def compare(a, b):
             for col, aprop, bprop in zip(self.columns,
                                          a.__composite_values__(),
@@ -51,63 +78,56 @@ class ColumnLoader(LoaderStrategy):
                     return False
             else:
                 return True
-        sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator)
-
-    def _init_scalar_attribute(self):
-        self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
-        coltype = self.columns[0].type
-        sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
-        
-    def create_row_processor(self, selectcontext, mapper, row):
-        if self.is_composite:
-            for c in self.columns:
-                if c not in row:
-                    break
-            else:
-                def new_execute(instance, row, **flags):
-                    if self._should_log_debug:
-                        self.logger.debug("populating %s with %s/%s..." % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
-                    instance.__dict__[self.key] = self.parent_property.composite_class(*[row[c] for c in self.columns])
-                if self._should_log_debug:
-                    self.logger.debug("Returning active composite column fetcher for %s %s" % (mapper, self.key))
-                return (new_execute, None, None)
-                
-        elif self.columns[0] in row:
-            def new_execute(instance, row, **flags):
+        sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator, parententity=self.parent)
+
+    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+        key, columns, composite_class = self.key, self.columns, self.parent_property.composite_class
+        if adapter:
+            columns = [adapter.columns[c] for c in columns]
+        for c in columns:
+            if c not in row:
+                def new_execute(state, row, isnew, **flags):
+                    if isnew:
+                        state.expire_attributes([key])
                 if self._should_log_debug:
-                    self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key))
-                instance.__dict__[self.key] = row[self.columns[0]]
-            if self._should_log_debug:
-                self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key))
-            return (new_execute, None, None)
+                    self.logger.debug("%s deferring load" % self)
+                return (new_execute, None)
         else:
-            def new_execute(instance, row, isnew, **flags):
-                if isnew:
-                    instance._state.expire_attributes([self.key])
+            def new_execute(state, row, **flags):
+                state.dict[key] = composite_class(*[row[c] for c in columns])
+
             if self._should_log_debug:
-                self.logger.debug("Deferring load for %s %s" % (mapper, self.key))
-            return (new_execute, None, None)
+                new_execute = self.debug_callable(new_execute, self.logger,
+                    "%s returning active composite column fetcher" % self,
+                    lambda state, row, **flags: "populating %s" % (mapperutil.state_attribute_str(state, key))
+                )
 
-ColumnLoader.logger = logging.class_logger(ColumnLoader)
+            return (new_execute, None)
 
+CompositeColumnLoader.logger = log.class_logger(CompositeColumnLoader)
+    
 class DeferredColumnLoader(LoaderStrategy):
     """Deferred column loader, a per-column or per-column-group lazy loader."""
     
-    def create_row_processor(self, selectcontext, mapper, row):
-        if self.columns[0] in row:
-            return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, mapper, row)
+    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+        col = self.columns[0]
+        if adapter:
+            col = adapter.columns[col]
+        if col in row:
+            return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, path, mapper, row, adapter)
+
         elif not self.is_class_level or len(selectcontext.options):
-            def new_execute(instance, row, **flags):
-                if self._should_log_debug:
-                    self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
-                instance._state.set_callable(self.key, self.setup_loader(instance))
-            return (new_execute, None, None)
+            def new_execute(state, row, **flags):
+                state.set_callable(self.key, self.setup_loader(state))
         else:
-            def new_execute(instance, row, **flags):
-                if self._should_log_debug:
-                    self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key))
-                instance._state.reset(self.key)
-            return (new_execute, None, None)
+            def new_execute(state, row, **flags):
+                state.reset(self.key)
+
+        if self._should_log_debug:
+            new_execute = self.debug_callable(new_execute, self.logger, None,
+                lambda state, row, **flags: "set deferred callable on %s" % mapperutil.state_attribute_str(state, self.key)
+            )
+        return (new_execute, None)
 
     def init(self):
         super(DeferredColumnLoader, self).init()
@@ -115,25 +135,25 @@ class DeferredColumnLoader(LoaderStrategy):
             raise NotImplementedError("Deferred loading for composite types not implemented yet")
         self.columns = self.parent_property.columns
         self.group = self.parent_property.group
-        self._should_log_debug = logging.is_debug_enabled(self.logger)
+        self._should_log_debug = log.is_debug_enabled(self.logger)
 
     def init_class_attribute(self):
         self.is_class_level = True
-        self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__))
-        sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator)
+        self.logger.info("%s register managed attribute" % self)
+        sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
 
-    def setup_query(self, context, only_load_props=None, **kwargs):
+    def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs):
         if \
             (self.group is not None and context.attributes.get(('undefer', self.group), False)) or \
             (only_load_props and self.key in only_load_props):
             
-            self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs)
+            self.parent_property._get_strategy(ColumnLoader).setup_query(context, entity, path, adapter, **kwargs)
     
-    def class_level_loader(self, instance, props=None):
-        if not mapper.has_mapper(instance):
+    def class_level_loader(self, state, props=None):
+        if not mapperutil._state_has_mapper(state):
             return None
             
-        localparent = mapper.object_mapper(instance)
+        localparent = mapper._state_mapper(state)
 
         # adjust for the ColumnProperty associated with the instance
         # not being our own ColumnProperty.  This can occur when entity_name
@@ -141,38 +161,38 @@ class DeferredColumnLoader(LoaderStrategy):
         # to the class.
         prop = localparent.get_property(self.key)
         if prop is not self.parent_property:
-            return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
+            return prop._get_strategy(DeferredColumnLoader).setup_loader(state)
 
-        return LoadDeferredColumns(instance, self.key, props)
+        return LoadDeferredColumns(state, self.key, props)
         
-    def setup_loader(self, instance, props=None, create_statement=None):
-        return LoadDeferredColumns(instance, self.key, props, optimizing_statement=create_statement)
+    def setup_loader(self, state, props=None, create_statement=None):
+        return LoadDeferredColumns(state, self.key, props)
                 
-DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader)
+DeferredColumnLoader.logger = log.class_logger(DeferredColumnLoader)
 
 class LoadDeferredColumns(object):
-    """callable, serializable loader object used by DeferredColumnLoader"""
+    """serializable loader object used by DeferredColumnLoader"""
     
-    def __init__(self, instance, key, keys, optimizing_statement=None):
-        self.instance = instance
+    def __init__(self, state, key, keys):
+        self.state = state
         self.key = key
         self.keys = keys
-        self.optimizing_statement = optimizing_statement
 
     def __getstate__(self):
-        return {'instance':self.instance, 'key':self.key, 'keys':self.keys}
+        return {'state':self.state, 'key':self.key, 'keys':self.keys}
     
     def __setstate__(self, state):
-        self.instance = state['instance']
+        self.state = state['state']
         self.key = state['key']
         self.keys = state['keys']
-        self.optimizing_statement = None
         
     def __call__(self):
-        if not mapper.has_identity(self.instance):
+        state = self.state
+        
+        if not mapper._state_has_identity(state):
             return None
-            
-        localparent = mapper.object_mapper(self.instance, raiseerror=False)
+        
+        localparent = mapper._state_mapper(state)
         
         prop = localparent.get_property(self.key)
         strategy = prop._get_strategy(DeferredColumnLoader)
@@ -185,22 +205,18 @@ class LoadDeferredColumns(object):
             toload = [self.key]
 
         # narrow the keys down to just those which have no history
-        group = [k for k in toload if k in self.instance._state.unmodified]
+        group = [k for k in toload if k in state.unmodified]
 
         if strategy._should_log_debug:
-            strategy.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(self.instance, self.key), group and ','.join(group) or 'None'))
+            strategy.logger.debug("deferred load %s group %s" % (mapperutil.state_attribute_str(state, self.key), group and ','.join(group) or 'None'))
 
-        session = sessionlib.object_session(self.instance)
+        session = sessionlib._state_session(state)
         if session is None:
-            raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (self.instance.__class__, self.key))
+            raise sa_exc.UnboundExecutionError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (mapperutil.state_str(state), self.key))
 
         query = session.query(localparent)
-        if not self.optimizing_statement:
-            ident = self.instance._instance_key[1]
-            query._get(None, ident=ident, only_load_props=group, refresh_instance=self.instance._state)
-        else:
-            statement, params = self.optimizing_statement(self.instance)
-            query.from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=self.instance._state)
+        ident = state.key[1]
+        query._get(None, ident=ident, only_load_props=group, refresh_instance=state)
         return attributes.ATTR_WAS_SET
 
 class DeferredOption(StrategizedOption):
@@ -223,55 +239,63 @@ class UndeferGroupOption(MapperOption):
 class AbstractRelationLoader(LoaderStrategy):
     def init(self):
         super(AbstractRelationLoader, self).init()
-        for attr in ['primaryjoin', 'secondaryjoin', 'secondary', 'foreign_keys', 'mapper', 'target', 'table', 'uselist', 'cascade', 'attributeext', 'order_by', 'remote_side', 'direction']:
+        for attr in ['mapper', 'target', 'table', 'uselist']:
             setattr(self, attr, getattr(self.parent_property, attr))
-        self._should_log_debug = logging.is_debug_enabled(self.logger)
+        self._should_log_debug = log.is_debug_enabled(self.logger)
         
-    def _init_instance_attribute(self, instance, callable_=None):
+    def _init_instance_attribute(self, state, callable_=None):
         if callable_:
-            instance._state.set_callable(self.key, callable_)
+            state.set_callable(self.key, callable_)
         else:
-            instance._state.initialize(self.key)
+            state.initialize(self.key)
         
     def _register_attribute(self, class_, callable_=None, **kwargs):
-        self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__))
-        sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=self.attributeext, cascade=self.cascade,  trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, **kwargs)
+        self.logger.info("%s register managed %s attribute" % (self, (self.uselist and "collection" or "scalar")))
+        
+        if self.parent_property.backref:
+            attribute_ext = self.parent_property.backref.extension
+        else:
+            attribute_ext = None
+        
+        sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=attribute_ext, cascade=self.parent_property.cascade,  trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, parententity=self.parent, **kwargs)
 
 class NoLoader(AbstractRelationLoader):
     def init_class_attribute(self):
         self.is_class_level = True
         self._register_attribute(self.parent.class_)
 
-    def create_row_processor(self, selectcontext, mapper, row):
-        def new_execute(instance, row, ispostselect, **flags):
-            if not ispostselect:
-                if self._should_log_debug:
-                    self.logger.debug("initializing blank scalar/collection on %s" % mapperutil.attribute_str(instance, self.key))
-                self._init_instance_attribute(instance)
-        return (new_execute, None, None)
+    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+        def new_execute(state, row, **flags):
+            self._init_instance_attribute(state)
+
+        if self._should_log_debug:
+            new_execute = self.debug_callable(new_execute, self.logger, None,
+                lambda state, row, **flags: "initializing blank scalar/collection on %s" % mapperutil.state_attribute_str(state, self.key)
+            )
+        return (new_execute, None)
 
-NoLoader.logger = logging.class_logger(NoLoader)
+NoLoader.logger = log.class_logger(NoLoader)
         
 class LazyLoader(AbstractRelationLoader):
     def init(self):
         super(LazyLoader, self).init()
         (self.__lazywhere, self.__bind_to_col, self._equated_columns) = self.__create_lazy_clause(self.parent_property)
         
-        self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.__lazywhere))
+        self.logger.info("%s lazy loading clause %s" % (self, self.__lazywhere))
 
         # determine if our "lazywhere" clause is the same as the mapper's
         # get() clause.  then we can just use mapper.get()
         #from sqlalchemy.orm import query
         self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere)
         if self.use_get:
-            self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads")
+            self.logger.info("%s will use query.get() to optimize instance loads" % self)
 
     def init_class_attribute(self):
         self.is_class_level = True
         self._register_attribute(self.parent.class_, callable_=self.class_level_loader)
 
-    def lazy_clause(self, instance, reverse_direction=False):
-        if instance is None:
+    def lazy_clause(self, state, reverse_direction=False):
+        if state is None:
             return self._lazy_none_clause(reverse_direction)
             
         if not reverse_direction:
@@ -285,8 +309,8 @@ class LazyLoader(AbstractRelationLoader):
                 # use the "committed" (database) version to get query column values
                 # also its a deferred value; so that when used by Query, the committed value is used
                 # after an autoflush occurs
-                bindparam.value = lambda: mapper._get_committed_attr_by_column(instance, bind_to_col[bindparam.key])
-        return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam)
+                bindparam.value = lambda: mapper._get_committed_state_attr_by_column(state, bind_to_col[bindparam.key])
+        return visitors.cloned_traverse(criterion, {}, {'bindparam':visit_bindparam})
     
     def _lazy_none_clause(self, reverse_direction=False):
         if not reverse_direction:
@@ -305,13 +329,13 @@ class LazyLoader(AbstractRelationLoader):
                 binary.right = expression.null()
                 binary.operator = operators.is_
         
-        return visitors.traverse(criterion, clone=True, visit_binary=visit_binary)
+        return visitors.cloned_traverse(criterion, {}, {'binary':visit_binary})
         
-    def class_level_loader(self, instance, options=None, path=None):
-        if not mapper.has_mapper(instance):
+    def class_level_loader(self, state, options=None, path=None):
+        if not mapperutil._state_has_mapper(state):
             return None
 
-        localparent = mapper.object_mapper(instance)
+        localparent = mapper._state_mapper(state)
 
         # adjust for the PropertyLoader associated with the instance
         # not being our own PropertyLoader.  This can occur when entity_name
@@ -319,35 +343,41 @@ class LazyLoader(AbstractRelationLoader):
         # to the class.
         prop = localparent.get_property(self.key)
         if prop is not self.parent_property:
-            return prop._get_strategy(LazyLoader).setup_loader(instance)
+            return prop._get_strategy(LazyLoader).setup_loader(state)
         
-        return LoadLazyAttribute(instance, self.key, options, path)
+        return LoadLazyAttribute(state, self.key, options, path)
 
-    def setup_loader(self, instance, options=None, path=None):
-        return LoadLazyAttribute(instance, self.key, options, path)
+    def setup_loader(self, state, options=None, path=None):
+        return LoadLazyAttribute(state, self.key, options, path)
 
-    def create_row_processor(self, selectcontext, mapper, row):
+    def create_row_processor(self, selectcontext, path, mapper, row, adapter):
         if not self.is_class_level or len(selectcontext.options):
-            def new_execute(instance, row, ispostselect, **flags):
-                if not ispostselect:
-                    if self._should_log_debug:
-                        self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
-                    # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
-                    # which will override the class-level behavior
-                    
-                    self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options, selectcontext.query._current_path + selectcontext.path))
-            return (new_execute, None, None)
+            path = path + (self.key,)
+            def new_execute(state, row, **flags):
+                # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader,
+                # which will override the class-level behavior
+                self._init_instance_attribute(state, callable_=self.setup_loader(state, selectcontext.options, selectcontext.query._current_path + path))
+
+            if self._should_log_debug:
+                new_execute = self.debug_callable(new_execute, self.logger, None,
+                    lambda state, row, **flags: "set instance-level lazy loader on %s" % mapperutil.state_attribute_str(state, self.key)
+                )
+
+            return (new_execute, None)
         else:
-            def new_execute(instance, row, ispostselect, **flags):
-                if not ispostselect:
-                    if self._should_log_debug:
-                        self.logger.debug("set class-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key))
-                    # we are the primary manager for this attribute on this class - reset its per-instance attribute state, 
-                    # so that the class-level lazy loader is executed when next referenced on this instance.
-                    # this usually is not needed unless the constructor of the object referenced the attribute before we got 
-                    # to load data into it.
-                    instance._state.reset(self.key)
-            return (new_execute, None, None)
+            def new_execute(state, row, **flags):
+                # we are the primary manager for this attribute on this class - reset its per-instance attribute state, 
+                # so that the class-level lazy loader is executed when next referenced on this instance.
+                # this usually is not needed unless the constructor of the object referenced the attribute before we got 
+                # to load data into it.
+                state.reset(self.key)
+
+            if self._should_log_debug:
+                new_execute = self.debug_callable(new_execute, self.logger, None,
+                    lambda state, row, **flags: "set class-level lazy loader on %s" % mapperutil.state_attribute_str(state, self.key)
+                )
+
+            return (new_execute, None)
 
     def __create_lazy_clause(cls, prop, reverse_direction=False):
         binds = {}
@@ -374,16 +404,16 @@ class LazyLoader(AbstractRelationLoader):
                     binds[col] = sql.bindparam(None, None, type_=col.type)
                 return binds[col]
             return None
-                    
-        lazywhere = prop.primaryjoin
         
+        lazywhere = prop.primaryjoin
+
         if not prop.secondaryjoin or not reverse_direction:
-            lazywhere = visitors.traverse(lazywhere, before_clone=col_to_bind, clone=True
+            lazywhere = visitors.replacement_traverse(lazywhere, {}, col_to_bind
         
         if prop.secondaryjoin is not None:
             secondaryjoin = prop.secondaryjoin
             if reverse_direction:
-                secondaryjoin = visitors.traverse(secondaryjoin, before_clone=col_to_bind, clone=True)
+                secondaryjoin = visitors.replacement_traverse(secondaryjoin, {}, col_to_bind)
             lazywhere = sql.and_(lazywhere, secondaryjoin)
     
         bind_to_col = dict([(binds[col].key, col) for col in binds])
@@ -391,47 +421,44 @@ class LazyLoader(AbstractRelationLoader):
         return (lazywhere, bind_to_col, equated_columns)
     __create_lazy_clause = classmethod(__create_lazy_clause)
     
-LazyLoader.logger = logging.class_logger(LazyLoader)
+LazyLoader.logger = log.class_logger(LazyLoader)
 
 class LoadLazyAttribute(object):
-    """callable, serializable loader object used by LazyLoader"""
+    """serializable loader object used by LazyLoader"""
 
-    def __init__(self, instance, key, options, path):
-        self.instance = instance
+    def __init__(self, state, key, options, path):
+        self.state = state
         self.key = key
         self.options = options
         self.path = path
         
     def __getstate__(self):
-        return {'instance':self.instance, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)}
+        return {'state':self.state, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)}
 
     def __setstate__(self, state):
-        self.instance = state['instance']
+        self.state = state['state']
         self.key = state['key']
-        self.options= state['options']
+        self.options = state['options']
         self.path = deserialize_path(state['path'])
         
     def __call__(self):
-        instance = self.instance
-        
-        if not mapper.has_identity(instance):
+        state = self.state
+        if not mapper._state_has_identity(state):
             return None
 
-        instance_mapper = mapper.object_mapper(instance)
+        instance_mapper = mapper._state_mapper(state)
         prop = instance_mapper.get_property(self.key)
         strategy = prop._get_strategy(LazyLoader)
         
         if strategy._should_log_debug:
-            strategy.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
+            strategy.logger.debug("loading %s" % mapperutil.state_attribute_str(state, self.key))
 
-        session = sessionlib.object_session(instance)
+        session = sessionlib._state_session(state)
         if session is None:
-            try:
-                session = instance_mapper.get_session()
-            except exceptions.InvalidRequestError:
-                raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+            raise sa_exc.UnboundExecutionError("Parent instance %s is not bound to a Session; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.state_str(state), self.key))
 
-        q = session.query(prop.mapper).autoflush(False)
+        q = session.query(prop.mapper).autoflush(False)._adapt_all_clauses()
+        
         if self.path:
             q = q._with_current_path(self.path)
             
@@ -441,7 +468,7 @@ class LoadLazyAttribute(object):
             ident = []
             allnulls = True
             for primary_key in prop.mapper.primary_key: 
-                val = instance_mapper._get_committed_attr_by_column(instance, strategy._equated_columns[primary_key])
+                val = instance_mapper._get_committed_state_attr_by_column(state, strategy._equated_columns[primary_key])
                 allnulls = allnulls and val is None
                 ident.append(val)
             if allnulls:
@@ -450,14 +477,14 @@ class LoadLazyAttribute(object):
                 q = q._conditional_options(*self.options)
             return q.get(ident)
             
-        if strategy.order_by is not False:
-            q = q.order_by(strategy.order_by)
-        elif strategy.secondary is not None and strategy.secondary.default_order_by() is not None:
-            q = q.order_by(strategy.secondary.default_order_by())
+        if prop.order_by is not False:
+            q = q.order_by(prop.order_by)
+        elif prop.secondary is not None and prop.secondary.default_order_by() is not None:
+            q = q.order_by(prop.secondary.default_order_by())
 
         if self.options:
             q = q._conditional_options(*self.options)
-        q = q.filter(strategy.lazy_clause(instance))
+        q = q.filter(strategy.lazy_clause(state))
 
         result = q.all()
         if strategy.uselist:
@@ -478,19 +505,35 @@ class EagerLoader(AbstractRelationLoader):
         self.join_depth = self.parent_property.join_depth
 
     def init_class_attribute(self):
-        # class-level eager strategy; add the PropertyLoader
-        # to the parent's list of "eager loaders"; this tells the Query
-        # that eager loaders will be used in a normal query
-        self.parent._eager_loaders.add(self.parent_property)
-        
-        # initialize a lazy loader on the class level attribute
         self.parent_property._get_strategy(LazyLoader).init_class_attribute()
         
-    def setup_query(self, context, parentclauses=None, parentmapper=None, **kwargs):
+    def setup_query(self, context, entity, path, adapter, column_collection=None, parentmapper=None, **kwargs):
         """Add a left outer join to the statement thats being constructed."""
+
+        path = path + (self.key,)
+
+        # check for user-defined eager alias
+        if ("eager_row_processor", path) in context.attributes:
+            clauses = context.attributes[("eager_row_processor", path)]
+            
+            adapter = entity._get_entity_clauses(context.query, context)
+            if adapter and clauses:
+                context.attributes[("eager_row_processor", path)] = clauses = adapter.wrap(clauses)
+            elif adapter:
+                context.attributes[("eager_row_processor", path)] = clauses = adapter
+                
+        else:
         
-        path = context.path
-        
+            clauses = self.__create_eager_join(context, entity, path, adapter, parentmapper)
+            if not clauses:
+                return
+
+            context.attributes[("eager_row_processor", path)] = clauses
+            
+        for value in self.mapper._iterate_polymorphic_properties():
+            value.setup(context, entity, path + (self.mapper.base_mapper,), clauses, parentmapper=self.mapper, column_collection=context.secondary_columns)
+    
+    def __create_eager_join(self, context, entity, path, adapter, parentmapper):
         # check for join_depth or basic recursion,
         # if the current path was not explicitly stated as 
         # a desired "loaderstrategy" (i.e. via query.options())
@@ -502,159 +545,148 @@ class EagerLoader(AbstractRelationLoader):
                 if self.mapper.base_mapper in path:
                     return
 
-        if ("eager_row_processor", path) in context.attributes:
-            # if user defined eager_row_processor, that's contains_eager().
-            # don't render LEFT OUTER JOIN, generate an AliasedClauses from 
-            # the decorator (this is a hack here, cleaned up in 0.5)
-            cl = context.attributes[("eager_row_processor", path)]
-            if cl:
-                row = cl(None)
-                class ActsLikeAliasedClauses(object):
-                    def aliased_column(self, col):
-                        return row.map[col]
-                clauses = ActsLikeAliasedClauses()
-            else:
-                clauses = None
-        else:
-            clauses = self.__create_eager_join(context, path, parentclauses, parentmapper, **kwargs)
-            if not clauses:
-                return
-
-        for value in self.mapper._iterate_polymorphic_properties():
-            context.exec_with_path(self.mapper, value.key, value.setup, context, parentclauses=clauses, parentmapper=self.mapper)
-
-    def __create_eager_join(self, context, path, parentclauses, parentmapper, **kwargs):
         if parentmapper is None:
-            localparent = context.mapper
+            localparent = entity.mapper
         else:
             localparent = parentmapper
-        
-        if context.eager_joins:
-            towrap = context.eager_joins
+    
+        # whether or not the Query will wrap the selectable in a subquery,
+        # and then attach eager load joins to that (i.e., in the case of LIMIT/OFFSET etc.)
+        should_nest_selectable = context.query._should_nest_selectable
+    
+        if entity in context.eager_joins:
+            entity_key, default_towrap = entity, entity.selectable
+        elif should_nest_selectable or not context.from_clause or not sql_util.search(context.from_clause, entity.selectable):
+            # if no from_clause, or a from_clause we can't join to, or a subquery is going to be generated, 
+            # store eager joins per _MappedEntity; Query._compile_context will 
+            # add them as separate selectables to the select(), or splice them together
+            # after the subquery is generated
+            entity_key, default_towrap = entity, entity.selectable
         else:
-            towrap = context.from_clause
-        
-        # create AliasedClauses object to build up the eager query.  this is cached after 1st creation.    
+            # otherwise, create a single eager join from the from clause.  
+            # Query._compile_context will adapt as needed and append to the
+            # FROM clause of the select().
+            entity_key, default_towrap = None, context.from_clause
+    
+        towrap = context.eager_joins.setdefault(entity_key, default_towrap)
+    
+        # create AliasedClauses object to build up the eager query.  this is cached after 1st creation.
+        # this also allows ORMJoin to cache the aliased joins it produces since we pass the same
+        # args each time in the typical case.
+        path_key = util.WeakCompositeKey(*path)
         try:
-            clauses = self.clauses[path]
+            clauses = self.clauses[path_key]
         except KeyError:
-            clauses = mapperutil.PropertyAliasedClauses(self.parent_property, self.parent_property.primaryjoin, self.parent_property.secondaryjoin, parentclauses)
-            self.clauses[path] = clauses
+            self.clauses[path_key] = clauses = mapperutil.ORMAdapter(mapperutil.AliasedClass(self.mapper), 
+                    equivalents=self.mapper._equivalent_columns, 
+                    chain_to=adapter)
 
-        # place the "row_decorator" from the AliasedClauses into the QueryContext, where it will
-        # be picked up in create_row_processor() when results are fetched
-        context.attributes[("eager_row_processor", path)] = clauses.row_decorator
-        
-        if self.secondaryjoin is not None:
-            context.eager_joins = sql.outerjoin(towrap, clauses.secondary, clauses.primaryjoin).outerjoin(clauses.alias, clauses.secondaryjoin)
-            
-            # TODO: check for "deferred" cols on parent/child tables here ?  this would only be
-            # useful if the primary/secondaryjoin are against non-PK columns on the tables (and therefore might be deferred)
-            
-            if self.order_by is False and self.secondary.default_order_by() is not None:
-                context.eager_order_by += clauses.secondary.default_order_by()
+        if adapter:
+            if getattr(adapter, 'aliased_class', None):
+                onclause = getattr(adapter.aliased_class, self.key, self.parent_property)
+            else:
+                onclause = getattr(mapperutil.AliasedClass(self.parent, adapter.selectable), self.key, self.parent_property)
         else:
-            context.eager_joins = towrap.outerjoin(clauses.alias, clauses.primaryjoin)
-            # ensure all the cols on the parent side are actually in the
+            onclause = self.parent_property
+    
+        context.eager_joins[entity_key] = eagerjoin = mapperutil.outerjoin(towrap, clauses.aliased_class, onclause)
+        
+        # send a hint to the Query as to where it may "splice" this join
+        eagerjoin.stop_on = entity.selectable
+        
+        if not self.parent_property.secondary and context.query._should_nest_selectable and not parentmapper:
+            # for parentclause that is the non-eager end of the join,
+            # ensure all the parent cols in the primaryjoin are actually in the
             # columns clause (i.e. are not deferred), so that aliasing applied by the Query propagates 
             # those columns outward.  This has the effect of "undefering" those columns.
-            for col in sql_util.find_columns(clauses.primaryjoin):
+            for col in sql_util.find_columns(self.parent_property.primaryjoin):
                 if localparent.mapped_table.c.contains_column(col):
+                    if adapter:
+                        col = adapter.columns[col]
                     context.primary_columns.append(col)
-                
-            if self.order_by is False and clauses.alias.default_order_by() is not None:
-                context.eager_order_by += clauses.alias.default_order_by()
-
-        if clauses.order_by:
-            context.eager_order_by += util.to_list(clauses.order_by)
         
+        if self.parent_property.order_by is False:
+            if self.parent_property.secondaryjoin:
+                default_order_by = eagerjoin.left.right.default_order_by()
+            else:
+                default_order_by = eagerjoin.right.default_order_by()
+            if default_order_by:
+                context.eager_order_by += default_order_by
+        elif self.parent_property.order_by:
+            context.eager_order_by += eagerjoin._target_adapter.copy_and_process(util.to_list(self.parent_property.order_by))
+            
         return clauses
         
-    def _create_row_decorator(self, selectcontext, row, path):
-        """Create a *row decorating* function that will apply eager
-        aliasing to the row.
-        
-        Also check that an identity key can be retrieved from the row,
-        else return None.
-        """
-        
-        #print "creating row decorator for path ", "->".join([str(s) for s in path])
-        
-        if ("eager_row_processor", path) in selectcontext.attributes:
-            decorator = selectcontext.attributes[("eager_row_processor", path)]
-            if decorator is None:
-                decorator = lambda row: row
+    def __create_eager_adapter(self, context, row, adapter, path):
+        if ("eager_row_processor", path) in context.attributes:
+            decorator = context.attributes[("eager_row_processor", path)]
         else:
             if self._should_log_debug:
                 self.logger.debug("Could not locate aliased clauses for key: " + str(path))
-            return None
+            return False
 
+        if adapter and decorator:
+            decorator = adapter.wrap(decorator)
+        elif adapter:
+            decorator = adapter
+            
         try:
-            decorated_row = decorator(row)
-            # check for identity key
-            identity_key = self.mapper.identity_key_from_row(decorated_row)
-            # and its good
+            identity_key = self.mapper.identity_key_from_row(row, decorator)
             return decorator
         except KeyError, k:
             # no identity key - dont return a row processor, will cause a degrade to lazy
             if self._should_log_debug:
-                self.logger.debug("could not locate identity key from row '%s'; missing column '%s'" % (repr(decorated_row), str(k)))
-            return None
-
-    def create_row_processor(self, selectcontext, mapper, row):
+                self.logger.debug("could not locate identity key from row; missing column '%s'" % k)
+            return False
 
-        row_decorator = self._create_row_decorator(selectcontext, row, selectcontext.path)
-        pathstr = ','.join([str(x) for x in selectcontext.path])
-        if row_decorator is not None:
-            def execute(instance, row, isnew, **flags):
-                decorated_row = row_decorator(row)
-
-                if not self.uselist:
-                    if self._should_log_debug:
-                        self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key))
+    def create_row_processor(self, context, path, mapper, row, adapter):
+        path = path + (self.key,)
+        eager_adapter = self.__create_eager_adapter(context, row, adapter, path)
+        
+        if eager_adapter is not False:
+            key = self.key
+            _instance = self.mapper._instance_processor(context, path + (self.mapper.base_mapper,), eager_adapter)
+            
+            if not self.uselist:
+                def execute(state, row, isnew, **flags):
                     if isnew:
                         # set a scalar object instance directly on the
                         # parent object, bypassing InstrumentedAttribute
                         # event handlers.
-                        #
-                        instance.__dict__[self.key] = self.mapper._instance(selectcontext, decorated_row, None)
+                        state.dict[key] = _instance(row, None)
                     else:
                         # call _instance on the row, even though the object has been created,
                         # so that we further descend into properties
-                        self.mapper._instance(selectcontext, decorated_row, None)
-                else:
-                    if isnew or self.key not in instance._state.appenders:
-                        # appender_key can be absent from selectcontext.attributes with isnew=False
+                        _instance(row, None)
+            else:
+                def execute(state, row, isnew, **flags):
+                    if isnew or (state, key) not in context.attributes:
+                        # appender_key can be absent from context.attributes with isnew=False
                         # when self-referential eager loading is used; the same instance may be present
                         # in two distinct sets of result columns
-                        
-                        if self._should_log_debug:
-                            self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key))
 
-                        collection = attributes.init_collection(instance, self.key)
+                        collection = attributes.init_collection(state, key)
                         appender = util.UniqueAppender(collection, 'append_without_event')
 
-                        instance._state.appenders[self.key] = appender
-                    
-                    result_list = instance._state.appenders[self.key]
-                    if self._should_log_debug:
-                        self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key))
+                        context.attributes[(state, key)] = appender
+
+                    result_list = context.attributes[(state, key)]
                     
-                    self.mapper._instance(selectcontext, decorated_row, result_list)
+                    _instance(row, result_list)
 
             if self._should_log_debug:
-                self.logger.debug("Returning eager instance loader for %s" % str(self))
+                execute = self.debug_callable(execute, self.logger, 
+                    "%s returning eager instance loader" % self,
+                    lambda state, row, isnew, **flags: "%s eagerload %s" % (self, self.uselist and "scalar attribute" or "collection")
+                )
 
-            return (execute, execute, None)
+            return (execute, execute)
         else:
             if self._should_log_debug:
-                self.logger.debug("eager loader %s degrading to lazy loader" % str(self))
-            return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row)
+                self.logger.debug("%s degrading to lazy loader" % self)
+            return self.parent_property._get_strategy(LazyLoader).create_row_processor(context, path, mapper, row, adapter)
 
-    def __str__(self):
-        return str(self.parent) + "." + self.key
-        
-EagerLoader.logger = logging.class_logger(EagerLoader)
+EagerLoader.logger = log.class_logger(EagerLoader)
 
 class EagerLazyOption(StrategizedOption):
     def __init__(self, key, lazy=True, chained=False, mapper=None):
@@ -665,20 +697,6 @@ class EagerLazyOption(StrategizedOption):
     def is_chained(self):
         return not self.lazy and self.chained
         
-    def process_query_property(self, query, paths):
-        if self.lazy:
-            if paths[-1] in query._eager_loaders:
-                query._eager_loaders = query._eager_loaders.difference(util.Set([paths[-1]]))
-        else:
-            if not self.chained:
-                paths = [paths[-1]]
-            res = util.Set()
-            for path in paths:
-                if len(path) - len(query._current_path) == 2:
-                    res.add(path)
-            query._eager_loaders = query._eager_loaders.union(res)
-        super(EagerLazyOption, self).process_query_property(query, paths)
-
     def get_strategy_class(self):
         if self.lazy:
             return LazyLoader
@@ -687,24 +705,26 @@ class EagerLazyOption(StrategizedOption):
         elif self.lazy is None:
             return NoLoader
 
-EagerLazyOption.logger = logging.class_logger(EagerLazyOption)
-
-class RowDecorateOption(PropertyOption):
-    def __init__(self, key, decorator=None, alias=None):
-        super(RowDecorateOption, self).__init__(key)
-        self.decorator = decorator
+class LoadEagerFromAliasOption(PropertyOption):
+    def __init__(self, key, alias=None):
+        super(LoadEagerFromAliasOption, self).__init__(key)
+        if alias:
+            if not isinstance(alias, basestring):
+                m, alias, is_aliased_class = mapperutil._entity_info(alias)
         self.alias = alias
 
     def process_query_property(self, query, paths):
-        if self.alias is not None and self.decorator is None:
-            (mapper, propname) = paths[-1][-2:]
-
-            prop = mapper.get_property(propname, resolve_synonyms=True)
+        if self.alias:
             if isinstance(self.alias, basestring):
-                self.alias = prop.target.alias(self.alias)
+                (mapper, propname) = paths[-1][-2:]
 
-            self.decorator = mapperutil.create_row_adapter(self.alias)
-        query._attributes[("eager_row_processor", paths[-1])] = self.decorator
+                prop = mapper.get_property(propname, resolve_synonyms=True)
+                self.alias = prop.target.alias(self.alias)
+            if not isinstance(self.alias, expression.Alias):
+                import pdb
+                pdb.set_trace()
+            query._attributes[("eager_row_processor", paths[-1])] = sql_util.ColumnAdapter(self.alias)
+        else:
+            query._attributes[("eager_row_processor", paths[-1])] = None
 
-RowDecorateOption.logger = logging.class_logger(RowDecorateOption)
         
index 39a7b5044c40064267f827163e75fa6f35239de0..eca80df25e9eefd9991fbd4453b64b25550605ea 100644 (file)
@@ -8,31 +8,27 @@
 based on join conditions.
 """
 
-from sqlalchemy import schema, exceptions, util
-from sqlalchemy.sql import visitors, operators, util as sqlutil
-from sqlalchemy import logging
-from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY  # legacy
+from sqlalchemy.orm import exc, util as mapperutil
 
 def populate(source, source_mapper, dest, dest_mapper, synchronize_pairs):
     for l, r in synchronize_pairs:
         try:
             value = source_mapper._get_state_attr_by_column(source, l)
-        except exceptions.UnmappedColumnError:
+        except exc.UnmappedColumnError:
             _raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
 
         try:
             dest_mapper._set_state_attr_by_column(dest, r, value)
-        except exceptions.UnmappedColumnError:
-            self._raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
+        except exc.UnmappedColumnError:
+            _raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
 
 def clear(dest, dest_mapper, synchronize_pairs):
     for l, r in synchronize_pairs:
         if r.primary_key:
-            raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (r, mapperutil.state_str(dest)))
+            raise AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (r, mapperutil.state_str(dest)))
         try:
             dest_mapper._set_state_attr_by_column(dest, r, None)
-        except exceptions.UnmappedColumnError:
+        except exc.UnmappedColumnError:
             _raise_col_to_prop(True, None, l, dest_mapper, r)
 
 def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
@@ -40,8 +36,8 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
         try:
             oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l)
             value = source_mapper._get_state_attr_by_column(source, l)
-        except exceptions.UnmappedColumnError:
-            self._raise_col_to_prop(False, source_mapper, l, None, r)
+        except exc.UnmappedColumnError:
+            _raise_col_to_prop(False, source_mapper, l, None, r)
         dest[r.key] = value
         dest[old_prefix + r.key] = oldvalue
 
@@ -49,16 +45,16 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs):
     for l, r in synchronize_pairs:
         try:
             value = source_mapper._get_state_attr_by_column(source, l)
-        except exceptions.UnmappedColumnError:
+        except exc.UnmappedColumnError:
             _raise_col_to_prop(False, source_mapper, l, None, r)
-            
+
         dict_[r.key] = value
 
 def source_changes(uowcommit, source, source_mapper, synchronize_pairs):
     for l, r in synchronize_pairs:
         try:
             prop = source_mapper._get_col_to_prop(l)
-        except exceptions.UnmappedColumnError:
+        except exc.UnmappedColumnError:
             _raise_col_to_prop(False, source_mapper, l, None, r)
         (added, unchanged, deleted) = uowcommit.get_attribute_history(source, prop.key, passive=True)
         if added and deleted:
@@ -70,7 +66,7 @@ def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs):
     for l, r in synchronize_pairs:
         try:
             prop = dest_mapper._get_col_to_prop(r)
-        except exceptions.UnmappedColumnError:
+        except exc.UnmappedColumnError:
             _raise_col_to_prop(True, None, l, dest_mapper, r)
         (added, unchanged, deleted) = uowcommit.get_attribute_history(dest, prop.key, passive=True)
         if added and deleted:
@@ -80,7 +76,6 @@ def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs):
 
 def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column):
     if isdest:
-        raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (dest_column, source_mapper))
+        raise exc.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (dest_column, source_mapper))
     else:
-        raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (source_column, source_mapper, dest_column))
-        
+        raise exc.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (source_column, source_mapper, dest_column))
index 66b68770d619ab87ce0d08ded2aa363c65d45657..4edfeefdceb6905da2b35ed083f0c41d3b9b1e7a 100644 (file)
@@ -17,16 +17,19 @@ unique against their primary key identity using an *identity map*
 pattern.  The Unit of Work then maintains lists of objects that are
 new, dirty, or deleted and provides the capability to flush all those
 changes at once.
+
 """
 
-import StringIO, weakref
-from sqlalchemy import util, logging, topological, exceptions
+import StringIO
+
+from sqlalchemy import util, log, topological
 from sqlalchemy.orm import attributes, interfaces
 from sqlalchemy.orm import util as mapperutil
-from sqlalchemy.orm.mapper import object_mapper, _state_mapper, has_identity
+from sqlalchemy.orm.mapper import _state_mapper
 
 # Load lazily
 object_session = None
+_state_session = None
 
 class UOWEventHandler(interfaces.AttributeExtension):
     """An event handler added to all relation attributes which handles
@@ -38,33 +41,33 @@ class UOWEventHandler(interfaces.AttributeExtension):
         self.class_ = class_
         self.cascade = cascade
     
-    def _target_mapper(self, obj):
-        prop = object_mapper(obj).get_property(self.key)
+    def _target_mapper(self, state):
+        prop = _state_mapper(state).get_property(self.key)
         return prop.mapper
 
-    def append(self, obj, item, initiator):
+    def append(self, state, item, initiator):
         # process "save_update" cascade rules for when an instance is appended to the list of another instance
-        sess = object_session(obj)
+        sess = _state_session(state)
         if sess:
             if self.cascade.save_update and item not in sess:
-                sess.save_or_update(item, entity_name=self._target_mapper(obj).entity_name)
+                sess.save_or_update(item, entity_name=self._target_mapper(state).entity_name)
 
-    def remove(self, obj, item, initiator):
-        sess = object_session(obj)
+    def remove(self, state, item, initiator):
+        sess = _state_session(state)
         if sess:
             # expunge pending orphans
             if self.cascade.delete_orphan and item in sess.new:
-                if self._target_mapper(obj)._is_orphan(item):
+                if self._target_mapper(state)._is_orphan(attributes.instance_state(item)):
                     sess.expunge(item)
 
-    def set(self, obj, newvalue, oldvalue, initiator):
+    def set(self, state, newvalue, oldvalue, initiator):
         # process "save_update" cascade rules for when an instance is attached to another instance
         if oldvalue is newvalue:
             return
-        sess = object_session(obj)
+        sess = _state_session(state)
         if sess:
             if newvalue is not None and self.cascade.save_update and newvalue not in sess:
-                sess.save_or_update(newvalue, entity_name=self._target_mapper(obj).entity_name)
+                sess.save_or_update(newvalue, entity_name=self._target_mapper(state).entity_name)
             if self.cascade.delete_orphan and oldvalue in sess.new:
                 sess.expunge(oldvalue)
 
@@ -86,184 +89,6 @@ def register_attribute(class_, key, *args, **kwargs):
     
 
 
-class UnitOfWork(object):
-    """Main UOW object which stores lists of dirty/new/deleted objects.
-
-    Provides top-level *flush* functionality as well as the
-    default transaction boundaries involved in a write
-    operation.
-    """
-
-    def __init__(self, session):
-        if session.weak_identity_map:
-            self.identity_map = attributes.WeakInstanceDict()
-        else:
-            self.identity_map = attributes.StrongInstanceDict()
-
-        self.new = {}   # InstanceState->object, strong refs object
-        self.deleted = {}  # same
-        self.logger = logging.instance_logger(self, echoflag=session.echo_uow)
-
-    def _remove_deleted(self, state):
-        if '_instance_key' in state.dict:
-            del self.identity_map[state.dict['_instance_key']]
-        self.deleted.pop(state, None)
-        self.new.pop(state, None)
-
-    def _is_valid(self, state):
-        if '_instance_key' in state.dict:
-            return state.dict['_instance_key'] in self.identity_map
-        else:
-            return state in self.new
-
-    def _register_clean(self, state):
-        """register the given object as 'clean' (i.e. persistent) within this unit of work, after
-        a save operation has taken place."""
-
-        mapper = _state_mapper(state)
-        instance_key = mapper._identity_key_from_state(state)
-        
-        if '_instance_key' not in state.dict:
-            state.dict['_instance_key'] = instance_key
-            
-        elif state.dict['_instance_key'] != instance_key:
-            # primary key switch
-            del self.identity_map[state.dict['_instance_key']]
-            state.dict['_instance_key'] = instance_key
-            
-        if hasattr(state, 'insert_order'):
-            delattr(state, 'insert_order')
-        
-        o = state.obj()
-        # prevent against last minute dereferences of the object
-        # TODO: identify a code path where state.obj() is None
-        if o is not None:
-            self.identity_map[state.dict['_instance_key']] = o
-            state.commit_all()
-        
-        # remove from new last, might be the last strong ref
-        self.new.pop(state, None)
-
-    def register_new(self, obj):
-        """register the given object as 'new' (i.e. unsaved) within this unit of work."""
-
-        if hasattr(obj, '_instance_key'):
-            raise exceptions.InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj))
-        if obj._state not in self.new:
-            self.new[obj._state] = obj
-            obj._state.insert_order = len(self.new)
-
-    def register_deleted(self, obj):
-        """register the given persistent object as 'to be deleted' within this unit of work."""
-        
-        self.deleted[obj._state] = obj
-
-    def locate_dirty(self):
-        """return a set of all persistent instances within this unit of work which 
-        either contain changes or are marked as deleted.
-        """
-        
-        # a little bit of inlining for speed
-        return util.IdentitySet([x for x in self.identity_map.values() 
-            if x._state not in self.deleted 
-            and (
-                x._state.modified
-                or (x.__class__._class_state.has_mutable_scalars and x._state.is_modified())
-            )
-            ])
-
-    def flush(self, session, objects=None):
-        """create a dependency tree of all pending SQL operations within this unit of work and execute."""
-
-        dirty = [x for x in self.identity_map.all_states()
-            if x.modified
-            or (x.class_._class_state.has_mutable_scalars and x.is_modified())
-        ]
-        
-        if not dirty and not self.deleted and not self.new:
-            return
-        
-        deleted = util.Set(self.deleted)
-        new = util.Set(self.new)
-        
-        dirty = util.Set(dirty).difference(deleted)
-        
-        flush_context = UOWTransaction(self, session)
-
-        if session.extension is not None:
-            session.extension.before_flush(session, flush_context, objects)
-
-        # create the set of all objects we want to operate upon
-        if objects:
-            # specific list passed in
-            objset = util.Set([o._state for o in objects])
-        else:
-            # or just everything
-            objset = util.Set(self.identity_map.all_states()).union(new)
-            
-        # store objects whose fate has been decided
-        processed = util.Set()
-
-        # put all saves/updates into the flush context.  detect top-level orphans and throw them into deleted.
-        for state in new.union(dirty).intersection(objset).difference(deleted):
-            if state in processed:
-                continue
-
-            obj = state.obj()
-            is_orphan = _state_mapper(state)._is_orphan(obj)
-            if is_orphan and not has_identity(obj):
-                raise exceptions.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" %
-                    (
-                        obj,
-                        ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in _state_mapper(state).delete_orphans])
-                    ))
-            flush_context.register_object(state, isdelete=is_orphan)
-            processed.add(state)
-
-        # put all remaining deletes into the flush context.
-        for state in deleted.intersection(objset).difference(processed):
-            flush_context.register_object(state, isdelete=True)
-
-        if len(flush_context.tasks) == 0:
-            return
-            
-        session.create_transaction(autoflush=False)
-        flush_context.transaction = session.transaction
-        try:
-            flush_context.execute()
-            
-            if session.extension is not None:
-                session.extension.after_flush(session, flush_context)
-            session.commit()
-        except:
-            session.rollback()
-            raise
-
-        flush_context.post_exec()
-
-        if session.extension is not None:
-            session.extension.after_flush_postexec(session, flush_context)
-
-    def prune_identity_map(self):
-        """Removes unreferenced instances cached in a strong-referencing identity map.
-
-        Note that this method is only meaningful if "weak_identity_map"
-        on the parent Session is set to False and therefore this UnitOfWork's
-        identity map is a regular dictionary
-        
-        Removes any object in the identity map that is not referenced
-        in user code or scheduled for a unit of work operation.  Returns
-        the number of objects pruned.
-        """
-
-        if isinstance(self.identity_map, attributes.WeakInstanceDict):
-            return 0
-        ref_count = len(self.identity_map)
-        dirty = self.locate_dirty()
-        keepers = weakref.WeakValueDictionary(self.identity_map)
-        self.identity_map.clear()
-        self.identity_map.update(keepers)
-        return ref_count - len(self.identity_map)
 
 class UOWTransaction(object):
     """Handles the details of organizing and executing transaction
@@ -275,8 +100,7 @@ class UOWTransaction(object):
     packages.
     """
 
-    def __init__(self, uow, session):
-        self.uow = uow
+    def __init__(self, session):
         self.session = session
         self.mapper_flush_opts = session._mapper_flush_opts
         
@@ -291,7 +115,7 @@ class UOWTransaction(object):
         # information. 
         self.attributes = {}
 
-        self.logger = logging.instance_logger(self, echoflag=session.echo_uow)
+        self.logger = log.instance_logger(self, echoflag=session.echo_uow)
 
     def get_attribute_history(self, state, key, passive=True):
         hashkey = ("history", state, key)
@@ -310,19 +134,18 @@ class UOWTransaction(object):
             (added, unchanged, deleted) = attributes.get_history(state, key, passive=passive)
             self.attributes[hashkey] = (added, unchanged, deleted, passive)
 
-        if added is None:
+        if added is None or not state.get_impl(key).uses_objects:
             return (added, unchanged, deleted)
         else:
             return (
-                [getattr(c, '_state', c) for c in added],
-                [getattr(c, '_state', c) for c in unchanged],
-                [getattr(c, '_state', c) for c in deleted],
+                [c is not None and attributes.instance_state(c) or None for c in added],
+                [c is not None and attributes.instance_state(c) or None for c in unchanged],
+                [c is not None and attributes.instance_state(c) or None for c in deleted],
                 )
 
-        
-    def register_object(self, state, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs):
+    def register_object(self, state, isdelete=False, listonly=False, postupdate=False, post_update_cols=None):
         # if object is not in the overall session, do nothing
-        if not self.uow._is_valid(state):
+        if not self.session._contains_state(state):
             if self._should_log_debug:
                 self.logger.debug("object %s not part of session, not registering for flush" % (mapperutil.state_str(state)))
             return
@@ -331,12 +154,12 @@ class UOWTransaction(object):
             self.logger.debug("register object for flush: %s isdelete=%s listonly=%s postupdate=%s" % (mapperutil.state_str(state), isdelete, listonly, postupdate))
 
         mapper = _state_mapper(state)
-        
+
         task = self.get_task_by_mapper(mapper)
         if postupdate:
             task.append_postupdate(state, post_update_cols)
         else:
-            task.append(state, listonly, isdelete=isdelete, **kwargs)
+            task.append(state, listonly=listonly, isdelete=isdelete)
 
     def set_row_switch(self, state):
         """mark a deleted object as a 'row switch'.
@@ -451,22 +274,26 @@ class UOWTransaction(object):
         import uowdumper
         uowdumper.UOWDumper(tasks, buf)
         return buf.getvalue()
-        
-    def post_exec(self):
+    
+    def elements(self):
+        """return an iterator of all UOWTaskElements within this UOWTransaction."""
+        for task in self.tasks.values():
+            for elem in task.elements:
+                yield elem
+    elements = property(elements)
+    
+    def finalize_flush_changes(self):
         """mark processed objects as clean / deleted after a successful flush().
         
         this method is called within the flush() method after the 
         execute() method has succeeded and the transaction has been committed.
         """
 
-        for task in self.tasks.values():
-            for elem in task.elements:
-                if elem.state is None:
-                    continue
-                if elem.isdelete:
-                    self.uow._remove_deleted(elem.state)
-                else:
-                    self.uow._register_clean(elem.state)
+        for elem in self.elements:
+            if elem.isdelete:
+                self.session._remove_newly_deleted(elem.state)
+            else:
+                self.session._register_newly_persistent(elem.state)
 
     def _sort_dependencies(self):
         nodes = topological.sort_with_cycles(self.dependencies, 
@@ -489,10 +316,9 @@ class UOWTransaction(object):
 
 class UOWTask(object):
     """Represents all of the objects in the UOWTransaction which correspond to
-    a particular mapper.  This is the primary class of three classes used to generate
-    the elements of the dependency graph.
+    a particular mapper.  
+    
     """
-
     def __init__(self, uowtransaction, mapper, base_task=None):
         self.uowtransaction = uowtransaction
 
@@ -515,6 +341,7 @@ class UOWTask(object):
         # mapping of InstanceState -> UOWTaskElement
         self._objects = {} 
 
+        self.dependent_tasks = []
         self.dependencies = util.Set()
         self.cyclical_dependencies = util.Set()
 
@@ -564,11 +391,6 @@ class UOWTask(object):
         
         rec.update(listonly, isdelete)
     
-    def _append_cyclical_childtask(self, task):
-        if "cyclical" not in self._objects:
-            self._objects["cyclical"] = UOWTaskElement(None)
-        self._objects["cyclical"].childtasks.append(task)
-
     def append_postupdate(self, state, post_update_cols):
         """issue a 'post update' UPDATE statement via this object's mapper immediately.  
         
@@ -577,8 +399,8 @@ class UOWTask(object):
         """
 
         # postupdates are UPDATED immeditely (for now)
-        # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns
-        # instead of __eq__
+        # convert post_update_cols list to a Set so that __hash__() is used to compare columns
+        # instead of __eq__()
         self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols))
 
     def __contains__(self, state):
@@ -607,26 +429,42 @@ class UOWTask(object):
                 for rec in callable(task):
                     yield rec
         return property(collection)
-        
-    elements = property(lambda self:self._objects.values())
     
-    polymorphic_elements = _polymorphic_collection(lambda task:task.elements)
-
-    polymorphic_tosave_elements = property(lambda self: [rec for rec in self.polymorphic_elements
-                                             if not rec.isdelete])
-                                             
-    polymorphic_todelete_elements = property(lambda self:[rec for rec in self.polymorphic_elements
-                                               if rec.isdelete])
+    def _elements(self):
+        return self._objects.values()
+    elements = property(_elements)
+    
+    polymorphic_elements = _polymorphic_collection(_elements)
 
-    polymorphic_tosave_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements
-                                          if rec.state is not None and not rec.listonly and rec.isdelete is False])
+    def polymorphic_tosave_elements(self):
+        return [rec for rec in self.polymorphic_elements if not rec.isdelete]
+    polymorphic_tosave_elements = property(polymorphic_tosave_elements)
+    
+    def polymorphic_todelete_elements(self):
+        return [rec for rec in self.polymorphic_elements if rec.isdelete]
+    polymorphic_todelete_elements = property(polymorphic_todelete_elements)
+
+    def polymorphic_tosave_objects(self):
+        return [
+            rec.state for rec in self.polymorphic_elements
+            if rec.state is not None and not rec.listonly and rec.isdelete is False
+        ]
+    polymorphic_tosave_objects = property(polymorphic_tosave_objects)
 
-    polymorphic_todelete_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements
-                                          if rec.state is not None and not rec.listonly and rec.isdelete is True])
+    def polymorphic_todelete_objects(self):
+        return [
+            rec.state for rec in self.polymorphic_elements
+            if rec.state is not None and not rec.listonly and rec.isdelete is True
+        ]
+    polymorphic_todelete_objects = property(polymorphic_todelete_objects)
 
-    polymorphic_dependencies = _polymorphic_collection(lambda task:task.dependencies)
+    def polymorphic_dependencies(self):
+        return self.dependencies
+    polymorphic_dependencies = _polymorphic_collection(polymorphic_dependencies)
     
-    polymorphic_cyclical_dependencies = _polymorphic_collection(lambda task:task.cyclical_dependencies)
+    def polymorphic_cyclical_dependencies(self):
+        return self.cyclical_dependencies
+    polymorphic_cyclical_dependencies = _polymorphic_collection(polymorphic_cyclical_dependencies)
     
     def _sort_circular_dependencies(self, trans, cycles):
         """Create a hierarchical tree of *subtasks*
@@ -741,7 +579,7 @@ class UOWTask(object):
             if t is None:
                 t = UOWTask(self.uowtransaction, originating_task.mapper)
                 nexttasks[originating_task] = t
-                parenttask._append_cyclical_childtask(t)
+                parenttask.dependent_tasks.append(t)
             t.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete)
 
             if state in dependencies:
@@ -777,29 +615,17 @@ class UOWTask(object):
         return ret
 
     def __repr__(self):
-        if self.mapper is not None:
-            if self.mapper.__class__.__name__ == 'Mapper':
-                name = self.mapper.class_.__name__ + "/" + self.mapper.local_table.description
-            else:
-                name = repr(self.mapper)
-        else:
-            name = '(none)'
-        return ("UOWTask(%s) Mapper: '%s'" % (hex(id(self)), name))
+        return ("UOWTask(%s) Mapper: '%r'" % (hex(id(self)), self.mapper))
 
 class UOWTaskElement(object):
-    """An element within a UOWTask.
-
-    Corresponds to a single object instance to be saved, deleted, or
-    just part of the transaction as a placeholder for further
-    dependencies (i.e. 'listonly').
-
-    may also store additional sub-UOWTasks.
+    """Corresponds to a single InstanceState to be saved, deleted,
+    or otherwise marked as having dependencies.  A collection of 
+    UOWTaskElements are held by a UOWTask.
+    
     """
-
     def __init__(self, state):
         self.state = state
         self.listonly = True
-        self.childtasks = []
         self.isdelete = False
         self.__preprocessed = {}
 
@@ -835,11 +661,11 @@ class UOWTaskElement(object):
 
 class UOWDependencyProcessor(object):
     """In between the saving and deleting of objects, process
-    *dependent* data, such as filling in a foreign key on a child item
+    dependent data, such as filling in a foreign key on a child item
     from a new primary key, or deleting association rows before a
     delete.  This object acts as a proxy to a DependencyProcessor.
+    
     """
-
     def __init__(self, processor, targettask):
         self.processor = processor
         self.targettask = targettask
@@ -877,12 +703,12 @@ class UOWDependencyProcessor(object):
             return elem.state
 
         ret = False
-        elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None and not elem.is_preprocessed(self)]
+        elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if not elem.is_preprocessed(self)]
         if elements:
             ret = True
             self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False)
 
-        elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None and not elem.is_preprocessed(self)]
+        elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if not elem.is_preprocessed(self)]
         if elements:
             ret = True
             self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True)
@@ -892,9 +718,9 @@ class UOWDependencyProcessor(object):
         """process all objects contained within this ``UOWDependencyProcessor``s target task."""
         
         if not delete:
-            self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None], trans, delete=False)
+            self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements], trans, delete=False)
         else:
-            self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None], trans, delete=True)
+            self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements], trans, delete=True)
 
     def get_object_dependencies(self, state, trans, passive):
         return trans.get_attribute_history(state, self.processor.key, passive=passive)
@@ -907,7 +733,6 @@ class UOWDependencyProcessor(object):
         when toplogically sorting on a per-instance basis.
         
         """
-        
         return self.processor.whose_dependent_on_who(state1, state2)
 
     def branch(self, task):
@@ -917,7 +742,6 @@ class UOWDependencyProcessor(object):
         is broken up into many individual ``UOWTask`` objects.
         
         """
-        
         return UOWDependencyProcessor(self.processor, task)
     
         
@@ -944,13 +768,11 @@ class UOWExecutor(object):
     def execute_save_steps(self, trans, task):
         self.save_objects(trans, task)
         self.execute_cyclical_dependencies(trans, task, False)
-        self.execute_per_element_childtasks(trans, task, False)
         self.execute_dependencies(trans, task, False)
         self.execute_dependencies(trans, task, True)
-        
+
     def execute_delete_steps(self, trans, task):
         self.execute_cyclical_dependencies(trans, task, True)
-        self.execute_per_element_childtasks(trans, task, True)
         self.delete_objects(trans, task)
 
     def execute_dependencies(self, trans, task, isdelete=None):
@@ -964,12 +786,5 @@ class UOWExecutor(object):
     def execute_cyclical_dependencies(self, trans, task, isdelete):
         for dep in task.polymorphic_cyclical_dependencies:
             self.execute_dependency(trans, dep, isdelete)
-
-    def execute_per_element_childtasks(self, trans, task, isdelete):
-        for element in task.polymorphic_tosave_elements + task.polymorphic_todelete_elements:
-            self.execute_element_childtasks(trans, element, isdelete)
-
-    def execute_element_childtasks(self, trans, element, isdelete):
-        for child in element.childtasks:
-            self.execute(trans, [child], isdelete)
-
+        for t in task.dependent_tasks:
+            self.execute(trans, [t], isdelete)
index 4b3fed70aa1e664989f614a8c9e28624de23fa9b..09b82167da36734fd3ae16f3c9087d7dcc3fec16 100644 (file)
@@ -6,17 +6,15 @@
 
 """Dumps out a string representation of a UOWTask structure"""
 
+from sqlalchemy import util
 from sqlalchemy.orm import unitofwork
 from sqlalchemy.orm import util as mapperutil
-from sqlalchemy import util
 
 class UOWDumper(unitofwork.UOWExecutor):
-    def __init__(self, tasks, buf, verbose=False):
-        self.verbose = verbose
+    def __init__(self, tasks, buf):
         self.indent = 0
         self.tasks = tasks
         self.buf = buf
-        self.headers = {}
         self.execute(None, tasks)
 
     def execute(self, trans, tasks, isdelete=None):
@@ -62,88 +60,23 @@ class UOWDumper(unitofwork.UOWExecutor):
         for rec in l:
             if rec.listonly:
                 continue
-            self.header("Save elements"+ self._inheritance_tag(task))
             self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec)  + "\n")
-            self.closeheader()
 
     def delete_objects(self, trans, task):
         for rec in task.polymorphic_todelete_elements:
             if rec.listonly:
                 continue
-            self.header("Delete elements"+ self._inheritance_tag(task))
             self.buf.write(self._indent() + "- " + self._repr_task_element(rec)  + "\n")
-            self.closeheader()
-
-    def _inheritance_tag(self, task):
-        if not self.verbose:
-            return ""
-        else:
-            return (" (inheriting task %s)" % self._repr_task(task))
-
-    def header(self, text):
-        """Write a given header just once."""
-
-        if not self.verbose:
-            return
-        try:
-            self.headers[text]
-        except KeyError:
-            self.buf.write(self._indent() +  "- " + text + "\n")
-            self.headers[text] = True
-
-    def closeheader(self):
-        if not self.verbose:
-            return
-        self.buf.write(self._indent() + "- ------\n")
 
     def execute_dependency(self, transaction, dep, isdelete):
         self._dump_processor(dep, isdelete)
 
-    def execute_save_steps(self, trans, task):
-        super(UOWDumper, self).execute_save_steps(trans, task)
-
-    def execute_delete_steps(self, trans, task):
-        super(UOWDumper, self).execute_delete_steps(trans, task)
-
-    def execute_dependencies(self, trans, task, isdelete=None):
-        super(UOWDumper, self).execute_dependencies(trans, task, isdelete)
-
-    def execute_cyclical_dependencies(self, trans, task, isdelete):
-        self.header("Cyclical %s dependencies" % (isdelete and "delete" or "save"))
-        super(UOWDumper, self).execute_cyclical_dependencies(trans, task, isdelete)
-        self.closeheader()
-
-    def execute_per_element_childtasks(self, trans, task, isdelete):
-        super(UOWDumper, self).execute_per_element_childtasks(trans, task, isdelete)
-
-    def execute_element_childtasks(self, trans, element, isdelete):
-        self.header("%s subelements of UOWTaskElement(%s)" % ((isdelete and "Delete" or "Save"), hex(id(element))))
-        super(UOWDumper, self).execute_element_childtasks(trans, element, isdelete)
-        self.closeheader()
-
     def _dump_processor(self, proc, deletes):
         if deletes:
             val = proc.targettask.polymorphic_todelete_elements
         else:
             val = proc.targettask.polymorphic_tosave_elements
 
-        if self.verbose:
-            self.buf.write(self._indent() + "   +- %s attribute on %s (UOWDependencyProcessor(%d) processing %s)\n" % (
-                repr(proc.processor.key),
-                    ("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")),
-                hex(id(proc)),
-                self._repr_task(proc.targettask))
-            )
-        elif False:
-            self.buf.write(self._indent() + "   +- %s attribute on %s\n" % (
-                repr(proc.processor.key),
-                    ("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")),
-                )
-            )
-
-        if len(val) == 0:
-            if self.verbose:
-                self.buf.write(self._indent() + "   +- " + "(no objects)\n")
         for v in val:
             self.buf.write(self._indent() + "   +- " + self._repr_task_element(v, proc.processor.key, process=True) + "\n")
 
@@ -155,9 +88,7 @@ class UOWDumper(unitofwork.UOWExecutor):
                 objid = "%s.%s" % (mapperutil.state_str(te.state), attribute)
             else:
                 objid = mapperutil.state_str(te.state)
-        if self.verbose:
-            return "%s (UOWTaskElement(%s, %s))" % (objid, hex(id(te)), (te.listonly and 'listonly' or (te.isdelete and 'delete' or 'save')))
-        elif process:
+        if process:
             return "Process %s" % (objid)
         else:
             return "%s %s" % ((te.isdelete and "Delete" or "Save"), objid)
index 19e5e59b93280f213d021dfc47b63b8552fb26bf..09b5aa7780b45a23d8ce70ce0bd21678d5ccaf13 100644 (file)
@@ -4,15 +4,19 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-from sqlalchemy import sql, util, exceptions
-from sqlalchemy.sql import util as sql_util
-from sqlalchemy.sql.util import row_adapter as create_row_adapter
-from sqlalchemy.sql import visitors
-from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator
+import new
 
-all_cascades = util.Set(["delete", "delete-orphan", "all", "merge",
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import sql, util
+from sqlalchemy.sql import expression, util as sql_util, operators
+from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, MapperProperty
+from sqlalchemy.orm import attributes
+
+all_cascades = util.FrozenSet(["delete", "delete-orphan", "all", "merge",
                          "expunge", "save-update", "refresh-expire", "none"])
 
+_INSTRUMENTOR = ('mapper', 'instrumentor')
+
 class CascadeOptions(object):
     """Keeps track of the options sent to relation().cascade"""
 
@@ -26,7 +30,7 @@ class CascadeOptions(object):
         self.refresh_expire = "refresh-expire" in values or "all" in values
         for x in values:
             if x not in all_cascades:
-                raise exceptions.ArgumentError("Invalid cascade option '%s'" % x)
+                raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x)
 
     def __contains__(self, item):
         return getattr(self, item.replace("-", "_"), False)
@@ -78,235 +82,277 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'):
             result.append(sql.select([col(name, table) for name in colnames], from_obj=[table]))
     return sql.union_all(*result).alias(aliasname)
 
+def identity_key(*args, **kwargs):
+    """Get an identity key.
 
-class ExtensionCarrier(object):
-    """stores a collection of MapperExtension objects.
-    
-    allows an extension methods to be called on contained MapperExtensions
-    in the order they were added to this object.  Also includes a 'methods' dictionary
-    accessor which allows for a quick check if a particular method
-    is overridden on any contained MapperExtensions.
+    Valid call signatures:
+
+    * ``identity_key(class, ident, entity_name=None)``
+
+      class
+          mapped class (must be a positional argument)
+
+      ident
+          primary key, if the key is composite this is a tuple
+
+      entity_name
+          optional entity name
+
+    * ``identity_key(instance=instance)``
+
+      instance
+          object instance (must be given as a keyword arg)
+
+    * ``identity_key(class, row=row, entity_name=None)``
+
+      class
+          mapped class (must be a positional argument)
+
+      row
+          result proxy row (must be given as a keyword arg)
+
+      entity_name
+          optional entity name (must be given as a keyword arg)
     """
+    from sqlalchemy.orm import class_mapper, object_mapper
+    if args:
+        if len(args) == 1:
+            class_ = args[0]
+            try:
+                row = kwargs.pop("row")
+            except KeyError:
+                ident = kwargs.pop("ident")
+            entity_name = kwargs.pop("entity_name", None)
+        elif len(args) == 2:
+            class_, ident = args
+            entity_name = kwargs.pop("entity_name", None)
+        elif len(args) == 3:
+            class_, ident, entity_name = args
+        else:
+            raise sa_exc.ArgumentError("expected up to three "
+                "positional arguments, got %s" % len(args))
+        if kwargs:
+            raise sa_exc.ArgumentError("unknown keyword arguments: %s"
+                % ", ".join(kwargs.keys()))
+        mapper = class_mapper(class_, entity_name=entity_name)
+        if "ident" in locals():
+            return mapper.identity_key_from_primary_key(ident)
+        return mapper.identity_key_from_row(row)
+    instance = kwargs.pop("instance")
+    if kwargs:
+        raise sa_exc.ArgumentError("unknown keyword arguments: %s"
+            % ", ".join(kwargs.keys()))
+    mapper = object_mapper(instance)
+    return mapper.identity_key_from_instance(instance)
     
-    def __init__(self, _elements=None):
+class ExtensionCarrier(object):
+    """Fronts an ordered collection of MapperExtension objects.
+
+    Bundles multiple MapperExtensions into a unified callable unit,
+    encapsulating ordering, looping and EXT_CONTINUE logic.  The
+    ExtensionCarrier implements the MapperExtension interface, e.g.::
+
+      carrier.after_insert(...args...)
+
+    Also includes a 'methods' dictionary accessor which allows for a quick
+    check if a particular method is overridden on any contained
+    MapperExtensions.
+
+    """
+
+    interface = util.Set([method for method in dir(MapperExtension)
+                          if not method.startswith('_')])
+
+    def __init__(self, extensions=None):
         self.methods = {}
-        if _elements is not None:
-            self.__elements = [self.__inspect(e) for e in _elements]
-        else:
-            self.__elements = []
-        
-    def copy(self):
-        return ExtensionCarrier(list(self.__elements))
-        
-    def __iter__(self):
-        return iter(self.__elements)
+        self._extensions = []
+        for ext in extensions or ():
+            self.append(ext)
 
-    def insert(self, extension):
-        """Insert a MapperExtension at the beginning of this ExtensionCarrier's list."""
+    def copy(self):
+        return ExtensionCarrier(self._extensions)
 
-        self.__elements.insert(0, self.__inspect(extension))
+    def push(self, extension):
+        """Insert a MapperExtension at the beginning of the collection."""
+        self._register(extension)
+        self._extensions.insert(0, extension)
 
     def append(self, extension):
-        """Append a MapperExtension at the end of this ExtensionCarrier's list."""
+        """Append a MapperExtension at the end of the collection."""
+        self._register(extension)
+        self._extensions.append(extension)
 
-        self.__elements.append(self.__inspect(extension))
+    def __iter__(self):
+        """Iterate over MapperExtensions in the collection."""
+        return iter(self._extensions)
+
+    def _register(self, extension):
+        """Register callable fronts for overridden interface methods."""
+        for method in self.interface:
+            if method in self.methods:
+                continue
+            impl = getattr(extension, method, None)
+            if impl and impl is not getattr(MapperExtension, method):
+                self.methods[method] = self._create_do(method)
+
+    def _create_do(self, method):
+        """Return a closure that loops over impls of the named method."""
 
-    def __inspect(self, extension):
-        for meth in MapperExtension.__dict__.keys():
-            if meth not in self.methods and hasattr(extension, meth) and getattr(extension, meth) is not getattr(MapperExtension, meth):
-                self.methods[meth] = self.__create_do(meth)
-        return extension
-           
-    def __create_do(self, funcname):
         def _do(*args, **kwargs):
-            for elem in self.__elements:
-                ret = getattr(elem, funcname)(*args, **kwargs)
+            for ext in self._extensions:
+                ret = getattr(ext, method)(*args, **kwargs)
                 if ret is not EXT_CONTINUE:
                     return ret
             else:
                 return EXT_CONTINUE
-
         try:
-            _do.__name__ = funcname
+            _do.__name__ = method.im_func.func_name
         except:
-            # cant set __name__ in py 2.3 
             pass
         return _do
-    
-    def _pass(self, *args, **kwargs):
+
+    def _pass(*args, **kwargs):
         return EXT_CONTINUE
-        
+    _pass = staticmethod(_pass)
+
     def __getattr__(self, key):
+        """Delegate MapperExtension methods to bundled fronts."""
+        if key not in self.interface:
+            raise AttributeError(key)
         return self.methods.get(key, self._pass)
 
-class AliasedClauses(object):
-    """Creates aliases of a mapped tables for usage in ORM queries, and provides expression adaptation."""
-
-    def __init__(self, alias, equivalents=None, chain_to=None, should_adapt=True):
-        self.alias = alias
-        self.equivalents = equivalents
-        self.row_decorator = self._create_row_adapter()
-        self.should_adapt = should_adapt
-        if should_adapt:
-            self.adapter = sql_util.ClauseAdapter(self.alias, equivalents=equivalents)
+class ORMAdapter(sql_util.ColumnAdapter):
+    def __init__(self, entity, equivalents=None, chain_to=None):
+        mapper, selectable, is_aliased_class = _entity_info(entity)
+        if is_aliased_class:
+            self.aliased_class = entity
         else:
-            self.adapter = visitors.NullVisitor()
-
-        if chain_to:
-            self.adapter.chain(chain_to.adapter)
-            
-    def aliased_column(self, column):
-        if not self.should_adapt:
-            return column
-            
-        conv = self.alias.corresponding_column(column)
-        if conv:
-            return conv
-        
-        # process column-level subqueries    
-        aliased_column = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).traverse(column, clone=True)
-
-        # anonymize labels which might have specific names
-        if isinstance(aliased_column, expression._Label):
-            aliased_column = aliased_column.label(None)
-
-        # add to row decorator explicitly
-        self.row_decorator({}).map[column] = aliased_column
-        return aliased_column
-
-    def adapt_clause(self, clause):
-        return self.adapter.traverse(clause, clone=True)
-    
-    def adapt_list(self, clauses):
-        return self.adapter.copy_and_process(clauses)
-        
-    def _create_row_adapter(self):
-        return create_row_adapter(self.alias, equivalent_columns=self.equivalents)
+            self.aliased_class = None
+        sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to)
 
+class AliasedClass(object):
+    def __init__(self, cls, alias=None, name=None):
+        self.__mapper = _class_to_mapper(cls)
+        self.__target = self.__mapper.class_
+        alias = alias or self.__mapper._with_polymorphic_selectable.alias()
+        self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns)
+        self.__alias = alias
+        self._sa_label_name = name
+        self.__name__ = 'AliasedClass_' + str(self.__target)
+
+    def __adapt_prop(self, prop):
+        existing = getattr(self.__target, prop.key)
+        comparator = AliasedComparator(self, self.__adapter, existing.comparator)
+        queryattr = attributes.QueryableAttribute(
+            existing.impl, parententity=self, comparator=comparator)
+        setattr(self, prop.key, queryattr)
+        return queryattr
 
-class PropertyAliasedClauses(AliasedClauses):
-    """extends AliasedClauses to add support for primary/secondary joins on a relation()."""
-    
-    def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None, alias=None, should_adapt=True):
-        self.prop = prop
-        self.mapper = self.prop.mapper
-        self.table = self.prop.table
-        self.parentclauses = parentclauses
-
-        if not alias:
-            from_obj = self.mapper._with_polymorphic_selectable()
-            alias = from_obj.alias()
-
-        super(PropertyAliasedClauses, self).__init__(alias, equivalents=self.mapper._equivalent_columns, chain_to=parentclauses, should_adapt=should_adapt)
-        
-        if prop.secondary:
-            self.secondary = prop.secondary.alias()
-            primary_aliasizer = sql_util.ClauseAdapter(self.secondary)
-            secondary_aliasizer = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).chain(sql_util.ClauseAdapter(self.secondary))
-
-            if parentclauses is not None:
-                primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, equivalents=parentclauses.equivalents))
-
-            self.secondaryjoin = secondary_aliasizer.traverse(secondaryjoin, clone=True)
-            self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True)
-        else:
-            primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side, equivalents=self.equivalents)
-            if parentclauses is not None: 
-                primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side, equivalents=parentclauses.equivalents))
-            
-            self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True)
-            self.secondary = None
-            self.secondaryjoin = None
-        
-        if prop.order_by:
-            if prop.secondary:
-                # usually this is not used but occasionally someone has a sort key in their secondary
-                # table, even tho SA does not support writing this column directly
-                self.order_by = secondary_aliasizer.copy_and_process(util.to_list(prop.order_by))
+    def __getattr__(self, key):
+        prop = self.__mapper._get_property(key, raiseerr=False)
+        if prop:
+            return self.__adapt_prop(prop)
+
+        for base in self.__target.__mro__:
+            try:
+                attr = object.__getattribute__(base, key)
+            except AttributeError:
+                continue
             else:
-                self.order_by = primary_aliasizer.copy_and_process(util.to_list(prop.order_by))
-                
+                break
         else:
-            self.order_by = None
+            raise AttributeError(key)
 
-class AliasedClass(object):
-    def __new__(cls, target):
-        from sqlalchemy.orm import attributes
-        mapper = _class_to_mapper(target)
-        alias = mapper.mapped_table.alias()
-        retcls = type(target.__name__ + "Alias", (cls,), {'alias':alias})
-        retcls._class_state = mapper._class_state
-        for prop in mapper.iterate_properties:
-            existing = mapper._class_state.attrs[prop.key]
-            setattr(retcls, prop.key, attributes.InstrumentedAttribute(existing.impl, comparator=AliasedComparator(alias, existing.comparator)))
-
-        return retcls
+        if hasattr(attr, 'func_code'):
+            is_method = getattr(self.__target, key, None)
+            if is_method and is_method.im_self is not None:
+                return new.instancemethod(attr.im_func, self, self)
+            else:
+                return None
+        elif hasattr(attr, '__get__'):
+            return attr.__get__(None, self)
+        else:
+            return attr
 
-    def __init__(self, alias):
-        self.alias = alias
+    def __repr__(self):
+        return '<AliasedClass at 0x%x; %s>' % (
+            id(self), self.__target.__name__)
 
 class AliasedComparator(PropComparator):
-    def __init__(self, alias, comparator):
-        self.alias = alias
+    def __init__(self, aliasedclass, adapter, comparator):
+        self.aliasedclass = aliasedclass
         self.comparator = comparator
-        self.adapter = sql_util.ClauseAdapter(alias) 
+        self.adapter = adapter
+        self.__clause_element = self.adapter.traverse(self.comparator.__clause_element__())._annotate({'parententity': aliasedclass})
 
-    def clause_element(self):
-        return self.adapter.traverse(self.comparator.clause_element(), clone=True)
+    def __clause_element__(self):
+        return self.__clause_element
 
     def operate(self, op, *other, **kwargs):
-        return self.adapter.traverse(self.comparator.operate(op, *other, **kwargs), clone=True)
+        return self.adapter.traverse(self.comparator.operate(op, *other, **kwargs))
 
     def reverse_operate(self, op, other, **kwargs):
-        return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs), clone=True)
-
-from sqlalchemy.sql import expression
-_selectable = expression._selectable
-def _orm_selectable(selectable):
-    if _is_mapped_class(selectable):
-        if _is_aliased_class(selectable):
-            return selectable.alias
-        else:
-            return _class_to_mapper(selectable)._with_polymorphic_selectable()
-    else:
-        return _selectable(selectable)
-expression._selectable = _orm_selectable
+        return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs))
+
+def _orm_annotate(element, exclude=None):
+    def clone(elem):
+        if exclude and elem in exclude:
+            elem = elem._clone()
+        elif '_orm_adapt' not in elem._annotations:
+            elem = elem._annotate({'_orm_adapt':True})
+        elem._copy_internals(clone=clone)
+        return elem
+    
+    if element is not None:
+        element = clone(element)
+    return element
+
 
 class _ORMJoin(expression.Join):
-    """future functionality."""
 
     __visit_name__ = expression.Join.__visit_name__
-    
+
     def __init__(self, left, right, onclause=None, isouter=False):
-        if _is_mapped_class(left) or _is_mapped_class(right):
-            if hasattr(left, '_orm_mappers'):
-                left_mapper = left._orm_mappers[1]
-                adapt_from = left.right
+        if hasattr(left, '_orm_mappers'):
+            left_mapper = left._orm_mappers[1]
+            adapt_from = left.right
+
+        else:
+            left_mapper, left, left_is_aliased = _entity_info(left)
+            if left_is_aliased or not left_mapper:
+                adapt_from = left
             else:
-                left_mapper = _class_to_mapper(left)
-                if _is_aliased_class(left):
-                    adapt_from = left.alias
-                else:
-                    adapt_from = None
+                adapt_from = None
 
-            right_mapper = _class_to_mapper(right)
+        right_mapper, right, right_is_aliased = _entity_info(right)
+        if right_is_aliased:
+            adapt_to = right
+        else:
+            adapt_to = None
+
+        if left_mapper or right_mapper:
             self._orm_mappers = (left_mapper, right_mapper)
-            
+
             if isinstance(onclause, basestring):
                 prop = left_mapper.get_property(onclause)
+            elif isinstance(onclause, attributes.QueryableAttribute):
+                adapt_from = onclause.__clause_element__()
+                prop = onclause.property
+            elif isinstance(onclause, MapperProperty):
+                prop = onclause
+            else:
+                prop = None
 
-                if _is_aliased_class(right):
-                    adapt_to = right.alias
-                else:
-                    adapt_to = None
-
-                pj, sj, source, dest, target_adapter = prop._create_joins(source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, dest_polymorphic=True)
+            if prop:
+                pj, sj, source, dest, secondary, target_adapter = prop._create_joins(source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, dest_polymorphic=True)
 
                 if sj:
-                    left = sql.join(left, prop.secondary, onclause=pj)
+                    left = sql.join(left, secondary, pj, isouter)
                     onclause = sj
                 else:
                     onclause = pj
+                self._target_adapter = target_adapter
+
         expression.Join.__init__(self, left, right, onclause, isouter)
 
     def join(self, right, onclause=None, isouter=False):
@@ -315,37 +361,81 @@ class _ORMJoin(expression.Join):
     def outerjoin(self, right, onclause=None):
         return _ORMJoin(self, right, onclause, True)
 
-def _join(left, right, onclause=None):
-    """future functionality."""
-    
-    return _ORMJoin(left, right, onclause, False)
-
-def _outerjoin(left, right, onclause=None):
-    """future functionality."""
+def join(left, right, onclause=None, isouter=False):
+    return _ORMJoin(left, right, onclause, isouter)
 
+def outerjoin(left, right, onclause=None):
     return _ORMJoin(left, right, onclause, True)
-    
-def has_identity(object):
-    return hasattr(object, '_instance_key')
 
-def _state_has_identity(state):
-    return '_instance_key' in state.dict
+def with_parent(instance, prop):
+    """Return criterion which selects instances with a given parent.
 
-def _is_mapped_class(cls):
-    return hasattr(cls, '_class_state')
+    instance
+      a parent instance, which should be persistent or detached.
+
+     property
+       a class-attached descriptor, MapperProperty or string property name
+       attached to the parent instance.
+
+     \**kwargs
+       all extra keyword arguments are propagated to the constructor of
+       Query.
 
-def _is_aliased_class(obj):
-    return isinstance(obj, type) and issubclass(obj, AliasedClass)
-    
-def has_mapper(object):
-    """Return True if the given object has had a mapper association
-    set up, either through loading, or via insertion in a session.
     """
+    if isinstance(prop, basestring):
+        mapper = object_mapper(instance)
+        prop = mapper.get_property(prop, resolve_synonyms=True)
+    elif isinstance(prop, attributes.QueryableAttribute):
+        prop = prop.property
+
+    return prop.compare(operators.eq, instance, value_is_parent=True)
+
+
+def _entity_info(entity, entity_name=None, compile=True):
+    if isinstance(entity, AliasedClass):
+        return entity._AliasedClass__mapper, entity._AliasedClass__alias, True
+    elif _is_mapped_class(entity):
+        if isinstance(entity, type):
+            mapper = class_mapper(entity, entity_name, compile)
+        else:
+            if compile:
+                mapper = entity.compile()
+            else:
+                mapper = entity
+        return mapper, mapper._with_polymorphic_selectable, False
+    else:
+        return None, entity, False
+
+def _entity_descriptor(entity, key):
+    if isinstance(entity, AliasedClass):
+        desc = getattr(entity, key)
+        return desc, desc.property
+    elif isinstance(entity, type):
+        desc = attributes.manager_of_class(entity)[key]
+        return desc, desc.property
+    else:
+        desc = entity.class_manager[key]
+        return desc, desc.property
+
+def _orm_columns(entity):
+    mapper, selectable, is_aliased_class = _entity_info(entity)
+    if isinstance(selectable, expression.Selectable):
+        return [c for c in selectable.c]
+    else:
+        return [selectable]
+
+def _orm_selectable(entity):
+    mapper, selectable, is_aliased_class = _entity_info(entity)
+    return selectable
 
-    return hasattr(object, '_entity_name')
+def _is_aliased_class(entity):
+    return isinstance(entity, AliasedClass)
 
 def _state_mapper(state, entity_name=None):
-    return state.class_._class_state.mappers[state.dict.get('_entity_name', entity_name)]
+    if state.entity_name is not attributes.NO_ENTITY_NAME:
+        # Override the given entity name if the object is not transient.
+        entity_name = state.entity_name
+    return state.manager.mappers[entity_name]
 
 def object_mapper(object, entity_name=None, raiseerror=True):
     """Given an object, return the primary Mapper associated with the object instance.
@@ -363,36 +453,40 @@ def object_mapper(object, entity_name=None, raiseerror=True):
             be located.  If False, return None.
 
     """
-
-    try:
-        mapper = object.__class__._class_state.mappers[getattr(object, '_entity_name', entity_name)]
-    except (KeyError, AttributeError):
-        if raiseerror:
-            raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', entity_name)))
-        else:
-            return None
-    return mapper
+    state = attributes.instance_state(object)
+    if state.entity_name is not attributes.NO_ENTITY_NAME:
+        # Override the given entity name if the object is not transient.
+        entity_name = state.entity_name
+    return class_mapper(
+        type(object), entity_name=entity_name,
+        compile=False, raiseerror=raiseerror)
 
 def class_mapper(class_, entity_name=None, compile=True, raiseerror=True):
-    """Given a class and optional entity_name, return the primary Mapper associated with the key.
+    """Given a class (or an object) and optional entity_name, return the primary Mapper associated with the key.
 
     If no mapper can be located, raises ``InvalidRequestError``.
-    """
 
+    """
+    
+    if not isinstance(class_, type):
+        class_ = type(class_)
     try:
-        mapper = class_._class_state.mappers[entity_name]
+        class_manager = attributes.manager_of_class(class_)
+        mapper = class_manager.mappers[entity_name]
     except (KeyError, AttributeError):
-        if raiseerror:
-            raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name))
-        else:
-            return None
+        if not raiseerror:
+            return
+        raise sa_exc.InvalidRequestError(
+            "Class '%s' entity name '%s' has no mapper associated with it" %
+            (class_.__name__, entity_name))
     if compile:
-        return mapper.compile()
-    else:
-        return mapper
+        mapper = mapper.compile()
+    return mapper
 
 def _class_to_mapper(class_or_mapper, entity_name=None, compile=True):
-    if isinstance(class_or_mapper, type):
+    if _is_aliased_class(class_or_mapper):
+        return class_or_mapper._AliasedClass__mapper
+    elif isinstance(class_or_mapper, type):
         return class_mapper(class_or_mapper, entity_name=entity_name, compile=compile)
     else:
         if compile:
@@ -400,10 +494,32 @@ def _class_to_mapper(class_or_mapper, entity_name=None, compile=True):
         else:
             return class_or_mapper
 
+def has_identity(object):
+    state = attributes.instance_state(object)
+    return _state_has_identity(state)
+
+def _state_has_identity(state):
+    return bool(state.key)
+
+def has_mapper(object):
+    state = attributes.instance_state(object)
+    return _state_has_mapper(state)
+
+def _state_has_mapper(state):
+    return state.entity_name is not attributes.NO_ENTITY_NAME
+
+def _is_mapped_class(cls):
+    from sqlalchemy.orm import mapperlib as mapper
+    if isinstance(cls, (AliasedClass, mapper.Mapper)):
+        return True
+
+    manager = attributes.manager_of_class(cls)
+    return manager and _INSTRUMENTOR in manager.info
+
 def instance_str(instance):
     """Return a string describing an instance."""
 
-    return instance.__class__.__name__ + "@" + hex(id(instance))
+    return state_str(attributes.instance_state(instance))
 
 def state_str(state):
     """Return a string describing an instance."""
@@ -415,12 +531,24 @@ def state_str(state):
 def attribute_str(instance, attribute):
     return instance_str(instance) + "." + attribute
 
+def state_attribute_str(state, attribute):
+    return state_str(state) + "." + attribute
+
 def identity_equal(a, b):
     if a is b:
         return True
-    id_a = getattr(a, '_instance_key', None)
-    id_b = getattr(b, '_instance_key', None)
-    if id_a is None or id_b is None:
+    if a is None or b is None:
+        return False
+    try:
+        state_a = attributes.instance_state(a)
+        state_b = attributes.instance_state(b)
+    except (KeyError, AttributeError):
+        return False
+    if state_a.key is None or state_b.key is None:
         return False
-    return id_a == id_b
+    return state_a.key == state_b.key
 
+# TODO: Avoid circular import.
+attributes.identity_equal = identity_equal
+attributes._is_aliased_class = _is_aliased_class
+attributes._entity_info = _entity_info
index 31adf77d12d3e8f61f78b9ec6164e8b7bb80f157..c1b29a1d0d0bafa267fd01a2f20b7f566e82fe62 100644 (file)
@@ -18,7 +18,7 @@ SQLAlchemy connection pool.
 
 import weakref, time
 
-from sqlalchemy import exceptions, logging
+from sqlalchemy import exc, log
 from sqlalchemy import queue as Queue
 from sqlalchemy.util import thread, threading, pickle, as_interface
 
@@ -118,7 +118,7 @@ class Pool(object):
     """
     def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=True,
                  reset_on_return=True, listeners=None):
-        self.logger = logging.instance_logger(self, echoflag=echo)
+        self.logger = log.instance_logger(self, echoflag=echo)
         # the WeakValueDictionary works more nicely than a regular dict of
         # weakrefs.  the latter can pile up dead reference objects which don't
         # get cleaned out.  WVD adds from 1-6 method calls to a checkout
@@ -342,7 +342,7 @@ class _ConnectionFairy(object):
             return self._connection_record.info
         except AttributeError:
             if self.connection is None:
-                raise exceptions.InvalidRequestError("This connection is closed")
+                raise exc.InvalidRequestError("This connection is closed")
             try:
                 return self._detached_info
             except AttributeError:
@@ -359,7 +359,7 @@ class _ConnectionFairy(object):
         """
 
         if self.connection is None:
-            raise exceptions.InvalidRequestError("This connection is closed")
+            raise exc.InvalidRequestError("This connection is closed")
         if self._connection_record is not None:
             self._connection_record.invalidate(e=e)
         self.connection = None
@@ -378,8 +378,8 @@ class _ConnectionFairy(object):
 
     def checkout(self):
         if self.connection is None:
-            raise exceptions.InvalidRequestError("This connection is closed")
-        self.__counter +=1
+            raise exc.InvalidRequestError("This connection is closed")
+        self.__counter += 1
 
         if not self._pool._on_checkout or self.__counter != 1:
             return self
@@ -391,7 +391,7 @@ class _ConnectionFairy(object):
                 for l in self._pool._on_checkout:
                     l.checkout(self.connection, self._connection_record, self)
                 return self
-            except exceptions.DisconnectionError, e:
+            except exc.DisconnectionError, e:
                 if self._pool._should_log_info:
                     self._pool.log(
                     "Disconnection detected on checkout: %s" % e)
@@ -402,7 +402,7 @@ class _ConnectionFairy(object):
         if self._pool._should_log_info:
             self._pool.log("Reconnection attempts exhausted on checkout")
         self.invalidate()
-        raise exceptions.InvalidRequestError("This connection is closed")
+        raise exc.InvalidRequestError("This connection is closed")
 
     def detach(self):
         """Separate this connection from its Pool.
@@ -426,7 +426,7 @@ class _ConnectionFairy(object):
             self._connection_record = None
 
     def close(self):
-        self.__counter -=1
+        self.__counter -= 1
         if self.__counter == 0:
             self._close()
 
@@ -601,7 +601,7 @@ class QueuePool(Pool):
                 if not wait:
                     return self.do_get()
                 else:
-                    raise exceptions.TimeoutError("QueuePool limit of size %d overflow %d reached, connection timed out, timeout %d" % (self.size(), self.overflow(), self._timeout))
+                    raise exc.TimeoutError("QueuePool limit of size %d overflow %d reached, connection timed out, timeout %d" % (self.size(), self.overflow(), self._timeout))
 
             if self._overflow_lock is not None:
                 self._overflow_lock.acquire()
@@ -658,10 +658,10 @@ class NullPool(Pool):
         return "NullPool"
 
     def do_return_conn(self, conn):
-       conn.close()
+        conn.close()
 
     def do_return_invalid(self, conn):
-       pass
+        pass
 
     def do_get(self):
         return self.create_connection()
index 9a4bf4109ca5b940808817024403389b5f7a80a1..1f0b52ace645e49f43fb121982a27254912cb135 100644 (file)
@@ -27,7 +27,7 @@ are part of the SQL expression language, they are usable as components in SQL
 expressions.  """
 
 import re, inspect
-from sqlalchemy import types, exceptions, util, databases
+from sqlalchemy import types, exc, util, databases
 from sqlalchemy.sql import expression, visitors
 
 URL = None
@@ -42,10 +42,11 @@ class SchemaItem(object):
     """Base class for items that define a database schema."""
 
     __metaclass__ = expression._FigureVisitName
-
+    quote = None
+    
     def _init_items(self, *args):
         """Initialize the list of child items for this SchemaItem."""
-
+        
         for item in args:
             if item is not None:
                 item._set_parent(self)
@@ -95,7 +96,7 @@ class _TableSingleton(expression._FigureVisitName):
         try:
             table = metadata.tables[key]
             if not useexisting and table._cant_override(*args, **kwargs):
-                raise exceptions.InvalidRequestError(
+                raise exc.InvalidRequestError(
                     "Table '%s' is already defined for this MetaData instance.  "
                     "Specify 'useexisting=True' to redefine options and "
                     "columns on an existing Table object." % key)
@@ -104,7 +105,7 @@ class _TableSingleton(expression._FigureVisitName):
             return table
         except KeyError:
             if mustexist:
-                raise exceptions.InvalidRequestError(
+                raise exc.InvalidRequestError(
                     "Table '%s' not defined" % (key))
             try:
                 return type.__call__(self, name, metadata, *args, **kwargs)
@@ -182,17 +183,19 @@ class Table(SchemaItem, expression.TableClause):
             Deprecated; this is an oracle-only argument - "schema" should
             be used in its place.
 
-          quote
-            When True, indicates that the Table identifier must be quoted.
-            This flag does *not* disable quoting; for case-insensitive names,
-            use an all lower case identifier.
+        quote
+          Force quoting of the identifier on or off, based on `True` or
+          `False`.  Defaults to `None`.  This flag is rarely needed, 
+          as quoting is normally applied
+          automatically for known reserved words, as well as for
+          "case sensitive" identifiers.  An identifier is "case sensitive"
+          if it contains non-lowercase letters, otherwise it's 
+          considered to be "case insensitive".
 
           quote_schema
-            When True, indicates that the schema identifier must be quoted.
-            This flag does *not* disable quoting; for case-insensitive names,
-            use an all lower case identifier.
+            same as 'quote' but applies to the schema identifier.
+            
         """
-
         super(Table, self).__init__(name)
         self.metadata = metadata
         self.schema = kwargs.pop('schema', kwargs.pop('owner', None))
@@ -214,7 +217,7 @@ class Table(SchemaItem, expression.TableClause):
 
         self._set_parent(metadata)
 
-       self.__extra_kwargs(**kwargs)
+        self.__extra_kwargs(**kwargs)
 
         # load column definitions from the database if 'autoload' is defined
         # we do it after the table is in the singleton dictionary to support
@@ -234,7 +237,7 @@ class Table(SchemaItem, expression.TableClause):
         autoload_with = kwargs.pop('autoload_with', None)
         schema = kwargs.pop('schema', None)
         if schema and schema != self.schema:
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "Can't change schema of existing table from '%s' to '%s'",
                 (self.schema, schema))
 
@@ -258,8 +261,8 @@ class Table(SchemaItem, expression.TableClause):
             ['autoload', 'autoload_with', 'schema', 'owner']))
 
     def __extra_kwargs(self, **kwargs):
-        self.quote = kwargs.pop('quote', False)
-        self.quote_schema = kwargs.pop('quote_schema', False)
+        self.quote = kwargs.pop('quote', None)
+        self.quote_schema = kwargs.pop('quote_schema', None)
         if kwargs.get('info'):
             self._info = kwargs.pop('info')
 
@@ -488,9 +491,13 @@ class Column(SchemaItem, expression._ColumnClause):
             or subtype of Integer.
 
           quote
-            When True, indicates that the Column identifier must be quoted.
-            This flag does *not* disable quoting; for case-insensitive names,
-            use an all lower case identifier.
+            Force quoting of the identifier on or off, based on `True` or
+            `False`.  Defaults to `None`.  This flag is rarely needed, 
+            as quoting is normally applied
+            automatically for known reserved words, as well as for
+            "case sensitive" identifiers.  An identifier is "case sensitive"
+            if it contains non-lowercase letters, otherwise it's 
+            considered to be "case insensitive".
         """
 
         name = kwargs.pop('name', None)
@@ -499,7 +506,7 @@ class Column(SchemaItem, expression._ColumnClause):
             args = list(args)
             if isinstance(args[0], basestring):
                 if name is not None:
-                    raise exceptions.ArgumentError(
+                    raise exc.ArgumentError(
                         "May not pass name positionally and as a keyword.")
                 name = args.pop(0)
         if args:
@@ -507,7 +514,7 @@ class Column(SchemaItem, expression._ColumnClause):
                 (isinstance(args[0], type) and
                  issubclass(args[0], types.AbstractType))):
                 if type_ is not None:
-                    raise exceptions.ArgumentError(
+                    raise exc.ArgumentError(
                         "May not pass type_ positionally and as a keyword.")
                 type_ = args.pop(0)
 
@@ -520,15 +527,17 @@ class Column(SchemaItem, expression._ColumnClause):
         self.default = kwargs.pop('default', None)
         self.index = kwargs.pop('index', None)
         self.unique = kwargs.pop('unique', None)
-        self.quote = kwargs.pop('quote', False)
+        self.quote = kwargs.pop('quote', None)
         self.onupdate = kwargs.pop('onupdate', None)
         self.autoincrement = kwargs.pop('autoincrement', True)
         self.constraints = util.Set()
         self.foreign_keys = util.OrderedSet()
+        util.set_creation_order(self)
+
         if kwargs.get('info'):
             self._info = kwargs.pop('info')
         if kwargs:
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "Unknown arguments passed to Column: " + repr(kwargs.keys()))
 
     def __str__(self):
@@ -545,7 +554,7 @@ class Column(SchemaItem, expression._ColumnClause):
     bind = property(bind)
 
     def references(self, column):
-        """Return True if this references the given column via a foreign key."""
+        """Return True if this Column references the given column via foreign key."""
         for fk in self.foreign_keys:
             if fk.references(column.table):
                 return True
@@ -576,14 +585,14 @@ class Column(SchemaItem, expression._ColumnClause):
 
     def _set_parent(self, table):
         if self.name is None:
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "Column must be constructed with a name or assign .name "
                 "before adding to a Table.")
         if self.key is None:
             self.key = self.name
         self.metadata = table.metadata
         if getattr(self, 'table', None) is not None:
-            raise exceptions.ArgumentError("this Column already has a table!")
+            raise exc.ArgumentError("this Column already has a table!")
         if not self._is_oid:
             self._pre_existing_column = table._columns.get(self.key)
 
@@ -594,7 +603,7 @@ class Column(SchemaItem, expression._ColumnClause):
         if self.primary_key:
             table.primary_key.replace(self)
         elif self.key in table.primary_key:
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "Trying to redefine primary-key column '%s' as a "
                 "non-primary-key column on table '%s'" % (
                 self.key, table.fullname))
@@ -604,14 +613,14 @@ class Column(SchemaItem, expression._ColumnClause):
 
         if self.index:
             if isinstance(self.index, basestring):
-                raise exceptions.ArgumentError(
+                raise exc.ArgumentError(
                     "The 'index' keyword argument on Column is boolean only. "
                     "To create indexes with a specific name, create an "
                     "explicit Index object external to the Table.")
             Index('ix_%s' % self._label, self, unique=self.unique)
         elif self.unique:
             if isinstance(self.unique, basestring):
-                raise exceptions.ArgumentError(
+                raise exc.ArgumentError(
                     "The 'unique' keyword argument on Column is boolean only. "
                     "To create unique constraints or indexes with a specific "
                     "name, append an explicit UniqueConstraint to the Table's "
@@ -631,17 +640,17 @@ class Column(SchemaItem, expression._ColumnClause):
         """Create a copy of this ``Column``, unitialized.
 
         This is used in ``Table.tometadata``.
-        """
 
+        """
         return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, index=self.index, autoincrement=self.autoincrement, *[c.copy() for c in self.constraints])
-
-    def _make_proxy(self, selectable, name = None):
+    
+    def _make_proxy(self, selectable, name=None):
         """Create a *proxy* for this column.
 
         This is a copy of this ``Column`` referenced by a different parent
         (such as an alias or select statement).
-        """
 
+        """
         fk = [ForeignKey(f._colspec) for f in self.foreign_keys]
         c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk)
         c.table = selectable
@@ -654,7 +663,6 @@ class Column(SchemaItem, expression._ColumnClause):
         [c._init_items(f) for f in fk]
         return c
 
-
     def get_children(self, schema_visitor=False, **kwargs):
         if schema_visitor:
             return [x for x in (self.default, self.onupdate) if x is not None] + \
@@ -670,8 +678,8 @@ class ForeignKey(SchemaItem):
 
     For a composite (multiple column) FOREIGN KEY, use a ForeignKeyConstraint
     within the Table definition.
-    """
 
+    """
     def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None):
         """Construct a column-level FOREIGN KEY.
 
@@ -742,14 +750,15 @@ class ForeignKey(SchemaItem):
 
     def references(self, table):
         """Return True if the given table is referenced by this ForeignKey."""
-
         return table.corresponding_column(self.column) is not None
 
     def get_referent(self, table):
         """Return the column in the given table referenced by this ForeignKey.
 
         Returns None if this ``ForeignKey`` does not reference the given table.
+
         """
+
         return table.corresponding_column(self.column)
 
     def column(self):
@@ -766,22 +775,22 @@ class ForeignKey(SchemaItem):
                         parenttable = c.table
                         break
                 else:
-                    raise exceptions.ArgumentError(
+                    raise exc.ArgumentError(
                         "Parent column '%s' does not descend from a "
                         "table-attached Column" % str(self.parent))
                 m = re.match(r"^(.+?)(?:\.(.+?))?(?:\.(.+?))?$", self._colspec,
                              re.UNICODE)
                 if m is None:
-                    raise exceptions.ArgumentError(
+                    raise exc.ArgumentError(
                         "Invalid foreign key column specification: %s" %
                         self._colspec)
                 if m.group(3) is None:
                     (tname, colname) = m.group(1, 2)
                     schema = None
                 else:
-                    (schema,tname,colname) = m.group(1,2,3)
+                    (schema, tname, colname) = m.group(1, 2, 3)
                 if _get_table_key(tname, schema) not in parenttable.metadata:
-                    raise exceptions.NoReferencedTableError(
+                    raise exc.NoReferencedTableError(
                         "Could not find table '%s' with which to generate a "
                         "foreign key" % tname)
                 table = Table(tname, parenttable.metadata,
@@ -797,13 +806,13 @@ class ForeignKey(SchemaItem):
                     else:
                         self._column = table.c[colname]
                 except KeyError, e:
-                    raise exceptions.ArgumentError(
+                    raise exc.ArgumentError(
                         "Could not create ForeignKey '%s' on table '%s': "
                         "table '%s' has no column named '%s'" % (
                         self._colspec, parenttable.name, table.name, str(e)))
-            
-            elif isinstance(self._colspec, expression.Operators):
-                self._column = self._colspec.clause_element()
+
+            elif hasattr(self._colspec, '__clause_element__'):
+                self._column = self._colspec.__clause_element__()
             else:
                 self._column = self._colspec
 
@@ -906,12 +915,11 @@ class ColumnDefault(DefaultGenerator):
 
         defaulted = argspec[3] is not None and len(argspec[3]) or 0
         if positionals - defaulted > 1:
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "ColumnDefault Python function takes zero or one "
                 "positional arguments")
         return fn
 
-
     def _visit_name(self):
         if self.for_update:
             return "column_onupdate"
@@ -926,12 +934,12 @@ class Sequence(DefaultGenerator):
     """Represents a named database sequence."""
 
     def __init__(self, name, start=None, increment=None, schema=None,
-                 optional=False, quote=False, **kwargs):
+                 optional=False, quote=None, **kwargs):
         super(Sequence, self).__init__(**kwargs)
         self.name = name
         self.start = start
         self.increment = increment
-        self.optional=optional
+        self.optional = optional
         self.quote = quote
         self.schema = schema
         self.kwargs = kwargs
@@ -960,7 +968,6 @@ class Sequence(DefaultGenerator):
             bind = _bind_or_error(self)
         bind.drop(self, checkfirst=checkfirst)
 
-
 class Constraint(SchemaItem):
     """A table-level SQL constraint, such as a KEY.
 
@@ -989,8 +996,11 @@ class Constraint(SchemaItem):
         self.initially = initially
 
     def __contains__(self, x):
-        return self.columns.contains_column(x)
-
+        return x in self.columns
+    
+    def contains_column(self, col):
+        return self.columns.contains_column(col)
+        
     def keys(self):
         return self.columns.keys()
 
@@ -1105,7 +1115,7 @@ class ForeignKeyConstraint(Constraint):
         self.onupdate = onupdate
         self.ondelete = ondelete
         if self.name is None and use_alter:
-            raise exceptions.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name")
+            raise exc.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name")
         self.use_alter = use_alter
 
     def _set_parent(self, table):
@@ -1113,7 +1123,7 @@ class ForeignKeyConstraint(Constraint):
         if self not in table.constraints:
             table.constraints.add(self)
             for (c, r) in zip(self.__colnames, self.__refcolnames):
-                self.append_element(c,r)
+                self.append_element(c, r)
 
     def append_element(self, col, refcol):
         fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter)
@@ -1159,7 +1169,7 @@ class PrimaryKeyConstraint(Constraint):
                                deferrable=kwargs.pop('deferrable', None),
                                initially=kwargs.pop('initially', None))
         if kwargs:
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 'Unknown PrimaryKeyConstraint argument(s): %s' %
                 ', '.join([repr(x) for x in kwargs.keys()]))
 
@@ -1174,14 +1184,14 @@ class PrimaryKeyConstraint(Constraint):
 
     def add(self, col):
         self.columns.add(col)
-        col.primary_key=True
+        col.primary_key = True
     append_column = add
 
     def replace(self, col):
         self.columns.replace(col)
 
     def remove(self, col):
-        col.primary_key=False
+        col.primary_key = False
         del self.columns[col.key]
 
     def copy(self):
@@ -1222,7 +1232,7 @@ class UniqueConstraint(Constraint):
                                deferrable=kwargs.pop('deferrable', None),
                                initially=kwargs.pop('initially', None))
         if kwargs:
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 'Unknown UniqueConstraint argument(s): %s' %
                 ', '.join([repr(x) for x in kwargs.keys()]))
 
@@ -1295,11 +1305,11 @@ class Index(SchemaItem):
             self._set_parent(column.table)
         elif column.table != self.table:
             # all columns muse be from same table
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "All index columns must be from same table. "
                 "%s is from %s not %s" % (column, column.table, self.table))
         elif column.name in [ c.name for c in self.columns ]:
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "A column may not appear twice in the "
                 "same index (%s already has column %s)" % (self.name, column))
         self.columns.append(column)
@@ -1370,7 +1380,7 @@ class MetaData(SchemaItem):
         self.ddl_listeners = util.defaultdict(list)
         if reflect:
             if not bind:
-                raise exceptions.ArgumentError(
+                raise exc.ArgumentError(
                     "A bind must be supplied in conjunction with reflect=True")
             self.reflect()
 
@@ -1508,7 +1518,7 @@ class MetaData(SchemaItem):
             missing = [name for name in only if name not in available]
             if missing:
                 s = schema and (" schema '%s'" % schema) or ''
-                raise exceptions.InvalidRequestError(
+                raise exc.InvalidRequestError(
                     'Could not reflect: requested table(s) not available '
                     'in %s%s: (%s)' % (bind.engine.url, s, ', '.join(missing)))
             load = [name for name in only if name not in current]
@@ -1777,12 +1787,12 @@ class DDL(object):
         """
 
         if not isinstance(statement, basestring):
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "Expected a string or unicode SQL statement, got '%r'" %
                 statement)
         if (on is not None and
             (not isinstance(on, basestring) and not callable(on))):
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "Expected the name of a database dialect or a callable for "
                 "'on' criteria, got type '%s'." % type(on).__name__)
 
@@ -1858,10 +1868,10 @@ class DDL(object):
         """
 
         if not hasattr(schema_item, 'ddl_listeners'):
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "%s does not support DDL events" % type(schema_item).__name__)
         if event not in schema_item.ddl_events:
-            raise exceptions.ArgumentError(
+            raise exc.ArgumentError(
                 "Unknown event, expected one of (%s), got '%r'" %
                 (', '.join(schema_item.ddl_events), event))
         schema_item.ddl_listeners[event].append(self)
@@ -1955,5 +1965,5 @@ def _bind_or_error(schemaitem):
                'Execution can not proceed without a database to execute '
                'against.  Either execute with an explicit connection or '
                'assign %s to enable implicit execution.') % (item, bindable)
-        raise exceptions.UnboundExecutionError(msg)
+        raise exc.UnboundExecutionError(msg)
     return bind
index c966f396a29a44cdec5f4d7dda7b104bcab45ae8..5ea9eb1e66d5c8cd058a87240c335fb323f1b558 100644 (file)
@@ -1,2 +1,2 @@
 from sqlalchemy.sql.expression import *
-from sqlalchemy.sql.visitors import ClauseVisitor, NoColumnVisitor
+from sqlalchemy.sql.visitors import ClauseVisitor
index 1fe9ef0622b2c40e5fe2d9441dfc8b504b5264d1..78bb4e31ca56856d23237aaaaf47217c16bfebe4 100644 (file)
@@ -19,7 +19,7 @@ is otherwise internal to SQLAlchemy.
 """
 
 import string, re, itertools
-from sqlalchemy import schema, engine, util, exceptions
+from sqlalchemy import schema, engine, util, exc
 from sqlalchemy.sql import operators, functions
 from sqlalchemy.sql import expression as sql
 
@@ -115,8 +115,6 @@ class DefaultCompiler(engine.Compiled):
     paradigm as visitors.ClauseVisitor but implements its own traversal.
     """
 
-    __traverse_options__ = {'column_collections':False, 'entry':True}
-
     operators = OPERATORS
     functions = FUNCTIONS
 
@@ -162,17 +160,12 @@ class DefaultCompiler(engine.Compiled):
         # for aliases
         self.generated_ids = {}
 
-        # paramstyle from the dialect (comes from DB-API)
-        self.paramstyle = self.dialect.paramstyle
-
         # true if the paramstyle is positional
         self.positional = self.dialect.positional
+        if self.positional:
+            self.positiontup = []
 
-        self.bindtemplate = BIND_TEMPLATES[self.paramstyle]
-
-        # a list of the compiled's bind parameter names, used to help
-        # formulate a positional argument list
-        self.positiontup = []
+        self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle]
 
         # an IdentifierPreparer that formats the quoting of identifiers
         self.preparer = self.dialect.identifier_preparer
@@ -230,15 +223,18 @@ class DefaultCompiler(engine.Compiled):
         return ""
 
     def visit_grouping(self, grouping, **kwargs):
-        return "(" + self.process(grouping.elem) + ")"
+        return "(" + self.process(grouping.element) + ")"
 
-    def visit_label(self, label, result_map=None):
+    def visit_label(self, label, result_map=None, render_labels=False):
+        if not render_labels:
+            return self.process(label.element)
+            
         labelname = self._truncated_identifier("colident", label.name)
 
         if result_map is not None:
-            result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type)
+            result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type)
 
-        return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
+        return " ".join([self.process(label.element), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
 
     def visit_column(self, column, result_map=None, **kwargs):
 
@@ -261,16 +257,16 @@ class DefaultCompiler(engine.Compiled):
         if getattr(column, "is_literal", False):
             name = self.escape_literal_column(name)
         else:
-            name = self.preparer.quote(column, name)
+            name = self.preparer.quote(name, column.quote)
 
         if column.table is None or not column.table.named_with_column:
             return name
         else:
             if getattr(column.table, 'schema', None):
-                schema_prefix = self.preparer.quote(column.table, column.table.schema) + '.'
+                schema_prefix = self.preparer.quote(column.table.schema, column.table.quote_schema) + '.'
             else:
                 schema_prefix = ''
-            return schema_prefix + self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name
+            return schema_prefix + self.preparer.quote(ANONYMOUS_LABEL.sub(self._process_anon, column.table.name), column.table.quote) + "." + name
 
     def escape_literal_column(self, text):
         """provide escaping for the literal_column() construct."""
@@ -387,7 +383,7 @@ class DefaultCompiler(engine.Compiled):
         if name in self.binds:
             existing = self.binds[name]
             if existing is not bindparam and (existing.unique or bindparam.unique):
-                raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
+                raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
         self.binds[bindparam.key] = self.binds[name] = bindparam
         return self.bindparam_string(name)
 
@@ -418,7 +414,7 @@ class DefaultCompiler(engine.Compiled):
         return truncname
     
     def _process_anon(self, match):
-        (ident, derived) = match.group(1,2)
+        (ident, derived) = match.group(1, 2)
 
         key = ('anonymous', ident)
         if key in self.generated_ids:
@@ -436,8 +432,9 @@ class DefaultCompiler(engine.Compiled):
     def bindparam_string(self, name):
         if self.positional:
             self.positiontup.append(name)
-
-        return self.bindtemplate % {'name':name, 'position':len(self.positiontup)}
+            return self.bindtemplate % {'name':name, 'position':len(self.positiontup)}
+        else:
+            return self.bindtemplate % {'name':name}
 
     def visit_alias(self, alias, asfrom=False, **kwargs):
         if asfrom:
@@ -490,7 +487,7 @@ class DefaultCompiler(engine.Compiled):
 
         froms = select._get_display_froms(existingfroms)
 
-        correlate_froms = util.Set(itertools.chain(*([froms] + [f._get_from_objects() for f in froms])))
+        correlate_froms = util.Set(sql._from_objects(*froms))
 
         # TODO: might want to propigate existing froms for select(select(select))
         # where innermost select should correlate to outermost
@@ -504,6 +501,7 @@ class DefaultCompiler(engine.Compiled):
             [c for c in [
                 self.process(
                     self.label_select_column(select, co, asfrom=asfrom), 
+                    render_labels=True,
                     **column_clause_args) 
                 for co in select.inner_columns
             ]
@@ -580,9 +578,9 @@ class DefaultCompiler(engine.Compiled):
     def visit_table(self, table, asfrom=False, **kwargs):
         if asfrom:
             if getattr(table, "schema", None):
-                return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name)
+                return self.preparer.quote(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote)
             else:
-                return self.preparer.quote(table, table.name)
+                return self.preparer.quote(table.name, table.quote)
         else:
             return ""
 
@@ -603,7 +601,7 @@ class DefaultCompiler(engine.Compiled):
 
         return (insert + " INTO %s (%s) VALUES (%s)" %
                 (preparer.format_table(insert_stmt.table),
-                 ', '.join([preparer.quote(c[0], c[0].name)
+                 ', '.join([preparer.quote(c[0].name, c[0].quote)
                             for c in colparams]),
                  ', '.join([c[1] for c in colparams])))
 
@@ -613,7 +611,7 @@ class DefaultCompiler(engine.Compiled):
         self.isupdate = True
         colparams = self._get_colparams(update_stmt)
 
-        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ')
+        text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0].name, c[0].quote), c[1]) for c in colparams], ', ')
 
         if update_stmt._whereclause:
             text += " WHERE " + self.process(update_stmt._whereclause)
@@ -837,7 +835,7 @@ class SchemaGenerator(DDLBase):
         if constraint.name is not None:
             self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint))
         self.append("PRIMARY KEY ")
-        self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint]))
+        self.append("(%s)" % ', '.join([self.preparer.quote(c.name, c.quote) for c in constraint]))
         self.define_constraint_deferrability(constraint)
 
     def visit_foreign_key_constraint(self, constraint):
@@ -858,9 +856,9 @@ class SchemaGenerator(DDLBase):
                         preparer.format_constraint(constraint))
         table = list(constraint.elements)[0].column.table
         self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % (
-            ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]),
+            ', '.join([preparer.quote(f.parent.name, f.parent.quote) for f in constraint.elements]),
             preparer.format_table(table),
-            ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements])
+            ', '.join([preparer.quote(f.column.name, f.column.quote) for f in constraint.elements])
         ))
         if constraint.ondelete is not None:
             self.append(" ON DELETE %s" % constraint.ondelete)
@@ -873,7 +871,7 @@ class SchemaGenerator(DDLBase):
         if constraint.name is not None:
             self.append("CONSTRAINT %s " %
                         self.preparer.format_constraint(constraint))
-        self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint])))
+        self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c.name, c.quote) for c in constraint])))
         self.define_constraint_deferrability(constraint)
 
     def define_constraint_deferrability(self, constraint):
@@ -896,7 +894,7 @@ class SchemaGenerator(DDLBase):
         self.append("INDEX %s ON %s (%s)" \
                     % (preparer.format_index(index),
                        preparer.format_table(index.table),
-                       string.join([preparer.quote(c, c.name) for c in index.columns], ', ')))
+                       string.join([preparer.quote(c.name, c.quote) for c in index.columns], ', ')))
         self.execute()
 
 
@@ -1005,9 +1003,12 @@ class IdentifierPreparer(object):
                 or not self.legal_characters.match(unicode(value))
                 or (lc_value != value))
 
-    def quote(self, obj, ident):
-        if getattr(obj, 'quote', False):
+    def quote(self, ident, force):
+        if force:
             return self.quote_identifier(ident)
+        elif force is False:
+            return ident
+            
         if ident in self.__strings:
             return self.__strings[ident]
         else:
@@ -1017,53 +1018,47 @@ class IdentifierPreparer(object):
                 self.__strings[ident] = ident
             return self.__strings[ident]
 
-    def should_quote(self, object):
-        return object.quote or self._requires_quotes(object.name)
-
     def format_sequence(self, sequence, use_schema=True):
-        name = self.quote(sequence, sequence.name)
+        name = self.quote(sequence.name, sequence.quote)
         if not self.omit_schema and use_schema and sequence.schema is not None:
-            name = self.quote(sequence, sequence.schema) + "." + name
+            name = self.quote(sequence.schema, sequence.quote) + "." + name
         return name
 
     def format_label(self, label, name=None):
-        return self.quote(label, name or label.name)
+        return self.quote(name or label.name, label.quote)
 
     def format_alias(self, alias, name=None):
-        return self.quote(alias, name or alias.name)
+        return self.quote(name or alias.name, alias.quote)
 
     def format_savepoint(self, savepoint, name=None):
-        return self.quote(savepoint, name or savepoint.ident)
+        return self.quote(name or savepoint.ident, savepoint.quote)
 
     def format_constraint(self, constraint):
-        return self.quote(constraint, constraint.name)
+        return self.quote(constraint.name, constraint.quote)
 
     def format_index(self, index):
-        return self.quote(index, index.name)
+        return self.quote(index.name, index.quote)
 
     def format_table(self, table, use_schema=True, name=None):
         """Prepare a quoted table and schema name."""
 
         if name is None:
             name = table.name
-        result = self.quote(table, name)
+        result = self.quote(name, table.quote)
         if not self.omit_schema and use_schema and getattr(table, "schema", None):
-            result = self.quote(table, table.schema) + "." + result
+            result = self.quote(table.schema, table.quote_schema) + "." + result
         return result
 
     def format_column(self, column, use_table=False, name=None, table_name=None):
-        """Prepare a quoted column name.
-
-        deprecated.  use preparer.quote(col, column.name) or combine with format_table()
-        """
+        """Prepare a quoted column name."""
 
         if name is None:
             name = column.name
         if not getattr(column, 'is_literal', False):
             if use_table:
-                return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name)
+                return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote)
             else:
-                return self.quote(column, name)
+                return self.quote(name, column.quote)
         else:
             # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
             if use_table:
@@ -1079,7 +1074,7 @@ class IdentifierPreparer(object):
         # a longer sequence.
 
         if not self.omit_schema and use_schema and getattr(table, 'schema', None):
-            return (self.quote_identifier(table.schema),
+            return (self.quote(table.schema, table.quote_schema),
                     self.format_table(table, use_schema=False))
         else:
             return (self.format_table(table, use_schema=False), )
index 867fdd69c3ce5e6be7d2a14957c31b5a456edf19..7ce6377011289920bc164e6d4495b23b65c29bc5 100644 (file)
@@ -26,12 +26,12 @@ to stay the same in future releases.
 """
 
 import itertools, re
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exc
 from sqlalchemy.sql import operators, visitors
 from sqlalchemy import types as sqltypes
 
 functions, schema, sql_util = None, None, None
-DefaultDialect, ClauseAdapter = None, None
+DefaultDialect, ClauseAdapter, Annotated = None, None, None
 
 __all__ = [
     'Alias', 'ClauseElement',
@@ -503,15 +503,21 @@ def collate(expression, collation):
 
 def exists(*args, **kwargs):
     """Return an ``EXISTS`` clause as applied to a [sqlalchemy.sql.expression#Select] object.
+    
+    Calling styles are of the following forms::
+    
+        # use on an existing select()
+        s = select([<columns>]).where(<criterion>)
+        s = exists(s)
+        
+        # construct a select() at once
+        exists(['*'], **select_arguments).where(<criterion>)
+        
+        # columns argument is optional, generates "EXISTS (SELECT *)"
+        # by default.
+        exists().where(<criterion>) 
 
-    The resulting [sqlalchemy.sql.expression#_Exists] object can be executed by
-    itself or used as a subquery within an enclosing select.
-
-    \*args, \**kwargs
-      all arguments are sent directly to the [sqlalchemy.sql.expression#select()]
-      function to produce a ``SELECT`` statement.
     """
-
     return _Exists(*args, **kwargs)
 
 def union(*selects, **kwargs):
@@ -872,27 +878,36 @@ def _compound_select(keyword, *selects, **kwargs):
     return CompoundSelect(keyword, *selects, **kwargs)
 
 def _is_literal(element):
-    return not isinstance(element, ClauseElement)
+    return not isinstance(element, (ClauseElement, Operators))
+
+def _from_objects(*elements, **kwargs):
+    return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements])
 
+def _labeled(element):
+    if not hasattr(element, 'name'):
+        return element.label(None)
+    else:
+        return element
+        
 def _literal_as_text(element):
-    if isinstance(element, Operators):
-        return element.expression_element()
+    if hasattr(element, '__clause_element__'):
+        return element.__clause_element__()
     elif _is_literal(element):
         return _TextClause(unicode(element))
     else:
         return element
 
 def _literal_as_column(element):
-    if isinstance(element, Operators):
-        return element.clause_element()
+    if hasattr(element, '__clause_element__'):
+        return element.__clause_element__()
     elif _is_literal(element):
         return literal_column(str(element))
     else:
         return element
 
 def _literal_as_binds(element, name=None, type_=None):
-    if isinstance(element, Operators):
-        return element.expression_element()
+    if hasattr(element, '__clause_element__'):
+        return element.__clause_element__()
     elif _is_literal(element):
         if element is None:
             return null()
@@ -902,17 +917,17 @@ def _literal_as_binds(element, name=None, type_=None):
         return element
 
 def _no_literals(element):
-    if isinstance(element, Operators):
-        return element.expression_element()
+    if hasattr(element, '__clause_element__'):
+        return element.__clause_element__()
     elif _is_literal(element):
-        raise exceptions.ArgumentError("Ambiguous literal: %r.  Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
+        raise exc.ArgumentError("Ambiguous literal: %r.  Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element)
     else:
         return element
     
 def _corresponding_column_or_error(fromclause, column, require_embedded=False):
     c = fromclause.corresponding_column(column, require_embedded=require_embedded)
     if not c:
-        raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
+        raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description))
     return c
 
 def _selectable(element):
@@ -921,9 +936,8 @@ def _selectable(element):
     elif isinstance(element, Selectable):
         return element
     else:
-        raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
+        raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element))
 
-    
 def is_column(col):
     """True if ``col`` is an instance of ``ColumnElement``."""
     return isinstance(col, ColumnElement)
@@ -941,7 +955,9 @@ class _FigureVisitName(type):
 class ClauseElement(object):
     """Base class for elements of a programmatically constructed SQL expression."""
     __metaclass__ = _FigureVisitName
-
+    _annotations = {}
+    supports_execution = False
+    
     def _clone(self):
         """Create a shallow copy of this ClauseElement.
 
@@ -976,6 +992,14 @@ class ClauseElement(object):
         """
 
         raise NotImplementedError(repr(self))
+    
+    def _annotate(self, values):
+        """return a copy of this ClauseElement with the given annotations dictionary."""
+
+        global Annotated
+        if Annotated is None:
+            from sqlalchemy.sql.util import Annotated
+        return Annotated(self, values)
 
     def unique_params(self, *optionaldict, **kwargs):
         """Return a copy with ``bindparam()`` elments replaced.
@@ -1006,14 +1030,14 @@ class ClauseElement(object):
         if len(optionaldict) == 1:
             kwargs.update(optionaldict[0])
         elif len(optionaldict) > 1:
-            raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument")
+            raise exc.ArgumentError("params() takes zero or one positional dictionary argument")
 
         def visit_bindparam(bind):
             if bind.key in kwargs:
                 bind.value = kwargs[bind.key]
             if unique:
                 bind._convert_to_unique()
-        return visitors.traverse(self, visit_bindparam=visit_bindparam, clone=True)
+        return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam})
 
     def compare(self, other):
         """Compare this ClauseElement to the given ClauseElement.
@@ -1049,11 +1073,6 @@ class ClauseElement(object):
     def self_group(self, against=None):
         return self
 
-    def supports_execution(self):
-        """Return True if this clause element represents a complete executable statement."""
-
-        return False
-
     def bind(self):
         """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found."""
 
@@ -1062,7 +1081,7 @@ class ClauseElement(object):
                 return self._bind
         except AttributeError:
             pass
-        for f in self._get_from_objects():
+        for f in _from_objects(self):
             if f is self:
                 continue
             engine = f.bind
@@ -1083,7 +1102,7 @@ class ClauseElement(object):
                    'Engine for execution. Or, assign a bind to the statement '
                    'or the Metadata of its underlying tables to enable '
                    'implicit execution via this method.' % label)
-            raise exceptions.UnboundExecutionError(msg)
+            raise exc.UnboundExecutionError(msg)
         return e.execute_clauseelement(self, multiparams, params)
 
     def scalar(self, *multiparams, **params):
@@ -1159,6 +1178,12 @@ class ClauseElement(object):
                 self.__module__, self.__class__.__name__, id(self), friendly)
 
 
+class _Immutable(object):
+    """mark a ClauseElement as 'immutable' when expressions are cloned."""
+    
+    def _clone(self):
+        return self
+        
 class Operators(object):
     def __and__(self, other):
         return self.operate(operators.and_, other)
@@ -1174,9 +1199,6 @@ class Operators(object):
             return self.operate(operators.op, opstring, b)
         return op
 
-    def clause_element(self):
-        raise NotImplementedError()
-
     def operate(self, op, *other, **kwargs):
         raise NotImplementedError()
 
@@ -1216,7 +1238,7 @@ class ColumnOperators(Operators):
     def ilike(self, other, escape=None):
         return self.operate(operators.ilike_op, other, escape=escape)
 
-    def in_(self, *other):
+    def in_(self, other):
         return self.operate(operators.in_op, other)
 
     def startswith(self, other, **kwargs):
@@ -1279,18 +1301,18 @@ class _CompareMixin(ColumnOperators):
     def __compare(self, op, obj, negate=None, reverse=False, **kwargs):
         if obj is None or isinstance(obj, _Null):
             if op == operators.eq:
-                return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot)
+                return _BinaryExpression(self, null(), operators.is_, negate=operators.isnot)
             elif op == operators.ne:
-                return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_)
+                return _BinaryExpression(self, null(), operators.isnot, negate=operators.is_)
             else:
-                raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL")
+                raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL")
         else:
             obj = self._check_literal(obj)
 
         if reverse:
-            return _BinaryExpression(obj, self.expression_element(), op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
+            return _BinaryExpression(obj, self, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
         else:
-            return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
+            return _BinaryExpression(self, obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs)
 
     def __operate(self, op, obj, reverse=False):
         obj = self._check_literal(obj)
@@ -1298,9 +1320,9 @@ class _CompareMixin(ColumnOperators):
         type_ = self._compare_type(obj)
 
         if reverse:
-            return _BinaryExpression(obj, self.expression_element(), type_.adapt_operator(op), type_=type_)
+            return _BinaryExpression(obj, self, type_.adapt_operator(op), type_=type_)
         else:
-            return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_)
+            return _BinaryExpression(self, obj, type_.adapt_operator(op), type_=type_)
 
     # a mapping of operators with the method they use, along with their negated
     # operator for comparison operators
@@ -1329,17 +1351,10 @@ class _CompareMixin(ColumnOperators):
         o = _CompareMixin.operators[op]
         return o[0](self, op, other, reverse=True, *o[1:], **kwargs)
 
-    def in_(self, *other):
-        return self._in_impl(operators.in_op, operators.notin_op, *other)
-
-    def _in_impl(self, op, negate_op, *other):
-        # Handle old style *args argument passing
-        if len(other) != 1 or not isinstance(other[0], Selectable) and (not hasattr(other[0], '__iter__') or isinstance(other[0], basestring)):
-            util.warn_deprecated('passing in_ arguments as varargs is deprecated, in_ takes a single argument that is a sequence or a selectable')
-            seq_or_selectable = other
-        else:
-            seq_or_selectable = other[0]
+    def in_(self, other):
+        return self._in_impl(operators.in_op, operators.notin_op, other)
 
+    def _in_impl(self, op, negate_op, seq_or_selectable):
         if isinstance(seq_or_selectable, Selectable):
             return self.__compare( op, seq_or_selectable, negate=negate_op)
 
@@ -1348,7 +1363,7 @@ class _CompareMixin(ColumnOperators):
         for o in seq_or_selectable:
             if not _is_literal(o):
                 if not isinstance( o, _CompareMixin):
-                    raise exceptions.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) )
+                    raise exc.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) )
             else:
                 o = self._bind_param(o)
             args.append(o)
@@ -1433,22 +1448,13 @@ class _CompareMixin(ColumnOperators):
         if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
             other.type = self.type
             return other
-        elif isinstance(other, Operators):
-            return other.expression_element()
+        elif hasattr(other, '__clause_element__'):
+            return other.__clause_element__()
         elif _is_literal(other):
             return self._bind_param(other)
         else:
             return other
 
-    def clause_element(self):
-        """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``."""
-        return self
-
-    def expression_element(self):
-        """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions."""
-
-        return self
-
     def _compare_type(self, obj):
         """Allow subclasses to override the type used in constructing
         ``_BinaryExpression`` objects.
@@ -1480,23 +1486,22 @@ class ColumnElement(ClauseElement, _CompareMixin):
 
     primary_key = False
     foreign_keys = []
-
+    quote = None
+    
     def base_columns(self):
-        if hasattr(self, '_base_columns'):
-            return self._base_columns
-        self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')])
+        if not hasattr(self, '_base_columns'):
+            self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')])
         return self._base_columns
     base_columns = property(base_columns)
 
     def proxy_set(self):
-        if hasattr(self, '_proxy_set'):
-            return self._proxy_set
-        s = util.Set([self])
-        if hasattr(self, 'proxies'):
-            for c in self.proxies:
-                s = s.union(c.proxy_set)
-        self._proxy_set = s
-        return s
+        if not hasattr(self, '_proxy_set'):
+            s = util.Set([self])
+            if hasattr(self, 'proxies'):
+                for c in self.proxies:
+                    s.update(c.proxy_set)
+            self._proxy_set = s
+        return self._proxy_set
     proxy_set = property(proxy_set)
 
     def shares_lineage(self, othercolumn):
@@ -1518,7 +1523,7 @@ class ColumnElement(ClauseElement, _CompareMixin):
             co = _ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None))
 
         co.proxies = [self]
-        selectable.columns[name]= co
+        selectable.columns[name] = co
         return co
 
     def anon_label(self):
@@ -1613,7 +1618,7 @@ class ColumnCollection(util.OrderedProperties):
 
     def __contains__(self, other):
         if not isinstance(other, basestring):
-            raise exceptions.ArgumentError("__contains__ requires a string argument")
+            raise exc.ArgumentError("__contains__ requires a string argument")
         return util.OrderedProperties.__contains__(self, other)
 
     def contains_column(self, col):
@@ -1641,6 +1646,9 @@ class ColumnSet(util.OrderedSet):
                     l.append(c==local)
         return and_(*l)
 
+    def __hash__(self):
+        return hash(tuple(self._list))
+
 class Selectable(ClauseElement):
     """mark a class as being selectable"""
 
@@ -1648,8 +1656,9 @@ class FromClause(Selectable):
     """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement."""
 
     __visit_name__ = 'fromclause'
-    named_with_column=False
+    named_with_column = False
     _hide_froms = []
+    quote = None
 
     def _get_from_objects(self, **modifiers):
         return []
@@ -1694,12 +1703,12 @@ class FromClause(Selectable):
         return fromclause in util.Set(self._cloned_set)
 
     def replace_selectable(self, old, alias):
-      """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
+        """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``."""
 
-      global ClauseAdapter
-      if ClauseAdapter is None:
-          from sqlalchemy.sql.util import ClauseAdapter
-      return ClauseAdapter(alias).traverse(self, clone=True)
+        global ClauseAdapter
+        if ClauseAdapter is None:
+            from sqlalchemy.sql.util import ClauseAdapter
+        return ClauseAdapter(alias).traverse(self)
 
     def correspond_on_equivalents(self, column, equivalents):
         col = self.corresponding_column(column, require_embedded=True)
@@ -1859,7 +1868,7 @@ class _BindParamClause(ClauseElement, _CompareMixin):
 
     def _convert_to_unique(self):
         if not self.unique:
-            self.unique=True
+            self.unique = True
             self.key = "{ANON %d %s}" % (id(self), self._orig_key or 'param')
 
     def _get_from_objects(self, **modifiers):
@@ -1910,6 +1919,7 @@ class _TextClause(ClauseElement):
     __visit_name__ = 'textclause'
 
     _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+    supports_execution = True
 
     _hide_froms = []
     oid_column = None
@@ -1950,12 +1960,6 @@ class _TextClause(ClauseElement):
     def _get_from_objects(self, **modifiers):
         return []
     
-    def supports_execution(self):
-        return True
-
-    def _table_iterator(self):
-        return iter([])
-
 class _Null(ColumnElement):
     """Represent the NULL keyword in a SQL statement.
 
@@ -2042,6 +2046,7 @@ class _CalculatedClause(ColumnElement):
     __visit_name__ = 'calculatedclause'
 
     def __init__(self, name, *clauses, **kwargs):
+        ColumnElement.__init__(self)
         self.name = name
         self.type = sqltypes.to_instance(kwargs.get('type_', None))
         self._bind = kwargs.get('bind', None)
@@ -2061,7 +2066,7 @@ class _CalculatedClause(ColumnElement):
 
     def clauses(self):
         if isinstance(self.clause_expr, _Grouping):
-            return self.clause_expr.elem
+            return self.clause_expr.element
         else:
             return self.clause_expr
     clauses = property(clauses)
@@ -2239,8 +2244,13 @@ class _Exists(_UnaryExpression):
     __visit_name__ = _UnaryExpression.__visit_name__
 
     def __init__(self, *args, **kwargs):
-        kwargs['correlate'] = True
-        s = select(*args, **kwargs).as_scalar().self_group()
+        if args and isinstance(args[0], _SelectBaseMixin):
+            s = args[0]
+        else:
+            if not args:
+                args = ([literal_column('*')],)
+            s = select(*args, **kwargs).as_scalar().self_group()
+            
         _UnaryExpression.__init__(self, s, operator=operators.exists)
 
     def select(self, whereclause=None, **params):
@@ -2272,7 +2282,7 @@ class Join(FromClause):
         self.right = _selectable(right).self_group()
 
         if onclause is None:
-            self.onclause = self.__match_primaries(self.left, self.right)
+            self.onclause = self._match_primaries(self.left, self.right)
         else:
             self.onclause = onclause
         
@@ -2310,7 +2320,7 @@ class Join(FromClause):
     def get_children(self, **kwargs):
         return self.left, self.right, self.onclause
 
-    def __match_primaries(self, primary, secondary):
+    def _match_primaries(self, primary, secondary):
         global sql_util
         if not sql_util:
             from sqlalchemy.sql import util as sql_util
@@ -2359,7 +2369,7 @@ class Join(FromClause):
         return self.select(use_labels=True, correlate=False).alias(name)
 
     def _hide_froms(self):
-        return itertools.chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set])
+        return itertools.chain(*[_from_objects(x.left, x.right) for x in self._cloned_set])
     _hide_froms = property(_hide_froms)
 
     def _get_from_objects(self, **modifiers):
@@ -2382,9 +2392,10 @@ class Alias(FromClause):
     def __init__(self, selectable, alias=None):
         baseselectable = selectable
         while isinstance(baseselectable, Alias):
-            baseselectable = baseselectable.selectable
+            baseselectable = baseselectable.element
         self.original = baseselectable
-        self.selectable = selectable
+        self.supports_execution = baseselectable.supports_execution
+        self.element = selectable
         if alias is None:
             if self.original.named_with_column:
                 alias = getattr(self.original, 'name', None)
@@ -2398,112 +2409,100 @@ class Alias(FromClause):
     def is_derived_from(self, fromclause):
         if fromclause in util.Set(self._cloned_set):
             return True
-        return self.selectable.is_derived_from(fromclause)
-
-    def supports_execution(self):
-        return self.original.supports_execution()
-
-    def _table_iterator(self):
-        return self.original._table_iterator()
+        return self.element.is_derived_from(fromclause)
 
     def _populate_column_collection(self):
-        for col in self.selectable.columns:
+        for col in self.element.columns:
             col._make_proxy(self)
-        if self.selectable.oid_column is not None:
-            self._oid_column = self.selectable.oid_column._make_proxy(self)
+        if self.element.oid_column is not None:
+            self._oid_column = self.element.oid_column._make_proxy(self)
 
     def _copy_internals(self, clone=_clone):
-       self._reset_exported()
-       self.selectable = _clone(self.selectable)
-       baseselectable = self.selectable
-       while isinstance(baseselectable, Alias):
-           baseselectable = baseselectable.selectable
-       self.original = baseselectable
+        self._reset_exported()
+        self.element = _clone(self.element)
+        baseselectable = self.element
+        while isinstance(baseselectable, Alias):
+            baseselectable = baseselectable.selectable
+        self.original = baseselectable
 
     def get_children(self, column_collections=True, aliased_selectables=True, **kwargs):
         if column_collections:
             for c in self.c:
                 yield c
         if aliased_selectables:
-            yield self.selectable
+            yield self.element
 
     def _get_from_objects(self, **modifiers):
         return [self]
 
     def bind(self):
-        return self.selectable.bind
+        return self.element.bind
     bind = property(bind)
 
-class _ColumnElementAdapter(ColumnElement):
-    """Adapts a ClauseElement which may or may not be a
-    ColumnElement subclass itself into an object which
-    acts like a ColumnElement.
-    """
+class _Grouping(ColumnElement):
+    """Represent a grouping within a column expression"""
 
-    def __init__(self, elem):
-        self.elem = elem
-        self.type = getattr(elem, 'type', None)
+    def __init__(self, element):
+        ColumnElement.__init__(self)
+        self.element = element
+        self.type = getattr(element, 'type', None)
 
     def key(self):
-        return self.elem.key
+        return self.element.key
     key = property(key)
 
     def _label(self):
         try:
-            return self.elem._label
+            return self.element._label
         except AttributeError:
             return self.anon_label
     _label = property(_label)
 
     def _copy_internals(self, clone=_clone):
-        self.elem = clone(self.elem)
+        self.element = clone(self.element)
 
     def get_children(self, **kwargs):
-        return self.elem,
+        return self.element,
 
     def _get_from_objects(self, **modifiers):
-        return self.elem._get_from_objects(**modifiers)
+        return self.element._get_from_objects(**modifiers)
 
     def __getattr__(self, attr):
-        return getattr(self.elem, attr)
+        return getattr(self.element, attr)
 
     def __getstate__(self):
-        return {'elem':self.elem, 'type':self.type}
+        return {'element':self.element, 'type':self.type}
 
     def __setstate__(self, state):
-        self.elem = state['elem']
+        self.element = state['element']
         self.type = state['type']
 
-class _Grouping(_ColumnElementAdapter):
-    """Represent a grouping within a column expression"""
-    pass
-
 class _FromGrouping(FromClause):
     """Represent a grouping of a FROM clause"""
     __visit_name__ = 'grouping'
 
-    def __init__(self, elem):
-        self.elem = elem
+    def __init__(self, element):
+        self.element = element
 
     def columns(self):
-        return self.elem.columns
+        return self.element.columns
     columns = c = property(columns)
 
     def _hide_froms(self):
-        return self.elem._hide_froms
+        return self.element._hide_froms
     _hide_froms = property(_hide_froms)
 
     def get_children(self, **kwargs):
-        return self.elem,
+        return self.element,
 
     def _copy_internals(self, clone=_clone):
-        self.elem = clone(self.elem)
+        self.element = clone(self.element)
 
     def _get_from_objects(self, **modifiers):
-        return self.elem._get_from_objects(**modifiers)
+        return self.element._get_from_objects(**modifiers)
 
     def __getattr__(self, attr):
-        return getattr(self.elem, attr)
+        return getattr(self.element, attr)
 
 class _Label(ColumnElement):
     """Represents a column label (AS).
@@ -2516,12 +2515,12 @@ class _Label(ColumnElement):
     ``ColumnElement`` subclasses.
     """
 
-    def __init__(self, name, obj, type_=None):
-        while isinstance(obj, _Label):
-            obj = obj.obj
-        self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon'))
-        self.obj = obj.self_group(against=operators.as_)
-        self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None))
+    def __init__(self, name, element, type_=None):
+        while isinstance(element, _Label):
+            element = element.element
+        self.name = name or "{ANON %d %s}" % (id(self), getattr(element, 'name', 'anon'))
+        self.element = element.self_group(against=operators.as_)
+        self.type = sqltypes.to_instance(type_ or getattr(element, 'type', None))
 
     def key(self):
         return self.name
@@ -2532,8 +2531,9 @@ class _Label(ColumnElement):
     _label = property(_label)
 
     def _proxy_attr(name):
+        get = util.attrgetter(name)
         def attr(self):
-            return getattr(self.obj, name)
+            return get(self.element)
         return property(attr)
 
     proxies = _proxy_attr('proxies')
@@ -2542,27 +2542,24 @@ class _Label(ColumnElement):
     primary_key = _proxy_attr('primary_key')
     foreign_keys = _proxy_attr('foreign_keys')
 
-    def expression_element(self):
-        return self.obj
-
     def get_children(self, **kwargs):
-        return self.obj,
+        return self.element,
 
     def _copy_internals(self, clone=_clone):
-        self.obj = clone(self.obj)
+        self.element = clone(self.element)
 
     def _get_from_objects(self, **modifiers):
-        return self.obj._get_from_objects(**modifiers)
+        return self.element._get_from_objects(**modifiers)
 
     def _make_proxy(self, selectable, name = None):
-        if isinstance(self.obj, (Selectable, ColumnElement)):
-            e = self.obj._make_proxy(selectable, name=self.name)
+        if isinstance(self.element, (Selectable, ColumnElement)):
+            e = self.element._make_proxy(selectable, name=self.name)
         else:
             e = column(self.name)._make_proxy(selectable=selectable)
         e.proxies.append(self)
         return e
 
-class _ColumnClause(ColumnElement):
+class _ColumnClause(_Immutable, ColumnElement):
     """Represents a generic column expression from any textual string.
 
     This includes columns associated with tables, aliases and select
@@ -2602,16 +2599,7 @@ class _ColumnClause(ColumnElement):
         return self.name.encode('ascii', 'backslashreplace')
     description = property(description)
 
-    def _clone(self):
-        # ColumnClause is immutable
-        return self
-
     def _label(self):
-        """Generate a 'label' string for this column.
-        """
-
-        # for a "literal" column, we've no idea what the text is
-        # therefore no 'label' can be automatically generated
         if self.is_literal:
             return None
         if not self.__label:
@@ -2626,24 +2614,21 @@ class _ColumnClause(ColumnElement):
                     counter = 1
                     while label in self.table.c:
                         label = self.__label + "_" + str(counter)
-                        counter +=1
+                        counter += 1
                     self.__label = label
             else:
                 self.__label = self.name
         return self.__label
-
     _label = property(_label)
 
     def label(self, name):
-        # if going off the "__label" property and its None, we have
-        # no label; return self
         if name is None:
             return self
         else:
             return super(_ColumnClause, self).label(name)
 
     def _get_from_objects(self, **modifiers):
-        if self.table is not None:
+        if self.table:
             return [self.table]
         else:
             return []
@@ -2651,20 +2636,20 @@ class _ColumnClause(ColumnElement):
     def _bind_param(self, obj):
         return _BindParamClause(self.name, obj, type_=self.type, unique=True)
 
-    def _make_proxy(self, selectable, name = None):
+    def _make_proxy(self, selectable, name=None, attach=True):
         # propigate the "is_literal" flag only if we are keeping our name,
         # otherwise its considered to be a label
         is_literal = self.is_literal and (name is None or name == self.name)
         c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal)
         c.proxies = [self]
-        if not self._is_oid:
+        if attach and not self._is_oid:
             selectable.columns[c.name] = c
         return c
 
     def _compare_type(self, obj):
         return self.type
 
-class TableClause(FromClause):
+class TableClause(_Immutable, FromClause):
     """Represents a "table" construct.
 
     Note that this represents tables only as another syntactical
@@ -2691,10 +2676,6 @@ class TableClause(FromClause):
         return self.name.encode('ascii', 'backslashreplace')
     description = property(description)
 
-    def _clone(self):
-        # TableClause is immutable
-        return self
-
     def append_column(self, c):
         self._columns[c.name] = c
         c.table = self
@@ -2724,10 +2705,11 @@ class TableClause(FromClause):
     def _get_from_objects(self, **modifiers):
         return [self]
 
-
 class _SelectBaseMixin(object):
     """Base class for ``Select`` and ``CompoundSelects``."""
 
+    supports_execution = True
+    
     def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, autocommit=False):
         self.use_labels = use_labels
         self.for_update = for_update
@@ -2773,11 +2755,6 @@ class _SelectBaseMixin(object):
         """
         return self.as_scalar().label(name)
 
-    def supports_execution(self):
-        """part of the ClauseElement contract; returns ``True`` in all cases for this class."""
-
-        return True
-
     def autocommit(self):
         """return a new selectable with the 'autocommit' flag set to True."""
 
@@ -2860,15 +2837,15 @@ class _SelectBaseMixin(object):
 class _ScalarSelect(_Grouping):
     __visit_name__ = 'grouping'
 
-    def __init__(self, elem):
-        self.elem = elem
-        cols = list(elem.inner_columns)
+    def __init__(self, element):
+        self.element = element
+        cols = list(element.inner_columns)
         if len(cols) != 1:
-            raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.")
+            raise exc.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.")
         self.type = cols[0].type
 
     def columns(self):
-        raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.")
+        raise exc.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.")
     columns = c = property(columns)
 
     def self_group(self, **kwargs):
@@ -2893,7 +2870,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
             if not numcols:
                 numcols = len(s.c)
             elif len(s.c) != numcols:
-                raise exceptions.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" %
+                raise exc.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" %
                     (1, len(self.selects[0].c), n+1, len(s.c))
                 )
             if s._order_by_clause:
@@ -2936,11 +2913,6 @@ class CompoundSelect(_SelectBaseMixin, FromClause):
         return (column_collections and list(self.c) or []) + \
             [self._order_by_clause, self._group_by_clause] + list(self.selects)
 
-    def _table_iterator(self):
-        for s in self.selects:
-            for t in s._table_iterator():
-                yield t
-
     def bind(self):
         if self._bind:
             return self._bind
@@ -2976,6 +2948,7 @@ class Select(_SelectBaseMixin, FromClause):
         self._distinct = distinct
 
         self._correlate = util.Set()
+        self._froms = util.OrderedSet()
 
         if columns:
             self._raw_columns = [
@@ -2983,22 +2956,23 @@ class Select(_SelectBaseMixin, FromClause):
                 for c in
                 [_literal_as_column(c) for c in columns]
             ]
+
+            self._froms.update(_from_objects(*self._raw_columns))
         else:
             self._raw_columns = []
-        
-        if from_obj:
-            self._froms = util.Set([
-                _is_literal(f) and _TextClause(f) or f
-                for f in util.to_list(from_obj)
-            ])
-        else:
-            self._froms = util.Set()
-            
+
         if whereclause:
             self._whereclause = _literal_as_text(whereclause)
+            self._froms.update(_from_objects(self._whereclause, is_where=True))
         else:
             self._whereclause = None
 
+        if from_obj:
+            self._froms.update([
+                _is_literal(f) and _TextClause(f) or f
+                for f in util.to_list(from_obj)
+            ])
+
         if having:
             self._having = _literal_as_text(having)
         else:
@@ -3020,36 +2994,28 @@ class Select(_SelectBaseMixin, FromClause):
         correlating.
         
         """
-        froms = util.OrderedSet()
-
-        for col in self._raw_columns:
-            froms.update(col._get_from_objects())
-
-        if self._whereclause is not None:
-            froms.update(self._whereclause._get_from_objects(is_where=True))
-
-        if self._froms:
-            froms.update(self._froms)
+        froms = self._froms
         
         toremove = itertools.chain(*[f._hide_froms for f in froms])
-        froms.difference_update(toremove)
+        if toremove:
+            froms = froms.difference(toremove)
 
         if len(froms) > 1 or self._correlate:
             if self._correlate:
-                froms.difference_update(_cloned_intersection(froms, self._correlate))
+                froms = froms.difference(_cloned_intersection(froms, self._correlate))
                 
             if self._should_correlate and existing_froms:
-                froms.difference_update(_cloned_intersection(froms, existing_froms))
+                froms = froms.difference(_cloned_intersection(froms, existing_froms))
                 
                 if not len(froms):
-                    raise exceptions.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self)
+                    raise exc.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self)
                     
         return froms
 
     froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""")
 
     def type(self):
-        raise exceptions.InvalidRequestError("Select objects don't have a type.  Call as_scalar() on this Select object to return a 'scalar' version of this Select.")
+        raise exc.InvalidRequestError("Select objects don't have a type.  Call as_scalar() on this Select object to return a 'scalar' version of this Select.")
     type = property(type)
 
     def locate_all_froms(self):
@@ -3059,22 +3025,10 @@ class Select(_SelectBaseMixin, FromClause):
         is specifically for those FromClause elements that would actually be rendered.
         
         """
-        if hasattr(self, '_all_froms'):
-            return self._all_froms
-
-        froms = util.Set(
-            itertools.chain(*
-                [self._froms] +
-                [f._get_from_objects() for f in self._froms] +
-                [col._get_from_objects() for col in self._raw_columns]
-            )
-        )
+        if not hasattr(self, '_all_froms'):
+            self._all_froms = self._froms.union(_from_objects(*list(self._froms)))
 
-        if self._whereclause:
-            froms.update(self._whereclause._get_from_objects(is_where=True))
-
-        self._all_froms = froms
-        return froms
+        return self._all_froms
 
     def inner_columns(self):
         """an iteratorof all ColumnElement expressions which would
@@ -3092,7 +3046,7 @@ class Select(_SelectBaseMixin, FromClause):
     def is_derived_from(self, fromclause):
         if self in util.Set(fromclause._cloned_set):
             return True
-
+        
         for f in self.locate_all_froms():
             if f.is_derived_from(fromclause):
                 return True
@@ -3112,7 +3066,7 @@ class Select(_SelectBaseMixin, FromClause):
         """return child elements as per the ClauseElement specification."""
 
         return (column_collections and list(self.columns) or []) + \
-            list(self.locate_all_froms()) + \
+            list(self._froms) + \
             [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None]
 
     def column(self, column):
@@ -3125,6 +3079,7 @@ class Select(_SelectBaseMixin, FromClause):
             column = column.self_group(against=operators.comma_op)
 
         s._raw_columns = s._raw_columns + [column]
+        s._froms = s._froms.union(_from_objects(column))
         return s
 
     def where(self, whereclause):
@@ -3185,7 +3140,7 @@ class Select(_SelectBaseMixin, FromClause):
         
         """
         s = self._generate()
-        s._should_correlate=False
+        s._should_correlate = False
         if fromclauses == (None,):
             s._correlate = util.Set()
         else:
@@ -3195,7 +3150,7 @@ class Select(_SelectBaseMixin, FromClause):
     def append_correlation(self, fromclause):
         """append the given correlation expression to this select() construct."""
         
-        self._should_correlate=False
+        self._should_correlate = False
         self._correlate = self._correlate.union([fromclause])
 
     def append_column(self, column):
@@ -3207,6 +3162,7 @@ class Select(_SelectBaseMixin, FromClause):
             column = column.self_group(against=operators.comma_op)
 
         self._raw_columns = self._raw_columns + [column]
+        self._froms = self._froms.union(_from_objects(column))
         self._reset_exported()
 
     def append_prefix(self, clause):
@@ -3221,10 +3177,13 @@ class Select(_SelectBaseMixin, FromClause):
         The expression will be joined to existing WHERE criterion via AND.
 
         """
+        whereclause = _literal_as_text(whereclause)
+        self._froms = self._froms.union(_from_objects(whereclause, is_where=True))
+        
         if self._whereclause is not None:
-            self._whereclause = and_(self._whereclause, _literal_as_text(whereclause))
+            self._whereclause = and_(self._whereclause, whereclause)
         else:
-            self._whereclause = _literal_as_text(whereclause)
+            self._whereclause = whereclause
 
     def append_having(self, having):
         """append the given expression to this select() construct's HAVING criterion.
@@ -3311,31 +3270,23 @@ class Select(_SelectBaseMixin, FromClause):
 
         return intersect_all(self, other, **kwargs)
 
-    def _table_iterator(self):
-        for t in visitors.NoColumnVisitor().iterate(self):
-            if isinstance(t, TableClause):
-                yield t
-
     def bind(self):
         if self._bind:
             return self._bind
-        for f in self._froms:
-            if f is self:
-                continue
-            e = f.bind
-            if e:
-                self._bind = e
-                return e
-        # look through the columns (largely synomous with looking
-        # through the FROMs except in the case of _CalculatedClause/_Function)
-        for c in self._raw_columns:
-            if getattr(c, 'table', None) is self:
-                continue
-            e = c.bind
+        if not self._froms:
+            for c in self._raw_columns:
+                e = c.bind
+                if e:
+                    self._bind = e
+                    return e
+        else:
+            e = list(self._froms)[0].bind
             if e:
                 self._bind = e
                 return e
+
         return None
+        
     def _set_bind(self, bind):
         self._bind = bind
     bind = property(bind, _set_bind)
@@ -3343,11 +3294,7 @@ class Select(_SelectBaseMixin, FromClause):
 class _UpdateBase(ClauseElement):
     """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements."""
 
-    def supports_execution(self):
-        return True
-
-    def _table_iterator(self):
-        return iter([self.table])
+    supports_execution = True
 
     def _generate(self):
         s = self.__class__.__new__(self.__class__)
@@ -3407,7 +3354,7 @@ class Insert(_ValuesBase):
         self._bind = bind
         self.table = table
         self.select = None
-        self.inline=inline
+        self.inline = inline
         if prefixes:
             self._prefixes = [_literal_as_text(p) for p in prefixes]
         else:
@@ -3502,10 +3449,11 @@ class Delete(_UpdateBase):
         self._whereclause = clone(self._whereclause)
 
 class _IdentifiedClause(ClauseElement):
+    supports_execution = True
+    quote = None
+    
     def __init__(self, ident):
         self.ident = ident
-    def supports_execution(self):
-        return True
 
 class SavepointClause(_IdentifiedClause):
     pass
index dfd638ecb1b2f5acef8faa8488a456e6daeac5ea..46dcaba66b99d6bca491bb1eab23257870ff31d4 100644 (file)
@@ -44,7 +44,7 @@ def between_op(a, b, c):
     return a.between(b, c)
 
 def in_op(a, b):
-    return a.in_(*b)
+    return a.in_(b)
 
 def notin_op(a, b):
     raise NotImplementedError()
index d299982cfa0e24fbc60cf9bc16c4e8b7c2d78596..944a68def9fafe72ddcc1f563dfe64aa7150c2b3 100644 (file)
@@ -1,4 +1,4 @@
-from sqlalchemy import exceptions, schema, topological, util, sql
+from sqlalchemy import exc, schema, topological, util, sql
 from sqlalchemy.sql import expression, operators, visitors
 from itertools import chain
 
@@ -8,43 +8,57 @@ def sort_tables(tables, reverse=False):
     """sort a collection of Table objects in order of their foreign-key dependency."""
     
     tuples = []
-    class TVisitor(schema.SchemaVisitor):
-        def visit_foreign_key(_self, fkey):
-            if fkey.use_alter:
-                return
-            parent_table = fkey.column.table
-            if parent_table in tables:
-                child_table = fkey.parent.table
-                tuples.append( ( parent_table, child_table ) )
-    vis = TVisitor()
+    def visit_foreign_key(fkey):
+        if fkey.use_alter:
+            return
+        parent_table = fkey.column.table
+        if parent_table in tables:
+            child_table = fkey.parent.table
+            tuples.append( ( parent_table, child_table ) )
+
     for table in tables:
-        vis.traverse(table)
+        visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key})    
     sequence = topological.sort(tuples, tables)
     if reverse:
         return util.reversed(sequence)
     else:
         return sequence
 
-def find_tables(clause, check_columns=False, include_aliases=False):
+def search(clause, target):
+    if not clause:
+        return False
+    for elem in visitors.iterate(clause, {'column_collections':False}):
+        if elem is target:
+            return True
+    else:
+        return False
+
+def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False):
     """locate Table objects within the given expression."""
     
     tables = []
-    kwargs = {}
+    _visitors = {}
+    
+    def visit_something(elem):
+        tables.append(elem)
+        
+    if include_selects:
+        _visitors['select'] = _visitors['compound_select'] = visit_something
+    
+    if include_joins:
+        _visitors['join'] = visit_something
+        
     if include_aliases:
-        def visit_alias(alias):
-            tables.append(alias)
-        kwargs['visit_alias']  = visit_alias
+        _visitors['alias']  = visit_something
 
     if check_columns:
         def visit_column(column):
             tables.append(column.table)
-        kwargs['visit_column'] = visit_column
+        _visitors['column'] = visit_column
 
-    def visit_table(table):
-        tables.append(table)
-    kwargs['visit_table'] = visit_table
+    _visitors['table'] = visit_something
 
-    visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs)
+    visitors.traverse(clause, {'column_collections':False}, _visitors)
     return tables
 
 def find_columns(clause):
@@ -53,7 +67,7 @@ def find_columns(clause):
     cols = util.Set()
     def visit_column(col):
         cols.add(col)
-    visitors.traverse(clause, visit_column=visit_column)
+    visitors.traverse(clause, {}, {'column':visit_column})
     return cols
 
 def join_condition(a, b, ignore_nonexistent_tables=False):
@@ -72,7 +86,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False):
     for fk in b.foreign_keys:
         try:
             col = fk.get_referent(a)
-        except exceptions.NoReferencedTableError:
+        except exc.NoReferencedTableError:
             if ignore_nonexistent_tables:
                 continue
             else:
@@ -81,27 +95,26 @@ def join_condition(a, b, ignore_nonexistent_tables=False):
         if col:
             crit.append(col == fk.parent)
             constraints.add(fk.constraint)
-
     if a is not b:
         for fk in a.foreign_keys:
             try:
                 col = fk.get_referent(b)
-            except exceptions.NoReferencedTableError:
+            except exc.NoReferencedTableError:
                 if ignore_nonexistent_tables:
                     continue
                 else:
                     raise
-            
+
             if col:
                 crit.append(col == fk.parent)
                 constraints.add(fk.constraint)
 
     if len(crit) == 0:
-        raise exceptions.ArgumentError(
+        raise exc.ArgumentError(
             "Can't find any foreign key relationships "
             "between '%s' and '%s'" % (a.description, b.description))
     elif len(constraints) > 1:
-        raise exceptions.ArgumentError(
+        raise exc.ArgumentError(
             "Can't determine join between '%s' and '%s'; "
             "tables have more than one foreign key "
             "constraint relationship between them. "
@@ -111,7 +124,70 @@ def join_condition(a, b, ignore_nonexistent_tables=False):
         return (crit[0])
     else:
         return sql.and_(*crit)
+
+class Annotated(object):
+    """clones a ClauseElement and applies an 'annotations' dictionary.
+    
+    Unlike regular clones, this clone also mimics __hash__() and 
+    __cmp__() of the original element so that it takes its place
+    in hashed collections.
     
+    A reference to the original element is maintained, for the important
+    reason of keeping its hash value current.  When GC'ed, the 
+    hash value may be reused, causing conflicts.
+
+    """
+    def __new__(cls, *args):
+        if not args:
+            return object.__new__(cls)
+        else:
+            element, values = args
+            return object.__new__(
+                type.__new__(type, "Annotated%s" % element.__class__.__name__, (Annotated, element.__class__), {}) 
+            )
+
+    def __init__(self, element, values):
+        self.__dict__ = element.__dict__.copy()
+        self.__element = element
+        self._annotations = values
+
+    def _annotate(self, values):
+        _values = self._annotations.copy()
+        _values.update(values)
+        clone = self.__class__.__new__(self.__class__)
+        clone.__dict__ = self.__dict__.copy()
+        clone._annotations = _values
+        return clone
+        
+    def __hash__(self):
+        return hash(self.__element)
+
+    def __cmp__(self, other):
+        return cmp(hash(self.__element), hash(other))
+
+def splice_joins(left, right, stop_on=None):
+    if left is None:
+        return right
+        
+    stack = [(right, None)]
+
+    adapter = ClauseAdapter(left)
+    ret = None
+    while stack:
+        (right, prevright) = stack.pop()
+        if isinstance(right, expression.Join) and right is not stop_on:
+            right = right._clone()
+            right._reset_exported()
+            right.onclause = adapter.traverse(right.onclause)
+            stack.append((right.left, right))
+        else:
+            right = adapter.traverse(right)
+        if prevright:
+            prevright.left = right
+        if not ret:
+            ret = right
+
+    return ret
     
 def reduce_columns(columns, *clauses):
     """given a list of columns, return a 'reduced' set based on natural equivalents.
@@ -151,7 +227,7 @@ def reduce_columns(columns, *clauses):
                             omit.add(c)
                             break
         for clause in clauses:
-            visitors.traverse(clause, visit_binary=visit_binary)
+            visitors.traverse(clause, {}, {'binary':visit_binary})
 
     return expression.ColumnSet(columns.difference(omit))
 
@@ -159,7 +235,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re
     """traverse an expression and locate binary criterion pairs."""
     
     if consider_as_foreign_keys and consider_as_referenced_keys:
-        raise exceptions.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
+        raise exc.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
         
     def visit_binary(binary):
         if not any_operator and binary.operator != operators.eq:
@@ -184,7 +260,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re
                 elif binary.right.references(binary.left):
                     pairs.append((binary.left, binary.right))
     pairs = []
-    visitors.traverse(expression, visit_binary=visit_binary)
+    visitors.traverse(expression, {}, {'binary':visit_binary})
     return pairs
 
 def folded_equivalents(join, equivs=None):
@@ -195,15 +271,15 @@ def folded_equivalents(join, equivs=None):
     This function is used by Join.select(fold_equivalents=True).
     
     TODO: deprecate ?
-    """
 
+    """
     if equivs is None:
         equivs = util.Set()
     def visit_binary(binary):
         if binary.operator == operators.eq and binary.left.name == binary.right.name:
             equivs.add(binary.right)
             equivs.add(binary.left)
-    visitors.traverse(join.onclause, visit_binary=visit_binary)
+    visitors.traverse(join.onclause, {}, {'binary':visit_binary})
     collist = []
     if isinstance(join.left, expression.Join):
         left = folded_equivalents(join.left, equivs)
@@ -246,43 +322,8 @@ class AliasedRow(object):
     def keys(self):
         return self.row.keys()
 
-def row_adapter(from_, equivalent_columns=None):
-    """create a row adapter callable against a selectable."""
-    
-    if equivalent_columns is None:
-        equivalent_columns = {}
-
-    def locate_col(col):
-        c = from_.corresponding_column(col)
-        if c:
-            return c
-        elif col in equivalent_columns:
-            for c2 in equivalent_columns[col]:
-                corr = from_.corresponding_column(c2)
-                if corr:
-                    return corr
-        return col
-        
-    map = util.PopulateDict(locate_col)
-    
-    def adapt(row):
-        return AliasedRow(row, map)
-    return adapt
-
-class ColumnsInClause(visitors.ClauseVisitor):
-    """Given a selectable, visit clauses and determine if any columns
-    from the clause are in the selectable.
-    """
-
-    def __init__(self, selectable):
-        self.selectable = selectable
-        self.result = False
-
-    def visit_column(self, column):
-        if self.selectable.c.get(column.key) is column:
-            self.result = True
 
-class ClauseAdapter(visitors.ClauseVisitor):
+class ClauseAdapter(visitors.ReplacingCloningVisitor):
     """Given a clause (like as in a WHERE criterion), locate columns
     which are embedded within a given selectable, and changes those
     columns to be that of the selectable.
@@ -308,58 +349,76 @@ class ClauseAdapter(visitors.ClauseVisitor):
     condition to read::
 
       s.c.col1 == table2.c.col1
-    """
-
-    __traverse_options__ = {'column_collections':False}
 
-    def __init__(self, selectable, include=None, exclude=None, equivalents=None):
-        self.__traverse_options__ = self.__traverse_options__.copy()
-        self.__traverse_options__['stop_on'] = [selectable]
+    """
+    def __init__(self, selectable, equivalents=None, include=None, exclude=None):
+        self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]}
         self.selectable = selectable
         self.include = include
         self.exclude = exclude
-        self.equivalents = equivalents
-    
-    def traverse(self, obj, clone=True):
-        if not clone:
-            raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True")
-        return visitors.ClauseVisitor.traverse(self, obj, clone=True)
-    
-    def copy_and_chain(self, adapter):
-        """create a copy of this adapter and chain to the given adapter.
-
-        currently this adapter must be unchained to start, raises
-        an exception if it's already chained.
-
-        Does not modify the given adapter.
-        """
+        self.equivalents = equivalents or {}
 
-        if adapter is None:
-            return self
+    def _corresponding_column(self, col, require_embedded):
+        newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded)
 
-        if hasattr(self, '_next'):
-            raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)")
-
-        ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents)
-        ca._next = adapter
-        return ca
+        if not newcol and col in self.equivalents:
+            for equiv in self.equivalents[col]:
+                newcol = self.selectable.corresponding_column(equiv, require_embedded=require_embedded)
+                if newcol:
+                    return newcol
+        return newcol
 
-    def before_clone(self, col):
+    def replace(self, col):
         if isinstance(col, expression.FromClause):
             if self.selectable.is_derived_from(col):
                 return self.selectable
+
         if not isinstance(col, expression.ColumnElement):
             return None
-        if self.include is not None:
-            if col not in self.include:
-                return None
-        if self.exclude is not None:
-            if col in self.exclude:
-                return None
-        newcol = self.selectable.corresponding_column(col, require_embedded=True)
-        if newcol is None and self.equivalents is not None and col in self.equivalents:
-            for equiv in self.equivalents[col]:
-                newcol = self.selectable.corresponding_column(equiv, require_embedded=True)
-                if newcol:
-                    return newcol
-        return newcol
+
+        if self.include and col not in self.include:
+            return None
+        elif self.exclude and col in self.exclude:
+            return None
+
+        return self._corresponding_column(col, True)
+
+class ColumnAdapter(ClauseAdapter):
+
+    def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None):
+        ClauseAdapter.__init__(self, selectable, equivalents, include, exclude)
+        if chain_to:
+            self.chain(chain_to)
+        self.columns = util.PopulateDict(self._locate_col)
+
+    def wrap(self, adapter):
+        ac = self.__class__.__new__(self.__class__)
+        ac.__dict__ = self.__dict__.copy()
+        ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col)
+        ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause)
+        ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list)
+        ac.columns = util.PopulateDict(ac._locate_col)
+        return ac
+
+    adapt_clause = ClauseAdapter.traverse
+    adapt_list = ClauseAdapter.copy_and_process
+
+    def _wrap(self, local, wrapped):
+        def locate(col):
+            col = local(col)
+            return wrapped(col)
+        return locate
+
+    def _locate_col(self, col):
+        c = self._corresponding_column(col, False)
+        if not c:
+            c = self.adapt_clause(col)
+            
+            # anonymize labels in case they have a hardcoded name
+            if isinstance(c, expression._Label):
+                c = c.label(None)
+        return c    
+
+    def adapted_row(self, row):
+        return AliasedRow(row, self.columns)
+    
index 9888a228a39a55a50729f12e0708d1cfc4041ef5..738dae9c7e1ee7f3b1b1204da598671467a401cd 100644 (file)
 from sqlalchemy import util
 
 class ClauseVisitor(object):
-    """Traverses and visits ``ClauseElement`` structures.
-    
-    Calls visit_XXX() methods for each particular
-    ``ClauseElement`` subclass encountered.  Traversal of a
-    hierarchy of ``ClauseElements`` is achieved via the
-    ``traverse()`` method, which is passed the lead
-    ``ClauseElement``.
-    
-    By default, ``ClauseVisitor`` traverses all elements
-    fully.  Options can be specified at the class level via the 
-    ``__traverse_options__`` dictionary which will be passed
-    to the ``get_children()`` method of each ``ClauseElement``;
-    these options can indicate modifications to the set of 
-    elements returned, such as to not return column collections
-    (column_collections=False) or to return Schema-level items
-    (schema_visitor=True).
-    
-    ``ClauseVisitor`` also supports a simultaneous copy-and-traverse
-    operation, which will produce a copy of a given ``ClauseElement``
-    structure while at the same time allowing ``ClauseVisitor`` subclasses
-    to modify the new structure in-place.
-    
-    """
     __traverse_options__ = {}
     
-    def traverse_single(self, obj, **kwargs):
-        """visit a single element, without traversing its child elements."""
-        
+    def traverse_single(self, obj):
         for v in self._iterate_visitors:
             meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
             if meth:
-                return meth(obj, **kwargs)
+                return meth(obj)
     
-    traverse_chained = traverse_single
-        
     def iterate(self, obj):
         """traverse the given expression structure, returning an iterator of all elements."""
-        
-        stack = [obj]
-        traversal = util.deque()
-        while stack:
-            t = stack.pop()
-            traversal.appendleft(t)
-            for c in t.get_children(**self.__traverse_options__):
-                stack.append(c)
-        return iter(traversal)
-        
-    def traverse(self, obj, clone=False):
-        """traverse and visit the given expression structure.
-        
-        Returns the structure given, or a copy of the structure if
-        clone=True.
-        
-        When the copy operation takes place, the before_clone() method
-        will receive each element before it is copied.  If the method
-        returns a non-None value, the return value is taken as the 
-        "copied" element and traversal will not descend further.  
-        
-        The visit_XXX() methods receive the element *after* it's been
-        copied.  To compare an element to another regardless of
-        one element being a cloned copy of the original, the 
-        '_cloned_set' attribute of ClauseElement can be used for the compare, 
-        i.e.::
-        
-            original in copied._cloned_set
-            
-        
-        """
-        if clone:
-            return self._cloned_traversal(obj)
-        else:
-            return self._non_cloned_traversal(obj)
-
-    def copy_and_process(self, list_):
-        """Apply cloned traversal to the given list of elements, and return the new list."""
-
-        return [self._cloned_traversal(x) for x in list_]
 
-    def before_clone(self, elem):
-        """receive pre-copied elements during a cloning traversal.
-        
-        If the method returns a new element, the element is used 
-        instead of creating a simple copy of the element.  Traversal 
-        will halt on the newly returned element if it is re-encountered.
-        """
-        return None
-    
-    def _clone_element(self, elem, stop_on, cloned):
-        for v in self._iterate_visitors:
-            newelem = v.before_clone(elem)
-            if newelem:
-                stop_on.add(newelem)
-                return newelem
-
-        if elem not in cloned:
-            # the full traversal will only make a clone of a particular element
-            # once.
-            cloned[elem] = elem._clone()
-        return cloned[elem]
-            
-    def _cloned_traversal(self, obj):
-        """a recursive traversal which creates copies of elements, returning the new structure."""
-        
-        stop_on = self.__traverse_options__.get('stop_on', [])
-        return self._cloned_traversal_impl(obj, util.Set(stop_on), {}, _clone_toplevel=True)
-        
-    def _cloned_traversal_impl(self, elem, stop_on, cloned, _clone_toplevel=False):
-        if elem in stop_on:
-            return elem
-
-        if _clone_toplevel:
-            elem = self._clone_element(elem, stop_on, cloned)
-            if elem in stop_on:
-                return elem
-
-        def clone(element):
-            return self._clone_element(element, stop_on, cloned)
-        elem._copy_internals(clone=clone)
+        return iterate(obj, self.__traverse_options__)
         
-        self.traverse_single(elem)
+    def traverse(self, obj):
+        """traverse and visit the given expression structure."""
 
-        for e in elem.get_children(**self.__traverse_options__):
-            if e not in stop_on:
-                self._cloned_traversal_impl(e, stop_on, cloned)
-        return elem
+        visitors = {}
 
-    def _non_cloned_traversal(self, obj):
-        """a non-recursive, non-cloning traversal."""
-
-        for target in self.iterate(obj):
-            self.traverse_single(target)
-        return obj
+        for name in dir(self):
+            if name.startswith('visit_'):
+                visitors[name[6:]] = getattr(self, name)
+            
+        return traverse(obj, self.__traverse_options__, visitors)
 
     def _iterate_visitors(self):
         """iterate through this visitor and each 'chained' visitor."""
@@ -152,31 +43,136 @@ class ClauseVisitor(object):
         tail._next = visitor
         return self
 
-class NoColumnVisitor(ClauseVisitor):
-    """ClauseVisitor with 'column_collections' set to False; will not
-    traverse the front-facing Column collections on Table, Alias, Select, 
-    and CompoundSelect objects.
+class CloningVisitor(ClauseVisitor):
+    def copy_and_process(self, list_):
+        """Apply cloned traversal to the given list of elements, and return the new list."""
+
+        return [self.traverse(x) for x in list_]
+
+    def traverse(self, obj):
+        """traverse and visit the given expression structure."""
+
+        visitors = {}
+
+        for name in dir(self):
+            if name.startswith('visit_'):
+                visitors[name[6:]] = getattr(self, name)
+            
+        return cloned_traverse(obj, self.__traverse_options__, visitors)
+
+class ReplacingCloningVisitor(CloningVisitor):
+    def replace(self, elem):
+        """receive pre-copied elements during a cloning traversal.
+        
+        If the method returns a new element, the element is used 
+        instead of creating a simple copy of the element.  Traversal 
+        will halt on the newly returned element if it is re-encountered.
+        """
+        return None
+
+    def traverse(self, obj):
+        """traverse and visit the given expression structure."""
+
+        def replace(elem):
+            for v in self._iterate_visitors:
+                e = v.replace(elem)
+                if e:
+                    return e
+        return replacement_traverse(obj, self.__traverse_options__, replace)
+
+def iterate(obj, opts):
+    """traverse the given expression structure, returning an iterator.
+    
+    traversal is configured to be breadth-first.
     
     """
+    stack = util.deque([obj])
+    while stack:
+        t = stack.popleft()
+        yield t
+        for c in t.get_children(**opts):
+            stack.append(c)
+
+def iterate_depthfirst(obj, opts):
+    """traverse the given expression structure, returning an iterator.
     
-    __traverse_options__ = {'column_collections':False}
-
-class NullVisitor(ClauseVisitor):
-    def traverse(self, obj, clone=False):
-        next = getattr(self, '_next', None)
-        if next:
-            return next.traverse(obj, clone=clone)
-        else:
-            return obj
-        
-def traverse(clause, **kwargs):
-    """traverse the given clause, applying visit functions passed in as keyword arguments."""
+    traversal is configured to be depth-first.
+    
+    """
+    stack = util.deque([obj])
+    traversal = util.deque()
+    while stack:
+        t = stack.pop()
+        traversal.appendleft(t)
+        for c in t.get_children(**opts):
+            stack.append(c)
+    return iter(traversal)
+
+def traverse_using(iterator, obj, visitors):
+    """visit the given expression structure using the given iterator of objects."""
+
+    for target in iterator:
+        meth = visitors.get(target.__visit_name__, None)
+        if meth:
+            meth(target)
+    return obj
     
-    clone = kwargs.pop('clone', False)
-    class Vis(ClauseVisitor):
-        __traverse_options__ = kwargs.pop('traverse_options', {})
-    vis = Vis()
-    for key in kwargs:
-        setattr(vis, key, kwargs[key])
-    return vis.traverse(clause, clone=clone)
+def traverse(obj, opts, visitors):
+    """traverse and visit the given expression structure using the default iterator."""
+
+    return traverse_using(iterate(obj, opts), obj, visitors)
+
+def traverse_depthfirst(obj, opts, visitors):
+    """traverse and visit the given expression structure using the depth-first iterator."""
+
+    return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
+
+def cloned_traverse(obj, opts, visitors):
+    cloned = {}
+
+    def clone(element):
+        if element not in cloned:
+            cloned[element] = element._clone()
+        return cloned[element]
+
+    obj = clone(obj)
+    stack = [obj]
+
+    while stack:
+        t = stack.pop()
+        if t in cloned:
+            continue
+        t._copy_internals(clone=clone)
+
+        meth = visitors.get(t.__visit_name__, None)
+        if meth:
+            meth(t)
+
+        for c in t.get_children(**opts):
+            stack.append(c)
+    return obj
+
+def replacement_traverse(obj, opts, replace):
+    cloned = {}
+    stop_on = util.Set(opts.get('stop_on', []))
+
+    def clone(element):
+        newelem = replace(element)
+        if newelem:
+            stop_on.add(newelem)
+            return newelem
+
+        if element not in cloned:
+            cloned[element] = element._clone()
+        return cloned[element]
 
+    obj = clone(obj)
+    stack = [obj]
+    while stack:
+        t = stack.pop()
+        if t in stop_on:
+            continue
+        t._copy_internals(clone=clone)
+        for c in t.get_children(**opts):
+            stack.append(c)
+    return obj
index 99612397989f25528c15e62face61ef69a67fbc9..9ef3dfaf45cca5f744fae0b480ce2d773477afbd 100644 (file)
@@ -19,7 +19,7 @@ conditions.
 """
 
 from sqlalchemy import util
-from sqlalchemy.exceptions import CircularDependencyError
+from sqlalchemy.exc import CircularDependencyError
 
 __all__ = ['sort', 'sort_with_cycles', 'sort_as_tree']
 
@@ -207,9 +207,9 @@ def _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=False):
                     for n in lead.cycles:
                         if n is not lead:
                             n._cyclical = True
-                            for (n,k) in list(edges.edges_by_parent(n)):
+                            for (n, k) in list(edges.edges_by_parent(n)):
                                 edges.add((lead, k))
-                                edges.remove((n,k))
+                                edges.remove((n, k))
                 continue
             else:
                 # long cycles not allowed
@@ -248,7 +248,7 @@ def _organize_as_tree(nodes):
         nodealldeps = node.all_deps()
         if nodealldeps:
             # iterate over independent node indexes in reverse order so we can efficiently remove them
-            for index in xrange(len(independents)-1,-1,-1):
+            for index in xrange(len(independents) - 1, -1, -1):
                 child, childsubtree, childcycles = independents[index]
                 # if there is a dependency between this node and an independent node
                 if (childsubtree.intersection(nodealldeps) or childcycles.intersection(node.dependencies)):
@@ -261,7 +261,7 @@ def _organize_as_tree(nodes):
                     # remove the child from list of independent subtrees
                     independents[index:index+1] = []
         # add node as a new independent subtree
-        independents.append((node,subtree,cycles))
+        independents.append((node, subtree, cycles))
     # choose an arbitrary node from list of all independent subtrees
     head = independents.pop()[0]
     # add all other independent subtrees as a child of the chosen root
index e06ec9a5a507af30e382548e5c3a970d9207cb2f..bae079e649bd45ef9077e24cb3bb66f6e48e4191 100644 (file)
@@ -24,7 +24,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType',
 import inspect
 import datetime as dt
 
-from sqlalchemy import exceptions
+from sqlalchemy import exc
 from sqlalchemy.util import pickle, Decimal as _python_Decimal
 import sqlalchemy.util as util
 NoneType = type(None)
@@ -173,7 +173,6 @@ class TypeEngine(AbstractType):
     def get_col_spec(self):
         raise NotImplementedError()
 
-
     def bind_processor(self, dialect):
         return None
 
@@ -214,7 +213,7 @@ class TypeDecorator(AbstractType):
     
     def __init__(self, *args, **kwargs):
         if not hasattr(self.__class__, 'impl'):
-            raise exceptions.AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
+            raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated")
         self.impl = self.__class__.impl(*args, **kwargs)
 
     def dialect_impl(self, dialect, **kwargs):
@@ -231,7 +230,7 @@ class TypeDecorator(AbstractType):
             typedesc = self.load_dialect_impl(dialect)
         tt = self.copy()
         if not isinstance(tt, self.__class__):
-            raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__))
+            raise AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__))
         tt.impl = typedesc
         self._impl_dict[dialect] = tt
         return tt
@@ -299,7 +298,7 @@ class TypeDecorator(AbstractType):
         return self.impl.copy_value(value)
 
     def compare_values(self, x, y):
-        return self.impl.compare_values(x,y)
+        return self.impl.compare_values(x, y)
 
     def is_mutable(self):
         return self.impl.is_mutable()
@@ -363,12 +362,14 @@ class Concatenable(object):
 class String(Concatenable, TypeEngine):
     """A sized string type.
 
-    Usually corresponds to VARCHAR.  Can also take Python unicode objects
+    In SQL, corresponds to VARCHAR.  Can also take Python unicode objects
     and encode to the database's encoding in bind params (and the reverse for
     result sets.)
 
-    a String with no length will adapt itself automatically to a Text
-    object at the dialect level (this behavior is deprecated in 0.4).
+    The `length` field is usually required when the `String` type is used within a 
+    CREATE TABLE statement, since VARCHAR requires a length on most databases.
+    Currently SQLite is an exception to this.
+    
     """
     def __init__(self, length=None, convert_unicode=False, assert_unicode=None):
         self.length = length
@@ -393,7 +394,7 @@ class String(Concatenable, TypeEngine):
                                   "param value %r" % value)
                         return value
                     else:
-                        raise exceptions.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
+                        raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
                 else:
                     return value
             return process
@@ -411,26 +412,6 @@ class String(Concatenable, TypeEngine):
         else:
             return None
 
-    def dialect_impl(self, dialect, **kwargs):
-        _for_ddl = kwargs.pop('_for_ddl', False)
-        if _for_ddl and self.length is None:
-            label = util.to_ascii(_for_ddl is True and
-                                  '' or (' for column "%s"' % str(_for_ddl)))
-            util.warn_deprecated(
-                "Using String type with no length for CREATE TABLE "
-                "is deprecated; use the Text type explicitly" + label)
-        return TypeEngine.dialect_impl(self, dialect, **kwargs)
-
-    def get_search_list(self):
-        l = super(String, self).get_search_list()
-        # if we are String or Unicode with no length,
-        # return Text as the highest-priority type
-        # to be adapted by the dialect
-        if self.length is None and l[0] in (String, Unicode):
-            return (Text,) + l
-        else:
-            return l
-
     def get_dbapi_type(self, dbapi):
         return dbapi.STRING
 
@@ -632,7 +613,7 @@ class Interval(TypeDecorator):
             if value is None:
                 return None
             return dt.datetime.utcfromtimestamp(0) + value
-            
+
     def process_result_value(self, value, dialect):
         if dialect.__class__ in self.__supported:
             return value
@@ -641,23 +622,68 @@ class Interval(TypeDecorator):
                 return None
             return value - dt.datetime.utcfromtimestamp(0)
 
-class FLOAT(Float): pass
-TEXT = Text
-class NUMERIC(Numeric): pass
-class DECIMAL(Numeric): pass
-class INT(Integer): pass
+class FLOAT(Float):
+    """The SQL FLOAT type."""
+
+
+class NUMERIC(Numeric):
+    """The SQL NUMERIC type."""
+
+
+class DECIMAL(Numeric):
+    """The SQL DECIMAL type."""
+
+
+class INT(Integer):
+    """The SQL INT or INTEGER type."""
+
+
 INTEGER = INT
-class SMALLINT(Smallinteger): pass
-class TIMESTAMP(DateTime): pass
-class DATETIME(DateTime): pass
-class DATE(Date): pass
-class TIME(Time): pass
-class CLOB(Text): pass
-class VARCHAR(String): pass
-class CHAR(String): pass
-class NCHAR(Unicode): pass
-class BLOB(Binary): pass
-class BOOLEAN(Boolean): pass
+
+class SMALLINT(Smallinteger):
+    """The SQL SMALLINT type."""
+
+
+class TIMESTAMP(DateTime):
+    """The SQL TIMESTAMP type."""
+
+
+class DATETIME(DateTime):
+    """The SQL DATETIME type."""
+
+
+class DATE(Date):
+    """The SQL DATE type."""
+
+
+class TIME(Time):
+    """The SQL TIME type."""
+
+
+TEXT = Text
+
+class CLOB(Text):
+    """The SQL CLOB type."""
+
+
+class VARCHAR(String):
+    """The SQL VARCHAR type."""
+
+
+class CHAR(String):
+    """The SQL CHAR type."""
+
+
+class NCHAR(Unicode):
+    """The SQL NCHAR type."""
+
+
+class BLOB(Binary):
+    """The SQL BLOB type."""
+
+
+class BOOLEAN(Boolean):
+    """The SQL BOOLEAN type."""
 
 NULLTYPE = NullType()
 
index e88c4b3b9b176c80d91a8405f969848e804d4d66..ff1108c3b29b40b418b809788b98a143695df909 100644 (file)
@@ -8,7 +8,7 @@ import inspect, itertools, new, operator, sets, sys, warnings, weakref
 import __builtin__
 types = __import__('types')
 
-from sqlalchemy import exceptions
+from sqlalchemy import exc
 
 try:
     import thread, threading
@@ -18,14 +18,16 @@ except ImportError:
 
 try:
     Set = set
+    FrozenSet = frozenset
     set_types = set, sets.Set
 except NameError:
     set_types = sets.Set,
-    # layer some of __builtin__.set's binop behavior onto sets.Set
-    class Set(sets.Set):
+
+    def py24_style_ops():
+        """Layer some of __builtin__.set's binop behavior onto sets.Set."""
+
         def _binary_sanity_check(self, other):
             pass
-
         def issubset(self, iterable):
             other = type(self)(iterable)
             return sets.Set.issubset(self, other)
@@ -38,7 +40,6 @@ except NameError:
         def __ge__(self, other):
             sets.Set._binary_sanity_check(self, other)
             return sets.Set.__ge__(self, other)
-
         # lt and gt still require a BaseSet
         def __lt__(self, other):
             sets.Set._binary_sanity_check(self, other)
@@ -63,6 +64,14 @@ except NameError:
             if not isinstance(other, sets.BaseSet):
                 return NotImplemented
             return sets.Set.__isub__(self, other)
+        return locals()
+
+    py24_style_ops = py24_style_ops()
+    Set = type('Set', (sets.Set,), py24_style_ops)
+    FrozenSet = type('FrozenSet', (sets.ImmutableSet,), py24_style_ops)
+    del py24_style_ops
+
+EMPTY_SET = FrozenSet()
 
 try:
     import cPickle as pickle
@@ -96,10 +105,16 @@ except ImportError:
 
 try:
     from operator import attrgetter
-except:
+except ImportError:
     def attrgetter(attribute):
         return lambda value: getattr(value, attribute)
 
+try:
+    from operator import itemgetter
+except ImportError:
+    def itemgetter(attribute):
+        return lambda value: value[attribute]
+
 if sys.version_info >= (2, 5):
     class PopulateDict(dict):
         """a dict which populates missing values via a creation function.
@@ -169,17 +184,17 @@ except ImportError:
     class deque(list):
         def appendleft(self, x):
             self.insert(0, x)
-        
+
         def extendleft(self, iterable):
             self[0:0] = list(iterable)
 
         def popleft(self):
             return self.pop(0)
-            
+
         def rotate(self, n):
             for i in xrange(n):
                 self.appendleft(self.pop())
-                
+
 def to_list(x, default=None):
     if x is None:
         return default
@@ -188,18 +203,34 @@ def to_list(x, default=None):
     else:
         return x
 
-def array_as_starargs_decorator(func):
+def array_as_starargs_decorator(fn):
     """Interpret a single positional array argument as
     *args for the decorated method.
-    
+
     """
+
     def starargs_as_list(self, *args, **kwargs):
-        if len(args) == 1:
-            return func(self, *to_list(args[0], []), **kwargs)
+        if isinstance(args, basestring) or (len(args) == 1 and not isinstance(args[0], tuple)):
+            return fn(self, *to_list(args[0], []), **kwargs)
         else:
-            return func(self, *args, **kwargs)
-    return starargs_as_list
-    
+            return fn(self, *args, **kwargs)
+    starargs_as_list.__doc__ = fn.__doc__
+    return function_named(starargs_as_list, fn.__name__)
+
+def array_as_starargs_fn_decorator(fn):
+    """Interpret a single positional array argument as
+    *args for the decorated function.
+
+    """
+
+    def starargs_as_list(*args, **kwargs):
+        if isinstance(args, basestring) or (len(args) == 1 and not isinstance(args[0], tuple)):
+            return fn(*to_list(args[0], []), **kwargs)
+        else:
+            return fn(*args, **kwargs)
+    starargs_as_list.__doc__ = fn.__doc__
+    return function_named(starargs_as_list, fn.__name__)
+
 def to_set(x):
     if x is None:
         return Set()
@@ -281,14 +312,121 @@ def get_func_kwargs(func):
     """Return the full set of legal kwargs for the given `func`."""
     return inspect.getargspec(func)[0]
 
+def format_argspec_plus(fn, grouped=True):
+    """Returns a dictionary of formatted, introspected function arguments.
+
+    A enhanced variant of inspect.formatargspec to support code generation.
+
+    fn
+       An inspectable callable
+    grouped
+      Defaults to True; include (parens, around, argument) lists
+
+    Returns:
+
+    args
+      Full inspect.formatargspec for fn
+    self_arg
+      The name of the first positional argument, or None
+    apply_pos
+      args, re-written in calling rather than receiving syntax.  Arguments are
+      passed positionally.
+    apply_kw
+      Like apply_pos, except keyword-ish args are passed as keywords.
+
+    Example::
+
+      >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123)
+      {'args': '(self, a, b, c=3, **d)',
+       'self_arg': 'self',
+       'apply_kw': '(self, a, b, c=c, **d)',
+       'apply_pos': '(self, a, b, c, **d)'}
+
+    """
+    spec = inspect.getargspec(fn)
+    args = inspect.formatargspec(*spec)
+    self_arg = spec[0] and spec[0][0] or None
+    apply_pos = inspect.formatargspec(spec[0], spec[1], spec[2])
+    defaulted_vals = spec[3] is not None and spec[0][0-len(spec[3]):] or ()
+    apply_kw = inspect.formatargspec(spec[0], spec[1], spec[2], defaulted_vals,
+                                     formatvalue=lambda x: '=' + x)
+    if grouped:
+        return dict(args=args, self_arg=self_arg,
+                    apply_pos=apply_pos, apply_kw=apply_kw)
+    else:
+        return dict(args=args[1:-1], self_arg=self_arg,
+                    apply_pos=apply_pos[1:-1], apply_kw=apply_kw[1:-1])
+
+def format_argspec_init(method, grouped=True):
+    """format_argspec_plus with considerations for typical __init__ methods
+
+    Wraps format_argspec_plus with error handling strategies for typical
+    __init__ cases::
+
+      object.__init__ -> (self)
+      other unreflectable (usually C) -> (self, *args, **kwargs)
+
+    """
+    try:
+        return format_argspec_plus(method, grouped=grouped)
+    except TypeError:
+        self_arg = 'self'
+        if method is object.__init__:
+            args = grouped and '(self)' or 'self'
+        else:
+            args = (grouped and '(self, *args, **kwargs)'
+                            or 'self, *args, **kwargs')
+        return dict(self_arg='self', args=args, apply_pos=args, apply_kw=args)
+
+def getargspec_init(method):
+    """inspect.getargspec with considerations for typical __init__ methods
+
+    Wraps inspect.getargspec with error handling for typical __init__ cases::
+
+      object.__init__ -> (self)
+      other unreflectable (usually C) -> (self, *args, **kwargs)
+
+    """
+    try:
+        return inspect.getargspec(method)
+    except TypeError:
+        if method is object.__init__:
+            return (['self'], None, None, None)
+        else:
+            return (['self'], 'args', 'kwargs', None)
+
 def unbound_method_to_callable(func_or_cls):
     """Adjust the incoming callable such that a 'self' argument is not required."""
-    
+
     if isinstance(func_or_cls, types.MethodType) and not func_or_cls.im_self:
         return func_or_cls.im_func
     else:
         return func_or_cls
 
+def class_hierarchy(cls):
+    """Return an unordered sequence of all classes related to cls.
+
+    Traverses diamond hierarchies.
+
+    Fibs slightly: subclasses of builtin types are not returned.  Thus
+    class_hierarchy(class A(object)) returns (A, object), not A plus every
+    class systemwide that derives from object.
+
+    """
+    hier = Set([cls])
+    process = list(cls.__mro__)
+    while process:
+        c = process.pop()
+        for b in [_ for _ in c.__bases__ if _ not in hier]:
+            process.append(b)
+            hier.add(b)
+        if c.__module__ == '__builtin__':
+            continue
+        for s in [_ for _ in c.__subclasses__() if _ not in hier]:
+            process.append(s)
+            hier.add(s)
+    return list(hier)
+
 # from paste.deploy.converters
 def asbool(obj):
     if isinstance(obj, (str, unicode)):
@@ -328,9 +466,12 @@ def duck_type_collection(specimen, default=None):
             return specimen.__emulates__
 
     isa = isinstance(specimen, type) and issubclass or isinstance
-    if isa(specimen, list): return list
-    if isa(specimen, set_types): return Set
-    if isa(specimen, dict): return dict
+    if isa(specimen, list):
+        return list
+    elif isa(specimen, set_types):
+        return Set
+    elif isa(specimen, dict):
+        return dict
 
     if hasattr(specimen, 'append'):
         return list
@@ -370,10 +511,23 @@ def assert_arg_type(arg, argtype, name):
         return arg
     else:
         if isinstance(argtype, tuple):
-            raise exceptions.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg))))
+            raise exc.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg))))
         else:
-            raise exceptions.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg))))
+            raise exc.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg))))
 
+_creation_order = 1
+def set_creation_order(instance):
+    """assign a '_creation_order' sequence to the given instance.
+    
+    This allows multiple instances to be sorted in order of
+    creation (typically within a single thread; the counter is
+    not particularly threadsafe).
+    
+    """
+    global _creation_order
+    instance._creation_order = _creation_order
+    _creation_order +=1
+    
 def warn_exception(func, *args, **kwargs):
     """executes the given function, catches all exceptions and converts to a warning."""
     try:
@@ -430,22 +584,22 @@ class SimpleProperty(object):
 
 
 class NotImplProperty(object):
-  """a property that raises ``NotImplementedError``."""
+    """a property that raises ``NotImplementedError``."""
 
-  def __init__(self, doc):
-      self.__doc__ = doc
+    def __init__(self, doc):
+        self.__doc__ = doc
 
-  def __set__(self, obj, value):
-      raise NotImplementedError()
+    def __set__(self, obj, value):
+        raise NotImplementedError()
 
-  def __delete__(self, obj):
-      raise NotImplementedError()
+    def __delete__(self, obj):
+        raise NotImplementedError()
 
-  def __get__(self, obj, owner):
-      if obj is None:
-          return self
-      else:
-          raise NotImplementedError()
+    def __get__(self, obj, owner):
+        if obj is None:
+            return self
+        else:
+            raise NotImplementedError()
 
 class OrderedProperties(object):
     """An object that maintains the order in which attributes are set upon it.
@@ -496,10 +650,10 @@ class OrderedProperties(object):
 
     def __contains__(self, key):
         return key in self._data
-    
+
     def update(self, value):
         self._data.update(value)
-        
+
     def get(self, key, default=None):
         if key in self:
             return self[key]
@@ -529,7 +683,10 @@ class OrderedDict(dict):
     def clear(self):
         self._list = []
         dict.clear(self)
-
+    
+    def sort(self, fn=None):
+        self._list.sort(fn)
+        
     def update(self, ____sequence=None, **kwargs):
         if ____sequence is not None:
             if hasattr(____sequence, 'keys'):
@@ -622,22 +779,24 @@ class OrderedSet(Set):
         if d is not None:
             self.update(d)
 
-    def add(self, key):
-        if key not in self:
-            self._list.append(key)
-        Set.add(self, key)
+    def add(self, element):
+        if element not in self:
+            self._list.append(element)
+        Set.add(self, element)
 
     def remove(self, element):
         Set.remove(self, element)
         self._list.remove(element)
 
+    def insert(self, pos, element):
+        if element not in self:
+            self._list.insert(pos, element)
+        Set.add(self, element)
+
     def discard(self, element):
-        try:
-            Set.remove(self, element)
-        except KeyError:
-            pass
-        else:
+        if element in self:
             self._list.remove(element)
+            Set.remove(self, element)
 
     def clear(self):
         Set.clear(self)
@@ -650,22 +809,22 @@ class OrderedSet(Set):
         return iter(self._list)
 
     def __repr__(self):
-      return '%s(%r)' % (self.__class__.__name__, self._list)
+        return '%s(%r)' % (self.__class__.__name__, self._list)
 
     __str__ = __repr__
 
     def update(self, iterable):
-      add = self.add
-      for i in iterable:
-          add(i)
-      return self
+        add = self.add
+        for i in iterable:
+            add(i)
+        return self
 
     __ior__ = update
 
     def union(self, other):
-      result = self.__class__(self)
-      result.update(other)
-      return result
+        result = self.__class__(self)
+        result.update(other)
+        return result
 
     __or__ = union
 
@@ -698,10 +857,10 @@ class OrderedSet(Set):
     __iand__ = intersection_update
 
     def symmetric_difference_update(self, other):
-      Set.symmetric_difference_update(self, other)
-      self._list =  [ a for a in self._list if a in self]
-      self._list += [ a for a in other._list if a in self]
-      return self
+        Set.symmetric_difference_update(self, other)
+        self._list =  [ a for a in self._list if a in self]
+        self._list += [ a for a in other._list if a in self]
+        return self
 
     __ixor__ = symmetric_difference_update
 
@@ -1021,6 +1180,35 @@ class ScopedRegistry(object):
     def _get_key(self):
         return self.scopefunc()
 
+class WeakCompositeKey(object):
+    """an weak-referencable, hashable collection which is strongly referenced
+    until any one of its members is garbage collected.
+
+    """
+    keys = Set()
+
+    def __init__(self, *args):
+        self.args = [self.__ref(arg) for arg in args]
+        WeakCompositeKey.keys.add(self)
+
+    def __ref(self, arg):
+        if isinstance(arg, type):
+            return weakref.ref(arg, self.__remover)
+        else:
+            return lambda: arg
+
+    def __remover(self, wr):
+        WeakCompositeKey.keys.discard(self)
+
+    def __hash__(self):
+        return hash(tuple(self))
+
+    def __cmp__(self, other):
+        return cmp(tuple(self), tuple(other))
+
+    def __iter__(self):
+        return iter([arg() for arg in self.args])
+
 class _symbol(object):
     def __init__(self, name):
         """Construct a new named symbol."""
@@ -1059,7 +1247,6 @@ class symbol(object):
         finally:
             symbol._lock.release()
 
-
 def as_interface(obj, cls=None, methods=None, required=None):
     """Ensure basic interface compliance for an instance or dict of callables.
 
@@ -1155,21 +1342,12 @@ def function_named(fn, name):
                           fn.func_defaults, fn.func_closure)
     return fn
 
-def conditional_cache_decorator(func):
-    """apply conditional caching to the return value of a function."""
-
-    return cache_decorator(func, conditional=True)
-
-def cache_decorator(func, conditional=False):
+def cache_decorator(func):
     """apply caching to the return value of a function."""
 
     name = '_cached_' + func.__name__
-    
+
     def do_with_cache(self, *args, **kwargs):
-        if conditional:
-            cache = kwargs.pop('cache', False)
-            if not cache:
-                return func(self, *args, **kwargs)
         try:
             return getattr(self, name)
         except AttributeError:
@@ -1177,21 +1355,109 @@ def cache_decorator(func, conditional=False):
             setattr(self, name, value)
             return value
     return do_with_cache
-    
+
 def reset_cached(instance, name):
     try:
         delattr(instance, '_cached_' + name)
     except AttributeError:
         pass
 
+class WeakIdentityMapping(weakref.WeakKeyDictionary):
+    """A WeakKeyDictionary with an object identity index.
+
+    Adds a .by_id dictionary to a regular WeakKeyDictionary.  Trades
+    performance during mutation operations for accelerated lookups by id().
+
+    The usual cautions about weak dictionaries and iteration also apply to
+    this subclass.
+
+    """
+    _none = symbol('none')
+
+    def __init__(self):
+        weakref.WeakKeyDictionary.__init__(self)
+        self.by_id = {}
+        self._weakrefs = {}
+
+    def __setitem__(self, object, value):
+        oid = id(object)
+        self.by_id[oid] = value
+        if oid not in self._weakrefs:
+            self._weakrefs[oid] = self._ref(object)
+        weakref.WeakKeyDictionary.__setitem__(self, object, value)
+
+    def __delitem__(self, object):
+        del self._weakrefs[id(object)]
+        del self.by_id[id(object)]
+        weakref.WeakKeyDictionary.__delitem__(self, object)
+
+    def setdefault(self, object, default=None):
+        value = weakref.WeakKeyDictionary.setdefault(self, object, default)
+        oid = id(object)
+        if value is default:
+            self.by_id[oid] = default
+        if oid not in self._weakrefs:
+            self._weakrefs[oid] = self._ref(object)
+        return value
+
+    def pop(self, object, default=_none):
+        if default is self._none:
+            value = weakref.WeakKeyDictionary.pop(self, object)
+        else:
+            value = weakref.WeakKeyDictionary.pop(self, object, default)
+        if id(object) in self.by_id:
+            del self._weakrefs[id(object)]
+            del self.by_id[id(object)]
+        return value
+
+    def popitem(self):
+        item = weakref.WeakKeyDictionary.popitem(self)
+        oid = id(item[0])
+        del self._weakrefs[oid]
+        del self.by_id[oid]
+        return item
+
+    def clear(self):
+        self._weakrefs.clear()
+        self.by_id.clear()
+        weakref.WeakKeyDictionary.clear(self)
+
+    def update(self, *a, **kw):
+        raise NotImplementedError
+
+    def _cleanup(self, wr, key=None):
+        if key is None:
+            key = wr.key
+        try:
+            del self._weakrefs[key]
+        except (KeyError, AttributeError):  # pragma: no cover
+            pass                            # pragma: no cover
+        try:
+            del self.by_id[key]
+        except (KeyError, AttributeError):  # pragma: no cover
+            pass                            # pragma: no cover
+    if sys.version_info < (2, 4):           # pragma: no cover
+        def _ref(self, object):
+            oid = id(object)
+            return weakref.ref(object, lambda wr: self._cleanup(wr, oid))
+    else:
+        class _keyed_weakref(weakref.ref):
+            def __init__(self, object, callback):
+                weakref.ref.__init__(self, object, callback)
+                self.key = id(object)
+
+        def _ref(self, object):
+            return self._keyed_weakref(object, self._cleanup)
+
+
 def warn(msg):
     if isinstance(msg, basestring):
-        warnings.warn(msg, exceptions.SAWarning, stacklevel=3)
+        warnings.warn(msg, exc.SAWarning, stacklevel=3)
     else:
         warnings.warn(msg, stacklevel=3)
 
 def warn_deprecated(msg):
-    warnings.warn(msg, exceptions.SADeprecationWarning, stacklevel=3)
+    warnings.warn(msg, exc.SADeprecationWarning, stacklevel=3)
 
 def deprecated(message=None, add_deprecation_to_docstring=True):
     """Decorates a function and issues a deprecation warning on use.
@@ -1216,7 +1482,7 @@ def deprecated(message=None, add_deprecation_to_docstring=True):
 
     def decorate(fn):
         return _decorate_with_warning(
-            fn, exceptions.SADeprecationWarning,
+            fn, exc.SADeprecationWarning,
             message % dict(func=fn.__name__), header)
     return decorate
 
@@ -1248,7 +1514,7 @@ def pending_deprecation(version, message=None,
 
     def decorate(fn):
         return _decorate_with_warning(
-            fn, exceptions.SAPendingDeprecationWarning,
+            fn, exc.SAPendingDeprecationWarning,
             message % dict(func=fn.__name__), header)
     return decorate
 
index 25d34ffd39a018a8e218fe6fe683e6e74856a37f..f891bc92e583af060fe6ab1d069e9c8a42b683b7 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import sqlalchemy.topological as topological
 from sqlalchemy import util
-from testlib import *
+from testlib import TestBase
 
 
 class DependencySortTest(TestBase):
index 84b84793ce14276a3ba58df1800432f3ff089389..cbbb941c67e2571ccf1f60ee7b4441e7bda08c77 100644 (file)
@@ -1,9 +1,8 @@
 """Tests exceptions and DB-API exception wrapping."""
 import testenv; testenv.configure_for_tests()
-import sys, unittest
+import unittest
 import exceptions as stdlib_exceptions
-from sqlalchemy import exceptions as sa_exceptions
-from testlib import *
+from sqlalchemy import exc as sa_exceptions
 
 
 class Error(stdlib_exceptions.StandardError):
@@ -48,10 +47,10 @@ class WrapTest(unittest.TestCase):
         # subclasses of sqlalchemy.exceptions.DBAPIError
         try:
             raise sa_exceptions.DBAPIError.instance(
-                '', [], sa_exceptions.AssertionError())
+                '', [], sa_exceptions.ArgumentError())
         except sa_exceptions.DBAPIError, e:
             self.assert_(e.__class__ is sa_exceptions.DBAPIError)
-        except sa_exceptions.AssertionError:
+        except sa_exceptions.ArgumentError:
             self.assert_(False)
 
     def test_db_error_keyboard_interrupt(self):
index a00338f5f52ec50d43d62caf1c6d269e6529d9fa..070ffb5835552a7c3f8f4d189b630f3d23283f05 100644 (file)
@@ -1,8 +1,9 @@
 import testenv; testenv.configure_for_tests()
-import unittest
-from sqlalchemy import util, sql, exceptions
-from testlib import *
-from testlib import sorted
+import threading, unittest
+from sqlalchemy import util, sql, exc
+from testlib import TestBase
+from testlib.testing import eq_, is_, ne_
+from testlib.compat import frozenset, set, sorted
 
 class OrderedDictTest(TestBase):
     def test_odict(self):
@@ -12,40 +13,37 @@ class OrderedDictTest(TestBase):
         o['snack'] = 'attack'
         o['c'] = 3
 
-        self.assert_(o.keys() == ['a', 'b', 'snack', 'c'])
-        self.assert_(o.values() == [1, 2, 'attack', 3])
+        eq_(o.keys(), ['a', 'b', 'snack', 'c'])
+        eq_(o.values(), [1, 2, 'attack', 3])
 
         o.pop('snack')
 
-        self.assert_(o.keys() == ['a', 'b', 'c'])
-        self.assert_(o.values() == [1, 2, 3])
+        eq_(o.keys(), ['a', 'b', 'c'])
+        eq_(o.values(), [1, 2, 3])
 
         o2 = util.OrderedDict(d=4)
         o2['e'] = 5
 
-        self.assert_(o2.keys() == ['d', 'e'])
-        self.assert_(o2.values() == [4, 5])
+        eq_(o2.keys(), ['d', 'e'])
+        eq_(o2.values(), [4, 5])
 
         o.update(o2)
-        self.assert_(o.keys() == ['a', 'b', 'c', 'd', 'e'])
-        self.assert_(o.values() == [1, 2, 3, 4, 5])
+        eq_(o.keys(), ['a', 'b', 'c', 'd', 'e'])
+        eq_(o.values(), [1, 2, 3, 4, 5])
 
         o.setdefault('c', 'zzz')
         o.setdefault('f', 6)
-        self.assert_(o.keys() == ['a', 'b', 'c', 'd', 'e', 'f'])
-        self.assert_(o.values() == [1, 2, 3, 4, 5, 6])
+        eq_(o.keys(), ['a', 'b', 'c', 'd', 'e', 'f'])
+        eq_(o.values(), [1, 2, 3, 4, 5, 6])
 
 class OrderedSetTest(TestBase):
     def test_mutators_against_iter(self):
         # testing a set modified against an iterator
         o = util.OrderedSet([3,2, 4, 5])
 
-        self.assertEquals(o.difference(iter([3,4])),
-                          util.OrderedSet([2,5]))
-        self.assertEquals(o.intersection(iter([3,4, 6])),
-                          util.OrderedSet([3, 4]))
-        self.assertEquals(o.union(iter([3,4, 6])),
-                          util.OrderedSet([2, 3, 4, 5, 6]))
+        eq_(o.difference(iter([3,4])), util.OrderedSet([2,5]))
+        eq_(o.intersection(iter([3,4, 6])), util.OrderedSet([3, 4]))
+        eq_(o.union(iter([3,4, 6])), util.OrderedSet([2, 3, 4, 5, 6]))
 
 class ColumnCollectionTest(TestBase):
     def test_in(self):
@@ -59,8 +57,8 @@ class ColumnCollectionTest(TestBase):
         try:
             cc['col1'] in cc
             assert False
-        except exceptions.ArgumentError, e:
-            assert str(e) == "__contains__ requires a string argument"
+        except exc.ArgumentError, e:
+            eq_(str(e), "__contains__ requires a string argument")
 
     def test_compare(self):
         cc1 = sql.ColumnCollection()
@@ -90,11 +88,11 @@ class ArgSingletonTest(unittest.TestCase):
         m3 = MyClass(3, 4)
         assert m1 is m3
         assert m2 is not m3
-        assert len(util.ArgSingleton.instances) == 2
+        eq_(len(util.ArgSingleton.instances), 2)
 
         m1 = m2 = m3 = None
         MyClass.dispose(MyClass)
-        assert len(util.ArgSingleton.instances) == 0
+        eq_(len(util.ArgSingleton.instances), 0)
 
 
 class ImmutableSubclass(str):
@@ -140,7 +138,7 @@ class IdentitySetTest(unittest.TestCase):
     def assert_eq(self, identityset, expected_iterable):
         expected = sorted([id(o) for o in expected_iterable])
         found = sorted([id(o) for o in identityset])
-        self.assertEquals(found, expected)
+        eq_(found, expected)
 
     def test_init(self):
         ids = util.IdentitySet([1,2,3,2,1])
@@ -184,32 +182,35 @@ class IdentitySetTest(unittest.TestCase):
         ids.remove(o1)
         self.assertRaises(KeyError, ids.remove, o1)
 
-        self.assert_(ids.copy() == ids)
-        self.assert_(ids != None)
-        self.assert_(not(ids == None))
-        self.assert_(ids != IdentitySet([o1,o2,o3]))
+        eq_(ids.copy(), ids)
+
+        # explicit __eq__ and __ne__ tests
+        assert ids != None
+        assert not(ids == None)
+
+        ne_(ids, IdentitySet([o1,o2,o3]))
         ids.clear()
-        self.assert_(o1 not in ids)
+        assert o1 not in ids
         ids.add(o2)
-        self.assert_(o2 in ids)
-        self.assert_(ids.pop() == o2)
+        assert o2 in ids
+        eq_(ids.pop(), o2)
         ids.add(o1)
-        self.assert_(len(ids) == 1)
+        eq_(len(ids), 1)
 
         isuper = IdentitySet([o1,o2])
-        self.assert_(ids < isuper)
-        self.assert_(ids.issubset(isuper))
-        self.assert_(isuper.issuperset(ids))
-        self.assert_(isuper > ids)
-
-        self.assert_(ids.union(isuper) == isuper)
-        self.assert_(ids | isuper == isuper)
-        self.assert_(isuper - ids == IdentitySet([o2]))
-        self.assert_(isuper.difference(ids) == IdentitySet([o2]))
-        self.assert_(ids.intersection(isuper) == IdentitySet([o1]))
-        self.assert_(ids & isuper == IdentitySet([o1]))
-        self.assert_(ids.symmetric_difference(isuper) == IdentitySet([o2]))
-        self.assert_(ids ^ isuper == IdentitySet([o2]))
+        assert ids < isuper
+        assert ids.issubset(isuper)
+        assert isuper.issuperset(ids)
+        assert isuper > ids
+
+        eq_(ids.union(isuper), isuper)
+        eq_(ids | isuper, isuper)
+        eq_(isuper - ids, IdentitySet([o2]))
+        eq_(isuper.difference(ids), IdentitySet([o2]))
+        eq_(ids.intersection(isuper), IdentitySet([o1]))
+        eq_(ids & isuper, IdentitySet([o1]))
+        eq_(ids.symmetric_difference(isuper), IdentitySet([o2]))
+        eq_(ids ^ isuper, IdentitySet([o2]))
 
         ids.update(isuper)
         ids |= isuper
@@ -223,16 +224,16 @@ class IdentitySetTest(unittest.TestCase):
         ids.update('foobar')
         try:
             ids |= 'foobar'
-            self.assert_(False)
+            assert False
         except TypeError:
-            self.assert_(True)
+            assert True
 
         try:
             s = set([o1,o2])
             s |= ids
-            self.assert_(False)
+            assert False
         except TypeError:
-            self.assert_(True)
+            assert True
 
         self.assertRaises(TypeError, cmp, ids)
         self.assertRaises(TypeError, hash, ids)
@@ -243,8 +244,8 @@ class IdentitySetTest(unittest.TestCase):
         s1 = set([1,2,3])
         s2 = set([3,4,5])
 
-        self.assertEquals(os1 - os2, util.IdentitySet([1, 2]))
-        self.assertEquals(os2 - os1, util.IdentitySet([4, 5]))
+        eq_(os1 - os2, util.IdentitySet([1, 2]))
+        eq_(os2 - os1, util.IdentitySet([4, 5]))
         self.assertRaises(TypeError, lambda: os1 - s2)
         self.assertRaises(TypeError, lambda: os1 - [3, 4, 5])
         self.assertRaises(TypeError, lambda: s1 - os2)
@@ -256,7 +257,7 @@ class DictlikeIteritemsTest(unittest.TestCase):
 
     def _ok(self, instance):
         iterator = util.dictlike_iteritems(instance)
-        self.assertEquals(set(iterator), self.baseline)
+        eq_(set(iterator), self.baseline)
 
     def _notok(self, instance):
         self.assertRaises(TypeError,
@@ -322,6 +323,33 @@ class DictlikeIteritemsTest(unittest.TestCase):
         self._notok(duck6())
 
 
+class DuckTypeCollectionTest(TestBase):
+    def test_sets(self):
+        import sets
+        class SetLike(object):
+            def add(self):
+                pass
+
+        class ForcedSet(list):
+            __emulates__ = set
+
+        for type_ in (set,
+                      sets.Set,
+                      util.Set,
+                      SetLike,
+                      ForcedSet):
+            eq_(util.duck_type_collection(type_), util.Set)
+            instance = type_()
+            eq_(util.duck_type_collection(instance), util.Set)
+
+        for type_ in (frozenset,
+                      sets.ImmutableSet,
+                      util.FrozenSet):
+            is_(util.duck_type_collection(type_), None)
+            instance = type_()
+            is_(util.duck_type_collection(instance), None)
+
+
 class ArgInspectionTest(TestBase):
     def test_get_cls_kwargs(self):
         class A(object):
@@ -359,7 +387,7 @@ class ArgInspectionTest(TestBase):
             pass
 
         def test(cls, *expected):
-            self.assertEquals(set(util.get_cls_kwargs(cls)), set(expected))
+            eq_(set(util.get_cls_kwargs(cls)), set(expected))
 
         test(A, 'a')
         test(A1, 'a1')
@@ -382,7 +410,7 @@ class ArgInspectionTest(TestBase):
         def f4(**foo): pass
 
         def test(fn, *expected):
-            self.assertEquals(set(util.get_func_kwargs(fn)), set(expected))
+            eq_(set(util.get_func_kwargs(fn)), set(expected))
 
         test(f1)
         test(f2, 'foo')
@@ -419,7 +447,336 @@ class SymbolTest(TestBase):
             assert rt is sym1
             assert rt is sym2
 
+class WeakIdentityMappingTest(TestBase):
+    class Data(object):
+        pass
+
+    def _some_data(self, some=20):
+        return [self.Data() for _ in xrange(some)]
+
+    def _fixture(self, some=20):
+        data = self._some_data()
+        wim = util.WeakIdentityMapping()
+        for idx, obj in enumerate(data):
+            wim[obj] = idx
+        return data, wim
+
+    def test_delitem(self):
+        data, wim = self._fixture()
+        needle = data[-1]
+
+        assert needle in wim
+        assert id(needle) in wim.by_id
+        eq_(wim[needle], wim.by_id[id(needle)])
+
+        del wim[needle]
+
+        assert needle not in wim
+        assert id(needle) not in wim.by_id
+        eq_(len(wim), (len(data) - 1))
+
+        data.remove(needle)
+
+        assert needle not in wim
+        assert id(needle) not in wim.by_id
+        eq_(len(wim), len(data))
+
+    def test_setitem(self):
+        data, wim = self._fixture()
+
+        o1, oid1 = data[-1], id(data[-1])
+
+        assert o1 in wim
+        assert oid1 in wim.by_id
+        eq_(wim[o1], wim.by_id[oid1])
+        id_keys = set(wim.by_id.keys())
+
+        wim[o1] = 1234
+        assert o1 in wim
+        assert oid1 in wim.by_id
+        eq_(wim[o1], wim.by_id[oid1])
+        eq_(set(wim.by_id.keys()), id_keys)
+
+        o2 = self.Data()
+        oid2 = id(o2)
+
+        wim[o2] = 5678
+        assert o2 in wim
+        assert oid2 in wim.by_id
+        eq_(wim[o2], wim.by_id[oid2])
+
+    def test_pop(self):
+        data, wim = self._fixture()
+        needle = data[-1]
+
+        needle = data.pop()
+        assert needle in wim
+        assert id(needle) in wim.by_id
+        eq_(wim[needle], wim.by_id[id(needle)])
+        eq_(len(wim), (len(data) + 1))
+
+        wim.pop(needle)
+        assert needle not in wim
+        assert id(needle) not in wim.by_id
+        eq_(len(wim), len(data))
+
+    def test_pop_default(self):
+        data, wim = self._fixture()
+        needle = data[-1]
+
+        value = wim[needle]
+        x = wim.pop(needle, 123)
+        ne_(x, 123)
+        eq_(x, value)
+        assert needle not in wim
+        assert id(needle) not in wim.by_id
+        eq_(len(data), (len(wim) + 1))
+
+        n2 = self.Data()
+        y = wim.pop(n2, 456)
+        eq_(y, 456)
+        assert n2 not in wim
+        assert id(n2) not in wim.by_id
+        eq_(len(data), (len(wim) + 1))
+
+    def test_popitem(self):
+        data, wim = self._fixture()
+        (needle, idx) = wim.popitem()
+
+        assert needle in data
+        eq_(len(data), (len(wim) + 1))
+        assert id(needle) not in wim.by_id
+
+    def test_setdefault(self):
+        data, wim = self._fixture()
+
+        o1 = self.Data()
+        oid1 = id(o1)
+
+        assert o1 not in wim
+
+        res1 = wim.setdefault(o1, 123)
+        assert o1 in wim
+        assert oid1 in wim.by_id
+        eq_(res1, 123)
+        id_keys = set(wim.by_id.keys())
+
+        res2 = wim.setdefault(o1, 456)
+        assert o1 in wim
+        assert oid1 in wim.by_id
+        eq_(res2, 123)
+        assert set(wim.by_id.keys()) == id_keys
+
+        del wim[o1]
+        assert o1 not in wim
+        assert oid1 not in wim.by_id
+        ne_(set(wim.by_id.keys()), id_keys)
+
+        res3 = wim.setdefault(o1, 789)
+        assert o1 in wim
+        assert oid1 in wim.by_id
+        eq_(res3, 789)
+        eq_(set(wim.by_id.keys()), id_keys)
+
+    def test_clear(self):
+        data, wim = self._fixture()
+
+        assert len(data) == len(wim) == len(wim.by_id)
+        wim.clear()
+
+        eq_(wim, {})
+        eq_(wim.by_id, {})
+
+    def test_update(self):
+        data, wim = self._fixture()
+        self.assertRaises(NotImplementedError, wim.update)
+
+    def test_weak_clear(self):
+        data, wim = self._fixture()
+
+        assert len(data) == len(wim) == len(wim.by_id)
+
+        del data[:]
+        eq_(wim, {})
+        eq_(wim.by_id, {})
+        eq_(wim._weakrefs, {})
+
+    def test_weak_single(self):
+        data, wim = self._fixture()
+
+        assert len(data) == len(wim) == len(wim.by_id)
+
+        oid = id(data[0])
+        del data[0]
+
+        assert len(data) == len(wim) == len(wim.by_id)
+        assert oid not in wim.by_id
+
+    def test_weak_threadhop(self):
+        data, wim = self._fixture()
+        data = set(data)
+
+        cv = threading.Condition()
+
+        def empty(obj):
+            cv.acquire()
+            obj.clear()
+            cv.notify()
+            cv.release()
+
+        th = threading.Thread(target=empty, args=(data,))
+
+        cv.acquire()
+        th.start()
+        cv.wait()
+        cv.release()
+
+        eq_(wim, {})
+        eq_(wim.by_id, {})
+        eq_(wim._weakrefs, {})
+
+
+class TestFormatArgspec(TestBase):
+    def test_specs(self):
+        def test(fn, wanted, grouped=None):
+            if grouped is None:
+                parsed = util.format_argspec_plus(fn)
+            else:
+                parsed = util.format_argspec_plus(fn, grouped=grouped)
+            eq_(parsed, wanted)
+
+        test(lambda: None,
+           {'args': '()', 'self_arg': None,
+            'apply_kw': '()', 'apply_pos': '()' })
+
+        test(lambda: None,
+           {'args': '', 'self_arg': None,
+            'apply_kw': '', 'apply_pos': '' },
+           grouped=False)
+
+        test(lambda self: None,
+           {'args': '(self)', 'self_arg': 'self',
+            'apply_kw': '(self)', 'apply_pos': '(self)' })
+
+        test(lambda self: None,
+           {'args': 'self', 'self_arg': 'self',
+            'apply_kw': 'self', 'apply_pos': 'self' },
+           grouped=False)
+
+        test(lambda *a: None,
+           {'args': '(*a)', 'self_arg': None,
+            'apply_kw': '(*a)', 'apply_pos': '(*a)' })
+
+        test(lambda **kw: None,
+           {'args': '(**kw)', 'self_arg': None,
+            'apply_kw': '(**kw)', 'apply_pos': '(**kw)' })
+
+        test(lambda *a, **kw: None,
+           {'args': '(*a, **kw)', 'self_arg': None,
+            'apply_kw': '(*a, **kw)', 'apply_pos': '(*a, **kw)' })
+
+        test(lambda a, *b: None,
+           {'args': '(a, *b)', 'self_arg': 'a',
+            'apply_kw': '(a, *b)', 'apply_pos': '(a, *b)' })
+
+        test(lambda a, **b: None,
+           {'args': '(a, **b)', 'self_arg': 'a',
+            'apply_kw': '(a, **b)', 'apply_pos': '(a, **b)' })
+
+        test(lambda a, *b, **c: None,
+           {'args': '(a, *b, **c)', 'self_arg': 'a',
+            'apply_kw': '(a, *b, **c)', 'apply_pos': '(a, *b, **c)' })
+
+        test(lambda a, b=1, **c: None,
+           {'args': '(a, b=1, **c)', 'self_arg': 'a',
+            'apply_kw': '(a, b=b, **c)', 'apply_pos': '(a, b, **c)' })
+
+        test(lambda a=1, b=2: None,
+           {'args': '(a=1, b=2)', 'self_arg': 'a',
+            'apply_kw': '(a=a, b=b)', 'apply_pos': '(a, b)' })
+
+        test(lambda a=1, b=2: None,
+           {'args': 'a=1, b=2', 'self_arg': 'a',
+            'apply_kw': 'a=a, b=b', 'apply_pos': 'a, b' },
+           grouped=False)
+
+    def test_init_grouped(self):
+        object_spec = {
+            'args': '(self)', 'self_arg': 'self',
+            'apply_pos': '(self)', 'apply_kw': '(self)'}
+        wrapper_spec = {
+            'args': '(self, *args, **kwargs)', 'self_arg': 'self',
+            'apply_pos': '(self, *args, **kwargs)',
+            'apply_kw': '(self, *args, **kwargs)'}
+        custom_spec = {
+            'args': '(slef, a=123)', 'self_arg': 'slef', # yes, slef
+            'apply_pos': '(slef, a)', 'apply_kw': '(slef, a=a)'}
+
+        self._test_init(None, object_spec, wrapper_spec, custom_spec)
+        self._test_init(True, object_spec, wrapper_spec, custom_spec)
+
+    def test_init_bare(self):
+        object_spec = {
+            'args': 'self', 'self_arg': 'self',
+            'apply_pos': 'self', 'apply_kw': 'self'}
+        wrapper_spec = {
+            'args': 'self, *args, **kwargs', 'self_arg': 'self',
+            'apply_pos': 'self, *args, **kwargs',
+            'apply_kw': 'self, *args, **kwargs'}
+        custom_spec = {
+            'args': 'slef, a=123', 'self_arg': 'slef', # yes, slef
+            'apply_pos': 'slef, a', 'apply_kw': 'slef, a=a'}
+
+        self._test_init(False, object_spec, wrapper_spec, custom_spec)
+
+    def _test_init(self, grouped, object_spec, wrapper_spec, custom_spec):
+        def test(fn, wanted):
+            if grouped is None:
+                parsed = util.format_argspec_init(fn)
+            else:
+                parsed = util.format_argspec_init(fn, grouped=grouped)
+            eq_(parsed, wanted)
+
+        class O(object): pass
+
+        test(O.__init__, object_spec)
+
+        class O(object):
+            def __init__(self):
+                pass
+
+        test(O.__init__, object_spec)
+
+        class O(object):
+            def __init__(slef, a=123):
+                pass
+
+        test(O.__init__, custom_spec)
+
+        class O(list): pass
+
+        test(O.__init__, wrapper_spec)
+
+        class O(list):
+            def __init__(self, *args, **kwargs):
+                pass
+
+        test(O.__init__, wrapper_spec)
+
+        class O(list):
+            def __init__(self):
+                pass
+
+        test(O.__init__, object_spec)
+
+        class O(list):
+            def __init__(slef, a=123):
+                pass
+
+        test(O.__init__, custom_spec)
+
 class AsInterfaceTest(TestBase):
+
     class Something(object):
         def _ignoreme(self): pass
         def foo(self): pass
@@ -442,9 +799,9 @@ class AsInterfaceTest(TestBase):
                           cls=self.Something, required=('foo'))
 
         obj = self.Something()
-        self.assertEqual(obj, util.as_interface(obj, cls=self.Something))
-        self.assertEqual(obj, util.as_interface(obj, methods=('foo',)))
-        self.assertEqual(
+        eq_(obj, util.as_interface(obj, cls=self.Something))
+        eq_(obj, util.as_interface(obj, methods=('foo',)))
+        eq_(
             obj, util.as_interface(obj, cls=self.Something,
                                    required=('outofband',)))
         partial = self.Partial()
@@ -453,12 +810,11 @@ class AsInterfaceTest(TestBase):
         slotted.bar = lambda self: 123
 
         for obj in partial, slotted:
-            self.assertEqual(obj, util.as_interface(obj, cls=self.Something))
+            eq_(obj, util.as_interface(obj, cls=self.Something))
             self.assertRaises(TypeError, util.as_interface, obj,
                               methods=('foo'))
-            self.assertEqual(obj, util.as_interface(obj, methods=('bar',)))
-            self.assertEqual(
-                obj, util.as_interface(obj, cls=self.Something,
+            eq_(obj, util.as_interface(obj, methods=('bar',)))
+            eq_(obj, util.as_interface(obj, cls=self.Something,
                                        required=('bar',)))
             self.assertRaises(TypeError, util.as_interface, obj,
                               cls=self.Something, required=('foo',))
index f929443fd42eb8177f3922d9c79b429c81521fcf..da6cc697094884679234be5042b4a5e02026fe7e 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.databases import firebird
-from sqlalchemy.exceptions import ProgrammingError
+from sqlalchemy.exc import ProgrammingError
 from sqlalchemy.sql import table, column
 from testlib import *
 
index 0a35f54705ee9746be7fb4d80cac26ce8ff7cb9a..f0bcd00e139e51349c0151b8d077b5b521d46fa0 100644 (file)
@@ -3,7 +3,7 @@
 import testenv; testenv.configure_for_tests()
 import StringIO, sys
 from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc, sql
 from sqlalchemy.util import Decimal
 from sqlalchemy.databases import maxdb
 from testlib import *
@@ -53,7 +53,7 @@ class ReflectionTest(TestBase, AssertsExecutionResults):
         finally:
             try:
                 testing.db.execute("DROP TABLE dectest")
-            except exceptions.DatabaseError:
+            except exc.DatabaseError:
                 pass
 
     def test_decimal_fixed_serial(self):
@@ -165,7 +165,7 @@ class ReflectionTest(TestBase, AssertsExecutionResults):
         finally:
             try:
                 testing.db.execute("DROP TABLE assorted")
-            except exceptions.DatabaseError:
+            except exc.DatabaseError:
                 pass
 
 class DBAPITest(TestBase, AssertsExecutionResults):
index b5d7f1641b44080c59fc976d085454dd3f4f06ae..c3ce338df1ffa324626ed87017560f324ad99bf8 100755 (executable)
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
 import re
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
 from sqlalchemy.sql import table, column
 from sqlalchemy.databases import mssql
 from testlib import *
@@ -210,7 +210,7 @@ class QueryTest(TestBase):
                 r = users.select(limit=3, offset=2,
                                  order_by=[users.c.user_id]).execute().fetchall()
                 assert False # InvalidRequestError should have been raised
-            except exceptions.InvalidRequestError:
+            except exc.InvalidRequestError:
                 pass
         finally:
             metadata.drop_all()
index 00478908ef8c8d46e174c27890e0ba88a49c82ff..923658b014c69282b62308cb180c8037d97eab2f 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import sets
 from sqlalchemy import *
-from sqlalchemy import sql, exceptions
+from sqlalchemy import sql, exc
 from sqlalchemy.databases import mysql
 from testlib import *
 
@@ -537,13 +537,13 @@ class TypesTest(TestBase, AssertsExecutionResults):
         try:
             enum_table.insert().execute(e1=None, e2=None, e3=None, e4=None)
             self.assert_(False)
-        except exceptions.SQLError:
+        except exc.SQLError:
             self.assert_(True)
 
         try:
             enum_table.insert().execute(e1='c', e2='c', e3='c', e4='c')
             self.assert_(False)
-        except exceptions.InvalidRequestError:
+        except exc.InvalidRequestError:
             self.assert_(True)
 
         enum_table.insert().execute()
index cdd575dd38abb764e8ed7bf8384596b24463bdd0..24353152d3fafb3673b8e637ae9c42a773888667 100644 (file)
@@ -120,10 +120,10 @@ AND mytable.myid = myothertable.otherid(+)",
 
         query = table1.outerjoin(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid)
         self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable LEFT OUTER JOIN myothertable ON mytable.myid = myothertable.otherid LEFT OUTER JOIN thirdtable ON thirdtable.userid = myothertable.otherid")
-        self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid(+) AND thirdtable.userid(+) = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))
+        self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE thirdtable.userid(+) = myothertable.otherid AND mytable.myid = myothertable.otherid(+)", dialect=oracle.dialect(use_ansi=False))
 
         query = table1.join(table2, table1.c.myid==table2.c.otherid).join(table3, table3.c.userid==table2.c.otherid)
-        self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid AND thirdtable.userid = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))
+        self.assert_compile(query.select(), "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable, myothertable, thirdtable WHERE thirdtable.userid = myothertable.otherid AND mytable.myid = myothertable.otherid", dialect=oracle.dialect(use_ansi=False))
 
         query = table1.join(table2, table1.c.myid==table2.c.otherid).outerjoin(table3, table3.c.userid==table2.c.otherid)
         self.assert_compile(query.select().order_by(table1.oid_column).limit(10).offset(5), "SELECT myid, name, description, otherid, othername, userid, \
@@ -131,7 +131,7 @@ otherstuff FROM (SELECT mytable.myid AS myid, mytable.name AS name, \
 mytable.description AS description, myothertable.otherid AS otherid, \
 myothertable.othername AS othername, thirdtable.userid AS userid, \
 thirdtable.otherstuff AS otherstuff, ROW_NUMBER() OVER (ORDER BY mytable.rowid) AS ora_rn \
-FROM mytable, myothertable, thirdtable WHERE mytable.myid = myothertable.otherid AND thirdtable.userid(+) = myothertable.otherid) \
+FROM mytable, myothertable, thirdtable WHERE thirdtable.userid(+) = myothertable.otherid AND mytable.myid = myothertable.otherid) \
 WHERE ora_rn>5 AND ora_rn<=15", dialect=oracle.dialect(use_ansi=False))
 
     def test_alias_outer_join(self):
index 90cc0a47742a7c5b7218ca8a74bca1236b5be4a4..3e5c200e41ea67529b0a5f3f234675800ce5fd80 100644 (file)
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
 import datetime
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
 from sqlalchemy.databases import postgres
 from sqlalchemy.engine.strategies import MockEngineStrategy
 from testlib import *
@@ -332,12 +332,12 @@ class InsertTest(TestBase, AssertsExecutionResults):
         try:
             table.insert().execute({'data':'d2'})
             assert False
-        except exceptions.IntegrityError, e:
+        except exc.IntegrityError, e:
             assert "violates not-null constraint" in str(e)
         try:
             table.insert().execute({'data':'d2'}, {'data':'d3'})
             assert False
-        except exceptions.IntegrityError, e:
+        except exc.IntegrityError, e:
             assert "violates not-null constraint" in str(e)
 
         table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'})
@@ -359,12 +359,12 @@ class InsertTest(TestBase, AssertsExecutionResults):
         try:
             table.insert().execute({'data':'d2'})
             assert False
-        except exceptions.IntegrityError, e:
+        except exc.IntegrityError, e:
             assert "violates not-null constraint" in str(e)
         try:
             table.insert().execute({'data':'d2'}, {'data':'d3'})
             assert False
-        except exceptions.IntegrityError, e:
+        except exc.IntegrityError, e:
             assert "violates not-null constraint" in str(e)
 
         table.insert().execute({'id':31, 'data':'d2'}, {'id':32, 'data':'d3'})
@@ -387,7 +387,7 @@ class DomainReflectionTest(TestBase, AssertsExecutionResults):
         try:
             con.execute('CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42')
             con.execute('CREATE DOMAIN alt_schema.testdomain INTEGER DEFAULT 0')
-        except exceptions.SQLError, e:
+        except exc.SQLError, e:
             if not "already exists" in str(e):
                 raise e
         con.execute('CREATE TABLE testtable (question integer, answer testdomain)')
index 585a853d2a4fa5376be1e1d9fd1c3c6b12fb2389..4cde5fc3312b986dc2312af0eada46250280823f 100644 (file)
@@ -3,7 +3,7 @@
 import testenv; testenv.configure_for_tests()
 import datetime
 from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
 from sqlalchemy.databases import sqlite
 from testlib import *
 
@@ -34,11 +34,11 @@ class TestTypes(TestBase, AssertsExecutionResults):
     @testing.uses_deprecated('Using String type with no length')
     def test_type_reflection(self):
         # (ask_for, roundtripped_as_if_different)
-        specs = [( String(), sqlite.SLText(), ),
+        specs = [( String(), sqlite.SLString(), ),
                  ( String(1), sqlite.SLString(1), ),
                  ( String(3), sqlite.SLString(3), ),
                  ( Text(), sqlite.SLText(), ),
-                 ( Unicode(), sqlite.SLText(), ),
+                 ( Unicode(), sqlite.SLString(), ),
                  ( Unicode(1), sqlite.SLString(1), ),
                  ( Unicode(3), sqlite.SLString(3), ),
                  ( UnicodeText(), sqlite.SLText(), ),
@@ -94,7 +94,7 @@ class TestTypes(TestBase, AssertsExecutionResults):
                 for table in rt, rv:
                     for i, reflected in enumerate(table.c):
                         print reflected.type, type(expected[i])
-                        assert isinstance(reflected.type, type(expected[i]))
+                        assert isinstance(reflected.type, type(expected[i])), type(expected[i])
             finally:
                 db.execute('DROP VIEW types_v')
         finally:
@@ -212,7 +212,7 @@ class DialectTest(TestBase, AssertsExecutionResults):
         except:
             try:
                 cx.execute('DROP TABLE tempy')
-            except exceptions.DBAPIError:
+            except exc.DBAPIError:
                 pass
             raise
 
@@ -247,7 +247,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
     @testing.exclude('sqlite', '<', (3, 4))
     def test_empty_insert_pk2(self):
         self.assertRaises(
-            exceptions.DBAPIError,
+            exc.DBAPIError,
             self._test_empty_insert,
             Table('b', MetaData(testing.db),
                   Column('x', Integer, primary_key=True),
@@ -256,7 +256,7 @@ class InsertTest(TestBase, AssertsExecutionResults):
     @testing.exclude('sqlite', '<', (3, 4))
     def test_empty_insert_pk3(self):
         self.assertRaises(
-            exceptions.DBAPIError,
+            exc.DBAPIError,
             self._test_empty_insert,
             Table('c', MetaData(testing.db),
                   Column('x', Integer, primary_key=True),
index b59cd284a19f94cba7ee30ba4779e9f6407ec18d..300a4eae6e845c2309190c8db3f4119c7be80b24 100644 (file)
@@ -2,9 +2,10 @@
 including the deprecated versions of these arguments"""
 
 import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import engine, exceptions
-from testlib import *
+from sqlalchemy import engine, exc
+from sqlalchemy import MetaData, ThreadLocalMetaData
+from testlib.sa import Table, Column, Integer, String, func, Sequence, text
+from testlib import TestBase, testing
 
 
 class BindTest(TestBase):
@@ -41,7 +42,7 @@ class BindTest(TestBase):
             try:
                 meth()
                 assert False
-            except exceptions.UnboundExecutionError, e:
+            except exc.UnboundExecutionError, e:
                 self.assertEquals(
                     str(e),
                     "The MetaData "
@@ -59,7 +60,7 @@ class BindTest(TestBase):
             try:
                 meth()
                 assert False
-            except exceptions.UnboundExecutionError, e:
+            except exc.UnboundExecutionError, e:
                 self.assertEquals(
                     str(e),
                     "The Table 'test_table' "
@@ -71,6 +72,10 @@ class BindTest(TestBase):
 
     @testing.future
     def test_create_drop_err2(self):
+        metadata = MetaData()
+        table = Table('test_table', metadata,
+            Column('foo', Integer))
+
         for meth in [
             table.exists,
             table.create,
@@ -79,7 +84,7 @@ class BindTest(TestBase):
             try:
                 meth()
                 assert False
-            except exceptions.UnboundExecutionError, e:
+            except exc.UnboundExecutionError, e:
                 self.assertEquals(
                     str(e),
                     "The Table 'test_table' "
@@ -201,7 +206,7 @@ class BindTest(TestBase):
                     assert e.bind is None
                     e.execute()
                     assert False
-                except exceptions.UnboundExecutionError, e:
+                except exc.UnboundExecutionError, e:
                     assert str(e).endswith(
                         'is not bound and does not support direct '
                         'execution. Supply this statement to a Connection or '
@@ -248,7 +253,7 @@ class BindTest(TestBase):
             try:
                 sess.flush()
                 assert False
-            except exceptions.InvalidRequestError, e:
+            except exc.InvalidRequestError, e:
                 assert str(e).startswith("Could not locate any Engine or Connection bound to mapper")
         finally:
             if isinstance(bind, engine.Connection):
index 258c6141206d310d412801862cd3be545536193d..117ee1219bbf7755a510f652b6bdfec1937b1d4f 100644 (file)
@@ -1,9 +1,9 @@
 import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
 from sqlalchemy.schema import DDL
-import sqlalchemy
-from testlib import *
+from sqlalchemy import create_engine
+from testlib.sa import MetaData, Table, Column, Integer, String
+import testlib.sa as tsa
+from testlib import TestBase, testing
 
 
 class DDLEventTest(TestBase):
@@ -294,7 +294,7 @@ class DDLExecutionTest(TestBase):
             try:
                 r = eval(py)
                 assert False
-            except exceptions.UnboundExecutionError:
+            except tsa.exc.UnboundExecutionError:
                 pass
 
         for bind in engine, cx:
@@ -310,7 +310,7 @@ class DDLTest(TestBase):
         engine = create_engine(testing.db.name + '://',
                                strategy='mock', executor=executor)
         engine.dialect.identifier_preparer = \
-           sqlalchemy.sql.compiler.IdentifierPreparer(engine.dialect)
+           tsa.sql.compiler.IdentifierPreparer(engine.dialect)
         return engine
 
     def test_tokens(self):
@@ -324,7 +324,7 @@ class DDLTest(TestBase):
         ddl = DDL('%(schema)s-%(table)s-%(fullname)s')
 
         self.assertEquals(ddl._expand(sane_alone, bind), '-t-t')
-        self.assertEquals(ddl._expand(sane_schema, bind), '"s"-t-s.t')
+        self.assertEquals(ddl._expand(sane_schema, bind), 's-t-s.t')
         self.assertEquals(ddl._expand(insane_alone, bind), '-"t t"-"t t"')
         self.assertEquals(ddl._expand(insane_schema, bind),
                           '"s s"-"t t"-"s s"."t t"')
index 260a05e270d85730d9a8647d71969811be169533..36a6bc3179a43708410a30762422b4f3896f152a 100644 (file)
@@ -1,8 +1,13 @@
 import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from testlib import *
+import re
+from sqlalchemy.interfaces import ConnectionProxy
+from testlib.sa import MetaData, Table, Column, Integer, String, INT, \
+     VARCHAR, func
+import testlib.sa as tsa
+from testlib import TestBase, testing, engines
 
+
+users, metadata = None, None
 class ExecuteTest(TestBase):
     def setUpAll(self):
         global users, metadata
@@ -70,8 +75,85 @@ class ExecuteTest(TestBase):
             try:
                 conn.execute("osdjafioajwoejoasfjdoifjowejfoawejqoijwef")
                 assert False
-            except exceptions.DBAPIError:
+            except tsa.exc.DBAPIError:
                 assert True
 
+class ProxyConnectionTest(TestBase):
+    def test_proxy(self):
+        
+        stmts = []
+        cursor_stmts = []
+        
+        class MyProxy(ConnectionProxy):
+            def execute(self, conn, execute, clauseelement, *multiparams, **params):
+                stmts.append(
+                    (str(clauseelement), params,multiparams)
+                )
+                return execute(clauseelement, *multiparams, **params)
+
+            def cursor_execute(self, execute, cursor, statement, parameters, context, executemany):
+                cursor_stmts.append(
+                    (statement, parameters, None)
+                )
+                return execute(cursor, statement, parameters, context)
+        
+        def assert_stmts(expected, received):
+            for stmt, params, posn in expected:
+                if not received:
+                    assert False
+                while received:
+                    teststmt, testparams, testmultiparams = received.pop(0)
+                    teststmt = re.compile(r'[\n\t ]+', re.M).sub(' ', teststmt).strip()
+                    if teststmt.startswith(stmt) and (testparams==params or testparams==posn):
+                        break
+
+        for engine in (
+            engines.testing_engine(options=dict(proxy=MyProxy())),
+            engines.testing_engine(options=dict(proxy=MyProxy(), strategy='threadlocal'))
+        ):
+            m = MetaData(engine)
+
+            t1 = Table('t1', m, Column('c1', Integer, primary_key=True), Column('c2', String(50), default=func.lower('Foo'), primary_key=True))
+
+            m.create_all()
+            try:
+                t1.insert().execute(c1=5, c2='some data')
+                t1.insert().execute(c1=6)
+                assert engine.execute("select * from t1").fetchall() == [(5, 'some data'), (6, 'foo')]
+            finally:
+                m.drop_all()
+            
+            engine.dispose()
+            
+            compiled = [
+                ("CREATE TABLE t1", {}, None),
+                ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, None),
+                ("INSERT INTO t1 (c1, c2)", {'c1': 6}, None),
+                ("select * from t1", {}, None),
+                ("DROP TABLE t1", {}, None)
+            ]
+
+            if engine.dialect.preexecute_pk_sequences:
+                cursor = [
+                    ("CREATE TABLE t1", {}, None),
+                    ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']),
+                    ("SELECT lower", {'lower_2':'Foo'}, ['Foo']),
+                    ("INSERT INTO t1 (c1, c2)", {'c2': 'foo', 'c1': 6}, [6, 'foo']),
+                    ("select * from t1", {}, None),
+                    ("DROP TABLE t1", {}, None)
+                ]
+            else:
+                cursor = [
+                    ("CREATE TABLE t1", {}, None),
+                    ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, [5, 'some data']),
+                    ("INSERT INTO t1 (c1, c2)", {'c1': 6, "lower_2":"Foo"}, [6, "Foo"]),  # bind param name 'lower_2' might be incorrect
+                    ("select * from t1", {}, None),
+                    ("DROP TABLE t1", {}, None)
+                ]
+                
+            assert_stmts(compiled, stmts)
+            assert_stmts(cursor, cursor_stmts)
+    
+
 if __name__ == "__main__":
     testenv.main()
index 22cdaafee4c41b1f1003ac6059e934d7e1affa11..90f8a00a855322a7fade38c2a90ddf12b9cda05c 100644 (file)
@@ -1,8 +1,11 @@
 import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from testlib import *
 import pickle
+from sqlalchemy import MetaData
+from testlib.sa import Table, Column, Integer, String, UniqueConstraint, \
+     CheckConstraint, ForeignKey
+import testlib.sa as tsa
+from testlib import TestBase, ComparesTables, testing
+
 
 class MetaDataTest(TestBase, ComparesTables):
     def test_metadata_connect(self):
@@ -30,7 +33,7 @@ class MetaDataTest(TestBase, ComparesTables):
                 t2 = Table('table1', metadata, Column('col1', Integer, primary_key=True),
                     Column('col2', String(20)))
                 assert False
-            except exceptions.InvalidRequestError, e:
+            except tsa.exc.InvalidRequestError, e:
                 assert str(e) == "Table 'table1' is already defined for this MetaData instance.  Specify 'useexisting=True' to redefine options and columns on an existing Table object."
         finally:
             metadata.drop_all()
@@ -109,7 +112,7 @@ class MetaDataTest(TestBase, ComparesTables):
             meta.drop_all(testing.db)
 
     def test_nonexistent(self):
-        self.assertRaises(exceptions.NoSuchTableError, Table,
+        self.assertRaises(tsa.exc.NoSuchTableError, Table,
                           'fake_table',
                           MetaData(testing.db), autoload=True)
 
index 117c3ed4bb3ebc71d7933bf97b1016884cdbd3e1..1f7d09c9df7aa77838b6d269469ce547194fc616 100644 (file)
@@ -1,9 +1,9 @@
 import testenv; testenv.configure_for_tests()
 import ConfigParser, StringIO
-from sqlalchemy import *
-from sqlalchemy import exceptions, pool, engine
 import sqlalchemy.engine.url as url
-from testlib import *
+from sqlalchemy import create_engine, engine_from_config
+import testlib.sa as tsa
+from testlib import TestBase
 
 
 class ParseConnectTest(TestBase):
@@ -92,10 +92,10 @@ pool_timeout=10
             }
 
         prefixed = dict(ini.items('prefixed'))
-        self.assert_(engine._coerce_config(prefixed, 'sqlalchemy.') == expected)
+        self.assert_(tsa.engine._coerce_config(prefixed, 'sqlalchemy.') == expected)
 
         plain = dict(ini.items('plain'))
-        self.assert_(engine._coerce_config(plain, '') == expected)
+        self.assert_(tsa.engine._coerce_config(plain, '') == expected)
 
     def test_engine_from_config(self):
         dbapi = MockDBAPI()
@@ -181,7 +181,7 @@ pool_timeout=10
         try:
             c = e.connect()
             assert False
-        except exceptions.DBAPIError:
+        except tsa.exc.DBAPIError:
             assert True
 
     def test_urlattr(self):
@@ -200,11 +200,11 @@ pool_timeout=10
         assert e.pool._recycle == 50
 
         # these args work for QueuePool
-        e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=pool.QueuePool, module=MockDBAPI())
+        e = create_engine('postgres://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.QueuePool, module=MockDBAPI())
 
         try:
             # but not SingletonThreadPool
-            e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=pool.SingletonThreadPool)
+            e = create_engine('sqlite://', max_overflow=8, pool_timeout=60, poolclass=tsa.pool.SingletonThreadPool)
             assert False
         except TypeError:
             assert True
index 75cb08e3c880e4d29ce187a0f1b5e258a7096d45..f2b74a45a780dfe3650256a99f9f28115fa7e140 100644 (file)
@@ -1,9 +1,8 @@
 import testenv; testenv.configure_for_tests()
-import threading, thread, time, gc
-import sqlalchemy.pool as pool
-import sqlalchemy.interfaces as interfaces
-import sqlalchemy.exceptions as exceptions
-from testlib import *
+import threading, time, gc
+from sqlalchemy import pool
+import testlib.sa as tsa
+from testlib import TestBase
 
 
 mcid = 1
@@ -127,7 +126,7 @@ class PoolTest(TestBase):
         try:
             c4 = p.connect()
             assert False
-        except exceptions.TimeoutError, e:
+        except tsa.exc.TimeoutError, e:
             assert int(time.time() - now) == 2
 
     def test_timeout_race(self):
@@ -145,7 +144,7 @@ class PoolTest(TestBase):
                 now = time.time()
                 try:
                     c1 = p.connect()
-                except exceptions.TimeoutError, e:
+                except tsa.exc.TimeoutError, e:
                     timeouts.append(int(time.time()) - now)
                     continue
                 time.sleep(4)
@@ -181,7 +180,7 @@ class PoolTest(TestBase):
                     peaks.append(p.overflow())
                     con.close()
                     del con
-                except exceptions.TimeoutError:
+                except tsa.exc.TimeoutError:
                     pass
         threads = []
         for i in xrange(thread_count):
@@ -444,7 +443,7 @@ class PoolTest(TestBase):
                 # con can be None if invalidated
                 assert record is not None
                 self.checked_in.append(con)
-        class ListenAll(interfaces.PoolListener, InstrumentingListener):
+        class ListenAll(tsa.interfaces.PoolListener, InstrumentingListener):
             pass
         class ListenConnect(InstrumentingListener):
             def connect(self, con, record):
index d0d037a3407526c8cab7832c7be6957c85f41a00..1539d80e0df6ec01d0db15ac4a87fd925d8f92ae 100644 (file)
@@ -1,7 +1,8 @@
 import testenv; testenv.configure_for_tests()
-import sys, weakref
-from sqlalchemy import create_engine, exceptions, select, MetaData, Table, Column, Integer, String
-from testlib import *
+import weakref
+from testlib.sa import select, MetaData, Table, Column, Integer, String
+import testlib.sa as tsa
+from testlib import TestBase, testing, engines
 
 
 class MockDisconnect(Exception):
@@ -43,13 +44,14 @@ class MockCursor(object):
     def close(self):
         pass
 
+db, dbapi = None, None
 class MockReconnectTest(TestBase):
     def setUp(self):
         global db, dbapi
         dbapi = MockDBAPI()
 
         # create engine using our current dburi
-        db = create_engine('postgres://foo:bar@localhost/test', module=dbapi)
+        db = tsa.create_engine('postgres://foo:bar@localhost/test', module=dbapi)
 
         # monkeypatch disconnect checker
         db.dialect.is_disconnect = lambda e: isinstance(e, MockDisconnect)
@@ -80,7 +82,7 @@ class MockReconnectTest(TestBase):
         try:
             conn.execute(select([1]))
             assert False
-        except exceptions.DBAPIError:
+        except tsa.exc.DBAPIError:
             pass
 
         # assert was invalidated
@@ -108,7 +110,7 @@ class MockReconnectTest(TestBase):
         try:
             conn.execute(select([1]))
             assert False
-        except exceptions.DBAPIError:
+        except tsa.exc.DBAPIError:
             pass
 
         # assert was invalidated
@@ -120,7 +122,7 @@ class MockReconnectTest(TestBase):
         try:
             conn.execute(select([1]))
             assert False
-        except exceptions.InvalidRequestError, e:
+        except tsa.exc.InvalidRequestError, e:
             assert str(e) == "Can't reconnect until invalid transaction is rolled back"
 
         assert trans.is_active
@@ -128,7 +130,7 @@ class MockReconnectTest(TestBase):
         try:
             trans.commit()
             assert False
-        except exceptions.InvalidRequestError, e:
+        except tsa.exc.InvalidRequestError, e:
             assert str(e) == "Can't reconnect until invalid transaction is rolled back"
 
         assert trans.is_active
@@ -154,7 +156,7 @@ class MockReconnectTest(TestBase):
         try:
             conn.execute(select([1]))
             assert False
-        except exceptions.DBAPIError:
+        except tsa.exc.DBAPIError:
             pass
 
         assert not conn.closed
@@ -168,7 +170,7 @@ class MockReconnectTest(TestBase):
         assert not conn.invalidated
         assert len(dbapi.connections) == 1
 
-
+engine = None
 class RealReconnectTest(TestBase):
     def setUp(self):
         global engine
@@ -188,7 +190,7 @@ class RealReconnectTest(TestBase):
         try:
             conn.execute(select([1]))
             assert False
-        except exceptions.DBAPIError, e:
+        except tsa.exc.DBAPIError, e:
             if not e.connection_invalidated:
                 raise
 
@@ -204,7 +206,7 @@ class RealReconnectTest(TestBase):
         try:
             conn.execute(select([1]))
             assert False
-        except exceptions.DBAPIError, e:
+        except tsa.exc.DBAPIError, e:
             if not e.connection_invalidated:
                 raise
         assert conn.invalidated
@@ -212,7 +214,7 @@ class RealReconnectTest(TestBase):
         assert not conn.invalidated
 
         conn.close()
-    
+
     def test_close(self):
         conn = engine.connect()
         self.assertEquals(conn.execute(select([1])).scalar(), 1)
@@ -223,7 +225,7 @@ class RealReconnectTest(TestBase):
         try:
             conn.execute(select([1]))
             assert False
-        except exceptions.DBAPIError, e:
+        except tsa.exc.DBAPIError, e:
             if not e.connection_invalidated:
                 raise
 
@@ -244,7 +246,7 @@ class RealReconnectTest(TestBase):
         try:
             conn.execute(select([1]))
             assert False
-        except exceptions.DBAPIError, e:
+        except tsa.exc.DBAPIError, e:
             if not e.connection_invalidated:
                 raise
 
@@ -255,7 +257,7 @@ class RealReconnectTest(TestBase):
         try:
             conn.execute(select([1]))
             assert False
-        except exceptions.InvalidRequestError, e:
+        except tsa.exc.InvalidRequestError, e:
             assert str(e) == "Can't reconnect until invalid transaction is rolled back"
 
         assert trans.is_active
@@ -263,7 +265,7 @@ class RealReconnectTest(TestBase):
         try:
             trans.commit()
             assert False
-        except exceptions.InvalidRequestError, e:
+        except tsa.exc.InvalidRequestError, e:
             assert str(e) == "Can't reconnect until invalid transaction is rolled back"
 
         assert trans.is_active
@@ -275,6 +277,7 @@ class RealReconnectTest(TestBase):
         self.assertEquals(conn.execute(select([1])).scalar(), 1)
         assert not conn.invalidated
 
+meta, table, engine = None, None, None
 class InvalidateDuringResultTest(TestBase):
     def setUp(self):
         global meta, table, engine
@@ -287,28 +290,28 @@ class InvalidateDuringResultTest(TestBase):
         table.insert().execute(
             [{'id':i, 'name':'row %d' % i} for i in range(1, 100)]
         )
-        
+
     def tearDown(self):
         meta.drop_all()
         engine.dispose()
-    
-    @testing.fails_on('mysql')    
+
+    @testing.fails_on('mysql')
     def test_invalidate_on_results(self):
         conn = engine.connect()
-        
+
         result = conn.execute("select * from sometable")
         for x in xrange(20):
             result.fetchone()
-        
+
         engine.test_shutdown()
         try:
             result.fetchone()
             assert False
-        except exceptions.DBAPIError, e:
+        except tsa.exc.DBAPIError, e:
             if not e.connection_invalidated:
                 raise
 
         assert conn.invalidated
-        
+
 if __name__ == '__main__':
     testenv.main()
index 2ace3306a25b1256c8519a679d850bc7f5e88763..64c8489ed728e0cd3848ea62bfd1dfa1f82012ee 100644 (file)
@@ -1,12 +1,13 @@
 import testenv; testenv.configure_for_tests()
 import StringIO, unicodedata
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from sqlalchemy import types as sqltypes
-from testlib import *
-from testlib import engines
+import sqlalchemy as sa
+from testlib.sa import MetaData, Table, Column
+from testlib import TestBase, ComparesTables, testing, engines, sa as tsa
+from testlib.compat import set
 
 
+metadata, users = None, None
+
 class ReflectionTest(TestBase, ComparesTables):
 
     @testing.exclude('mysql', '<', (4, 1, 1))
@@ -14,35 +15,38 @@ class ReflectionTest(TestBase, ComparesTables):
         meta = MetaData(testing.db)
 
         users = Table('engine_users', meta,
-            Column('user_id', INT, primary_key=True),
-            Column('user_name', VARCHAR(20), nullable=False),
-            Column('test1', CHAR(5), nullable=False),
-            Column('test2', Float(5), nullable=False),
-            Column('test3', Text),
-            Column('test4', Numeric, nullable = False),
-            Column('test5', DateTime),
-            Column('parent_user_id', Integer, ForeignKey('engine_users.user_id')),
-            Column('test6', DateTime, nullable=False),
-            Column('test7', Text),
-            Column('test8', Binary),
-            Column('test_passivedefault2', Integer, PassiveDefault("5")),
-            Column('test9', Binary(100)),
-            Column('test_numeric', Numeric()),
+            Column('user_id', sa.INT, primary_key=True),
+            Column('user_name', sa.VARCHAR(20), nullable=False),
+            Column('test1', sa.CHAR(5), nullable=False),
+            Column('test2', sa.Float(5), nullable=False),
+            Column('test3', sa.Text),
+            Column('test4', sa.Numeric, nullable = False),
+            Column('test5', sa.DateTime),
+            Column('parent_user_id', sa.Integer,
+                   sa.ForeignKey('engine_users.user_id')),
+            Column('test6', sa.DateTime, nullable=False),
+            Column('test7', sa.Text),
+            Column('test8', sa.Binary),
+            Column('test_passivedefault2', sa.Integer, sa.PassiveDefault("5")),
+            Column('test9', sa.Binary(100)),
+            Column('test_numeric', sa.Numeric()),
             test_needs_fk=True,
         )
 
         addresses = Table('engine_email_addresses', meta,
-            Column('address_id', Integer, primary_key = True),
-            Column('remote_user_id', Integer, ForeignKey(users.c.user_id)),
-            Column('email_address', String(20)),
+            Column('address_id', sa.Integer, primary_key = True),
+            Column('remote_user_id', sa.Integer, sa.ForeignKey(users.c.user_id)),
+            Column('email_address', sa.String(20)),
             test_needs_fk=True,
         )
         meta.create_all()
 
         try:
             meta2 = MetaData()
-            reflected_users = Table('engine_users', meta2, autoload=True, autoload_with=testing.db)
-            reflected_addresses = Table('engine_email_addresses', meta2, autoload=True, autoload_with=testing.db)
+            reflected_users = Table('engine_users', meta2, autoload=True,
+                                    autoload_with=testing.db)
+            reflected_addresses = Table('engine_email_addresses', meta2,
+                                        autoload=True, autoload_with=testing.db)
             self.assert_tables_equal(users, reflected_users)
             self.assert_tables_equal(addresses, reflected_addresses)
         finally:
@@ -51,22 +55,25 @@ class ReflectionTest(TestBase, ComparesTables):
 
     def test_include_columns(self):
         meta = MetaData(testing.db)
-        foo = Table('foo', meta, *[Column(n, String(30)) for n in ['a', 'b', 'c', 'd', 'e', 'f']])
+        foo = Table('foo', meta, *[Column(n, sa.String(30))
+                                   for n in ['a', 'b', 'c', 'd', 'e', 'f']])
         meta.create_all()
         try:
             meta2 = MetaData(testing.db)
-            foo = Table('foo', meta2, autoload=True, include_columns=['b', 'f', 'e'])
+            foo = Table('foo', meta2, autoload=True,
+                        include_columns=['b', 'f', 'e'])
             # test that cols come back in original order
             self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f'])
             for c in ('b', 'f', 'e'):
                 assert c in foo.c
             for c in ('a', 'c', 'd'):
                 assert c not in foo.c
-                
+
             # test against a table which is already reflected
             meta3 = MetaData(testing.db)
             foo = Table('foo', meta3, autoload=True)
-            foo = Table('foo', meta3, include_columns=['b', 'f', 'e'], useexisting=True)
+            foo = Table('foo', meta3, include_columns=['b', 'f', 'e'],
+                        useexisting=True)
             self.assertEquals([c.name for c in foo.c], ['b', 'e', 'f'])
             for c in ('b', 'f', 'e'):
                 assert c in foo.c
@@ -79,7 +86,7 @@ class ReflectionTest(TestBase, ComparesTables):
     def test_unknown_types(self):
         meta = MetaData(testing.db)
         t = Table("test", meta,
-            Column('foo', DateTime))
+            Column('foo', sa.DateTime))
 
         import sys
         dialect_module = sys.modules[testing.db.dialect.__module__]
@@ -100,14 +107,14 @@ class ReflectionTest(TestBase, ComparesTables):
                 m2 = MetaData(testing.db)
                 t2 = Table("test", m2, autoload=True)
                 assert False
-            except exceptions.SAWarning:
+            except tsa.exc.SAWarning:
                 assert True
 
             @testing.emits_warning('Did not recognize type')
             def warns():
                 m3 = MetaData(testing.db)
                 t3 = Table("test", m3, autoload=True)
-                assert t3.c.foo.type.__class__ == sqltypes.NullType
+                assert t3.c.foo.type.__class__ == sa.types.NullType
 
         finally:
             dialect_module.ischema_names = ischema_names
@@ -117,9 +124,9 @@ class ReflectionTest(TestBase, ComparesTables):
         meta = MetaData(testing.db)
         table = Table(
             'override_test', meta,
-            Column('col1', Integer, primary_key=True),
-            Column('col2', String(20)),
-            Column('col3', Numeric)
+            Column('col1', sa.Integer, primary_key=True),
+            Column('col2', sa.String(20)),
+            Column('col3', sa.Numeric)
         )
         table.create()
 
@@ -127,12 +134,12 @@ class ReflectionTest(TestBase, ComparesTables):
         try:
             table = Table(
                 'override_test', meta2,
-                Column('col2', Unicode()),
-                Column('col4', String(30)), autoload=True)
+                Column('col2', sa.Unicode()),
+                Column('col4', sa.String(30)), autoload=True)
 
-            self.assert_(isinstance(table.c.col1.type, Integer))
-            self.assert_(isinstance(table.c.col2.type, Unicode))
-            self.assert_(isinstance(table.c.col4.type, String))
+            self.assert_(isinstance(table.c.col1.type, sa.Integer))
+            self.assert_(isinstance(table.c.col2.type, sa.Unicode))
+            self.assert_(isinstance(table.c.col4.type, sa.String))
         finally:
             table.drop()
 
@@ -142,18 +149,19 @@ class ReflectionTest(TestBase, ComparesTables):
 
         meta = MetaData(testing.db)
         users = Table('users', meta,
-            Column('id', Integer, primary_key=True),
-            Column('name', String(30)))
+            Column('id', sa.Integer, primary_key=True),
+            Column('name', sa.String(30)))
         addresses = Table('addresses', meta,
-            Column('id', Integer, primary_key=True),
-            Column('street', String(30)))
+            Column('id', sa.Integer, primary_key=True),
+            Column('street', sa.String(30)))
 
 
         meta.create_all()
         try:
             meta2 = MetaData(testing.db)
             a2 = Table('addresses', meta2,
-                Column('id', Integer, ForeignKey('users.id'), primary_key=True),
+                Column('id', sa.Integer,
+                       sa.ForeignKey('users.id'), primary_key=True),
                 autoload=True)
             u2 = Table('users', meta2, autoload=True)
 
@@ -164,7 +172,8 @@ class ReflectionTest(TestBase, ComparesTables):
             meta3 = MetaData(testing.db)
             u3 = Table('users', meta3, autoload=True)
             a3 = Table('addresses', meta3,
-                Column('id', Integer, ForeignKey('users.id'), primary_key=True),
+                Column('id', sa.Integer, sa.ForeignKey('users.id'),
+                       primary_key=True),
                 autoload=True)
 
             assert list(a3.primary_key) == [a3.c.id]
@@ -180,18 +189,18 @@ class ReflectionTest(TestBase, ComparesTables):
 
         meta = MetaData(testing.db)
         users = Table('users', meta,
-            Column('id', Integer, primary_key=True),
-            Column('name', String(30)))
+            Column('id', sa.Integer, primary_key=True),
+            Column('name', sa.String(30)))
         addresses = Table('addresses', meta,
-            Column('id', Integer, primary_key=True),
-            Column('street', String(30)),
-            Column('user_id', Integer))
+            Column('id', sa.Integer, primary_key=True),
+            Column('street', sa.String(30)),
+            Column('user_id', sa.Integer))
 
         meta.create_all()
         try:
             meta2 = MetaData(testing.db)
             a2 = Table('addresses', meta2,
-                Column('user_id', Integer, ForeignKey('users.id')),
+                Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
                 autoload=True)
             u2 = Table('users', meta2, autoload=True)
 
@@ -205,19 +214,19 @@ class ReflectionTest(TestBase, ComparesTables):
             meta3 = MetaData(testing.db)
             u3 = Table('users', meta3, autoload=True)
             a3 = Table('addresses', meta3,
-                Column('user_id', Integer, ForeignKey('users.id')),
+                Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
                 autoload=True)
 
             assert u3.join(a3).onclause == u3.c.id==a3.c.user_id
 
             meta4 = MetaData(testing.db)
             u4 = Table('users', meta4,
-                       Column('id', Integer, key='u_id', primary_key=True),
+                       Column('id', sa.Integer, key='u_id', primary_key=True),
                        autoload=True)
             a4 = Table('addresses', meta4,
-                       Column('id', Integer, key='street', primary_key=True),
-                       Column('street', String(30), key='user_id'),
-                       Column('user_id', Integer, ForeignKey('users.u_id'),
+                       Column('id', sa.Integer, key='street', primary_key=True),
+                       Column('street', sa.String(30), key='user_id'),
+                       Column('user_id', sa.Integer, sa.ForeignKey('users.u_id'),
                               key='id'),
                        autoload=True)
 
@@ -237,19 +246,19 @@ class ReflectionTest(TestBase, ComparesTables):
 
         meta = MetaData(testing.db)
         users = Table('users', meta,
-            Column('id', Integer, primary_key=True),
-            Column('name', String(30)),
+            Column('id', sa.Integer, primary_key=True),
+            Column('name', sa.String(30)),
             test_needs_fk=True)
         addresses = Table('addresses', meta,
-            Column('id', Integer,primary_key=True),
-            Column('user_id', Integer, ForeignKey('users.id')),
+            Column('id', sa.Integer, primary_key=True),
+            Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
             test_needs_fk=True)
 
         meta.create_all()
         try:
             meta2 = MetaData(testing.db)
             a2 = Table('addresses', meta2,
-                Column('user_id',Integer, ForeignKey('users.id')),
+                Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
                 autoload=True)
             u2 = Table('users', meta2, autoload=True)
 
@@ -263,11 +272,11 @@ class ReflectionTest(TestBase, ComparesTables):
 
             meta2 = MetaData(testing.db)
             u2 = Table('users', meta2, 
-                Column('id', Integer, primary_key=True),
+                Column('id', sa.Integer, primary_key=True),
                 autoload=True)
             a2 = Table('addresses', meta2,
-                Column('id', Integer, primary_key=True),
-                Column('user_id',Integer, ForeignKey('users.id')),
+                Column('id', sa.Integer, primary_key=True),
+                Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
                 autoload=True)
 
             assert len(a2.foreign_keys) == 1
@@ -279,31 +288,31 @@ class ReflectionTest(TestBase, ComparesTables):
             assert u2.join(a2).onclause == u2.c.id==a2.c.user_id
         finally:
             meta.drop_all()
-    
+
     def test_use_existing(self):
         meta = MetaData(testing.db)
         users = Table('users', meta,
-            Column('id', Integer, primary_key=True),
-            Column('name', String(30)),
+            Column('id', sa.Integer, primary_key=True),
+            Column('name', sa.String(30)),
             test_needs_fk=True)
         addresses = Table('addresses', meta,
-            Column('id', Integer,primary_key=True),
-            Column('user_id', Integer, ForeignKey('users.id')),
-            Column('data', String(100)),
+            Column('id', sa.Integer,primary_key=True),
+            Column('user_id', sa.Integer, sa.ForeignKey('users.id')),
+            Column('data', sa.String(100)),
             test_needs_fk=True)
 
         meta.create_all()
         try:
             meta2 = MetaData(testing.db)
-            addresses = Table('addresses', meta2, Column('data', Unicode), autoload=True)
+            addresses = Table('addresses', meta2, Column('data', sa.Unicode), autoload=True)
             try:
-                users = Table('users', meta2, Column('name', Unicode), autoload=True)
+                users = Table('users', meta2, Column('name', sa.Unicode), autoload=True)
                 assert False
-            except exceptions.InvalidRequestError, err:
+            except tsa.exc.InvalidRequestError, err:
                 assert str(err) == "Table 'users' is already defined for this MetaData instance.  Specify 'useexisting=True' to redefine options and columns on an existing Table object."
             
-            users = Table('users', meta2, Column('name', Unicode), autoload=True, useexisting=True)
-            assert isinstance(users.c.name.type, Unicode)
+            users = Table('users', meta2, Column('name', sa.Unicode), autoload=True, useexisting=True)
+            assert isinstance(users.c.name.type, sa.Unicode)
 
             assert not users.quote
             
@@ -328,8 +337,8 @@ class ReflectionTest(TestBase, ComparesTables):
         try:
             metadata = MetaData(bind=testing.db)
             book = Table('book', metadata, autoload=True)
-            assert book.c.id  in book.primary_key
-            assert book.c.series not in book.primary_key
+            assert book.primary_key.contains_column(book.c.id)
+            assert not book.primary_key.contains_column(book.c.series)
             assert len(book.primary_key) == 1
         finally:
             testing.db.execute("drop table book")
@@ -337,14 +346,14 @@ class ReflectionTest(TestBase, ComparesTables):
     def test_fk_error(self):
         metadata = MetaData(testing.db)
         slots_table = Table('slots', metadata,
-            Column('slot_id', Integer, primary_key=True),
-            Column('pkg_id', Integer, ForeignKey('pkgs.pkg_id')),
-            Column('slot', String(128)),
+            Column('slot_id', sa.Integer, primary_key=True),
+            Column('pkg_id', sa.Integer, sa.ForeignKey('pkgs.pkg_id')),
+            Column('slot', sa.String(128)),
             )
         try:
             metadata.create_all()
             assert False
-        except exceptions.InvalidRequestError, err:
+        except tsa.exc.InvalidRequestError, err:
             assert str(err) == "Could not find table 'pkgs' with which to generate a foreign key"
 
     def test_composite_pks(self):
@@ -363,9 +372,9 @@ class ReflectionTest(TestBase, ComparesTables):
         try:
             metadata = MetaData(bind=testing.db)
             book = Table('book', metadata, autoload=True)
-            assert book.c.id  in book.primary_key
-            assert book.c.isbn  in book.primary_key
-            assert book.c.series not in book.primary_key
+            assert book.primary_key.contains_column(book.c.id)
+            assert book.primary_key.contains_column(book.c.isbn)
+            assert not book.primary_key.contains_column(book.c.series)
             assert len(book.primary_key) == 2
         finally:
             testing.db.execute("drop table book")
@@ -377,20 +386,20 @@ class ReflectionTest(TestBase, ComparesTables):
         meta = MetaData(testing.db)
         multi = Table(
             'multi', meta,
-            Column('multi_id', Integer, primary_key=True),
-            Column('multi_rev', Integer, primary_key=True),
-            Column('multi_hoho', Integer, primary_key=True),
-            Column('name', String(50), nullable=False),
-            Column('val', String(100)),
+            Column('multi_id', sa.Integer, primary_key=True),
+            Column('multi_rev', sa.Integer, primary_key=True),
+            Column('multi_hoho', sa.Integer, primary_key=True),
+            Column('name', sa.String(50), nullable=False),
+            Column('val', sa.String(100)),
             test_needs_fk=True,
         )
         multi2 = Table('multi2', meta,
-            Column('id', Integer, primary_key=True),
-            Column('foo', Integer),
-            Column('bar', Integer),
-            Column('lala', Integer),
-            Column('data', String(50)),
-            ForeignKeyConstraint(['foo', 'bar', 'lala'], ['multi.multi_id', 'multi.multi_rev', 'multi.multi_hoho']),
+            Column('id', sa.Integer, primary_key=True),
+            Column('foo', sa.Integer),
+            Column('bar', sa.Integer),
+            Column('lala', sa.Integer),
+            Column('data', sa.String(50)),
+            sa.ForeignKeyConstraint(['foo', 'bar', 'lala'], ['multi.multi_id', 'multi.multi_rev', 'multi.multi_hoho']),
             test_needs_fk=True,
         )
         meta.create_all()
@@ -401,8 +410,8 @@ class ReflectionTest(TestBase, ComparesTables):
             table2 = Table('multi2', meta2, autoload=True, autoload_with=testing.db)
             self.assert_tables_equal(multi, table)
             self.assert_tables_equal(multi2, table2)
-            j = join(table, table2)
-            self.assert_(and_(table.c.multi_id==table2.c.foo, table.c.multi_rev==table2.c.bar, table.c.multi_hoho==table2.c.lala).compare(j.onclause))
+            j = sa.join(table, table2)
+            self.assert_(sa.and_(table.c.multi_id==table2.c.foo, table.c.multi_rev==table2.c.bar, table.c.multi_hoho==table2.c.lala).compare(j.onclause))
         finally:
             meta.drop_all()
 
@@ -412,10 +421,10 @@ class ReflectionTest(TestBase, ComparesTables):
         # check a table that uses an SQL reserved name doesn't cause an error
         meta = MetaData(testing.db)
         table_a = Table('select', meta,
-                       Column('not', Integer, primary_key=True),
-                       Column('from', String(12), nullable=False),
-                       UniqueConstraint('from', name='when'))
-        Index('where', table_a.c['from'])
+                       Column('not', sa.Integer, primary_key=True),
+                       Column('from', sa.String(12), nullable=False),
+                       sa.UniqueConstraint('from', name='when'))
+        sa.Index('where', table_a.c['from'])
 
         # There's currently no way to calculate identifier case normalization
         # in isolation, so...
@@ -426,17 +435,17 @@ class ReflectionTest(TestBase, ComparesTables):
         quoter = meta.bind.dialect.identifier_preparer.quote_identifier
 
         table_b = Table('false', meta,
-                        Column('create', Integer, primary_key=True),
-                        Column('true', Integer, ForeignKey('select.not')),
-                        CheckConstraint('%s <> 1' % quoter(check_col),
+                        Column('create', sa.Integer, primary_key=True),
+                        Column('true', sa.Integer, sa.ForeignKey('select.not')),
+                        sa.CheckConstraint('%s <> 1' % quoter(check_col),
                                         name='limit'))
 
         table_c = Table('is', meta,
-                        Column('or', Integer, nullable=False, primary_key=True),
-                        Column('join', Integer, nullable=False, primary_key=True),
-                        PrimaryKeyConstraint('or', 'join', name='to'))
+                        Column('or', sa.Integer, nullable=False, primary_key=True),
+                        Column('join', sa.Integer, nullable=False, primary_key=True),
+                        sa.PrimaryKeyConstraint('or', 'join', name='to'))
 
-        index_c = Index('else', table_c.c.join)
+        index_c = sa.Index('else', table_c.c.join)
 
         meta.create_all()
 
@@ -462,7 +471,7 @@ class ReflectionTest(TestBase, ComparesTables):
 
         baseline = MetaData(testing.db)
         for name in names:
-            Table(name, baseline, Column('id', Integer, primary_key=True))
+            Table(name, baseline, Column('id', sa.Integer, primary_key=True))
         baseline.create_all()
 
         try:
@@ -484,7 +493,7 @@ class ReflectionTest(TestBase, ComparesTables):
             try:
                 m4.reflect(only=['rt_a', 'rt_f'])
                 self.assert_(False)
-            except exceptions.InvalidRequestError, e:
+            except tsa.exc.InvalidRequestError, e:
                 self.assert_(e.args[0].endswith('(rt_f)'))
 
             m5 = MetaData(testing.db)
@@ -501,7 +510,7 @@ class ReflectionTest(TestBase, ComparesTables):
             try:
                 m8 = MetaData(reflect=True)
                 self.assert_(False)
-            except exceptions.ArgumentError, e:
+            except tsa.exc.ArgumentError, e:
                 self.assert_(
                     e.args[0] ==
                     "A bind must be supplied in conjunction with reflect=True")
@@ -521,27 +530,27 @@ class CreateDropTest(TestBase):
         global metadata, users
         metadata = MetaData()
         users = Table('users', metadata,
-                      Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key=True),
-                      Column('user_name', String(40)),
+                      Column('user_id', sa.Integer, sa.Sequence('user_id_seq', optional=True), primary_key=True),
+                      Column('user_name', sa.String(40)),
                       )
 
         addresses = Table('email_addresses', metadata,
-            Column('address_id', Integer, Sequence('address_id_seq', optional=True), primary_key = True),
-            Column('user_id', Integer, ForeignKey(users.c.user_id)),
-            Column('email_address', String(40)),
+            Column('address_id', sa.Integer, sa.Sequence('address_id_seq', optional=True), primary_key = True),
+            Column('user_id', sa.Integer, sa.ForeignKey(users.c.user_id)),
+            Column('email_address', sa.String(40)),
         )
 
         orders = Table('orders', metadata,
-            Column('order_id', Integer, Sequence('order_id_seq', optional=True), primary_key = True),
-            Column('user_id', Integer, ForeignKey(users.c.user_id)),
-            Column('description', String(50)),
-            Column('isopen', Integer),
+            Column('order_id', sa.Integer, sa.Sequence('order_id_seq', optional=True), primary_key = True),
+            Column('user_id', sa.Integer, sa.ForeignKey(users.c.user_id)),
+            Column('description', sa.String(50)),
+            Column('isopen', sa.Integer),
         )
 
         orderitems = Table('items', metadata,
-            Column('item_id', INT, Sequence('items_id_seq', optional=True), primary_key = True),
-            Column('order_id', INT, ForeignKey("orders")),
-            Column('item_name', VARCHAR(50)),
+            Column('item_id', sa.INT, sa.Sequence('items_id_seq', optional=True), primary_key = True),
+            Column('order_id', sa.INT, sa.ForeignKey("orders")),
+            Column('item_name', sa.VARCHAR(50)),
         )
 
     def test_sorter( self ):
@@ -590,10 +599,10 @@ class SchemaManipulationTest(TestBase):
     def test_append_constraint_unique(self):
         meta = MetaData()
         
-        users = Table('users', meta, Column('id', Integer))
-        addresses = Table('addresses', meta, Column('id', Integer), Column('user_id', Integer))
+        users = Table('users', meta, Column('id', sa.Integer))
+        addresses = Table('addresses', meta, Column('id', sa.Integer), Column('user_id', sa.Integer))
         
-        fk = ForeignKeyConstraint(['user_id'],[users.c.id])
+        fk = sa.ForeignKeyConstraint(['user_id'],[users.c.id])
         
         addresses.append_constraint(fk)
         addresses.append_constraint(fk)
@@ -616,7 +625,7 @@ class UnicodeReflectionTest(TestBase):
                 names = set([u'plain', u'Unit\u00e9ble', u'\u6e2c\u8a66'])
 
             for name in names:
-                Table(name, metadata, Column('id', Integer, Sequence(name + "_id_seq"), primary_key=True))
+                Table(name, metadata, Column('id', sa.Integer, sa.Sequence(name + "_id_seq"), primary_key=True))
             metadata.create_all()
 
             reflected = set(bind.table_names())
@@ -642,18 +651,18 @@ class SchemaTest(TestBase):
     def test_iteration(self):
         metadata = MetaData()
         table1 = Table('table1', metadata,
-            Column('col1', Integer, primary_key=True),
+            Column('col1', sa.Integer, primary_key=True),
             schema='someschema')
         table2 = Table('table2', metadata,
-            Column('col1', Integer, primary_key=True),
-            Column('col2', Integer, ForeignKey('someschema.table1.col1')),
+            Column('col1', sa.Integer, primary_key=True),
+            Column('col2', sa.Integer, sa.ForeignKey('someschema.table1.col1')),
             schema='someschema')
         # ensure this doesnt crash
         print [t for t in metadata.table_iterator()]
         buf = StringIO.StringIO()
         def foo(s, p=None):
             buf.write(s)
-        gen = create_engine(testing.db.name + "://", strategy="mock", executor=foo)
+        gen = sa.create_engine(testing.db.name + "://", strategy="mock", executor=foo)
         gen = gen.dialect.schemagenerator(gen.dialect, gen)
         gen.traverse(table1)
         gen.traverse(table2)
@@ -681,12 +690,12 @@ class SchemaTest(TestBase):
 
         metadata = MetaData(engine)
         table1 = Table('table1', metadata,
-                       Column('col1', Integer, primary_key=True),
+                       Column('col1', sa.Integer, primary_key=True),
                        schema=schema)
         table2 = Table('table2', metadata,
-                       Column('col1', Integer, primary_key=True),
-                       Column('col2', Integer,
-                              ForeignKey('%s.table1.col1' % schema)),
+                       Column('col1', sa.Integer, primary_key=True),
+                       Column('col2', sa.Integer,
+                              sa.ForeignKey('%s.table1.col1' % schema)),
                        schema=schema)
         try:
             metadata.create_all()
@@ -704,8 +713,8 @@ class HasSequenceTest(TestBase):
         global metadata, users
         metadata = MetaData()
         users = Table('users', metadata,
-                      Column('user_id', Integer, Sequence('user_id_seq'), primary_key=True),
-                      Column('user_name', String(40)),
+                      Column('user_id', sa.Integer, sa.Sequence('user_id_seq'), primary_key=True),
+                      Column('user_name', sa.String(40)),
                       )
 
     @testing.unsupported('sqlite', 'mysql', 'mssql', 'access', 'sybase')
index edae14da29a2ee6db58cee7dd9dd01abb44071c5..1cb6ba7a1e573605705c59eea85dfd1cccd213e7 100644 (file)
@@ -1,11 +1,11 @@
 import testenv; testenv.configure_for_tests()
 import sys, time, threading
-
-from sqlalchemy import *
-from sqlalchemy.orm import *
-from testlib import *
+from testlib.sa import create_engine, MetaData, Table, Column, INT, VARCHAR, \
+     Sequence, select, Integer, String, func, text
+from testlib import TestBase, testing
 
 
+users, metadata = None, None
 class TransactionTest(TestBase):
     def setUpAll(self):
         global users, metadata
@@ -22,7 +22,7 @@ class TransactionTest(TestBase):
     def tearDownAll(self):
         users.drop(testing.db)
 
-    def testcommits(self):
+    def test_commits(self):
         connection = testing.db.connect()
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -38,7 +38,7 @@ class TransactionTest(TestBase):
         assert len(result.fetchall()) == 3
         transaction.commit()
 
-    def testrollback(self):
+    def test_rollback(self):
         """test a basic rollback"""
         connection = testing.db.connect()
         transaction = connection.begin()
@@ -51,7 +51,7 @@ class TransactionTest(TestBase):
         assert len(result.fetchall()) == 0
         connection.close()
 
-    def testraise(self):
+    def test_raise(self):
         connection = testing.db.connect()
 
         transaction = connection.begin()
@@ -70,7 +70,7 @@ class TransactionTest(TestBase):
         connection.close()
 
     @testing.exclude('mysql', '<', (5, 0, 3))
-    def testnestedrollback(self):
+    def test_nested_rollback(self):
         connection = testing.db.connect()
 
         try:
@@ -100,7 +100,7 @@ class TransactionTest(TestBase):
 
 
     @testing.exclude('mysql', '<', (5, 0, 3))
-    def testnesting(self):
+    def test_nesting(self):
         connection = testing.db.connect()
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -118,7 +118,7 @@ class TransactionTest(TestBase):
         connection.close()
 
     @testing.exclude('mysql', '<', (5, 0, 3))
-    def testclose(self):
+    def test_close(self):
         connection = testing.db.connect()
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -139,7 +139,7 @@ class TransactionTest(TestBase):
         connection.close()
 
     @testing.exclude('mysql', '<', (5, 0, 3))
-    def testclose2(self):
+    def test_close2(self):
         connection = testing.db.connect()
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -159,10 +159,8 @@ class TransactionTest(TestBase):
         assert len(result.fetchall()) == 0
         connection.close()
 
-
-    @testing.unsupported('sqlite', 'mssql', 'sybase', 'access')
-    @testing.exclude('mysql', '<', (5, 0, 3))
-    def testnestedsubtransactionrollback(self):
+    @testing.requires.savepoints
+    def test_nested_subtransaction_rollback(self):
         connection = testing.db.connect()
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -178,9 +176,8 @@ class TransactionTest(TestBase):
         )
         connection.close()
 
-    @testing.unsupported('sqlite', 'mssql', 'sybase', 'access')
-    @testing.exclude('mysql', '<', (5, 0, 3))
-    def testnestedsubtransactioncommit(self):
+    @testing.requires.savepoints
+    def test_nested_subtransaction_commit(self):
         connection = testing.db.connect()
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -196,9 +193,8 @@ class TransactionTest(TestBase):
         )
         connection.close()
 
-    @testing.unsupported('sqlite', 'mssql', 'sybase', 'access')
-    @testing.exclude('mysql', '<', (5, 0, 3))
-    def testrollbacktosubtransaction(self):
+    @testing.requires.savepoints
+    def test_rollback_to_subtransaction(self):
         connection = testing.db.connect()
         transaction = connection.begin()
         connection.execute(users.insert(), user_id=1, user_name='user1')
@@ -216,10 +212,8 @@ class TransactionTest(TestBase):
         )
         connection.close()
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @testing.exclude('mysql', '<', (5, 0, 3))
-    def testtwophasetransaction(self):
+    @testing.requires.two_phase_transactions
+    def test_two_phase_transaction(self):
         connection = testing.db.connect()
 
         transaction = connection.begin_twophase()
@@ -246,10 +240,9 @@ class TransactionTest(TestBase):
         )
         connection.close()
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @testing.exclude('mysql', '<', (5, 0, 3))
-    def testmixedtwophasetransaction(self):
+    @testing.requires.two_phase_transactions
+    @testing.requires.savepoints
+    def test_mixed_two_phase_transaction(self):
         connection = testing.db.connect()
 
         transaction = connection.begin_twophase()
@@ -281,11 +274,9 @@ class TransactionTest(TestBase):
         )
         connection.close()
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    # fixme: see if this is still true and/or can be convert to fails_on()
-    @testing.unsupported('mysql')
-    def testtwophaserecover(self):
+    @testing.requires.two_phase_transactions
+    @testing.fails_on('mysql')
+    def test_two_phase_recover(self):
         # MySQL recovery doesn't currently seem to work correctly
         # Prepared transactions disappear when connections are closed and even
         # when they aren't it doesn't seem possible to use the recovery id.
@@ -316,10 +307,8 @@ class TransactionTest(TestBase):
         )
         connection2.close()
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @testing.exclude('mysql', '<', (5, 0, 3))
-    def testmultipletwophase(self):
+    @testing.requires.two_phase_transactions
+    def test_multiple_two_phase(self):
         conn = testing.db.connect()
 
         xa = conn.begin_twophase()
@@ -355,7 +344,7 @@ class AutoRollbackTest(TestBase):
         metadata.drop_all(testing.db)
 
     @testing.unsupported('sqlite')
-    def testrollback_deadlock(self):
+    def test_rollback_deadlock(self):
         """test that returning connections to the pool clears any object locks."""
         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
@@ -375,12 +364,13 @@ class AutoRollbackTest(TestBase):
         users.drop(conn2)
         conn2.close()
 
+foo = None
 class ExplicitAutoCommitTest(TestBase):
-    """test the 'autocommit' flag on select() and text() objects.  
-    
+    """test the 'autocommit' flag on select() and text() objects.
+
     Requires Postgres so that we may define a custom function which modifies the database.
     """
-    
+
     __only_on__ = 'postgres'
 
     def setUpAll(self):
@@ -392,13 +382,13 @@ class ExplicitAutoCommitTest(TestBase):
 
     def tearDown(self):
         foo.delete().execute()
-        
+
     def tearDownAll(self):
         testing.db.execute("drop function insert_foo(varchar)")
         metadata.drop_all()
-    
+
     def test_control(self):
-        # test that not using autocommit does not commit 
+        # test that not using autocommit does not commit
         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
 
@@ -412,44 +402,45 @@ class ExplicitAutoCommitTest(TestBase):
         trans.commit()
 
         assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',), ('moredata',)]
-        
+
         conn1.close()
         conn2.close()
-        
+
     def test_explicit_compiled(self):
         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
-        
+
         conn1.execute(select([func.insert_foo('data1')], autocommit=True))
         assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',)]
 
         conn1.execute(select([func.insert_foo('data2')]).autocommit())
         assert conn2.execute(select([foo.c.data])).fetchall() == [('data1',), ('data2',)]
-        
+
         conn1.close()
         conn2.close()
-    
+
     def test_explicit_text(self):
         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
-        
+
         conn1.execute(text("select insert_foo('moredata')", autocommit=True))
         assert conn2.execute(select([foo.c.data])).fetchall() == [('moredata',)]
-        
+
         conn1.close()
         conn2.close()
 
     def test_implicit_text(self):
         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
-        
+
         conn1.execute(text("insert into foo (data) values ('implicitdata')"))
         assert conn2.execute(select([foo.c.data])).fetchall() == [('implicitdata',)]
-        
+
         conn1.close()
         conn2.close()
-        
-    
+
+
+tlengine = None
 class TLTransactionTest(TestBase):
     def setUpAll(self):
         global users, metadata, tlengine
@@ -502,7 +493,7 @@ class TLTransactionTest(TestBase):
         finally:
             external_connection.close()
 
-    def testrollback(self):
+    def test_rollback(self):
         """test a basic rollback"""
         tlengine.begin()
         tlengine.execute(users.insert(), user_id=1, user_name='user1')
@@ -517,7 +508,7 @@ class TLTransactionTest(TestBase):
         finally:
             external_connection.close()
 
-    def testcommit(self):
+    def test_commit(self):
         """test a basic commit"""
         tlengine.begin()
         tlengine.execute(users.insert(), user_id=1, user_name='user1')
@@ -532,7 +523,7 @@ class TLTransactionTest(TestBase):
         finally:
             external_connection.close()
 
-    def testcommits(self):
+    def test_commits(self):
         assert tlengine.connect().execute("select count(1) from query_users").scalar() == 0
 
         connection = tlengine.contextual_connect()
@@ -551,7 +542,7 @@ class TLTransactionTest(TestBase):
         assert len(l) == 3, "expected 3 got %d" % len(l)
         transaction.commit()
 
-    def testrollback_off_conn(self):
+    def test_rollback_off_conn(self):
         # test that a TLTransaction opened off a TLConnection allows that
         # TLConnection to be aware of the transactional context
         conn = tlengine.contextual_connect()
@@ -568,7 +559,7 @@ class TLTransactionTest(TestBase):
         finally:
             external_connection.close()
 
-    def testmorerollback_off_conn(self):
+    def test_morerollback_off_conn(self):
         # test that an existing TLConnection automatically takes place in a TLTransaction
         # opened on a second TLConnection
         conn = tlengine.contextual_connect()
@@ -586,7 +577,7 @@ class TLTransactionTest(TestBase):
         finally:
             external_connection.close()
 
-    def testcommit_off_conn(self):
+    def test_commit_off_connection(self):
         conn = tlengine.contextual_connect()
         trans = conn.begin()
         conn.execute(users.insert(), user_id=1, user_name='user1')
@@ -603,7 +594,7 @@ class TLTransactionTest(TestBase):
 
     @testing.unsupported('sqlite')
     @testing.exclude('mysql', '<', (5, 0, 3))
-    def testnesting(self):
+    def test_nesting(self):
         """tests nesting of transactions"""
         external_connection = tlengine.connect()
         self.assert_(external_connection.connection is not tlengine.contextual_connect().connection)
@@ -622,7 +613,7 @@ class TLTransactionTest(TestBase):
             external_connection.close()
 
     @testing.exclude('mysql', '<', (5, 0, 3))
-    def testmixednesting(self):
+    def test_mixed_nesting(self):
         """tests nesting of transactions off the TLEngine directly inside of
         tranasctions off the connection from the TLEngine"""
         external_connection = tlengine.connect()
@@ -651,7 +642,7 @@ class TLTransactionTest(TestBase):
             external_connection.close()
 
     @testing.exclude('mysql', '<', (5, 0, 3))
-    def testmoremixednesting(self):
+    def test_more_mixed_nesting(self):
         """tests nesting of transactions off the connection from the TLEngine
         inside of tranasctions off thbe TLEngine directly."""
         external_connection = tlengine.connect()
@@ -674,24 +665,9 @@ class TLTransactionTest(TestBase):
         finally:
             external_connection.close()
 
-    @testing.exclude('mysql', '<', (5, 0, 3))
-    def testsessionnesting(self):
-        class User(object):
-            pass
-        try:
-            mapper(User, users)
-
-            sess = create_session(bind=tlengine)
-            tlengine.begin()
-            u = User()
-            sess.save(u)
-            sess.flush()
-            tlengine.commit()
-        finally:
-            clear_mappers()
 
 
-    def testconnections(self):
+    def test_connections(self):
         """tests that contextual_connect is threadlocal"""
         c1 = tlengine.contextual_connect()
         c2 = tlengine.contextual_connect()
@@ -699,10 +675,8 @@ class TLTransactionTest(TestBase):
         c2.close()
         assert c1.connection.connection is not None
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @testing.exclude('mysql', '<', (5, 0, 3))
-    def testtwophasetransaction(self):
+    @testing.requires.two_phase_transactions
+    def test_two_phase_transaction(self):
         tlengine.begin_twophase()
         tlengine.execute(users.insert(), user_id=1, user_name='user1')
         tlengine.prepare()
@@ -726,6 +700,7 @@ class TLTransactionTest(TestBase):
             [(1,),(2,)]
         )
 
+counters = None
 class ForUpdateTest(TestBase):
     def setUpAll(self):
         global counters, metadata
@@ -770,7 +745,7 @@ class ForUpdateTest(TestBase):
 
     @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access')
 
-    def testqueued_update(self):
+    def test_queued_update(self):
         """Test SELECT FOR UPDATE with concurrent modifications.
 
         Runs concurrent modifications on a single row in the users table,
@@ -832,7 +807,7 @@ class ForUpdateTest(TestBase):
         return errors
 
     @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access')
-    def testqueued_select(self):
+    def test_queued_select(self):
         """Simple SELECT FOR UPDATE conflict test"""
 
         errors = self._threaded_overlap(2, [(1,2,3),(3,4,5)])
@@ -842,7 +817,7 @@ class ForUpdateTest(TestBase):
 
     @testing.unsupported('sqlite', 'mysql', 'mssql', 'firebird',
                          'sybase', 'access')
-    def testnowait_select(self):
+    def test_nowait_select(self):
         """Simple SELECT FOR UPDATE NOWAIT conflict test"""
 
         errors = self._threaded_overlap(2, [(1,2,3),(3,4,5)],
diff --git a/test/ext/activemapper.py b/test/ext/activemapper.py
deleted file mode 100644 (file)
index fa112c3..0000000
+++ /dev/null
@@ -1,357 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from datetime import datetime
-
-from sqlalchemy.ext.activemapper           import ActiveMapper, column, one_to_many, one_to_one, many_to_many, objectstore
-from sqlalchemy             import and_, or_, exceptions
-from sqlalchemy             import ForeignKey, String, Integer, DateTime, Table, Column
-from sqlalchemy.orm         import clear_mappers, backref, create_session, class_mapper
-import sqlalchemy.ext.activemapper as activemapper
-import sqlalchemy
-from testlib import *
-
-
-class testcase(TestBase):
-    def setUpAll(self):
-        clear_mappers()
-        objectstore.clear()
-        global Person, Preferences, Address
-
-        class Person(ActiveMapper):
-            class mapping:
-                __version_id_col__ = 'row_version'
-                full_name   = column(String(128))
-                first_name  = column(String(128))
-                middle_name = column(String(128))
-                last_name   = column(String(128))
-                birth_date  = column(DateTime)
-                ssn         = column(String(128))
-                gender      = column(String(128))
-                home_phone  = column(String(128))
-                cell_phone  = column(String(128))
-                work_phone  = column(String(128))
-                row_version = column(Integer, default=0)
-                prefs_id    = column(Integer, foreign_key=ForeignKey('preferences.id'))
-                addresses   = one_to_many('Address', colname='person_id', backref='person', order_by=['state', 'city', 'postal_code'])
-                preferences = one_to_one('Preferences', colname='pref_id', backref='person')
-
-            def __str__(self):
-                s =  '%s\n' % self.full_name
-                s += '  * birthdate: %s\n' % (self.birth_date or 'not provided')
-                s += '  * fave color: %s\n' % (self.preferences.favorite_color or 'Unknown')
-                s += '  * personality: %s\n' % (self.preferences.personality_type or 'Unknown')
-
-                for address in self.addresses:
-                    s += '  * address: %s\n' % address.address_1
-                    s += '             %s, %s %s\n' % (address.city, address.state, address.postal_code)
-
-                return s
-
-        class Preferences(ActiveMapper):
-            class mapping:
-                __table__        = 'preferences'
-                favorite_color   = column(String(128))
-                personality_type = column(String(128))
-
-        class Address(ActiveMapper):
-            class mapping:
-                # note that in other objects, the 'id' primary key is
-                # automatically added -- if you specify a primary key,
-                # then ActiveMapper will not add an integer primary key
-                # for you.
-                id          = column(Integer, primary_key=True)
-                type        = column(String(128))
-                address_1   = column(String(128))
-                city        = column(String(128))
-                state       = column(String(128))
-                postal_code = column(String(128))
-                person_id   = column(Integer, foreign_key=ForeignKey('person.id'))
-
-        activemapper.metadata.bind = testing.db
-        activemapper.create_tables()
-
-    def tearDownAll(self):
-        clear_mappers()
-        activemapper.drop_tables()
-
-    def tearDown(self):
-        for t in activemapper.metadata.table_iterator(reverse=True):
-            t.delete().execute()
-
-    def create_person_one(self):
-        # create a person
-        p1 = Person(
-                full_name='Jonathan LaCour',
-                birth_date=datetime(1979, 10, 12),
-                preferences=Preferences(
-                                favorite_color='Green',
-                                personality_type='ENTP'
-                            ),
-                addresses=[
-                    Address(
-                        address_1='123 Some Great Road.',
-                        city='Atlanta',
-                        state='GA',
-                        postal_code='30338'
-                    ),
-                    Address(
-                        address_1='435 Franklin Road.',
-                        city='Atlanta',
-                        state='GA',
-                        postal_code='30342'
-                    )
-                ]
-             )
-        return p1
-
-
-    def create_person_two(self):
-        p2 = Person(
-                full_name='Lacey LaCour',
-                addresses=[
-                    Address(
-                        address_1='123 Some Great Road.',
-                        city='Atlanta',
-                        state='GA',
-                        postal_code='30338'
-                    ),
-                    Address(
-                        address_1='200 Main Street',
-                        city='Roswell',
-                        state='GA',
-                        postal_code='30075'
-                    )
-                ]
-             )
-        # I don't like that I have to do this... and putting
-        # a "self.preferences = Preferences()" into the __init__
-        # of Person also doens't seem to fix this
-        p2.preferences = Preferences()
-
-        return p2
-
-
-    def test_create(self):
-        p1 = self.create_person_one()
-        objectstore.flush()
-        objectstore.clear()
-
-        results = Person.query.all()
-
-        self.assertEquals(len(results), 1)
-
-        person = results[0]
-        self.assertEquals(person.id, p1.id)
-        self.assertEquals(len(person.addresses), 2)
-        self.assertEquals(person.addresses[0].postal_code, '30338')
-
-    @testing.unsupported('mysql')
-    def test_update(self):
-        p1 = self.create_person_one()
-        objectstore.flush()
-        objectstore.clear()
-
-        person = Person.query.first()
-        person.gender = 'F'
-        objectstore.flush()
-        objectstore.clear()
-        self.assertEquals(person.row_version, 2)
-
-        person = Person.query.first()
-        person.gender = 'M'
-        objectstore.flush()
-        objectstore.clear()
-        self.assertEquals(person.row_version, 3)
-
-        #TODO: check that a concurrent modification raises exception
-        p1 = Person.query.first()
-        s1 = objectstore()
-        s2 = create_session()
-        objectstore.registry.set(s2)
-        p2 = Person.query.first()
-        p1.first_name = "jack"
-        p2.first_name = "ed"
-        objectstore.flush()
-        try:
-            objectstore.registry.set(s1)
-            objectstore.flush()
-            # Only dialects with a sane rowcount can detect the ConcurrentModificationError
-            if testing.db.dialect.supports_sane_rowcount:
-                assert False
-        except exceptions.ConcurrentModificationError:
-            pass
-
-
-    def test_delete(self):
-        p1 = self.create_person_one()
-
-        objectstore.flush()
-        objectstore.clear()
-
-        results = Person.query.all()
-        self.assertEquals(len(results), 1)
-
-        objectstore.delete(results[0])
-        objectstore.flush()
-        objectstore.clear()
-
-        results = Person.query.all()
-        self.assertEquals(len(results), 0)
-
-
-    def test_multiple(self):
-        p1 = self.create_person_one()
-        p2 = self.create_person_two()
-
-        objectstore.flush()
-        objectstore.clear()
-
-        # select and make sure we get back two results
-        people = Person.query.all()
-        self.assertEquals(len(people), 2)
-
-        # make sure that our backwards relationships work
-        self.assertEquals(people[0].addresses[0].person.id, p1.id)
-        self.assertEquals(people[1].addresses[0].person.id, p2.id)
-
-        # try a more complex select
-        results = Person.query.filter(
-            or_(
-                and_(
-                    Address.c.person_id == Person.c.id,
-                    Address.c.postal_code.like('30075')
-                ),
-                and_(
-                    Person.c.prefs_id == Preferences.c.id,
-                    Preferences.c.favorite_color == 'Green'
-                )
-            )
-        ).all()
-        self.assertEquals(len(results), 2)
-
-
-    def test_oneway_backref(self):
-        # FIXME: I don't know why, but it seems that my backwards relationship
-        #        on preferences still ends up being a list even though I pass
-        #        in uselist=False...
-        # FIXED: the backref is a new PropertyLoader which needs its own "uselist".
-        # uses a function which I dont think existed when you first wrote ActiveMapper.
-        p1 = self.create_person_one()
-        self.assertEquals(p1.preferences.person, p1)
-        objectstore.flush()
-        objectstore.delete(p1)
-
-        objectstore.flush()
-        objectstore.clear()
-
-
-    def test_select_by(self):
-        # FIXME: either I don't understand select_by, or it doesn't work.
-        # FIXED (as good as we can for now): yup....everyone thinks it works that way....it only
-        # generates joins for keyword arguments, not ColumnClause args.  would need a new layer of
-        # "MapperClause" objects to use properties in expressions. (MB)
-
-        p1 = self.create_person_one()
-        p2 = self.create_person_two()
-
-        objectstore.flush()
-        objectstore.clear()
-
-        results = Person.query.join('addresses').filter(
-            Address.c.postal_code.like('30075')
-        ).all()
-        self.assertEquals(len(results), 1)
-
-        self.assertEquals(Person.query.count(), 2)
-
-class testmanytomany(TestBase):
-     def setUpAll(self):
-         clear_mappers()
-         objectstore.clear()
-         global secondarytable, foo, baz
-         secondarytable = Table("secondarytable",
-             activemapper.metadata,
-             Column("foo_id", Integer, ForeignKey("foo.id"),primary_key=True),
-             Column("baz_id", Integer, ForeignKey("baz.id"),primary_key=True))
-
-         class foo(activemapper.ActiveMapper):
-             class mapping:
-                 name = column(String(30))
-#                 bazrel = many_to_many('baz', secondarytable, backref='foorel')
-
-         class baz(activemapper.ActiveMapper):
-             class mapping:
-                 name = column(String(30))
-                 foorel = many_to_many("foo", secondarytable, backref='bazrel')
-
-         activemapper.metadata.bind = testing.db
-         activemapper.create_tables()
-
-     # Create a couple of activemapper objects
-     def create_objects(self):
-         return foo(name='foo1'), baz(name='baz1')
-
-     def tearDownAll(self):
-         clear_mappers()
-         activemapper.drop_tables()
-         objectstore.clear()
-     def testbasic(self):
-         # Set up activemapper objects
-         foo1, baz1 = self.create_objects()
-
-         objectstore.flush()
-         objectstore.clear()
-
-         foo1 = foo.query.filter_by(name='foo1').one()
-         baz1 = baz.query.filter_by(name='baz1').one()
-
-         # Just checking ...
-         assert (foo1.name == 'foo1')
-         assert (baz1.name == 'baz1')
-
-         # Diagnostics ...
-         # import sys
-         # sys.stderr.write("\nbazrel missing from dir(foo1):\n%s\n"  % dir(foo1))
-         # sys.stderr.write("\nbazrel in foo1 relations:\n%s\n" %  foo1.relations)
-
-         # Optimistically based on activemapper one_to_many test, try  to append
-         # baz1 to foo1.bazrel - (AttributeError: 'foo' object has no attribute 'bazrel')
-         foo1.bazrel.append(baz1)
-         assert (foo1.bazrel == [baz1])
-
-class testselfreferential(TestBase):
-    def setUpAll(self):
-        clear_mappers()
-        objectstore.clear()
-        global TreeNode
-        class TreeNode(activemapper.ActiveMapper):
-            class mapping:
-                id = column(Integer, primary_key=True)
-                name = column(String(30))
-                parent_id = column(Integer, foreign_key=ForeignKey('treenode.id'))
-                children = one_to_many('TreeNode', colname='id', backref='parent')
-
-        activemapper.metadata.bind = testing.db
-        activemapper.create_tables()
-    def tearDownAll(self):
-        clear_mappers()
-        activemapper.drop_tables()
-
-    def testbasic(self):
-        t = TreeNode(name='node1')
-        t.children.append(TreeNode(name='node2'))
-        t.children.append(TreeNode(name='node3'))
-        objectstore.flush()
-        objectstore.clear()
-
-        t = TreeNode.query.filter_by(name='node1').one()
-        assert (t.name == 'node1')
-        assert (t.children[0].name == 'node2')
-        assert (t.children[1].name == 'node3')
-        assert (t.children[1].parent is t)
-
-        objectstore.clear()
-        t = TreeNode.query.filter_by(name='node3').one()
-        assert (t.parent is TreeNode.query.filter_by(name='node1').one())
-
-if __name__ == '__main__':
-    testenv.main()
index d5db4d01ed1e496c57bedaa37e561219d25725ba..1b6dc53d2e48ce57ae8bb00f8ba91acab2eeaf9d 100644 (file)
@@ -2,8 +2,7 @@ import testenv; testenv.configure_for_tests()
 import doctest, sys, unittest
 
 def suite():
-    unittest_modules = ['ext.activemapper',
-                        'ext.assignmapper',
+    unittest_modules = [
                         'ext.declarative',
                         'ext.orderinglist',
                         'ext.associationproxy']
diff --git a/test/ext/assignmapper.py b/test/ext/assignmapper.py
deleted file mode 100644 (file)
index 1cb2ca3..0000000
+++ /dev/null
@@ -1,83 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy import exceptions
-from sqlalchemy.orm import create_session, clear_mappers, relation, class_mapper
-from sqlalchemy.ext.assignmapper import assign_mapper
-from sqlalchemy.ext.sessioncontext import SessionContext
-from testlib import *
-
-
-class AssignMapperTest(TestBase):
-    def setUpAll(self):
-        global metadata, table, table2
-        metadata = MetaData(testing.db)
-        table = Table('sometable', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('data', String(30)))
-        table2 = Table('someothertable', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('someid', None, ForeignKey('sometable.id'))
-            )
-        metadata.create_all()
-
-    @testing.uses_deprecated('SessionContext', 'assign_mapper')
-    def setUp(self):
-        global SomeObject, SomeOtherObject, ctx
-        class SomeObject(object):pass
-        class SomeOtherObject(object):pass
-
-        ctx = SessionContext(create_session)
-        assign_mapper(ctx, SomeObject, table, properties={
-            'options':relation(SomeOtherObject)
-            })
-        assign_mapper(ctx, SomeOtherObject, table2)
-
-        s = SomeObject()
-        s.id = 1
-        s.data = 'hello'
-        sso = SomeOtherObject()
-        s.options.append(sso)
-        ctx.current.flush()
-        ctx.current.clear()
-
-    def tearDownAll(self):
-        metadata.drop_all()
-
-    def tearDown(self):
-        for table in metadata.table_iterator(reverse=True):
-            table.delete().execute()
-        clear_mappers()
-
-    @testing.uses_deprecated('assign_mapper')
-    def test_override_attributes(self):
-
-        sso = SomeOtherObject.query().first()
-
-        assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
-
-        s2 = SomeObject(someid=12)
-        s3 = SomeOtherObject(someid=123, bogus=345)
-
-        class ValidatedOtherObject(object):pass
-        assign_mapper(ctx, ValidatedOtherObject, table2, validate=True)
-
-        v1 = ValidatedOtherObject(someid=12)
-        try:
-            v2 = ValidatedOtherObject(someid=12, bogus=345)
-            assert False
-        except exceptions.ArgumentError:
-            pass
-
-    @testing.uses_deprecated('assign_mapper')
-    def test_dont_clobber_methods(self):
-        class MyClass(object):
-            def expunge(self):
-                return "an expunge !"
-
-        assign_mapper(ctx, MyClass, table2)
-
-        assert MyClass().expunge() == "an expunge !"
-
-
-if __name__ == '__main__':
-    testenv.main()
index ab07627ddaccbdcc24c8232772af551568307a84..4c4f9b0127f4a12c16a63f169daf7408635394f8 100644 (file)
@@ -5,7 +5,7 @@ from sqlalchemy.orm import *
 from sqlalchemy.orm.interfaces import MapperExtension
 from sqlalchemy.ext.declarative import declarative_base, declared_synonym, \
                                        synonym_for, comparable_using
-from sqlalchemy import exceptions
+from sqlalchemy import exc
 from testlib.fixtures import Base as Fixture
 from testlib import *
 
@@ -94,7 +94,7 @@ class DeclarativeTest(TestBase, AssertsExecutionResults):
 
                 id = Column(Integer, primary_key=True)
                 foo = column_property(User.id==5)
-        self.assertRaises(exceptions.InvalidRequestError, go)
+        self.assertRaises(exc.InvalidRequestError, go)
         
     def test_add_prop(self):
         class User(Base, Fixture):
@@ -183,7 +183,7 @@ class DeclarativeTest(TestBase, AssertsExecutionResults):
                 name = Column('name', String(50))
             assert False
         self.assertRaisesMessage(
-            exceptions.ArgumentError,
+            exc.ArgumentError,
             "Mapper Mapper|User|users could not assemble any primary key",
             define)
 
index 73406c00d509f22a027d3686ff26304059b82d39..77745aea144e04a633097b85ddffc81b0a7fd84c 100644 (file)
@@ -6,7 +6,9 @@ import sharding.alltests as sharding
 
 def suite():
     modules_to_test = (
-    'orm.attributes',
+        'orm.attributes',
+        'orm.extendedattr',
+        'orm.instrumentation',
         'orm.query',
         'orm.lazy_relations',
         'orm.eager_relations',
@@ -19,15 +21,17 @@ def suite():
         'orm.assorted_eager',
 
         'orm.naturalpks',
-        'orm.sessioncontext',
         'orm.unitofwork',
         'orm.session',
+        'orm.transaction',
+        'orm.scoping',
         'orm.cascade',
         'orm.relationships',
         'orm.association',
         'orm.merge',
         'orm.pickled',
         'orm.memusage',
+        'orm.utils',
 
         'orm.cycles',
 
@@ -36,6 +40,8 @@ def suite():
         'orm.manytomany',
         'orm.onetoone',
         'orm.dynamic',
+
+        'orm.deprecations',
         )
     alltests = unittest.TestSuite()
     for name in modules_to_test:
index 65d70253835341fbcade7875842b393d304392f4..1115849d25f30b772de5e727b4c2eef796b31fbe 100644 (file)
@@ -5,7 +5,6 @@ from sqlalchemy.orm import *
 from testlib import *
 
 class AssociationTest(TestBase):
-    @testing.uses_deprecated('association option')
     def setUpAll(self):
         global items, item_keywords, keywords, metadata, Item, Keyword, KeywordAssociation
         metadata = MetaData(testing.db)
@@ -46,7 +45,7 @@ class AssociationTest(TestBase):
             'keyword':relation(Keyword, lazy=False)
         }, primary_key=[item_keywords.c.item_id, item_keywords.c.keyword_id], order_by=[item_keywords.c.data])
         mapper(Item, items, properties={
-            'keywords' : relation(KeywordAssociation, association=Keyword)
+            'keywords' : relation(KeywordAssociation, cascade="all, delete-orphan")
         })
 
     def tearDown(self):
@@ -123,7 +122,6 @@ class AssociationTest(TestBase):
         print loaded
         self.assert_(saved == loaded)
 
-    @testing.uses_deprecated('association option')
     def testdelete(self):
         sess = create_session()
         item1 = Item('item1')
@@ -185,7 +183,7 @@ in self.c ]
 
         mapper(Originals, table_originals, order_by=Originals.order,
             properties={
-                'people': relation(IsAuthor, association=People),
+                'people': relation(IsAuthor, cascade="all, delete-orphan"),
                 'authors': relation(People, secondary=table_isauthor, backref='written',
                             primaryjoin=and_(table_originals.c.ID==table_isauthor.c.OriginalsID,
                             table_isauthor.c.Kind=='A')),
@@ -193,7 +191,7 @@ in self.c ]
                 'date': table_originals.c.Date,
             })
         mapper(People, table_people, order_by=People.order, properties=    {
-                'originals':        relation(IsAuthor, association=Originals),
+                'originals':        relation(IsAuthor, cascade="all, delete-orphan"),
                 'name':             table_people.c.Name,
                 'country':          table_people.c.Country,
             })
index af3fcbc7bbf7e26d99752178f9266c18f8cdf590..731a9f91617545c31c967d000f35b955986e2761 100644 (file)
@@ -4,7 +4,6 @@ import testenv; testenv.configure_for_tests()
 import random, datetime
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
 from testlib import *
 from testlib import fixtures
 
@@ -125,15 +124,6 @@ class EagerTest(TestBase, AssertsExecutionResults):
         print result
         assert result == [u'1 Some Category', u'3 Some Category']
 
-    @testing.uses_deprecated('//select')
-    def test_withouteagerload_deprecated(self):
-        s = create_session()
-        l=s.query(Test).select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
-            from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
-        result = ["%d %s" % ( t.id,t.category.name ) for t in l]
-        print result
-        assert result == [u'1 Some Category', u'3 Some Category']
-
     def test_witheagerload(self):
         """test that an eagerload locates the correct "from" clause with
         which to attach to, when presented with a query that already has a complicated from clause."""
@@ -152,17 +142,6 @@ class EagerTest(TestBase, AssertsExecutionResults):
         print result
         assert result == [u'1 Some Category', u'3 Some Category']
 
-    @testing.uses_deprecated('//select')
-    def test_witheagerload_deprecated(self):
-        """As test_witheagerload, but via select()."""
-        s = create_session()
-        q=s.query(Test).options(eagerload('category'))
-        l=q.select ( and_(tests.c.owner_id==1,or_(options.c.someoption==None,options.c.someoption==False)),
-            from_obj=[tests.outerjoin(options,and_(tests.c.id==options.c.test_id,tests.c.owner_id==options.c.owner_id))])
-        result = ["%d %s" % ( t.id,t.category.name ) for t in l]
-        print result
-        assert result == [u'1 Some Category', u'3 Some Category']
-
     def test_dslish(self):
         """test the same as witheagerload except using generative"""
         s = create_session()
@@ -188,16 +167,6 @@ class EagerTest(TestBase, AssertsExecutionResults):
         print result
         assert result == [u'3 Some Category']
 
-    @testing.unsupported('sybase')
-    @testing.uses_deprecated('//select', '//join_to')
-    def test_withoutouterjoin_literal_deprecated(self):
-        s = create_session()
-        q=s.query(Test).options(eagerload('category'))
-        l=q.select( (tests.c.owner_id==1) & ('options.someoption is null or options.someoption=%s' % false) & q.join_to('owner_option') )
-        result = ["%d %s" % ( t.id,t.category.name ) for t in l]
-        print result
-        assert result == [u'3 Some Category']
-
     def test_withoutouterjoin(self):
         s = create_session()
         q=s.query(Test).options(eagerload('category'))
@@ -206,15 +175,6 @@ class EagerTest(TestBase, AssertsExecutionResults):
         print result
         assert result == [u'3 Some Category']
 
-    @testing.uses_deprecated('//select', '//join_to', '//join_via')
-    def test_withoutouterjoin_deprecated(self):
-        s = create_session()
-        q=s.query(Test).options(eagerload('category'))
-        l=q.select( (tests.c.owner_id==1) & ((options.c.someoption==None) | (options.c.someoption==False)) & q.join_to('owner_option') )
-        result = ["%d %s" % ( t.id,t.category.name ) for t in l]
-        print result
-        assert result == [u'3 Some Category']
-
 class EagerTest2(TestBase, AssertsExecutionResults):
     def setUpAll(self):
         global metadata, middle, left, right
@@ -389,7 +349,7 @@ class EagerTest4(ORMTest):
         sess.flush()
 
         q = sess.query(Department)
-        q = q.join('employees').filter(Employee.c.name.startswith('J')).distinct().order_by([desc(Department.c.name)])
+        q = q.join('employees').filter(Employee.name.startswith('J')).distinct().order_by([desc(Department.name)])
         assert q.count() == 2
         assert q[0] is d2
 
@@ -543,12 +503,11 @@ class EagerTest6(ORMTest):
         x.inheritedParts
 
 class EagerTest7(ORMTest):
-    @testing.uses_deprecated('SessionContext')
     def define_tables(self, metadata):
         global companies_table, addresses_table, invoice_table, phones_table, items_table, ctx
         global Company, Address, Phone, Item,Invoice
 
-        ctx = SessionContext(create_session)
+        ctx = scoped_session(create_session)
 
         companies_table = Table('companies', metadata,
             Column('company_id', Integer, Sequence('company_id_seq', optional=True), primary_key = True),
@@ -606,20 +565,19 @@ class EagerTest7(ORMTest):
             def __repr__(self):
                 return "Item: " + repr(getattr(self, 'item_id', None)) + " " + repr(getattr(self, 'invoice_id', None)) + " " + repr(self.code) + " " + repr(self.qty)
 
-    @testing.uses_deprecated('SessionContext')
     def testone(self):
         """tests eager load of a many-to-one attached to a one-to-many.  this testcase illustrated
         the bug, which is that when the single Company is loaded, no further processing of the rows
         occurred in order to load the Company's second Address object."""
 
         mapper(Address, addresses_table, properties={
-            }, extension=ctx.mapper_extension)
+            }, extension=ctx.extension)
         mapper(Company, companies_table, properties={
             'addresses' : relation(Address, lazy=False),
-            }, extension=ctx.mapper_extension)
+            }, extension=ctx.extension)
         mapper(Invoice, invoice_table, properties={
             'company': relation(Company, lazy=False, )
-            }, extension=ctx.mapper_extension)
+            }, extension=ctx.extension)
 
         c1 = Company()
         c1.company_name = 'company 1'
@@ -633,18 +591,18 @@ class EagerTest7(ORMTest):
         i1.date = datetime.datetime.now()
         i1.company = c1
 
-        ctx.current.flush()
+        ctx.flush()
 
         company_id = c1.company_id
         invoice_id = i1.invoice_id
 
-        ctx.current.clear()
+        ctx.clear()
 
-        c = ctx.current.query(Company).get(company_id)
+        c = ctx.query(Company).get(company_id)
 
-        ctx.current.clear()
+        ctx.clear()
 
-        i = ctx.current.query(Invoice).get(invoice_id)
+        i = ctx.query(Invoice).get(invoice_id)
 
         print repr(c)
         print repr(i.company)
@@ -653,24 +611,24 @@ class EagerTest7(ORMTest):
     def testtwo(self):
         """this is the original testcase that includes various complicating factors"""
 
-        mapper(Phone, phones_table, extension=ctx.mapper_extension)
+        mapper(Phone, phones_table, extension=ctx.extension)
 
         mapper(Address, addresses_table, properties={
             'phones': relation(Phone, lazy=False, backref='address')
-            }, extension=ctx.mapper_extension)
+            }, extension=ctx.extension)
 
         mapper(Company, companies_table, properties={
             'addresses' : relation(Address, lazy=False, backref='company'),
-            }, extension=ctx.mapper_extension)
+            }, extension=ctx.extension)
 
-        mapper(Item, items_table, extension=ctx.mapper_extension)
+        mapper(Item, items_table, extension=ctx.extension)
 
         mapper(Invoice, invoice_table, properties={
             'items': relation(Item, lazy=False, backref='invoice'),
             'company': relation(Company, lazy=False, backref='invoices')
-            }, extension=ctx.mapper_extension)
+            }, extension=ctx.extension)
 
-        ctx.current.clear()
+        ctx.clear()
         c1 = Company()
         c1.company_name = 'company 1'
 
@@ -705,13 +663,13 @@ class EagerTest7(ORMTest):
 
         c1.addresses.append(a2)
 
-        ctx.current.flush()
+        ctx.flush()
 
         company_id = c1.company_id
 
-        ctx.current.clear()
+        ctx.clear()
 
-        a = ctx.current.query(Company).get(company_id)
+        a = ctx.query(Company).get(company_id)
         print repr(a)
 
         # set up an invoice
@@ -734,18 +692,18 @@ class EagerTest7(ORMTest):
         item3.qty = 3
         item3.invoice = i1
 
-        ctx.current.flush()
+        ctx.flush()
 
         invoice_id = i1.invoice_id
 
-        ctx.current.clear()
+        ctx.clear()
 
-        c = ctx.current.query(Company).get(company_id)
+        c = ctx.query(Company).get(company_id)
         print repr(c)
 
-        ctx.current.clear()
+        ctx.clear()
 
-        i = ctx.current.query(Invoice).get(invoice_id)
+        i = ctx.query(Invoice).get(invoice_id)
 
         assert repr(i.company) == repr(c), repr(i.company) +  " does not match " + repr(c)
 
index caa129e5ea0864c911bf8ca909315415735b6f35..3883cdcd1b0a285868d30b8307a4f4007249a5e2 100644 (file)
@@ -2,18 +2,24 @@ import testenv; testenv.configure_for_tests()
 import pickle
 import sqlalchemy.orm.attributes as attributes
 from sqlalchemy.orm.collections import collection
-from sqlalchemy import exceptions
+from sqlalchemy.orm.interfaces import AttributeExtension
+from sqlalchemy import exc as sa_exc
 from testlib import *
 from testlib import fixtures
 
-ROLLBACK_SUPPORTED=False
-
-# these test classes defined at the module
-# level to support pickling
-class MyTest(object):pass
-class MyTest2(object):pass
+# global for pickling tests
+MyTest = None
+MyTest2 = None
 
 class AttributesTest(TestBase):
+    def setUp(self):
+        global MyTest, MyTest2
+        class MyTest(object): pass
+        class MyTest2(object): pass
+
+    def tearDown(self):
+        global MyTest, MyTest2
+        MyTest, MyTest2 = None, None
 
     def test_basic(self):
         class User(object):pass
@@ -29,7 +35,7 @@ class AttributesTest(TestBase):
         u.email_address = 'lala@123.com'
 
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
-        u._state.commit_all()
+        attributes.instance_state(u).commit_all()
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
 
         u.user_name = 'heythere'
@@ -99,31 +105,33 @@ class AttributesTest(TestBase):
         class Foo(object):pass
 
         data = {'a':'this is a', 'b':12}
-        def loader(instance, keys):
+        def loader(state, keys):
             for k in keys:
-                instance.__dict__[k] = data[k]
+                state.dict[k] = data[k]
             return attributes.ATTR_WAS_SET
 
-        attributes.register_class(Foo, deferred_scalar_loader=loader)
+        attributes.register_class(Foo)
+        manager = attributes.manager_of_class(Foo)
+        manager.deferred_scalar_loader = loader
         attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
         attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
 
         f = Foo()
-        f._state.expire_attributes(None)
+        attributes.instance_state(f).expire_attributes(None)
         self.assertEquals(f.a, "this is a")
         self.assertEquals(f.b, 12)
 
         f.a = "this is some new a"
-        f._state.expire_attributes(None)
+        attributes.instance_state(f).expire_attributes(None)
         self.assertEquals(f.a, "this is a")
         self.assertEquals(f.b, 12)
 
-        f._state.expire_attributes(None)
+        attributes.instance_state(f).expire_attributes(None)
         f.a = "this is another new a"
         self.assertEquals(f.a, "this is another new a")
         self.assertEquals(f.b, 12)
 
-        f._state.expire_attributes(None)
+        attributes.instance_state(f).expire_attributes(None)
         self.assertEquals(f.a, "this is a")
         self.assertEquals(f.b, 12)
 
@@ -131,23 +139,25 @@ class AttributesTest(TestBase):
         self.assertEquals(f.a, None)
         self.assertEquals(f.b, 12)
 
-        f._state.commit_all()
+        attributes.instance_state(f).commit_all()
         self.assertEquals(f.a, None)
         self.assertEquals(f.b, 12)
 
     def test_deferred_pickleable(self):
         data = {'a':'this is a', 'b':12}
-        def loader(instance, keys):
+        def loader(state, keys):
             for k in keys:
-                instance.__dict__[k] = data[k]
+                state.dict[k] = data[k]
             return attributes.ATTR_WAS_SET
 
-        attributes.register_class(MyTest, deferred_scalar_loader=loader)
+        attributes.register_class(MyTest)
+        manager = attributes.manager_of_class(MyTest)
+        manager.deferred_scalar_loader=loader
         attributes.register_attribute(MyTest, 'a', uselist=False, useobject=False)
         attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False)
 
         m = MyTest()
-        m._state.expire_attributes(None)
+        attributes.instance_state(m).expire_attributes(None)
         assert 'a' not in m.__dict__
         m2 = pickle.loads(pickle.dumps(m))
         assert 'a' not in m2.__dict__
@@ -176,7 +186,7 @@ class AttributesTest(TestBase):
         u.addresses.append(a)
 
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
-        u, a._state.commit_all()
+        u, attributes.instance_state(a).commit_all()
         self.assert_(u.user_id == 7 and u.user_name == 'john' and u.addresses[0].email_address == 'lala@123.com')
 
         u.user_name = 'heythere'
@@ -186,6 +196,45 @@ class AttributesTest(TestBase):
         u.addresses.append(a)
         self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.addresses[0].email_address == 'lala@123.com' and u.addresses[1].email_address == 'foo@bar.com')
 
+    def test_scalar_listener(self):
+        # listeners on ScalarAttributeImpl and MutableScalarAttributeImpl aren't used normally.
+        # test that they work for the benefit of user extensions
+        class Foo(object):
+            pass
+        
+        results = []
+        class ReceiveEvents(AttributeExtension):
+            def append(self, state, child, initiator):
+                assert False
+
+            def remove(self, state, child, initiator):
+                results.append(("remove", state.obj(), child))
+
+            def set(self, state, child, oldchild, initiator):
+                results.append(("set", state.obj(), child, oldchild))
+        
+        attributes.register_class(Foo)
+        attributes.register_attribute(Foo, 'x', uselist=False, mutable_scalars=False, useobject=False, extension=ReceiveEvents())
+        attributes.register_attribute(Foo, 'y', uselist=False, mutable_scalars=True, useobject=False, copy_function=lambda x:x, extension=ReceiveEvents())
+        
+        f = Foo()
+        f.x = 5
+        f.x = 17
+        del f.x
+        f.y = [1,2,3]
+        f.y = [4,5,6]
+        del f.y
+        
+        self.assertEquals(results, [
+            ('set', f, 5, None),
+            ('set', f, 17, 5),
+            ('remove', f, 17),
+            ('set', f, [1,2,3], None),
+            ('set', f, [4,5,6], [1,2,3]),
+            ('remove', f, [4,5,6])
+        ])
+        
+        
     def test_lazytrackparent(self):
         """test that the "hasparent" flag works properly when lazy loaders and backrefs are used"""
 
@@ -201,9 +250,9 @@ class AttributesTest(TestBase):
         # create objects as if they'd been freshly loaded from the database (without history)
         b = Blog()
         p1 = Post()
-        b._state.set_callable('posts', lambda:[p1])
-        p1._state.set_callable('blog', lambda:b)
-        p1, b._state.commit_all()
+        attributes.instance_state(b).set_callable('posts', lambda:[p1])
+        attributes.instance_state(p1).set_callable('blog', lambda:b)
+        p1, attributes.instance_state(b).commit_all()
 
         # no orphans (called before the lazy loaders fire off)
         assert attributes.has_parent(Blog, p1, 'posts', optimistic=True)
@@ -253,10 +302,10 @@ class AttributesTest(TestBase):
         states = set()
         class Foo(object):
             def __init__(self):
-                states.add(self._state)
+                states.add(attributes.instance_state(self))
         class Bar(Foo):
             def __init__(self):
-                states.add(self._state)
+                states.add(attributes.instance_state(self))
                 Foo.__init__(self)
 
 
@@ -283,10 +332,10 @@ class AttributesTest(TestBase):
         el = Element()
         x = Bar()
         x.element = el
-        self.assertEquals(attributes.get_history(x._state, 'element'), ([el],[], []))
-        x._state.commit_all()
+        self.assertEquals(attributes.get_history(attributes.instance_state(x), 'element'), ([el],[], []))
+        attributes.instance_state(x).commit_all()
 
-        (added, unchanged, deleted) = attributes.get_history(x._state, 'element')
+        (added, unchanged, deleted) = attributes.get_history(attributes.instance_state(x), 'element')
         assert added == []
         assert unchanged == [el]
 
@@ -312,9 +361,9 @@ class AttributesTest(TestBase):
         attributes.register_attribute(Bar, 'id', uselist=False, useobject=True)
 
         x = Foo()
-        x._state.commit_all()
+        attributes.instance_state(x).commit_all()
         x.col2.append(bar4)
-        self.assertEquals(attributes.get_history(x._state, 'col2'), ([bar4], [bar1, bar2, bar3], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(x), 'col2'), ([bar4], [bar1, bar2, bar3], []))
 
     def test_parenttrack(self):
         class Foo(object):pass
@@ -358,9 +407,9 @@ class AttributesTest(TestBase):
         attributes.register_attribute(Foo, 'element', uselist=False, copy_function=lambda x:[y for y in x], mutable_scalars=True, useobject=False)
         x = Foo()
         x.element = ['one', 'two', 'three']
-        x._state.commit_all()
+        attributes.instance_state(x).commit_all()
         x.element[1] = 'five'
-        assert x._state.is_modified()
+        assert attributes.instance_state(x).check_modified()
 
         attributes.unregister_class(Foo)
 
@@ -368,9 +417,9 @@ class AttributesTest(TestBase):
         attributes.register_attribute(Foo, 'element', uselist=False, useobject=False)
         x = Foo()
         x.element = ['one', 'two', 'three']
-        x._state.commit_all()
+        attributes.instance_state(x).commit_all()
         x.element[1] = 'five'
-        assert not x._state.is_modified()
+        assert not attributes.instance_state(x).check_modified()
 
     def test_descriptorattributes(self):
         """changeset: 1633 broke ability to use ORM to map classes with unusual
@@ -379,27 +428,31 @@ class AttributesTest(TestBase):
         This is a simple regression test to prevent that defect.
         """
         class des(object):
-            def __get__(self, instance, owner): raise AttributeError('fake attribute')
+            def __get__(self, instance, owner):
+                raise AttributeError('fake attribute')
 
         class Foo(object):
             A = des()
 
-
+        attributes.register_class(Foo)
         attributes.unregister_class(Foo)
 
     def test_collectionclasses(self):
 
         class Foo(object):pass
         attributes.register_class(Foo)
+
         attributes.register_attribute(Foo, "collection", uselist=True, typecallable=set, useobject=True)
+        assert attributes.manager_of_class(Foo).is_instrumented("collection")
         assert isinstance(Foo().collection, set)
 
         attributes.unregister_attribute(Foo, "collection")
-
+        assert not attributes.manager_of_class(Foo).is_instrumented("collection")
+        
         try:
             attributes.register_attribute(Foo, "collection", uselist=True, typecallable=dict, useobject=True)
             assert False
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             assert str(e) == "Type InstrumentedDict must elect an appender method to be a collection class"
 
         class MyDict(dict):
@@ -418,7 +471,7 @@ class AttributesTest(TestBase):
         try:
             attributes.register_attribute(Foo, "collection", uselist=True, typecallable=MyColl, useobject=True)
             assert False
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             assert str(e) == "Type MyColl must elect an appender method to be a collection class"
 
         class MyColl(object):
@@ -435,7 +488,7 @@ class AttributesTest(TestBase):
         try:
             Foo().collection
             assert True
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             assert False
 
 
@@ -512,7 +565,7 @@ class BackrefTest(TestBase):
         j.port = None
         self.assert_(p.jack is None)
 
-class DeferredBackrefTest(TestBase):
+class PendingBackrefTest(TestBase):
     def setUp(self):
         global Post, Blog, called, lazy_load
 
@@ -550,6 +603,7 @@ class DeferredBackrefTest(TestBase):
 
         b = Blog("blog 1")
         p = Post("post 4")
+        
         p.blog = b
         p = Post("post 5")
         p.blog = b
@@ -559,6 +613,22 @@ class DeferredBackrefTest(TestBase):
         # calling backref calls the callable, populates extra posts
         assert b.posts == [p1, p2, p3, Post("post 4"), Post("post 5")]
         assert called[0] == 1
+    
+    def test_lazy_history(self):
+        global lazy_load
+
+        p1, p2, p3 = Post("post 1"), Post("post 2"), Post("post 3")
+        lazy_load = [p1, p2, p3]
+        
+        b = Blog("blog 1")
+        p = Post("post 4")
+        p.blog = b
+        
+        p4 = Post("post 5")
+        p4.blog = b
+        assert called[0] == 0
+        self.assertEquals(attributes.instance_state(b).get_history('posts'), ([p, p4], [p1, p2, p3], []))
+        assert called[0] == 1
 
     def test_lazy_remove(self):
         global lazy_load
@@ -609,17 +679,17 @@ class HistoryTest(TestBase):
         attributes.register_attribute(Foo, 'someattr', uselist=False, useobject=False)
 
         f = Foo()
-        self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None)
+        self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
 
         f.someattr = 3
-        self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None)
+        self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
 
         f = Foo()
         f.someattr = 3
-        self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), None)
+        self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), None)
         
-        f._state.commit(['someattr'])
-        self.assertEquals(Foo.someattr.impl.get_committed_value(f._state), 3)
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(Foo.someattr.impl.get_committed_value(attributes.instance_state(f)), 3)
 
     def test_scalar(self):
         class Foo(fixtures.Base):
@@ -630,48 +700,59 @@ class HistoryTest(TestBase):
 
         # case 1.  new object
         f = Foo()
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
 
         f.someattr = "hi"
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), (['hi'], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['hi'], [], []))
 
-        f._state.commit(['someattr'])
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['hi'], []))
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['hi'], []))
 
         f.someattr = 'there'
 
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), (['there'], [], ['hi']))
-        f._state.commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['there'], [], ['hi']))
+        attributes.instance_state(f).commit(['someattr'])
 
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['there'], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['there'], []))
 
         del f.someattr
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], ['there']))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], ['there']))
 
         # case 2.  object with direct dictionary settings (similar to a load operation)
         f = Foo()
         f.__dict__['someattr'] = 'new'
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
 
         f.someattr = 'old'
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), (['old'], [], ['new']))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['old'], [], ['new']))
 
-        f._state.commit(['someattr'])
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['old'], []))
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['old'], []))
 
         # setting None on uninitialized is currently a change for a scalar attribute
         # no lazyload occurs so this allows overwrite operation to proceed
         f = Foo()
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
+        print f._foostate.committed_state
         f.someattr = None
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], []))
+        print f._foostate.committed_state, f._foostate.dict
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], []))
 
         f = Foo()
         f.__dict__['someattr'] = 'new'
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], ['new'], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
         f.someattr = None
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], ['new']))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], ['new']))
 
+        # set same value twice
+        f = Foo()
+        attributes.instance_state(f).commit(['someattr'])
+        f.someattr = 'one'
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], [], []))
+        f.someattr = 'two'
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['two'], [], []))
+        
+        
     def test_mutable_scalar(self):
         class Foo(fixtures.Base):
             pass
@@ -681,33 +762,33 @@ class HistoryTest(TestBase):
 
         # case 1.  new object
         f = Foo()
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
 
         f.someattr = {'foo':'hi'}
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'hi'}], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'hi'}], [], []))
 
-        f._state.commit(['someattr'])
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'hi'}], []))
-        self.assertEquals(f._state.committed_state['someattr'], {'foo':'hi'})
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'hi'}], []))
+        self.assertEquals(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'})
 
         f.someattr['foo'] = 'there'
-        self.assertEquals(f._state.committed_state['someattr'], {'foo':'hi'})
+        self.assertEquals(attributes.instance_state(f).committed_state['someattr'], {'foo':'hi'})
 
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'there'}], [], [{'foo':'hi'}]))
-        f._state.commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'there'}], [], [{'foo':'hi'}]))
+        attributes.instance_state(f).commit(['someattr'])
 
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'there'}], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'there'}], []))
 
         # case 2.  object with direct dictionary settings (similar to a load operation)
         f = Foo()
         f.__dict__['someattr'] = {'foo':'new'}
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'new'}], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'new'}], []))
 
         f.someattr = {'foo':'old'}
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([{'foo':'old'}], [], [{'foo':'new'}]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([{'foo':'old'}], [], [{'foo':'new'}]))
 
-        f._state.commit(['someattr'])
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [{'foo':'old'}], []))
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [{'foo':'old'}], []))
 
 
     def test_use_object(self):
@@ -729,48 +810,56 @@ class HistoryTest(TestBase):
 
         # case 1.  new object
         f = Foo()
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [None], []))
 
         f.someattr = hi
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
 
-        f._state.commit(['someattr'])
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], []))
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], []))
 
         f.someattr = there
 
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], [hi]))
-        f._state.commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi]))
+        attributes.instance_state(f).commit(['someattr'])
 
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [there], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [there], []))
 
         del f.someattr
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], [there]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], [there]))
 
         # case 2.  object with direct dictionary settings (similar to a load operation)
         f = Foo()
-        f.__dict__['someattr'] = new
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+        f.__dict__['someattr'] = 'new'
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
 
         f.someattr = old
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [], [new]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], ['new']))
 
-        f._state.commit(['someattr'])
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old], []))
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [old], []))
 
         # setting None on uninitialized is currently not a change for an object attribute
         # (this is different than scalar attribute).  a lazyload has occured so if its
         # None, its really None
         f = Foo()
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [None], []))
         f.someattr = None
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [None], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [None], []))
 
         f = Foo()
-        f.__dict__['someattr'] = new
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+        f.__dict__['someattr'] = 'new'
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], ['new'], []))
         f.someattr = None
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([None], [], [new]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([None], [], ['new']))
+
+        # set same value twice
+        f = Foo()
+        attributes.instance_state(f).commit(['someattr'])
+        f.someattr = 'one'
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['one'], [], []))
+        f.someattr = 'two'
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), (['two'], [], []))
 
     def test_object_collections_set(self):
         class Foo(fixtures.Base):
@@ -789,39 +878,39 @@ class HistoryTest(TestBase):
 
         # case 1.  new object
         f = Foo()
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
 
         f.someattr = [hi]
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
 
-        f._state.commit(['someattr'])
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], []))
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], []))
 
         f.someattr = [there]
 
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], [hi]))
-        f._state.commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], [hi]))
+        attributes.instance_state(f).commit(['someattr'])
 
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [there], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [there], []))
 
         f.someattr = [hi]
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], [there]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], [there]))
 
         f.someattr = [old, new]
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old, new], [], [there]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [], [there]))
 
         # case 2.  object with direct settings (similar to a load operation)
         f = Foo()
-        collection = attributes.init_collection(f, 'someattr')
+        collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
         collection.append_without_event(new)
-        f._state.commit_all()
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+        attributes.instance_state(f).commit_all()
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new], []))
 
         f.someattr = [old]
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [], [new]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [], [new]))
 
-        f._state.commit(['someattr'])
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old], []))
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [old], []))
 
     def test_dict_collections(self):
         class Foo(fixtures.Base):
@@ -840,16 +929,16 @@ class HistoryTest(TestBase):
         new = Bar(name='new')
 
         f = Foo()
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
 
         f.someattr['hi'] = hi
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
 
         f.someattr['there'] = there
-        self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([hi, there]), set([]), set([])))
+        self.assertEquals(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([hi, there]), set([]), set([])))
 
-        f._state.commit(['someattr'])
-        self.assertEquals(tuple([set(x) for x in attributes.get_history(f._state, 'someattr')]), (set([]), set([hi, there]), set([])))
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(tuple([set(x) for x in attributes.get_history(attributes.instance_state(f), 'someattr')]), (set([]), set([hi, there]), set([])))
 
     def test_object_collections_mutate(self):
         class Foo(fixtures.Base):
@@ -868,65 +957,65 @@ class HistoryTest(TestBase):
 
         # case 1.  new object
         f = Foo(id=1)
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], []))
 
         f.someattr.append(hi)
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
 
-        f._state.commit(['someattr'])
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], []))
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], []))
 
         f.someattr.append(there)
 
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [hi], []))
-        f._state.commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [hi], []))
+        attributes.instance_state(f).commit(['someattr'])
 
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, there], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi, there], []))
 
         f.someattr.remove(there)
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi], [there]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi], [there]))
 
         f.someattr.append(old)
         f.someattr.append(new)
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old, new], [hi], [there]))
-        f._state.commit(['someattr'])
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [hi, old, new], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old, new], [hi], [there]))
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [hi, old, new], []))
 
         f.someattr.pop(0)
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [old, new], [hi]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [old, new], [hi]))
 
         # case 2.  object with direct settings (similar to a load operation)
         f = Foo()
         f.__dict__['id'] = 1
-        collection = attributes.init_collection(f, 'someattr')
+        collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
         collection.append_without_event(new)
-        f._state.commit_all()
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+        attributes.instance_state(f).commit_all()
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new], []))
 
         f.someattr.append(old)
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([old], [new], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([old], [new], []))
 
-        f._state.commit(['someattr'])
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new, old], []))
+        attributes.instance_state(f).commit(['someattr'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new, old], []))
 
         f = Foo()
-        collection = attributes.init_collection(f, 'someattr')
+        collection = attributes.init_collection(attributes.instance_state(f), 'someattr')
         collection.append_without_event(new)
-        f._state.commit_all()
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [new], []))
+        attributes.instance_state(f).commit_all()
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [new], []))
 
         f.id = 1
         f.someattr.remove(new)
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([], [], [new]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([], [], [new]))
 
         # case 3.  mixing appends with sets
         f = Foo()
         f.someattr.append(hi)
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi], [], []))
         f.someattr.append(there)
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([hi, there], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([hi, there], [], []))
         f.someattr = [there]
-        self.assertEquals(attributes.get_history(f._state, 'someattr'), ([there], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'someattr'), ([there], [], []))
 
     def test_collections_via_backref(self):
         class Foo(fixtures.Base):
@@ -941,19 +1030,19 @@ class HistoryTest(TestBase):
 
         f1 = Foo()
         b1 = Bar()
-        self.assertEquals(attributes.get_history(f1._state, 'bars'), ([], [], []))
-        self.assertEquals(attributes.get_history(b1._state, 'foo'), ([], [None], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(b1), 'foo'), ([], [None], []))
 
         #b1.foo = f1
         f1.bars.append(b1)
-        self.assertEquals(attributes.get_history(f1._state, 'bars'), ([b1], [], []))
-        self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(b1), 'foo'), ([f1], [], []))
 
         b2 = Bar()
         f1.bars.append(b2)
-        self.assertEquals(attributes.get_history(f1._state, 'bars'), ([b1, b2], [], []))
-        self.assertEquals(attributes.get_history(b1._state, 'foo'), ([f1], [], []))
-        self.assertEquals(attributes.get_history(b2._state, 'foo'), ([f1], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1, b2], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(b1), 'foo'), ([f1], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(b2), 'foo'), ([f1], [], []))
 
     def test_lazy_backref_collections(self):
         class Foo(fixtures.Base):
@@ -978,17 +1067,17 @@ class HistoryTest(TestBase):
         f = Foo()
         bar4 = Bar()
         bar4.foo = f
-        self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar2, bar3], []))
 
         lazy_load = None
         f = Foo()
         bar4 = Bar()
         bar4.foo = f
-        self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [], []))
 
         lazy_load = [bar1, bar2, bar3]
-        f._state.expire_attributes(['bars'])
-        self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar2, bar3], []))
+        attributes.instance_state(f).expire_attributes(['bars'])
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar2, bar3], []))
 
     def test_collections_via_lazyload(self):
         class Foo(fixtures.Base):
@@ -1011,26 +1100,26 @@ class HistoryTest(TestBase):
 
         f = Foo()
         f.bars = []
-        self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [], [bar1, bar2, bar3]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [], [bar1, bar2, bar3]))
 
         f = Foo()
         f.bars.append(bar4)
-        self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar2, bar3], []) )
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar2, bar3], []) )
 
         f = Foo()
         f.bars.remove(bar2)
-        self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar3], [bar2]))
         f.bars.append(bar4)
-        self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [bar1, bar3], [bar2]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar4], [bar1, bar3], [bar2]))
 
         f = Foo()
         del f.bars[1]
-        self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar3], [bar2]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([], [bar1, bar3], [bar2]))
 
         lazy_load = None
         f = Foo()
         f.bars.append(bar2)
-        self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar2], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bars'), ([bar2], [], []))
 
     def test_scalar_via_lazyload(self):
         class Foo(fixtures.Base):
@@ -1051,24 +1140,24 @@ class HistoryTest(TestBase):
 
         f = Foo()
         self.assertEquals(f.bar, "hi")
-        self.assertEquals(attributes.get_history(f._state, 'bar'), ([], ["hi"], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], ["hi"], []))
 
         f = Foo()
         f.bar = None
-        self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], []))
 
         f = Foo()
         f.bar = "there"
-        self.assertEquals(attributes.get_history(f._state, 'bar'), (["there"], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), (["there"], [], []))
         f.bar = "hi"
-        self.assertEquals(attributes.get_history(f._state, 'bar'), (["hi"], [], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), (["hi"], [], []))
 
         f = Foo()
         self.assertEquals(f.bar, "hi")
         del f.bar
-        self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [], ["hi"]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], [], ["hi"]))
         assert f.bar is None
-        self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], ["hi"]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], ["hi"]))
 
     def test_scalar_object_via_lazyload(self):
         class Foo(fixtures.Base):
@@ -1092,24 +1181,25 @@ class HistoryTest(TestBase):
         # operations
 
         f = Foo()
-        self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], [bar1], []))
 
         f = Foo()
         f.bar = None
-        self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], [bar1]))
 
         f = Foo()
         f.bar = bar2
-        self.assertEquals(attributes.get_history(f._state, 'bar'), ([bar2], [], [bar1]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([bar2], [], [bar1]))
         f.bar = bar1
-        self.assertEquals(attributes.get_history(f._state, 'bar'), ([], [bar1], []))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([], [bar1], []))
 
         f = Foo()
         self.assertEquals(f.bar, bar1)
         del f.bar
-        self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], [bar1]))
         assert f.bar is None
-        self.assertEquals(attributes.get_history(f._state, 'bar'), ([None], [], [bar1]))
+        self.assertEquals(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], [], [bar1]))
 
+    
 if __name__ == "__main__":
     testenv.main()
index 7a68a4d58a994d0e3b83595761fe4601ed110ad7..4a2dc4419347f4107cf55201d87297de0279f843 100644 (file)
@@ -1,8 +1,9 @@
 import testenv; testenv.configure_for_tests()
 
 from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes, exc as orm_exc
 from testlib import *
 from testlib import fixtures
 
@@ -45,7 +46,7 @@ class O2MCascadeTest(fixtures.FixtureTest):
         try:
             sess.flush()
             assert False
-        except exceptions.FlushError, e:
+        except orm_exc.FlushError, e:
             assert "is an orphan" in str(e)
 
     def test_delete(self):
@@ -571,7 +572,7 @@ class UnsavedOrphansTest(ORMTest):
         s.save(a)
         try:
             s.flush()
-        except exceptions.FlushError, e:
+        except orm_exc.FlushError, e:
             pass
         assert a.address_id is None, "Error: address should not be persistent"
 
@@ -794,7 +795,7 @@ class DoubleParentOrphanTest(ORMTest):
         try:
             session.flush()
             assert False
-        except exceptions.FlushError, e:
+        except orm_exc.FlushError, e:
             assert True
 
 class CollectionAssignmentOrphanTest(ORMTest):
@@ -831,7 +832,7 @@ class CollectionAssignmentOrphanTest(ORMTest):
         self.assertEquals(sess.query(A).get(a1.id), A(name='a1', bs=[B(name='b1'), B(name='b2'), B(name='b3')]))
 
         a1 = sess.query(A).get(a1.id)
-        assert not class_mapper(B)._is_orphan(a1.bs[0])
+        assert not class_mapper(B)._is_orphan(attributes.instance_state(a1.bs[0]))
         a1.bs[0].foo='b2modified'
         a1.bs[1].foo='b3modified'
         sess.flush()
index 711dc730ba35a817aadfaf36a5110f43de795873..94e36f3668e892cd7213bd8500f9188f5dd165c6 100644 (file)
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
 import sys
 from operator import and_
 from sqlalchemy import *
-import sqlalchemy.exceptions as exceptions
+import sqlalchemy.exc as sa_exc
 from sqlalchemy.orm import create_session, mapper, relation, \
     interfaces, attributes
 import sqlalchemy.orm.collections as collections
@@ -933,13 +933,13 @@ class CollectionsTest(TestBase):
             self._test_adapter(dict, dictable_entity,
                                to_set=lambda c: set(c.values()))
             self.assert_(False)
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
 
         try:
             self._test_dict(dict)
             self.assert_(False)
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             self.assert_(e.args[0] == 'Type InstrumentedDict must elect an appender method to be a collection class')
 
     def test_dict_subclass(self):
index 31b6860623c4b72bf8d029862f776f5c2d70e5b0..59d636baec3525b4b0d6c105bcf3423662b9da11 100644 (file)
@@ -1,6 +1,6 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import *
 from testlib import *
 
@@ -118,7 +118,7 @@ class CompileTest(TestBase, AssertsExecutionResults):
         try:
             class_mapper(Product).compile()
             assert False
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             assert str(e).index("Error creating backref ") > -1
 
     def testthree(self):
@@ -177,7 +177,7 @@ class CompileTest(TestBase, AssertsExecutionResults):
         try:
             compile_mappers()
             assert False
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             assert str(e).index("Error creating backref") > -1
 
 if __name__ == '__main__':
index f956a4529b0ee86a3f0a9f32a9daddcf6055bcbb..8b5173d3ccf0afda907495543a7590a29bb22999 100644 (file)
@@ -173,22 +173,25 @@ class InheritTestOne(TestBase, AssertsExecutionResults):
             Column("child2_data", String(50))
             )
         meta.create_all()
+        
     def tearDownAll(self):
         meta.drop_all()
+        
     def testmanytooneonly(self):
         """test similar to SelfReferentialTest.testmanytooneonly"""
+        
         class Parent(object):
-                pass
+            pass
 
         mapper(Parent, parent)
 
         class Child1(Parent):
-                pass
+            pass
 
         mapper(Child1, child1, inherits=Parent)
 
         class Child2(Parent):
-                pass
+            pass
 
         mapper(Child2, child2, properties={
                         "child1": relation(Child1,
@@ -216,7 +219,9 @@ class InheritTestOne(TestBase, AssertsExecutionResults):
 class InheritTestTwo(ORMTest):
     """the fix in BiDirectionalManyToOneTest raised this issue, regarding
     the 'circular sort' containing UOWTasks that were still polymorphic, which could
-    create duplicate entries in the final sort"""
+    create duplicate entries in the final sort
+    
+    """
     def define_tables(self, metadata):
         global a, b, c
         a = Table('a', metadata,
@@ -235,6 +240,7 @@ class InheritTestTwo(ORMTest):
             Column('data', String(30)),
             Column('aid', Integer, ForeignKey('a.id', use_alter=True, name="foo")),
             )
+            
     def test_flush(self):
         class A(object):pass
         class B(A):pass
@@ -484,17 +490,19 @@ class OneToManyManyToOneTest(TestBase, AssertsExecutionResults):
 
     def testcycle(self):
         """this test has a peculiar aspect in that it doesnt create as many dependent
-        relationships as the other tests, and revealed a small glitch in the circular dependency sorting."""
+        relationships as the other tests, and revealed a small glitch in the circular dependency sorting.
+        
+        """
         class Person(object):
-         pass
+            pass
 
         class Ball(object):
-         pass
+            pass
 
         Ball.mapper = mapper(Ball, ball)
         Person.mapper = mapper(Person, person, properties= dict(
-         balls = relation(Ball.mapper, primaryjoin=ball.c.person_id==person.c.id, remote_side=ball.c.person_id),
-         favorateBall = relation(Ball.mapper, primaryjoin=person.c.favorite_ball_id==ball.c.id, remote_side=person.c.favorite_ball_id),
+             balls = relation(Ball.mapper, primaryjoin=ball.c.person_id==person.c.id, remote_side=ball.c.person_id),
+             favorateBall = relation(Ball.mapper, primaryjoin=person.c.favorite_ball_id==ball.c.id, remote_side=ball.c.id),
          )
         )
 
@@ -502,10 +510,9 @@ class OneToManyManyToOneTest(TestBase, AssertsExecutionResults):
         p = Person()
         p.balls.append(b)
         sess = create_session()
-        sess.save(b)
-        sess.save(b)
+        sess.save(p)
         sess.flush()
-
+        
     def testpostupdate_m2o(self):
         """tests a cycle between two rows, with a post_update on the many-to-one"""
         class Person(object):
@@ -860,6 +867,7 @@ class SelfReferentialPostUpdateTest2(TestBase, AssertsExecutionResults):
         a_table.create()
     def tearDownAll(self):
         a_table.drop()
+
     def testbasic(self):
         """test that post_update remembers to be involved in update operations as well,
         since it replaces the normal dependency processing completely [ticket:413]"""
diff --git a/test/orm/deprecations.py b/test/orm/deprecations.py
new file mode 100644 (file)
index 0000000..d6caaa1
--- /dev/null
@@ -0,0 +1,394 @@
+"""The collection of modern alternatives to deprecated & removed functionality.
+
+Collects specimens of old ORM code and explicitly covers the recommended
+modern (i.e. not deprecated) alternative to them.  The tests snippets here can
+be migrated directly to the wiki, docs, etc.
+
+"""
+import testenv; testenv.configure_for_tests()
+from sqlalchemy import *
+from sqlalchemy.orm import *
+from testlib import *
+
+
+users, addresses = None, None
+session = None
+
+class Base(object):
+    def __init__(self, **kw):
+        for k, v in kw.iteritems():
+            setattr(self, k, v)
+
+class User(Base): pass
+class Address(Base): pass
+
+
+class QueryAlternativesTest(ORMTest):
+    '''Collects modern idioms for Queries
+
+    The docstring for each test case serves as miniature documentation about
+    the deprecated use case, and the test body illustrates (and covers) the
+    intended replacement code to accomplish the same task.
+
+    Documenting the "old way" including the argument signature helps these
+    cases remain useful to readers even after the deprecated method has been
+    removed from the modern codebase.
+
+    Format:
+
+    def test_deprecated_thing(self):
+        """Query.methodname(old, arg, **signature)
+
+        output = session.query(User).deprecatedmethod(inputs)
+
+        """
+        # 0.4+
+        output = session.query(User).newway(inputs)
+        assert output is correct
+
+        # 0.5+
+        output = session.query(User).evennewerway(inputs)
+        assert output is correct
+
+    '''
+    keep_mappers = True
+    keep_data = True
+
+    def define_tables(self, metadata):
+        global users_table, addresses_table
+        users_table = Table(
+            'users', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(64)))
+
+        addresses_table = Table(
+            'addresses', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('user_id', Integer, ForeignKey('users.id')),
+            Column('email_address', String(128)),
+            Column('purpose', String(16)),
+            Column('bounces', Integer, default=0))
+
+    def setup_mappers(self):
+        mapper(User, users_table, properties=dict(
+            addresses=relation(Address, backref='user'),
+            ))
+        mapper(Address, addresses_table)
+
+    def insert_data(self):
+        user_cols = ('id', 'name')
+        user_rows = ((1, 'jack'), (2, 'ed'), (3, 'fred'), (4, 'chuck'))
+        users_table.insert().execute(
+            [dict(zip(user_cols, row)) for row in user_rows])
+
+        add_cols = ('id', 'user_id', 'email_address', 'purpose', 'bounces')
+        add_rows = (
+            (1, 1, 'jack@jack.home', 'Personal', 0),
+            (2, 1, 'jack@jack.bizz', 'Work', 1),
+            (3, 2, 'ed@foo.bar', 'Personal', 0),
+            (4, 3, 'fred@the.fred', 'Personal', 10))
+
+        addresses_table.insert().execute(
+            [dict(zip(add_cols, row)) for row in add_rows])
+
+    def setUp(self):
+        super(QueryAlternativesTest, self).setUp()
+        global session
+        if session is None:
+            session = create_session()
+
+    def tearDown(self):
+        super(QueryAlternativesTest, self).tearDown()
+        session.clear()
+
+    ######################################################################
+
+    def test_apply_max(self):
+        """Query.apply_max(col)
+
+        max = session.query(Address).apply_max(Address.bounces)
+
+        """
+        # 0.5.0
+        maxes = list(session.query(Address).values(func.max(Address.bounces)))
+        max = maxes[0][0]
+        assert max == 10
+
+        max = session.query(func.max(Address.bounces)).one()[0]
+        assert max == 10
+
+    def test_apply_min(self):
+        """Query.apply_min(col)
+
+        min = session.query(Address).apply_min(Address.bounces)
+
+        """
+        # 0.5.0
+        mins = list(session.query(Address).values(func.min(Address.bounces)))
+        min = mins[0][0]
+        assert min == 0
+
+        min = session.query(func.min(Address.bounces)).one()[0]
+        assert min == 0
+
+    def test_apply_avg(self):
+        """Query.apply_avg(col)
+
+        avg = session.query(Address).apply_avg(Address.bounces)
+
+        """
+        avgs = list(session.query(Address).values(func.avg(Address.bounces)))
+        avg = avgs[0][0]
+        assert avg > 0 and avg < 10
+
+        avg = session.query(func.avg(Address.bounces)).one()[0]
+        assert avg > 0 and avg < 10
+
+    def test_apply_sum(self):
+        """Query.apply_sum(col)
+
+        avg = session.query(Address).apply_avg(Address.bounces)
+
+        """
+        avgs = list(session.query(Address).values(func.avg(Address.bounces)))
+        avg = avgs[0][0]
+        assert avg > 0 and avg < 10
+
+        avg = session.query(func.avg(Address.bounces)).one()[0]
+        assert avg > 0 and avg < 10
+
+    def test_count_by(self):
+        """Query.count_by(*args, **params)
+
+        num = session.query(Address).count_by(purpose='Personal')
+
+        # old-style implicit *_by join
+        num = session.query(User).count_by(purpose='Personal')
+
+        """
+        num = session.query(Address).filter_by(purpose='Personal').count()
+        assert num == 3, num
+
+        num = (session.query(User).join('addresses').
+               filter(Address.purpose=='Personal')).count()
+        assert num == 3, num
+
+    def test_count_whereclause(self):
+        """Query.count(whereclause=None, params=None, **kwargs)
+
+        num = session.query(Address).count(address_table.c.bounces > 1)
+
+        """
+        num = session.query(Address).filter(Address.bounces > 1).count()
+        assert num == 1, num
+
+    def test_execute(self):
+        """Query.execute(clauseelement, params=None, *args, **kwargs)
+
+        users = session.query(User).execute(users_table.select())
+
+        """
+        users = session.query(User).from_statement(users_table.select()).all()
+        assert len(users) == 4
+
+    def test_get_by(self):
+        """Query.get_by(*args, **params)
+
+        user = session.query(User).get_by(name='ed')
+
+        # 0.3-style implicit *_by join
+        user = session.query(User).get_by(email_addresss='fred@the.fred')
+
+        """
+        user = session.query(User).filter_by(name='ed').first()
+        assert user.name == 'ed'
+
+        user = (session.query(User).join('addresses').
+                filter(Address.email_address=='fred@the.fred')).first()
+        assert user.name == 'fred'
+
+        user = session.query(User).filter(
+            User.addresses.any(Address.email_address=='fred@the.fred')).first()
+        assert user.name == 'fred'
+
+    def test_instances_entities(self):
+        """Query.instances(cursor, *mappers_or_columns, **kwargs)
+
+        sel = users_table.join(addresses_table).select(use_labels=True)
+        res = session.query(User).instances(sel.execute(), Address)
+
+        """
+        sel = users_table.join(addresses_table).select(use_labels=True)
+        res = session.query(User, Address).instances(sel.execute())
+
+        assert len(res) == 4
+        cola, colb = res[0]
+        assert isinstance(cola, User) and isinstance(colb, Address)
+
+
+    def test_join_by(self):
+        """Query.join_by(*args, **params)
+
+        TODO
+        """
+
+    def test_join_to(self):
+        """Query.join_to(key)
+
+        TODO
+        """
+
+    def test_join_via(self):
+        """Query.join_via(keys)
+
+        TODO
+        """
+
+    def test_list(self):
+        """Query.list()
+
+        users = session.query(User).list()
+
+        """
+        users = session.query(User).all()
+        assert len(users) == 4
+
+    def test_scalar(self):
+        """Query.scalar()
+
+        user = session.query(User).filter(User.id==1).scalar()
+
+        """
+        user = session.query(User).filter(User.id==1).first()
+        assert user.id==1
+
+    def test_select(self):
+        """Query.select(arg=None, **kwargs)
+
+        users = session.query(User).select(users_table.c.name != None)
+
+        """
+        users = session.query(User).filter(User.name != None).all()
+        assert len(users) == 4
+
+    def test_select_by(self):
+        """Query.select_by(*args, **params)
+
+        users = session.query(User).select_by(name='fred')
+
+        # 0.3 magic join on *_by methods
+        users = session.query(User).select_by(email_address='fred@the.fred')
+
+        """
+        users = session.query(User).filter_by(name='fred').all()
+        assert len(users) == 1
+
+        users = session.query(User).filter(User.name=='fred').all()
+        assert len(users) == 1
+
+        users = (session.query(User).join('addresses').
+                 filter_by(email_address='fred@the.fred')).all()
+        assert len(users) == 1
+
+        users = session.query(User).filter(User.addresses.any(
+            Address.email_address == 'fred@the.fred')).all()
+        assert len(users) == 1
+
+    def test_selectfirst(self):
+        """Query.selectfirst(arg=None, **kwargs)
+
+        bounced = session.query(Address).selectfirst(
+          addresses_table.c.bounces > 0)
+
+        """
+        bounced = session.query(Address).filter(Address.bounces > 0).first()
+        assert bounced.bounces > 0
+
+    def test_selectfirst_by(self):
+        """Query.selectfirst_by(*args, **params)
+
+        onebounce = session.query(Address).selectfirst_by(bounces=1)
+
+        # 0.3 magic join on *_by methods
+        onebounce_user = session.query(User).selectfirst_by(bounces=1)
+
+        """
+        onebounce = session.query(Address).filter_by(bounces=1).first()
+        assert onebounce.bounces == 1
+
+        onebounce_user = (session.query(User).join('addresses').
+                          filter_by(bounces=1)).first()
+        assert onebounce_user.name == 'jack'
+
+        onebounce_user = (session.query(User).join('addresses').
+                          filter(Address.bounces == 1)).first()
+        assert onebounce_user.name == 'jack'
+
+        onebounce_user = session.query(User).filter(User.addresses.any(
+            Address.bounces == 1)).first()
+        assert onebounce_user.name == 'jack'
+
+
+    def test_selectone(self):
+        """Query.selectone(arg=None, **kwargs)
+
+        ed = session.query(User).selectone(users_table.c.name == 'ed')
+
+        """
+        ed = session.query(User).filter(User.name == 'jack').one()
+
+    def test_selectone_by(self):
+        """Query.selectone_by
+
+        ed = session.query(User).selectone_by(name='ed')
+
+        # 0.3 magic join on *_by methods
+        ed = session.query(User).selectone_by(email_address='ed@foo.bar')
+
+        """
+        ed = session.query(User).filter_by(name='jack').one()
+
+        ed = session.query(User).filter(User.name == 'jack').one()
+
+        ed = session.query(User).join('addresses').filter(
+            Address.email_address == 'ed@foo.bar').one()
+
+        ed = session.query(User).filter(User.addresses.any(
+            Address.email_address == 'ed@foo.bar')).one()
+
+    def test_select_statement(self):
+        """Query.select_statement(statement, **params)
+
+        users = session.query(User).select_statement(users_table.select())
+
+        """
+        users = session.query(User).from_statement(users_table.select()).all()
+        assert len(users) == 4
+
+    def test_select_text(self):
+        """Query.select_text(text, **params)
+
+        users = session.query(User).select_text('SELECT * FROM users')
+
+        """
+        users = session.query(User).from_statement('SELECT * FROM users').all()
+        assert len(users) == 4
+
+    def test_select_whereclause(self):
+        """Query.select_whereclause(whereclause=None, params=None, **kwargs)
+
+
+        users = session,query(User).select_whereclause(users.c.name=='ed')
+        users = session.query(User).select_whereclause("name='ed'")
+
+        """
+        users = session.query(User).filter(User.name=='ed').all()
+        assert len(users) == 1 and users[0].name == 'ed'
+
+        users = session.query(User).filter("name='ed'").all()
+        assert len(users) == 1 and users[0].name == 'ed'
+
+
+
+if __name__ == '__main__':
+    testenv.main()
index c38b27823806422ec30bbb5e5c9c9de9f7fa8c34..0c3f1a95d0376bdda35a316d531a4b4dc41e647b 100644 (file)
@@ -129,7 +129,25 @@ class FlushTest(FixtureTest):
             User(name='jack', addresses=[Address(email_address='lala@hoho.com')]),
             User(name='ed', addresses=[Address(email_address='foo@bar.com')])
         ] == sess.query(User).all()
+    
+    def test_rollback(self):
+        class Fixture(Base):
+            pass
 
+        mapper(User, users, properties={
+            'addresses':dynamic_loader(mapper(Address, addresses))
+        })
+        sess = create_session(autoexpire=False, autocommit=False, autoflush=True)
+        u1 = User(name='jack')
+        u1.addresses.append(Address(email_address='lala@hoho.com'))
+        sess.save(u1)
+        sess.flush()
+        sess.commit()
+        u1.addresses.append(Address(email_address='foo@bar.com'))
+        self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com'), Address(email_address='foo@bar.com')])
+        sess.rollback()
+        self.assertEquals(u1.addresses.all(), [Address(email_address='lala@hoho.com')])
+        
     @testing.fails_on('maxdb')
     def test_delete_nocascade(self):
         mapper(User, users, properties={
index 418df83dda7a8c4fca046ccf0984a8ab3218953c..94723a20bbfcd2c2713eae8143ece79173154a6e 100644 (file)
@@ -6,6 +6,7 @@ from sqlalchemy.orm import *
 from testlib import *
 from testlib.fixtures import *
 from query import QueryTest
+from sqlalchemy.orm import attributes
 
 class EagerTest(FixtureTest):
     keep_mappers = False
@@ -31,8 +32,8 @@ class EagerTest(FixtureTest):
 
         sess = create_session()
         user = sess.query(User).get(7)
-        assert getattr(User, 'addresses').hasparent(user.addresses[0], optimistic=True)
-        assert not class_mapper(Address)._is_orphan(user.addresses[0])
+        assert getattr(User, 'addresses').hasparent(attributes.instance_state(user.addresses[0]), optimistic=True)
+        assert not class_mapper(Address)._is_orphan(attributes.instance_state(user.addresses[0]))
 
     def test_orderby(self):
         mapper(User, users, properties = {
@@ -129,12 +130,18 @@ class EagerTest(FixtureTest):
         })
         mapper(User, users)
 
-        assert [Address(id=1, user=User(id=7)), Address(id=4, user=User(id=8)), Address(id=5, user=User(id=9))] == create_session().query(Address).filter(Address.id.in_([1, 4, 5])).all()
-
-        assert [Address(id=1, user=User(id=7)), Address(id=4, user=User(id=8)), Address(id=5, user=User(id=9))] == create_session().query(Address).filter(Address.id.in_([1, 4, 5])).limit(3).all()
-
         sess = create_session()
-        a = sess.query(Address).get(1)
+        
+        for q in [
+            sess.query(Address).filter(Address.id.in_([1, 4, 5])),
+            sess.query(Address).filter(Address.id.in_([1, 4, 5])).limit(3)
+        ]:
+            sess.clear()
+            self.assertEquals(q.all(), 
+                [Address(id=1, user=User(id=7)), Address(id=4, user=User(id=8)), Address(id=5, user=User(id=9))]
+            )
+
+        a = sess.query(Address).filter(Address.id==1).first()
         def go():
             assert a.user_id==7
         # assert that the eager loader added 'user_id' to the row
@@ -150,12 +157,17 @@ class EagerTest(FixtureTest):
             'user_id':deferred(addresses.c.user_id),
         })
         mapper(User, users, properties={'addresses':relation(Address, lazy=False)})
+        
+        for q in [
+            sess.query(User).filter(User.id==7),
+            sess.query(User).filter(User.id==7).limit(1)
+        ]:
+            sess.clear()
+            self.assertEquals(q.all(), 
+                [User(id=7, addresses=[Address(id=1)])]
+            )
 
-        assert [User(id=7, addresses=[Address(id=1)])] == create_session().query(User).filter(User.id==7).all()
-
-        assert [User(id=7, addresses=[Address(id=1)])] == create_session().query(User).limit(1).filter(User.id==7).all()
-
-        sess = create_session()
+        sess.clear()
         u = sess.query(User).get(7)
         def go():
             assert u.addresses[0].user_id==7
@@ -173,9 +185,9 @@ class EagerTest(FixtureTest):
         mapper(Dingaling, dingalings, properties={
             'address_id':deferred(dingalings.c.address_id)
         })
-        sess = create_session()
+        sess.clear()
         def go():
-            u = sess.query(User).limit(1).get(8)
+            u = sess.query(User).get(8)
             assert User(id=8, addresses=[Address(id=2, dingalings=[Dingaling(id=1)]), Address(id=3), Address(id=4)]) == u
         self.assert_sql_count(testing.db, go, 1)
 
@@ -192,11 +204,11 @@ class EagerTest(FixtureTest):
         self.assert_sql_count(testing.db, go, 1)
 
         def go():
-            assert fixtures.item_keyword_result[0:2] == q.join('keywords').filter(keywords.c.name == 'red').all()
+            assert fixtures.item_keyword_result[0:2] == q.join('keywords').filter(Keyword.name == 'red').all()
         self.assert_sql_count(testing.db, go, 1)
 
         def go():
-            assert fixtures.item_keyword_result[0:2] == q.join('keywords', aliased=True).filter(keywords.c.name == 'red').all()
+            assert fixtures.item_keyword_result[0:2] == q.join('keywords', aliased=True).filter(Keyword.name == 'red').all()
         self.assert_sql_count(testing.db, go, 1)
 
 
@@ -364,7 +376,7 @@ class EagerTest(FixtureTest):
         q = sess.query(User)
 
         def go():
-            l = q.filter(s.c.u2_id==User.c.id).distinct().all()
+            l = q.filter(s.c.u2_id==User.id).distinct().all()
             assert fixtures.user_address_result == l
         self.assert_sql_count(testing.db, go, 1)
 
@@ -377,7 +389,7 @@ class EagerTest(FixtureTest):
 
         sess = create_session()
         q = sess.query(Item)
-        l = q.filter((Item.c.description=='item 2') | (Item.c.description=='item 5') | (Item.c.description=='item 3')).\
+        l = q.filter((Item.description=='item 2') | (Item.description=='item 5') | (Item.description=='item 3')).\
             order_by(Item.id).limit(2).all()
 
         assert fixtures.item_keyword_result[1:3] == l
@@ -607,7 +619,7 @@ class AddEntityTest(FixtureTest):
               )
         ]
 
-    def test_basic(self):
+    def test_mapper_configured(self):
         mapper(User, users, properties={
             'addresses':relation(Address, lazy=False),
             'orders':relation(Order)
@@ -620,8 +632,9 @@ class AddEntityTest(FixtureTest):
 
 
         sess = create_session()
+        oalias = aliased(Order)
         def go():
-            ret = sess.query(User).add_entity(Order).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all()
+            ret = sess.query(User, oalias).join(('orders', oalias)).order_by(User.id, oalias.id).all()
             self.assertEquals(ret, self._assert_result())
         self.assert_sql_count(testing.db, go, 1)
 
@@ -638,14 +651,15 @@ class AddEntityTest(FixtureTest):
 
         sess = create_session()
 
+        oalias = aliased(Order)
         def go():
-            ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all()
+            ret = sess.query(User, oalias).options(eagerload('addresses')).join(('orders', oalias)).order_by(User.id, oalias.id).all()
             self.assertEquals(ret, self._assert_result())
         self.assert_sql_count(testing.db, go, 6)
 
         sess.clear()
         def go():
-            ret = sess.query(User).options(eagerload('addresses')).add_entity(Order).options(eagerload('items', Order)).join('orders', aliased=True).order_by(User.id).order_by(Order.id).all()
+            ret = sess.query(User, oalias).options(eagerload('addresses'), eagerload(oalias.items)).join(('orders', oalias)).order_by(User.id, oalias.id).all()
             self.assertEquals(ret, self._assert_result())
         self.assert_sql_count(testing.db, go, 1)
 
@@ -933,11 +947,94 @@ class SelfReferentialM2MEagerTest(ORMTest):
         sess.flush()
         sess.clear()
 
-#        l = sess.query(Widget).filter(Widget.name=='w1').all()
-#        print l
         assert [Widget(name='w1', children=[Widget(name='w2')])] == sess.query(Widget).filter(Widget.name==u'w1').all()
 
+class MixedEntitiesTest(FixtureTest, AssertsCompiledSQL):
+    keep_mappers = True
+    keep_data = True
+    
+    def setup_mappers(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user'),
+            'orders':relation(Order, backref='user'), # o2m, m2o
+        })
+        mapper(Address, addresses)
+        mapper(Order, orders, properties={
+            'items':relation(Item, secondary=order_items, order_by=items.c.id),  #m2m
+        })
+        mapper(Item, items, properties={
+            'keywords':relation(Keyword, secondary=item_keywords) #m2m
+        })
+        mapper(Keyword, keywords)
+    
+    def test_two_entities(self):
+        sess = create_session()
+
+        # two FROM clauses
+        def go():
+            self.assertEquals(
+                [
+                    (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+                    (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+                ],
+                sess.query(User, Order).filter(User.id==Order.user_id).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).all(),
+            )
+        self.assert_sql_count(testing.db, go, 1)
+
+        # one FROM clause
+        def go():
+            self.assertEquals(
+                [
+                    (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+                    (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+                ],
+                sess.query(User, Order).join(User.orders).options(eagerload(User.addresses), eagerload(Order.items)).filter(User.id==9).all(),
+            )
+        self.assert_sql_count(testing.db, go, 1)
+    
+    def test_aliased_entity(self):
+        sess = create_session()
+        
+        oalias = aliased(Order)
+        
+        # two FROM clauses
+        def go():
+            self.assertEquals(
+                [
+                    (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+                    (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+                ],
+                sess.query(User, oalias).filter(User.id==oalias.user_id).options(eagerload(User.addresses), eagerload(oalias.items)).filter(User.id==9).all(),
+            )
+        self.assert_sql_count(testing.db, go, 1)
+
+        # one FROM clause
+        def go():
+            self.assertEquals(
+                [
+                    (User(id=9, addresses=[Address(id=5)]), Order(id=2, items=[Item(id=1), Item(id=2), Item(id=3)])),
+                    (User(id=9, addresses=[Address(id=5)]), Order(id=4, items=[Item(id=1), Item(id=5)])),
+                ],
+                sess.query(User, oalias).join((User.orders, oalias)).options(eagerload(User.addresses), eagerload(oalias.items)).filter(User.id==9).all(),
+            )
+        self.assert_sql_count(testing.db, go, 1)
+        
+        from sqlalchemy.engine.default import DefaultDialect
+        
+        # improper setup: oalias in the columns clause but join to usual orders alias.  
+        # this should create two FROM clauses even though the query has a from_clause set up via the join
+        self.assert_compile(sess.query(User, oalias).join(User.orders).options(eagerload(oalias.items)).with_labels().statement, 
+        "SELECT users.id AS users_id, users.name AS users_name, orders_1.id AS orders_1_id, "\
+        "orders_1.user_id AS orders_1_user_id, orders_1.address_id AS orders_1_address_id, "\
+        "orders_1.description AS orders_1_description, orders_1.isopen AS orders_1_isopen, items_1.id AS items_1_id, "\
+        "items_1.description AS items_1_description FROM users JOIN orders ON users.id = orders.user_id, "\
+        "orders AS orders_1 LEFT OUTER JOIN order_items AS order_items_1 ON orders_1.id = order_items_1.order_id "\
+        "LEFT OUTER JOIN items AS items_1 ON items_1.id = order_items_1.item_id ORDER BY users.id, items_1.id",
+        dialect=DefaultDialect()
+        )
+        
 class CyclicalInheritingEagerTest(ORMTest):
+
     def define_tables(self, metadata):
         global t1, t2
         t1 = Table('t1', metadata,
@@ -1041,22 +1138,14 @@ class SubqueryTest(ORMTest):
             session.save(User(name='bar', tags=[Tag(score1=5.0, score2=4.0), Tag(score1=50.0, score2=1.0), Tag(score1=15.0, score2=2.0)]))
             session.flush()
             session.clear()
+            
+            for user in session.query(User).all():
+                self.assertEquals(user.query_score, user.prop_score)
 
             def go():
-                for user in session.query(User).all():
-                    self.assertEquals(user.query_score, user.prop_score)
-            self.assert_sql_count(testing.db, go, 1)
-
-
-            # fails for non labeled (fixed in 0.5):
-            if labeled:
-                def go():
-                    u = session.query(User).filter_by(name='joe').one()
-                    self.assertEquals(u.query_score, u.prop_score)
-                self.assert_sql_count(testing.db, go, 1)
-            else:
                 u = session.query(User).filter_by(name='joe').one()
                 self.assertEquals(u.query_score, u.prop_score)
+            self.assert_sql_count(testing.db, go, 1)
             
             for t in (tags_table, users_table):
                 t.delete().execute()
index 760f8fce901c34509ee5d27233b61261131dbed1..d9c9e4002f6166b0fbd08478f122e031501257ee 100644 (file)
@@ -1,19 +1,18 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
 from testlib import *
 from testlib.tables import *
+from testlib import fixtures
 
 class EntityTest(TestBase, AssertsExecutionResults):
     """tests mappers that are constructed based on "entity names", which allows the same class
     to have multiple primary mappers """
 
-    @testing.uses_deprecated('SessionContext')
     def setUpAll(self):
         global user1, user2, address1, address2, metadata, ctx
         metadata = MetaData(testing.db)
-        ctx = SessionContext(create_session)
+        ctx = scoped_session(create_session)
 
         user1 = Table('user1', metadata,
             Column('user_id', Integer, Sequence('user1_id_seq', optional=True),
@@ -45,28 +44,31 @@ class EntityTest(TestBase, AssertsExecutionResults):
     def tearDownAll(self):
         metadata.drop_all()
     def tearDown(self):
-        ctx.current.clear()
+        ctx.clear()
         clear_mappers()
         for t in metadata.table_iterator(reverse=True):
             t.delete().execute()
 
-    @testing.uses_deprecated('SessionContextExt')
     def testbasic(self):
         """tests a pair of one-to-many mapper structures, establishing that both
         parent and child objects honor the "entity_name" attribute attached to the object
         instances."""
-        class User(object):pass
-        class Address(object):pass
+        class User(object):
+            def __init__(self, **kw):
+                pass
+        class Address(object):
+            def __init__(self, **kw):
+                pass
 
-        a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.mapper_extension)
-        a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.mapper_extension)
+        a1mapper = mapper(Address, address1, entity_name='address1', extension=ctx.extension)
+        a2mapper = mapper(Address, address2, entity_name='address2', extension=ctx.extension)
         u1mapper = mapper(User, user1, entity_name='user1', properties ={
             'addresses':relation(a1mapper)
-        }, extension=ctx.mapper_extension)
+        }, extension=ctx.extension)
         u2mapper =mapper(User, user2, entity_name='user2', properties={
             'addresses':relation(a2mapper)
-        }, extension=ctx.mapper_extension)
-
+        }, extension=ctx.extension)
+        
         u1 = User(_sa_entity_name='user1')
         u1.name = 'this is user 1'
         a1 = Address(_sa_entity_name='address1')
@@ -79,22 +81,22 @@ class EntityTest(TestBase, AssertsExecutionResults):
         a2.email='a2@foo.com'
         u2.addresses.append(a2)
 
-        ctx.current.flush()
+        ctx.flush()
         assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
         assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
         assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')]
         assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')]
 
-        ctx.current.clear()
-        u1list = ctx.current.query(User, entity_name='user1').all()
-        u2list = ctx.current.query(User, entity_name='user2').all()
+        ctx.clear()
+        u1list = ctx.query(User, entity_name='user1').all()
+        u2list = ctx.query(User, entity_name='user2').all()
         assert len(u1list) == len(u2list) == 1
         assert u1list[0] is not u2list[0]
         assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
 
-        u1 = ctx.current.query(User, entity_name='user1').first()
-        ctx.current.refresh(u1)
-        ctx.current.expire(u1)
+        u1 = ctx.query(User, entity_name='user1').first()
+        ctx.refresh(u1)
+        ctx.expire(u1)
 
 
     def testcascade(self):
@@ -142,18 +144,24 @@ class EntityTest(TestBase, AssertsExecutionResults):
 
     def testpolymorphic(self):
         """tests that entity_name can be used to have two kinds of relations on the same class."""
-        class User(object):pass
-        class Address1(object):pass
-        class Address2(object):pass
+        class User(object):
+            def __init__(self, **kw):
+                pass
+        class Address1(object):
+            def __init__(self, **kw):
+                pass
+        class Address2(object):
+            def __init__(self, **kw):
+                pass
 
-        a1mapper = mapper(Address1, address1, extension=ctx.mapper_extension)
-        a2mapper = mapper(Address2, address2, extension=ctx.mapper_extension)
+        a1mapper = mapper(Address1, address1, extension=ctx.extension)
+        a2mapper = mapper(Address2, address2, extension=ctx.extension)
         u1mapper = mapper(User, user1, entity_name='user1', properties ={
             'addresses':relation(a1mapper)
-        }, extension=ctx.mapper_extension)
+        }, extension=ctx.extension)
         u2mapper =mapper(User, user2, entity_name='user2', properties={
             'addresses':relation(a2mapper)
-        }, extension=ctx.mapper_extension)
+        }, extension=ctx.extension)
 
         u1 = User(_sa_entity_name='user1')
         u1.name = 'this is user 1'
@@ -167,15 +175,15 @@ class EntityTest(TestBase, AssertsExecutionResults):
         a2.email='a2@foo.com'
         u2.addresses.append(a2)
 
-        ctx.current.flush()
+        ctx.flush()
         assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
         assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
         assert address1.select().execute().fetchall() == [(a1.address_id, u1.user_id, 'a1@foo.com')]
         assert address2.select().execute().fetchall() == [(a1.address_id, u2.user_id, 'a2@foo.com')]
 
-        ctx.current.clear()
-        u1list = ctx.current.query(User, entity_name='user1').all()
-        u2list = ctx.current.query(User, entity_name='user2').all()
+        ctx.clear()
+        u1list = ctx.query(User, entity_name='user1').all()
+        u2list = ctx.query(User, entity_name='user2').all()
         assert len(u1list) == len(u2list) == 1
         assert u1list[0] is not u2list[0]
         assert len(u1list[0].addresses) == len(u2list[0].addresses) == 1
@@ -186,13 +194,15 @@ class EntityTest(TestBase, AssertsExecutionResults):
 
     def testpolymorphic_deferred(self):
         """test that deferred columns load properly using entity names"""
-        class User(object):pass
+        class User(object):
+            def __init__(self, **kwargs):
+                pass
         u1mapper = mapper(User, user1, entity_name='user1', properties ={
             'name':deferred(user1.c.name)
-        }, extension=ctx.mapper_extension)
+        }, extension=ctx.extension)
         u2mapper =mapper(User, user2, entity_name='user2', properties={
             'name':deferred(user2.c.name)
-        }, extension=ctx.mapper_extension)
+        }, extension=ctx.extension)
 
         u1 = User(_sa_entity_name='user1')
         u1.name = 'this is user 1'
@@ -200,13 +210,13 @@ class EntityTest(TestBase, AssertsExecutionResults):
         u2 = User(_sa_entity_name='user2')
         u2.name='this is user 2'
 
-        ctx.current.flush()
+        ctx.flush()
         assert user1.select().execute().fetchall() == [(u1.user_id, u1.name)]
         assert user2.select().execute().fetchall() == [(u2.user_id, u2.name)]
 
-        ctx.current.clear()
-        u1list = ctx.current.query(User, entity_name='user1').all()
-        u2list = ctx.current.query(User, entity_name='user2').all()
+        ctx.clear()
+        u1list = ctx.query(User, entity_name='user1').all()
+        u2list = ctx.query(User, entity_name='user2').all()
         assert len(u1list) == len(u2list) == 1
         assert u1list[0] is not u2list[0]
         # the deferred column load requires that setup_loader() check that the correct DeferredColumnLoader
@@ -214,6 +224,49 @@ class EntityTest(TestBase, AssertsExecutionResults):
         assert u1list[0].name == 'this is user 1'
         assert u2list[0].name == 'this is user 2'
 
+class SelfReferentialTest(ORMTest):
+    def define_tables(self, metadata):
+        global nodes
+            
+        nodes = Table('nodes', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('parent_id', Integer, ForeignKey('nodes.id')),
+            Column('data', String(50)),
+            Column('type', String(50)),
+            )
+
+    # fails inconsistently.  entity name needs deterministic 
+    # instrumentation.
+    def dont_test_relation(self):
+        class Node(fixtures.Base):
+            pass
+        
+        foonodes = nodes.select().where(nodes.c.type=='foo').alias()
+        barnodes = nodes.select().where(nodes.c.type=='bar').alias()
+        
+        # TODO: the order of instrumentation here is not deterministic;
+        # therefore the test fails sporadically since "Node.data" references
+        # different mappers at different times
+        m1 = mapper(Node, nodes)
+        m2 = mapper(Node, foonodes, entity_name='foo')
+        m3 = mapper(Node, barnodes, entity_name='bar')
+        
+        m1.add_property('foonodes', relation(m2, primaryjoin=nodes.c.id==foonodes.c.parent_id, 
+            backref=backref('foo_parent', remote_side=nodes.c.id, primaryjoin=nodes.c.id==foonodes.c.parent_id)))
+        m1.add_property('barnodes', relation(m3, primaryjoin=nodes.c.id==barnodes.c.parent_id, 
+            backref=backref('bar_parent', remote_side=nodes.c.id, primaryjoin=nodes.c.id==barnodes.c.parent_id)))
+        
+        sess = create_session()
+        
+        n1 = Node(data='n1', type='bat')
+        n1.foonodes.append(Node(data='n2', type='foo'))
+        Node(data='n3', type='bar', bar_parent=n1)
+        sess.save(n1)
+        sess.flush()
+        sess.clear()
+        
+        self.assertEquals(sess.query(Node, entity_name="bar").one(), Node(data='n3'))
+        self.assertEquals(sess.query(Node).filter(Node.data=='n1').one(), Node(data='n1', foonodes=[Node(data='n2')], barnodes=[Node(data='n3')]))
 
 if __name__ == "__main__":
     testenv.main()
index 58c05a3820eaeae0f3a638e0740895015c9b9418..e9960786651126610cd33c3cf76fc72959f5f6f5 100644 (file)
@@ -2,8 +2,9 @@
 
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes, exc as orm_exc
 from testlib import *
 from testlib.fixtures import *
 import gc
@@ -39,12 +40,12 @@ class ExpireTest(FixtureTest):
         sess.expire(u)
         # object isnt refreshed yet, using dict to bypass trigger
         assert u.__dict__.get('name') != 'jack'
-        assert 'name' in u._state.expired_attributes
+        assert 'name' in attributes.instance_state(u).expired_attributes
 
         sess.query(User).all()
         # test that it refreshed
         assert u.__dict__['name'] == 'jack'
-        assert 'name' not in u._state.expired_attributes
+        assert 'name' not in attributes.instance_state(u).expired_attributes
 
         def go():
             assert u.name == 'jack'
@@ -56,8 +57,49 @@ class ExpireTest(FixtureTest):
         u = s.get(User, 7)
         s.clear()
 
-        self.assertRaisesMessage(exceptions.InvalidRequestError, r"is not persistent within this Session", lambda: s.expire(u))
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, r"is not persistent within this Session", s.expire, u)
+    
+    def test_get_refreshes(self):
+        mapper(User, users)
+        s = create_session()
+        u = s.get(User, 10)
+        s.expire_all()
 
+        def go():
+            u = s.get(User, 10)  # get() refreshes
+        self.assert_sql_count(testing.db, go, 1)
+        def go():
+            self.assertEquals(u.name, 'chuck')  # attributes unexpired
+        self.assert_sql_count(testing.db, go, 0)
+        def go():
+            u = s.get(User, 10)  # expire flag reset, so not expired
+        self.assert_sql_count(testing.db, go, 0)
+
+        s.expire_all()
+        users.delete().where(User.id==10).execute()
+        
+        # object is gone, get() returns None
+        assert u in s
+        assert s.get(User, 10) is None
+        assert u not in s # and expunges
+    
+        # add it back
+        s.add(u)
+        # nope, raises ObjectDeletedError
+        self.assertRaises(orm_exc.ObjectDeletedError, getattr, u, 'name')
+        
+    def test_refresh_cancels_expire(self):
+        mapper(User, users)
+        s = create_session()
+        u = s.get(User, 7)
+        s.expire(u)
+        s.refresh(u)
+        
+        def go():
+            u = s.get(User, 7)
+            self.assertEquals(u.name, 'jack')
+        self.assert_sql_count(testing.db, go, 0)
+        
     def test_expire_doesntload_on_set(self):
         mapper(User, users)
 
@@ -79,18 +121,16 @@ class ExpireTest(FixtureTest):
 
         sess.expire(u, attribute_names=['name'])
         sess.expunge(u)
-        try:
-            u.name
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Instance <class 'testlib.fixtures.User'> is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed"
+        self.assertRaises(sa_exc.UnboundExecutionError, getattr, u, 'name')
     
-    def test_pending_doesnt_raise(self):
+    def test_pending_raises(self):
+        # this was the opposite in 0.4, but the reasoning there seemed off.
+        # expiring a pending instance makes no sense, so should raise
         mapper(User, users)
         sess = create_session()
         u = User(id=15)
         sess.save(u)
-        sess.expire(u, ['name'])
-        assert u.name is None
+        self.assertRaises(sa_exc.InvalidRequestError, sess.expire, u, ['name'])
         
     def test_no_instance_key(self):
         # this tests an artificial condition such that 
@@ -103,11 +143,11 @@ class ExpireTest(FixtureTest):
 
         sess.expire(u, attribute_names=['name'])
         sess.expunge(u)
-        del u._instance_key
+        attributes.instance_state(u).key = None
         assert 'name' not in u.__dict__
         sess.save(u)
         assert u.name == 'jack'
-        
+
     def test_expire_preserves_changes(self):
         """test that the expire load operation doesn't revert post-expire changes"""
 
@@ -163,7 +203,7 @@ class ExpireTest(FixtureTest):
 
         orders.update(id=3).execute(description='order 3 modified')
         assert o.isopen == 1
-        assert o._state.dict['description'] == 'order 3 modified'
+        assert attributes.instance_state(o).dict['description'] == 'order 3 modified'
         def go():
             sess.flush()
         self.assert_sql_count(testing.db, go, 0)
@@ -180,7 +220,7 @@ class ExpireTest(FixtureTest):
         u.addresses[0].email_address = 'someotheraddress'
         s.expire(u)
         u.name
-        print u._state.dict
+        print attributes.instance_state(u).dict
         assert u.addresses[0].email_address == 'ed@wood.com'
 
     def test_expired_lazy(self):
@@ -307,28 +347,28 @@ class ExpireTest(FixtureTest):
         sess.expire(o, attribute_names=['description'])
         assert 'id' in o.__dict__
         assert 'description' not in o.__dict__
-        assert o._state.dict['isopen'] == 1
+        assert attributes.instance_state(o).dict['isopen'] == 1
 
         orders.update(orders.c.id==3).execute(description='order 3 modified')
 
         def go():
             assert o.description == 'order 3 modified'
         self.assert_sql_count(testing.db, go, 1)
-        assert o._state.dict['description'] == 'order 3 modified'
+        assert attributes.instance_state(o).dict['description'] == 'order 3 modified'
 
         o.isopen = 5
         sess.expire(o, attribute_names=['description'])
         assert 'id' in o.__dict__
         assert 'description' not in o.__dict__
         assert o.__dict__['isopen'] == 5
-        assert o._state.committed_state['isopen'] == 1
+        assert attributes.instance_state(o).committed_state['isopen'] == 1
 
         def go():
             assert o.description == 'order 3 modified'
         self.assert_sql_count(testing.db, go, 1)
         assert o.__dict__['isopen'] == 5
-        assert o._state.dict['description'] == 'order 3 modified'
-        assert o._state.committed_state['isopen'] == 1
+        assert attributes.instance_state(o).dict['description'] == 'order 3 modified'
+        assert attributes.instance_state(o).committed_state['isopen'] == 1
 
         sess.flush()
 
@@ -578,44 +618,8 @@ class PolymorphicExpireTest(ORMTest):
             {'person_id':3, 'status':'old engineer'},
         )
 
-    def test_poly_select(self):
-        mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
-        mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
-        
-        sess = create_session()
-        [p1, e1, e2] = sess.query(Person).order_by(people.c.person_id).all()
-        
-        sess.expire(p1)
-        sess.expire(e1, ['status'])
-        sess.expire(e2)
-        
-        for p in [p1, e2]:
-            assert 'name' not in p.__dict__
-        
-        assert 'name' in e1.__dict__
-        assert 'status' not in e2.__dict__
-        assert 'status' not in e1.__dict__
-        
-        e1.name = 'new engineer name'
-        
-        def go():
-            sess.query(Person).all()
-        self.assert_sql_count(testing.db, go, 3)
-        
-        for p in [p1, e1, e2]:
-            assert 'name' in p.__dict__
-        
-        assert 'status' in e2.__dict__
-        assert 'status' in e1.__dict__
-        def go():
-            assert e1.name == 'new engineer name'
-            assert e2.name == 'engineer2'
-            assert e1.status == 'new engineer'
-        self.assert_sql_count(testing.db, go, 0)
-        self.assertEquals(Engineer.name.get_history(e1), (['new engineer name'], [], ['engineer1']))
-        
     def test_poly_deferred(self):
-        mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person', polymorphic_fetch='deferred')
+        mapper(Person, people, polymorphic_on=people.c.type, polymorphic_identity='person')
         mapper(Engineer, engineers, inherits=Person, polymorphic_identity='engineer')
 
         sess = create_session()
@@ -700,7 +704,7 @@ class RefreshTest(FixtureTest):
         s = create_session()
         u = s.get(User, 7)
         s.clear()
-        self.assertRaisesMessage(exceptions.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u))
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, r"is not persistent within this Session", lambda: s.refresh(u))
         
     def test_refresh_expired(self):
         mapper(User, users)
diff --git a/test/orm/extendedattr.py b/test/orm/extendedattr.py
new file mode 100644 (file)
index 0000000..a5c2c4a
--- /dev/null
@@ -0,0 +1,303 @@
+import testenv; testenv.configure_for_tests()
+import pickle
+from sqlalchemy import util
+import sqlalchemy.orm.attributes as attributes
+from sqlalchemy.orm.collections import collection
+from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute, is_instrumented
+from sqlalchemy.orm import clear_mappers
+from sqlalchemy.orm import InstrumentationManager
+
+from testlib import *
+
+class MyTypesManager(InstrumentationManager):
+
+    def instrument_attribute(self, class_, key, attr):
+        pass
+
+    def install_descriptor(self, class_, key, attr):
+        pass
+
+    def uninstall_descriptor(self, class_, key):
+        pass
+
+    def instrument_collection_class(self, class_, key, collection_class):
+        return MyListLike
+
+    def get_instance_dict(self, class_, instance):
+        return instance._goofy_dict
+
+    def initialize_instance_dict(self, class_, instance):
+        instance.__dict__['_goofy_dict'] = {}
+
+    def install_state(self, class_, instance, state):
+        instance.__dict__['_my_state'] = state
+
+    def state_getter(self, class_):
+        return lambda instance: instance.__dict__['_my_state']
+
+class MyListLike(list):
+    # add @appender, @remover decorators as needed
+    _sa_iterator = list.__iter__
+    def _sa_appender(self, item, _sa_initiator=None):
+        if _sa_initiator is not False:
+            self._sa_adapter.fire_append_event(item, _sa_initiator)
+        list.append(self, item)
+    append = _sa_appender
+    def _sa_remover(self, item, _sa_initiator=None):
+        self._sa_adapter.fire_pre_remove_event(_sa_initiator)
+        if _sa_initiator is not False:
+            self._sa_adapter.fire_remove_event(item, _sa_initiator)
+        list.remove(self, item)
+    remove = _sa_remover
+
+class MyBaseClass(object):
+    __sa_instrumentation_manager__ = InstrumentationManager
+
+class MyClass(object):
+
+    # This proves that a staticmethod will work here; don't
+    # flatten this back to a class assignment!
+    def __sa_instrumentation_manager__(cls):
+        return MyTypesManager(cls)
+
+    __sa_instrumentation_manager__ = staticmethod(__sa_instrumentation_manager__)
+    
+    # This proves SA can handle a class with non-string dict keys
+    locals()[42] = 99   # Don't remove this line!
+
+    def __init__(self, **kwargs):
+        for k in kwargs:
+            setattr(self, k, kwargs[k])
+
+    def __getattr__(self, key):
+        if is_instrumented(self, key):
+            return get_attribute(self, key)
+        else:
+            try:
+                return self._goofy_dict[key]
+            except KeyError:
+                raise AttributeError(key)
+
+    def __setattr__(self, key, value):
+        if is_instrumented(self, key):
+            set_attribute(self, key, value)
+        else:
+            self._goofy_dict[key] = value
+
+    def __hasattr__(self, key):
+        if is_instrumented(self, key):
+            return True
+        else:
+            return key in self._goofy_dict
+
+    def __delattr__(self, key):
+        if is_instrumented(self, key):
+            del_attribute(self, key)
+        else:
+            del self._goofy_dict[key]
+
+class UserDefinedExtensionTest(TestBase):
+    def tearDownAll(self):
+        clear_mappers()
+        attributes._install_lookup_strategy(util.symbol('native'))
+
+    def test_basic(self):
+        for base in (object, MyBaseClass, MyClass):
+            class User(base):
+                pass
+
+            attributes.register_class(User)
+            attributes.register_attribute(User, 'user_id', uselist = False, useobject=False)
+            attributes.register_attribute(User, 'user_name', uselist = False, useobject=False)
+            attributes.register_attribute(User, 'email_address', uselist = False, useobject=False)
+
+            u = User()
+            u.user_id = 7
+            u.user_name = 'john'
+            u.email_address = 'lala@123.com'
+
+            self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+            attributes.instance_state(u).commit_all()
+            self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+
+            u.user_name = 'heythere'
+            u.email_address = 'foo@bar.com'
+            self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com')
+
+    def test_deferred(self):
+        for base in (object, MyBaseClass, MyClass):
+            class Foo(base):pass
+
+            data = {'a':'this is a', 'b':12}
+            def loader(state, keys):
+                for k in keys:
+                    state.dict[k] = data[k]
+                return attributes.ATTR_WAS_SET
+
+            attributes.register_class(Foo)
+            manager = attributes.manager_of_class(Foo)
+            manager.deferred_scalar_loader = loader
+            attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
+            attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
+            
+            assert Foo in attributes.instrumentation_registry.state_finders
+            f = Foo()
+            attributes.instance_state(f).expire_attributes(None)
+            self.assertEquals(f.a, "this is a")
+            self.assertEquals(f.b, 12)
+
+            f.a = "this is some new a"
+            attributes.instance_state(f).expire_attributes(None)
+            self.assertEquals(f.a, "this is a")
+            self.assertEquals(f.b, 12)
+
+            attributes.instance_state(f).expire_attributes(None)
+            f.a = "this is another new a"
+            self.assertEquals(f.a, "this is another new a")
+            self.assertEquals(f.b, 12)
+
+            attributes.instance_state(f).expire_attributes(None)
+            self.assertEquals(f.a, "this is a")
+            self.assertEquals(f.b, 12)
+
+            del f.a
+            self.assertEquals(f.a, None)
+            self.assertEquals(f.b, 12)
+
+            attributes.instance_state(f).commit_all()
+            self.assertEquals(f.a, None)
+            self.assertEquals(f.b, 12)
+
+    def test_inheritance(self):
+        """tests that attributes are polymorphic"""
+
+        for base in (object, MyBaseClass, MyClass):
+            class Foo(base):pass
+            class Bar(Foo):pass
+
+            attributes.register_class(Foo)
+            attributes.register_class(Bar)
+
+            def func1():
+                print "func1"
+                return "this is the foo attr"
+            def func2():
+                print "func2"
+                return "this is the bar attr"
+            def func3():
+                print "func3"
+                return "this is the shared attr"
+            attributes.register_attribute(Foo, 'element', uselist=False, callable_=lambda o:func1, useobject=True)
+            attributes.register_attribute(Foo, 'element2', uselist=False, callable_=lambda o:func3, useobject=True)
+            attributes.register_attribute(Bar, 'element', uselist=False, callable_=lambda o:func2, useobject=True)
+
+            x = Foo()
+            y = Bar()
+            assert x.element == 'this is the foo attr'
+            assert y.element == 'this is the bar attr', y.element
+            assert x.element2 == 'this is the shared attr'
+            assert y.element2 == 'this is the shared attr'
+
+    def test_collection_with_backref(self):
+        for base in (object, MyBaseClass, MyClass):
+            class Post(base):pass
+            class Blog(base):pass
+
+            attributes.register_class(Post)
+            attributes.register_class(Blog)
+            attributes.register_attribute(Post, 'blog', uselist=False, extension=attributes.GenericBackrefExtension('posts'), trackparent=True, useobject=True)
+            attributes.register_attribute(Blog, 'posts', uselist=True, extension=attributes.GenericBackrefExtension('blog'), trackparent=True, useobject=True)
+            b = Blog()
+            (p1, p2, p3) = (Post(), Post(), Post())
+            b.posts.append(p1)
+            b.posts.append(p2)
+            b.posts.append(p3)
+            self.assert_(b.posts == [p1, p2, p3])
+            self.assert_(p2.blog is b)
+
+            p3.blog = None
+            self.assert_(b.posts == [p1, p2])
+            p4 = Post()
+            p4.blog = b
+            self.assert_(b.posts == [p1, p2, p4])
+
+            p4.blog = b
+            p4.blog = b
+            self.assert_(b.posts == [p1, p2, p4])
+
+            # assert no failure removing None
+            p5 = Post()
+            p5.blog = None
+            del p5.blog
+
+    def test_history(self):
+        for base in (object, MyBaseClass, MyClass):
+            class Foo(base):
+                pass
+            class Bar(base):
+                pass
+
+            attributes.register_class(Foo)
+            attributes.register_class(Bar)
+            attributes.register_attribute(Foo, "name", uselist=False, useobject=False)
+            attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True)
+            attributes.register_attribute(Bar, "name", uselist=False, useobject=False)
+
+
+            f1 = Foo()
+            f1.name = 'f1'
+
+            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1'], [], []))
+
+            b1 = Bar()
+            b1.name = 'b1'
+            f1.bars.append(b1)
+            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
+
+            attributes.instance_state(f1).commit_all()
+            attributes.instance_state(b1).commit_all()
+
+            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), ([], ['f1'], []))
+            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([], [b1], []))
+
+            f1.name = 'f1mod'
+            b2 = Bar()
+            b2.name = 'b2'
+            f1.bars.append(b2)
+            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'name'), (['f1mod'], [], ['f1']))
+            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], []))
+            f1.bars.remove(b1)
+            self.assertEquals(attributes.get_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1]))
+
+    def test_null_instrumentation(self):
+        class Foo(MyBaseClass):
+            pass
+        attributes.register_class(Foo)
+        attributes.register_attribute(Foo, "name", uselist=False, useobject=False)
+        attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True)
+
+        assert Foo.name == attributes.manager_of_class(Foo).get_inst('name')
+        assert Foo.bars == attributes.manager_of_class(Foo).get_inst('bars')
+
+    def test_alternate_finders(self):
+        """Ensure the generic finder front-end deals with edge cases."""
+
+        class Unknown(object): pass
+        class Known(MyBaseClass): pass
+
+        attributes.register_class(Known)
+        k, u = Known(), Unknown()
+
+        assert attributes.manager_of_class(Unknown) is None
+        assert attributes.manager_of_class(Known) is not None
+        assert attributes.manager_of_class(None) is None
+
+        assert attributes.instance_state(k) is not None
+        self.assertRaises((AttributeError, KeyError),
+                          attributes.instance_state, u)
+        self.assertRaises((AttributeError, KeyError),
+                          attributes.instance_state, None)
+
+
+if __name__ == '__main__':
+    testing.main()
index aced8f626feec3530d7e1c39ce6defc2f2da5d6f..88793f7435450039dd41b2f101a8adc8dde19496 100644 (file)
@@ -1,7 +1,6 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy import exceptions
 from testlib import *
 import testlib.tables as tables
 
@@ -35,8 +34,8 @@ class GenerativeQueryTest(TestBase):
 
     def test_selectby(self):
         res = create_session(bind=testing.db).query(Foo).filter_by(range=5)
-        assert res.order_by([Foo.c.bar])[0].bar == 5
-        assert res.order_by([desc(Foo.c.bar)])[0].bar == 95
+        assert res.order_by([Foo.bar])[0].bar == 5
+        assert res.order_by([desc(Foo.bar)])[0].bar == 95
 
     @testing.unsupported('mssql')
     @testing.fails_on('maxdb')
@@ -60,8 +59,8 @@ class GenerativeQueryTest(TestBase):
         assert query.count() == 100
         assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0
         assert query.filter(foo.c.bar<30).max(foo.c.bar) == 29
-        assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).first() == 29
-        assert query.filter(foo.c.bar<30).apply_max(foo.c.bar).one() == 29
+        assert query.filter(foo.c.bar<30).values(func.max(foo.c.bar)).next()[0] == 29
+        assert query.filter(foo.c.bar<30).values(func.max(foo.c.bar)).next()[0] == 29
 
     def test_aggregate_1(self):
         if (testing.against('mysql') and
@@ -77,22 +76,20 @@ class GenerativeQueryTest(TestBase):
         avg = query.filter(foo.c.bar < 30).avg(foo.c.bar)
         assert round(avg, 1) == 14.5
 
-    @testing.fails_on('firebird', 'mssql')
-    @testing.uses_deprecated('Call to deprecated function apply_avg')
     def test_aggregate_3(self):
         query = create_session(bind=testing.db).query(Foo)
 
-        avg_f = query.filter(foo.c.bar<30).apply_avg(foo.c.bar).first()
+        avg_f = query.filter(foo.c.bar<30).values(func.avg(foo.c.bar)).next()[0]
         assert round(avg_f, 1) == 14.5
 
-        avg_o = query.filter(foo.c.bar<30).apply_avg(foo.c.bar).one()
+        avg_o = query.filter(foo.c.bar<30).values(func.avg(foo.c.bar)).next()[0]
         assert round(avg_o, 1) == 14.5
 
     def test_filter(self):
         query = create_session(bind=testing.db).query(Foo)
         assert query.count() == 100
-        assert query.filter(Foo.c.bar < 30).count() == 30
-        res2 = query.filter(Foo.c.bar < 30).filter(Foo.c.bar > 10)
+        assert query.filter(Foo.bar < 30).count() == 30
+        res2 = query.filter(Foo.bar < 30).filter(Foo.bar > 10)
         assert res2.count() == 19
 
     def test_options(self):
@@ -105,12 +102,12 @@ class GenerativeQueryTest(TestBase):
 
     def test_order_by(self):
         query = create_session(bind=testing.db).query(Foo)
-        assert query.order_by([Foo.c.bar])[0].bar == 0
-        assert query.order_by([desc(Foo.c.bar)])[0].bar == 99
+        assert query.order_by([Foo.bar])[0].bar == 0
+        assert query.order_by([desc(Foo.bar)])[0].bar == 99
 
     def test_offset(self):
         query = create_session(bind=testing.db).query(Foo)
-        assert list(query.order_by([Foo.c.bar]).offset(10))[0].bar == 10
+        assert list(query.order_by([Foo.bar]).offset(10))[0].bar == 10
 
     def test_offset(self):
         query = create_session(bind=testing.db).query(Foo)
@@ -168,7 +165,7 @@ class RelationsTest(TestBase, AssertsExecutionResults):
         })
         session = create_session(bind=testing.db)
         query = session.query(tables.User)
-        x = query.join(['orders', 'items']).filter(tables.Item.c.item_id==2)
+        x = query.join(['orders', 'items']).filter(tables.Item.item_id==2)
         print x.compile()
         self.assert_result(list(x), tables.User, tables.user_result[2])
     def test_outerjointo(self):
@@ -180,7 +177,7 @@ class RelationsTest(TestBase, AssertsExecutionResults):
         })
         session = create_session(bind=testing.db)
         query = session.query(tables.User)
-        x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+        x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.order_id==None,tables.Item.item_id==2))
         print x.compile()
         self.assert_result(list(x), tables.User, *tables.user_result[1:3])
     def test_outerjointo_count(self):
@@ -192,7 +189,7 @@ class RelationsTest(TestBase, AssertsExecutionResults):
         })
         session = create_session(bind=testing.db)
         query = session.query(tables.User)
-        x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2)).count()
+        x = query.outerjoin(['orders', 'items']).filter(or_(tables.Order.order_id==None,tables.Item.item_id==2)).count()
         assert x==2
     def test_from(self):
         mapper(tables.User, tables.users, properties={
@@ -203,7 +200,7 @@ class RelationsTest(TestBase, AssertsExecutionResults):
         session = create_session(bind=testing.db)
         query = session.query(tables.User)
         x = query.select_from(tables.users.outerjoin(tables.orders).outerjoin(tables.orderitems)).\
-            filter(or_(tables.Order.c.order_id==None,tables.Item.c.item_id==2))
+            filter(or_(tables.Order.order_id==None,tables.Item.item_id==2))
         print x.compile()
         self.assert_result(list(x), tables.User, *tables.user_result[1:3])
 
@@ -238,27 +235,6 @@ class CaseSensitiveTest(TestBase):
         res = q.filter(and_(table1.c.ID==table2.c.T1ID,table2.c.T1ID==1)).distinct()
         self.assertEqual(res.count(), 1)
 
-class SelfRefTest(ORMTest):
-    def define_tables(self, metadata):
-        global t1
-        t1 = Table('t1', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('parent_id', Integer, ForeignKey('t1.id'))
-            )
-    def test_noautojoin(self):
-        class T(object):pass
-        mapper(T, t1, properties={'children':relation(T)})
-        sess = create_session(bind=testing.db)
-        def go():
-            sess.query(T).join('children')
-        self.assertRaisesMessage(exceptions.InvalidRequestError, 
-            "Self-referential query on 'T\.children \(T\)' property requires aliased=True argument.", go)
-
-        def go():
-            sess.query(T).join(['children']).select_by(id=7)
-        self.assertRaisesMessage(exceptions.InvalidRequestError, 
-            "Self-referential query on 'T\.children \(T\)' property requires aliased=True argument.", go)
-
 
 
 if __name__ == "__main__":
index 5f7a1075628c82b0baf795abec70d5cd20275180..e6977506ab17fcbeca301e29fbf19ee4d0036293 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy.orm.sync import ONETOMANY, MANYTOONE
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE
 from testlib import *
 
 
index 076c7b76b8e787143f08089a0ee6c86be1fc1e39..367c2e73cc0747b97eb967e607815232774d1ed6 100644 (file)
@@ -1,6 +1,6 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import util
 from sqlalchemy.orm import *
 from testlib import *
 from testlib import fixtures
@@ -32,8 +32,8 @@ class ABCTest(ORMTest):
             else:
                 abc = bc = None
 
-            mapper(A, a, select_table=abc, polymorphic_on=a.c.type, polymorphic_identity='a', polymorphic_fetch=fetchtype)
-            mapper(B, b, select_table=bc, inherits=A, polymorphic_identity='b', polymorphic_fetch=fetchtype)
+            mapper(A, a, select_table=abc, polymorphic_on=a.c.type, polymorphic_identity='a')
+            mapper(B, b, select_table=bc, inherits=A, polymorphic_identity='b')
             mapper(C, c, inherits=B, polymorphic_identity='c')
 
             a1 = A(adata='a1')
@@ -82,8 +82,7 @@ class ABCTest(ORMTest):
         return test_roundtrip
 
     test_union = make_test('union')
-    test_select = make_test('select')
-    test_deferred = make_test('deferred')
+    test_none = make_test('none')
 
 
 if __name__ == '__main__':
index 8a0b6f30af1e9cf99cdd22e96e4c25c0303260d2..91e7b3b7f45a0b565e959c7a72c4f3559fe8dcd3 100644 (file)
@@ -1,7 +1,8 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import exc as sa_exc, util
 from sqlalchemy.orm import *
+from sqlalchemy.orm import exc as orm_exc
 from testlib import *
 from testlib import fixtures
 
@@ -302,7 +303,7 @@ class ConstructionTest(ORMTest):
                 'content_type':relation(content_types)
             }, polymorphic_identity='contents')
             assert False
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             assert str(e) == "Mapper 'Mapper|Content|content' specifies a polymorphic_identity of 'contents', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument"
 
     def testbackref(self):
@@ -397,7 +398,7 @@ class FlushTest(ORMTest):
         class Admin(User):pass
         role_mapper = mapper(Role, roles)
         user_mapper = mapper(User, users, properties = {
-                'roles' : relation(Role, secondary=user_roles, lazy=False, private=False)
+                'roles' : relation(Role, secondary=user_roles, lazy=False)
             }
         )
         admin_mapper = mapper(Admin, admins, inherits=user_mapper)
@@ -432,7 +433,7 @@ class FlushTest(ORMTest):
 
         role_mapper = mapper(Role, roles)
         user_mapper = mapper(User, users, properties = {
-                'roles' : relation(Role, secondary=user_roles, lazy=False, private=False)
+                'roles' : relation(Role, secondary=user_roles, lazy=False)
             }
         )
 
@@ -507,13 +508,13 @@ class VersioningTest(ORMTest):
         try:
             sess2.query(Base).with_lockmode('read').get(s1.id)
             assert False
-        except exceptions.ConcurrentModificationError, e:
+        except orm_exc.ConcurrentModificationError, e:
             assert True
 
         try:
             sess2.flush()
             assert False
-        except exceptions.ConcurrentModificationError, e:
+        except orm_exc.ConcurrentModificationError, e:
             assert True
 
         sess2.refresh(s2)
@@ -553,7 +554,7 @@ class VersioningTest(ORMTest):
             s1.subdata = 'some new subdata'
             sess.flush()
             assert False
-        except exceptions.ConcurrentModificationError, e:
+        except orm_exc.ConcurrentModificationError, e:
             assert True
 
 
@@ -608,7 +609,7 @@ class DistinctPKTest(ORMTest):
             mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id])
             self._do_test(True)
             assert False
-        except exceptions.SAWarning, e:
+        except sa_exc.SAWarning, e:
             assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'.  Use explicit properties to give each column its own mapped attribute name.", str(e)
 
     def test_explicit_pk(self):
index 29fa1df6053e713c296563b61ef16c9f0f77bbaa..ffc95ac056f545ced5f0124ff290d403317bbb83 100644 (file)
@@ -74,6 +74,10 @@ class ConcreteTest(ORMTest):
         assert set([repr(x) for x in session.query(Manager).all()]) == set(["Manager Tom knows how to manage things"])
         assert set([repr(x) for x in session.query(Engineer).all()]) == set(["Engineer Kurt knows how to hack"])
 
+        manager = session.query(Manager).one()
+        session.expire(manager, ['manager_data'])
+        self.assertEquals(manager.manager_data, "knows how to manage things")
+
     def test_multi_level(self):
         class Employee(object):
             def __init__(self, name):
index b2dd6c658ed98528347c54bcfdfcdf748c38173b..e9e5e1ef6bcef857f210f9aa9b6e7ce26a27f019 100644 (file)
@@ -166,7 +166,7 @@ class PolymorphicCircularTest(ORMTest):
 
         # clear and query forwards
         sess.clear()
-        node = sess.query(Table1).filter(Table1.c.id==t.id).first()
+        node = sess.query(Table1).filter(Table1.id==t.id).first()
         assertlist = []
         while (node):
             assertlist.append(node)
@@ -178,7 +178,7 @@ class PolymorphicCircularTest(ORMTest):
 
         # clear and query backwards
         sess.clear()
-        node = sess.query(Table1).filter(Table1.c.id==obj.id).first()
+        node = sess.query(Table1).filter(Table1.id==obj.id).first()
         assertlist = []
         while (node):
             assertlist.insert(0, node)
@@ -189,9 +189,6 @@ class PolymorphicCircularTest(ORMTest):
         backwards = repr(assertlist)
 
         # everything should match !
-        print "ORIGNAL", original
-        print "BACKWARDS",backwards
-        print "FORWARDS", forwards
         assert original == forwards == backwards
 
 if __name__ == '__main__':
index 5442520242351de873daf501a307c5bc171bd035..141aedcac6fe54befa8fbf03cc8b5d28d0dbe4ce 100644 (file)
@@ -4,7 +4,7 @@ import testenv; testenv.configure_for_tests()
 import sets
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy.orm import exc as orm_exc
 from testlib import *
 from testlib import fixtures
 
@@ -122,7 +122,7 @@ class RelationToSubclassTest(PolymorphTest):
 class RoundTripTest(PolymorphTest):
     pass
 
-def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_colprop=False, use_literal_join=False, polymorphic_fetch=None, use_outer_joins=False):
+def generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic):
     """generates a round trip test.
 
     include_base - whether or not to include the base 'person' type in the union.
@@ -131,62 +131,52 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
     use_literal_join - primary join condition is explicitly specified
     """
     def test_roundtrip(self):
-        # create a union that represents both types of joins.
-        if not polymorphic_fetch == 'union':
-            person_join = None
-            manager_join = None
-        elif include_base:
-            if use_outer_joins:
-                person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
-                manager_join = people.join(managers).outerjoin(boss)
-            else:
+        if with_polymorphic == 'unions':
+            if include_base:
                 person_join = polymorphic_union(
                     {
                         'engineer':people.join(engineers),
                         'manager':people.join(managers),
                         'person':people.select(people.c.type=='person'),
                     }, None, 'pjoin')
-
-                manager_join = people.join(managers).outerjoin(boss)
-        else:
-            if use_outer_joins:
-                person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
-                manager_join = people.join(managers).outerjoin(boss)
             else:
                 person_join = polymorphic_union(
                     {
                         'engineer':people.join(engineers),
                         'manager':people.join(managers),
                     }, None, 'pjoin')
-                manager_join = people.join(managers).outerjoin(boss)
+                
+            manager_join = people.join(managers).outerjoin(boss)
+            person_with_polymorphic = ['*', person_join]
+            manager_with_polymorphic = ['*', manager_join]
+        elif with_polymorphic == 'joins':
+            person_join = people.outerjoin(engineers).outerjoin(managers).outerjoin(boss)
+            manager_join = people.join(managers).outerjoin(boss)
+            person_with_polymorphic = ['*', person_join]
+            manager_with_polymorphic = ['*', manager_join]
+        elif with_polymorphic == 'auto':
+            person_with_polymorphic = '*'
+            manager_with_polymorphic = '*'
+        else:
+            person_with_polymorphic = None
+            manager_with_polymorphic = None
 
         if redefine_colprop:
-            person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
+            person_mapper = mapper(Person, people, with_polymorphic=person_with_polymorphic, polymorphic_on=people.c.type, polymorphic_identity='person', properties= {'person_name':people.c.name})
         else:
-            person_mapper = mapper(Person, people, select_table=person_join, polymorphic_fetch=polymorphic_fetch, polymorphic_on=people.c.type, polymorphic_identity='person')
+            person_mapper = mapper(Person, people, with_polymorphic=person_with_polymorphic, polymorphic_on=people.c.type, polymorphic_identity='person')
 
         mapper(Engineer, engineers, inherits=person_mapper, polymorphic_identity='engineer')
-        mapper(Manager, managers, inherits=person_mapper, select_table=manager_join, polymorphic_identity='manager')
+        mapper(Manager, managers, inherits=person_mapper, with_polymorphic=manager_with_polymorphic, polymorphic_identity='manager')
 
         mapper(Boss, boss, inherits=Manager, polymorphic_identity='boss')
 
-        if use_literal_join:
-            mapper(Company, companies, properties={
-                'employees': relation(Person, lazy=lazy_relation,
-                                      primaryjoin=(people.c.company_id ==
-                                                   companies.c.company_id),
-                                      cascade="all,delete-orphan",
-                                      backref="company", 
-                                      order_by=people.c.person_id
-                )
-            })
-        else:
-            mapper(Company, companies, properties={
-                'employees': relation(Person, lazy=lazy_relation,
-                                      cascade="all, delete-orphan",
-                backref="company", order_by=people.c.person_id
-                )
-            })
+        mapper(Company, companies, properties={
+            'employees': relation(Person, lazy=lazy_relation,
+                                  cascade="all, delete-orphan",
+            backref="company", order_by=people.c.person_id
+            )
+        })
 
         if redefine_colprop:
             person_attribute_name = 'person_name'
@@ -224,18 +214,16 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
 
         def go():
             cc = session.query(Company).get(c.company_id)
-            for e in cc.employees:
-                assert e._instance_key[0] == Person
             self.assertEquals(cc.employees, employees)
             
         if not lazy_relation:
-            if polymorphic_fetch=='union':
+            if with_polymorphic != 'none':
                 self.assert_sql_count(testing.db, go, 1)
             else:
                 self.assert_sql_count(testing.db, go, 5)
 
         else:
-            if polymorphic_fetch=='union':
+            if with_polymorphic != 'none':
                 self.assert_sql_count(testing.db, go, 2)
             else:
                 self.assert_sql_count(testing.db, go, 6)
@@ -265,21 +253,20 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
         session.flush()
         session.clear()
         
-        if polymorphic_fetch == 'select':
-            def go():
-                session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
-            self.assert_sql_count(testing.db, go, 2)
-            session.clear()
-            dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
-            def go():
-                # assert that only primary table is queried for already-present-in-session
-                d = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
-            self.assert_sql_count(testing.db, go, 1)
+        def go():
+            session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+        self.assert_sql_count(testing.db, go, 1)
+        session.clear()
+        dilbert = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+        def go():
+            # assert that only primary table is queried for already-present-in-session
+            d = session.query(Person).filter(getattr(Person, person_attribute_name)=='dilbert').first()
+        self.assert_sql_count(testing.db, go, 1)
 
         # test standalone orphans
         daboss = Boss(status='BBB', manager_name='boss', golf_swing='fore', **{person_attribute_name:'daboss'})
         session.save(daboss)
-        self.assertRaises(exceptions.FlushError, session.flush)
+        self.assertRaises(orm_exc.FlushError, session.flush)
         c = session.query(Company).first()
         daboss.company = c
         manager_list = [e for e in c.employees if isinstance(e, Manager)]
@@ -295,24 +282,21 @@ def generate_round_trip_test(include_base=False, lazy_relation=True, redefine_co
         self.assertEquals(people.count().scalar(), 0)
         
     test_roundtrip = _function_named(
-        test_roundtrip, "test_%s%s%s%s%s" % (
+        test_roundtrip, "test_%s%s%s_%s" % (
           (lazy_relation and "lazy" or "eager"),
           (include_base and "_inclbase" or ""),
           (redefine_colprop and "_redefcol" or ""),
-          (polymorphic_fetch != 'union' and '_' + polymorphic_fetch or (use_literal_join and "_litjoin" or "")),
-          (use_outer_joins and '_outerjoins' or '')))
+          with_polymorphic))
     setattr(RoundTripTest, test_roundtrip.__name__, test_roundtrip)
 
-for include_base in [True, False]:
-    for lazy_relation in [True, False]:
-        for redefine_colprop in [True, False]:
-            for use_literal_join in [True, False]:
-                for polymorphic_fetch in ['union', 'select', 'deferred']:
-                    if polymorphic_fetch == 'union':
-                        for use_outer_joins in [True, False]:
-                            generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, use_outer_joins)
-                    else:
-                        generate_round_trip_test(include_base, lazy_relation, redefine_colprop, use_literal_join, polymorphic_fetch, False)
+for lazy_relation in [True, False]:
+    for redefine_colprop in [True, False]:
+        for with_polymorphic in ['unions', 'joins', 'auto', 'none']:
+            if with_polymorphic == 'unions':
+                for include_base in [True, False]:
+                    generate_round_trip_test(include_base, lazy_relation, redefine_colprop, with_polymorphic)
+            else:
+                generate_round_trip_test(False, lazy_relation, redefine_colprop, with_polymorphic)
 
 if __name__ == "__main__":
     testenv.main()
index ed003927bb64d45047ce7ae0d3ce0db4779f6c27..4b17e9e9d7f6ebecd9b9afacf612aa86efd769a4 100644 (file)
@@ -4,7 +4,7 @@ inheritance setups for which we maintain compatibility.
 
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import util
 from sqlalchemy.orm import *
 from testlib import *
 from testlib import fixtures
@@ -560,7 +560,7 @@ class RelationTest7(ORMTest):
 
         class Car(PersistentObject):
             def __repr__(self):
-                return "Car number %d, name %s" % i(self.car_id, self.name)
+                return "Car number %d, name %s" % (self.car_id, self.name)
 
         class Offraod_Car(Car):
             def __repr__(self):
@@ -725,18 +725,18 @@ class GenerativeTest(TestBase, AssertsExecutionResults):
         session.save(car2)
         session.flush()
 
-        # test these twice because theres caching involved, as well previous issues that modified the polymorphic union
-        for x in range(0, 2):
-            r = session.query(Person).filter(people.c.name.like('%2')).join('status').filter_by(name="active")
-            assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]"
-            r = session.query(Engineer).join('status').filter(people.c.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active"))
-            assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]"
-            # this test embeds the original polymorphic union (employee_join) fully
-            # into the WHERE criterion, using a correlated select. ticket #577 tracks
-            # that Query's adaptation of the WHERE clause does not dig into the
-            # mapped selectable itself, which permanently breaks the mapped selectable.
-            r = session.query(Person).filter(exists([Car.c.owner], Car.c.owner==employee_join.c.person_id))
-            assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
+        # this particular adapt used to cause a recursion overflow;
+        # added here for testing
+        e = exists([Car.owner], Car.owner==employee_join.c.person_id)
+        Query(Person)._adapt_clause(employee_join, False, False)
+        
+        r = session.query(Person).filter(Person.name.like('%2')).join('status').filter_by(name="active")
+        assert str(list(r)) == "[Manager M2, category YYYYYYYYY, status Status active, Engineer E2, field X, status Status active]"
+        r = session.query(Engineer).join('status').filter(Person.name.in_(['E2', 'E3', 'E4', 'M4', 'M2', 'M1']) & (status.c.name=="active"))
+        assert str(list(r)) == "[Engineer E2, field X, status Status active, Engineer E3, field X, status Status active]"
+
+        r = session.query(Person).filter(exists([1], Car.owner==Person.person_id))
+        assert str(list(r)) == "[Engineer E4, field X, status Status dead]"
 
 class MultiLevelTest(ORMTest):
     def define_tables(self, metadata):
index 34ead1622cfd16633e73a074c7f83087c3970e1c..6a40efc4ac6dd75abf7692c5547b8f9cc0a13447 100644 (file)
@@ -7,9 +7,11 @@ import testenv; testenv.configure_for_tests()
 import sets
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
 from testlib import *
 from testlib import fixtures
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.engine import default
 
 class Company(fixtures.Base):
     pass
@@ -30,7 +32,7 @@ class Paperwork(fixtures.Base):
     pass
 
 def make_test(select_type):
-    class PolymorphicQueryTest(ORMTest):
+    class PolymorphicQueryTest(ORMTest, AssertsCompiledSQL):
         keep_data = True
         keep_mappers = True
 
@@ -184,11 +186,42 @@ def make_test(select_type):
 
         def test_primary_eager_aliasing(self):
             sess = create_session()
+            
+            # assert the SQL itself here to ensure no over-joining is taking place
+            if select_type == '':
+                self.assert_compile(
+                    sess.query(Person).options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().statement, 
+                    "SELECT people.person_id AS people_person_id, people.company_id AS people_company_id, "\
+                    "people.name AS people_name, people.type AS people_type FROM people ORDER BY people.person_id  LIMIT 2 OFFSET 1", 
+                    dialect=default.DefaultDialect())
+                
             def go():
                 self.assertEquals(sess.query(Person).options(eagerload(Engineer.machines))[1:3].all(), all_employees[1:3])
             self.assert_sql_count(testing.db, go, {'':6, 'Polymorphic':3}.get(select_type, 4))
 
             sess = create_session()
+
+            if select_type == '':
+                self.assert_compile(
+                    sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines)).limit(2).offset(1).with_labels().statement, 
+                    "SELECT anon_1.people_person_id AS anon_1_people_person_id, anon_1.people_company_id AS anon_1_people_company_id, "\
+                    "anon_1.people_name AS anon_1_people_name, anon_1.people_type AS anon_1_people_type, anon_1.engineers_person_id AS "\
+                    "anon_1_engineers_person_id, anon_1.engineers_status AS anon_1_engineers_status, anon_1.engineers_engineer_name AS "\
+                    "anon_1_engineers_engineer_name, anon_1.engineers_primary_language AS anon_1_engineers_primary_language, "\
+                    "anon_1.managers_person_id AS anon_1_managers_person_id, anon_1.managers_status AS anon_1_managers_status, "\
+                    "anon_1.managers_manager_name AS anon_1_managers_manager_name, anon_1.boss_boss_id AS anon_1_boss_boss_id, "\
+                    "anon_1.boss_golf_swing AS anon_1_boss_golf_swing, machines_1.machine_id AS machines_1_machine_id, machines_1.name AS "\
+                    "machines_1_name, machines_1.engineer_id AS machines_1_engineer_id FROM (SELECT people.person_id AS people_person_id, "\
+                    "people.company_id AS people_company_id, people.name AS people_name, people.type AS people_type, engineers.person_id AS "\
+                    "engineers_person_id, engineers.status AS engineers_status, engineers.engineer_name AS engineers_engineer_name, "\
+                    "engineers.primary_language AS engineers_primary_language, managers.person_id AS managers_person_id, managers.status "\
+                    "AS managers_status, managers.manager_name AS managers_manager_name, boss.boss_id AS boss_boss_id, boss.golf_swing "\
+                    "AS boss_golf_swing FROM people LEFT OUTER JOIN engineers ON people.person_id = engineers.person_id LEFT OUTER JOIN "\
+                    "managers ON people.person_id = managers.person_id LEFT OUTER JOIN boss ON managers.person_id = boss.boss_id ORDER BY "\
+                    "people.person_id  LIMIT 2 OFFSET 1) AS anon_1 LEFT OUTER JOIN machines AS machines_1 ON anon_1.engineers_person_id = "\
+                    "machines_1.engineer_id ORDER BY anon_1.people_person_id, machines_1.machine_id", 
+                    dialect=default.DefaultDialect())
+
             def go():
                 self.assertEquals(sess.query(Person).with_polymorphic('*').options(eagerload(Engineer.machines))[1:3].all(), all_employees[1:3])
             self.assert_sql_count(testing.db, go, 3)
@@ -199,9 +232,9 @@ def make_test(select_type):
             
             # for all mappers, ensure the primary key has been calculated as just the "person_id"
             # column
-            self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert"))
-            self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert"))
-            self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss"))
+            self.assertEquals(sess.query(Person).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
+            self.assertEquals(sess.query(Engineer).get(e1.person_id), Engineer(name="dilbert", primary_language="java"))
+            self.assertEquals(sess.query(Manager).get(b1.person_id), Boss(name="pointy haired boss", golf_swing="fore"))
             
         def test_filter_on_subclass(self):
             sess = create_session()
@@ -219,7 +252,7 @@ def make_test(select_type):
 
         def test_join_from_polymorphic(self):
             sess = create_session()
-        
+
             for aliased in (True, False):
                 self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
 
@@ -227,7 +260,7 @@ def make_test(select_type):
 
                 self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1])
 
-                self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+                self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
 
         def test_join_from_with_polymorphic(self):
             sess = create_session()
@@ -240,14 +273,14 @@ def make_test(select_type):
                 self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
 
                 sess.clear()
-                self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
+                self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
     
         def test_join_to_polymorphic(self):
             sess = create_session()
             self.assertEquals(sess.query(Company).join('employees').filter(Person.name=='vlad').one(), c2)
 
             self.assertEquals(sess.query(Company).join('employees', aliased=True).filter(Person.name=='vlad').one(), c2)
-        
+
         def test_polymorphic_any(self):
             sess = create_session()
 
@@ -305,6 +338,8 @@ def make_test(select_type):
                 Manager(name="dogbert", manager_name="dogbert", status="regular manager"),
                 Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer")
             ]
+            self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
+            
             
             def go():
                 self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1])
@@ -345,6 +380,7 @@ def make_test(select_type):
             ]
             
             sess = create_session()
+            
             def go():
                 # test load Companies with lazy load to 'employees'
                 self.assertEquals(sess.query(Company).all(), assert_result)
@@ -359,7 +395,7 @@ def make_test(select_type):
             # in the case of select_type='', the eagerload doesn't take in this case; 
             # it eagerloads company->people, then a load for each of 5 rows, then lazyload of "machines"            
             self.assert_sql_count(testing.db, go, {'':7, 'Polymorphic':1}.get(select_type, 2))
-
+    
         def test_eagerload_on_subclass(self):
             sess = create_session()
             def go():
@@ -371,10 +407,15 @@ def make_test(select_type):
             
         def test_join_to_subclass(self):
             sess = create_session()
+            self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
 
             if select_type == '':
                 self.assertEquals(sess.query(Company).select_from(companies.join(people).join(engineers)).filter(Engineer.primary_language=='java').all(), [c1])
                 self.assertEquals(sess.query(Company).join(('employees', people.join(engineers))).filter(Engineer.primary_language=='java').all(), [c1])
+                
+                ealias = aliased(Engineer)
+                self.assertEquals(sess.query(Company).join(('employees', ealias)).filter(ealias.primary_language=='java').all(), [c1])
+
                 self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).all(), [e1, e2, e3])
                 self.assertEquals(sess.query(Person).select_from(people.join(engineers)).join(Engineer.machines).filter(Machine.name.ilike("%ibm%")).all(), [e1, e3])
                 self.assertEquals(sess.query(Company).join([('employees', people.join(engineers)), Engineer.machines]).all(), [c1, c2])
@@ -445,6 +486,150 @@ def make_test(select_type):
         
             self.assertEquals(sess.query(Person).filter(Person.person_id==e2.person_id).one(), e2)
     
+        def test_from_alias(self):
+            sess = create_session()
+            
+            palias = aliased(Person)
+            self.assertEquals(
+                sess.query(palias).filter(palias.name.in_(['dilbert', 'wally'])).all(),
+                [e1, e2]
+            )
+            
+        def test_self_referential(self):
+            sess = create_session()
+            
+            c1_employees = [e1, e2, b1, m1]
+            
+            palias = aliased(Person)
+            self.assertEquals(
+                sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
+                    filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), 
+                [
+                    (m1, e1),
+                    (m1, e2),
+                    (m1, b1),
+                ]
+            )
+
+            self.assertEquals(
+                sess.query(Person, palias).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
+                    filter(Person.person_id>palias.person_id).from_self().order_by(Person.person_id, palias.person_id).all(), 
+                [
+                    (m1, e1),
+                    (m1, e2),
+                    (m1, b1),
+                ]
+            )
+        
+        def test_nesting_queries(self):
+            sess = create_session()
+            
+            # query.statement places a flag "no_adapt" on the returned statement.  This prevents
+            # the polymorphic adaptation in the second "filter" from hitting it, which would pollute 
+            # the subquery and usually results in recursion overflow errors within the adaption.
+            subq = sess.query(engineers.c.person_id).filter(Engineer.primary_language=='java').statement.as_scalar()
+            
+            self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1)
+            
+            
+        def test_mixed_entities(self):
+            sess = create_session()
+
+            self.assertEquals(
+                sess.query(Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
+                [(u'Elbonia, Inc.', 
+                    Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'))]
+            )
+
+            self.assertEquals(
+                sess.query(Person, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
+                [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
+                    u'Elbonia, Inc.')]
+            )
+
+            self.assertEquals(
+                sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
+                [(u'vlad',u'Elbonia, Inc.')]
+            )
+
+            self.assertEquals(
+                sess.query(Engineer.primary_language).filter(Person.type=='engineer').all(),
+                [(u'java',), (u'c++',), (u'cobol',)]
+            )
+
+            if select_type != '':
+                self.assertEquals(
+                    sess.query(Engineer, Company.name).join(Company.employees).filter(Person.type=='engineer').all(),
+                    [
+                    (Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'), u'MegaCorp, Inc.'), 
+                    (Engineer(status=u'regular engineer',engineer_name=u'wally',name=u'wally',company_id=1,primary_language=u'c++',person_id=2,type=u'engineer'), u'MegaCorp, Inc.'), 
+                    (Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',company_id=2,primary_language=u'cobol',person_id=5,type=u'engineer'), u'Elbonia, Inc.')
+                    ]
+                )
+            
+                self.assertEquals(
+                    sess.query(Engineer.primary_language, Company.name).join(Company.employees).filter(Person.type=='engineer').order_by(desc(Engineer.primary_language)).all(),
+                    [(u'java', u'MegaCorp, Inc.'), (u'cobol', u'Elbonia, Inc.'), (u'c++', u'MegaCorp, Inc.')]
+                )
+
+            palias = aliased(Person)
+            self.assertEquals(
+                sess.query(Person, Company.name, palias).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
+                [(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
+                    u'Elbonia, Inc.', 
+                    Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'))]
+            )
+
+            self.assertEquals(
+                sess.query(palias, Company.name, Person).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
+                [(Engineer(status=u'regular engineer',engineer_name=u'dilbert',name=u'dilbert',company_id=1,primary_language=u'java',person_id=1,type=u'engineer'),
+                    u'Elbonia, Inc.', 
+                    Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),)
+                ]
+            )
+
+            self.assertEquals(
+                sess.query(Person.name, Company.name, palias.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').filter(palias.name=='dilbert').all(),
+                [(u'vlad', u'Elbonia, Inc.', u'dilbert')]
+            )
+            
+            palias = aliased(Person)
+            self.assertEquals(
+                sess.query(Person.type, Person.name, palias.type, palias.name).filter(Person.company_id==palias.company_id).filter(Person.name=='dogbert').\
+                    filter(Person.person_id>palias.person_id).order_by(Person.person_id, palias.person_id).all(), 
+                [(u'manager', u'dogbert', u'engineer', u'dilbert'), 
+                (u'manager', u'dogbert', u'engineer', u'wally'), 
+                (u'manager', u'dogbert', u'boss', u'pointy haired boss')]
+            )
+        
+            self.assertEquals(
+                sess.query(Person.name, Paperwork.description).filter(Person.person_id==Paperwork.person_id).order_by(Person.name, Paperwork.description).all(), 
+                [(u'dilbert', u'tps report #1'), (u'dilbert', u'tps report #2'), (u'dogbert', u'review #2'), 
+                (u'dogbert', u'review #3'), 
+                (u'pointy haired boss', u'review #1'), 
+                (u'vlad', u'elbonian missive #3'),
+                (u'wally', u'tps report #3'), 
+                (u'wally', u'tps report #4'),
+                ]
+            )
+
+            if select_type != '':
+                self.assertEquals(
+                    sess.query(func.count(Person.person_id)).filter(Engineer.primary_language=='java').all(), 
+                    [(1, )]
+                )
+            
+            self.assertEquals(
+                sess.query(Company.name, func.count(Person.person_id)).filter(Company.company_id==Person.company_id).group_by(Company.name).order_by(Company.name).all(),
+                [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)]
+            )
+
+            self.assertEquals(
+                sess.query(Company.name, func.count(Person.person_id)).join(Company.employees).group_by(Company.name).order_by(Company.name).all(),
+                [(u'Elbonia, Inc.', 1), (u'MegaCorp, Inc.', 4)]
+            )
+    
+    
     PolymorphicQueryTest.__name__ = "Polymorphic%sTest" % select_type
     return PolymorphicQueryTest
 
@@ -500,11 +685,6 @@ class SelfReferentialTest(ORMTest):
         
         self.assertEquals(sess.query(Engineer).join('reports_to', aliased=True).filter(Person.name=='dogbert').first(), Engineer(name='dilbert'))
         
-    def test_noalias_raises(self):
-        sess = create_session()
-        def go():
-            sess.query(Engineer).join('reports_to')
-        self.assertRaises(exceptions.InvalidRequestError, go)
 
 class M2MFilterTest(ORMTest):
     keep_mappers = True
@@ -570,6 +750,59 @@ class M2MFilterTest(ORMTest):
         sess = create_session()
         self.assertEquals(sess.query(Organization).filter(Organization.engineers.of_type(Engineer).any(Engineer.name=='e1')).all(), [Organization(name='org1')])
         self.assertEquals(sess.query(Organization).filter(Organization.engineers.any(Engineer.name=='e1')).all(), [Organization(name='org1')])
+
+class SelfReferentialM2MTest(ORMTest, AssertsCompiledSQL):
+    def define_tables(self, metadata):
+        Base = declarative_base(metadata=metadata)
+
+        secondary_table = Table('secondary', Base.metadata,
+           Column('left_id', Integer, ForeignKey('parent.id'), nullable=False),
+           Column('right_id', Integer, ForeignKey('parent.id'), nullable=False))
+          
+        global Parent, Child1, Child2
+        class Parent(Base):
+           __tablename__ = 'parent'
+           id = Column(Integer, primary_key=True)
+           cls = Column(String(50))
+           __mapper_args__ = dict(polymorphic_on = cls )
+
+        class Child1(Parent):
+           __tablename__ = 'child1'
+           id = Column(Integer, ForeignKey('parent.id'), primary_key=True)
+           __mapper_args__ = dict(polymorphic_identity = 'child1')
+
+        class Child2(Parent):
+           __tablename__ = 'child2'
+           id = Column(Integer, ForeignKey('parent.id'), primary_key=True)
+           __mapper_args__ = dict(polymorphic_identity = 'child2')
+
+        Child1.left_child2 = relation(Child2, secondary = secondary_table,
+               primaryjoin = Parent.id == secondary_table.c.right_id,
+               secondaryjoin = Parent.id == secondary_table.c.left_id,
+               uselist = False,
+                               )
+
+    def test_eager_join(self):
+        session = create_session()
         
+        c1 = Child1()
+        c1.left_child2 = Child2()
+        session.add(c1)
+        session.flush()
+        
+        q = session.query(Child1).options(eagerload('left_child2'))
+
+        # test that the splicing of the join works here, doesnt break in the middle of "parent join child1"
+        self.assert_compile(q.limit(1).with_labels().statement, 
+        "SELECT anon_1.child1_id AS anon_1_child1_id, anon_1.parent_id AS anon_1_parent_id, "\
+        "anon_1.parent_cls AS anon_1_parent_cls, anon_2.child2_id AS anon_2_child2_id, anon_2.parent_id AS anon_2_parent_id, "\
+        "anon_2.parent_cls AS anon_2_parent_cls FROM (SELECT child1.id AS child1_id, parent.id AS parent_id, "\
+        "parent.cls AS parent_cls, parent.id AS parent_oid FROM parent JOIN child1 ON parent.id = child1.id ORDER BY parent.id  "\
+        "LIMIT 1) AS anon_1 LEFT OUTER JOIN secondary AS secondary_1 ON anon_1.parent_id = secondary_1.right_id LEFT OUTER JOIN "\
+        "(SELECT parent.id AS parent_id, parent.cls AS parent_cls, child2.id AS child2_id FROM parent JOIN child2 ON parent.id = child2.id) "\
+        "AS anon_2 ON anon_2.parent_id = secondary_1.left_id ORDER BY anon_1.child1_id"
+        , dialect=default.DefaultDialect())
+        assert q.first() is c1
+
 if __name__ == "__main__":
     testenv.main()
index 81223cc02e36445b757e61ed953706a01c3169d4..dabb701cd90321a3ed31ee01749187395dcbcdeb 100644 (file)
@@ -61,6 +61,10 @@ class SingleInheritanceTest(TestBase, AssertsExecutionResults):
         assert session.query(Engineer).all() == [e1, e2]
         assert session.query(Manager).all() == [m1]
         assert session.query(JuniorEngineer).all() == [e2]
-
+        
+        m1 = session.query(Manager).one()
+        session.expire(m1, ['manager_data'])
+        self.assertEquals(m1.manager_data, "knows how to manage things")
+        
 if __name__ == '__main__':
     testenv.main()
diff --git a/test/orm/instrumentation.py b/test/orm/instrumentation.py
new file mode 100644 (file)
index 0000000..5cb3a5c
--- /dev/null
@@ -0,0 +1,745 @@
+import testenv; testenv.configure_for_tests()
+from sqlalchemy import MetaData, Table, Column, Integer, ForeignKey
+from sqlalchemy import util
+from sqlalchemy.orm import attributes
+from sqlalchemy.orm import create_session
+from sqlalchemy.orm import interfaces
+from sqlalchemy.orm import mapper
+from sqlalchemy.orm import relation
+
+from testlib.testing import eq_, ne_
+from testlib.compat import _function_named
+from testlib import TestBase
+
+
+def modifies_instrumentation_finders(fn):
+    def decorated(*args, **kw):
+        pristine = attributes.instrumentation_finders[:]
+        try:
+            fn(*args, **kw)
+        finally:
+            del attributes.instrumentation_finders[:]
+            attributes.instrumentation_finders.extend(pristine)
+    return _function_named(decorated, fn.func_name)
+
+def with_lookup_strategy(strategy):
+    def decorate(fn):
+        def wrapped(*args, **kw):
+            current = attributes._lookup_strategy
+            try:
+                attributes._install_lookup_strategy(strategy)
+                return fn(*args, **kw)
+            finally:
+                attributes._install_lookup_strategy(current)
+        return _function_named(wrapped, fn.func_name)
+    return decorate
+
+
+class InitTest(TestBase):
+    def fixture(self):
+        return Table('t', MetaData(),
+                     Column('id', Integer, primary_key=True),
+                     Column('type', Integer),
+                     Column('x', Integer),
+                     Column('y', Integer))
+
+    def register(self, cls, canary):
+        original_init = cls.__init__
+        attributes.register_class(cls)
+        ne_(cls.__init__, original_init)
+        manager = attributes.manager_of_class(cls)
+        def on_init(state, instance, args, kwargs):
+            canary.append((cls, 'on_init', type(instance)))
+        manager.events.add_listener('on_init', on_init)
+
+    def test_ai(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+
+        obj = A()
+        eq_(inits, [(A, '__init__')])
+
+    def test_A(self):
+        inits = []
+
+        class A(object): pass
+        self.register(A, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A)])
+
+    def test_Ai(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+        self.register(A, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+    def test_ai_B(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+
+        class B(A): pass
+        self.register(B, inits)
+
+        obj = A()
+        eq_(inits, [(A, '__init__')])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+    def test_ai_Bi(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+
+        class B(A):
+            def __init__(self):
+                inits.append((B, '__init__'))
+                super(B, self).__init__()
+        self.register(B, inits)
+
+        obj = A()
+        eq_(inits, [(A, '__init__')])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')])
+
+    def test_Ai_bi(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+        self.register(A, inits)
+
+        class B(A):
+            def __init__(self):
+                inits.append((B, '__init__'))
+                super(B, self).__init__()
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, '__init__'), (A, 'on_init', B), (A, '__init__')])
+
+    def test_Ai_Bi(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+        self.register(A, inits)
+
+        class B(A):
+            def __init__(self):
+                inits.append((B, '__init__'))
+                super(B, self).__init__()
+        self.register(B, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')])
+
+    def test_Ai_B(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+        self.register(A, inits)
+
+        class B(A): pass
+        self.register(B, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+    def test_Ai_Bi_Ci(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+        self.register(A, inits)
+
+        class B(A):
+            def __init__(self):
+                inits.append((B, '__init__'))
+                super(B, self).__init__()
+        self.register(B, inits)
+
+        class C(B):
+            def __init__(self):
+                inits.append((C, '__init__'))
+                super(C, self).__init__()
+        self.register(C, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, 'on_init', B), (B, '__init__'), (A, '__init__')])
+
+        del inits[:]
+        obj = C()
+        eq_(inits, [(C, 'on_init', C), (C, '__init__'), (B, '__init__'),
+                   (A, '__init__')])
+
+    def test_Ai_bi_Ci(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+        self.register(A, inits)
+
+        class B(A):
+            def __init__(self):
+                inits.append((B, '__init__'))
+                super(B, self).__init__()
+
+        class C(B):
+            def __init__(self):
+                inits.append((C, '__init__'))
+                super(C, self).__init__()
+        self.register(C, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, '__init__'), (A, 'on_init', B), (A, '__init__')])
+
+        del inits[:]
+        obj = C()
+        eq_(inits, [(C, 'on_init', C), (C, '__init__'),  (B, '__init__'),
+                   (A, '__init__')])
+
+    def test_Ai_b_Ci(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+        self.register(A, inits)
+
+        class B(A): pass
+
+        class C(B):
+            def __init__(self):
+                inits.append((C, '__init__'))
+                super(C, self).__init__()
+        self.register(C, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(A, 'on_init', B), (A, '__init__')])
+
+        del inits[:]
+        obj = C()
+        eq_(inits, [(C, 'on_init', C), (C, '__init__'), (A, '__init__')])
+
+    def test_Ai_B_Ci(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+        self.register(A, inits)
+
+        class B(A): pass
+        self.register(B, inits)
+
+        class C(B):
+            def __init__(self):
+                inits.append((C, '__init__'))
+                super(C, self).__init__()
+        self.register(C, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+        del inits[:]
+        obj = C()
+        eq_(inits, [(C, 'on_init', C), (C, '__init__'), (A, '__init__')])
+
+    def test_Ai_B_C(self):
+        inits = []
+
+        class A(object):
+            def __init__(self):
+                inits.append((A, '__init__'))
+        self.register(A, inits)
+
+        class B(A): pass
+        self.register(B, inits)
+
+        class C(B): pass
+        self.register(C, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A), (A, '__init__')])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, 'on_init', B), (A, '__init__')])
+
+        del inits[:]
+        obj = C()
+        eq_(inits, [(C, 'on_init', C), (A, '__init__')])
+
+    def test_A_Bi_C(self):
+        inits = []
+
+        class A(object): pass
+        self.register(A, inits)
+
+        class B(A):
+            def __init__(self):
+                inits.append((B, '__init__'))
+        self.register(B, inits)
+
+        class C(B): pass
+        self.register(C, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A)])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, 'on_init', B), (B, '__init__')])
+
+        del inits[:]
+        obj = C()
+        eq_(inits, [(C, 'on_init', C), (B, '__init__')])
+
+    def test_A_B_Ci(self):
+        inits = []
+
+        class A(object): pass
+        self.register(A, inits)
+
+        class B(A): pass
+        self.register(B, inits)
+
+        class C(B):
+            def __init__(self):
+                inits.append((C, '__init__'))
+        self.register(C, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A)])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, 'on_init', B)])
+
+        del inits[:]
+        obj = C()
+        eq_(inits, [(C, 'on_init', C), (C, '__init__')])
+
+    def test_A_B_C(self):
+        inits = []
+
+        class A(object): pass
+        self.register(A, inits)
+
+        class B(A): pass
+        self.register(B, inits)
+
+        class C(B): pass
+        self.register(C, inits)
+
+        obj = A()
+        eq_(inits, [(A, 'on_init', A)])
+
+        del inits[:]
+
+        obj = B()
+        eq_(inits, [(B, 'on_init', B)])
+
+        del inits[:]
+        obj = C()
+        eq_(inits, [(C, 'on_init', C)])
+
+
+class MapperInitTest(TestBase):
+
+    def fixture(self):
+        return Table('t', MetaData(),
+                     Column('id', Integer, primary_key=True),
+                     Column('type', Integer),
+                     Column('x', Integer),
+                     Column('y', Integer))
+
+    def test_partially_mapped_inheritance(self):
+        class A(object):
+            pass
+
+        class B(A):
+            pass
+
+        class C(B):
+            def __init__(self):
+                pass
+
+        mapper(A, self.fixture())
+
+        a = attributes.instance_state(A())
+        assert isinstance(a, attributes.InstanceState)
+        assert type(a) is not attributes.InstanceState
+
+        b = attributes.instance_state(B())
+        assert isinstance(b, attributes.InstanceState)
+        assert type(b) is not attributes.InstanceState
+
+        # C is unmanaged
+        cobj = C()
+        self.assertRaises((AttributeError, TypeError),
+                          attributes.instance_state, cobj)
+
+class InstrumentationCollisionTest(TestBase):
+    def test_none(self):
+        class A(object): pass
+        attributes.register_class(A)
+
+        mgr_factory = lambda cls: attributes.ClassManager(cls)
+        class B(object):
+            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
+        attributes.register_class(B)
+
+        class C(object):
+            __sa_instrumentation_manager__ = attributes.ClassManager
+        attributes.register_class(C)
+
+    def test_single_down(self):
+        class A(object): pass
+        attributes.register_class(A)
+
+        mgr_factory = lambda cls: attributes.ClassManager(cls)
+        class B(A):
+            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
+
+        self.assertRaises(TypeError, attributes.register_class, B)
+
+    def test_single_up(self):
+
+        class A(object): pass
+        # delay registration
+
+        mgr_factory = lambda cls: attributes.ClassManager(cls)
+        class B(A):
+            __sa_instrumentation_manager__ = staticmethod(mgr_factory)
+        attributes.register_class(B)
+        self.assertRaises(TypeError, attributes.register_class, A)
+
+    def test_diamond_b1(self):
+        mgr_factory = lambda cls: attributes.ClassManager(cls)
+
+        class A(object): pass
+        class B1(A): pass
+        class B2(A):
+            __sa_instrumentation_manager__ = mgr_factory
+        class C(object): pass
+
+        self.assertRaises(TypeError, attributes.register_class, B1)
+
+    def test_diamond_b2(self):
+        mgr_factory = lambda cls: attributes.ClassManager(cls)
+
+        class A(object): pass
+        class B1(A): pass
+        class B2(A):
+            __sa_instrumentation_manager__ = mgr_factory
+        class C(object): pass
+
+        self.assertRaises(TypeError, attributes.register_class, B2)
+
+    def test_diamond_c_b(self):
+        mgr_factory = lambda cls: attributes.ClassManager(cls)
+
+        class A(object): pass
+        class B1(A): pass
+        class B2(A):
+            __sa_instrumentation_manager__ = mgr_factory
+        class C(object): pass
+
+        attributes.register_class(C)
+        self.assertRaises(TypeError, attributes.register_class, B1)
+
+
+class OnLoadTest(TestBase):
+    """Check that Events.on_load is not hit in regular attributes operations."""
+
+    def test_basic(self):
+        import pickle
+
+        global A
+        class A(object):
+            pass
+
+        def canary(instance): assert False
+
+        try:
+            attributes.register_class(A)
+            manager = attributes.manager_of_class(A)
+            manager.events.add_listener('on_load', canary)
+
+            a = A()
+            p_a = pickle.dumps(a)
+            re_a = pickle.loads(p_a)
+        finally:
+            del A
+
+
+class ExtendedEventsTest(TestBase):
+    """Allow custom Events implementations."""
+
+    @modifies_instrumentation_finders
+    def test_subclassed(self):
+        class MyEvents(attributes.Events):
+            pass
+        class MyClassManager(attributes.ClassManager):
+            event_registry_factory = MyEvents
+
+        attributes.instrumentation_finders.insert(0, lambda cls: MyClassManager)
+
+        class A(object): pass
+
+        attributes.register_class(A)
+        manager = attributes.manager_of_class(A)
+        assert isinstance(manager.events, MyEvents)
+
+
+class NativeInstrumentationTest(TestBase):
+    @with_lookup_strategy(util.symbol('native'))
+    def test_register_reserved_attribute(self):
+        class T(object): pass
+
+        attributes.register_class(T)
+        manager = attributes.manager_of_class(T)
+
+        sa = attributes.ClassManager.STATE_ATTR
+        ma = attributes.ClassManager.MANAGER_ATTR
+
+        fails = lambda method, attr: self.assertRaises(
+            KeyError, getattr(manager, method), attr, property())
+
+        fails('install_member', sa)
+        fails('install_member', ma)
+        fails('install_descriptor', sa)
+        fails('install_descriptor', ma)
+
+    @with_lookup_strategy(util.symbol('native'))
+    def test_mapped_stateattr(self):
+        t = Table('t', MetaData(),
+                  Column('id', Integer, primary_key=True),
+                  Column(attributes.ClassManager.STATE_ATTR, Integer))
+
+        class T(object): pass
+
+        self.assertRaises(KeyError, mapper, T, t)
+
+    @with_lookup_strategy(util.symbol('native'))
+    def test_mapped_managerattr(self):
+        t = Table('t', MetaData(),
+                  Column('id', Integer, primary_key=True),
+                  Column(attributes.ClassManager.MANAGER_ATTR, Integer))
+
+        class T(object): pass
+        self.assertRaises(KeyError, mapper, T, t)
+
+
+class MiscTest(TestBase):
+    """Seems basic, but not directly covered elsewhere!"""
+
+    def test_compileonattr(self):
+        t = Table('t', MetaData(),
+                  Column('id', Integer, primary_key=True),
+                  Column('x', Integer))
+        class A(object): pass
+        mapper(A, t)
+
+        a = A()
+        assert a.id is None
+
+    def test_compileonattr_rel(self):
+        m = MetaData()
+        t1 = Table('t1', m,
+                   Column('id', Integer, primary_key=True),
+                   Column('x', Integer))
+        t2 = Table('t2', m,
+                   Column('id', Integer, primary_key=True),
+                   Column('t1_id', Integer, ForeignKey('t1.id')))
+        class A(object): pass
+        class B(object): pass
+        mapper(A, t1, properties=dict(bs=relation(B)))
+        mapper(B, t2)
+
+        a = A()
+        assert not a.bs
+
+    def test_compileonattr_rel_backref_a(self):
+        m = MetaData()
+        t1 = Table('t1', m,
+                   Column('id', Integer, primary_key=True),
+                   Column('x', Integer))
+        t2 = Table('t2', m,
+                   Column('id', Integer, primary_key=True),
+                   Column('t1_id', Integer, ForeignKey('t1.id')))
+
+        class Base(object):
+            def __init__(self, *args, **kwargs):
+                pass
+
+        for base in object, Base:
+            class A(base): pass
+            class B(base): pass
+            mapper(A, t1, properties=dict(bs=relation(B, backref='a')))
+            mapper(B, t2)
+
+            b = B()
+            assert b.a is None
+            a = A()
+            b.a = a
+
+            session = create_session()
+            session.save(b)
+            assert a in session, "base is %s" % base
+
+    def test_compileonattr_rel_backref_b(self):
+        m = MetaData()
+        t1 = Table('t1', m,
+                   Column('id', Integer, primary_key=True),
+                   Column('x', Integer))
+        t2 = Table('t2', m,
+                   Column('id', Integer, primary_key=True),
+                   Column('t1_id', Integer, ForeignKey('t1.id')))
+
+        class Base(object):
+            def __init__(self): pass
+        class Base_AKW(object):
+            def __init__(self, *args, **kwargs): pass
+
+        for base in object, Base, Base_AKW:
+            class A(base): pass
+            class B(base): pass
+            mapper(A, t1)
+            mapper(B, t2, properties=dict(a=relation(A, backref='bs')))
+
+            a = A()
+            b = B()
+            b.a = a
+
+            session = create_session()
+            session.save(a)
+            assert b in session, 'base: %s' % base
+
+    def test_compileonattr_rel_entity_name(self):
+        m = MetaData()
+        t1 = Table('t1', m,
+                   Column('id', Integer, primary_key=True),
+                   Column('x', Integer))
+        t2 = Table('t2', m,
+                   Column('id', Integer, primary_key=True),
+                   Column('t1_id', Integer, ForeignKey('t1.id')))
+        class A(object): pass
+        class B(object): pass
+        mapper(A, t1, properties=dict(bs=relation(B)), entity_name='x')
+        mapper(B, t2)
+
+        a = A()
+        assert not a.bs
+
+class FinderTest(TestBase):
+    def test_standard(self):
+        class A(object): pass
+
+        attributes.register_class(A)
+
+        eq_(type(attributes.manager_of_class(A)), attributes.ClassManager)
+
+    def test_nativeext_interfaceexact(self):
+        class A(object):
+            __sa_instrumentation_manager__ = interfaces.InstrumentationManager
+
+        attributes.register_class(A)
+        ne_(type(attributes.manager_of_class(A)), attributes.ClassManager)
+
+    def test_nativeext_submanager(self):
+        class Mine(attributes.ClassManager): pass
+        class A(object):
+            __sa_instrumentation_manager__ = Mine
+
+        attributes.register_class(A)
+        eq_(type(attributes.manager_of_class(A)), Mine)
+
+    @modifies_instrumentation_finders
+    def test_customfinder_greedy(self):
+        class Mine(attributes.ClassManager): pass
+        class A(object): pass
+        def find(cls):
+            return Mine
+
+        attributes.instrumentation_finders.insert(0, find)
+        attributes.register_class(A)
+        eq_(type(attributes.manager_of_class(A)), Mine)
+
+    @modifies_instrumentation_finders
+    def test_customfinder_pass(self):
+        class A(object): pass
+        def find(cls):
+            return None
+
+        attributes.instrumentation_finders.insert(0, find)
+        attributes.register_class(A)
+        eq_(type(attributes.manager_of_class(A)), attributes.ClassManager)
+
+
+if __name__ == "__main__":
+    testenv.main()
index 55d79fd32b94f8d458445a2db31d2db4525e7fb6..1dd5d5e942257d488d57de4169edb1b8e81a2067 100644 (file)
@@ -2,12 +2,13 @@
 
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import *
 from testlib import *
 from testlib.fixtures import *
 from query import QueryTest
 import datetime
+from sqlalchemy.orm import attributes
 
 class LazyTest(FixtureTest):
     keep_mappers = False
@@ -21,35 +22,17 @@ class LazyTest(FixtureTest):
         q = sess.query(User)
         assert [User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')])] == q.filter(users.c.id == 7).all()
 
-    @testing.uses_deprecated('SessionContext')
-    def test_bindstosession(self):
-        """test that lazy loaders use the mapper's contextual session if the parent instance
-        is not in a session, and that an error is raised if no contextual session"""
-
-        from sqlalchemy.ext.sessioncontext import SessionContext
-        ctx = SessionContext(create_session)
-        m = mapper(User, users, properties = dict(
-            addresses = relation(mapper(Address, addresses, extension=ctx.mapper_extension), lazy=True)
-        ), extension=ctx.mapper_extension)
-        q = ctx.current.query(m)
-        u = q.filter(users.c.id == 7).first()
-        ctx.current.expunge(u)
-        assert User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')]) == u
-
-        clear_mappers()
+    def test_needs_parent(self):
+        """test the error raised when parent object is not bound."""
 
         mapper(User, users, properties={
             'addresses':relation(mapper(Address, addresses), lazy=True)
         })
-        try:
-            sess = create_session()
-            q = sess.query(User)
-            u = q.filter(users.c.id == 7).first()
-            sess.expunge(u)
-            assert User(id=7, addresses=[Address(id=1, email_address='jack@bean.com')]) == u
-            assert False
-        except exceptions.InvalidRequestError, err:
-            assert "not bound to a Session, and no contextual session" in str(err)
+        sess = create_session()
+        q = sess.query(User)
+        u = q.filter(users.c.id == 7).first()
+        sess.expunge(u)
+        self.assertRaises(sa_exc.InvalidRequestError, getattr, u, 'addresses')
 
     def test_orderby(self):
         mapper(User, users, properties = {
@@ -127,8 +110,8 @@ class LazyTest(FixtureTest):
 
         sess = create_session()
         user = sess.query(User).get(7)
-        assert getattr(User, 'addresses').hasparent(user.addresses[0], optimistic=True)
-        assert not class_mapper(Address)._is_orphan(user.addresses[0])
+        assert getattr(User, 'addresses').hasparent(attributes.instance_state(user.addresses[0]), optimistic=True)
+        assert not class_mapper(Address)._is_orphan(attributes.instance_state(user.addresses[0]))
 
 
     def test_limit(self):
@@ -170,7 +153,7 @@ class LazyTest(FixtureTest):
         u2 = users.alias('u2')
         s = union_all(u2.select(use_labels=True), u2.select(use_labels=True), u2.select(use_labels=True)).alias('u')
         print [key for key in s.c.keys()]
-        l = q.filter(s.c.u2_id==User.c.id).distinct().all()
+        l = q.filter(s.c.u2_id==User.id).distinct().all()
         assert fixtures.user_all_result == l
 
     def test_one_to_many_scalar(self):
index ca6410533d0540c28e497d1ba18cc400b3167830..e8580af4a208302f170ee36299111c4b9d3e86ed 100644 (file)
@@ -1,8 +1,8 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
+from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import *
 from testlib import *
-from sqlalchemy import exceptions
 
 class Place(object):
     '''represents a place'''
@@ -75,14 +75,7 @@ class M2MTest(ORMTest):
         mapper(Transition, transition, properties={
             'places':relation(Place, secondary=place_input, backref='transitions')
         })
-        try:
-            compile_mappers()
-            assert False
-        except exceptions.ArgumentError, e:
-            assert str(e) in [
-                "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'",
-                "Error creating backref 'places' on relation 'Place.transitions (Transition)': property of that name exists on mapper 'Mapper|Transition|transition'"
-            ]
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Error creating backref", compile_mappers)
 
 
     def testcircular(self):
index 7dce096145ca4e1a4ff0542abbf2fb966935bf48..017b2534cfae207b47662f770ab5376c414dd531 100644 (file)
@@ -2,9 +2,8 @@
 
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc as sa_exc, sql
 from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext, SessionContextExt
 from testlib import *
 from testlib import fixtures
 from testlib.tables import *
@@ -32,15 +31,44 @@ class MapperTest(MapperSuperTest):
             properties={
             'addresses':relation(Address, backref='email_address')
         })
-        self.assertRaises(exceptions.ArgumentError, compile_mappers)
+        self.assertRaises(sa_exc.ArgumentError, compile_mappers)
 
     def test_prop_accessor(self):
         mapper(User, users)
         self.assertRaises(NotImplementedError, getattr, class_mapper(User), 'properties')
 
+    @testing.uses_deprecated(
+        'Call to deprecated function _instance_key',
+        'Call to deprecated function _sa_session_id',
+        'Call to deprecated function _entity_name')
+    def test_legacy_accesors(self):
+        u1 = User()
+        assert not hasattr(u1, '_instance_key')
+        assert not hasattr(u1, '_sa_session_id')
+        assert not hasattr(u1, '_entity_name')
+
+        mapper(User, users)
+        u1 = User()
+        assert not hasattr(u1, '_instance_key')
+        assert not hasattr(u1, '_sa_session_id')
+        assert u1._entity_name is None
+
+        sess = create_session()
+        sess.save(u1)
+        assert not hasattr(u1, '_instance_key')
+        assert u1._sa_session_id == sess.hash_key
+        assert u1._entity_name is None
+
+        sess.flush()
+        assert u1._instance_key == class_mapper(u1).identity_key_from_instance(u1)
+        assert u1._sa_session_id == sess.hash_key
+        assert u1._entity_name is None
+        sess.delete(u1)
+        sess.flush()
+
     def test_badcascade(self):
         mapper(Address, addresses)
-        self.assertRaises(exceptions.ArgumentError, relation, Address, cascade="fake, all, delete-orphan")
+        self.assertRaises(sa_exc.ArgumentError, relation, Address, cascade="fake, all, delete-orphan")
 
     def test_columnprefix(self):
         mapper(User, users, column_prefix='_', properties={
@@ -56,26 +84,27 @@ class MapperTest(MapperSuperTest):
 
     def test_no_pks(self):
         s = select([users.c.user_name]).alias('foo')
-        self.assertRaises(exceptions.ArgumentError, mapper, User, s)
-    
+        self.assertRaises(sa_exc.ArgumentError, mapper, User, s)
+
     def test_recompile_on_othermapper(self):
-        """test the global '_new_mappers' flag such that a compile 
+        """test the global '_new_mappers' flag such that a compile
         trigger on an already-compiled mapper still triggers a check against all mappers."""
 
         from sqlalchemy.orm import mapperlib
-        
+
         mapper(User, users)
         compile_mappers()
         assert mapperlib._new_mappers is False
-        
-        m = mapper(Address, addresses, properties={'user':relation(User, backref="addresses")})
-        
-        assert m._Mapper__props_init is False
+
+        m = mapper(Address, addresses, properties={
+                'user': relation(User, backref="addresses")})
+
+        assert m.compiled is False
         assert mapperlib._new_mappers is True
         u = User()
         assert User.addresses
         assert mapperlib._new_mappers is False
-    
+
     def test_compileonsession(self):
         m = mapper(User, users)
         session = create_session()
@@ -95,7 +124,7 @@ class MapperTest(MapperSuperTest):
     def test_badconstructor(self):
         """test that if the construction of a mapped class fails, the instnace does not get placed in the session"""
         class Foo(object):
-            def __init__(self, one, two):
+            def __init__(self, one, two, _sa_session=None):
                 pass
         mapper(Foo, users)
         sess = create_session()
@@ -103,14 +132,13 @@ class MapperTest(MapperSuperTest):
         assert len(list(sess)) == 0
         self.assertRaises(TypeError, Foo, 'one')
 
-    @testing.uses_deprecated('SessionContext', 'SessionContextExt')
-    def test_constructorexceptions(self):
+    def test_constructorexc(self):
         """test that exceptions raised in the mapped class are not masked by sa decorations"""
         ex = AssertionError('oops')
         sess = create_session()
 
         class Foo(object):
-            def __init__(self):
+            def __init__(self, **kw):
                 raise ex
         mapper(Foo, users)
 
@@ -121,7 +149,7 @@ class MapperTest(MapperSuperTest):
             assert e is ex
 
         clear_mappers()
-        mapper(Foo, users, extension=SessionContextExt(SessionContext()))
+        mapper(Foo, users, extension=scoped_session(create_session).extension)
         def bad_expunge(foo):
             raise Exception("this exception should be stated as a warning")
 
@@ -130,7 +158,7 @@ class MapperTest(MapperSuperTest):
             Foo(_sa_session=sess)
             assert False
         except Exception, e:
-            assert isinstance(e, exceptions.SAWarning)
+            assert isinstance(e, sa_exc.SAWarning), e
 
         clear_mappers()
 
@@ -172,7 +200,7 @@ class MapperTest(MapperSuperTest):
         mapper(User, users, properties = {
             'addresses' : relation(mapper(Address, addresses))
         })
-        assert (User.user_id==3).compare(users.c.user_id==3)
+        self.assertEquals((User.user_id==3).__str__(), (users.c.user_id==3).__str__())
 
         clear_mappers()
 
@@ -232,7 +260,7 @@ class MapperTest(MapperSuperTest):
         m.add_property('uc_user_name2', comparable_property(
                 UCComparator, User.uc_user_name2))
 
-        sess = create_session(transactional=True)
+        sess = create_session(autocommit=False)
         assert sess.query(User).get(7)
 
         u = sess.query(User).filter_by(user_name='jack').one()
@@ -337,14 +365,14 @@ class MapperTest(MapperSuperTest):
                 'addresses':relation(Address)
             }).compile()
             assert False
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             assert "Attempting to assign a new relation 'addresses' to a non-primary mapper on class 'User'" in str(e)
 
     def test_illegal_non_primary_2(self):
         try:
             mapper(User, users, non_primary=True)
             assert False
-        except exceptions.InvalidRequestError, e:
+        except sa_exc.InvalidRequestError, e:
             assert "Configure a primary mapper first" in str(e)
 
     def test_propfilters(self):
@@ -386,7 +414,6 @@ class MapperTest(MapperSuperTest):
         def assert_props(cls, want):
             have = set([n for n in dir(cls) if not n.startswith('_')])
             want = set(want)
-            want.add('c')
             self.assert_(have == want, repr(have) + " " + repr(want))
 
         assert_props(Person, ['id', 'name', 'type'])
@@ -398,16 +425,6 @@ class MapperTest(MapperSuperTest):
         assert_props(Hoho, ['id', 'name', 'type'])
         assert_props(Lala, ['p_employee_number', 'p_id', 'p_name', 'p_type'])
 
-    @testing.uses_deprecated('//select_by', '//join_via', '//list')
-    def test_recursive_select_by_deprecated(self):
-        """test that no endless loop occurs when traversing for select_by"""
-        m = mapper(User, users, properties={
-            'orders':relation(mapper(Order, orders), backref='user'),
-            'addresses':relation(mapper(Address, addresses), backref='user'),
-        })
-        q = create_session().query(m)
-        q.select_by(email_address='foo')
-
     def test_mappingtojoin(self):
         """test mapping to a join"""
         usersaddresses = sql.join(users, addresses, users.c.user_id == addresses.c.user_id)
@@ -472,21 +489,6 @@ class MapperTest(MapperSuperTest):
 
         self.assert_result(l, User, user_result[0])
 
-    @testing.uses_deprecated('//select')
-    def test_customjoin_deprecated(self):
-        """test that the from_obj parameter to query.select() can be used
-        to totally replace the FROM parameters of the generated query."""
-
-        m = mapper(User, users, properties={
-            'orders':relation(mapper(Order, orders, properties={
-                'items':relation(mapper(Item, orderitems))
-            }))
-        })
-
-        q = create_session().query(m)
-        l = q.select((orderitems.c.item_name=='item 4'), from_obj=[users.join(orders).join(orderitems)])
-        self.assert_result(l, User, user_result[0])
-
     def test_orderby(self):
         """test ordering at the mapper and query level"""
 
@@ -527,21 +529,14 @@ class MapperTest(MapperSuperTest):
         mapper(User, users)
         q = create_session().query(User)
         self.assert_(q.count()==3)
-        self.assert_(q.count(users.c.user_id.in_([8,9]))==2)
-
-    @testing.unsupported('firebird')
-    @testing.uses_deprecated('//count_by', '//join_by', '//join_via')
-    def test_count_by_deprecated(self):
-        mapper(User, users)
-        q = create_session().query(User)
-        self.assert_(q.count_by(user_name='fred')==1)
+        self.assert_(q.filter(users.c.user_id.in_([8,9])).count()==2)
 
     def test_manytomany_count(self):
         mapper(Item, orderitems, properties = dict(
                 keywords = relation(mapper(Keyword, keywords), itemkeywords, lazy = True),
             ))
         q = create_session().query(Item)
-        assert q.join('keywords').distinct().count(Keyword.c.name=="red") == 2
+        assert q.join('keywords').distinct().filter(Keyword.name=="red").count() == 2
 
     def test_override(self):
         # assert that overriding a column raises an error
@@ -550,7 +545,7 @@ class MapperTest(MapperSuperTest):
                     'user_name' : relation(mapper(Address, addresses)),
                 }).compile()
             self.assert_(False, "should have raised ArgumentError")
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             self.assert_(True)
 
         clear_mappers()
@@ -601,8 +596,8 @@ class MapperTest(MapperSuperTest):
         self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1]))
 
         addr = sess.query(Address).filter_by(address_id=user_address_result[0]['addresses'][1][0]['address_id']).one()
-        u = sess.query(User).filter_by(adname=addr).one()
-        u2 = sess.query(User).filter_by(adlist=addr).one()
+        u = sess.query(User).filter(User.adname.contains(addr)).one()
+        u2 = sess.query(User).filter(User.adlist.contains(addr)).one()
 
         assert u is u2
 
@@ -641,7 +636,7 @@ class MapperTest(MapperSuperTest):
             })
             User.not_user_name
             assert False
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             assert str(e) == "Can't compile synonym '_user_name': no column on table 'users' named 'not_user_name'"
 
         clear_mappers()
@@ -742,33 +737,6 @@ class OptionsTest(MapperSuperTest):
             self.assert_result(u.adlist, Address, *(user_address_result[0]['addresses'][1]))
         self.assert_sql_count(testing.db, go, 1)
 
-    @testing.uses_deprecated('//select_by')
-    def test_extension_options(self):
-        sess  = create_session()
-        class ext1(MapperExtension):
-            def populate_instance(self, mapper, selectcontext, row, instance, **flags):
-                """test options at the Mapper._instance level"""
-                instance.TEST = "hello world"
-                return EXT_CONTINUE
-        mapper(User, users, extension=ext1(), properties={
-            'addresses':relation(mapper(Address, addresses), lazy=False)
-        })
-        class testext(MapperExtension):
-            def select_by(self, *args, **kwargs):
-                """test options at the Query level"""
-                return "HI"
-            def populate_instance(self, mapper, selectcontext, row, instance, **flags):
-                """test options at the Mapper._instance level"""
-                instance.TEST_2 = "also hello world"
-                return EXT_CONTINUE
-        l = sess.query(User).options(extension(testext())).select_by(x=5)
-        assert l == "HI"
-        l = sess.query(User).options(extension(testext())).get(7)
-        assert l.user_id == 7
-        assert l.TEST == "hello world"
-        assert l.TEST_2 == "also hello world"
-        assert not hasattr(l.addresses[0], 'TEST')
-        assert not hasattr(l.addresses[0], 'TEST2')
 
     def test_eageroptions(self):
         """tests that a lazy relation can be upgraded to an eager relation via the options method"""
@@ -927,9 +895,9 @@ class OptionsTest(MapperSuperTest):
 
         sess.clear()
 
-        self.assertRaisesMessage(exceptions.ArgumentError, 
-            r"Can't find entity Mapper\|Order\|orders in Query.  Current list: \['Mapper\|User\|users'\]", 
-            sess.query(User).options, eagerload('items', Order)
+        self.assertRaisesMessage(sa_exc.ArgumentError,
+            r"Can't find entity Mapper\|Order\|orders in Query.  Current list: \['Mapper\|User\|users'\]",
+            sess.query(User).options, eagerload(Order.items)
         )
 
         # eagerload "keywords" on items.  it will lazy load "orders", then lazy load
@@ -1333,11 +1301,29 @@ class MapperExtensionTest(TestBase):
     def setUpAll(self):
         tables.create()
 
-        global methods, Ext
+    def tearDown(self):
+        clear_mappers()
+        tables.delete()
 
+    def tearDownAll(self):
+        tables.drop()
+
+    def extension(self):
         methods = []
 
         class Ext(MapperExtension):
+            def instrument_class(self, mapper, cls):
+                methods.append('instrument_class')
+                return EXT_CONTINUE
+
+            def init_instance(self, mapper, class_, oldinit, instance, args, kwargs):
+                methods.append('init_instance')
+                return EXT_CONTINUE
+
+            def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
+                methods.append('init_failed')
+                return EXT_CONTINUE
+
             def load(self, query, *args, **kwargs):
                 methods.append('load')
                 return EXT_CONTINUE
@@ -1386,16 +1372,12 @@ class MapperExtensionTest(TestBase):
                 methods.append('after_delete')
                 return EXT_CONTINUE
 
-    def tearDown(self):
-        clear_mappers()
-        methods[:] = []
-        tables.delete()
-
-    def tearDownAll(self):
-        tables.drop()
+        return Ext, methods
 
     def test_basic(self):
         """test that common user-defined methods get called."""
+        Ext, methods = self.extension()
+
         mapper(User, users, extension=Ext())
         sess = create_session()
         u = User()
@@ -1408,13 +1390,17 @@ class MapperExtensionTest(TestBase):
         sess.flush()
         sess.delete(u)
         sess.flush()
-        self.assertEquals(methods, 
-            ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get', 'translate_row', 
-            'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete']        
-        )
+        self.assertEquals(methods,
+            ['instrument_class', 'init_instance', 'before_insert',
+             'after_insert', 'load', 'translate_row', 'populate_instance',
+             'append_result', 'get', 'translate_row', 'create_instance',
+             'populate_instance', 'append_result', 'before_update',
+             'after_update', 'before_delete', 'after_delete'])
+
 
     def test_inheritance(self):
-        # test using inheritance
+        Ext, methods = self.extension()
+
         class AdminUser(User):
             pass
 
@@ -1432,13 +1418,18 @@ class MapperExtensionTest(TestBase):
         sess.flush()
         sess.delete(am)
         sess.flush()
-        self.assertEquals(methods, 
-        ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get', 
-        'translate_row', 'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete'])
+        self.assertEquals(methods,
+            ['instrument_class', 'instrument_class', 'init_instance',
+             'before_insert', 'after_insert', 'load', 'translate_row',
+             'populate_instance', 'append_result', 'get', 'translate_row',
+             'create_instance', 'populate_instance', 'append_result',
+             'before_update', 'after_update', 'before_delete', 'after_delete'])
 
     def test_after_with_no_changes(self):
         # test that after_update is called even if no cols were updated
 
+        Ext, methods = self.extension()
+
         mapper(Item, orderitems, extension=Ext() , properties={
             'keywords':relation(Keyword, secondary=itemkeywords)
         })
@@ -1450,15 +1441,20 @@ class MapperExtensionTest(TestBase):
         sess.save(i1)
         sess.save(k1)
         sess.flush()
-        self.assertEquals(methods, ['before_insert', 'after_insert', 'before_insert', 'after_insert'])
+        self.assertEquals(methods,
+            ['instrument_class', 'instrument_class', 'init_instance',
+             'init_instance', 'before_insert', 'after_insert',
+             'before_insert', 'after_insert'])
 
-        methods[:] = []
+        del methods[:]
         i1.keywords.append(k1)
         sess.flush()
         self.assertEquals(methods, ['before_update', 'after_update'])
 
 
     def test_inheritance_with_dupes(self):
+        Ext, methods = self.extension()
+
         # test using inheritance, same extension on both mappers
         class AdminUser(User):
             pass
@@ -1478,10 +1474,49 @@ class MapperExtensionTest(TestBase):
         sess.flush()
         sess.delete(am)
         sess.flush()
-        self.assertEquals(methods, 
-            ['before_insert', 'after_insert', 'load', 'translate_row', 'populate_instance', 'append_result', 'get', 'translate_row', 
-            'create_instance', 'populate_instance', 'append_result', 'before_update', 'after_update', 'before_delete', 'after_delete']
-            )
+        self.assertEquals(methods,
+            ['instrument_class', 'instrument_class', 'init_instance',
+             'before_insert', 'after_insert', 'load', 'translate_row',
+             'populate_instance', 'append_result', 'get', 'translate_row',
+             'create_instance', 'populate_instance', 'append_result',
+             'before_update', 'after_update', 'before_delete', 'after_delete'])
+
+    def test_single_instrumentor(self):
+        ext_None, methods_None = self.extension()
+        ext_x, methods_x = self.extension()
+
+        def reset():
+            clear_mappers()
+            del methods_None[:]
+            del methods_x[:]
+
+        mapper(User, users, extension=ext_None())
+        mapper(User, users, extension=ext_x(), entity_name='x')
+        User()
+
+        self.assertEquals(methods_None, ['instrument_class', 'init_instance'])
+        self.assertEquals(methods_x, [])
+
+        reset()
+
+        mapper(User, users, extension=ext_x(), entity_name='x')
+        mapper(User, users, extension=ext_None())
+        User()
+
+        self.assertEquals(methods_x, ['instrument_class', 'init_instance'])
+        self.assertEquals(methods_None, [])
+
+        reset()
+
+        ext_y, methods_y = self.extension()
+
+        mapper(User, users, extension=ext_x(), entity_name='x')
+        mapper(User, users, extension=ext_y(), entity_name='y')
+        User()
+
+        self.assertEquals(methods_x, ['instrument_class', 'init_instance'])
+        self.assertEquals(methods_y, [])
+
 
 class RequirementsTest(ORMTest):
     """Tests the contract for user classes."""
@@ -1519,13 +1554,13 @@ class RequirementsTest(ORMTest):
         class OldStyle:
             pass
 
-        self.assertRaises(exceptions.ArgumentError, mapper, OldStyle, t1)
+        self.assertRaises(sa_exc.ArgumentError, mapper, OldStyle, t1)
 
         class NoWeakrefSupport(str):
             pass
 
         # TODO: is weakref support detectable without an instance?
-        #self.assertRaises(exceptions.ArgumentError, mapper, NoWeakrefSupport, t2)
+        #self.assertRaises(sa_exc.ArgumentError, mapper, NoWeakrefSupport, t2)
 
     def test_comparison_overrides(self):
         """Simple tests to ensure users can supply comparison __methods__.
@@ -1584,7 +1619,6 @@ class RequirementsTest(ORMTest):
                     return self.value == other.value
                 return False
 
-                
         mapper(H1, t1, properties={
             'h2s': relation(H2, backref='h1'),
             'h3s': relation(H3, secondary=t4, backref='h1s'),
@@ -1654,6 +1688,92 @@ class NoEqFoo(object):
     def __ne__(self, other):
         raise NotImplementedError()
 
+class MagicNamesTest(ORMTest):
+
+    def define_tables(self, metadata):
+        Table('cartographers', metadata,
+              Column('id', Integer, primary_key=True),
+              Column('name', String(50)),
+              Column('alias', String(50)),
+              Column('quip', String(100)))
+        Table('maps', metadata,
+              Column('id', Integer, primary_key=True),
+              Column('cart_id', Integer,
+                     ForeignKey('cartographers.id')),
+              Column('state', String(2)),
+              Column('data', Text))
+
+    def tables(self):
+        cat = testing._otest_metadata.tables
+        return cat['cartographers'], cat['maps']
+
+    def classes(self):
+        class Base(object):
+            def __init__(self, **kw):
+                for key, value in kw.iteritems():
+                    setattr(self, key, value)
+        class Cartographer(Base): pass
+        class Map(Base): pass
+
+        return Cartographer, Map
+
+    @testing.future
+    def test_mappish(self):
+        t1, t2 = self.tables()
+        Cartographer, Map = self.classes()
+        mapper(Cartographer, t1, properties=dict(
+            query=t1.c.quip))
+        mapper(Map, t2, properties=dict(
+            mapper=relation(Cartographer, backref='maps')))
+
+        c = Cartographer(name='Lenny', alias='The Dude',
+                         query='Where be dragons?')
+        m = Map(state='AK', mapper=c)
+
+        sess = create_session()
+        sess.save(c)
+        sess.flush()
+        sess.clear()
+
+        for C, M in ((Cartographer, Map), (aliased(Cartographer), aliased(Map))):
+            print C, M
+            c1 = (sess.query(C).
+                  filter(C.alias=='The Dude').
+                  filter(C.query=='Where be dragons?')).one()
+            m1 = sess.query(M).filter(M.mapper==c1).one()
+
+    @testing.future
+    def test_stateish(self):
+        from sqlalchemy.orm import attributes
+        if hasattr(attributes, 'ClassManager'):
+            syn1 = attributes.ClassManager.STATE_ATTR
+            syn2 = attributes.ClassManager.MANAGER_ATTR
+        else:
+            syn1 = '_state'
+            syn2 = '_class_state'
+
+
+        t1, t2 = self.tables()
+        Cartographer, Map = self.classes()
+        mapper(Map, t2, properties=dict(
+            syn1=t2.c.state,
+            syn2=t2.c.data))
+
+        m = Map()
+        setattr(m, syn1, 'AK')
+        setattr(m, syn2, '10x10')
+
+        sess = create_session()
+        sess.save(m)
+        sess.flush()
+        sess.clear()
+
+        for M in (Map, aliased(Map)):
+            print M
+            sess.query(M).filter(getattr(M, syn1) == 'AK').one()
+            sess.query(M).filter(getattr(M, syn2) == '10x10').one()
+
+
 class ScalarRequirementsTest(ORMTest):
     def define_tables(self, metadata):
         import pickle
@@ -1661,14 +1781,14 @@ class ScalarRequirementsTest(ORMTest):
         t1 = Table('t1', metadata, Column('id', Integer, primary_key=True),
             Column('data', PickleType(pickler=pickle))  # dont use cPickle due to import weirdness
         )
-        
+
     def test_correct_comparison(self):
-                
+
         class H1(fixtures.Base):
             pass
-            
+
         mapper(H1, t1)
-        
+
         h1 = H1(data=NoEqFoo('12345'))
         s = create_session()
         s.save(h1)
@@ -1676,7 +1796,7 @@ class ScalarRequirementsTest(ORMTest):
         s.clear()
         h1 = s.get(H1, h1.id)
         assert h1.data.data == '12345'
-        
+
 
 if __name__ == "__main__":
     testenv.main()
index fd61ccc28c4548bbfc75db94d352a259f5b3ac95..6ca42d53d996c3a6c1ad257155f782ef06b142bd 100644 (file)
@@ -1,8 +1,8 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import *
-from sqlalchemy.orm import mapperlib
+from sqlalchemy.orm import mapperlib, attributes
 from sqlalchemy.util import OrderedSet
 from testlib import *
 from testlib import fixtures
@@ -21,20 +21,34 @@ class MergeTest(TestBase, AssertsExecutionResults):
         clear_mappers()
         tables.delete()
 
+    def on_load_tracker(self, cls, canary=None):
+        if canary is None:
+            def canary(instance):
+                canary.called += 1
+            canary.called = 0
+
+        manager = attributes.manager_of_class(cls)
+        manager.events.add_listener('on_load', canary)
+
+        return canary
+
     def test_transient_to_pending(self):
         class User(fixtures.Base):
             pass
         mapper(User, users)
         sess = create_session()
+        on_load = self.on_load_tracker(User)
 
         u = User(user_id=7, user_name='fred')
+        assert on_load.called == 0
         u2 = sess.merge(u)
+        assert on_load.called == 1
         assert u2 in sess
         self.assertEquals(u2, User(user_id=7, user_name='fred'))
         sess.flush()
         sess.clear()
         self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred'))
-    
+
     def test_transient_to_pending_collection(self):
         class User(fixtures.Base):
             pass
@@ -42,47 +56,72 @@ class MergeTest(TestBase, AssertsExecutionResults):
             pass
         mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
         mapper(Address, addresses)
+        on_load = self.on_load_tracker(User)
+        self.on_load_tracker(Address, on_load)
 
         u = User(user_id=7, user_name='fred', addresses=OrderedSet([
             Address(address_id=1, email_address='fred1'),
             Address(address_id=2, email_address='fred2'),
-        ]))
+            ]))
+        assert on_load.called == 0
+
         sess = create_session()
         sess.merge(u)
+        assert on_load.called == 3
+
+        merged_users = [e for e in sess if isinstance(e, User)]
+        assert len(merged_users) == 1
+        assert merged_users[0] is not u
+
         sess.flush()
         sess.clear()
 
-        self.assertEquals(sess.query(User).one(), 
+        self.assertEquals(sess.query(User).one(),
             User(user_id=7, user_name='fred', addresses=OrderedSet([
                 Address(address_id=1, email_address='fred1'),
                 Address(address_id=2, email_address='fred2'),
             ]))
         )
-        
+
     def test_transient_to_persistent(self):
         class User(fixtures.Base):
             pass
         mapper(User, users)
+        on_load = self.on_load_tracker(User)
+
         sess = create_session()
         u = User(user_id=7, user_name='fred')
         sess.save(u)
         sess.flush()
         sess.clear()
-        
-        u2 = User(user_id=7, user_name='fred jones')
+
+        assert on_load.called == 0
+
+        _u2 = u2 = User(user_id=7, user_name='fred jones')
+        assert on_load.called == 0
         u2 = sess.merge(u2)
+        assert u2 is not _u2
+        assert on_load.called == 1
         sess.flush()
         sess.clear()
         self.assertEquals(sess.query(User).first(), User(user_id=7, user_name='fred jones'))
-        
+        assert on_load.called == 2
+
     def test_transient_to_persistent_collection(self):
         class User(fixtures.Base):
             pass
         class Address(fixtures.Base):
             pass
-        mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
+        mapper(User, users, properties={
+            'addresses':relation(Address, 
+                        backref='user', 
+                        collection_class=OrderedSet, cascade="all, delete-orphan")
+        })
         mapper(Address, addresses)
         
+        on_load = self.on_load_tracker(User)
+        self.on_load_tracker(Address, on_load)
+
         u = User(user_id=7, user_name='fred', addresses=OrderedSet([
             Address(address_id=1, email_address='fred1'),
             Address(address_id=2, email_address='fred2'),
@@ -91,14 +130,21 @@ class MergeTest(TestBase, AssertsExecutionResults):
         sess.save(u)
         sess.flush()
         sess.clear()
-        
+
+        assert on_load.called == 0
+
         u = User(user_id=7, user_name='fred', addresses=OrderedSet([
             Address(address_id=3, email_address='fred3'),
             Address(address_id=4, email_address='fred4'),
         ]))
-        
+
         u = sess.merge(u)
-        self.assertEquals(u, 
+        
+        assert on_load.called == 5, on_load.called    # 1. merges User object.  updates into session.
+                                                      # 2.,3. merges Address ids 3 & 4, saves into session.
+                                                      # 4.,5. loads pre-existing elements in "addresses" collection, 
+                                                      # marks as deleted, Address ids 1 and 2.
+        self.assertEquals(u,
             User(user_id=7, user_name='fred', addresses=OrderedSet([
                 Address(address_id=3, email_address='fred3'),
                 Address(address_id=4, email_address='fred4'),
@@ -106,13 +152,13 @@ class MergeTest(TestBase, AssertsExecutionResults):
         )
         sess.flush()
         sess.clear()
-        self.assertEquals(sess.query(User).one(), 
+        self.assertEquals(sess.query(User).one(),
             User(user_id=7, user_name='fred', addresses=OrderedSet([
                 Address(address_id=3, email_address='fred3'),
                 Address(address_id=4, email_address='fred4'),
             ]))
         )
-        
+
     def test_detached_to_persistent_collection(self):
         class User(fixtures.Base):
             pass
@@ -120,7 +166,9 @@ class MergeTest(TestBase, AssertsExecutionResults):
             pass
         mapper(User, users, properties={'addresses':relation(Address, backref='user', collection_class=OrderedSet)})
         mapper(Address, addresses)
-        
+        on_load = self.on_load_tracker(User)
+        self.on_load_tracker(Address, on_load)
+
         a = Address(address_id=1, email_address='fred1')
         u = User(user_id=7, user_name='fred', addresses=OrderedSet([
             a,
@@ -130,34 +178,39 @@ class MergeTest(TestBase, AssertsExecutionResults):
         sess.save(u)
         sess.flush()
         sess.clear()
-        
+
         u.user_name='fred jones'
         u.addresses.add(Address(address_id=3, email_address='fred3'))
         u.addresses.remove(a)
-        
+
+        assert on_load.called == 0
         u = sess.merge(u)
+        assert on_load.called == 4
         sess.flush()
         sess.clear()
-        
-        self.assertEquals(sess.query(User).first(), 
+
+        self.assertEquals(sess.query(User).first(),
             User(user_id=7, user_name='fred jones', addresses=OrderedSet([
                 Address(address_id=2, email_address='fred2'),
                 Address(address_id=3, email_address='fred3'),
             ]))
         )
-        
+
     def test_unsaved_cascade(self):
         """test merge of a transient entity with two child transient entities, with a bidirectional relation."""
-        
+
         class User(fixtures.Base):
             pass
         class Address(fixtures.Base):
             pass
-            
+
         mapper(User, users, properties={
             'addresses':relation(mapper(Address, addresses), cascade="all", backref="user")
         })
+        on_load = self.on_load_tracker(User)
+        self.on_load_tracker(Address, on_load)
         sess = create_session()
+
         u = User(user_id=7, user_name='fred')
         a1 = Address(email_address='foo@bar.com')
         a2 = Address(email_address='hoho@bar.com')
@@ -165,12 +218,16 @@ class MergeTest(TestBase, AssertsExecutionResults):
         u.addresses.append(a2)
 
         u2 = sess.merge(u)
+        assert on_load.called == 3
+
         self.assertEquals(u, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')]))
         self.assertEquals(u2, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')]))
         sess.flush()
         sess.clear()
         u2 = sess.query(User).get(7)
         self.assertEquals(u2, User(user_id=7, user_name='fred', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@bar.com')]))
+        assert on_load.called == 6
+
 
     def test_attribute_cascade(self):
         """test merge of a persistent entity with two child persistent entities."""
@@ -183,6 +240,9 @@ class MergeTest(TestBase, AssertsExecutionResults):
         mapper(User, users, properties={
             'addresses':relation(mapper(Address, addresses), backref='user')
         })
+        on_load = self.on_load_tracker(User)
+        self.on_load_tracker(Address, on_load)
+
         sess = create_session()
 
         # set up data and save
@@ -202,9 +262,12 @@ class MergeTest(TestBase, AssertsExecutionResults):
         u.user_name = 'fred2'
         u.addresses[1].email_address = 'hoho@lalala.com'
 
+        assert on_load.called == 3
+
         # new session, merge modified data into session
         sess3 = create_session()
         u3 = sess3.merge(u)
+        assert on_load.called == 6
 
         # ensure local changes are pending
         self.assertEquals(u3, User(user_id=7, user_name='fred2', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@lalala.com')]))
@@ -216,6 +279,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
         sess.clear()
         u = sess.query(User).get(7)
         self.assertEquals(u, User(user_id=7, user_name='fred2', addresses=[Address(email_address='foo@bar.com'), Address(email_address='hoho@lalala.com')]))
+        assert on_load.called == 9
 
         # merge persistent object into another session
         sess4 = create_session()
@@ -227,6 +291,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
             sess4.flush()
         # no changes; therefore flush should do nothing
         self.assert_sql_count(testing.db, go, 0)
+        assert on_load.called == 12
 
         # test with "dontload" merge
         sess5 = create_session()
@@ -240,6 +305,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
         # but also, dont_load wipes out any difference in committed state,
         # so no flush at all
         self.assert_sql_count(testing.db, go, 0)
+        assert on_load.called == 15
 
         sess4 = create_session()
         u = sess4.merge(u, dont_load=True)
@@ -249,11 +315,13 @@ class MergeTest(TestBase, AssertsExecutionResults):
             sess4.flush()
         # afafds change flushes
         self.assert_sql_count(testing.db, go, 1)
+        assert on_load.called == 18
 
         sess5 = create_session()
         u2 = sess5.query(User).get(u.user_id)
         assert u2.user_name == 'fred2'
         assert u2.addresses[1].email_address == 'afafds'
+        assert on_load.called == 21
 
     def test_one_to_many_cascade(self):
 
@@ -265,6 +333,9 @@ class MergeTest(TestBase, AssertsExecutionResults):
             'addresses':relation(mapper(Address, addresses)),
             'orders':relation(Order, backref='customer')
         })
+        on_load = self.on_load_tracker(User)
+        self.on_load_tracker(Address, on_load)
+        self.on_load_tracker(Order, on_load)
 
         sess = create_session()
         u = User()
@@ -282,16 +353,24 @@ class MergeTest(TestBase, AssertsExecutionResults):
         sess.save(u)
         sess.flush()
 
+        assert on_load.called == 0
+
         sess2 = create_session()
         u2 = sess2.query(User).get(u.user_id)
+        assert on_load.called == 1
+
         u.orders[0].items[1].item_name = 'item 2 modified'
         sess2.merge(u)
         assert u2.orders[0].items[1].item_name == 'item 2 modified'
+        assert on_load.called == 2
+
+        sess3 = create_session()
+        o2 = sess3.query(Order).get(o.order_id)
+        assert on_load.called == 3
 
-        sess2 = create_session()
-        o2 = sess2.query(Order).get(o.order_id)
         o.customer.user_name = 'also fred'
-        sess2.merge(o)
+        sess3.merge(o)
+        assert on_load.called == 4
         assert o2.customer.user_name == 'also fred'
 
     def test_one_to_one_cascade(self):
@@ -299,7 +378,10 @@ class MergeTest(TestBase, AssertsExecutionResults):
         mapper(User, users, properties={
             'address':relation(mapper(Address, addresses),uselist = False)
         })
+        on_load = self.on_load_tracker(User)
+        self.on_load_tracker(Address, on_load)
         sess = create_session()
+
         u = User()
         u.user_id = 7
         u.user_name = "fred"
@@ -310,19 +392,25 @@ class MergeTest(TestBase, AssertsExecutionResults):
         sess.save(u)
         sess.flush()
 
+        assert on_load.called == 0
+
         sess2 = create_session()
         u2 = sess2.query(User).get(7)
+        assert on_load.called == 1
         u2.user_name = 'fred2'
         u2.address.email_address = 'hoho@lalala.com'
+        assert on_load.called == 2
 
         u3 = sess.merge(u2)
-    
+        assert on_load.called == 2
+        assert u3 is u
+
     def test_transient_dontload(self):
         mapper(User, users)
 
         sess = create_session()
         u = User()
-        self.assertRaisesMessage(exceptions.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, "dont_load=True option does not support", sess.merge, u, dont_load=True)
 
 
     def test_dontload_with_backrefs(self):
@@ -407,7 +495,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
         try:
             sess2.merge(u, dont_load=True)
             assert False
-        except exceptions.InvalidRequestError, e:
+        except sa_exc.InvalidRequestError, e:
             assert "merge() with dont_load=True option does not support objects marked as 'dirty'.  flush() all changes on mapped instances before merging with dont_load=True." in str(e)
 
         u2 = sess2.query(User).get(7)
@@ -443,7 +531,8 @@ class MergeTest(TestBase, AssertsExecutionResults):
         u2 = sess2.merge(u, dont_load=True)
         assert not sess2.dirty
         # assert merged instance has a mapper and lazy load proceeds
-        assert hasattr(u2, '_entity_name')
+        state = attributes.instance_state(u2)
+        assert state.entity_name is not attributes.NO_ENTITY_NAME
         assert mapperlib.has_mapper(u2)
         def go():
             assert u2.addresses != []
@@ -505,7 +594,7 @@ class MergeTest(TestBase, AssertsExecutionResults):
         assert not sess2.dirty
         a2 = u2.addresses[0]
         a2.email_address='somenewaddress'
-        assert not object_mapper(a2)._is_orphan(a2)
+        assert not object_mapper(a2)._is_orphan(attributes.instance_state(a2))
         sess2.flush()
         sess2.clear()
         assert sess2.query(User).get(u2.user_id).addresses[0].email_address == 'somenewaddress'
@@ -526,11 +615,11 @@ class MergeTest(TestBase, AssertsExecutionResults):
             # if dont_load is changed to support dirty objects, this code needs to pass
             a2 = u2.addresses[0]
             a2.email_address='somenewaddress'
-            assert not object_mapper(a2)._is_orphan(a2)
+            assert not object_mapper(a2)._is_orphan(attributes.instance_state(a2))
             sess2.flush()
             sess2.clear()
             assert sess2.query(User).get(u2.user_id).addresses[0].email_address == 'somenewaddress'
-        except exceptions.InvalidRequestError, e:
+        except sa_exc.InvalidRequestError, e:
             assert "dont_load=True option does not support" in str(e)
 
 
index ec7d2fca9915ae662bfef6e6678b1fddb52801d8..67cf5e9adc0996baf9ba0bfff0328dfad426ded5 100644 (file)
@@ -1,8 +1,7 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy import exceptions
-
+from sqlalchemy.orm import attributes, exc as orm_exc
 from testlib.fixtures import *
 from testlib import *
 
@@ -62,17 +61,13 @@ class NaturalPKTest(ORMTest):
         sess.flush()
         assert sess.get(User, 'jack') is u1
 
-        users.update(values={u1.c.username:'jack'}).execute(username='ed')
+        users.update(values={User.username:'jack'}).execute(username='ed')
 
-        try:
-            # expire/refresh works off of primary key.  the PK is gone
-            # in this case so theres no way to look it up.  criterion-
-            # based session invalidation could solve this [ticket:911]
-            sess.expire(u1)
-            u1.username
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert "Could not refresh instance" in str(e)
+        # expire/refresh works off of primary key.  the PK is gone
+        # in this case so theres no way to look it up.  criterion-
+        # based session invalidation could solve this [ticket:911]
+        sess.expire(u1)
+        self.assertRaises(orm_exc.ObjectDeletedError, getattr, u1, 'username')
 
         sess.clear()
         assert sess.get(User, 'jack') is None
@@ -154,7 +149,7 @@ class NaturalPKTest(ORMTest):
         u1.username = 'ed'
 
         print id(a1), id(a2), id(u1)
-        print u1._state.parents
+        print attributes.instance_state(u1).parents
         def go():
             sess.flush()
         if passive_updates:
index ae0d6ef86dc506b23aa8dc80c53426ed80328cab..eb425c5779891f980079e41f15d2a65859257039 100644 (file)
@@ -1,7 +1,6 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
 from testlib import *
 
 class Jack(object):
@@ -29,7 +28,7 @@ class O2OTest(TestBase, AssertsExecutionResults):
     def setUpAll(self):
         global jack, port, metadata, ctx
         metadata = MetaData(testing.db)
-        ctx = SessionContext(create_session)
+        ctx = scoped_session(create_session)
         jack = Table('jack', metadata,
             Column('id', Integer, primary_key=True),
             #Column('room_id', Integer, ForeignKey("room.id")),
@@ -54,22 +53,21 @@ class O2OTest(TestBase, AssertsExecutionResults):
     def tearDownAll(self):
         metadata.drop_all()
 
-    @testing.uses_deprecated('SessionContext')
     def test1(self):
-        mapper(Port, port, extension=ctx.mapper_extension)
+        mapper(Port, port, extension=ctx.extension)
         mapper(Jack, jack, order_by=[jack.c.number],properties = {
             'port': relation(Port, backref='jack', uselist=False, lazy=True),
-        }, extension=ctx.mapper_extension)
+        }, extension=ctx.extension)
 
         j=Jack(number='101')
         p=Port(name='fa0/1')
         j.port=p
-        ctx.current.flush()
+        ctx.flush()
         jid = j.id
         pid = p.id
 
-        j=ctx.current.query(Jack).get(jid)
-        p=ctx.current.query(Port).get(pid)
+        j=ctx.query(Jack).get(jid)
+        p=ctx.query(Port).get(pid)
         print p.jack
         assert p.jack is not None
         assert p.jack is  j
@@ -77,17 +75,17 @@ class O2OTest(TestBase, AssertsExecutionResults):
         p.jack=None
         assert j.port is None #works
 
-        ctx.current.clear()
+        ctx.clear()
 
-        j=ctx.current.query(Jack).get(jid)
-        p=ctx.current.query(Port).get(pid)
+        j=ctx.query(Jack).get(jid)
+        p=ctx.query(Port).get(pid)
 
         j.port=None
         self.assert_(p.jack is None)
-        ctx.current.flush()
+        ctx.flush()
 
-        ctx.current.delete(j)
-        ctx.current.flush()
+        ctx.delete(j)
+        ctx.flush()
 
 if __name__ == "__main__":
     testenv.main()
index 84f5e5dafb2c5914d8dc02c2398f574fc945c197..6bb455d4166d813affc85623fe49cfed945b4f5b 100644 (file)
@@ -1,6 +1,5 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions
 from sqlalchemy.orm import *
 from testlib import *
 from testlib.fixtures import *
@@ -113,7 +112,7 @@ class PolymorphicDeferredTest(ORMTest):
             )
 
     def test_polymorphic_deferred(self):
-        mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type, polymorphic_fetch='deferred')
+        mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type)
         mapper(EmailUser, email_users, inherits=User, polymorphic_identity='emailuser')
 
         eu = EmailUser(name="user1", email_address='foo@bar.com')
index f1afdb90b40b140660e68b27e8edf700bc3e73dc..bc67740f275610b7cc5fd56a6b4d36b2bd4eba04 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import operator
 from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import exc as sa_exc, util
 from sqlalchemy.sql import compiler
 from sqlalchemy.engine import default
 from sqlalchemy.orm import *
@@ -10,12 +10,13 @@ from testlib import *
 from testlib import engines
 from testlib.fixtures import *
 
-from sqlalchemy.orm.util import _join as join, _outerjoin as outerjoin
+from sqlalchemy.orm.util import join, outerjoin, with_parent
 
 class QueryTest(FixtureTest):
     keep_mappers = True
     keep_data = True
 
+
     def setup_mappers(self):
         mapper(User, users, properties={
             'addresses':relation(Address, backref='user'),
@@ -68,11 +69,8 @@ class GetTest(QueryTest):
 
         s = create_session()
         
-        try:
-            s.query(User).join('addresses').filter(Address.user_id==8).get(7)
-            assert False
-        except exceptions.SAWarning, e:
-            assert str(e) == "Query.get() being called on a Query with existing criterion; criterion is being ignored."
+        q = s.query(User).join('addresses').filter(Address.user_id==8)
+        self.assertRaises(sa_exc.SAWarning, q.get, 7)
 
         @testing.emits_warning('Query.*')
         def warns():
@@ -119,7 +117,7 @@ class GetTest(QueryTest):
         try:
             assert s.query(User).load(19) is None
             assert False
-        except exceptions.InvalidRequestError:
+        except sa_exc.InvalidRequestError:
             assert True
 
         u = s.query(User).load(7)
@@ -193,6 +191,29 @@ class GetTest(QueryTest):
         assert u.addresses[0].email_address == 'jack@bean.com'
         assert u.orders[1].items[2].description == 'item 5'
 
+class InvalidGenerationsTest(QueryTest):
+    def test_no_limit_offset(self):
+        s = create_session()
+        
+        q = s.query(User).limit(2)
+        self.assertRaises(sa_exc.SAWarning, q.join, "addresses")
+
+        self.assertRaises(sa_exc.SAWarning, q.filter, User.name=='ed')
+
+        self.assertRaises(sa_exc.SAWarning, q.filter_by, name='ed')
+    
+    def test_no_from(self):
+        s = create_session()
+    
+        q = s.query(User).select_from(users)
+        self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+
+        q = s.query(User).join('addresses')
+        self.assertRaises(sa_exc.InvalidRequestError, q.select_from, users)
+        
+        # this is fine, however
+        q.from_self()
+        
 class OperatorTest(QueryTest):
     """test sql.Comparator implementation for MapperProperties"""
 
@@ -268,8 +289,40 @@ class OperatorTest(QueryTest):
             c = expr.compile(dialect=default.DefaultDialect())
             assert str(c) == compare, "%s != %s" % (str(c), compare)
 
+class RawSelectTest(QueryTest, AssertsCompiledSQL):
+    """compare a bunch of select() tests with the equivalent Query using straight table/columns.
+    
+    Results should be the same as Query should act as a select() pass-thru for ClauseElement entities.
+    
+    """
+    def test_select(self):
+        sess = create_session()
+
+        self.assert_compile(sess.query(users).select_from(users.select()).with_labels().statement, 
+            "SELECT users.id AS users_id, users.name AS users_name FROM users, (SELECT users.id AS id, users.name AS name FROM users) AS anon_1")
+
+        self.assert_compile(sess.query(users, exists([1], from_obj=addresses)).with_labels().statement, 
+            "SELECT users.id AS users_id, users.name AS users_name, EXISTS (SELECT 1 FROM addresses) AS anon_1 FROM users")
 
+        # a little tedious here, adding labels to work around Query's auto-labelling.
+        # also correlate needed explicitly.  hmmm.....
+        # TODO: can we detect only one table in the "froms" and then turn off use_labels ?
+        s = sess.query(addresses.c.id.label('id'), addresses.c.email_address.label('email')).\
+            filter(addresses.c.user_id==users.c.id).correlate(users).statement.alias()
+            
+        self.assert_compile(sess.query(users, s.c.email).select_from(users.join(s, s.c.id==users.c.id)).with_labels().statement, 
+                "SELECT users.id AS users_id, users.name AS users_name, anon_1.email AS anon_1_email "
+                "FROM users JOIN (SELECT addresses.id AS id, addresses.email_address AS email FROM addresses "
+                "WHERE addresses.user_id = users.id) AS anon_1 ON anon_1.id = users.id",
+                dialect=default.DefaultDialect()
+            )
+
+        x = func.lala(users.c.id).label('foo')
+        self.assert_compile(sess.query(x).filter(x==5).statement, 
+            "SELECT lala(users.id) AS foo FROM users WHERE lala(users.id) = :param_1", dialect=default.DefaultDialect())
+        
 class CompileTest(QueryTest):
+        
     def test_deferred(self):
         session = create_session()
         s = session.query(User).filter(and_(addresses.c.email_address == bindparam('emailad'), Address.user_id==User.id)).compile()
@@ -324,7 +377,7 @@ class FilterTest(QueryTest):
         try:
             sess.query(User).filter(User.addresses == address)
             assert False
-        except exceptions.InvalidRequestError:
+        except sa_exc.InvalidRequestError:
             assert True
 
         assert [User(id=10)] == sess.query(User).filter(User.addresses==None).all()
@@ -332,7 +385,7 @@ class FilterTest(QueryTest):
         try:
             assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
             assert False
-        except exceptions.InvalidRequestError:
+        except sa_exc.InvalidRequestError:
             assert True
 
         #assert [User(id=7), User(id=9), User(id=10)] == sess.query(User).filter(User.addresses!=address).all()
@@ -348,33 +401,15 @@ class FilterTest(QueryTest):
             filter(User.addresses.any(id=4)).all()
 
         assert [User(id=9)] == sess.query(User).filter(User.addresses.any(email_address='fred@fred.com')).all()
-
-    @testing.fails_on_everything_except()
-    def test_broken_any_1(self):
-        sess = create_session()
         
-        # overcorrelates
+        # test that any() doesn't overcorrelate
         assert [User(id=7), User(id=8)] == sess.query(User).join("addresses").filter(~User.addresses.any(Address.email_address=='fred@fred.com')).all()
-
-    def test_broken_any_2(self):
-        sess = create_session()
         
-        # works, filter is before the join
-        assert [User(id=7), User(id=8)] == sess.query(User).filter(~User.addresses.any(Address.email_address=='fred@fred.com')).join("addresses", aliased=True).all()
-        
-    def test_broken_any_3(self):
-        sess = create_session()
-        
-        # works, filter is after the join, but reset_joinpoint is called, removing aliasing
-        assert [User(id=7), User(id=8)] == sess.query(User).join("addresses", aliased=True).filter(Address.email_address != None).reset_joinpoint().filter(~User.addresses.any(email_address='fred@fred.com')).all()
+        # test that the contents are not adapted by the aliased join
+        assert [User(id=7), User(id=8)] == sess.query(User).join("addresses", aliased=True).filter(~User.addresses.any(Address.email_address=='fred@fred.com')).all()
 
-    @testing.fails_on_everything_except()
-    def test_broken_any_4(self):
-        sess = create_session()
-        
-        # filter is after the join, gets aliased.  in 0.5 any(), has() and not contains() are shielded from aliasing
         assert [User(id=10)] == sess.query(User).outerjoin("addresses", aliased=True).filter(~User.addresses.any()).all()
-
+        
     @testing.unsupported('maxdb') # can core
     def test_has(self):
         sess = create_session()
@@ -384,6 +419,12 @@ class FilterTest(QueryTest):
 
         assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
 
+        # test has() doesn't overcorrelate
+        assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user").filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+
+        # test has() doesnt' get subquery contents adapted by aliased join
+        assert [Address(id=2), Address(id=3), Address(id=4)] == sess.query(Address).join("user", aliased=True).filter(Address.user.has(User.name.like('%ed%'), id=8)).all()
+        
         dingaling = sess.query(Dingaling).get(2)
         assert [User(id=9)] == sess.query(User).filter(User.addresses.any(Address.dingaling==dingaling)).all()
         
@@ -457,23 +498,39 @@ class FromSelfTest(QueryTest):
             (User(id=8), Address(id=4)),
             (User(id=9), Address(id=5))
         ] == create_session().query(User).filter(User.id.in_([8,9]))._from_self().join('addresses').add_entity(Address).order_by(User.id, Address.id).all()
+    
+    def test_multiple_entities(self):
+        sess = create_session()
+
+        self.assertEquals(
+            sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().all(),
+            [
+                (User(id=8), Address(id=2)),
+                (User(id=9), Address(id=5))
+            ]
+        )
+
+        self.assertEquals(
+            sess.query(User, Address).filter(User.id==Address.user_id).filter(Address.id.in_([2, 5]))._from_self().options(eagerload('addresses')).first(),
+            (User(id=8, addresses=[Address(), Address(), Address()]), Address(id=2)),
+        )
         
 class AggregateTest(QueryTest):
+
     def test_sum(self):
         sess = create_session()
         orders = sess.query(Order).filter(Order.id.in_([2, 3, 4]))
         assert orders.sum(Order.user_id * Order.address_id) == 79
 
-    @testing.uses_deprecated('Call to deprecated function apply_sum')
     def test_apply(self):
         sess = create_session()
-        assert sess.query(Order).apply_sum(Order.user_id * Order.address_id).filter(Order.id.in_([2, 3, 4])).one() == 79
+        assert sess.query(func.sum(Order.user_id * Order.address_id)).filter(Order.id.in_([2, 3, 4])).one() == (79,)
 
     def test_having(self):
         sess = create_session()
-        assert [User(name=u'ed',id=8)] == sess.query(User).group_by([c for c in User.c]).join('addresses').having(func.count(Address.c.id)> 2).all()
+        assert [User(name=u'ed',id=8)] == sess.query(User).group_by(User).join('addresses').having(func.count(Address.id)> 2).all()
 
-        assert [User(name=u'jack',id=7), User(name=u'fred',id=9)] == sess.query(User).group_by([c for c in User.c]).join('addresses').having(func.count(Address.c.id)< 2).all()
+        assert [User(name=u'jack',id=7), User(name=u'fred',id=9)] == sess.query(User).group_by(User).join('addresses').having(func.count(Address.id)< 2).all()
 
 class CountTest(QueryTest):
     def test_basic(self):
@@ -561,10 +618,16 @@ class ParentTest(QueryTest):
         o = sess.query(Order).with_parent(u1, property='orders').all()
         assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
 
-        # test static method
-        o = Query.query_from_parent(u1, property='orders', session=sess).all()
+        o = sess.query(Order).filter(with_parent(u1, User.orders)).all()
         assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
-
+        
+        # test static method
+        @testing.uses_deprecated(".*query_from_parent")
+        def go():
+            o = Query.query_from_parent(u1, property='orders', session=sess).all()
+            assert [Order(description="order 1"), Order(description="order 3"), Order(description="order 5")] == o
+        go()
+        
         # test generative criterion
         o = sess.query(Order).with_parent(u1).filter(orders.c.id>2).all()
         assert [Order(description="order 3"), Order(description="order 5")] == o
@@ -582,7 +645,7 @@ class ParentTest(QueryTest):
         try:
             q = sess.query(Item).with_parent(u1)
             assert False
-        except exceptions.InvalidRequestError, e:
+        except sa_exc.InvalidRequestError, e:
             assert str(e) == "Could not locate a property which relates instances of class 'Item' to instances of class 'User'"
 
     def test_m2m(self):
@@ -594,28 +657,6 @@ class ParentTest(QueryTest):
 
 class JoinTest(QueryTest):
 
-    def test_getjoinable_tables(self):
-        sess = create_session()
-
-        sel1 = select([users]).alias()
-        sel2 = select([users], from_obj=users.join(addresses)).alias()
-
-        j1 = sel1.join(users, sel1.c.id==users.c.id)
-        j2 = j1.join(addresses)
-
-        for from_obj, assert_cond in (
-            (users, [users]),
-            (users.join(addresses), [users, addresses]),
-            (sel1, [sel1]),
-            (sel2, [sel2]),
-            (sel1.join(users, sel1.c.id==users.c.id), [sel1, users]),
-            (sel2.join(users, sel2.c.id==users.c.id), [sel2, users]),
-            (j2, [j1, j2, sel1, users, addresses])
-
-        ):
-            ret = set(sess.query(User).select_from(from_obj)._get_joinable_tables())
-            self.assertEquals(ret, set(assert_cond).union([from_obj]), [x.description for x in ret])
-
     def test_overlapping_paths(self):
         for aliased in (True,False):
             # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack)
@@ -654,7 +695,34 @@ class JoinTest(QueryTest):
 
     def test_orderby_arg_bug(self):
         sess = create_session()
+        # no arg error
+        result = sess.query(User).join('orders', aliased=True).order_by([Order.id]).reset_joinpoint().order_by(users.c.id).all()
+    
+    def test_no_onclause(self):
+        sess = create_session()
+
+        self.assertEquals(
+            sess.query(User).select_from(join(User, Order).join(Item, Order.items)).filter(Item.description == 'item 4').all(),
+            [User(name='jack')]
+        )
+
+        self.assertEquals(
+            sess.query(User).join(Order, (Item, Order.items)).filter(Item.description == 'item 4').all(),
+            [User(name='jack')]
+        )
         
+    def test_clause_onclause(self):
+        sess = create_session()
+
+        self.assertEquals(
+            sess.query(User).join(
+                (Order, User.id==Order.user_id), 
+                (order_items, Order.id==order_items.c.order_id), 
+                (Item, order_items.c.item_id==Item.id)
+            ).filter(Item.description == 'item 4').all(),
+            [User(name='jack')]
+        )
+
         # no arg error
         result = sess.query(User).join('orders', aliased=True).order_by([Order.id]).reset_joinpoint().order_by(users.c.id).all()
         
@@ -682,13 +750,43 @@ class JoinTest(QueryTest):
         l = q.select_from(outerjoin(User, AdAlias)).filter(AdAlias.email_address=='ed@bettyboop.com').all()
         self.assertEquals(l, [(user8, address3)])
 
-
         l = q.select_from(outerjoin(User, AdAlias, 'addresses')).filter(AdAlias.email_address=='ed@bettyboop.com').all()
         self.assertEquals(l, [(user8, address3)])
 
         l = q.select_from(outerjoin(User, AdAlias, User.id==AdAlias.user_id)).filter(AdAlias.email_address=='ed@bettyboop.com').all()
         self.assertEquals(l, [(user8, address3)])
 
+        # this is the first test where we are joining "backwards" - from AdAlias to User even though
+        # the query is against User
+        q = sess.query(User, AdAlias)
+        l = q.join(AdAlias.user).filter(User.name=='ed')
+        self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
+
+        q = sess.query(User, AdAlias).select_from(join(AdAlias, User, AdAlias.user)).filter(User.name=='ed')
+        self.assertEquals(l.all(), [(user8, address2),(user8, address3),(user8, address4),])
+        
+    def test_implicit_joins_from_aliases(self):
+        sess = create_session()
+        OrderAlias = aliased(Order)
+
+        self.assertEquals(
+            sess.query(OrderAlias).join('items').filter_by(description='item 3').all(),
+            [
+                Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), 
+                Order(address_id=4,description=u'order 2',isopen=0,user_id=9,id=2), 
+                Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3)
+            ]
+        )
+         
+        self.assertEquals(
+            sess.query(User, OrderAlias, Item.description).join(('orders', OrderAlias), 'items').filter_by(description='item 3').all(),
+            [
+                (User(name=u'jack',id=7), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1), u'item 3'), 
+                (User(name=u'jack',id=7), Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), u'item 3'), 
+                (User(name=u'fred',id=9), Order(address_id=4,description=u'order 2',isopen=0,user_id=9,id=2), u'item 3')
+            ]
+        )   
+        
     def test_aliased_classes_m2m(self):
         sess = create_session()
         
@@ -725,20 +823,6 @@ class JoinTest(QueryTest):
             ]
         )
         
-    def test_generative_join(self):
-        # test that alised_ids is copied
-        sess = create_session()
-        q = sess.query(User).add_entity(Address)
-        q1 = q.join('addresses', aliased=True)
-        q2 = q.join('addresses', aliased=True)
-        q3 = q2.join('addresses', aliased=True)
-        q4 = q2.join('addresses', aliased=True, id='someid')
-        q5 = q2.join('addresses', aliased=True, id='someid')
-        q6 = q5.join('addresses', aliased=True, id='someid')
-        assert q1._alias_ids[class_mapper(Address)] != q2._alias_ids[class_mapper(Address)]
-        assert q2._alias_ids[class_mapper(Address)] != q3._alias_ids[class_mapper(Address)]
-        assert q4._alias_ids['someid'] != q5._alias_ids['someid']
-        
     def test_reset_joinpoint(self):
         for aliased in (True, False):
             # load a user who has an order that contains item id 3 and address id 1 (order 3, owned by jack)
@@ -779,43 +863,19 @@ class JoinTest(QueryTest):
         assert q.count() == 1
         assert [User(id=7)] == q.all()
 
+
         # test the control version - same joins but not aliased.  rows are not returned because order 3 does not have item 1
-        # addtionally by placing this test after the previous one, test that the "aliasing" step does not corrupt the
-        # join clauses that are cached by the relationship.
-        q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Order.description=="item 1")
+        q = sess.query(User).join('orders').filter(Order.description=="order 3").join(['orders', 'items']).filter(Item.description=="item 1")
         assert [] == q.all()
         assert q.count() == 0
 
         q = sess.query(User).join('orders', aliased=True).filter(Order.items.any(Item.description=='item 4'))
         assert [User(id=7)] == q.all()
-
-    def test_aliased_add_entity(self):
-        """test the usage of aliased joins with add_entity()"""
-        sess = create_session()
-        q = sess.query(User).join('orders', aliased=True, id='order1').filter(Order.description=="order 3").join(['orders', 'items'], aliased=True, id='item1').filter(Item.description=="item 1")
-
-        try:
-            q.add_entity(Order, id='fakeid').compile()
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Query has no alias identified by 'fakeid'"
-
-        try:
-            q.add_entity(Order, id='fakeid').instances(None)
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Query has no alias identified by 'fakeid'"
-
-        q = q.add_entity(Order, id='order1').add_entity(Item, id='item1')
+        
+        # test that aliasing gets reset when join() is called
+        q = sess.query(User).join('orders', aliased=True).filter(Order.description=="order 3").join('orders', aliased=True).filter(Order.description=="order 5")
         assert q.count() == 1
-        assert [(User(id=7), Order(description='order 3'), Item(description='item 1'))] == q.all()
-
-        q = sess.query(User).add_entity(Order).join('orders', aliased=True).filter(Order.description=="order 3").join('orders', aliased=True).filter(Order.description=='order 4')
-        try:
-            q.compile()
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Ambiguous join for entity 'Mapper|Order|orders'; specify id=<someid> to query.join()/query.add_entity()"
+        assert [User(id=7)] == q.all()
 
 class MultiplePathTest(ORMTest):
     def define_tables(self, metadata):
@@ -849,11 +909,10 @@ class MultiplePathTest(ORMTest):
         })
         mapper(T2, t2)
 
-        try:
-            create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint().join('t2s_2')
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Can't join to property 't2s_2'; a path to this table along a different secondary table already exists.  Use the `alias=True` argument to `join()`."
+        q = create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint()
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, "a path to this table along a different secondary table already exists.",
+            q.join, 't2s_2'
+        )
 
         create_session().query(T1).join('t2s_1', aliased=True).filter(t2.c.id==5).reset_joinpoint().join('t2s_2').all()
         create_session().query(T1).join('t2s_1').filter(t2.c.id==5).reset_joinpoint().join('t2s_2', aliased=True).all()
@@ -926,26 +985,34 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
             assert fixtures.user_address_result == l
         self.assert_sql_count(testing.db, go, 1)
 
+        # better way.  use select_from()
+        def go():
+            l = sess.query(User).select_from(query).options(contains_eager('addresses')).all()
+            assert fixtures.user_address_result == l
+        self.assert_sql_count(testing.db, go, 1)
+
     def test_contains_eager(self):
         sess = create_session()
 
+        # test that contains_eager suppresses the normal outer join rendering
         q = sess.query(User).outerjoin(User.addresses).options(contains_eager(User.addresses))
-        self.assert_compile(q.statement, "SELECT addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "
-        "addresses.email_address AS addresses_email_address, users.id AS users_id, users.name AS users_name "\
-        "FROM users LEFT OUTER JOIN addresses ON users.id = addresses.user_id ORDER BY users.id", dialect=default.DefaultDialect())
-        
+        self.assert_compile(q.with_labels().statement, "SELECT users.id AS users_id, users.name AS users_name, "\
+                "addresses.id AS addresses_id, addresses.user_id AS addresses_user_id, "\
+                "addresses.email_address AS addresses_email_address FROM users LEFT OUTER JOIN addresses "\
+                "ON users.id = addresses.user_id ORDER BY users.id", dialect=default.DefaultDialect())
+
         def go():
             assert fixtures.user_address_result == q.all()
         self.assert_sql_count(testing.db, go, 1)
         sess.clear()
-        
+
         adalias = addresses.alias()
         q = sess.query(User).select_from(users.outerjoin(adalias)).options(contains_eager(User.addresses, alias=adalias))
         def go():
             assert fixtures.user_address_result == q.all()
         self.assert_sql_count(testing.db, go, 1)
         sess.clear()
-        
+
         selectquery = users.outerjoin(addresses).select(users.c.id<10, use_labels=True, order_by=[users.c.id, addresses.c.id])
         q = sess.query(User)
 
@@ -956,6 +1023,13 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
 
         sess.clear()
 
+
+        def go():
+            l = q.options(contains_eager(User.addresses)).instances(selectquery.execute())
+            assert fixtures.user_address_result[0:3] == l
+        self.assert_sql_count(testing.db, go, 1)
+        sess.clear()
+
         def go():
             l = q.options(contains_eager('addresses')).from_statement(selectquery).all()
             assert fixtures.user_address_result[0:3] == l
@@ -966,38 +1040,34 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
         selectquery = users.outerjoin(adalias).select(use_labels=True, order_by=[users.c.id, adalias.c.id])
         sess = create_session()
         q = sess.query(User)
-
+        
+        # string alias name
         def go():
-            # test using a string alias name
             l = q.options(contains_eager('addresses', alias="adalias")).instances(selectquery.execute())
             assert fixtures.user_address_result == l
         self.assert_sql_count(testing.db, go, 1)
         sess.clear()
 
+        # expression.Alias object
         def go():
-            # test using the Alias object itself
             l = q.options(contains_eager('addresses', alias=adalias)).instances(selectquery.execute())
             assert fixtures.user_address_result == l
         self.assert_sql_count(testing.db, go, 1)
 
         sess.clear()
 
-        def decorate(row):
-            d = {}
-            for c in addresses.c:
-                d[c] = row[adalias.corresponding_column(c)]
-            return d
-
+        # Aliased object
+        adalias = aliased(Address)
         def go():
-            # test using a custom 'decorate' function
-            l = q.options(contains_eager('addresses', decorator=decorate)).instances(selectquery.execute())
-            assert fixtures.user_address_result == l
+            l = q.options(contains_eager('addresses', alias=adalias)).outerjoin((adalias, User.addresses)).order_by(User.id, adalias.id)
+            assert fixtures.user_address_result == l.all()
         self.assert_sql_count(testing.db, go, 1)
         sess.clear()
+        
 
         oalias = orders.alias('o1')
         ialias = items.alias('i1')
-        query = users.outerjoin(oalias).outerjoin(order_items).outerjoin(ialias).select(use_labels=True).order_by(users.c.id).order_by(oalias.c.id).order_by(ialias.c.id)
+        query = users.outerjoin(oalias).outerjoin(order_items).outerjoin(ialias).select(use_labels=True).order_by(users.c.id, oalias.c.id, ialias.c.id)
         q = create_session().query(User)
         # test using string alias with more than one level deep
         def go():
@@ -1014,9 +1084,24 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
         self.assert_sql_count(testing.db, go, 1)
         sess.clear()
 
+        # test using Aliased with more than one level deep
+        oalias = aliased(Order)
+        ialias = aliased(Item)
+        def go():
+            l = q.options(contains_eager(User.orders, alias=oalias), contains_eager(User.orders, Order.items, alias=ialias)).\
+                outerjoin((oalias, User.orders), (ialias, Order.items)).order_by(User.id, oalias.id, ialias.id)
+            assert fixtures.user_order_result == l.all()
+        self.assert_sql_count(testing.db, go, 1)
+        sess.clear()
+
+
+class MixedEntitiesTest(QueryTest):
+
     def test_values(self):
         sess = create_session()
 
+        assert list(sess.query(User).values()) == list()
+
         sel = users.select(User.id.in_([7, 8])).alias()
         q = sess.query(User)
         q2 = q.select_from(sel).values(User.name)
@@ -1035,19 +1120,166 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
         q2 = q.join('addresses').filter(User.name.like('%e%')).order_by(desc(Address.email_address))[1:3].values(User.name, Address.email_address)
         self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@lala.com')])
         
-        q2 = q.join('addresses', aliased=True).filter(User.name.like('%e%')).values(User.name, Address.email_address)
+        adalias = aliased(Address)
+        q2 = q.join(('addresses', adalias)).filter(User.name.like('%e%')).values(User.name, adalias.email_address)
         self.assertEquals(list(q2), [(u'ed', u'ed@wood.com'), (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), (u'fred', u'fred@fred.com')])
         
         q2 = q.values(func.count(User.name))
         assert q2.next() == (4,)
 
-        u2 = users.alias()
-        q2 = q.select_from(sel).filter(u2.c.id>1).order_by([users.c.id, sel.c.id, u2.c.id]).values(users.c.name, sel.c.name, u2.c.name)
+        u2 = aliased(User)
+        q2 = q.select_from(sel).filter(u2.id>1).order_by([User.id, sel.c.id, u2.id]).values(User.name, sel.c.name, u2.name)
         self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'jack', u'jack', u'ed'), (u'jack', u'jack', u'fred'), (u'jack', u'jack', u'chuck'), (u'ed', u'ed', u'jack'), (u'ed', u'ed', u'ed'), (u'ed', u'ed', u'fred'), (u'ed', u'ed', u'chuck')])
         
-        q2 = q.select_from(sel).filter(users.c.id>1).values(users.c.name, sel.c.name, User.name)
-        self.assertEquals(list(q2), [(u'jack', u'jack', u'jack'), (u'ed', u'ed', u'ed')])
+        q2 = q.select_from(sel).filter(User.id==8).values(User.name, sel.c.name, User.name)
+        self.assertEquals(list(q2), [(u'ed', u'ed', u'ed')])
+
+        # using User.xxx is alised against "sel", so this query returns nothing
+        q2 = q.select_from(sel).filter(User.id==8).filter(User.id>sel.c.id).values(User.name, sel.c.name, User.name)
+        self.assertEquals(list(q2), [])
+
+        # whereas this uses users.c.xxx, is not aliased and creates a new join
+        q2 = q.select_from(sel).filter(users.c.id==8).filter(users.c.id>sel.c.id).values(users.c.name, sel.c.name, User.name)
+        self.assertEquals(list(q2), [(u'ed', u'jack', u'jack')])
     
+    def test_tuple_labeling(self):
+        sess = create_session()
+        for row in sess.query(User, Address).join(User.addresses).all():
+            self.assertEquals(set(row.keys()), set(['User', 'Address']))
+            self.assertEquals(row.User, row[0])
+            self.assertEquals(row.Address, row[1])
+            
+        for row in sess.query(User.name, User.id.label('foobar')):
+            self.assertEquals(set(row.keys()), set(['name', 'foobar']))
+            self.assertEquals(row.name, row[0])
+            self.assertEquals(row.foobar, row[1])
+
+        for row in sess.query(User).values(User.name, User.id.label('foobar')):
+            self.assertEquals(set(row.keys()), set(['name', 'foobar']))
+            self.assertEquals(row.name, row[0])
+            self.assertEquals(row.foobar, row[1])
+
+        oalias = aliased(Order)
+        for row in sess.query(User, oalias).join(User.orders).all():
+            self.assertEquals(set(row.keys()), set(['User']))
+            self.assertEquals(row.User, row[0])
+
+        oalias = aliased(Order, name='orders')
+        for row in sess.query(User, oalias).join(User.orders).all():
+            self.assertEquals(set(row.keys()), set(['User', 'orders']))
+            self.assertEquals(row.User, row[0])
+            self.assertEquals(row.orders, row[1])
+
+
+    def test_column_queries(self):
+        sess = create_session()
+
+        self.assertEquals(sess.query(User.name).all(), [(u'jack',), (u'ed',), (u'fred',), (u'chuck',)])
+        
+        sel = users.select(User.id.in_([7, 8])).alias()
+        q = sess.query(User.name)
+        q2 = q.select_from(sel).all()
+        self.assertEquals(list(q2), [(u'jack',), (u'ed',)])
+
+        self.assertEquals(sess.query(User.name, Address.email_address).filter(User.id==Address.user_id).all(), [
+            (u'jack', u'jack@bean.com'), (u'ed', u'ed@wood.com'), 
+            (u'ed', u'ed@bettyboop.com'), (u'ed', u'ed@lala.com'), 
+            (u'fred', u'fred@fred.com')
+        ])
+        
+        self.assertEquals(sess.query(User.name, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User.id, User.name).order_by(User.id).all(), 
+            [(u'jack', 1), (u'ed', 3), (u'fred', 1), (u'chuck', 0)]
+        )
+
+        self.assertEquals(sess.query(User, func.count(Address.email_address)).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), 
+            [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)]
+        )
+
+        self.assertEquals(sess.query(func.count(Address.email_address), User).outerjoin(User.addresses).group_by(User).order_by(User.id).all(), 
+            [(1, User(name='jack',id=7)), (3, User(name='ed',id=8)), (1, User(name='fred',id=9)), (0, User(name='chuck',id=10))]
+        )
+        
+        adalias = aliased(Address)
+        self.assertEquals(sess.query(User, func.count(adalias.email_address)).outerjoin(('addresses', adalias)).group_by(User).order_by(User.id).all(), 
+            [(User(name='jack',id=7), 1), (User(name='ed',id=8), 3), (User(name='fred',id=9), 1), (User(name='chuck',id=10), 0)]
+        )
+
+        self.assertEquals(sess.query(func.count(adalias.email_address), User).outerjoin((User.addresses, adalias)).group_by(User).order_by(User.id).all(),
+            [(1, User(name=u'jack',id=7)), (3, User(name=u'ed',id=8)), (1, User(name=u'fred',id=9)), (0, User(name=u'chuck',id=10))]
+        )
+
+        # select from aliasing + explicit aliasing
+        self.assertEquals(
+            sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).order_by(User.id, adalias.id).all(),
+            [
+                (User(name=u'jack',id=7), u'jack@bean.com'), 
+                (User(name=u'ed',id=8), u'ed@wood.com'), 
+                (User(name=u'ed',id=8), u'ed@bettyboop.com'),
+                (User(name=u'ed',id=8), u'ed@lala.com'), 
+                (User(name=u'fred',id=9), u'fred@fred.com'), 
+                (User(name=u'chuck',id=10), None)
+            ]
+        )
+        
+        # anon + select from aliasing
+        self.assertEquals(
+            sess.query(User).join(User.addresses, aliased=True).filter(Address.email_address.like('%ed%')).from_self().all(),
+            [
+                User(name=u'ed',id=8), 
+                User(name=u'fred',id=9), 
+            ]
+        )
+
+        # test eager aliasing, with/without select_from aliasing
+        for q in [
+            sess.query(User, adalias.email_address).outerjoin((User.addresses, adalias)).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10),
+            sess.query(User, adalias.email_address, adalias.id).outerjoin((User.addresses, adalias)).from_self(User, adalias.email_address).options(eagerload(User.addresses)).order_by(User.id, adalias.id).limit(10),
+        ]:
+            self.assertEquals(
+                q.all(),
+                [(User(addresses=[Address(user_id=7,email_address=u'jack@bean.com',id=1)],name=u'jack',id=7), u'jack@bean.com'), 
+                (User(addresses=[
+                                    Address(user_id=8,email_address=u'ed@wood.com',id=2), 
+                                    Address(user_id=8,email_address=u'ed@bettyboop.com',id=3), 
+                                    Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@wood.com'), 
+                (User(addresses=[
+                            Address(user_id=8,email_address=u'ed@wood.com',id=2), 
+                            Address(user_id=8,email_address=u'ed@bettyboop.com',id=3), 
+                            Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@bettyboop.com'), 
+                (User(addresses=[
+                            Address(user_id=8,email_address=u'ed@wood.com',id=2), 
+                            Address(user_id=8,email_address=u'ed@bettyboop.com',id=3), 
+                            Address(user_id=8,email_address=u'ed@lala.com',id=4)],name=u'ed',id=8), u'ed@lala.com'), 
+                (User(addresses=[Address(user_id=9,email_address=u'fred@fred.com',id=5)],name=u'fred',id=9), u'fred@fred.com'), 
+
+                (User(addresses=[],name=u'chuck',id=10), None)]
+        )
+            
+    def test_self_referential(self):
+        
+        sess = create_session()
+        oalias = aliased(Order)
+        
+        for q in [
+            sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id),
+            sess.query(Order, oalias)._from_self().filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id),
+            # here we go....two layers of aliasing
+            sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id)._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)),
+
+            # gratuitous four layers
+            sess.query(Order, oalias).filter(Order.user_id==oalias.user_id).filter(Order.user_id==7).filter(Order.id>oalias.id)._from_self()._from_self()._from_self().order_by(Order.id, oalias.id).limit(10).options(eagerload(Order.items)),
+
+        ]:
+        
+            self.assertEquals(
+            q.all(),
+            [
+                (Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)), 
+                (Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5), Order(address_id=1,description=u'order 1',isopen=0,user_id=7,id=1)), 
+                (Order(address_id=None,description=u'order 5',isopen=0,user_id=7,id=5), Order(address_id=1,description=u'order 3',isopen=1,user_id=7,id=3))                
+            ]
+        )
+        
     def test_multi_mappers(self):
 
         test_session = create_session()
@@ -1055,7 +1287,6 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
         (user7, user8, user9, user10) = test_session.query(User).all()
         (address1, address2, address3, address4, address5) = test_session.query(Address).all()
 
-        # note the result is a cartesian product
         expected = [(user7, address1),
             (user8, address2),
             (user8, address3),
@@ -1066,30 +1297,24 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
         sess = create_session()
 
         selectquery = users.outerjoin(addresses).select(use_labels=True, order_by=[users.c.id, addresses.c.id])
-        q = sess.query(User)
-        l = q.instances(selectquery.execute(), Address)
-        assert l == expected
-
+        self.assertEquals(sess.query(User, Address).instances(selectquery.execute()), expected)
         sess.clear()
 
-        for aliased in (False, True):
-            q = sess.query(User)
-
-            q = q.add_entity(Address).outerjoin('addresses', aliased=aliased)
-            l = q.all()
-            assert l == expected
+        for address_entity in (Address, aliased(Address)):
+            q = sess.query(User).add_entity(address_entity).outerjoin(('addresses', address_entity)).order_by(User.id, address_entity.id)
+            self.assertEquals(q.all(), expected)
             sess.clear()
 
-            q = sess.query(User).add_entity(Address)
-            l = q.join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com').all()
-            assert l == [(user8, address3)]
+            q = sess.query(User).add_entity(address_entity)
+            q = q.join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com')
+            self.assertEquals(q.all(), [(user8, address3)])
             sess.clear()
 
-            q = sess.query(User, Address).join('addresses', aliased=aliased).filter_by(email_address='ed@bettyboop.com')
-            assert q.all() == [(user8, address3)]
+            q = sess.query(User, address_entity).join(('addresses', address_entity)).filter_by(email_address='ed@bettyboop.com')
+            self.assertEquals(q.all(), [(user8, address3)])
             sess.clear()
 
-            q = sess.query(User, Address).join('addresses', aliased=aliased).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
+            q = sess.query(User, address_entity).join(('addresses', address_entity)).options(eagerload('addresses')).filter_by(email_address='ed@bettyboop.com')
             self.assertEquals(list(util.OrderedSet(q.all())), [(user8, address3)])
             sess.clear()
 
@@ -1123,18 +1348,12 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
 
         expected = [(u, u.name) for u in sess.query(User).all()]
 
-        for add_col in (User.name, users.c.name, User.c.name):
+        for add_col in (User.name, users.c.name):
             assert sess.query(User).add_column(add_col).all() == expected
             sess.clear()
 
-        self.assertRaises(exceptions.InvalidRequestError, sess.query(User).add_column, object())
+        self.assertRaises(sa_exc.InvalidRequestError, sess.query(User).add_column, object())
     
-    def test_ambiguous_column(self):
-        sess = create_session()
-        
-        q = sess.query(User).join('addresses', aliased=True).join('addresses', aliased=True).add_column(Address.id)
-        self.assertRaises(exceptions.InvalidRequestError, iter, q)
-        
     def test_multi_columns_2(self):
         """test aliased/nonalised joins with the usage of add_column()"""
         sess = create_session()
@@ -1146,12 +1365,16 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
             (user10, 0)
             ]
 
-        for aliased in (False, True):
-            q = sess.query(User)
-            q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses', aliased=aliased).add_column(func.count(Address.id).label('count'))
-            l = q.all()
-            assert l == expected
-            sess.clear()
+        q = sess.query(User)
+        q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin('addresses').add_column(func.count(Address.id).label('count'))
+        self.assertEquals(q.all(), expected)
+        sess.clear()
+        
+        adalias = aliased(Address)
+        q = sess.query(User)
+        q = q.group_by([c for c in users.c]).order_by(User.id).outerjoin(('addresses', adalias)).add_column(func.count(adalias.id).label('count'))
+        self.assertEquals(q.all(), expected)
+        sess.clear()
 
         s = select([users, func.count(addresses.c.id).label('count')]).select_from(users.outerjoin(addresses)).group_by(*[c for c in users.c]).order_by(User.id)
         q = sess.query(User)
@@ -1159,7 +1382,7 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
         assert l == expected
 
 
-    def test_two_columns(self):
+    def test_raw_columns(self):
         sess = create_session()
         (user7, user8, user9, user10) = sess.query(User).all()
         expected = [
@@ -1168,8 +1391,9 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
             (user9, 1, "Name:fred"),
             (user10, 0, "Name:chuck")]
 
-        q = create_session().query(User).add_column(func.count(addresses.c.id))\
-            .add_column(("Name:" + users.c.name)).outerjoin('addresses', aliased=True)\
+        adalias = addresses.alias()
+        q = create_session().query(User).add_column(func.count(adalias.c.id))\
+            .add_column(("Name:" + users.c.name)).outerjoin(('addresses', adalias))\
             .group_by([c for c in users.c]).order_by(users.c.id)
 
         assert q.all() == expected
@@ -1190,14 +1414,19 @@ class InstancesTest(QueryTest, AssertsCompiledSQL):
         assert q.all() == expected
         sess.clear()
 
-        # test with outerjoin() both aliased and non
-        for aliased in (False, True):
-            q = create_session().query(User).add_column(func.count(addresses.c.id))\
-                .add_column(("Name:" + users.c.name)).outerjoin('addresses', aliased=aliased)\
-                .group_by([c for c in users.c]).order_by(users.c.id)
+        q = create_session().query(User).add_column(func.count(addresses.c.id))\
+            .add_column(("Name:" + users.c.name)).outerjoin('addresses')\
+            .group_by([c for c in users.c]).order_by(users.c.id)
 
-            assert q.all() == expected
-            sess.clear()
+        assert q.all() == expected
+        sess.clear()
+
+        q = create_session().query(User).add_column(func.count(adalias.c.id))\
+            .add_column(("Name:" + users.c.name)).outerjoin(('addresses', adalias))\
+            .group_by([c for c in users.c]).order_by(users.c.id)
+
+        assert q.all() == expected
+        sess.clear()
 
 
 class SelectFromTest(QueryTest):
@@ -1217,7 +1446,7 @@ class SelectFromTest(QueryTest):
 
         self.assertEquals(sess.query(User).select_from(sel).all(), [User(id=7), User(id=8)])
 
-        self.assertEquals(sess.query(User).select_from(sel).filter(User.c.id==8).all(), [User(id=8)])
+        self.assertEquals(sess.query(User).select_from(sel).filter(User.id==8).all(), [User(id=8)])
 
         self.assertEquals(sess.query(User).select_from(sel).order_by(desc(User.name)).all(), [
             User(name='jack',id=7), User(name='ed',id=8)
@@ -1273,7 +1502,8 @@ class SelectFromTest(QueryTest):
             ]
         )
 
-        self.assertEquals(sess.query(User).select_from(sel).join('addresses', aliased=True).add_entity(Address).order_by(User.id).order_by(Address.id).all(),
+        adalias = aliased(Address)
+        self.assertEquals(sess.query(User).select_from(sel).join(('addresses', adalias)).add_entity(adalias).order_by(User.id).order_by(adalias.id).all(),
             [
                 (User(name='jack',id=7), Address(user_id=7,email_address='jack@bean.com',id=1)),
                 (User(name='ed',id=8), Address(user_id=8,email_address='ed@wood.com',id=2)),
@@ -1297,12 +1527,15 @@ class SelectFromTest(QueryTest):
 
         sel = users.select(users.c.id.in_([7, 8]))
         sess = create_session()
+        
+        # TODO: remove
+        sess.query(User).select_from(sel).options(eagerload_all('orders.items.keywords')).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all()
 
-        self.assertEquals(sess.query(User).select_from(sel).join(['orders', 'items', 'keywords']).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+        self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords').filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
             User(name=u'jack',id=7)
         ])
 
-        self.assertEquals(sess.query(User).select_from(sel).join(['orders', 'items', 'keywords'], aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
+        self.assertEquals(sess.query(User).select_from(sel).join('orders', 'items', 'keywords', aliased=True).filter(Keyword.name.in_(['red', 'big', 'round'])).all(), [
             User(name=u'jack',id=7)
         ])
 
@@ -1355,7 +1588,7 @@ class SelectFromTest(QueryTest):
         sess.clear()
 
         def go():
-            self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.c.id==8).all(),
+            self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel).filter(User.id==8).all(),
                 [User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)])]
             )
         self.assert_sql_count(testing.db, go, 1)
@@ -1364,7 +1597,7 @@ class SelectFromTest(QueryTest):
         def go():
             self.assertEquals(sess.query(User).options(eagerload('addresses')).select_from(sel)[1], User(id=8, addresses=[Address(id=2), Address(id=3), Address(id=4)]))
         self.assert_sql_count(testing.db, go, 1)
-
+    
 class CustomJoinTest(QueryTest):
     keep_mappers = False
 
@@ -1428,6 +1661,10 @@ class SelfReferentialTest(ORMTest):
         node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first()
         assert node.data=='n12'
 
+        ret = sess.query(Node.data).join(Node.children, aliased=True).filter_by(data='n122').all()
+        assert ret == [('n12',)]
+
+        
         node = sess.query(Node).join(['children', 'children'], aliased=True).filter_by(data='n122').first()
         assert node.data=='n1'
 
@@ -1461,10 +1698,66 @@ class SelfReferentialTest(ORMTest):
             list(sess.query(Node).select_from(join(Node, n1, 'parent').join(n2, 'parent')).\
             filter(and_(Node.data=='n122', n1.data=='n12', n2.data=='n1')).values(Node.data, n1.data, n2.data)),
             [('n122', 'n12', 'n1')])
+    
+    def test_join_to_nonaliased(self):
+        sess = create_session()
         
-    def test_any(self):
+        n1 = aliased(Node)
+
+        # using 'n1.parent' implicitly joins to unaliased Node
+        self.assertEquals(
+            sess.query(n1).join(n1.parent).filter(Node.data=='n1').all(),
+            [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)]
+        )
+        
+        # explicit (new syntax)
+        self.assertEquals(
+            sess.query(n1).join((Node, n1.parent)).filter(Node.data=='n1').all(),
+            [Node(parent_id=1,data=u'n11',id=2), Node(parent_id=1,data=u'n12',id=3), Node(parent_id=1,data=u'n13',id=4)]
+        )
+        
+    def test_multiple_explicit_entities(self):
         sess = create_session()
         
+        parent = aliased(Node)
+        grandparent = aliased(Node)
+        self.assertEquals(
+            sess.query(Node, parent, grandparent).\
+                join((Node.parent, parent), (parent.parent, grandparent)).\
+                    filter(Node.data=='n122').filter(parent.data=='n12').\
+                    filter(grandparent.data=='n1').first(),
+            (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+        )
+
+        self.assertEquals(
+            sess.query(Node, parent, grandparent).\
+                join((Node.parent, parent), (parent.parent, grandparent)).\
+                    filter(Node.data=='n122').filter(parent.data=='n12').\
+                    filter(grandparent.data=='n1')._from_self().first(),
+            (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+        )
+
+        self.assertEquals(
+            sess.query(Node, parent, grandparent).\
+                join((Node.parent, parent), (parent.parent, grandparent)).\
+                    filter(Node.data=='n122').filter(parent.data=='n12').\
+                    filter(grandparent.data=='n1').\
+                    options(eagerload(Node.children)).first(),
+            (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+        )
+
+        self.assertEquals(
+            sess.query(Node, parent, grandparent).\
+                join((Node.parent, parent), (parent.parent, grandparent)).\
+                    filter(Node.data=='n122').filter(parent.data=='n12').\
+                    filter(grandparent.data=='n1')._from_self().\
+                    options(eagerload(Node.children)).first(),
+            (Node(data='n122'), Node(data='n12'), Node(data='n1'))
+        )
+        
+        
+    def test_any(self):
+        sess = create_session()
         self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
         self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
         self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
@@ -1561,6 +1854,8 @@ class SelfReferentialM2MTest(ORMTest):
         )
         
 class ExternalColumnsTest(QueryTest):
+    """test mappers with SQL-expressions added as column properties."""
+    
     keep_mappers = False
 
     def setup_mappers(self):
@@ -1568,15 +1863,11 @@ class ExternalColumnsTest(QueryTest):
 
     def test_external_columns_bad(self):
 
-        self.assertRaisesMessage(exceptions.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={
+        self.assertRaisesMessage(sa_exc.ArgumentError, "not represented in mapper's table", mapper, User, users, properties={
             'concat': (users.c.id * 2),
         })
         clear_mappers()
 
-        self.assertRaisesMessage(exceptions.ArgumentError, "must be given a ColumnElement as its argument.", column_property,
-            select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users)
-        )
-
     def test_external_columns_good(self):
         """test querying mappings that reference external columns or selectables."""
         
@@ -1586,19 +1877,21 @@ class ExternalColumnsTest(QueryTest):
         })
 
         mapper(Address, addresses, properties={
-            'user':relation(User, lazy=True)
+            'user':relation(User)
         })
 
         sess = create_session()
-
         
-        l = sess.query(User).all()
-        assert [
-            User(id=7, concat=14, count=1),
-            User(id=8, concat=16, count=3),
-            User(id=9, concat=18, count=1),
-            User(id=10, concat=20, count=0),
-        ] == l
+        sess.query(Address).options(eagerload('user')).all()
+
+        self.assertEquals(sess.query(User).all(), 
+            [
+                User(id=7, concat=14, count=1),
+                User(id=8, concat=16, count=3),
+                User(id=9, concat=18, count=1),
+                User(id=10, concat=20, count=0),
+            ]
+        )
 
         address_result = [
             Address(id=1, user=User(id=7, concat=14, count=1)),
@@ -1617,15 +1910,24 @@ class ExternalColumnsTest(QueryTest):
                self.assertEquals(sess.query(Address).options(eagerload('user')).all(), address_result)
             self.assert_sql_count(testing.db, go, 1)
         
-        tuple_address_result = [(address, address.user) for address in address_result]
-        
-        q =sess.query(Address).join('user', aliased=True, id='ualias').join('user', aliased=True).add_column(User.concat)
-        self.assertRaisesMessage(exceptions.InvalidRequestError, "Ambiguous", q.all)
-        
-        self.assertEquals(sess.query(Address).join('user', aliased=True, id='ualias').add_entity(User, id='ualias').all(), tuple_address_result)
+        ualias = aliased(User)
+        self.assertEquals(
+            sess.query(Address, ualias).join(('user', ualias)).all(), 
+            [(address, address.user) for address in address_result]
+        )
 
-        self.assertEquals(sess.query(Address).join('user', aliased=True, id='ualias').join('user', aliased=True).\
-                add_column(User.concat, id='ualias').add_column(User.count, id='ualias').all(),
+        self.assertEquals(
+                sess.query(Address, ualias.count).join(('user', ualias)).join('user', aliased=True).all(),
+                [
+                    (Address(id=1), 1),
+                    (Address(id=2), 3),
+                    (Address(id=3), 3),
+                    (Address(id=4), 3),
+                    (Address(id=5), 1)
+                ]
+            )
+
+        self.assertEquals(sess.query(Address, ualias.concat, ualias.count).join(('user', ualias)).join('user', aliased=True).all(),
             [
                 (Address(id=1), 14, 1),
                 (Address(id=2), 16, 3),
@@ -1635,15 +1937,21 @@ class ExternalColumnsTest(QueryTest):
             ]
         )
 
-        self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)), 
-            [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
+        ua = aliased(User)
+        self.assertEquals(sess.query(Address, ua.concat, ua.count).select_from(join(Address, ua, 'user')).options(eagerload(Address.user)).all(),
+            [
+                (Address(id=1, user=User(id=7, concat=14, count=1)), 14, 1),
+                (Address(id=2, user=User(id=8, concat=16, count=3)), 16, 3),
+                (Address(id=3, user=User(id=8, concat=16, count=3)), 16, 3),
+                (Address(id=4, user=User(id=8, concat=16, count=3)), 16, 3),
+                (Address(id=5, user=User(id=9, concat=18, count=1)), 18, 1)
+            ]
         )
 
-        self.assertEquals(list(sess.query(Address).join('user', aliased=True).values(Address.id, User.id, User.concat, User.count)), 
+        self.assertEquals(list(sess.query(Address).join('user').values(Address.id, User.id, User.concat, User.count)), 
             [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
         )
 
-        ua = aliased(User)
         self.assertEquals(list(sess.query(Address, ua).select_from(join(Address,ua, 'user')).values(Address.id, ua.id, ua.concat, ua.count)), 
             [(1, 7, 14, 1), (2, 8, 16, 3), (3, 8, 16, 3), (4, 8, 16, 3), (5, 9, 18, 1)]
         )
index 40773f8359599013eaa697bf45996a146d257cea..b33684e2fed913108a3d29f5e053f6e5b16ee575 100644 (file)
@@ -1,9 +1,9 @@
 import testenv; testenv.configure_for_tests()
 import datetime
 from sqlalchemy import *
-from sqlalchemy import exceptions, types
+from sqlalchemy import exc as sa_exc, types
 from sqlalchemy.orm import *
-from sqlalchemy.orm import collections
+from sqlalchemy.orm import collections, attributes, exc as orm_exc
 from sqlalchemy.orm.collections import collection
 from testlib import *
 from testlib import fixtures
@@ -278,7 +278,13 @@ class RelationTest3(TestBase):
                 self.pagename = pagename
                 self.currentversion = PageVersion(self, 1)
             def __repr__(self):
-                return "Page jobno:%s pagename:%s %s" % (self.jobno, self.pagename, getattr(self, '_instance_key', None))
+                try:
+                    state = attributes.instance_state(self)
+                    key = state.key
+                except (KeyError, AttributeError):
+                    key = None
+                return ("Page jobno:%s pagename:%s %s" %
+                        (self.jobno, self.pagename, key))
             def add_version(self):
                 self.currentversion = PageVersion(self, self.currentversion.version+1)
                 comment = self.add_comment()
@@ -393,7 +399,7 @@ class RelationTest4(ORMTest):
         try:
             sess.flush()
             assert False
-        except exceptions.AssertionError, e:
+        except AssertionError, e:
             assert str(e).startswith("Dependency rule tried to blank-out primary key column 'B.id' on instance ")
 
     def test_no_delete_PK_BtoA(self):
@@ -413,7 +419,7 @@ class RelationTest4(ORMTest):
         try:
             sess.flush()
             assert False
-        except exceptions.AssertionError, e:
+        except AssertionError, e:
             assert str(e).startswith("Dependency rule tried to blank-out primary key column 'B.id' on instance ")
 
     @testing.fails_on_everything_except('sqlite', 'mysql')
@@ -627,7 +633,7 @@ class TypeMatchTest(ORMTest):
         try:
             sess.save(a1)
             assert False
-        except exceptions.AssertionError, err:
+        except AssertionError, err:
             assert str(err) == "Attribute 'bs' on class '%s' doesn't handle objects of type '%s'" % (A, C)
     def test_o2m_onflush(self):
         class A(object):pass
@@ -646,11 +652,8 @@ class TypeMatchTest(ORMTest):
         sess.save(a1)
         sess.save(b1)
         sess.save(c1)
-        try:
-            sess.flush()
-            assert False
-        except exceptions.FlushError, err:
-            assert str(err).startswith("Attempting to flush an item of type %s on collection 'A.bs (B)', which is handled by mapper 'Mapper|B|b' and does not load items of that type.  Did you mean to use a polymorphic mapper for this relationship ?" % C)
+        self.assertRaisesMessage(orm_exc.FlushError, "Attempting to flush an item", sess.flush)
+
     def test_o2m_nopoly_onflush(self):
         class A(object):pass
         class B(object):pass
@@ -668,11 +671,7 @@ class TypeMatchTest(ORMTest):
         sess.save(a1)
         sess.save(b1)
         sess.save(c1)
-        try:
-            sess.flush()
-            assert False
-        except exceptions.FlushError, err:
-            assert str(err).startswith("Attempting to flush an item of type %s on collection 'A.bs (B)', which is handled by mapper 'Mapper|B|b' and does not load items of that type.  Did you mean to use a polymorphic mapper for this relationship ?" % C)
+        self.assertRaisesMessage(orm_exc.FlushError, "Attempting to flush an item", sess.flush)
 
     def test_m2o_nopoly_onflush(self):
         class A(object):pass
@@ -687,11 +686,8 @@ class TypeMatchTest(ORMTest):
         sess = create_session()
         sess.save(b1)
         sess.save(d1)
-        try:
-            sess.flush()
-            assert False
-        except exceptions.FlushError, err:
-            assert str(err).startswith("Attempting to flush an item of type %s on collection 'D.a (A)', which is handled by mapper 'Mapper|A|a' and does not load items of that type.  Did you mean to use a polymorphic mapper for this relationship ?" % B)
+        self.assertRaisesMessage(orm_exc.FlushError, "Attempting to flush an item", sess.flush)
+
     def test_m2o_oncascade(self):
         class A(object):pass
         class B(object):pass
@@ -703,11 +699,7 @@ class TypeMatchTest(ORMTest):
         d1 = D()
         d1.a = b1
         sess = create_session()
-        try:
-            sess.save(d1)
-            assert False
-        except exceptions.AssertionError, err:
-            assert str(err) == "Attribute 'a' on class '%s' doesn't handle objects of type '%s'" % (D, B)
+        self.assertRaisesMessage(AssertionError, "doesn't handle objects of type", sess.save, d1)
 
 class TypedAssociationTable(ORMTest):
     def define_tables(self, metadata):
@@ -1030,6 +1022,7 @@ class ViewOnlyTest6(ORMTest):
         
         a = sess.query(T1).first()
         self.assertEquals(a.t3s, [T3(data='t3')])
+
         
     def test_remote_side_escalation(self):
         class T1(fixtures.Base):
@@ -1051,7 +1044,7 @@ class ViewOnlyTest6(ORMTest):
             't3s':relation(T3, secondary=t2tot3)
         })
         mapper(T3, t3)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Specify remote_side argument", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Specify remote_side argument", compile_mappers)
 
 class ExplicitLocalRemoteTest(ORMTest):
     def define_tables(self, metadata):
@@ -1210,7 +1203,7 @@ class ExplicitLocalRemoteTest(ORMTest):
             )
         })
         mapper(T2, t2)
-        self.assertRaises(exceptions.ArgumentError, compile_mappers)
+        self.assertRaises(sa_exc.ArgumentError, compile_mappers)
         
         clear_mappers()
         mapper(T1, t1, properties={
@@ -1219,7 +1212,7 @@ class ExplicitLocalRemoteTest(ORMTest):
             )
         })
         mapper(T2, t2)
-        self.assertRaises(exceptions.ArgumentError, compile_mappers)
+        self.assertRaises(sa_exc.ArgumentError, compile_mappers)
         
 class InvalidRelationEscalationTest(ORMTest):
     def define_tables(self, metadata):
@@ -1237,7 +1230,7 @@ class InvalidRelationEscalationTest(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
 
     def test_no_join_self_ref(self):
         mapper(Foo, foos, properties={
@@ -1245,7 +1238,7 @@ class InvalidRelationEscalationTest(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
         
     def test_no_equated(self):
         mapper(Foo, foos, properties={
@@ -1253,7 +1246,7 @@ class InvalidRelationEscalationTest(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
 
     def test_no_equated_fks(self):
         mapper(Foo, foos, properties={
@@ -1261,7 +1254,7 @@ class InvalidRelationEscalationTest(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
 
     def test_no_equated_self_ref(self):
         mapper(Foo, foos, properties={
@@ -1269,7 +1262,7 @@ class InvalidRelationEscalationTest(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
 
     def test_no_equated_self_ref(self):
         mapper(Foo, foos, properties={
@@ -1277,7 +1270,7 @@ class InvalidRelationEscalationTest(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not locate any equated, locally mapped column pairs for primaryjoin condition", compile_mappers)
 
     def test_no_equated_viewonly(self):
         mapper(Foo, foos, properties={
@@ -1285,7 +1278,7 @@ class InvalidRelationEscalationTest(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
 
     def test_no_equated_self_ref_viewonly(self):
         mapper(Foo, foos, properties={
@@ -1294,7 +1287,7 @@ class InvalidRelationEscalationTest(ORMTest):
 
         mapper(Bar, bars)
 
-        self.assertRaisesMessage(exceptions.ArgumentError, "Specify the foreign_keys argument to indicate which columns on the relation are foreign.", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Specify the foreign_keys argument to indicate which columns on the relation are foreign.", compile_mappers)
 
     def test_no_equated_self_ref_viewonly_fks(self):
         mapper(Foo, foos, properties={
@@ -1308,21 +1301,21 @@ class InvalidRelationEscalationTest(ORMTest):
             'bars':relation(Bar, primaryjoin=foos.c.id==bars.c.fid)
         })
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
     
     def test_equated_self_ref(self):
         mapper(Foo, foos, properties={
             'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid)
         })
 
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
 
     def test_equated_self_ref_wrong_fks(self):
         mapper(Foo, foos, properties={
             'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid, foreign_keys=[bars.c.id])
         })
 
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
 
 class InvalidRelationEscalationTestM2M(ORMTest):
     def define_tables(self, metadata):
@@ -1341,7 +1334,7 @@ class InvalidRelationEscalationTestM2M(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
 
     def test_no_secondaryjoin(self):
         mapper(Foo, foos, properties={
@@ -1349,7 +1342,7 @@ class InvalidRelationEscalationTestM2M(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
 
     def test_bad_primaryjoin(self):
         mapper(Foo, foos, properties={
@@ -1357,7 +1350,7 @@ class InvalidRelationEscalationTestM2M(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
 
     def test_bad_secondaryjoin(self):
         mapper(Foo, foos, properties={
@@ -1365,7 +1358,7 @@ class InvalidRelationEscalationTestM2M(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for secondaryjoin condition", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not determine relation direction for secondaryjoin condition", compile_mappers)
 
     def test_no_equated_secondaryjoin(self):
         mapper(Foo, foos, properties={
@@ -1373,7 +1366,7 @@ class InvalidRelationEscalationTestM2M(ORMTest):
         })
 
         mapper(Bar, bars)
-        self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated, locally mapped column pairs for secondaryjoin condition", compile_mappers)
+        self.assertRaisesMessage(sa_exc.ArgumentError, "Could not locate any equated, locally mapped column pairs for secondaryjoin condition", compile_mappers)
 
 
 if __name__ == "__main__":
diff --git a/test/orm/scoping.py b/test/orm/scoping.py
new file mode 100644 (file)
index 0000000..523f376
--- /dev/null
@@ -0,0 +1,171 @@
+import testenv; testenv.configure_for_tests()
+from sqlalchemy import *
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.orm import *
+from testlib import *
+from testlib import fixtures
+
+
+class ScopedSessionTest(ORMTest):
+
+    def define_tables(self, metadata):
+        global table, table2
+        table = Table('sometable', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30)))
+        table2 = Table('someothertable', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('someid', None, ForeignKey('sometable.id'))
+            )
+
+    def test_basic(self):
+        Session = scoped_session(sessionmaker())
+
+        class SomeObject(fixtures.Base):
+            query = Session.query_property()
+        class SomeOtherObject(fixtures.Base):
+            query = Session.query_property()
+
+        mapper(SomeObject, table, properties={
+            'options':relation(SomeOtherObject)
+        })
+        mapper(SomeOtherObject, table2)
+
+        s = SomeObject(id=1, data="hello")
+        sso = SomeOtherObject()
+        s.options.append(sso)
+        Session.save(s)
+        Session.commit()
+        Session.refresh(sso)
+        Session.remove()
+
+        self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), Session.query(SomeObject).one())
+        self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), SomeObject.query.one())
+        self.assertEquals(SomeOtherObject(someid=1), SomeOtherObject.query.filter(SomeOtherObject.someid==sso.someid).one())
+
+
+class ScopedMapperTest(TestBase):
+    def setUpAll(self):
+        global metadata, table, table2
+        metadata = MetaData(testing.db)
+        table = Table('sometable', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30)))
+        table2 = Table('someothertable', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('someid', None, ForeignKey('sometable.id'))
+            )
+        metadata.create_all()
+
+    def setUp(self):
+        global SomeObject, SomeOtherObject
+        class SomeObject(fixtures.Base):pass
+        class SomeOtherObject(fixtures.Base):pass
+
+        global Session
+
+        Session = scoped_session(create_session)
+        Session.mapper(SomeObject, table, properties={
+            'options':relation(SomeOtherObject)
+        })
+        Session.mapper(SomeOtherObject, table2)
+
+        s = SomeObject()
+        s.id = 1
+        s.data = 'hello'
+        sso = SomeOtherObject()
+        s.options.append(sso)
+        Session.flush()
+        Session.clear()
+
+    def tearDownAll(self):
+        metadata.drop_all()
+
+    def tearDown(self):
+        for table in metadata.table_iterator(reverse=True):
+            table.delete().execute()
+        clear_mappers()
+
+    def test_query(self):
+        sso = SomeOtherObject.query().first()
+        assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
+
+    def test_query_compiles(self):
+        class Foo(object):
+            pass
+        Session.mapper(Foo, table2)
+        assert hasattr(Foo, 'query')
+
+        ext = MapperExtension()
+
+        class Bar(object):
+            pass
+        Session.mapper(Bar, table2, extension=[ext])
+        assert hasattr(Bar, 'query')
+
+        class Baz(object):
+            pass
+        Session.mapper(Baz, table2, extension=ext)
+        assert hasattr(Baz, 'query')
+
+    def test_validating_constructor(self):
+        s2 = SomeObject(someid=12)
+        s3 = SomeOtherObject(someid=123, bogus=345)
+
+        class ValidatedOtherObject(object): pass
+        Session.mapper(ValidatedOtherObject, table2, validate=True)
+
+        v1 = ValidatedOtherObject(someid=12)
+        self.assertRaises(sa_exc.ArgumentError, ValidatedOtherObject, someid=12, bogus=345)
+
+    def test_dont_clobber_methods(self):
+        class MyClass(object):
+            def expunge(self):
+                return "an expunge !"
+
+        Session.mapper(MyClass, table2)
+
+        assert MyClass().expunge() == "an expunge !"
+
+class ScopedMapperTest2(ORMTest):
+    def define_tables(self, metadata):
+        global table, table2
+        table = Table('sometable', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(30)),
+            Column('type', String(30))
+
+            )
+        table2 = Table('someothertable', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('someid', None, ForeignKey('sometable.id')),
+            Column('somedata', String(30)),
+            )
+
+    def test_inheritance(self):
+        def expunge_list(l):
+            for x in l:
+                Session.expunge(x)
+            return l
+
+        class BaseClass(fixtures.Base):
+            pass
+        class SubClass(BaseClass):
+            pass
+
+        Session = scoped_session(sessionmaker())
+        Session.mapper(BaseClass, table, polymorphic_identity='base', polymorphic_on=table.c.type)
+        Session.mapper(SubClass, table2, polymorphic_identity='sub', inherits=BaseClass)
+
+        b = BaseClass(data='b1')
+        s =  SubClass(data='s1', somedata='somedata')
+        Session.commit()
+        Session.clear()
+
+        assert expunge_list([BaseClass(data='b1'), SubClass(data='s1', somedata='somedata')]) == BaseClass.query.all()
+        assert expunge_list([SubClass(data='s1', somedata='somedata')]) == SubClass.query.all()
+
+
+
+if __name__ == "__main__":
+    testenv.main()
index fc5be6f505f7504ac29ef30ee508f142cd16c6ab..a16c24fc11962f5f34bf112150ab1f895075199d 100644 (file)
@@ -2,7 +2,7 @@
 
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc as sa_exc
 from sqlalchemy.orm import *
 from testlib import *
 from testlib.fixtures import *
@@ -21,7 +21,7 @@ class SelectableNoFromsTest(ORMTest):
         class Subset(object):
             pass
         selectable = select(["x", "y", "z"])
-        self.assertRaisesMessage(exceptions.InvalidRequestError, "Could not find any Table objects", mapper, Subset, selectable)
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, "Could not find any Table objects", mapper, Subset, selectable)
 
     @testing.emits_warning('.*creating an Alias.*')
     def test_basic(self):
index 49932f8d9d78718012a47b1feeb8e09f9631530f..719ecccf9d84716160ef4e58570920da74406886 100644 (file)
@@ -1,14 +1,15 @@
 import testenv; testenv.configure_for_tests()
+import gc
+import pickle
 from sqlalchemy import *
-from sqlalchemy import exceptions, util
+from sqlalchemy import exc as sa_exc, util
 from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes
 from sqlalchemy.orm.session import SessionExtension
 from sqlalchemy.orm.session import Session as SessionCls
 from testlib import *
 from testlib.tables import *
 from testlib import fixtures, tables
-import pickle
-import gc
 
 
 class SessionTest(TestBase, AssertsExecutionResults):
@@ -27,7 +28,8 @@ class SessionTest(TestBase, AssertsExecutionResults):
         pass
 
     def test_close(self):
-        """test that flush() doenst close a connection the session didnt open"""
+        """test that flush() doesn't close a connection the session didn't open"""
+
         c = testing.db.connect()
         class User(object):pass
         mapper(User, users)
@@ -83,9 +85,14 @@ class SessionTest(TestBase, AssertsExecutionResults):
         # then see if expunge fails
         session.expunge(u)
 
+        assert object_session(u) is attributes.instance_state(u).session_id is None
+        for a in u.addresses:
+            assert object_session(a) is attributes.instance_state(a).session_id is None
+
     @engines.close_open_connections
     def test_binds_from_expression(self):
         """test that Session can extract Table objects from ClauseElements and match them to tables."""
+
         Session = sessionmaker(binds={users:testing.db, addresses:testing.db})
         sess = Session()
         sess.execute(users.insert(), params=dict(user_id=1, user_name='ed'))
@@ -123,7 +130,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
 
-        sess = create_session(transactional=True, bind=conn1)
+        sess = create_session(autocommit=False, bind=conn1)
         u = User()
         sess.save(u)
         sess.flush()
@@ -134,20 +141,6 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert testing.db.connect().execute("select count(1) from users").scalar() == 1
         sess.close()
 
-    def test_flush_noop(self):
-        session = create_session()
-        session.uow = object()
-
-        self.assertRaises(AttributeError, session.flush)
-
-        session = create_session()
-        session.uow = object()
-
-        session.flush(objects=[])
-        session.flush(objects=set())
-        session.flush(objects=())
-        session.flush(objects=iter([]))
-
     @testing.unsupported('sqlite', 'mssql') # TEMP: test causes mssql to hang
     @engines.close_open_connections
     def test_autoflush(self):
@@ -156,7 +149,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
 
-        sess = create_session(bind=conn1, transactional=True, autoflush=True)
+        sess = create_session(bind=conn1, autocommit=False, autoflush=True)
         u = User()
         u.user_name='ed'
         sess.save(u)
@@ -179,7 +172,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         })
         mapper(Address, addresses)
 
-        sess = create_session(autoflush=True, transactional=True)
+        sess = create_session(autoflush=True, autocommit=False)
         u = User(user_name='ed', addresses=[Address(email_address='foo')])
         sess.save(u)
         self.assertEquals(sess.query(Address).filter(Address.user==u).one(), Address(email_address='foo'))
@@ -191,7 +184,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         mapper(User, users)
 
         try:
-            sess = create_session(transactional=True, autoflush=True)
+            sess = create_session(autocommit=False, autoflush=True)
             u = User()
             u.user_name='ed'
             sess.save(u)
@@ -214,7 +207,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
 
-        sess = create_session(bind=conn1, transactional=True, autoflush=True)
+        sess = create_session(bind=conn1, autocommit=False, autoflush=True)
         u = User()
         u.user_name='ed'
         sess.save(u)
@@ -223,18 +216,17 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert testing.db.connect().execute("select count(1) from users").scalar() == 1
         sess.commit()
 
-    # TODO: not doing rollback of attributes right now.
-    def dont_test_autoflush_rollback(self):
+    def test_autoflush_rollback(self):
         tables.data()
         mapper(Address, addresses)
         mapper(User, users, properties={
             'addresses':relation(Address)
         })
 
-        sess = create_session(transactional=True, autoflush=True)
+        sess = create_session(autocommit=False, autoflush=True)
         u = sess.query(User).get(8)
         newad = Address()
-        newad.email_address == 'something new'
+        newad.email_address = 'something new'
         u.addresses.append(newad)
         u.user_name = 'some new name'
         assert u.user_name == 'some new name'
@@ -244,16 +236,26 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert u.user_name == 'ed'
         assert len(u.addresses) == 3
         assert newad not in u.addresses
-
+        
+        # pending objects dont get expired
+        assert newad.email_address == 'something new'
+    
+    def test_textual_execute(self):
+        """test that Session.execute() converts to text()"""
+        
+        tables.data()
+        sess = create_session(bind=testing.db)
+        # use :bindparam style
+        self.assertEquals(sess.execute("select * from users where user_id=:id", {'id':7}).fetchall(), [(7, u'jack')])
 
     @engines.close_open_connections
-    def test_external_joined_transaction(self):
+    def test_subtransaction_on_external(self):
         class User(object):pass
         mapper(User, users)
         conn = testing.db.connect()
         trans = conn.begin()
-        sess = create_session(bind=conn, transactional=True, autoflush=True)
-        sess.begin()
+        sess = create_session(bind=conn, autocommit=False, autoflush=True)
+        sess.begin(subtransactions=True)
         u = User()
         sess.save(u)
         sess.flush()
@@ -271,7 +273,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         try:
             conn = testing.db.connect()
             trans = conn.begin()
-            sess = create_session(bind=conn, transactional=True, autoflush=True)
+            sess = create_session(bind=conn, autocommit=False, autoflush=True)
             u1 = User()
             sess.save(u1)
             sess.flush()
@@ -288,16 +290,14 @@ class SessionTest(TestBase, AssertsExecutionResults):
             conn.close()
             raise
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @engines.close_open_connections
+    @testing.requires.savepoints
     def test_heavy_nesting(self):
         session = create_session(bind=testing.db)
 
         session.begin()
         session.connection().execute("insert into users (user_name) values ('user1')")
 
-        session.begin()
+        session.begin(subtransactions=True)
 
         session.begin_nested()
 
@@ -312,9 +312,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert session.connection().execute("select count(1) from users").scalar() == 2
 
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @testing.exclude('mysql', '<', (5, 0, 3))
+    @testing.requires.two_phase_transactions
     def test_twophase(self):
         # TODO: mock up a failure condition here
         # to ensure a rollback succeeds
@@ -324,7 +322,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         mapper(Address, addresses)
 
         engine2 = create_engine(testing.db.url)
-        sess = create_session(transactional=False, autoflush=False, twophase=True)
+        sess = create_session(autocommit=True, autoflush=False, twophase=True)
         sess.bind_mapper(User, testing.db)
         sess.bind_mapper(Address, engine2)
         sess.begin()
@@ -338,11 +336,11 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert users.count().scalar() == 1
         assert addresses.count().scalar() == 1
 
-    def test_joined_transaction(self):
+    def test_subtransaction_on_noautocommit(self):
         class User(object):pass
         mapper(User, users)
-        sess = create_session(transactional=True, autoflush=True)
-        sess.begin()
+        sess = create_session(autocommit=False, autoflush=True)
+        sess.begin(subtransactions=True)
         u = User()
         sess.save(u)
         sess.flush()
@@ -351,9 +349,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert len(sess.query(User).all()) == 0
         sess.close()
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @testing.exclude('mysql', '<', (5, 0, 3))
+    @testing.requires.savepoints
     def test_nested_transaction(self):
         class User(object):pass
         mapper(User, users)
@@ -376,13 +372,11 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert len(sess.query(User).all()) == 1
         sess.close()
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @testing.exclude('mysql', '<', (5, 0, 3))
+    @testing.requires.savepoints
     def test_nested_autotrans(self):
         class User(object):pass
         mapper(User, users)
-        sess = create_session(transactional=True)
+        sess = create_session(autocommit=False)
         u = User()
         sess.save(u)
         sess.flush()
@@ -399,14 +393,12 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert len(sess.query(User).all()) == 1
         sess.close()
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @testing.exclude('mysql', '<', (5, 0, 3))
+    @testing.requires.savepoints
     def test_nested_transaction_connection_add(self):
         class User(object): pass
         mapper(User, users)
 
-        sess = create_session(transactional=False)
+        sess = create_session(autocommit=True)
 
         sess.begin()
         sess.begin_nested()
@@ -436,18 +428,16 @@ class SessionTest(TestBase, AssertsExecutionResults):
 
         sess.close()
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @testing.exclude('mysql', '<', (5, 0, 3))
+    @testing.requires.savepoints
     def test_mixed_transaction_control(self):
         class User(object): pass
         mapper(User, users)
 
-        sess = create_session(transactional=False)
+        sess = create_session(autocommit=True)
 
         sess.begin()
         sess.begin_nested()
-        transaction = sess.begin()
+        transaction = sess.begin(subtransactions=True)
 
         sess.save(User())
 
@@ -469,14 +459,12 @@ class SessionTest(TestBase, AssertsExecutionResults):
 
         sess.close()
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @testing.exclude('mysql', '<', (5, 0, 3))
+    @testing.requires.savepoints
     def test_mixed_transaction_close(self):
         class User(object): pass
         mapper(User, users)
 
-        sess = create_session(transactional=True)
+        sess = create_session(autocommit=False)
 
         sess.begin_nested()
 
@@ -492,27 +480,20 @@ class SessionTest(TestBase, AssertsExecutionResults):
 
         self.assertEquals(len(sess.query(User).all()), 1)
 
-    @testing.unsupported('sqlite', 'mssql', 'firebird', 'sybase', 'access',
-                         'oracle', 'maxdb')
-    @testing.exclude('mysql', '<', (5, 0, 3))
     def test_error_on_using_inactive_session(self):
         class User(object): pass
         mapper(User, users)
 
-        sess = create_session(transactional=False)
+        sess = create_session(autocommit=True)
 
-        try:
-            sess.begin()
-            sess.begin()
+        sess.begin()
+        sess.begin(subtransactions=True)
 
-            sess.save(User())
-            sess.flush()
+        sess.save(User())
+        sess.flush()
 
-            sess.rollback()
-            sess.begin()
-            assert False
-        except exceptions.InvalidRequestError, e:
-            self.assertEquals(str(e), "The transaction is inactive due to a rollback in a subtransaction and should be closed")
+        sess.rollback()
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, "inactive due to a rollback in a subtransaction", sess.begin, subtransactions=True)
         sess.close()
 
     @engines.close_open_connections
@@ -521,30 +502,14 @@ class SessionTest(TestBase, AssertsExecutionResults):
         mapper(User, users)
         c = testing.db.connect()
         sess = create_session(bind=c)
-        sess.create_transaction()
+        sess.begin()
         transaction = sess.transaction
         u = User()
         sess.save(u)
         sess.flush()
-        assert transaction.get_or_add(testing.db) is transaction.get_or_add(c) is c
+        assert transaction._connection_for_bind(testing.db) is transaction._connection_for_bind(c) is c
 
-        try:
-            transaction.add(testing.db.connect())
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Session already has a Connection associated for the given Connection's Engine"
-
-        try:
-            transaction.get_or_add(testing.db.connect())
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Session already has a Connection associated for the given Connection's Engine"
-
-        try:
-            transaction.add(testing.db)
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Session already has a Connection associated for the given Engine"
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, "Session already has a Connection associated", transaction._connection_for_bind, testing.db.connect())
 
         transaction.rollback()
         assert len(sess.query(User).all()) == 0
@@ -555,7 +520,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         mapper(User, users)
         c = testing.db.connect()
 
-        sess = create_session(bind=c, transactional=True)
+        sess = create_session(bind=c, autocommit=False)
         u = User()
         sess.save(u)
         sess.flush()
@@ -563,7 +528,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert not c.in_transaction()
         assert c.scalar("select count(1) from users") == 0
 
-        sess = create_session(bind=c, transactional=True)
+        sess = create_session(bind=c, autocommit=False)
         u = User()
         sess.save(u)
         sess.flush()
@@ -576,7 +541,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         c = testing.db.connect()
 
         trans = c.begin()
-        sess = create_session(bind=c, transactional=False)
+        sess = create_session(bind=c, autocommit=True)
         u = User()
         sess.save(u)
         sess.flush()
@@ -596,17 +561,8 @@ class SessionTest(TestBase, AssertsExecutionResults):
 
         user = User()
 
-        try:
-            s.update(user)
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Instance 'User@%s' is not persisted" % hex(id(user))
-
-        try:
-            s.delete(user)
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Instance 'User@%s' is not persisted" % hex(id(user))
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, "is not persisted", s.update, user)
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, "is not persisted", s.delete, user)
 
         s.save(user)
         s.flush()
@@ -632,25 +588,13 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert user in s
         assert user not in s.dirty
 
-        try:
-            s.save(user)
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert str(e) == "Instance 'User@%s' is already persistent" % hex(id(user))
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, "is already persistent", s.save, user)
 
         s2 = create_session()
-        try:
-            s2.delete(user)
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert "is already attached to session" in str(e)
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, "is already attached to session", s2.delete, user)
 
         u2 = s2.query(User).get(user.user_id)
-        try:
-            s.delete(u2)
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert "already persisted with a different identity" in str(e)
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, "already persisted with a different identity", s.delete, u2)
 
         s.delete(user)
         s.flush()
@@ -707,21 +651,18 @@ class SessionTest(TestBase, AssertsExecutionResults):
         del user
         gc.collect()
         assert len(s.identity_map) == 0
-        assert len(s.identity_map.data) == 0
 
         user = s.query(User).one()
         user.user_name = 'fred'
         del user
         gc.collect()
         assert len(s.identity_map) == 1
-        assert len(s.identity_map.data) == 1
         assert len(s.dirty) == 1
 
         s.flush()
         gc.collect()
         assert not s.dirty
         assert not s.identity_map
-        assert not s.identity_map.data
 
         user = s.query(User).one()
         assert user.user_name == 'fred'
@@ -890,7 +831,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert log == ['before_flush', 'after_begin', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
 
         log = []
-        sess = create_session(transactional=True, extension=MyExt())
+        sess = create_session(autocommit=False, extension=MyExt())
         u = User()
         sess.save(u)
         sess.flush()
@@ -906,7 +847,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert log == ['before_commit', 'after_commit']
         
         log = []
-        sess = create_session(transactional=True, extension=MyExt(), bind=testing.db)
+        sess = create_session(autocommit=False, extension=MyExt(), bind=testing.db)
         conn = sess.connection()
         assert log == ['after_begin']
 
@@ -918,11 +859,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         u1 = User()
         sess1.save(u1)
 
-        try:
-            sess2.save(u1)
-            assert False
-        except exceptions.InvalidRequestError, e:
-            assert "already attached to session" in str(e)
+        self.assertRaisesMessage(sa_exc.InvalidRequestError, "already attached to session", sess2.save, u1)
 
         u2 = pickle.loads(pickle.dumps(u1))
 
@@ -941,6 +878,7 @@ class SessionTest(TestBase, AssertsExecutionResults):
         sess.expunge(u1)
 
         assert u1 not in sess
+        assert Session.object_session(u1) is None
 
         u2 = sess.query(User).get(u1.user_id)
         assert u2 is not None and u2 is not u1
@@ -950,12 +888,14 @@ class SessionTest(TestBase, AssertsExecutionResults):
 
         sess.expunge(u2)
         assert u2 not in sess
+        assert Session.object_session(u2) is None
 
         u1.user_name = "John"
         u2.user_name = "Doe"
 
         sess.update(u1)
         assert u1 in sess
+        assert Session.object_session(u1) is sess
 
         sess.flush()
 
@@ -981,197 +921,39 @@ class SessionTest(TestBase, AssertsExecutionResults):
         assert len(list(sess)) == 1
 
 
-class ScopedSessionTest(ORMTest):
-
-    def define_tables(self, metadata):
-        global table, table2
-        table = Table('sometable', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('data', String(30)))
-        table2 = Table('someothertable', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('someid', None, ForeignKey('sometable.id'))
-            )
-
-    def test_basic(self):
-        Session = scoped_session(sessionmaker())
-
-        class SomeObject(fixtures.Base):
-            query = Session.query_property()
-        class SomeOtherObject(fixtures.Base):
-            query = Session.query_property()
-
-        mapper(SomeObject, table, properties={
-            'options':relation(SomeOtherObject)
-        })
-        mapper(SomeOtherObject, table2)
-
-        s = SomeObject(id=1, data="hello")
-        sso = SomeOtherObject()
-        s.options.append(sso)
-        Session.save(s)
-        Session.commit()
-        Session.remove()
-
-        self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), Session.query(SomeObject).one())
-        self.assertEquals(SomeObject(id=1, data="hello", options=[SomeOtherObject(someid=1)]), SomeObject.query.one())
-        self.assertEquals(SomeOtherObject(someid=1), SomeOtherObject.query.filter(SomeOtherObject.someid==sso.someid).one())
-
-class ScopedMapperTest(TestBase):
+class TLTransactionTest(TestBase):
     def setUpAll(self):
-        global metadata, table, table2
-        metadata = MetaData(testing.db)
-        table = Table('sometable', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('data', String(30), nullable=False))
-        table2 = Table('someothertable', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('someid', None, ForeignKey('sometable.id'))
-            )
-        metadata.create_all()
-
-    def setUp(self):
-        global SomeObject, SomeOtherObject
-        class SomeObject(object):pass
-        class SomeOtherObject(object):pass
-
-        global Session
-
-        Session = scoped_session(create_session)
-        Session.mapper(SomeObject, table, properties={
-            'options':relation(SomeOtherObject)
-        })
-        Session.mapper(SomeOtherObject, table2)
-
-        s = SomeObject()
-        s.id = 1
-        s.data = 'hello'
-        sso = SomeOtherObject()
-        s.options.append(sso)
-        Session.flush()
-        Session.clear()
-
-    def tearDownAll(self):
-        metadata.drop_all()
-
+        global users, metadata, tlengine
+        tlengine = create_engine(testing.db.url, strategy='threadlocal')
+        metadata = MetaData()
+        users = Table('query_users', metadata,
+            Column('user_id', INT, Sequence('query_users_id_seq', optional=True), primary_key=True),
+            Column('user_name', VARCHAR(20)),
+            test_needs_acid=True,
+        )
+        users.create(tlengine)
     def tearDown(self):
-        for table in metadata.table_iterator(reverse=True):
-            table.delete().execute()
-        clear_mappers()
-
-    def test_query(self):
-        sso = SomeOtherObject.query().first()
-        assert SomeObject.query.filter_by(id=1).one().options[0].id == sso.id
+        tlengine.execute(users.delete())
 
-    def test_query_compiles(self):
-        class Foo(object):
-            pass
-        Session.mapper(Foo, table2)
-        assert hasattr(Foo, 'query')
-
-        ext = MapperExtension()
-
-        class Bar(object):
-            pass
-        Session.mapper(Bar, table2, extension=[ext])
-        assert hasattr(Bar, 'query')
+    def tearDownAll(self):
+        users.drop(tlengine)
+        tlengine.dispose()
 
-        class Baz(object):
+    @testing.exclude('mysql', '<', (5, 0, 3))
+    def testsessionnesting(self):
+        class User(object):
             pass
-        Session.mapper(Baz, table2, extension=ext)
-        assert hasattr(Baz, 'query')
-
-    def test_validating_constructor(self):
-        s2 = SomeObject(someid=12)
-        s3 = SomeOtherObject(someid=123, bogus=345)
-
-        class ValidatedOtherObject(object):pass
-        Session.mapper(ValidatedOtherObject, table2, validate=True)
-
-        v1 = ValidatedOtherObject(someid=12)
         try:
-            v2 = ValidatedOtherObject(someid=12, bogus=345)
-            assert False
-        except exceptions.ArgumentError:
-            pass
-
-    def test_dont_clobber_methods(self):
-        class MyClass(object):
-            def expunge(self):
-                return "an expunge !"
-
-        Session.mapper(MyClass, table2)
-
-        assert MyClass().expunge() == "an expunge !"
-
-    def _test_autoflush_saveoninit(self, on_init, autoflush=None):
-        Session = scoped_session(
-            sessionmaker(transactional=True, autoflush=True))
-
-        class Foo(object):
-            def __init__(self, data=None):
-                if autoflush is not None:
-                    friends = Session.query(Foo).autoflush(autoflush).all()
-                else:
-                    friends = Session.query(Foo).all()
-                self.data = data
-
-        Session.mapper(Foo, table, save_on_init=on_init)
-
-        a1 = Foo('an address')
-        Session.flush()
-
-    def test_autoflush_saveoninit(self):
-        """Test save_on_init + query.autoflush()"""
-        self._test_autoflush_saveoninit(False)
-        self._test_autoflush_saveoninit(False, True)
-        self._test_autoflush_saveoninit(False, False)
-
-        self.assertRaises(exceptions.DBAPIError,
-                          self._test_autoflush_saveoninit, True)
-        self.assertRaises(exceptions.DBAPIError,
-                          self._test_autoflush_saveoninit, True, True)
-        self._test_autoflush_saveoninit(True, False)
-
-
-class ScopedMapperTest2(ORMTest):
-    def define_tables(self, metadata):
-        global table, table2
-        table = Table('sometable', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('data', String(30)),
-            Column('type', String(30))
-
-            )
-        table2 = Table('someothertable', metadata,
-            Column('id', Integer, primary_key=True),
-            Column('someid', None, ForeignKey('sometable.id')),
-            Column('somedata', String(30)),
-            )
-
-    def test_inheritance(self):
-        def expunge_list(l):
-            for x in l:
-                Session.expunge(x)
-            return l
-
-        class BaseClass(fixtures.Base):
-            pass
-        class SubClass(BaseClass):
-            pass
-
-        Session = scoped_session(sessionmaker())
-        Session.mapper(BaseClass, table, polymorphic_identity='base', polymorphic_on=table.c.type)
-        Session.mapper(SubClass, table2, polymorphic_identity='sub', inherits=BaseClass)
-
-        b = BaseClass(data='b1')
-        s =  SubClass(data='s1', somedata='somedata')
-        Session.commit()
-        Session.clear()
-
-        assert expunge_list([BaseClass(data='b1'), SubClass(data='s1', somedata='somedata')]) == BaseClass.query.all()
-        assert expunge_list([SubClass(data='s1', somedata='somedata')]) == SubClass.query.all()
+            mapper(User, users)
 
+            sess = create_session(bind=tlengine)
+            tlengine.begin()
+            u = User()
+            sess.save(u)
+            sess.flush()
+            tlengine.commit()
+        finally:
+            clear_mappers()
 
 
 if __name__ == "__main__":
diff --git a/test/orm/sessioncontext.py b/test/orm/sessioncontext.py
deleted file mode 100644 (file)
index c743dab..0000000
+++ /dev/null
@@ -1,48 +0,0 @@
-import testenv; testenv.configure_for_tests()
-from sqlalchemy import *
-from sqlalchemy.orm import *
-from sqlalchemy.ext.sessioncontext import SessionContext
-from sqlalchemy.orm.session import object_session, Session
-from testlib import *
-
-
-metadata = MetaData()
-users = Table('users', metadata,
-    Column('user_id', Integer, Sequence('user_id_seq', optional=True), primary_key = True),
-    Column('user_name', String(40)),
-)
-
-class SessionContextTest(TestBase, AssertsExecutionResults):
-    def setUp(self):
-        clear_mappers()
-
-    def do_test(self, class_, context):
-        """test session assignment on object creation"""
-        obj = class_()
-        assert context.current == object_session(obj)
-
-        # keep a reference so the old session doesn't get gc'd
-        old_session = context.current
-
-        context.current = Session()
-        assert context.current != object_session(obj)
-        assert old_session == object_session(obj)
-
-        new_session = context.current
-        del context.current
-        assert context.current != new_session
-        assert old_session == object_session(obj)
-
-        obj2 = class_()
-        assert context.current == object_session(obj2)
-
-    @testing.uses_deprecated('SessionContext')
-    def test_mapper_extension(self):
-        context = SessionContext(Session)
-        class User(object): pass
-        User.mapper = mapper(User, users, extension=context.mapper_extension)
-        self.do_test(User, context)
-
-
-if __name__ == "__main__":
-    testenv.main()
index d231b14a2ce8ffacc626f2083490ae28d860ce72..f25d097fd7624810f1a24645a7283dbccf617d25 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import datetime, os
 from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import sql
 from sqlalchemy.orm import *
 from sqlalchemy.orm.shard import ShardedSession
 from sqlalchemy.sql import operators
@@ -93,7 +93,7 @@ class ShardTest(TestBase):
             else:
                 return ids
 
-        create_session = sessionmaker(class_=ShardedSession, autoflush=True, transactional=True)
+        create_session = sessionmaker(class_=ShardedSession, autoflush=True, autocommit=False)
 
         create_session.configure(shards={
             'north_america':db1,
@@ -139,7 +139,7 @@ class ShardTest(TestBase):
         for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
             sess.save(c)
         sess.commit()
-
+        tokyo.city   # reload 'city' attribute on tokyo
         sess.clear()
 
         assert db2.execute(weather_locations.select()).fetchall() == [(1, 'Asia', 'Tokyo')]
diff --git a/test/orm/transaction.py b/test/orm/transaction.py
new file mode 100644 (file)
index 0000000..ca36800
--- /dev/null
@@ -0,0 +1,360 @@
+import testenv; testenv.configure_for_tests()
+import operator
+from sqlalchemy import *
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.fixtures import *
+
+
+class TransactionTest(FixtureTest):
+    keep_mappers = True
+    session = sessionmaker()
+
+    def setup_mappers(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user',
+                                 cascade="all, delete-orphan"),
+            })
+        mapper(Address, addresses)
+
+
+class FixtureDataTest(TransactionTest):
+    refresh_data = True
+
+    def test_attrs_on_rollback(self):
+        sess = self.session()
+        u1 = sess.get(User, 7)
+        u1.name = 'ed'
+        sess.rollback()
+        self.assertEquals(u1.name, 'jack')
+
+    def test_commit_persistent(self):
+        sess = self.session()
+        u1 = sess.get(User, 7)
+        u1.name = 'ed'
+        sess.flush()
+        sess.commit()
+        self.assertEquals(u1.name, 'ed')
+
+    def test_concurrent_commit_persistent(self):
+        s1 = self.session()
+        u1 = s1.get(User, 7)
+        u1.name = 'ed'
+        s1.commit()
+
+        s2 = self.session()
+        u2 = s2.get(User, 7)
+        assert u2.name == 'ed'
+        u2.name = 'will'
+        s2.commit()
+
+        assert u1.name == 'will'
+
+class AutoExpireTest(TransactionTest):
+    tables_only = True
+
+    def test_expunge_pending_on_rollback(self):
+        sess = self.session()
+        u2= User(name='newuser')
+        sess.add(u2)
+        assert u2 in sess
+        sess.rollback()
+        assert u2 not in sess
+
+    def test_trans_pending_cleared_on_commit(self):
+        sess = self.session()
+        u2= User(name='newuser')
+        sess.add(u2)
+        assert u2 in sess
+        sess.commit()
+        assert u2 in sess
+        u3 = User(name='anotheruser')
+        sess.add(u3)
+        sess.rollback()
+        assert u3 not in sess
+        assert u2 in sess
+
+    def test_update_deleted_on_rollback(self):
+        s = self.session()
+        u1 = User(name='ed')
+        s.add(u1)
+        s.commit()
+
+        s.delete(u1)
+        assert u1 in s.deleted
+        s.rollback()
+        assert u1 in s
+        assert u1 not in s.deleted
+
+    def test_trans_deleted_cleared_on_rollback(self):
+        s = self.session()
+        u1 = User(name='ed')
+        s.add(u1)
+        s.commit()
+
+        s.delete(u1)
+        s.commit()
+        assert u1 not in s
+        s.rollback()
+        assert u1 not in s
+
+    def test_update_deleted_on_rollback_cascade(self):
+        s = self.session()
+        u1 = User(name='ed', addresses=[Address(email_address='foo')])
+        s.add(u1)
+        s.commit()
+
+        s.delete(u1)
+        assert u1 in s.deleted
+        assert u1.addresses[0] in s.deleted
+        s.rollback()
+        assert u1 in s
+        assert u1 not in s.deleted
+        assert u1.addresses[0] not in s.deleted
+
+    def test_update_deleted_on_rollback_orphan(self):
+        s = self.session()
+        u1 = User(name='ed', addresses=[Address(email_address='foo')])
+        s.add(u1)
+        s.commit()
+
+        a1 = u1.addresses[0]
+        u1.addresses.remove(a1)
+
+        s.flush()
+        self.assertEquals(s.query(Address).filter(Address.email_address=='foo').all(), [])
+        s.rollback()
+        assert a1 not in s.deleted
+        assert u1.addresses == [a1]
+
+    def test_commit_pending(self):
+        sess = self.session()
+        u1 = User(name='newuser')
+        sess.add(u1)
+        sess.flush()
+        sess.commit()
+        self.assertEquals(u1.name, 'newuser')
+
+
+    def test_concurrent_commit_pending(self):
+        s1 = self.session()
+        u1 = User(name='edward')
+        s1.add(u1)
+        s1.commit()
+
+        s2 = self.session()
+        u2 = s2.query(User).filter(User.name=='edward').one()
+        u2.name = 'will'
+        s2.commit()
+
+        assert u1.name == 'will'
+
+class RollbackRecoverTest(TransactionTest):
+    only_tables = True
+
+    def test_pk_violation(self):
+        s = self.session()
+        a1 = Address(email_address='foo')
+        u1 = User(id=1, name='ed', addresses=[a1])
+        s.add(u1)
+        s.commit()
+
+        a2 = Address(email_address='bar')
+        u2 = User(id=1, name='jack', addresses=[a2])
+
+        u1.name = 'edward'
+        a1.email_address = 'foober'
+        s.add(u2)
+        self.assertRaises(sa_exc.FlushError, s.commit)
+        self.assertRaises(sa_exc.InvalidRequestError, s.commit)
+        s.rollback()
+        assert u2 not in s
+        assert a2 not in s
+        assert u1 in s
+        assert a1 in s
+        assert u1.name == 'ed'
+        assert a1.email_address == 'foo'
+        u1.name = 'edward'
+        a1.email_address = 'foober'
+        s.commit()
+        assert s.query(User).all() == [User(id=1, name='edward', addresses=[Address(email_address='foober')])]
+
+    @testing.requires.savepoints
+    def test_pk_violation_with_savepoint(self):
+        s = self.session()
+        a1 = Address(email_address='foo')
+        u1 = User(id=1, name='ed', addresses=[a1])
+        s.add(u1)
+        s.commit()
+
+        a2 = Address(email_address='bar')
+        u2 = User(id=1, name='jack', addresses=[a2])
+
+        u1.name = 'edward'
+        a1.email_address = 'foober'
+        s.begin_nested()
+        s.add(u2)
+        self.assertRaises(sa_exc.FlushError, s.commit)
+        self.assertRaises(sa_exc.InvalidRequestError, s.commit)
+        s.rollback()
+        assert u2 not in s
+        assert a2 not in s
+        assert u1 in s
+        assert a1 in s
+
+        s.commit()
+        assert s.query(User).all() == [User(id=1, name='edward', addresses=[Address(email_address='foober')])]
+
+
+class SavepointTest(TransactionTest):
+
+    only_tables = True
+
+    @testing.requires.savepoints
+    def test_savepoint_rollback(self):
+        s = self.session()
+        u1 = User(name='ed')
+        u2 = User(name='jack')
+        s.add_all([u1, u2])
+
+        s.begin_nested()
+        u3 = User(name='wendy')
+        u4 = User(name='foo')
+        u1.name = 'edward'
+        u2.name = 'jackward'
+        s.add_all([u3, u4])
+        self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+        s.rollback()
+        assert u1.name == 'ed'
+        assert u2.name == 'jack'
+        self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
+        s.commit()
+        assert u1.name == 'ed'
+        assert u2.name == 'jack'
+        self.assertEquals(s.query(User.name).order_by(User.id).all(), [('ed',), ('jack',)])
+
+    @testing.requires.savepoints
+    def test_savepoint_commit(self):
+        s = self.session()
+        u1 = User(name='ed')
+        u2 = User(name='jack')
+        s.add_all([u1, u2])
+
+        s.begin_nested()
+        u3 = User(name='wendy')
+        u4 = User(name='foo')
+        u1.name = 'edward'
+        u2.name = 'jackward'
+        s.add_all([u3, u4])
+        self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+        s.commit()
+        def go():
+            assert u1.name == 'edward'
+            assert u2.name == 'jackward'
+            self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+        self.assert_sql_count(testing.db, go, 1)
+
+        s.commit()
+        self.assertEquals(s.query(User.name).order_by(User.id).all(), [('edward',), ('jackward',), ('wendy',), ('foo',)])
+
+    @testing.requires.savepoints
+    def test_savepoint_rollback_collections(self):
+        s = self.session()
+        u1 = User(name='ed', addresses=[Address(email_address='foo')])
+        s.add(u1)
+        s.commit()
+
+        u1.name='edward'
+        u1.addresses.append(Address(email_address='bar'))
+        s.begin_nested()
+        u2 = User(name='jack', addresses=[Address(email_address='bat')])
+        s.add(u2)
+        self.assertEquals(s.query(User).order_by(User.id).all(),
+            [
+                User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+                User(name='jack', addresses=[Address(email_address='bat')])
+            ]
+        )
+        s.rollback()
+        self.assertEquals(s.query(User).order_by(User.id).all(),
+            [
+                User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+            ]
+        )
+        s.commit()
+        self.assertEquals(s.query(User).order_by(User.id).all(),
+            [
+                User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+            ]
+        )
+
+    @testing.requires.savepoints
+    def test_savepoint_commit_collections(self):
+        s = self.session()
+        u1 = User(name='ed', addresses=[Address(email_address='foo')])
+        s.add(u1)
+        s.commit()
+
+        u1.name='edward'
+        u1.addresses.append(Address(email_address='bar'))
+        s.begin_nested()
+        u2 = User(name='jack', addresses=[Address(email_address='bat')])
+        s.add(u2)
+        self.assertEquals(s.query(User).order_by(User.id).all(),
+            [
+                User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+                User(name='jack', addresses=[Address(email_address='bat')])
+            ]
+        )
+        s.commit()
+        self.assertEquals(s.query(User).order_by(User.id).all(),
+            [
+                User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+                User(name='jack', addresses=[Address(email_address='bat')])
+            ]
+        )
+        s.commit()
+        self.assertEquals(s.query(User).order_by(User.id).all(),
+            [
+                User(name='edward', addresses=[Address(email_address='foo'), Address(email_address='bar')]),
+                User(name='jack', addresses=[Address(email_address='bat')])
+            ]
+        )
+
+    @testing.requires.savepoints
+    def test_expunge_pending_on_rollback(self):
+        sess = self.session()
+
+        sess.begin_nested()
+        u2= User(name='newuser')
+        sess.add(u2)
+        assert u2 in sess
+        sess.rollback()
+        assert u2 not in sess
+
+    @testing.requires.savepoints
+    def test_update_deleted_on_rollback(self):
+        s = self.session()
+        u1 = User(name='ed')
+        s.add(u1)
+        s.commit()
+
+        s.begin_nested()
+        s.delete(u1)
+        assert u1 in s.deleted
+        s.rollback()
+        assert u1 in s
+        assert u1 not in s.deleted
+
+
+
+class AutocommitTest(TransactionTest):
+    def test_begin_nested_requires_trans(self):
+        sess = create_session(autocommit=True)
+        self.assertRaises(sa_exc.InvalidRequestError, sess.begin_nested)
+
+
+
+if __name__ == '__main__':
+    testenv.main()
index cd2a3005ea4aa135feb3dfc0d9d23e7479491fde..4c6f6f4cff6df050c8ad33d9a0fd55ea8d439566 100644 (file)
@@ -5,8 +5,9 @@
 import testenv; testenv.configure_for_tests()
 import pickleable
 from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc as sa_exc, sql
 from sqlalchemy.orm import *
+from sqlalchemy.orm import attributes, exc as orm_exc
 from testlib import *
 from testlib.tables import *
 from testlib import engines, tables, fixtures
@@ -14,7 +15,7 @@ from testlib import engines, tables, fixtures
 
 # TODO: convert suite to not use Session.mapper, use fixtures.Base
 # with explicit session.save()
-Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
+Session = scoped_session(sessionmaker(autoflush=True, autocommit=False, autoexpire=False))
 orm_mapper = mapper
 mapper = Session.mapper
 
@@ -28,8 +29,10 @@ class HistoryTest(ORMTest):
 
     def test_backref(self):
         s = Session()
-        class User(object):pass
-        class Address(object):pass
+        class User(object):
+            def __init__(self, **kw): pass
+        class Address(object):
+            def __init__(self, _sa_session=None): pass
         am = mapper(Address, addresses)
         m = mapper(User, users, properties = dict(
             addresses = relation(am, backref='user', lazy=False))
@@ -59,7 +62,9 @@ class VersioningTest(ORMTest):
     @engines.close_open_connections
     def test_basic(self):
         s = Session(scope=None)
-        class Foo(object):pass
+        class Foo(object):
+            def __init__(self, value, _sa_session=None):
+                self.value = value
         mapper(Foo, version_table, version_id_col=version_table.c.version_id)
         f1 = Foo(value='f1', _sa_session=s)
         f2 = Foo(value='f2', _sa_session=s)
@@ -67,26 +72,22 @@ class VersioningTest(ORMTest):
 
         f1.value='f1rev2'
         s.commit()
+
         s2 = Session()
         f1_s = s2.query(Foo).get(f1.id)
         f1_s.value='f1rev3'
         s2.commit()
 
         f1.value='f1rev3mine'
-        success = False
-        try:
-            # a concurrent session has modified this, should throw
-            # an exception
-            s.commit()
-        except exceptions.ConcurrentModificationError, e:
-            #print e
-            success = True
 
         # Only dialects with a sane rowcount can detect the ConcurrentModificationError
         if testing.db.dialect.supports_sane_rowcount:
-            assert success
-
-        s.close()
+            self.assertRaises(orm_exc.ConcurrentModificationError, s.commit)
+            s.rollback()
+        else:
+            s.commit()
+        
+        # new in 0.5 !  dont need to close the session
         f1 = s.query(Foo).get(f1.id)
         f2 = s.query(Foo).get(f2.id)
 
@@ -95,33 +96,29 @@ class VersioningTest(ORMTest):
 
         s.delete(f1)
         s.delete(f2)
-        success = False
-        try:
-            s.commit()
-        except exceptions.ConcurrentModificationError, e:
-            #print e
-            success = True
+
         if testing.db.dialect.supports_sane_multi_rowcount:
-            assert success
+            self.assertRaises(orm_exc.ConcurrentModificationError, s.commit)
+        else:
+            s.commit()
 
     @engines.close_open_connections
     def test_versioncheck(self):
         """test that query.with_lockmode performs a 'version check' on an already loaded instance"""
         s1 = Session(scope=None)
-        class Foo(object):pass
+        class Foo(object):
+            def __init__(self, _sa_session=None): pass
         mapper(Foo, version_table, version_id_col=version_table.c.version_id)
-        f1s1 =Foo(value='f1', _sa_session=s1)
+        f1s1 = Foo(_sa_session=s1)
+        f1s1.value = 'f1 value'
         s1.commit()
         s2 = Session()
         f1s2 = s2.query(Foo).get(f1s1.id)
         f1s2.value='f1 new value'
         s2.commit()
-        try:
-            # load, version is wrong
-            s1.query(Foo).with_lockmode('read').get(f1s1.id)
-            assert False
-        except exceptions.ConcurrentModificationError, e:
-            assert True
+        # load, version is wrong
+        self.assertRaises(orm_exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id)
+
         # reload it
         s1.query(Foo).load(f1s1.id)
         # now assert version OK
@@ -135,9 +132,11 @@ class VersioningTest(ORMTest):
     def test_noversioncheck(self):
         """test that query.with_lockmode works OK when the mapper has no version id col"""
         s1 = Session()
-        class Foo(object):pass
+        class Foo(object):
+            def __init__(self, _sa_session=None): pass
         mapper(Foo, version_table)
-        f1s1 =Foo(value='f1', _sa_session=s1)
+        f1s1 =Foo(_sa_session=s1)
+        f1s1.value = 'foo'
         f1s1.version_id=0
         s1.commit()
         s2 = Session()
@@ -271,9 +270,11 @@ class MutableTypesTest(ORMTest):
         Session.commit()
         Session.close()
         f2 = Session.query(Foo).filter_by(id=f1.id).one()
+        assert 'data' in attributes.instance_state(f2).unmodified
         assert f2.data == f1.data
         f2.data.y = 19
         assert f2 in Session.dirty
+        assert 'data' not in attributes.instance_state(f2).unmodified
         Session.commit()
         Session.close()
         f3 = Session.query(Foo).filter_by(id=f1.id).one()
@@ -439,8 +440,11 @@ class PKTest(ORMTest):
         e.multi_rev = 2
         Session.commit()
         Session.close()
-        e2 = Query(Entry).get((e.multi_id, 2))
-        self.assert_(e is not e2 and e._instance_key == e2._instance_key)
+        e2 = Session.query(Entry).get((e.multi_id, 2))
+        self.assert_(e is not e2)
+        state = attributes.instance_state(e)
+        state2 = attributes.instance_state(e2)
+        self.assert_(state.key == state2.key)
 
     # this one works with sqlite since we are manually setting up pk values
     def test_manualpk(self):
@@ -514,8 +518,7 @@ class ClauseAttributesTest(ORMTest):
             Column('counter', Integer, default=1))
 
     def test_update(self):
-        class User(object):
-            pass
+        class User(fixtures.Base): pass
         mapper(User, users_table)
         u = User(name='test')
         sess = Session()
@@ -530,8 +533,7 @@ class ClauseAttributesTest(ORMTest):
         self.assert_sql_count(testing.db, go, 1)
 
     def test_multi_update(self):
-        class User(object):
-            pass
+        class User(fixtures.Base): pass
         mapper(User, users_table)
         u = User(name='test')
         sess = Session()
@@ -553,8 +555,7 @@ class ClauseAttributesTest(ORMTest):
 
     @testing.unsupported('mssql')
     def test_insert(self):
-        class User(object):
-            pass
+        class User(fixtures.Base): pass
         mapper(User, users_table)
         u = User(name='test', counter=select([5]))
         sess = Session()
@@ -641,7 +642,7 @@ class ExtraPassiveDeletesTest(ORMTest):
                 'children':relation(MyOtherClass, passive_deletes='all', cascade="all")
             })
             assert False
-        except exceptions.ArgumentError, e:
+        except sa_exc.ArgumentError, e:
             assert str(e) == "Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade"
 
     @testing.unsupported('sqlite')
@@ -669,7 +670,7 @@ class ExtraPassiveDeletesTest(ORMTest):
         assert myothertable.count().scalar() == 4
         mc = sess.query(MyClass).get(mc.id)
         sess.delete(mc)
-        self.assertRaises(exceptions.DBAPIError, sess.commit)
+        self.assertRaises(sa_exc.DBAPIError, sess.commit)
 
     @testing.unsupported('sqlite')
     def test_extra_passive_2(self):
@@ -694,7 +695,7 @@ class ExtraPassiveDeletesTest(ORMTest):
         mc = sess.query(MyClass).get(mc.id)
         sess.delete(mc)
         mc.children[0].data = 'some new data'
-        self.assertRaises(exceptions.DBAPIError, sess.commit)
+        self.assertRaises(sa_exc.DBAPIError, sess.commit)
 
 
 class DefaultTest(ORMTest):
@@ -736,7 +737,7 @@ class DefaultTest(ORMTest):
             secondary_table.append_column(Column('hoho', hohotype, ForeignKey('default_test.hoho')))
 
     def test_insert(self):
-        class Hoho(object):pass
+        class Hoho(fixtures.Base): pass
         mapper(Hoho, default_table)
 
         h1 = Hoho(hoho=althohoval)
@@ -790,7 +791,7 @@ class DefaultTest(ORMTest):
 
     def test_insert_nopostfetch(self):
         # populates the PassiveDefaults explicitly so there is no "post-update"
-        class Hoho(object):pass
+        class Hoho(fixtures.Base): pass
         mapper(Hoho, default_table)
 
         h1 = Hoho(hoho="15", counter="15")
@@ -803,7 +804,7 @@ class DefaultTest(ORMTest):
         self.assert_sql_count(testing.db, go, 0)
 
     def test_update(self):
-        class Hoho(object):pass
+        class Hoho(fixtures.Base): pass
         mapper(Hoho, default_table)
         h1 = Hoho()
         Session.commit()
@@ -971,7 +972,7 @@ class OneToManyTest(ORMTest):
 
     def test_o2m_delete_parent(self):
         m = mapper(User, users, properties = dict(
-            address = relation(mapper(Address, addresses), lazy = True, uselist = False, private = False)
+            address = relation(mapper(Address, addresses), lazy=True, uselist=False)
         ))
         u = User()
         a = Address()
@@ -981,7 +982,10 @@ class OneToManyTest(ORMTest):
         Session.commit()
         Session.delete(u)
         Session.commit()
-        self.assert_(a.address_id is not None and a.user_id is None and u._instance_key not in Session.identity_map and a._instance_key in Session.identity_map)
+        self.assert_(a.address_id is not None)
+        self.assert_(a.user_id is None)
+        self.assert_(attributes.instance_state(a).key in Session.identity_map)
+        self.assert_(attributes.instance_state(u).key not in Session.identity_map)
 
     def test_onetoone(self):
         m = mapper(User, users, properties = dict(
@@ -2029,7 +2033,7 @@ class TransactionTest(ORMTest):
         orm_mapper(T2, t2)
 
     def test_close_transaction_on_commit_fail(self):
-        Session = sessionmaker(autoflush=False, transactional=False)
+        Session = sessionmaker(autoflush=False, autocommit=True)
         sess = Session()
 
         # with a deferred constraint, this fails at COMMIT time instead
diff --git a/test/orm/utils.py b/test/orm/utils.py
new file mode 100644 (file)
index 0000000..4bb2464
--- /dev/null
@@ -0,0 +1,208 @@
+import testenv; testenv.configure_for_tests()
+from sqlalchemy.orm import interfaces, util
+from testlib import *
+from testlib import fixtures
+from sqlalchemy import Column
+from sqlalchemy import Integer
+from sqlalchemy import MetaData
+from sqlalchemy import Table
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm import mapper
+
+
+class ExtensionCarrierTest(TestBase):
+    def test_basic(self):
+        carrier = util.ExtensionCarrier()
+
+        assert 'translate_row' not in carrier.methods
+        assert carrier.translate_row() is interfaces.EXT_CONTINUE
+        assert 'translate_row' not in carrier.methods
+
+        self.assertRaises(AttributeError, lambda: carrier.snickysnack)
+
+        class Partial(object):
+            def __init__(self, marker):
+                self.marker = marker
+            def translate_row(self, row):
+                return self.marker
+
+        carrier.append(Partial('end'))
+        assert 'translate_row' in carrier.methods
+        assert carrier.translate_row(None) == 'end'
+
+        carrier.push(Partial('front'))
+        assert carrier.translate_row(None) == 'front'
+
+        assert 'populate_instance' not in carrier.methods
+        carrier.append(interfaces.MapperExtension)
+        assert 'populate_instance' in carrier.methods
+
+        assert carrier.interface
+        for m in carrier.interface:
+            assert getattr(interfaces.MapperExtension, m)
+
+class AliasedClassTest(TestBase):
+    def point_map(self, cls):
+        table = Table('point', MetaData(),
+                    Column('id', Integer(), primary_key=True),
+                    Column('x', Integer),
+                    Column('y', Integer))
+        mapper(cls, table)
+        return table
+
+    def test_simple(self):
+        class Point(object):
+            pass
+        table = self.point_map(Point)
+
+        alias = aliased(Point)
+
+        assert alias.id
+        assert alias.x
+        assert alias.y
+
+        assert Point.id.__clause_element__().table is table
+        assert alias.id.__clause_element__().table is not table
+
+    def test_notcallable(self):
+        class Point(object):
+            pass
+        table = self.point_map(Point)
+        alias = aliased(Point)
+
+        self.assertRaises(TypeError, alias)
+
+    def test_instancemethods(self):
+        class Point(object):
+            def zero(self):
+                self.x, self.y = 0, 0
+
+        table = self.point_map(Point)
+        alias = aliased(Point)
+
+        assert Point.zero
+        assert not getattr(alias, 'zero')
+
+    def test_classmethods(self):
+        class Point(object):
+            @classmethod
+            def max_x(cls):
+                return 100
+
+        table = self.point_map(Point)
+        alias = aliased(Point)
+
+        assert Point.max_x
+        assert alias.max_x
+        assert Point.max_x() == alias.max_x()
+
+    def test_simpleproperties(self):
+        class Point(object):
+            @property
+            def max_x(self):
+                return 100
+
+        table = self.point_map(Point)
+        alias = aliased(Point)
+
+        assert Point.max_x
+        assert Point.max_x != 100
+        assert alias.max_x
+        assert Point.max_x is alias.max_x
+
+    def test_descriptors(self):
+        class descriptor(object):
+            """Tortured..."""
+            def __init__(self, fn):
+                self.fn = fn
+            def __get__(self, obj, owner):
+                if obj is not None:
+                    return self.fn(obj, obj)
+                else:
+                    return self
+            def method(self):
+                return 'method'
+
+        class Point(object):
+            center = (0, 0)
+            @descriptor
+            def thing(self, arg):
+                return arg.center
+
+        table = self.point_map(Point)
+        alias = aliased(Point)
+
+        assert Point.thing != (0, 0)
+        assert Point().thing == (0, 0)
+        assert Point.thing.method() == 'method'
+
+        assert alias.thing != (0, 0)
+        assert alias.thing.method() == 'method'
+
+    def test_hybrid_descriptors(self):
+        from sqlalchemy import Column  # override testlib's override
+        import new
+
+        class MethodDescriptor(object):
+            def __init__(self, func):
+                self.func = func
+            def __get__(self, instance, owner):
+                if instance is None:
+                    args = (self.func, owner, owner.__class__)
+                else:
+                    args = (self.func, instance, owner)
+                return new.instancemethod(*args)
+
+        class PropertyDescriptor(object):
+            def __init__(self, fget, fset, fdel):
+                self.fget = fget
+                self.fset = fset
+                self.fdel = fdel
+            def __get__(self, instance, owner):
+                if instance is None:
+                    return self.fget(owner)
+                else:
+                    return self.fget(instance)
+            def __set__(self, instance, value):
+                self.fset(instance, value)
+            def __delete__(self, instance):
+                self.fdel(instance)
+        hybrid = MethodDescriptor
+        def hybrid_property(fget, fset=None, fdel=None):
+            return PropertyDescriptor(fget, fset, fdel)
+
+        def assert_table(expr, table):
+            for child in expr.get_children():
+                if isinstance(child, Column):
+                    assert child.table is table
+
+        class Point(object):
+            def __init__(self, x, y):
+                self.x, self.y = x, y
+            @hybrid
+            def left_of(self, other):
+                return self.x < other.x
+
+            double_x = hybrid_property(lambda self: self.x * 2)
+
+        table = self.point_map(Point)
+        alias = aliased(Point)
+        alias_table = alias.x.__clause_element__().table
+        assert table is not alias_table
+
+        p1 = Point(-10, -10)
+        p2 = Point(20, 20)
+
+        assert p1.left_of(p2)
+        assert p1.double_x == -20
+
+        assert_table(Point.double_x, table)
+        assert_table(alias.double_x, alias_table)
+
+        assert_table(Point.left_of(p2), table)
+        assert_table(alias.left_of(p2), alias_table)
+
+
+if __name__ == '__main__':
+    testenv.main()
+
index bc2834ff74736b59057e51f81d62779784754c45..a848b866cc8ef1fb376343bdf0275366484209ee 100644 (file)
@@ -37,6 +37,7 @@ def load():
 
 @profiling.profiled('masseagerload', always=True, sort=['cumulative'])
 def masseagerload(session):
+    session.begin()
     query = session.query(Item)
     l = query.all()
     print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems"
index 4e1111aa2aa7074ed84720dca6eebaa075eea20e..cd0a29ee3d6afa1d8216773824046f21979e5476 100644 (file)
@@ -15,11 +15,11 @@ class CompileTest(TestBase, AssertsExecutionResults):
             Column('c1', Integer, primary_key=True),
             Column('c2', String(30)))
 
-    @profiling.function_call_count(74, {'2.3': 44, '2.4': 42})
+    @profiling.function_call_count(67, {'2.3': 44, '2.4': 42})
     def test_insert(self):
         t1.insert().compile()
 
-    @profiling.function_call_count(75, {'2.3': 47, '2.4': 42})
+    @profiling.function_call_count(68, {'2.3': 47, '2.4': 42})
     def test_update(self):
         t1.update().compile()
 
index 0994b5d4be1e0f6ff2e6e19af692a0f6be5d0728..cdf663a4e2c57fd6439407775bd55b4c32dc9179 100644 (file)
@@ -332,7 +332,7 @@ class ZooMarkTest(TestBase):
     def test_profile_2_insert(self):
         self.test_baseline_2_insert()
 
-    @profiling.function_call_count(4923, {'2.4': 2557})
+    @profiling.function_call_count(4662, {'2.4': 2557})
     def test_profile_3_properties(self):
         self.test_baseline_3_properties()
 
@@ -344,7 +344,7 @@ class ZooMarkTest(TestBase):
     def test_profile_5_aggregates(self):
         self.test_baseline_5_aggregates()
 
-    @profiling.function_call_count(1988, {'2.4': 1048})
+    @profiling.function_call_count(1882, {'2.4': 1048})
     def test_profile_6_editing(self):
         self.test_baseline_6_editing()
 
index 6aecefd3c33f740806e82d7ed905d5ed3a031bd4..876f820b5c51a8e675d04def11797b453a097a7c 100644 (file)
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
 import sys
 from sqlalchemy import *
 from testlib import *
-from sqlalchemy import util, exceptions
+from sqlalchemy import util, exc
 from sqlalchemy.sql import table, column
 
 
@@ -91,7 +91,7 @@ class CaseTest(TestBase, AssertsCompiledSQL):
     def test_literal_interpretation(self):
         t = table('test', column('col1'))
         
-        self.assertRaises(exceptions.ArgumentError, case, [("x", "y")])
+        self.assertRaises(exc.ArgumentError, case, [("x", "y")])
         
         self.assert_compile(case([("x", "y")], value=t.c.col1), "CASE test.col1 WHEN :param_1 THEN :param_2 END")
         self.assert_compile(case([(t.c.col1==7, "y")], else_="z"), "CASE WHEN (test.col1 = :col1_1) THEN :param_1 ELSE :param_2 END")
index 76bf9b389c6f41ef5b485f77ff036f64d823866e..661be891aee6ef0b54e4c8b45176bd1348ee1685 100644 (file)
@@ -1,6 +1,6 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc, sql
 from testlib import *
 from sqlalchemy import Table, Column  # don't use testlib's wrappers
 
@@ -37,7 +37,7 @@ class ColumnDefinitionTest(TestBase):
     def test_incomplete(self):
         c = self.columns()
 
-        self.assertRaises(exceptions.ArgumentError, Table, 't', MetaData(), *c)
+        self.assertRaises(exc.ArgumentError, Table, 't', MetaData(), *c)
 
     def test_incomplete_key(self):
         c = Column(Integer)
@@ -52,8 +52,8 @@ class ColumnDefinitionTest(TestBase):
 
 
     def test_bogus(self):
-        self.assertRaises(exceptions.ArgumentError, Column, 'foo', name='bar')
-        self.assertRaises(exceptions.ArgumentError, Column, 'foo', Integer,
+        self.assertRaises(exc.ArgumentError, Column, 'foo', name='bar')
+        self.assertRaises(exc.ArgumentError, Column, 'foo', Integer,
                           type_=Integer())
 
 if __name__ == "__main__":
index 2908e07da929ef10da54ed59a20567c99cf120af..966930ca9738c06071e4bba4ddc286ae49cc805b 100644 (file)
@@ -1,6 +1,6 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
-from sqlalchemy import exceptions
+from sqlalchemy import exc
 from testlib import *
 from testlib import config, engines
 
@@ -72,14 +72,14 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
         try:
             foo.insert().execute(id=2,x=5,y=9)
             assert False
-        except exceptions.SQLError:
+        except exc.SQLError:
             assert True
 
         bar.insert().execute(id=1,x=10)
         try:
             bar.insert().execute(id=2,x=5)
             assert False
-        except exceptions.SQLError:
+        except exc.SQLError:
             assert True
 
     def test_unique_constraint(self):
@@ -100,12 +100,12 @@ class ConstraintTest(TestBase, AssertsExecutionResults):
         try:
             foo.insert().execute(id=3, value='value1')
             assert False
-        except exceptions.SQLError:
+        except exc.SQLError:
             assert True
         try:
             bar.insert().execute(id=3, value='a', value2='b')
             assert False
-        except exceptions.SQLError:
+        except exc.SQLError:
             assert True
 
     def test_index_create(self):
index 22660c0607b568ee2b9eae156884474fcd71e736..e9ed21a650827fc276eb7bfc2e02a8f25bde6042 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import datetime
 from sqlalchemy import *
-from sqlalchemy import exceptions, schema, util
+from sqlalchemy import exc, schema, util
 from sqlalchemy.orm import mapper, create_session
 from testlib import *
 
@@ -122,7 +122,7 @@ class DefaultTest(TestBase):
             try:
                 c = ColumnDefault(fn)
                 assert False, str(fn)
-            except exceptions.ArgumentError, e:
+            except exc.ArgumentError, e:
                 assert str(e) == ex_msg
 
     def test_argsignature(self):
@@ -327,7 +327,7 @@ class AutoIncrementTest(TestBase):
                 nonai_table.insert().execute(data='row 1')
                 nonai_table.insert().execute(data='row 2')
                 assert False
-            except exceptions.SQLError, e:
+            except exc.SQLError, e:
                 print "Got exception", str(e)
                 assert True
 
index d1ce17c72f556bd7267e5ece011b53b45810cd4f..82814ef1b9c157d7df025ddc48680572a24c87cb 100644 (file)
@@ -2,7 +2,7 @@ import testenv; testenv.configure_for_tests()
 import datetime
 from sqlalchemy import *
 from sqlalchemy.sql import table, column
-from sqlalchemy import databases, exceptions, sql, util
+from sqlalchemy import databases, sql, util
 from sqlalchemy.sql.compiler import BIND_TEMPLATES
 from sqlalchemy.engine import default
 from sqlalchemy import types as sqltypes
index 820474282157322e51646226b6c28fc90283707d..cf5ea8235213684b8d1761619bb64894de2494f2 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from sqlalchemy.sql import table, column, ClauseElement
-from sqlalchemy.sql.expression import  _clone
+from sqlalchemy.sql.expression import  _clone, _from_objects
 from testlib import *
 from sqlalchemy.sql.visitors import *
 from sqlalchemy import util
@@ -82,14 +82,14 @@ class TraversalTest(TestBase, AssertsExecutionResults):
     def test_clone(self):
         struct = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2b")), A("expr3"))
 
-        class Vis(ClauseVisitor):
+        class Vis(CloningVisitor):
             def visit_a(self, a):
                 pass
             def visit_b(self, b):
                 pass
 
         vis = Vis()
-        s2 = vis.traverse(struct, clone=True)
+        s2 = vis.traverse(struct)
         assert struct == s2
         assert not struct.is_other(s2)
     
@@ -103,7 +103,7 @@ class TraversalTest(TestBase, AssertsExecutionResults):
                 pass
 
         vis = Vis()
-        s2 = vis.traverse(struct, clone=False)
+        s2 = vis.traverse(struct)
         assert struct == s2
         assert struct.is_other(s2)
 
@@ -112,7 +112,7 @@ class TraversalTest(TestBase, AssertsExecutionResults):
         struct2 = B(A("expr1"), A("expr2modified"), B(A("expr1b"), A("expr2b")), A("expr3"))
         struct3 = B(A("expr1"), A("expr2"), B(A("expr1b"), A("expr2bmodified")), A("expr3"))
 
-        class Vis(ClauseVisitor):
+        class Vis(CloningVisitor):
             def visit_a(self, a):
                 if a.expr == "expr2":
                     a.expr = "expr2modified"
@@ -120,12 +120,12 @@ class TraversalTest(TestBase, AssertsExecutionResults):
                 pass
 
         vis = Vis()
-        s2 = vis.traverse(struct, clone=True)
+        s2 = vis.traverse(struct)
         assert struct != s2
         assert not struct.is_other(s2)
         assert struct2 == s2
 
-        class Vis2(ClauseVisitor):
+        class Vis2(CloningVisitor):
             def visit_a(self, a):
                 if a.expr == "expr2b":
                     a.expr = "expr2bmodified"
@@ -133,7 +133,7 @@ class TraversalTest(TestBase, AssertsExecutionResults):
                 pass
 
         vis2 = Vis2()
-        s3 = vis2.traverse(struct, clone=True)
+        s3 = vis2.traverse(struct)
         assert struct != s3
         assert struct3 == s3
 
@@ -156,7 +156,7 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
 
     def test_binary(self):
         clause = t1.c.col2 == t2.c.col2
-        assert str(clause) == ClauseVisitor().traverse(clause, clone=True)
+        assert str(clause) == CloningVisitor().traverse(clause)
 
     def test_binary_anon_label_quirk(self):
         t = table('t1', column('col1'))
@@ -175,25 +175,25 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
     def test_join(self):
         clause = t1.join(t2, t1.c.col2==t2.c.col2)
         c1 = str(clause)
-        assert str(clause) == str(ClauseVisitor().traverse(clause, clone=True))
+        assert str(clause) == str(CloningVisitor().traverse(clause))
 
-        class Vis(ClauseVisitor):
+        class Vis(CloningVisitor):
             def visit_binary(self, binary):
                 binary.right = t2.c.col3
 
-        clause2 = Vis().traverse(clause, clone=True)
+        clause2 = Vis().traverse(clause)
         assert c1 == str(clause)
         assert str(clause2) == str(t1.join(t2, t1.c.col2==t2.c.col3))
     
     def test_text(self):
         clause = text("select * from table where foo=:bar", bindparams=[bindparam('bar')])
         c1 = str(clause)
-        class Vis(ClauseVisitor):
+        class Vis(CloningVisitor):
             def visit_textclause(self, text):
                 text.text = text.text + " SOME MODIFIER=:lala"
                 text.bindparams['lala'] = bindparam('lala')
 
-        clause2 = Vis().traverse(clause, clone=True)
+        clause2 = Vis().traverse(clause)
         assert c1 == str(clause)
         assert str(clause2) == c1 + " SOME MODIFIER=:lala"
         assert clause.bindparams.keys() == ['bar']
@@ -203,24 +203,27 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
         s2 = select([t1])
         s2_assert = str(s2)
         s3_assert = str(select([t1], t1.c.col2==7))
-        class Vis(ClauseVisitor):
+        class Vis(CloningVisitor):
             def visit_select(self, select):
                 select.append_whereclause(t1.c.col2==7)
-        s3 = Vis().traverse(s2, clone=True)
+        s3 = Vis().traverse(s2)
         assert str(s3) == s3_assert
         assert str(s2) == s2_assert
         print str(s2)
         print str(s3)
+        class Vis(ClauseVisitor):
+            def visit_select(self, select):
+                select.append_whereclause(t1.c.col2==7)
         Vis().traverse(s2)
         assert str(s2) == s3_assert
 
         print "------------------"
 
         s4_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col3==9)))
-        class Vis(ClauseVisitor):
+        class Vis(CloningVisitor):
             def visit_select(self, select):
                 select.append_whereclause(t1.c.col3==9)
-        s4 = Vis().traverse(s3, clone=True)
+        s4 = Vis().traverse(s3)
         print str(s3)
         print str(s4)
         assert str(s4) == s4_assert
@@ -228,12 +231,12 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
 
         print "------------------"
         s5_assert = str(select([t1], and_(t1.c.col2==7, t1.c.col1==9)))
-        class Vis(ClauseVisitor):
+        class Vis(CloningVisitor):
             def visit_binary(self, binary):
                 if binary.left is t1.c.col3:
                     binary.left = t1.c.col1
                     binary.right = bindparam("col1", unique=True)
-        s5 = Vis().traverse(s4, clone=True)
+        s5 = Vis().traverse(s4)
         print str(s4)
         print str(s5)
         assert str(s5) == s5_assert
@@ -241,13 +244,13 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
     
     def test_union(self):
         u = union(t1.select(), t2.select())
-        u2 = ClauseVisitor().traverse(u, clone=True)
+        u2 = CloningVisitor().traverse(u)
         assert str(u) == str(u2)
         assert [str(c) for c in u2.c] == [str(c) for c in u.c]
 
         u = union(t1.select(), t2.select())
         cols = [str(c) for c in u.c]
-        u2 = ClauseVisitor().traverse(u, clone=True)
+        u2 = CloningVisitor().traverse(u)
         assert str(u) == str(u2)
         assert [str(c) for c in u2.c] == cols
         
@@ -265,7 +268,7 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
         """test that unique bindparams change their name upon clone() to prevent conflicts"""
 
         s = select([t1], t1.c.col1==bindparam(None, unique=True)).alias()
-        s2 = ClauseVisitor().traverse(s, clone=True).alias()
+        s2 = CloningVisitor().traverse(s).alias()
         s3 = select([s], s.c.col2==s2.c.col2)
 
         self.assert_compile(s3, "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "\
@@ -274,7 +277,7 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
         "WHERE anon_1.col2 = anon_2.col2")
 
         s = select([t1], t1.c.col1==4).alias()
-        s2 = ClauseVisitor().traverse(s, clone=True).alias()
+        s2 = CloningVisitor().traverse(s).alias()
         s3 = select([s], s.c.col2==s2.c.col2)
         self.assert_compile(s3, "SELECT anon_1.col1, anon_1.col2, anon_1.col3 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, "\
         "table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1, "\
@@ -286,26 +289,51 @@ class ClauseTest(TestBase, AssertsCompiledSQL):
         subq = t2.select().alias('subq')
         s = select([t1.c.col1, subq.c.col1], from_obj=[t1, subq, t1.join(subq, t1.c.col1==subq.c.col2)])
         orig = str(s)
-        s2 = ClauseVisitor().traverse(s, clone=True)
+        s2 = CloningVisitor().traverse(s)
         assert orig == str(s) == str(s2)
 
-        s4 = ClauseVisitor().traverse(s2, clone=True)
+        s4 = CloningVisitor().traverse(s2)
         assert orig == str(s) == str(s2) == str(s4)
 
-        s3 = sql_util.ClauseAdapter(table('foo')).traverse(s, clone=True)
+        s3 = sql_util.ClauseAdapter(table('foo')).traverse(s)
         assert orig == str(s) == str(s3)
 
-        s4 = sql_util.ClauseAdapter(table('foo')).traverse(s3, clone=True)
+        s4 = sql_util.ClauseAdapter(table('foo')).traverse(s3)
         assert orig == str(s) == str(s3) == str(s4)
 
     def test_correlated_select(self):
         s = select(['*'], t1.c.col1==t2.c.col1, from_obj=[t1, t2]).correlate(t2)
-        class Vis(ClauseVisitor):
+        class Vis(CloningVisitor):
             def visit_select(self, select):
                 select.append_whereclause(t1.c.col2==7)
 
-        self.assert_compile(Vis().traverse(s, clone=True), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :col2_1")
-
+        self.assert_compile(Vis().traverse(s), "SELECT * FROM table1 WHERE table1.col1 = table2.col1 AND table1.col2 = :col2_1")
+    
+    def test_this_thing(self):
+        s = select([t1]).where(t1.c.col1=='foo').alias()
+        s2 = select([s.c.col1])
+        
+        self.assert_compile(s2, "SELECT anon_1.col1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1")
+        t1a = t1.alias()
+        s2 = sql_util.ClauseAdapter(t1a).traverse(s2)
+        self.assert_compile(s2, "SELECT anon_1.col1 FROM (SELECT table1_1.col1 AS col1, table1_1.col2 AS col2, table1_1.col3 AS col3 FROM table1 AS table1_1 WHERE table1_1.col1 = :col1_1) AS anon_1")
+        
+    def test_select_fromtwice(self):
+        t1a = t1.alias()
+        
+        s = select([1], t1.c.col1==t1a.c.col1, from_obj=t1a).correlate(t1)
+        self.assert_compile(s, "SELECT 1 FROM table1 AS table1_1 WHERE table1.col1 = table1_1.col1")
+        
+        s = CloningVisitor().traverse(s)
+        self.assert_compile(s, "SELECT 1 FROM table1 AS table1_1 WHERE table1.col1 = table1_1.col1")
+        
+        s = select([t1]).where(t1.c.col1=='foo').alias()
+        
+        s2 = select([1], t1.c.col1==s.c.col1, from_obj=s).correlate(t1)
+        self.assert_compile(s2, "SELECT 1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1 WHERE table1.col1 = anon_1.col1")
+        s2 = ReplacingCloningVisitor().traverse(s2)
+        self.assert_compile(s2, "SELECT 1 FROM (SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 WHERE table1.col1 = :col1_1) AS anon_1 WHERE table1.col1 = anon_1.col1")
+        
 class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
     def setUpAll(self):
         global t1, t2
@@ -330,69 +358,88 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
         assert t1alias in s._froms
 
         self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
-        s = vis.traverse(s, clone=True)
+        s = vis.traverse(s)
+
         assert t2alias not in s._froms  # not present because it's been cloned
+
         assert t1alias in s._froms # present because the adapter placed it there
+
         # correlate list on "s" needs to take into account the full _cloned_set for each element in _froms when correlating
         self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
 
         s = select(['*'], from_obj=[t1alias, t2alias]).correlate(t2alias).as_scalar()
         self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
-        s = vis.traverse(s, clone=True)
+        s = vis.traverse(s)
         self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
-        s = ClauseVisitor().traverse(s, clone=True)
+        s = CloningVisitor().traverse(s)
         self.assert_compile(select(['*'], t2alias.c.col1==s), "SELECT * FROM table2 AS t2alias WHERE t2alias.col1 = (SELECT * FROM table1 AS t1alias)")
         
         s = select(['*']).where(t1.c.col1==t2.c.col1).as_scalar()
         self.assert_compile(select([t1.c.col1, s]), "SELECT table1.col1, (SELECT * FROM table2 WHERE table1.col1 = table2.col1) AS anon_1 FROM table1")
         vis = sql_util.ClauseAdapter(t1alias)
-        s = vis.traverse(s, clone=True)
+        s = vis.traverse(s)
         self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
-        s = ClauseVisitor().traverse(s, clone=True)
+        s = CloningVisitor().traverse(s)
         self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
 
         s = select(['*']).where(t1.c.col1==t2.c.col1).correlate(t1).as_scalar()
         self.assert_compile(select([t1.c.col1, s]), "SELECT table1.col1, (SELECT * FROM table2 WHERE table1.col1 = table2.col1) AS anon_1 FROM table1")
         vis = sql_util.ClauseAdapter(t1alias)
-        s = vis.traverse(s, clone=True)
+        s = vis.traverse(s)
         self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
-        s = ClauseVisitor().traverse(s, clone=True)
+        s = CloningVisitor().traverse(s)
         self.assert_compile(select([t1alias.c.col1, s]), "SELECT t1alias.col1, (SELECT * FROM table2 WHERE t1alias.col1 = table2.col1) AS anon_1 FROM table1 AS t1alias")
-        
+
+    @testing.fails_on_everything_except()
+    def test_joins_dont_adapt(self):
+        # adapting to a join, i.e. ClauseAdapter(t1.join(t2)), doesn't make much sense.
+        # ClauseAdapter doesn't make any changes if it's against a straight join.
+        users = table('users', column('id'))
+        addresses = table('addresses', column('id'), column('user_id'))
+
+        ualias = users.alias()
+
+        s = select([func.count(addresses.c.id)], users.c.id==addresses.c.user_id).correlate(users) #.as_scalar().label(None)
+        s= sql_util.ClauseAdapter(ualias).traverse(s)
+
+        j1 = addresses.join(ualias, addresses.c.user_id==ualias.c.id)
+
+        self.assert_compile(sql_util.ClauseAdapter(j1).traverse(s), "SELECT count(addresses.id) AS count_1 FROM addresses WHERE users_1.id = addresses.user_id")
         
     def test_table_to_alias(self):
 
         t1alias = t1.alias('t1alias')
 
         vis = sql_util.ClauseAdapter(t1alias)
-        ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True)
-        assert ff._get_from_objects() == [t1alias]
+        ff = vis.traverse(func.count(t1.c.col1).label('foo'))
+        assert list(_from_objects(ff)) == [t1alias]
 
-        self.assert_compile(vis.traverse(select(['*'], from_obj=[t1]), clone=True), "SELECT * FROM table1 AS t1alias")
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2")
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2")
+        self.assert_compile(vis.traverse(select(['*'], from_obj=[t1])), "SELECT * FROM table1 AS t1alias")
+        self.assert_compile(select(['*'], t1.c.col1==t2.c.col2), "SELECT * FROM table1, table2 WHERE table1.col1 = table2.col2")
+        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2)), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2])), "SELECT * FROM table1 AS t1alias, table2 WHERE t1alias.col1 = table2.col2")
+        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1)), "SELECT * FROM table2 WHERE t1alias.col1 = table2.col2")
+        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2)), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = table2.col2")
 
 
         s = select(['*'], from_obj=[t1]).alias('foo')
         self.assert_compile(s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo")
-        self.assert_compile(vis.traverse(s.select(), clone=True), "SELECT foo.* FROM (SELECT * FROM table1 AS t1alias) AS foo")
+        self.assert_compile(vis.traverse(s.select()), "SELECT foo.* FROM (SELECT * FROM table1 AS t1alias) AS foo")
         self.assert_compile(s.select(), "SELECT foo.* FROM (SELECT * FROM table1) AS foo")
 
-        ff = vis.traverse(func.count(t1.c.col1).label('foo'), clone=True)
-        self.assert_compile(ff, "count(t1alias.col1) AS foo")
-        assert ff._get_from_objects() == [t1alias]
+        ff = vis.traverse(func.count(t1.c.col1).label('foo'))
+        self.assert_compile(select([ff]), "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias")
+        assert list(_from_objects(ff)) == [t1alias]
 
 # TODO:
     #    self.assert_compile(vis.traverse(select([func.count(t1.c.col1).label('foo')]), clone=True), "SELECT count(t1alias.col1) AS foo FROM table1 AS t1alias")
 
         t2alias = t2.alias('t2alias')
         vis.chain(sql_util.ClauseAdapter(t2alias))
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]), clone=True), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1), clone=True), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
-        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2), clone=True), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
+        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2)), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2])), "SELECT * FROM table1 AS t1alias, table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t1)), "SELECT * FROM table2 AS t2alias WHERE t1alias.col1 = t2alias.col2")
+        self.assert_compile(vis.traverse(select(['*'], t1.c.col1==t2.c.col2, from_obj=[t1, t2]).correlate(t2)), "SELECT * FROM table1 AS t1alias WHERE t1alias.col1 = t2alias.col2")
 
     def test_include_exclude(self):
         m = MetaData()
@@ -517,6 +564,65 @@ class ClauseAdapterTest(TestBase, AssertsCompiledSQL):
             "WHERE c.bid = anon_1.b_aid"
         )
 
+class SpliceJoinsTest(TestBase, AssertsCompiledSQL):
+    def setUpAll(self):
+        global table1, table2, table3, table4
+        def _table(name):
+            return table(name, column("col1"), column("col2"),column("col3"))
+        
+        table1, table2, table3, table4 = [_table(name) for name in ("table1", "table2", "table3", "table4")]    
+
+    def test_splice(self):
+        (t1, t2, t3, t4) = (table1, table2, table1.alias(), table2.alias())
+        
+        j = t1.join(t2, t1.c.col1==t2.c.col1).join(t3, t2.c.col1==t3.c.col1).join(t4, t4.c.col1==t1.c.col1)
+        
+        s = select([t1]).where(t1.c.col2<5).alias()
+        
+        self.assert_compile(sql_util.splice_joins(s, j), 
+            "(SELECT table1.col1 AS col1, table1.col2 AS col2, "\
+            "table1.col3 AS col3 FROM table1 WHERE table1.col2 < :col2_1) AS anon_1 "\
+            "JOIN table2 ON anon_1.col1 = table2.col1 JOIN table1 AS table1_1 ON table2.col1 = table1_1.col1 "\
+            "JOIN table2 AS table2_1 ON table2_1.col1 = anon_1.col1")
+
+    def test_stop_on(self):
+        (t1, t2, t3) = (table1, table2, table3)
+        
+        j1= t1.join(t2, t1.c.col1==t2.c.col1)
+        j2 = j1.join(t3, t2.c.col1==t3.c.col1)
+        
+        s = select([t1]).select_from(j1).alias()
+        
+        self.assert_compile(sql_util.splice_joins(s, j2), 
+            "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 JOIN table2 "\
+            "ON table1.col1 = table2.col1) AS anon_1 JOIN table2 ON anon_1.col1 = table2.col1 JOIN table3 "\
+            "ON table2.col1 = table3.col1"
+        )
+
+        self.assert_compile(sql_util.splice_joins(s, j2, j1), 
+            "(SELECT table1.col1 AS col1, table1.col2 AS col2, table1.col3 AS col3 FROM table1 "\
+            "JOIN table2 ON table1.col1 = table2.col1) AS anon_1 JOIN table3 ON table2.col1 = table3.col1")
+    
+    def test_splice_2(self):
+        t2a = table2.alias()
+        t3a = table3.alias()
+        j1 = table1.join(t2a, table1.c.col1==t2a.c.col1).join(t3a, t2a.c.col2==t3a.c.col2)
+        
+        t2b = table4.alias()
+        j2 = table1.join(t2b, table1.c.col3==t2b.c.col3)
+        
+        self.assert_compile(sql_util.splice_joins(table1, j1), 
+            "table1 JOIN table2 AS table2_1 ON table1.col1 = table2_1.col1 "\
+            "JOIN table3 AS table3_1 ON table2_1.col2 = table3_1.col2")
+            
+        self.assert_compile(sql_util.splice_joins(table1, j2), "table1 JOIN table4 AS table4_1 ON table1.col3 = table4_1.col3")
+
+        self.assert_compile(sql_util.splice_joins(sql_util.splice_joins(table1, j1), j2), 
+            "table1 JOIN table2 AS table2_1 ON table1.col1 = table2_1.col1 "\
+            "JOIN table3 AS table3_1 ON table2_1.col2 = table3_1.col2 "\
+            "JOIN table4 AS table4_1 ON table1.col3 = table4_1.col3")
+    
+        
 class SelectTest(TestBase, AssertsCompiledSQL):
     """tests the generative capability of Select"""
 
index e6d6714c2cf960f99ad5579a1b56498dbfd2435b..a305a5314695a9f6f450802e9daf7b5ac9920f10 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import datetime
 from sqlalchemy import *
-from sqlalchemy import exceptions, sql
+from sqlalchemy import exc, sql
 from sqlalchemy.engine import default
 from testlib import *
 
@@ -426,7 +426,7 @@ class QueryTest(TestBase):
         try:
             print r['user_id']
             assert False
-        except exceptions.InvalidRequestError, e:
+        except exc.InvalidRequestError, e:
             assert str(e) == "Ambiguous column name 'user_id' in result set! try 'use_labels' option on select statement." or \
                    str(e) == "Ambiguous column name 'USER_ID' in result set! try 'use_labels' option on select statement."
 
@@ -466,7 +466,7 @@ class QueryTest(TestBase):
     def test_cant_execute_join(self):
         try:
             users.join(addresses).execute()
-        except exceptions.ArgumentError, e:
+        except exc.ArgumentError, e:
             assert str(e).startswith('Not an executable clause: ')
 
 
index 825e836ff86e9440a36f606e71274b596f004b3d..d137b44a3ad1d17cdb7b015c6efa6e98e2ce94da 100644 (file)
@@ -4,7 +4,7 @@ from sqlalchemy import sql
 from sqlalchemy.sql import compiler
 from testlib import *
 
-class QuoteTest(TestBase):
+class QuoteTest(TestBase, AssertsCompiledSQL):
     def setUpAll(self):
         # TODO: figure out which databases/which identifiers allow special characters to be used,
         # such as:  spaces, quote characters, punctuation characters, set up tests for those as
@@ -67,7 +67,23 @@ class QuoteTest(TestBase):
         res2 = select([table2.c.d123, table2.c.u123, table2.c.MixedCase], use_labels=True).execute().fetchall()
         print res2
         assert(res2==[(1,2,3),(2,2,3),(4,3,2)])
+    
+    def test_quote_flag(self):
+        metadata = MetaData()
+        t1 = Table('TableOne', metadata, 
+            Column('ColumnOne', Integer), schema="FooBar")
+        self.assert_compile(t1.select(), '''SELECT "FooBar"."TableOne"."ColumnOne" FROM "FooBar"."TableOne"''')
+
+        metadata = MetaData()
+        t1 = Table('t1', metadata, 
+            Column('col1', Integer, quote=True), quote=True, schema="foo", quote_schema=True)
+        self.assert_compile(t1.select(), '''SELECT "foo"."t1"."col1" FROM "foo"."t1"''')
 
+        metadata = MetaData()
+        t1 = Table('TableOne', metadata, 
+            Column('ColumnOne', Integer, quote=False), quote=False, schema="FooBar", quote_schema=False)
+        self.assert_compile(t1.select(), '''SELECT FooBar.TableOne.ColumnOne FROM FooBar.TableOne''')
+        
     @testing.unsupported('oracle')
     def testlabels(self):
         """test the quoting of labels.
@@ -86,16 +102,19 @@ class QuoteTest(TestBase):
         table = Table("ImATable", metadata,
             Column("col1", Integer))
         x = select([table.c.col1.label("ImATable_col1")]).alias("SomeAlias")
-        assert str(select([x.c.ImATable_col1])) == '''SELECT "SomeAlias"."ImATable_col1" \nFROM (SELECT "ImATable".col1 AS "ImATable_col1" \nFROM "ImATable") AS "SomeAlias"'''
+        self.assert_compile(select([x.c.ImATable_col1]),
+            '''SELECT "SomeAlias"."ImATable_col1" FROM (SELECT "ImATable".col1 AS "ImATable_col1" FROM "ImATable") AS "SomeAlias"''')
 
         # note that 'foo' and 'FooCol' are literals already quoted
         x = select([sql.literal_column("'foo'").label("somelabel")], from_obj=[table]).alias("AnAlias")
         x = x.select()
-        assert str(x) == '''SELECT "AnAlias".somelabel \nFROM (SELECT 'foo' AS somelabel \nFROM "ImATable") AS "AnAlias"'''
+        self.assert_compile(x, 
+            '''SELECT "AnAlias".somelabel FROM (SELECT 'foo' AS somelabel FROM "ImATable") AS "AnAlias"''')
 
         x = select([sql.literal_column("'FooCol'").label("SomeLabel")], from_obj=[table])
         x = x.select()
-        assert str(x) == '''SELECT "SomeLabel" \nFROM (SELECT 'FooCol' AS "SomeLabel" \nFROM "ImATable")'''
+        self.assert_compile(x, 
+            '''SELECT "SomeLabel" FROM (SELECT 'FooCol' AS "SomeLabel" FROM "ImATable")''')
 
 
 class PreparerTest(TestBase):
index bea8621121be7f72272399fb71c80f4aca81e65c..3ecf63d3498be29fdfdb6ad52746af1b4fcd0b1a 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import datetime, re, operator
 from sqlalchemy import *
-from sqlalchemy import exceptions, sql, util
+from sqlalchemy import exc, sql, util
 from sqlalchemy.sql import table, column, compiler
 from sqlalchemy.databases import sqlite, postgres, mysql, oracle, firebird, mssql
 from testlib import *
@@ -154,7 +154,7 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         t2 = table('t2', column('c'), column('d'))
         s = select([t.c.a]).where(t.c.a==t2.c.d).as_scalar()
         s2 =select([t, t2, s])
-        self.assertRaises(exceptions.InvalidRequestError, str, s2)
+        self.assertRaises(exc.InvalidRequestError, str, s2)
 
         # intentional again
         s = s.correlate(t, t2)
@@ -245,14 +245,14 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         try:
             s = select([table1.c.myid, table1.c.name]).as_scalar()
             assert False
-        except exceptions.InvalidRequestError, err:
+        except exc.InvalidRequestError, err:
             assert str(err) == "Scalar select can only be created from a Select object that has exactly one column expression.", str(err)
 
         try:
             # generic function which will look at the type of expression
             func.coalesce(select([table1.c.myid]))
             assert False
-        except exceptions.InvalidRequestError, err:
+        except exc.InvalidRequestError, err:
             assert str(err) == "Select objects don't have a type.  Call as_scalar() on this Select object to return a 'scalar' version of this Select.", str(err)
 
         s = select([table1.c.myid], scalar=True, correlate=False)
@@ -278,12 +278,12 @@ sq.myothertable_othername AS sq_myothertable_othername FROM (" + sqstring + ") A
         s = select([table1.c.myid]).as_scalar()
         try:
             s.c.foo
-        except exceptions.InvalidRequestError, err:
+        except exc.InvalidRequestError, err:
             assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.'
 
         try:
             s.columns.foo
-        except exceptions.InvalidRequestError, err:
+        except exc.InvalidRequestError, err:
             assert str(err) == 'Scalar Select expression has no columns; use this object directly within a column-level expression.'
 
         zips = table('zips',
@@ -807,8 +807,8 @@ mytable.description FROM myothertable JOIN mytable ON mytable.myid = myothertabl
 
         self.assert_compile(
             select(
-                [join(join(table1, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid == table3.c.userid)
-            ]),
+                [join(join(table1, table2, table1.c.myid == table2.c.otherid), table3, table1.c.myid == table3.c.userid)]
+            ),
             "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername, thirdtable.userid, thirdtable.otherstuff FROM mytable JOIN myothertable ON mytable.myid = myothertable.otherid JOIN thirdtable ON mytable.myid = thirdtable.userid"
         )
 
@@ -854,7 +854,7 @@ EXISTS (select yay from foo where boo = lar)",
     def test_compound_selects(self):
         try:
             union(table3.select(), table1.select())
-        except exceptions.ArgumentError, err:
+        except exc.ArgumentError, err:
             assert str(err) == "All selectables passed to CompoundSelect must have identical numbers of columns; select #1 has 2 columns, select #2 has 3"
     
         x = union(
@@ -1048,10 +1048,10 @@ UNION SELECT mytable.myid FROM mytable"
 
         # check that conflicts with "unique" params are caught
         s = select([table1], or_(table1.c.myid==7, table1.c.myid==bindparam('myid_1')))
-        self.assertRaisesMessage(exceptions.CompileError, "conflicts with unique bind parameter of the same name", str, s)
+        self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
 
         s = select([table1], or_(table1.c.myid==7, table1.c.myid==8, table1.c.myid==bindparam('myid_1')))
-        self.assertRaisesMessage(exceptions.CompileError, "conflicts with unique bind parameter of the same name", str, s)
+        self.assertRaisesMessage(exc.CompileError, "conflicts with unique bind parameter of the same name", str, s)
 
 
 
@@ -1153,20 +1153,6 @@ UNION SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE
         self.assert_compile(select([table1], table1.c.myid.in_([])),
         "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)")
 
-    @testing.uses_deprecated('passing in_')
-    def test_in_deprecated_api(self):
-        self.assert_compile(select([table1], table1.c.myid.in_('abc')),
-        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1)")
-
-        self.assert_compile(select([table1], table1.c.myid.in_(1)),
-        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1)")
-
-        self.assert_compile(select([table1], table1.c.myid.in_(1,2)),
-        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE mytable.myid IN (:myid_1, :myid_2)")
-
-        self.assert_compile(select([table1], table1.c.myid.in_()),
-        "SELECT mytable.myid, mytable.name, mytable.description FROM mytable WHERE (CASE WHEN (mytable.myid IS NULL) THEN NULL ELSE 0 END = 1)")
-
     def test_cast(self):
         tbl = table('casttest',
                     column('id', Integer),
index b29ba8d5c0c30d21b7b7639fb6cd2d84f283ff68..66793a25bafbf5f718cb910cc8c8ce8f347c00c3 100755 (executable)
@@ -6,7 +6,7 @@ import testenv; testenv.configure_for_tests()
 from sqlalchemy import *
 from testlib import *
 from sqlalchemy.sql import util as sql_util
-from sqlalchemy import exceptions
+from sqlalchemy import exc
 
 metadata = MetaData()
 table = Table('table1', metadata,
@@ -164,7 +164,7 @@ class SelectableTest(TestBase, AssertsExecutionResults):
         print str(j)
         self.assert_(criterion.compare(j.onclause))
 
-    def testcolumnlabels(self):
+    def test_column_labels(self):
         a = select([table.c.col1.label('acol1'), table.c.col2.label('acol2'), table.c.col3.label('acol3')])
         print str(a)
         print [c for c in a.columns]
@@ -173,13 +173,13 @@ class SelectableTest(TestBase, AssertsExecutionResults):
         criterion = a.c.acol1 == table2.c.col2
         print str(j)
         self.assert_(criterion.compare(j.onclause))
-
+    
     def test_labeled_select_correspoinding(self):
         l1 = select([func.max(table.c.col1)]).label('foo')
 
         s = select([l1])
         assert s.corresponding_column(l1).name == s.c.foo
-
+        
         s = select([table.c.col1, l1])
         assert s.corresponding_column(l1).name == s.c.foo
 
@@ -193,7 +193,7 @@ class SelectableTest(TestBase, AssertsExecutionResults):
         print str(j.onclause)
         self.assert_(criterion.compare(j.onclause))
 
-    def testtablejoinedtoselectoftable(self):
+    def test_table_joined_to_select_of_table(self):
         metadata = MetaData()
         a = Table('a', metadata,
             Column('id', Integer, primary_key=True))
@@ -242,7 +242,7 @@ class SelectableTest(TestBase, AssertsExecutionResults):
 
         s = select([t2, t3], use_labels=True)
 
-        self.assertRaises(exceptions.NoReferencedTableError, s.join, t1)
+        self.assertRaises(exc.NoReferencedTableError, s.join, t1)
         
 class PrimaryKeyTest(TestBase, AssertsExecutionResults):
     def test_join_pk_collapse_implicit(self):
index 09a3702ee74c9583d21547db767d4df1c3e23233..9cd6f9bdb8a58f5aa3a104eae8447c4b3a9e7b56 100644 (file)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import datetime, os, pickleable, re
 from sqlalchemy import *
-from sqlalchemy import exceptions, types, util
+from sqlalchemy import exc, types, util
 from sqlalchemy.sql import operators
 import sqlalchemy.engine.url as url
 from sqlalchemy.databases import mssql, oracle, mysql, postgres, firebird
@@ -40,17 +40,6 @@ class AdaptTest(TestBase):
             assert isinstance(dialect_type, mssql.MSNVarchar)
             assert dialect_type.get_col_spec() == 'NVARCHAR(10)'
 
-    def testoracletext(self):
-        dialect = oracle.OracleDialect()
-        class MyDecoratedType(types.TypeDecorator):
-            impl = String
-            def copy(self):
-                return MyDecoratedType()
-
-        col = Column('', MyDecoratedType)
-        dialect_type = col.type.dialect_impl(dialect)
-        assert isinstance(dialect_type.impl, oracle.OracleText), repr(dialect_type.impl)
-
 
     def testoracletimestamp(self):
         dialect = oracle.OracleDialect()
@@ -77,29 +66,29 @@ class AdaptTest(TestBase):
         firebird_dialect = firebird.FBDialect()
 
         for dialect, start, test in [
-            (oracle_dialect, String(), oracle.OracleText),
+            (oracle_dialect, String(), oracle.OracleString),
             (oracle_dialect, VARCHAR(), oracle.OracleString),
             (oracle_dialect, String(50), oracle.OracleString),
-            (oracle_dialect, Unicode(), oracle.OracleText),
+            (oracle_dialect, Unicode(), oracle.OracleString),
             (oracle_dialect, UnicodeText(), oracle.OracleText),
             (oracle_dialect, NCHAR(), oracle.OracleString),
             (oracle_dialect, oracle.OracleRaw(50), oracle.OracleRaw),
-            (mysql_dialect, String(), mysql.MSText),
+            (mysql_dialect, String(), mysql.MSString),
             (mysql_dialect, VARCHAR(), mysql.MSString),
             (mysql_dialect, String(50), mysql.MSString),
-            (mysql_dialect, Unicode(), mysql.MSText),
+            (mysql_dialect, Unicode(), mysql.MSString),
             (mysql_dialect, UnicodeText(), mysql.MSText),
             (mysql_dialect, NCHAR(), mysql.MSNChar),
-            (postgres_dialect, String(), postgres.PGText),
+            (postgres_dialect, String(), postgres.PGString),
             (postgres_dialect, VARCHAR(), postgres.PGString),
             (postgres_dialect, String(50), postgres.PGString),
-            (postgres_dialect, Unicode(), postgres.PGText),
+            (postgres_dialect, Unicode(), postgres.PGString),
             (postgres_dialect, UnicodeText(), postgres.PGText),
             (postgres_dialect, NCHAR(), postgres.PGString),
-            (firebird_dialect, String(), firebird.FBText),
+            (firebird_dialect, String(), firebird.FBString),
             (firebird_dialect, VARCHAR(), firebird.FBString),
             (firebird_dialect, String(50), firebird.FBString),
-            (firebird_dialect, Unicode(), firebird.FBText),
+            (firebird_dialect, Unicode(), firebird.FBString),
             (firebird_dialect, UnicodeText(), firebird.FBText),
             (firebird_dialect, NCHAR(), firebird.FBString),
         ]:
@@ -118,9 +107,9 @@ class UserDefinedTest(TestBase):
     def testprocessing(self):
 
         global users
-        users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy3='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12, goofy9=12)
-        users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy3='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15, goofy9=15)
-        users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy3='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9, goofy9=9)
+        users.insert().execute(user_id = 2, goofy = 'jack', goofy2='jack', goofy4=u'jack', goofy5=u'jack', goofy6='jack', goofy7=u'jack', goofy8=12, goofy9=12)
+        users.insert().execute(user_id = 3, goofy = 'lala', goofy2='lala', goofy4=u'lala', goofy5=u'lala', goofy6='lala', goofy7=u'lala', goofy8=15, goofy9=15)
+        users.insert().execute(user_id = 4, goofy = 'fred', goofy2='fred', goofy4=u'fred', goofy5=u'fred', goofy6='fred', goofy7=u'fred', goofy8=9, goofy9=9)
 
         l = users.select().execute().fetchall()
         for assertstr, assertint, assertint2, row in zip(
@@ -130,11 +119,11 @@ class UserDefinedTest(TestBase):
             l
 
         ):
-            for col in row[1:8]:
+            for col in row[1:7]:
                 self.assertEquals(col, assertstr)
-            self.assertEquals(row[8], assertint)
-            self.assertEquals(row[9], assertint2)
-            for col in (row[4], row[5], row[7]):
+            self.assertEquals(row[7], assertint)
+            self.assertEquals(row[8], assertint2)
+            for col in (row[3], row[4], row[6]):
                 assert isinstance(col, unicode)
 
     def setUpAll(self):
@@ -250,13 +239,10 @@ class UserDefinedTest(TestBase):
             # decorated type with an argument, so its a String
             Column('goofy2', MyDecoratedType(50), nullable = False),
 
-            # decorated type without an argument, it will adapt_args to TEXT
-            Column('goofy3', MyDecoratedType, nullable = False),
-
-            Column('goofy4', MyUnicodeType, nullable = False),
-            Column('goofy5', LegacyUnicodeType, nullable = False),
+            Column('goofy4', MyUnicodeType(50), nullable = False),
+            Column('goofy5', LegacyUnicodeType(50), nullable = False),
             Column('goofy6', LegacyType, nullable = False),
-            Column('goofy7', MyNewUnicodeType, nullable = False),
+            Column('goofy7', MyNewUnicodeType(50), nullable = False),
             Column('goofy8', MyNewIntType, nullable = False),
             Column('goofy9', MyNewIntSubClass, nullable = False),
 
@@ -344,7 +330,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
         try:
             unicode_table.insert().execute(unicode_varchar='not unicode')
             assert False
-        except exceptions.SAWarning, e:
+        except exc.SAWarning, e:
             assert str(e) == "Unicode type received non-unicode bind param value 'not unicode'", str(e)
 
         unicode_engine = engines.utf8_engine(options={'convert_unicode':True,
@@ -353,7 +339,7 @@ class UnicodeTest(TestBase, AssertsExecutionResults):
             try:
                 unicode_engine.execute(unicode_table.insert(), plain_varchar='im not unicode')
                 assert False
-            except exceptions.InvalidRequestError, e:
+            except exc.InvalidRequestError, e:
                 assert str(e) == "Unicode type received non-unicode bind param value 'im not unicode'"
 
             @testing.emits_warning('.*non-unicode bind')
@@ -664,33 +650,20 @@ class DateTest(TestBase, AssertsExecutionResults):
             t.drop(checkfirst=True)
 
 class StringTest(TestBase, AssertsExecutionResults):
-    def test_nolen_string_deprecated(self):
+    
+    
+    def test_nolength_string(self):
+        # this tests what happens with String DDL with no length.
+        # seems like we need to decide amongst "VARCHAR" (sqlite, postgres), "TEXT" (mysql)
+        # i.e. theres some inconsisency here.
+        
         metadata = MetaData(testing.db)
         foo =Table('foo', metadata,
             Column('one', String))
-
-        # no warning
-        select([func.count("*")], bind=testing.db).execute()
-
-        try:
-            # warning during CREATE
-            foo.create()
-            assert False
-        except exceptions.SADeprecationWarning, e:
-            assert "Using String type with no length" in str(e)
-            assert re.search(r'\bone\b', str(e))
-
-        bar = Table('bar', metadata, Column('one', String(40)))
-
-        try:
-            # no warning
-            bar.create()
-
-            # no warning for non-lengthed string
-            select([func.count("*")], from_obj=bar).execute()
-        finally:
-            bar.drop()
-
+        
+        foo.create()
+        foo.drop()
+        
 def _missing_decimal():
     """Python implementation supports decimals"""
     try:
index 98552b0f39fd8fc5b3c790868e545e16f03290d2..67e56e3d8a9321c0f45535996e11aad77ada6058 100644 (file)
@@ -3,14 +3,21 @@
 Load after sqlalchemy imports to use instrumented stand-ins like Table.
 """
 
+import sys
 import testlib.config
 from testlib.schema import Table, Column
 from testlib.orm import mapper
 import testlib.testing as testing
-from testlib.testing import rowset
-from testlib.testing import TestBase, AssertsExecutionResults, ORMTest, AssertsCompiledSQL, ComparesTables
+from testlib.testing import \
+     AssertsCompiledSQL, \
+     AssertsExecutionResults, \
+     ComparesTables, \
+     ORMTest, \
+     TestBase, \
+     rowset
 import testlib.profiling as profiling
 import testlib.engines as engines
+import testlib.requires as requires
 from testlib.compat import set, frozenset, sorted, _function_named
 
 
@@ -18,6 +25,15 @@ __all__ = ('testing',
            'mapper',
            'Table', 'Column',
            'rowset',
-           'TestBase', 'AssertsExecutionResults', 'ORMTest', 'AssertsCompiledSQL', 'ComparesTables',
+           'TestBase', 'AssertsExecutionResults', 'ORMTest',
+           'AssertsCompiledSQL', 'ComparesTables',
            'profiling', 'engines',
            'set', 'frozenset', 'sorted', '_function_named')
+
+
+testing.requires = requires
+
+sys.modules['testlib.sa'] = sa = testing.CompositeModule(
+    'testlib.sa', 'sqlalchemy', 'testlib.schema', orm=testing.CompositeModule(
+    'testlib.sa.orm', 'sqlalchemy.orm', 'testlib.orm'))
+sys.modules['testlib.sa.orm'] = sa.orm
index ba12b78ac8128960865a8e01b5c247b82443c33f..fcb7fa1e9256af355f03b2c4676f97d42e4e50ff 100644 (file)
@@ -1,6 +1,6 @@
-import itertools, new, sys, warnings
+import new
 
-__all__ = 'set', 'frozenset', 'sorted', '_function_named', 'deque'
+__all__ = 'set', 'frozenset', 'sorted', '_function_named', 'deque', 'reversed'
 
 try:
     set = set
@@ -68,6 +68,16 @@ except NameError:
             l.sort()
         return l
 
+try:
+    reversed = reversed
+except NameError:
+    def reversed(seq):
+        i = len(seq) - 1
+        while  i >= 0:
+            yield seq[i]
+            i -= 1
+        raise StopIteration()
+
 try:
     from collections import deque
 except ImportError:
@@ -77,9 +87,7 @@ except ImportError:
         def popleft(self):
             return self.pop(0)
         def extendleft(self, iterable):
-            items = list(iterable)
-            items.reverse()
-            for x in items:
+            for x in reversed(list(iterable)):
                 self.insert(0, x)
 
 def _function_named(fn, newname):
index f5694df57e55c92fc1630ad7861f49f6df4e4bab..5ad35a066361a7862269f45bb5c7455c9f424ee6 100644 (file)
@@ -1,6 +1,6 @@
 import sys, types, weakref
 from testlib import config
-from testlib.compat import *
+from testlib.compat import set, _function_named, deque
 
 
 class ConnectionKiller(object):
index eb7eff279b43f2f9e7b98a2210e41eba2c3d7f8c..2d559f53b28cdd89ccc4faba3dc3d95310647e07 100644 (file)
@@ -14,8 +14,8 @@ Includes::
 """
 
 import sys
-from StringIO import StringIO
-from tokenize import *
+from tokenize import generate_tokens, INDENT, DEDENT, NAME, OP, NL, NEWLINE, \
+     NUMBER, STRING, COMMENT
 
 __all__ = ['py23_decorators', 'py23']
 
index e8d71179a8c129e320d3fbc2742c5629e0163173..f56b865c6823b9a9eaea84a9cb1ca3ed18f35ea8 100644 (file)
@@ -1,14 +1,16 @@
-# can't be imported until the path is setup; be sure to configure
-# first if covering.
-from sqlalchemy import *
-from sqlalchemy import util
-from testlib import *
-
-__all__ = ['keywords', 'addresses', 'Base', 'Keyword', 'FixtureTest', 'Dingaling', 'item_keywords', 
-            'dingalings', 'User', 'items', 'Fixtures', 'orders', 'install_fixture_data', 'Address', 'users', 
+from testlib.sa import MetaData, Table, Column, Integer, String, ForeignKey
+from testlib.sa.orm import attributes
+from testlib import ORMTest
+from testlib.compat import set
+
+
+__all__ = ['keywords', 'addresses', 'Base', 'Keyword', 'FixtureTest',
+           'Dingaling', 'item_keywords', 'dingalings', 'User', 'items',
+           'Fixtures', 'orders', 'install_fixture_data', 'Address', 'users',
             'order_items', 'Item', 'Order', 'fixtures']
-            
-_recursion_stack = util.Set()
+
+
+_recursion_stack = set()
 class Base(object):
     def __init__(self, **kwargs):
         for k in kwargs:
@@ -36,10 +38,15 @@ class Base(object):
         _recursion_stack.add(self)
         try:
             # pick the entity thats not SA persisted as the source
+            try:
+                state = attributes.instance_state(self)
+                key = state.key
+            except (KeyError, AttributeError):
+                key = None
             if other is None:
                 a = self
                 b = other
-            elif hasattr(self, '_instance_key'):
+            elif key is not None:
                 a = other
                 b = self
             else:
@@ -57,8 +64,9 @@ class Base(object):
                         battr = getattr(b, attr)
                     except AttributeError:
                         #print "b class does not have attribute named '%s'" % attr
+                        #raise
                         return False
-                    
+
                     if list(value) == list(battr):
                         continue
                     else:
@@ -84,43 +92,60 @@ metadata = MetaData()
 
 users = Table('users', metadata,
     Column('id', Integer, primary_key=True),
-    Column('name', String(30), nullable=False))
+    Column('name', String(30), nullable=False),
+    test_needs_acid=True,
+    test_needs_fk=True
+    )
 
 orders = Table('orders', metadata,
     Column('id', Integer, primary_key=True),
     Column('user_id', None, ForeignKey('users.id')),
     Column('address_id', None, ForeignKey('addresses.id')),
     Column('description', String(30)),
-    Column('isopen', Integer)
+    Column('isopen', Integer),
+    test_needs_acid=True,
+    test_needs_fk=True
     )
 
 addresses = Table('addresses', metadata,
     Column('id', Integer, primary_key=True),
     Column('user_id', None, ForeignKey('users.id')),
-    Column('email_address', String(50), nullable=False))
+    Column('email_address', String(50), nullable=False),
+    test_needs_acid=True,
+    test_needs_fk=True)
 
 dingalings = Table("dingalings", metadata,
     Column('id', Integer, primary_key=True),
     Column('address_id', None, ForeignKey('addresses.id')),
-    Column('data', String(30))
+    Column('data', String(30)),
+    test_needs_acid=True,
+    test_needs_fk=True
     )
 
 items = Table('items', metadata,
     Column('id', Integer, primary_key=True),
-    Column('description', String(30), nullable=False)
+    Column('description', String(30), nullable=False),
+    test_needs_acid=True,
+    test_needs_fk=True
     )
 
 order_items = Table('order_items', metadata,
     Column('item_id', None, ForeignKey('items.id')),
-    Column('order_id', None, ForeignKey('orders.id')))
+    Column('order_id', None, ForeignKey('orders.id')),
+    test_needs_acid=True,
+    test_needs_fk=True)
 
 item_keywords = Table('item_keywords', metadata,
     Column('item_id', None, ForeignKey('items.id')),
-    Column('keyword_id', None, ForeignKey('keywords.id')))
+    Column('keyword_id', None, ForeignKey('keywords.id')),
+    test_needs_acid=True,
+    test_needs_fk=True)
 
 keywords = Table('keywords', metadata,
     Column('id', Integer, primary_key=True),
-    Column('name', String(30), nullable=False)
+    Column('name', String(30), nullable=False),
+    test_needs_acid=True,
+    test_needs_fk=True
     )
 
 def install_fixture_data():
@@ -203,14 +228,15 @@ def install_fixture_data():
 
 class FixtureTest(ORMTest):
     refresh_data = False
-
+    only_tables = False
+    
     def setUpAll(self):
         super(FixtureTest, self).setUpAll()
-        if self.keep_data:
+        if not self.only_tables and self.keep_data:
             install_fixture_data()
 
     def setUp(self):
-        if self.refresh_data:
+        if not self.only_tables and self.refresh_data:
             install_fixture_data()
 
     def define_tables(self, meta):
index b452d1fb827498d42c1f5e72d91d28f0750f2e66..e423b9904257bbdab69adca19ba47c3c438fc604 100644 (file)
@@ -1,8 +1,7 @@
 """Profiling support for unit and performance tests."""
 
 import os, sys
-from testlib.config import parser, post_configure
-from testlib.compat import *
+from testlib.compat import set, _function_named
 import testlib.config
 
 __all__ = 'profiled', 'function_call_count', 'conditional_call_count'
@@ -26,8 +25,6 @@ def profiled(target=None, **target_opts):
     configuration and command-line options.
     """
 
-    import time, hotshot, hotshot.stats
-
     # manual or automatic namespacing by module would remove conflict issues
     if target is None:
         target = 'anonymous_target'
diff --git a/test/testlib/requires.py b/test/testlib/requires.py
new file mode 100644 (file)
index 0000000..a4604ff
--- /dev/null
@@ -0,0 +1,32 @@
+"""Global database feature support policy.
+
+Provides decorators to mark tests requiring specific feature support from the
+target database.
+
+"""
+from testlib import testing
+
+def savepoints(fn):
+    """Target database must support savepoints."""
+    return (testing.unsupported(
+            'access',
+            'mssql',
+            'sqlite',
+            'sybase',
+            )
+            (testing.exclude('mysql', '<', (5, 0, 3))
+             (fn)))
+
+def two_phase_transactions(fn):
+    """Target database must support two-phase transactions."""
+    return (testing.unsupported(
+            'access',
+            'firebird',
+            'maxdb',
+            'mssql',
+            'oracle',
+            'sqlite',
+            'sybase',
+            )
+            (testing.exclude('mysql', '<', (5, 0, 3))
+             (fn)))
index 37f3591aded32420ce99c2e3380ab503885a5b6e..9cedc02f0a360c91f4768a1ac2bd30ad377c178b 100644 (file)
@@ -1,5 +1,5 @@
 from testlib import testing
-import itertools
+
 schema = None
 
 __all__ = 'Table', 'Column',
index 33b1b20db9aaddbf087407b6892dceda5b7cddfc..3399acaae37ce78299a37cd78ff9bc434cc4505e 100644 (file)
@@ -1,8 +1,9 @@
 # can't be imported until the path is setup; be sure to configure
 # first if covering.
-from sqlalchemy import *
+
 from testlib import testing
-from testlib.schema import Table, Column
+from testlib.sa import MetaData, Table, Column, Integer, String, Sequence, \
+     ForeignKey, VARCHAR, INT
 
 
 # these are older test fixtures, used primarily by test/orm/mapper.py and
index cf0936e9228db56df1e264f59e2a01c134752dd3..1e2ca62e923ddbf683cae0ee5fe8fe12088cce08 100644 (file)
@@ -2,15 +2,27 @@
 
 # monkeypatches unittest.TestLoader.suiteClass at import time
 
-import itertools, os, operator, re, sys, unittest, warnings
+import itertools
+import operator
+import re
+import sys
+import types
+import unittest
+import warnings
 from cStringIO import StringIO
+
 import testlib.config as config
-from testlib.compat import *
+from testlib.compat import set, _function_named, reversed
 
-sql, sqltypes, schema, MetaData, clear_mappers, Session, util = None, None, None, None, None, None, None
-sa_exceptions = None
+# Delayed imports
+MetaData = None
+Session = None
+clear_mappers = None
+sa_exc = None
+schema = None
+sqltypes = None
+util = None
 
-__all__ = ('TestBase', 'AssertsExecutionResults', 'ComparesTables', 'ORMTest', 'AssertsCompiledSQL')
 
 _ops = { '<': operator.lt,
          '>': operator.gt,
@@ -25,6 +37,9 @@ _ops = { '<': operator.lt,
 # sugar ('testing.db'); set here by config() at runtime
 db = None
 
+# more sugar, installed by __init__
+requires = None
+
 def fails_if(callable_):
     """Mark a test as expected to fail if callable_ returns True.
 
@@ -224,17 +239,17 @@ def emits_warning(*messages):
     # - update: jython looks ok, it uses cpython's module
     def decorate(fn):
         def safe(*args, **kw):
-            global sa_exceptions
-            if sa_exceptions is None:
-                import sqlalchemy.exceptions as sa_exceptions
+            global sa_exc
+            if sa_exc is None:
+                import sqlalchemy.exc as sa_exc
 
             if not messages:
                 filters = [dict(action='ignore',
-                                category=sa_exceptions.SAWarning)]
+                                category=sa_exc.SAWarning)]
             else:
                 filters = [dict(action='ignore',
                                 message=message,
-                                category=sa_exceptions.SAWarning)
+                                category=sa_exc.SAWarning)
                            for message in messages ]
             for f in filters:
                 warnings.filterwarnings(**f)
@@ -259,17 +274,17 @@ def uses_deprecated(*messages):
 
     def decorate(fn):
         def safe(*args, **kw):
-            global sa_exceptions
-            if sa_exceptions is None:
-                import sqlalchemy.exceptions as sa_exceptions
+            global sa_exc
+            if sa_exc is None:
+                import sqlalchemy.exc as sa_exc
 
             if not messages:
                 filters = [dict(action='ignore',
-                                category=sa_exceptions.SADeprecationWarning)]
+                                category=sa_exc.SADeprecationWarning)]
             else:
                 filters = [dict(action='ignore',
                                 message=message,
-                                category=sa_exceptions.SADeprecationWarning)
+                                category=sa_exc.SADeprecationWarning)
                            for message in
                            [ (m.startswith('//') and
                               ('Call to deprecated function ' + m[2:]) or m)
@@ -287,13 +302,13 @@ def uses_deprecated(*messages):
 def resetwarnings():
     """Reset warning behavior to testing defaults."""
 
-    global sa_exceptions
-    if sa_exceptions is None:
-        import sqlalchemy.exceptions as sa_exceptions
+    global sa_exc
+    if sa_exc is None:
+        import sqlalchemy.exc as sa_exc
 
     warnings.resetwarnings()
-    warnings.filterwarnings('error', category=sa_exceptions.SADeprecationWarning)
-    warnings.filterwarnings('error', category=sa_exceptions.SAWarning)
+    warnings.filterwarnings('error', category=sa_exc.SADeprecationWarning)
+    warnings.filterwarnings('error', category=sa_exc.SAWarning)
 
     if sys.version_info < (2, 4):
         warnings.filterwarnings('ignore', category=FutureWarning)
@@ -338,6 +353,23 @@ def rowset(results):
     return set([tuple(row) for row in results])
 
 
+def eq_(a, b, msg=None):
+    """Assert a == b, with repr messaging on failure."""
+    assert a == b, msg or "%r != %r" % (a, b)
+
+def ne_(a, b, msg=None):
+    """Assert a != b, with repr messaging on failure."""
+    assert a != b, msg or "%r == %r" % (a, b)
+
+def is_(a, b, msg=None):
+    """Assert a is b, with repr messaging on failure."""
+    assert a is b, msg or "%r is not %r" % (a, b)
+
+def is_not_(a, b, msg=None):
+    """Assert a is not b, with repr messaging on failure."""
+    assert a is not b, msg or "%r is %r" % (a, b)
+
+
 class TestData(object):
     """Tracks SQL expressions as they are executed via an instrumented ExecutionContext."""
 
@@ -360,10 +392,6 @@ class ExecutionContextWrapper(object):
     can be tracked."""
 
     def __init__(self, ctx):
-        global sql
-        if sql is None:
-            from sqlalchemy import sql
-
         self.__dict__['ctx'] = ctx
     def __getattr__(self, key):
         return getattr(self.ctx, key)
@@ -414,7 +442,7 @@ class ExecutionContextWrapper(object):
 
             query = self.convert_statement(query)
             equivalent = ( (statement == query)
-                           or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) ) 
+                           or ( (config.db.name == 'oracle') and (self.trailing_underscore_pattern.sub(r'\1', statement) == query) )
                          ) \
                          and \
                          ( (params is None) or (params == parameters)
@@ -422,7 +450,7 @@ class ExecutionContextWrapper(object):
                                                for (k, v) in p.items()])
                                          for p in parameters]
                          )
-            testdata.unittest.assert_(equivalent, 
+            testdata.unittest.assert_(equivalent,
                     "Testing for query '%s' params %s, received '%s' with params %s" % (query, repr(params), statement, repr(parameters)))
         testdata.sql_count += 1
         self.ctx.post_execution()
@@ -445,6 +473,44 @@ class ExecutionContextWrapper(object):
             query = re.sub(r':([\w_]+)', repl, query)
         return query
 
+
+def _import_by_name(name):
+    submodule = name.split('.')[-1]
+    return __import__(name, globals(), locals(), [submodule])
+
+class CompositeModule(types.ModuleType):
+    """Merged attribute access for multiple modules."""
+
+    # break the habit
+    __all__ = ()
+
+    def __init__(self, name, *modules, **overrides):
+        """Construct a new lazy composite of modules.
+
+        Modules may be string names or module-like instances.  Individual
+        attribute overrides may be specified as keyword arguments for
+        convenience.
+
+        The constructed module will resolve attribute access in reverse order:
+        overrides, then each member of reversed(modules).  Modules specified
+        by name will be loaded lazily when encountered in attribute
+        resolution.
+
+        """
+        types.ModuleType.__init__(self, name)
+        self.__modules = list(reversed(modules))
+        for key, value in overrides.iteritems():
+            setattr(self, key, value)
+
+    def __getattr__(self, key):
+        for idx, mod in enumerate(self.__modules):
+            if isinstance(mod, basestring):
+                self.__modules[idx] = mod = _import_by_name(mod)
+            if hasattr(mod, key):
+                return getattr(mod, key)
+        raise AttributeError(key)
+
+
 class TestBase(unittest.TestCase):
     # A sequence of dialect names to exclude from the test class.
     __unsupported_on__ = ()
@@ -469,14 +535,14 @@ class TestBase(unittest.TestCase):
     def shortDescription(self):
         """overridden to not return docstrings"""
         return None
-    
+
     def assertRaisesMessage(self, except_cls, msg, callable_, *args, **kwargs):
         try:
             callable_(*args, **kwargs)
             assert False, "Callable did not raise expected exception"
         except except_cls, e:
             assert re.search(msg, str(e)), "Exception message did not match: '%s'" % str(e)
-        
+
     if not hasattr(unittest.TestCase, 'assertTrue'):
         assertTrue = unittest.TestCase.failUnless
     if not hasattr(unittest.TestCase, 'assertFalse'):
@@ -522,7 +588,7 @@ class ComparesTables(object):
                 set(type(c.type).__mro__).difference(base_mro)
                 )
             ) > 0, "Type '%s' doesn't correspond to type '%s'" % (reflected_c.type, c.type)
-            
+
             if isinstance(c.type, sqltypes.String):
                 self.assertEquals(c.type.length, reflected_c.type.length)
 
@@ -535,18 +601,18 @@ class ComparesTables(object):
             elif not c.primary_key or not against('postgres'):
                 print repr(c)
                 assert reflected_c.default is None, reflected_c.default
-        
+
         assert len(table.primary_key) == len(reflected_table.primary_key)
         for c in table.primary_key:
             assert reflected_table.primary_key.columns[c.name]
 
-    
+
 class AssertsExecutionResults(object):
     def assert_result(self, result, class_, *objects):
         result = list(result)
         print repr(result)
         self.assert_list(result, class_, objects)
-        
+
     def assert_list(self, result, class_, list):
         self.assert_(len(result) == len(list),
                      "result list is not the same size as test list, " +
@@ -675,10 +741,10 @@ class ORMTest(TestBase, AssertsExecutionResults):
 
     def define_tables(self, _otest_metadata):
         raise NotImplementedError()
-    
+
     def setup_mappers(self):
         pass
-        
+
     def insert_data(self):
         pass