]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- several ORM attributes have been removed or made private:
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Dec 2007 23:00:05 +0000 (23:00 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 1 Dec 2007 23:00:05 +0000 (23:00 +0000)
mapper.get_attr_by_column(), mapper.set_attr_by_column(),
mapper.pks_by_table, mapper.cascade_callable(),
MapperProperty.cascade_callable(), mapper.canload()
- refinements to mapper PK/table column organization, session cascading,
some naming convention work

13 files changed:
CHANGES
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/orm/dependency.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/sync.py
lib/sqlalchemy/sql/util.py
test/orm/inheritance/basic.py
test/orm/manytomany.py
test/orm/mapper.py

diff --git a/CHANGES b/CHANGES
index 1dc797d77edd6c38729a74d822149bddabdc50e0..6e275254f3e6e586751fc7879c270be8b881160f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -34,6 +34,11 @@ CHANGES
      relationship where it takes effect for all inheriting mappers. 
      [ticket:883]
      
+   - several ORM attributes have been removed or made private:
+     mapper.get_attr_by_column(), mapper.set_attr_by_column(), 
+     mapper.pks_by_table, mapper.cascade_callable(), 
+     MapperProperty.cascade_callable(), mapper.canload()
+     
    - fixed endless loop issue when using lazy="dynamic" on both 
      sides of a bi-directional relationship [ticket:872]
 
index 942b880c94c51221803f7500711a9090eb7516cf..c2cd4cf09dfe1ca7024492fd46a708d297cdd931 100644 (file)
@@ -122,7 +122,7 @@ def column_mapped_collection(mapping_spec):
     if isinstance(mapping_spec, schema.Column):
         def keyfunc(value):
             m = object_mapper(value)
-            return m.get_attr_by_column(value, mapping_spec)
+            return m._get_attr_by_column(value, mapping_spec)
     else:
         cols = []
         for c in mapping_spec:
@@ -133,7 +133,7 @@ def column_mapped_collection(mapping_spec):
         mapping_spec = tuple(cols)
         def keyfunc(value):
             m = object_mapper(value)
-            return tuple([m.get_attr_by_column(value, c) for c in mapping_spec])
+            return tuple([m._get_attr_by_column(value, c) for c in mapping_spec])
     return lambda: MappedCollection(keyfunc)
 
 def attribute_mapped_collection(attr_name):
index 9688999169cec10b7a0c7631353a0fda89a0d8e3..9220c5743b30af4a047a97e98917ff43a6176bab 100644 (file)
@@ -111,7 +111,7 @@ class DependencyProcessor(object):
     def _verify_canload(self, child):
         if not self.enable_typechecks:
             return
-        if child is not None and not self.mapper.canload(child):
+        if child is not None and not self.mapper._canload(child):
             raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type.  Did you mean to use a polymorphic mapper for this relationship ?  Set 'enable_typechecks=False' on the relation() to disable this exception.  Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (child.__class__, self.prop, self.mapper))
         
     def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
@@ -240,7 +240,7 @@ class OneToManyDP(DependencyProcessor):
                             uowcommit.register_object(child, isdelete=False)
                         elif self.hasparent(child) is False:
                             uowcommit.register_object(child, isdelete=True)
-                            for c in self.mapper.cascade_iterator('delete', child):
+                            for c, m in self.mapper.cascade_iterator('delete', child):
                                 uowcommit.register_object(c, isdelete=True)
 
     def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
@@ -294,7 +294,7 @@ class ManyToOneDP(DependencyProcessor):
                         for child in childlist.deleted_items() + childlist.unchanged_items():
                             if child is not None and self.hasparent(child) is False:
                                 uowcommit.register_object(child, isdelete=True)
-                                for c in self.mapper.cascade_iterator('delete', child):
+                                for c, m in self.mapper.cascade_iterator('delete', child):
                                     uowcommit.register_object(c, isdelete=True)
         else:
             for obj in deplist:
@@ -305,7 +305,7 @@ class ManyToOneDP(DependencyProcessor):
                         for child in childlist.deleted_items():
                             if self.hasparent(child) is False:
                                 uowcommit.register_object(child, isdelete=True)
-                                for c in self.mapper.cascade_iterator('delete', child):
+                                for c, m in self.mapper.cascade_iterator('delete', child):
                                     uowcommit.register_object(c, isdelete=True)
 
     def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
@@ -391,7 +391,7 @@ class ManyToManyDP(DependencyProcessor):
                     for child in childlist.deleted_items():
                         if self.cascade.delete_orphan and self.hasparent(child) is False:
                             uowcommit.register_object(child, isdelete=True)
-                            for c in self.mapper.cascade_iterator('delete', child):
+                            for c, m in self.mapper.cascade_iterator('delete', child):
                                 uowcommit.register_object(c, isdelete=True)
 
     def _synchronize(self, obj, child, associationrow, clearkeys, uowcommit):
index aa0b2dcc242bd0711cf44c7d94357e23b805dc03..815ea8cebb660df817704d623ce5f6a4e966c210 100644 (file)
@@ -319,18 +319,15 @@ class MapperProperty(object):
         """
         
         raise NotImplementedError()
-        
+
     def cascade_iterator(self, type, object, recursive=None, halt_on=None):
-        """return an iterator of objects which are child objects of the given object,
-        as attached to the attribute corresponding to this MapperProperty."""
+        """iterate through instances related to the given instance along
+        a particular 'cascade' path, starting with this MapperProperty.
         
-        return []
-
-    def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
-        """run the given callable across all objects which are child objects of 
-        the given object, as attached to the attribute corresponding to this MapperProperty."""
+        see PropertyLoader for the related instance implementation.
+        """
         
-        return []
+        return iter([])
 
     def get_criterion(self, query, key, value):
         """Return a ``WHERE`` clause suitable for this
index 5673be44c117233a8d2338f345a515a3dde396ed..67087c5708b97e5f99a276d278bbb7d88f80d78b 100644 (file)
@@ -4,15 +4,15 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-import weakref, warnings, operator
+import weakref, warnings
+from itertools import chain
 from sqlalchemy import sql, util, exceptions, logging
-from sqlalchemy.sql import expression, visitors
+from sqlalchemy.sql import expression, visitors, operators
 from sqlalchemy.sql import util as sqlutil
 from sqlalchemy.orm import util as mapperutil
 from sqlalchemy.orm.util import ExtensionCarrier, create_row_adapter
 from sqlalchemy.orm import sync, attributes
 from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator
-deferred_load = None
 
 __all__ = ['Mapper', 'class_mapper', 'object_mapper', 'mapper_registry']
 
@@ -22,7 +22,7 @@ mapper_registry = weakref.WeakKeyDictionary()
 # a list of MapperExtensions that will be installed in all mappers by default
 global_extensions = []
 
-# a constant returned by get_attr_by_column to indicate
+# a constant returned by _get_attr_by_column to indicate
 # this mapper is not handling an attribute for a particular
 # column
 NO_ATTRIBUTE = object()
@@ -152,6 +152,7 @@ class Mapper(object):
         self._compile_inheritance()
         self._compile_tables()
         self._compile_properties()
+        self._compile_pks()
         self._compile_selectable()
 
         self.__log("constructed")
@@ -376,12 +377,6 @@ class Mapper(object):
         self.polymorphic_map[key] = class_or_mapper
 
     def _compile_tables(self):
-        """After the inheritance relationships have been reconciled,
-        set up some more table-based instance variables and determine
-        the *primary key* columns for all tables represented by this
-        ``Mapper``.
-        """
-
         # summary of the various Selectable units:
         # mapped_table - the Selectable that represents a join of the underlying Tables to be saved (or just the Table)
         # local_table - the Selectable that was passed to this Mapper's constructor, if any
@@ -401,27 +396,25 @@ class Mapper(object):
         if not self.tables:
             raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table))
 
-        # TODO: move the "figure pks" step down into compile_properties; after 
-        # all columns have been mapped, assemble PK columns and their
-        # proxied parents into the pks_by_table collection, then get rid 
-        # of the _has_pks method
-        
-        # determine primary key columns
-        self.pks_by_table = {}
+    def _compile_pks(self):
 
-        # go through all of our represented tables
-        # and assemble primary key columns
-        for t in self.tables + [self.mapped_table]:
+        self._pks_by_table = {}
+        self._cols_by_table = {}
+        
+        all_cols = util.Set(chain(*[c2 for c2 in [col.proxy_set for col in [c for c in self._columntoproperty]]]))
+        pk_cols = util.Set([c for c in all_cols if c.primary_key])
+        
+        for t in util.Set(self.tables + [self.mapped_table]):
             self._all_tables.add(t)
-            if t not in self.pks_by_table:
-                self.pks_by_table[t] = util.OrderedSet()
-            self.pks_by_table[t].update(t.primary_key)
-                
-        if self.primary_key_argument is not None:
+            if t.primary_key and pk_cols.issuperset(t.primary_key):
+                self._pks_by_table[t] = util.Set(t.primary_key).intersection(pk_cols)
+            self._cols_by_table[t] = util.Set(t.c).intersection(all_cols)
+            
+        if self.primary_key_argument:
             for k in self.primary_key_argument:
-                self.pks_by_table.setdefault(k.table, util.OrderedSet()).add(k)
+                self._pks_by_table.setdefault(k.table, util.Set()).add(k)
                 
-        if len(self.pks_by_table[self.mapped_table]) == 0:
+        if len(self._pks_by_table[self.mapped_table]) == 0:
             raise exceptions.ArgumentError("Could not assemble any primary key columns for mapped table '%s'" % (self.mapped_table.name))
 
         if self.inherits is not None and not self.concrete and not self.primary_key_argument:
@@ -437,7 +430,7 @@ class Mapper(object):
         
             primary_key = expression.ColumnSet()
 
-            for col in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+            for col in (self.primary_key_argument or self._pks_by_table[self.mapped_table]):
                 c = self.mapped_table.corresponding_column(col, raiseerr=False)
                 if c is None:
                     for cc in self._equivalent_columns[col]:
@@ -474,7 +467,10 @@ class Mapper(object):
 
             self.primary_key = primary_key
             self.__log("Identified primary key columns: " + str(primary_key))
-        
+
+            # create a "get clause" based on the primary key.  this is used
+            # by query.get() and many-to-one lazyloads to load this item
+            # by primary key.
             _get_clause = sql.and_()
             _get_params = {}
             for primary_key in self.primary_key:
@@ -510,7 +506,7 @@ class Mapper(object):
 
         result = {}
         def visit_binary(binary):
-            if binary.operator == operator.eq:
+            if binary.operator == operators.eq:
                 if binary.left in result:
                     result[binary.left].add(binary.right)
                 else:
@@ -533,13 +529,17 @@ class Mapper(object):
                 return
             recursive.add(col)
             for fk in col.foreign_keys:
-                result.setdefault(fk.column, util.Set()).add(equiv)
+                if fk.column not in result:
+                    result[fk.column] = util.Set()
+                result[fk.column].add(equiv)
                 equivs(fk.column, recursive, col)
                 
-        for column in (self.primary_key_argument or self.pks_by_table[self.mapped_table]):
+        for column in (self.primary_key_argument or self._pks_by_table[self.mapped_table]):
             for col in column.proxy_set:
                 if not col.foreign_keys:
-                    result.setdefault(col, util.Set()).add(col)
+                    if col not in result:
+                        result[col] = util.Set()
+                    result[col].add(col)
                 else:
                     equivs(col, util.Set(), col)
                     
@@ -571,11 +571,6 @@ class Mapper(object):
             return getattr(getattr(cls, clskey), key)
             
     def _compile_properties(self):
-        """Inspect the properties dictionary sent to the Mapper's
-        constructor as well as the mapped_table, and create
-        ``MapperProperty`` objects corresponding to each mapped column
-        and relation.
-        """
 
         # object attribute names mapped to MapperProperty objects
         self.__props = util.OrderedDict()
@@ -637,9 +632,6 @@ class Mapper(object):
                 # TODO: the "property already exists" case is still not well defined here.  
                 # assuming single-column, etc.
                 
-                if column in self.primary_key and prop.columns[-1] in self.primary_key:
-                    warnings.warn(RuntimeWarning("On mapper %s, primary key column '%s' is being combined with distinct primary key column '%s' in attribute '%s'.  Use explicit properties to give each column its own mapped attribute name." % (str(self), str(column), str(prop.columns[-1]), key)))
-
                 if prop.parent is not self:
                     # existing ColumnProperty from an inheriting mapper.
                     # make a copy and append our column to it
@@ -896,38 +888,27 @@ class Mapper(object):
         instance.
         """
 
-        return [self.get_attr_by_column(instance, column) for column in self.primary_key]
+        return [self._get_attr_by_column(instance, column) for column in self.primary_key]
 
-    def canload(self, instance):
+    def _canload(self, instance):
         """return true if this mapper is capable of loading the given instance"""
         if self.polymorphic_on is not None:
             return isinstance(instance, self.class_)
         else:
             return instance.__class__ is self.class_
         
-    def _getpropbycolumn(self, column, raiseerror=True):
+    def _get_attr_by_column(self, obj, column):
+        """Return an instance attribute using a Column as the key."""
         try:
-            return self._columntoproperty[column]
+            return self._columntoproperty[column].getattr(obj, column)
         except KeyError:
-            try:
-                prop = self.__props[column.key]
-                if not raiseerror:
-                    return None
+            prop = self.__props.get(column.key, None)
+            if prop:
                 raise exceptions.InvalidRequestError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop)))
-            except KeyError:
-                if not raiseerror:
-                    return None
+            else:
                 raise exceptions.InvalidRequestError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self)))
-
-    def get_attr_by_column(self, obj, column, raiseerror=True):
-        """Return an instance attribute using a Column as the key."""
-
-        prop = self._getpropbycolumn(column, raiseerror)
-        if prop is None:
-            return NO_ATTRIBUTE
-        return prop.getattr(obj, column)
-
-    def set_attr_by_column(self, obj, column, value):
+        
+    def _set_attr_by_column(self, obj, column, value):
         """Set the value of an instance attribute using a Column as the key."""
 
         self._columntoproperty[column].setattr(obj, value, column)
@@ -996,18 +977,18 @@ class Mapper(object):
         table_to_mapper = {}
         for mapper in self.base_mapper.polymorphic_iterator():
             for t in mapper.tables:
-                table_to_mapper.setdefault(t, mapper)
+                table_to_mapper[t] = mapper
 
-        for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=False):
+        for table in sqlutil.sort_tables(table_to_mapper.keys()):
             # two lists to store parameters for each table/object pair located
             insert = []
             update = []
 
             for obj, connection in tups:
                 mapper = object_mapper(obj)
-                if table not in mapper.tables or not mapper._has_pks(table):
+                if table not in mapper._pks_by_table:
                     continue
-                pks = mapper.pks_by_table[table]
+                pks = mapper._pks_by_table[table]
                 instance_key = mapper.identity_key_from_instance(obj)
 
                 if self.__should_log_debug:
@@ -1019,11 +1000,11 @@ class Mapper(object):
                 hasdata = False
 
                 if isinsert:
-                    for col in table.columns:
+                    for col in mapper._cols_by_table[table]:
                         if col is mapper.version_id_col:
                             params[col.key] = 1
                         elif col in pks:
-                            value = mapper.get_attr_by_column(obj, col)
+                            value = mapper._get_attr_by_column(obj, col)
                             if value is not None:
                                 params[col.key] = value
                         elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
@@ -1033,9 +1014,7 @@ class Mapper(object):
                             if col.default is None or value is not None:
                                 params[col.key] = value
                         else:
-                            value = mapper.get_attr_by_column(obj, col, False)
-                            if value is NO_ATTRIBUTE:
-                                continue
+                            value = mapper._get_attr_by_column(obj, col)
                             if col.default is None or value is not None:
                                 if isinstance(value, sql.ClauseElement):
                                     value_params[col] = value
@@ -1043,24 +1022,22 @@ class Mapper(object):
                                     params[col.key] = value
                     insert.append((obj, params, mapper, connection, value_params))
                 else:
-                    for col in table.columns:
+                    for col in mapper._cols_by_table[table]:
                         if col is mapper.version_id_col:
-                            params[col._label] = mapper.get_attr_by_column(obj, col)
+                            params[col._label] = mapper._get_attr_by_column(obj, col)
                             params[col.key] = params[col._label] + 1
                             for prop in mapper._columntoproperty.values():
                                 history = attributes.get_history(obj, prop.key, passive=True)
                                 if history and history.added_items():
                                     hasdata = True
                         elif col in pks:
-                            params[col._label] = mapper.get_attr_by_column(obj, col)
+                            params[col._label] = mapper._get_attr_by_column(obj, col)
                         elif mapper.polymorphic_on is not None and mapper.polymorphic_on.shares_lineage(col):
                             pass
                         else:
                             if post_update_cols is not None and col not in post_update_cols:
                                 continue
-                            prop = mapper._getpropbycolumn(col, False)
-                            if prop is None:
-                                continue
+                            prop = mapper._columntoproperty[col]
                             history = attributes.get_history(obj, prop.key, passive=True)
                             if history:
                                 a = history.added_items()
@@ -1076,14 +1053,14 @@ class Mapper(object):
             if update:
                 mapper = table_to_mapper[table]
                 clause = sql.and_()
-                for col in mapper.pks_by_table[table]:
+                for col in mapper._pks_by_table[table]:
                     clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True))
                 if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
                     clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True))
                 statement = table.update(clause)
                 rows = 0
                 supports_sane_rowcount = True
-                pks = mapper.pks_by_table[table]
+                pks = mapper._pks_by_table[table]
                 def comparator(a, b):
                     for col in pks:
                         x = cmp(a[1][col._label],b[1][col._label])
@@ -1115,23 +1092,22 @@ class Mapper(object):
 
                     if primary_key is not None:
                         i = 0
-                        for col in mapper.pks_by_table[table]:
-                            if mapper.get_attr_by_column(obj, col) is None and len(primary_key) > i:
-                                mapper.set_attr_by_column(obj, col, primary_key[i])
+                        for col in mapper._pks_by_table[table]:
+                            if mapper._get_attr_by_column(obj, col) is None and len(primary_key) > i:
+                                mapper._set_attr_by_column(obj, col, primary_key[i])
                             i+=1
                     mapper._postfetch(connection, table, obj, c, c.last_inserted_params(), value_params)
 
                     # synchronize newly inserted ids from one table to the next
                     # TODO: this fires off more than needed, try to organize syncrules
                     # per table
-                    mappers = list(mapper.iterate_to_root())
-                    mappers.reverse()
-                    for m in mappers:
+                    for m in util.reversed(list(mapper.iterate_to_root())):
                         if m._synchronizer is not None:
                             m._synchronizer.execute(obj, obj)
 
                     # testlib.pragma exempt:__hash__
                     inserted_objects.add((id(obj), obj, connection))
+
         if not postupdate:
             for id_, obj, connection in inserted_objects:
                 for mapper in object_mapper(obj).iterate_to_root():
@@ -1141,7 +1117,7 @@ class Mapper(object):
                 for mapper in object_mapper(obj).iterate_to_root():
                     if 'after_update' in mapper.extension.methods:
                         mapper.extension.after_update(mapper, connection, obj)
-
+    
     def _postfetch(self, connection, table, obj, resultproxy, params, value_params):
         """After an ``INSERT`` or ``UPDATE``, assemble newly generated
         values on an instance.  For columns which are marked as being generated
@@ -1152,20 +1128,15 @@ class Mapper(object):
         postfetch_cols = resultproxy.postfetch_cols().union(util.Set(value_params.keys())) 
         deferred_props = []
 
-        for c in table.c:
+        for c in self._cols_by_table[table]:
             if c in postfetch_cols and (not c.key in params or c in value_params):
-                prop = self._getpropbycolumn(c, raiseerror=False)
-                if prop is None:
-                    continue
+                prop = self._columntoproperty[c]
                 deferred_props.append(prop.key)
                 continue
             if c.primary_key or not c.key in params:
                 continue
-            v = self.get_attr_by_column(obj, c, False)
-            if v is NO_ATTRIBUTE:
-                continue
-            elif v != params[c.key]:
-                self.set_attr_by_column(obj, c, params[c.key])
+            if self._get_attr_by_column(obj, c) != params[c.key]:
+                self._set_attr_by_column(obj, c, params[c.key])
         
         if deferred_props:
             expire_instance(obj, deferred_props)
@@ -1196,13 +1167,13 @@ class Mapper(object):
         table_to_mapper = {}
         for mapper in self.base_mapper.polymorphic_iterator():
             for t in mapper.tables:
-                table_to_mapper.setdefault(t, mapper)
+                table_to_mapper[t] = mapper
 
         for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True):
             delete = {}
             for (obj, connection) in tups:
                 mapper = object_mapper(obj)
-                if table not in mapper.tables or not mapper._has_pks(table):
+                if table not in mapper._pks_by_table:
                     continue
 
                 params = {}
@@ -1210,23 +1181,23 @@ class Mapper(object):
                     continue
                 else:
                     delete.setdefault(connection, []).append(params)
-                for col in mapper.pks_by_table[table]:
-                    params[col.key] = mapper.get_attr_by_column(obj, col)
+                for col in mapper._pks_by_table[table]:
+                    params[col.key] = mapper._get_attr_by_column(obj, col)
                 if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
-                    params[mapper.version_id_col.key] = mapper.get_attr_by_column(obj, mapper.version_id_col)
+                    params[mapper.version_id_col.key] = mapper._get_attr_by_column(obj, mapper.version_id_col)
                 # testlib.pragma exempt:__hash__
                 deleted_objects.add((id(obj), obj, connection))
             for connection, del_objects in delete.iteritems():
                 mapper = table_to_mapper[table]
                 def comparator(a, b):
-                    for col in mapper.pks_by_table[table]:
+                    for col in mapper._pks_by_table[table]:
                         x = cmp(a[col.key],b[col.key])
                         if x != 0:
                             return x
                     return 0
                 del_objects.sort(comparator)
                 clause = sql.and_()
-                for col in mapper.pks_by_table[table]:
+                for col in mapper._pks_by_table[table]:
                     clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True))
                 if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
                     clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True))
@@ -1240,17 +1211,6 @@ class Mapper(object):
                 if 'after_delete' in mapper.extension.methods:
                     mapper.extension.after_delete(mapper, connection, obj)
 
-    def _has_pks(self, table):
-        # TODO: determine this beforehand
-        if self.pks_by_table.get(table, None):
-            for k in self.pks_by_table[table]:
-                if k not in self._columntoproperty:
-                    return False
-            else:
-                return True
-        else:
-            return False
-
     def register_dependencies(self, uowcommit, *args, **kwargs):
         """Register ``DependencyProcessor`` instances with a
         ``unitofwork.UOWTransaction``.
@@ -1263,8 +1223,8 @@ class Mapper(object):
             prop.register_dependencies(uowcommit, *args, **kwargs)
 
     def cascade_iterator(self, type, object, recursive=None, halt_on=None):
-        """Iterate each element in an object graph, for all relations
-        taht meet the given cascade rule.
+        """Iterate each element and its mapper in an object graph, 
+        for all relations that meet the given cascade rule.
 
         type
           The name of the cascade rule (i.e. save-update, delete,
@@ -1282,33 +1242,8 @@ class Mapper(object):
         if recursive is None:
             recursive=util.IdentitySet()
         for prop in self.__props.values():
-            for c in prop.cascade_iterator(type, object, recursive, halt_on=halt_on):
-                yield c
-
-    def cascade_callable(self, type, object, callable_, recursive=None, halt_on=None):
-        """Execute a callable for each element in an object graph, for
-        all relations that meet the given cascade rule.
-
-        type
-          The name of the cascade rule (i.e. save-update, delete, etc.)
-
-        object
-          The lead object instance.  child items will be processed per
-          the relations defined for this object's mapper.
-
-        callable\_
-          The callable function.
-
-        recursive
-          Used by the function for internal context during recursive
-          calls, leave as None.
-          
-        """
-
-        if recursive is None:
-            recursive=util.IdentitySet()
-        for prop in self.__props.values():
-            prop.cascade_callable(type, object, callable_, recursive, halt_on=halt_on)
+            for (c, m) in prop.cascade_iterator(type, object, recursive, halt_on=halt_on):
+                yield (c, m)
 
     def get_select_mapper(self):
         """Return the mapper used for issuing selects.
@@ -1365,8 +1300,8 @@ class Mapper(object):
                 
             isnew = False
 
-            if context.version_check and self.version_id_col is not None and self.get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]:
-                raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self.get_attr_by_column(instance, self.version_id_col), row[self.version_id_col]))
+            if context.version_check and self.version_id_col is not None and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]:
+                raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col]))
 
             if context.populate_existing or self.always_refresh or instance._state.trigger is not None:
                 instance._state.trigger = None
@@ -1541,7 +1476,7 @@ class Mapper(object):
 
             params = {}
             for c in param_names:
-                params[c.name] = self.get_attr_by_column(instance, c)
+                params[c.name] = self._get_attr_by_column(instance, c)
             row = selectcontext.session.connection(self).execute(statement, params).fetchone()
             self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
 
index 2c50ec92fca5b3b292fbee7a52363ade6a178072..806c91a2b7cebea3bfd36486c582b284526fc1a7 100644 (file)
@@ -12,13 +12,13 @@ to handle flush-time dependency sorting and processing.
 """
 
 from sqlalchemy import sql, schema, util, exceptions, logging
-from sqlalchemy.sql import util as sql_util, visitors
+from sqlalchemy.sql import util as sql_util, visitors, operators
 from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm import util as mapperutil
-import operator
 from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
 from sqlalchemy.exceptions import ArgumentError
+import warnings
 
 __all__ = ['ColumnProperty', 'CompositeProperty', 'SynonymProperty', 'PropertyLoader', 'BackRef']
 
@@ -48,7 +48,12 @@ class ColumnProperty(StrategizedProperty):
             return strategies.DeferredColumnLoader(self)
         else:
             return strategies.ColumnLoader(self)
-    
+
+    def do_init(self):
+        super(ColumnProperty, self).do_init()
+        if len(self.columns) > 1 and self.parent.primary_key.issuperset(self.columns):
+            warnings.warn(RuntimeWarning("On mapper %s, primary key column '%s' is being combined with distinct primary key column '%s' in attribute '%s'.  Use explicit properties to give each column its own mapped attribute name." % (str(self.parent), str(self.columns[1]), str(self.columns[0]), self.key)))
+        
     def copy(self):
         return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
         
@@ -75,10 +80,8 @@ class ColumnProperty(StrategizedProperty):
             col = self.prop.columns[0]
             return op(col._bind_param(other), col)
 
-            
 ColumnProperty.logger = logging.class_logger(ColumnProperty)
 
-
 class CompositeProperty(ColumnProperty):
     """subclasses ColumnProperty to provide composite type support."""
     
@@ -86,6 +89,10 @@ class CompositeProperty(ColumnProperty):
         super(CompositeProperty, self).__init__(*columns, **kwargs)
         self.composite_class = class_
         self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self)
+
+    def do_init(self):
+        super(ColumnProperty, self).do_init()
+        # TODO: similar PK check as ColumnProperty does ?
         
     def copy(self):
         return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns)
@@ -283,7 +290,7 @@ class PropertyLoader(StrategizedProperty):
             return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
             
     def compare(self, op, value, value_is_parent=False):
-        if op == operator.eq:
+        if op == operators.eq:
             if value is None:
                 return ~sql.exists([1], self.prop.mapper.mapped_table, self.prop.primaryjoin)
             else:
@@ -347,23 +354,9 @@ class PropertyLoader(StrategizedProperty):
                 if not isinstance(c, self.mapper.class_):
                     raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
                 recursive.add(c)
-                yield c
-                for c2 in mapper.cascade_iterator(type, c, recursive):
-                    yield c2
-
-    def cascade_callable(self, type, object, callable_, recursive, halt_on=None):
-        if not type in self.cascade:
-            return
-
-        mapper = self.mapper.primary_mapper()
-        passive = type != 'delete' or self.passive_deletes
-        for c in attributes.get_as_list(object, self.key, passive=passive):
-            if c is not None and c not in recursive and (halt_on is None or not halt_on(c)):
-                if not isinstance(c, self.mapper.class_):
-                    raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__)))
-                recursive.add(c)
-                callable_(c, mapper.entity_name)
-                mapper.cascade_callable(type, c, callable_, recursive)
+                yield (c, mapper)
+                for (c2, m) in mapper.cascade_iterator(type, c, recursive):
+                    yield (c2, m)
 
     def _get_target_class(self):
         """Return the target class of the relation, even if the
@@ -464,7 +457,7 @@ class PropertyLoader(StrategizedProperty):
         if self.foreign_keys:
             self._opposite_side = util.Set()
             def visit_binary(binary):
-                if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+                if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
                     return
                 if binary.left in self.foreign_keys:
                     self._opposite_side.add(binary.right)
@@ -477,7 +470,7 @@ class PropertyLoader(StrategizedProperty):
             self.foreign_keys = util.Set()
             self._opposite_side = util.Set()
             def visit_binary(binary):
-                if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+                if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
                     return
 
                 # this check is for when the user put the "view_only" flag on and has tables that have nothing
index 7097273f59c1326dabd3bfe0536e4618456d04bf..28ef39aba419dca53f09d072f67fc7651bd1fb77 100644 (file)
@@ -1,4 +1,4 @@
-# objectstore.py
+# session.py
 # Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com
 #
 # This module is part of SQLAlchemy and is released under
@@ -103,10 +103,10 @@ class SessionExtension(object):
 
         Note that this may not be per-flush if a longer running transaction is ongoing."""
 
-    def before_flush(self, session, flush_context, objects):
+    def before_flush(self, session, flush_context, instances):
         """execute before flush process has started.
         
-        'objects' is an optional list of objects which were passed to the ``flush()``
+        'instances' is an optional list of objects which were passed to the ``flush()``
         method.
         """
 
@@ -719,7 +719,7 @@ class Session(object):
         entity_name = kwargs.pop('entity_name', None)
         return self.query(class_, entity_name=entity_name).load(ident, **kwargs)
 
-    def refresh(self, obj, attribute_names=None):
+    def refresh(self, instance, attribute_names=None):
         """Refresh the attributes on the given instance.
         
         When called, a query will be issued
@@ -738,12 +738,12 @@ class Session(object):
         refreshed.
         """
 
-        self._validate_persistent(obj)
+        self._validate_persistent(instance)
             
-        if self.query(obj.__class__)._get(obj._instance_key, refresh_instance=obj, only_load_props=attribute_names) is None:
-            raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(obj))
+        if self.query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None:
+            raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
 
-    def expire(self, obj, attribute_names=None):
+    def expire(self, instance, attribute_names=None):
         """Expire the attributes on the given instance.
         
         The instance's attributes are instrumented such that
@@ -764,11 +764,16 @@ class Session(object):
         """
         
         if attribute_names:
-            self._validate_persistent(obj)
-            expire_instance(obj, attribute_names=attribute_names)
+            self._validate_persistent(instance)
+            expire_instance(instance, attribute_names=attribute_names)
         else:
-            for c in [obj] + list(_object_mapper(obj).cascade_iterator('refresh-expire', obj)):
-                self._validate_persistent(obj)
+            # pre-fetch the full cascade since the expire is going to 
+            # remove associations
+            cascaded = list(_cascade_iterator('refresh-expire', instance))
+            self._validate_persistent(instance)
+            expire_instance(instance, None)
+            for (c, m) in cascaded:
+                self._validate_persistent(c)
                 expire_instance(c, None)
 
     def prune(self):
@@ -784,20 +789,20 @@ class Session(object):
 
         return self.uow.prune_identity_map()
 
-    def expunge(self, object):
-        """Remove the given `object` from this ``Session``.
+    def expunge(self, instance):
+        """Remove the given `instance` from this ``Session``.
 
-        This will free all internal references to the object.
+        This will free all internal references to the instance.
         Cascading will be applied according to the *expunge* cascade
         rule.
         """
-        self._validate_persistent(object)
-        for c in [object] + list(_object_mapper(object).cascade_iterator('expunge', object)):
+        self._validate_persistent(instance)
+        for c, m in [(instance, None)] + list(_cascade_iterator('expunge', instance)):
             if c in self:
                 self.uow._remove_deleted(c)
                 self._unattach(c)
 
-    def save(self, object, entity_name=None):
+    def save(self, instance, entity_name=None):
         """Add a transient (unsaved) instance to this ``Session``.
 
         This operation cascades the `save_or_update` method to
@@ -808,12 +813,10 @@ class Session(object):
         specific ``Mapper`` used to handle this instance.
         """
 
-        self._save_impl(object, entity_name=entity_name)
-        _object_mapper(object).cascade_callable('save-update', object,
-                                                lambda c, e:self._save_or_update_impl(c, e),
-                                                halt_on=lambda c:c in self)
+        self._save_impl(instance, entity_name=entity_name)
+        self._cascade_save_or_update(instance)
 
-    def update(self, object, entity_name=None):
+    def update(self, instance, entity_name=None):
         """Bring the given detached (saved) instance into this
         ``Session``.
 
@@ -826,37 +829,37 @@ class Session(object):
         ``cascade="save-update"``.
         """
 
-        self._update_impl(object, entity_name=entity_name)
-        _object_mapper(object).cascade_callable('save-update', object,
-                                                lambda c, e:self._save_or_update_impl(c, e),
-                                                halt_on=lambda c:c in self)
+        self._update_impl(instance, entity_name=entity_name)
+        self._cascade_save_or_update(instance)
 
-    def save_or_update(self, object, entity_name=None):
-        """Save or update the given object into this ``Session``.
+    def save_or_update(self, instance, entity_name=None):
+        """Save or update the given instance into this ``Session``.
 
         The presence of an `_instance_key` attribute on the instance
         determines whether to ``save()`` or ``update()`` the instance.
         """
 
-        self._save_or_update_impl(object, entity_name=entity_name)
-        _object_mapper(object).cascade_callable('save-update', object,
-                                                lambda c, e:self._save_or_update_impl(c, e),
-                                                halt_on=lambda c:c in self)
+        self._save_or_update_impl(instance, entity_name=entity_name)
+        self._cascade_save_or_update(instance)
+    
+    def _cascade_save_or_update(self, instance):
+        for obj, mapper in _cascade_iterator('save-update', instance, halt_on=lambda c:c in self):
+            self._save_or_update_impl(obj, mapper.entity_name)
 
-    def delete(self, object):
+    def delete(self, instance):
         """Mark the given instance as deleted.
 
         The delete operation occurs upon ``flush()``.
         """
 
-        self._delete_impl(object)
-        for c in list(_object_mapper(object).cascade_iterator('delete', object)):
+        self._delete_impl(instance)
+        for c, m in _cascade_iterator('delete', instance):
             self._delete_impl(c, ignore_transient=True)
 
 
-    def merge(self, object, entity_name=None, dont_load=False, _recursive=None):
-        """Copy the state of the given `object` onto the persistent
-        object with the same identifier.
+    def merge(self, instance, entity_name=None, dont_load=False, _recursive=None):
+        """Copy the state of the given `instance` onto the persistent
+        instance with the same identifier.
 
         If there is no persistent instance currently associated with
         the session, it will be loaded.  Return the persistent
@@ -871,20 +874,20 @@ class Session(object):
         if _recursive is None:
             _recursive = {}  #TODO: this should be an IdentityDict
         if entity_name is not None:
-            mapper = _class_mapper(object.__class__, entity_name=entity_name)
+            mapper = _class_mapper(instance.__class__, entity_name=entity_name)
         else:
-            mapper = _object_mapper(object)
-        if object in _recursive:
-            return _recursive[object]
+            mapper = _object_mapper(instance)
+        if instance in _recursive:
+            return _recursive[instance]
         
-        key = getattr(object, '_instance_key', None)
+        key = getattr(instance, '_instance_key', None)
         if key is None:
             merged = attributes.new_instance(mapper.class_)
         else:
             if key in self.identity_map:
                 merged = self.identity_map[key]
             elif dont_load:
-                if object._state.modified:
+                if instance._state.modified:
                     raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects marked as 'dirty'.  flush() all changes on mapped instances before merging with dont_load=True.")
                     
                 merged = attributes.new_instance(mapper.class_)
@@ -894,10 +897,10 @@ class Session(object):
             else:
                 merged = self.get(mapper.class_, key[1])
                 if merged is None:
-                    raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(object))
-        _recursive[object] = merged
+                    raise exceptions.AssertionError("Instance %s has an instance key but is not persisted" % mapperutil.instance_str(instance))
+        _recursive[instance] = merged
         for prop in mapper.iterate_properties:
-            prop.merge(self, object, merged, dont_load, _recursive)
+            prop.merge(self, instance, merged, dont_load, _recursive)
         if key is None:
             self.save(merged, entity_name=mapper.entity_name)
         elif dont_load:
@@ -968,96 +971,96 @@ class Session(object):
         return mapper.identity_key_from_instance(instance)
     identity_key = classmethod(identity_key)
     
-    def object_session(cls, obj):
+    def object_session(cls, instance):
         """return the ``Session`` to which the given object belongs."""
         
-        return object_session(obj)
+        return object_session(instance)
     object_session = classmethod(object_session)
     
-    def _save_impl(self, obj, **kwargs):
-        if hasattr(obj, '_instance_key'):
-            raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(obj))
+    def _save_impl(self, instance, **kwargs):
+        if hasattr(instance, '_instance_key'):
+            raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(instance))
         else:
             # TODO: consolidate the steps here
-            attributes.manage(obj)
-            obj._entity_name = kwargs.get('entity_name', None)
-            self._attach(obj)
-            self.uow.register_new(obj)
+            attributes.manage(instance)
+            instance._entity_name = kwargs.get('entity_name', None)
+            self._attach(instance)
+            self.uow.register_new(instance)
 
-    def _update_impl(self, obj, **kwargs):
-        if obj in self and obj not in self.deleted:
+    def _update_impl(self, instance, **kwargs):
+        if instance in self and instance not in self.deleted:
             return
-        if not hasattr(obj, '_instance_key'):
-            raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(obj))
-        elif self.identity_map.get(obj._instance_key, obj) is not obj:
+        if not hasattr(instance, '_instance_key'):
+            raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance))
+        elif self.identity_map.get(instance._instance_key, instance) is not instance:
             raise exceptions.InvalidRequestError("Could not update instance '%s', identity key %s; a different instance with the same identity key already exists in this session." % (mapperutil.instance_str(obj), obj._instance_key))
-        self._attach(obj)
+        self._attach(instance)
 
-    def _save_or_update_impl(self, object, entity_name=None):
-        key = getattr(object, '_instance_key', None)
+    def _save_or_update_impl(self, instance, entity_name=None):
+        key = getattr(instance, '_instance_key', None)
         if key is None:
-            self._save_impl(object, entity_name=entity_name)
+            self._save_impl(instance, entity_name=entity_name)
         else:
-            self._update_impl(object, entity_name=entity_name)
+            self._update_impl(instance, entity_name=entity_name)
 
-    def _delete_impl(self, obj, ignore_transient=False):
-        if obj in self and obj in self.deleted:
+    def _delete_impl(self, instance, ignore_transient=False):
+        if instance in self and instance in self.deleted:
             return
-        if not hasattr(obj, '_instance_key'):
+        if not hasattr(instance, '_instance_key'):
             if ignore_transient:
                 return
             else:
-                raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(obj))
-        if self.identity_map.get(obj._instance_key, obj) is not obj:
-            raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(obj), obj._instance_key))
-        self._attach(obj)
-        self.uow.register_deleted(obj)
-
-    def _register_persistent(self, obj):
-        obj._sa_session_id = self.hash_key
-        self.identity_map[obj._instance_key] = obj
-        obj._state.commit_all()
-
-    def _attach(self, obj):
-        old_id = getattr(obj, '_sa_session_id', None)
+                raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance))
+        if self.identity_map.get(instance._instance_key, instance) is not instance:
+            raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(instance), instance._instance_key))
+        self._attach(instance)
+        self.uow.register_deleted(instance)
+
+    def _register_persistent(self, instance):
+        instance._sa_session_id = self.hash_key
+        self.identity_map[instance._instance_key] = instance
+        instance._state.commit_all()
+
+    def _attach(self, instance):
+        old_id = getattr(instance, '_sa_session_id', None)
         if old_id != self.hash_key:
-            if old_id is not None and old_id in _sessions and obj in _sessions[old_id]:
+            if old_id is not None and old_id in _sessions and instance in _sessions[old_id]:
                 raise exceptions.InvalidRequestError("Object '%s' is already attached "
                                                      "to session '%s' (this is '%s')" %
-                                                     (mapperutil.instance_str(obj), old_id, id(self)))
+                                                     (mapperutil.instance_str(instance), old_id, id(self)))
 
-            key = getattr(obj, '_instance_key', None)
+            key = getattr(instance, '_instance_key', None)
             if key is not None:
-                self.identity_map[key] = obj
-            obj._sa_session_id = self.hash_key
+                self.identity_map[key] = instance
+            instance._sa_session_id = self.hash_key
         
-    def _unattach(self, obj):
-        if obj._sa_session_id == self.hash_key:
-            del obj._sa_session_id
+    def _unattach(self, instance):
+        if instance._sa_session_id == self.hash_key:
+            del instance._sa_session_id
 
-    def _validate_persistent(self, obj):
-        """Validate that the given object is persistent within this
+    def _validate_persistent(self, instance):
+        """Validate that the given instance is persistent within this
         ``Session``.
         """
         
-        return obj in self
+        return instance in self
 
-    def __contains__(self, obj):
-        """return True if the given object is associated with this session.
+    def __contains__(self, instance):
+        """return True if the given instance is associated with this session.
         
         The instance may be pending or persistent within the Session for a
         result of True.
         """
         
-        return obj in self.uow.new or (hasattr(obj, '_instance_key') and self.identity_map.get(obj._instance_key) is obj)
+        return instance in self.uow.new or (hasattr(instance, '_instance_key') and self.identity_map.get(instance._instance_key) is instance)
 
     def __iter__(self):
-        """return an iterator of all objects which are pending or persistent within this Session."""
+        """return an iterator of all instances which are pending or persistent within this Session."""
         
         return iter(list(self.uow.new) + self.uow.identity_map.values())
 
-    def is_modified(self, obj, include_collections=True, passive=False):
-        """return True if the given object has modified attributes.
+    def is_modified(self, instance, include_collections=True, passive=False):
+        """return True if the given instance has modified attributes.
         
         This method retrieves a history instance for each instrumented attribute
         on the instance and performs a comparison of the current value to its
@@ -1073,15 +1076,15 @@ class Session(object):
         not be loaded in the course of performing this test.
         """
 
-        for attr in attributes.managed_attributes(obj.__class__):
+        for attr in attributes.managed_attributes(instance.__class__):
             if not include_collections and hasattr(attr.impl, 'get_collection'):
                 continue
-            if attr.get_history(obj).is_modified():
+            if attr.get_history(instance).is_modified():
                 return True
         return False
         
     dirty = property(lambda s:s.uow.locate_dirty(),
-                     doc="""A ``Set`` of all objects marked as 'dirty' within this ``Session``.  
+                     doc="""A ``Set`` of all instances marked as 'dirty' within this ``Session``.  
                      
                      Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection
                      modification operations will mark an instance as 'dirty' and place it in this set,
@@ -1095,12 +1098,12 @@ class Session(object):
                      """)
 
     deleted = property(lambda s:s.uow.deleted,
-                       doc="A ``Set`` of all objects marked as 'deleted' within this ``Session``")
+                       doc="A ``Set`` of all instances marked as 'deleted' within this ``Session``")
 
     new = property(lambda s:s.uow.new,
-                   doc="A ``Set`` of all objects marked as 'new' within this ``Session``.")
+                   doc="A ``Set`` of all instances marked as 'new' within this ``Session``.")
 
-def expire_instance(obj, attribute_names):
+def expire_instance(instance, attribute_names):
     """standalone expire instance function. 
     
     installs a callable with the given instance's _state
@@ -1110,29 +1113,30 @@ def expire_instance(obj, attribute_names):
     If the list is None or blank, the entire instance is expired.
     """
     
-    if obj._state.trigger is None:
+    if instance._state.trigger is None:
         def load_attributes(instance, attribute_names):
             if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance, only_load_props=attribute_names) is None:
                 raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
-        obj._state.trigger = load_attributes
+        instance._state.trigger = load_attributes
         
-    obj._state.expire_attributes(attribute_names)
+    instance._state.expire_attributes(attribute_names)
     
 register_attribute = unitofwork.register_attribute
 
-# this dictionary maps the hash key of a Session to the Session itself, and
-# acts as a Registry with which to locate Sessions.  this is to enable
-# object instances to be associated with Sessions without having to attach the
-# actual Session object directly to the object instance.
 _sessions = weakref.WeakValueDictionary()
 
-def object_session(obj):
-    """Return the ``Session`` to which the given object is bound, or ``None`` if none."""
+def _cascade_iterator(cascade, instance, **kwargs):
+    mapper = _object_mapper(instance)
+    for (o, m) in mapper.cascade_iterator(cascade, instance, **kwargs):
+        yield o, m
+
+def object_session(instance):
+    """Return the ``Session`` to which the given instance is bound, or ``None`` if none."""
 
-    hashkey = getattr(obj, '_sa_session_id', None)
+    hashkey = getattr(instance, '_sa_session_id', None)
     if hashkey is not None:
         sess = _sessions.get(hashkey)
-        if obj in sess:
+        if instance in sess:
             return sess
     return None
 
index 096a42bb71dffd5ba844137024127d061221dfca..3c647ac60466732e6a787edd29e53f9e56a9aa84 100644 (file)
@@ -109,7 +109,7 @@ class ColumnLoader(LoaderStrategy):
             def create_statement(instance):
                 params = {}
                 for c in param_names:
-                    params[c.name] = mapper.get_attr_by_column(instance, c)
+                    params[c.name] = mapper._get_attr_by_column(instance, c)
                 return (statement, params)
             
             def new_execute(instance, row, isnew, **flags):
@@ -301,7 +301,7 @@ class LazyLoader(AbstractRelationLoader):
             def visit_bindparam(s, bindparam):
                 mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
                 if bindparam.key in bind_to_col:
-                    bindparam.value = mapper.get_attr_by_column(instance, bind_to_col[bindparam.key])
+                    bindparam.value = mapper._get_attr_by_column(instance, bind_to_col[bindparam.key])
         return Visitor().traverse(criterion, clone=True)
     
     def setup_loader(self, instance, options=None, path=None):
@@ -338,7 +338,7 @@ class LazyLoader(AbstractRelationLoader):
             if self.use_get:
                 params = {}
                 for col, bind in self.lazybinds.iteritems():
-                    params[bind.key] = self.parent.get_attr_by_column(instance, col)
+                    params[bind.key] = self.parent._get_attr_by_column(instance, col)
                 ident = []
                 nonnulls = False
                 for primary_key in self.select_mapper.primary_key: 
index 9575aa958feb476dcabc6e5715bfe9e5bc222086..8132c7e4a5df273db834620928add7cf2d43b542 100644 (file)
@@ -115,10 +115,12 @@ class SyncRule(object):
         #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper
 
     def dest_primary_key(self):
+        # late-evaluating boolean since some syncs are created
+        # before the mapper has assembled pks
         try:
             return self._dest_primary_key
         except AttributeError:
-            self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper.pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks
+            self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper._pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks
             return self._dest_primary_key
 
     def execute(self, source, dest, obj, child, clearkeys):
@@ -131,7 +133,7 @@ class SyncRule(object):
             value = None
             clearkeys = True
         else:
-            value = self.source_mapper.get_attr_by_column(source, self.source_column)
+            value = self.source_mapper._get_attr_by_column(source, self.source_column)
         if isinstance(dest, dict):
             dest[self.dest_column.key] = value
         else:
@@ -140,7 +142,7 @@ class SyncRule(object):
 
             if logging.is_debug_enabled(self.logger):
                 self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.instance_str(source), str(self.source_column), mapperutil.instance_str(dest), str(self.dest_column), value))
-            self.dest_mapper.set_attr_by_column(dest, self.dest_column, value)
+            self.dest_mapper._set_attr_by_column(dest, self.dest_column, value)
 
 SyncRule.logger = logging.class_logger(SyncRule)
 
index d3e89d57e65fc7b80144d674151e68586571283b..1cf0cb1b05f17f1cfec2b41733ffb86cb1b45748 100644 (file)
@@ -18,8 +18,9 @@ def sort_tables(tables, reverse=False):
         vis.traverse(table)
     sequence = topological.QueueDependencySorter( tuples, tables).sort(create_tree=False)
     if reverse:
-        sequence.reverse()
-    return sequence
+        return util.reversed(sequence)
+    else:
+        return sequence
 
 def find_tables(clause, check_columns=False, include_aliases=False):
     tables = []
index 0301d01c99af69cf5b508d09b197389cc10c6674..5affaa238f2e0a3890349eaae9556163e8d7722a 100644 (file)
@@ -489,7 +489,7 @@ class DistinctPKTest(ORMTest):
             self._do_test(True)
             assert False
         except RuntimeWarning, e:
-            assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'.  Use explicit properties to give each column its own mapped attribute name."
+            assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'.  Use explicit properties to give each column its own mapped attribute name.", str(e)
 
     def test_explicit_pk(self):
         person_mapper = mapper(Person, person_table)
index 5608ae67deefd0acd298285a0e0a4f982e673c66..a94e9bbc444fbae802c6a145515d974c7baba94a 100644 (file)
@@ -79,7 +79,11 @@ class M2MTest(ORMTest):
             compile_mappers()
             assert False
         except exceptions.ArgumentError, e:
-            assert str(e) == "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'"
+            assert str(e) in [
+                "Error creating backref 'transitions' on relation 'Transition.places (Place)': property of that name exists on mapper 'Mapper|Place|place'",
+                "Error creating backref 'places' on relation 'Place.transitions (Transition)': property of that name exists on mapper 'Mapper|Transition|transition'"
+            ]
+            
         
     def testcircular(self):
         """tests a many-to-many relationship from a table to itself."""
index 9cd07b8fee75d0da96fb8068d40f64f484633248..f0d553630213279aa83b4955547ec2a76e2cd035 100644 (file)
@@ -269,8 +269,8 @@ class MapperTest(MapperSuperTest):
         class A(object):pass
         m = mapper(A, account_ids_table.join(account_stuff_table))
         m.compile()
-        assert m._has_pks(account_ids_table)
-        assert not m._has_pks(account_stuff_table)
+        assert account_ids_table in m._pks_by_table
+        assert account_stuff_table not in m._pks_by_table
         metadata.create_all(testbase.db)
         try:
             sess = create_session(bind=testbase.db)