]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- session.execute() will execute a Sequence object passed to
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Sep 2008 19:10:22 +0000 (19:10 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Sep 2008 19:10:22 +0000 (19:10 +0000)
  it (regression from 0.4).
- Removed the "raiseerror" keyword argument from object_mapper()
  and class_mapper().  These functions raise in all cases
  if the given class/instance is not mapped.
- Refined ExtensionCarrier to be itself a dict, removed
'methods' accessor
- moved identity_key tests to test/orm/utils.py
- some docstrings

CHANGES
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/orm/util.py
lib/sqlalchemy/sql/expression.py
test/orm/session.py
test/orm/utils.py

diff --git a/CHANGES b/CHANGES
index 3686034bcd8951a090c0a8b8a23f976c48faa8ca..3f981b927df588a9baa8ec46c00a9e2864c9eeed 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -31,6 +31,15 @@ CHANGES
       for zero length slices, slices with None on either end.
       [ticket:1177]
 
+    - Added an example illustrating Celko's "nested sets" as a 
+      SQLA mapping.
+      
+    - session.execute() will execute a Sequence object passed to
+      it (regression from 0.4).
+    
+    - Removed the "raiseerror" keyword argument from object_mapper()
+      and class_mapper().  These functions raise in all cases
+      if the given class/instance is not mapped.
 - sql
     - column.in_(someselect) can now be used as a columns-clause
       expression without the subquery bleeding into the FROM clause
index 886fcec91e85efc2e064d24c737a923661fc44b4..dd30bdda7a1a30cafd5b76b56d7960b1d90175ab 100644 (file)
@@ -630,9 +630,9 @@ class Mapper(object):
             class_mapper(cls)
 
             if cls.__dict__.get(clskey) is self:
-                # FIXME: there should not be any scenarios where
-                # a mapper compile leaves this CompileOnAttr in
-                # place.
+                # if this warning occurs, it usually means mapper
+                # compilation has failed, but operations upon the mapped
+                # classes have proceeded.
                 util.warn(
                     ("Attribute '%s' on class '%s' was not replaced during "
                      "mapper compilation operation") % (clskey, cls.__name__))
@@ -755,11 +755,18 @@ class Mapper(object):
                 for c in columns:
                     mc = self.mapped_table.corresponding_column(c)
                     if not mc:
-                        raise sa_exc.ArgumentError("Column '%s' is not represented in mapper's table.  Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c))
+                        raise sa_exc.ArgumentError("Column '%s' is not represented in mapper's table.  "
+                            "Use the `column_property()` function to force this column "
+                            "to be mapped as a read-only attribute." % c)
                     mapped_column.append(mc)
                 prop = ColumnProperty(*mapped_column)
             else:
-                raise sa_exc.ArgumentError("WARNING: column '%s' conflicts with property '%s'.  To resolve this, map the column to the class under a different name in the 'properties' dictionary.  Or, to remove all awareness of the column entirely (including its availability as a foreign key), use the 'include_properties' or 'exclude_properties' mapper arguments to control specifically which table columns get mapped." % (column.key, repr(prop)))
+                raise sa_exc.ArgumentError("WARNING: column '%s' conflicts with property '%r'.  "
+                    "To resolve this, map the column to the class under a different "
+                    "name in the 'properties' dictionary.  Or, to remove all awareness "
+                    "of the column entirely (including its availability as a foreign key), "
+                    "use the 'include_properties' or 'exclude_properties' mapper arguments "
+                    "to control specifically which table columns get mapped." % (column.key, prop))
 
         if isinstance(prop, ColumnProperty):
             col = self.mapped_table.corresponding_column(prop.columns[0])
@@ -879,7 +886,7 @@ class Mapper(object):
                     for name in method.__sa_validators__:
                         self._validators[name] = method
 
-        if 'reconstruct_instance' in self.extension.methods:
+        if 'reconstruct_instance' in self.extension:
             def reconstruct(instance):
                 self.extension.reconstruct_instance(self, instance)
             event_registry.add_listener('on_load', reconstruct)
@@ -1067,10 +1074,10 @@ class Mapper(object):
             # call before_XXX extensions
             for state, mapper, connection, has_identity in tups:
                 if not has_identity:
-                    if 'before_insert' in mapper.extension.methods:
+                    if 'before_insert' in mapper.extension:
                         mapper.extension.before_insert(mapper, connection, state.obj())
                 else:
-                    if 'before_update' in mapper.extension.methods:
+                    if 'before_update' in mapper.extension:
                         mapper.extension.before_update(mapper, connection, state.obj())
 
         for state, mapper, connection, has_identity in tups:
@@ -1237,10 +1244,10 @@ class Mapper(object):
 
                 # call after_XXX extensions
                 if not has_identity:
-                    if 'after_insert' in mapper.extension.methods:
+                    if 'after_insert' in mapper.extension:
                         mapper.extension.after_insert(mapper, connection, state.obj())
                 else:
-                    if 'after_update' in mapper.extension.methods:
+                    if 'after_update' in mapper.extension:
                         mapper.extension.after_update(mapper, connection, state.obj())
 
     def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
@@ -1289,7 +1296,7 @@ class Mapper(object):
             tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)]
 
         for state, mapper, connection in tups:
-            if 'before_delete' in mapper.extension.methods:
+            if 'before_delete' in mapper.extension:
                 mapper.extension.before_delete(mapper, connection, state.obj())
 
         table_to_mapper = {}
@@ -1326,7 +1333,7 @@ class Mapper(object):
                     raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects)))
 
         for state, mapper, connection in tups:
-            if 'after_delete' in mapper.extension.methods:
+            if 'after_delete' in mapper.extension:
                 mapper.extension.after_delete(mapper, connection, state.obj())
 
     def _register_dependencies(self, uowcommit):
@@ -1419,15 +1426,15 @@ class Mapper(object):
         if not extension:
             extension = self.extension
 
-        translate_row = 'translate_row' in extension.methods
-        create_instance = 'create_instance' in extension.methods
-        populate_instance = 'populate_instance' in extension.methods
-        append_result = 'append_result' in extension.methods
+        translate_row = extension.get('translate_row', None)
+        create_instance = extension.get('create_instance', None)
+        populate_instance = extension.get('populate_instance', None)
+        append_result = extension.get('append_result', None)
         populate_existing = context.populate_existing or self.always_refresh
 
         def _instance(row, result):
             if translate_row:
-                ret = extension.translate_row(self, context, row)
+                ret = translate_row(self, context, row)
                 if ret is not EXT_CONTINUE:
                     row = ret
 
@@ -1489,7 +1496,7 @@ class Mapper(object):
                 loaded_instance = True
 
                 if create_instance:
-                    instance = extension.create_instance(self, context, row, self.class_)
+                    instance = create_instance(self, context, row, self.class_)
                     if instance is EXT_CONTINUE:
                         instance = self.class_manager.new_instance()
                     else:
@@ -1517,7 +1524,7 @@ class Mapper(object):
                     state.runid = context.runid
                     context.progress.add(state)
 
-                if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+                if not populate_instance or populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
                     populate_state(state, row, isnew, only_load_props)
 
             else:
@@ -1533,13 +1540,13 @@ class Mapper(object):
                         attrs = state.unloaded
                         context.partials[state] = attrs  #<-- allow query.instances to commit the subset of attrs
 
-                    if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+                    if not populate_instance or populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
                         populate_state(state, row, isnew, attrs, instancekey=identitykey)
 
             if loaded_instance:
                 state._run_on_load(instance)
 
-            if result is not None and (not append_result or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
+            if result is not None and (not append_result or append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
                 result.append(instance)
 
             return instance
@@ -1651,7 +1658,7 @@ def _event_on_init(state, instance, args, kwargs):
     instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
     # compile() always compiles all mappers
     instrumenting_mapper.compile()
-    if 'init_instance' in instrumenting_mapper.extension.methods:
+    if 'init_instance' in instrumenting_mapper.extension:
         instrumenting_mapper.extension.init_instance(
             instrumenting_mapper, instrumenting_mapper.class_,
             state.manager.events.original_init,
@@ -1661,7 +1668,7 @@ def _event_on_init_failure(state, instance, args, kwargs):
     """Run init_failed hooks."""
 
     instrumenting_mapper = state.manager.info[_INSTRUMENTOR]
-    if 'init_failed' in instrumenting_mapper.extension.methods:
+    if 'init_failed' in instrumenting_mapper.extension:
         util.warn_exception(
             instrumenting_mapper.extension.init_failed,
             instrumenting_mapper, instrumenting_mapper.class_,
index e25716316a85b960c14ca914da0d777c901539ba..4b6aa0018be3a29727e01292fbfcd47fde27996a 100644 (file)
@@ -1109,7 +1109,7 @@ class Query(object):
         else:
             filter = None
 
-        custom_rows = single_entity and 'append_result' in self._entities[0].extension.methods
+        custom_rows = single_entity and 'append_result' in self._entities[0].extension
 
         (process, labels) = zip(*[query_entity.row_processor(self, context, custom_rows) for query_entity in self._entities])
 
@@ -1433,7 +1433,6 @@ class Query(object):
 
         return result.rowcount
 
-
     def _compile_context(self, labels=True):
         context = QueryContext(self)
 
@@ -1563,9 +1562,6 @@ class Query(object):
                 crit = self._adapt_clause(crit, False, False)
                 context.whereclause = sql.and_(context.whereclause, crit)
 
-    def __log_debug(self, msg):
-        self.logger.debug(msg)
-
     def __str__(self):
         return str(self._compile_context().statement)
 
index 6ea28148d92b252ba12253d27a4803d21af1aef5..6e33be96bf929f02b632074f26158a8b9f6d58cd 100644 (file)
@@ -7,8 +7,9 @@
 import sqlalchemy.exceptions as sa_exc
 from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs
 from sqlalchemy.orm import (
-    EXT_CONTINUE, MapperExtension, class_mapper, object_session,
+    EXT_CONTINUE, MapperExtension, class_mapper, object_session
     )
+from sqlalchemy.orm import exc as orm_exc
 from sqlalchemy.orm.session import Session
 
 
@@ -102,15 +103,16 @@ class ScopedSession(object):
         """
         class query(object):
             def __get__(s, instance, owner):
-                mapper = class_mapper(owner, raiseerror=False)
-                if mapper:
-                    if query_cls:
-                        # custom query class
-                        return query_cls(mapper, session=self.registry())
-                    else:
-                        # session's configured query class
-                        return self.registry().query(mapper)
-                else:
+                try:
+                    mapper = class_mapper(owner)
+                    if mapper:
+                        if query_cls:
+                            # custom query class
+                            return query_cls(mapper, session=self.registry())
+                        else:
+                            # session's configured query class
+                            return self.registry().query(mapper)
+                except orm_exc.UnmappedClassError:
                     return None
         return query()
 
index ad987430a57c8f7925d02e69bcece337c80d63f0..0ca59141e069993870c408aa6779ca00bc91b8a7 100644 (file)
@@ -1535,18 +1535,7 @@ class Session(object):
 
         return util.IdentitySet(self._new.values())
 
-def _expire_state(state, attribute_names):
-    """Stand-alone expire instance function.
-
-    Installs a callable with the given instance's _state which will fire off
-    when any of the named attributes are accessed; their existing value is
-    removed.
-
-    If the list is None or blank, the entire instance is expired.
-
-    """
-    state.expire_attributes(attribute_names)
-
+_expire_state = attributes.InstanceState.expire_attributes
 register_attribute = unitofwork.register_attribute
 
 _sessions = weakref.WeakValueDictionary()
@@ -1591,6 +1580,7 @@ def _state_for_unknown_persistence_instance(instance):
 
 def object_session(instance):
     """Return the ``Session`` to which instance belongs, or None."""
+
     return _state_session(attributes.instance_state(instance))
 
 def _state_session(state):
index c1d93153e861bd8691466dc8fcd0a78b8db4b5e1..308266cd8319978169b5a630914c1e7187674f00 100644 (file)
@@ -20,7 +20,8 @@ from sqlalchemy.orm import util as mapperutil
 
 
 class DefaultColumnLoader(LoaderStrategy):
-    def _register_attribute(self, compare_function, copy_function, mutable_scalars, comparator_factory, callable_=None, proxy_property=None, active_history=False):
+    def _register_attribute(self, compare_function, copy_function, mutable_scalars, 
+            comparator_factory, callable_=None, proxy_property=None, active_history=False):
         self.logger.info("%s register managed attribute" % self)
 
         attribute_ext = util.to_list(self.parent_property.extension) or []
index 453b9f510d7fca13bf8cad216ebc890dce58506d..f461c4bdcb0e90022add4aaa6289ab69816b4e53 100644 (file)
@@ -170,7 +170,7 @@ def identity_key(*args, **kwargs):
     mapper = object_mapper(instance)
     return mapper.identity_key_from_instance(instance)
     
-class ExtensionCarrier(object):
+class ExtensionCarrier(dict):
     """Fronts an ordered collection of MapperExtension objects.
 
     Bundles multiple MapperExtensions into a unified callable unit,
@@ -179,9 +179,9 @@ class ExtensionCarrier(object):
 
       carrier.after_insert(...args...)
 
-    Also includes a 'methods' dictionary accessor which allows for a quick
-    check if a particular method is overridden on any contained
-    MapperExtensions.
+    The dictionary interface provides containment for implemented
+    method names mapped to a callable which executes that method
+    for participating extensions.
 
     """
 
@@ -189,7 +189,6 @@ class ExtensionCarrier(object):
                     if not method.startswith('_'))
 
     def __init__(self, extensions=None):
-        self.methods = {}
         self._extensions = []
         for ext in extensions or ():
             self.append(ext)
@@ -213,12 +212,11 @@ class ExtensionCarrier(object):
 
     def _register(self, extension):
         """Register callable fronts for overridden interface methods."""
-        for method in self.interface:
-            if method in self.methods:
-                continue
+        
+        for method in self.interface.difference(self):
             impl = getattr(extension, method, None)
             if impl and impl is not getattr(MapperExtension, method):
-                self.methods[method] = self._create_do(method)
+                self[method] = self._create_do(method)
 
     def _create_do(self, method):
         """Return a closure that loops over impls of the named method."""
@@ -230,10 +228,7 @@ class ExtensionCarrier(object):
                     return ret
             else:
                 return EXT_CONTINUE
-        try:
-            _do.__name__ = method.im_func.func_name
-        except:
-            pass
+        _do.__name__ = method
         return _do
 
     @staticmethod
@@ -242,9 +237,10 @@ class ExtensionCarrier(object):
 
     def __getattr__(self, key):
         """Delegate MapperExtension methods to bundled fronts."""
+        
         if key not in self.interface:
             raise AttributeError(key)
-        return self.methods.get(key, self._pass)
+        return self.get(key, self._pass)
 
 class ORMAdapter(sql_util.ColumnAdapter):
     def __init__(self, entity, equivalents=None, chain_to=None):
@@ -320,6 +316,11 @@ class AliasedComparator(PropComparator):
         return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs))
 
 def _orm_annotate(element, exclude=None):
+    """Deep copy the given ClauseElement, annotating each element with the "_orm_adapt" flag.
+    
+    Elements within the exclude collection will be cloned but not annotated.
+    
+    """
     def clone(elem):
         if exclude and elem in exclude:
             elem = elem._clone()
@@ -334,7 +335,8 @@ def _orm_annotate(element, exclude=None):
 
 
 class _ORMJoin(expression.Join):
-
+    """Extend Join to support ORM constructs as input."""
+    
     __visit_name__ = expression.Join.__visit_name__
 
     def __init__(self, left, right, onclause=None, isouter=False):
@@ -387,9 +389,27 @@ class _ORMJoin(expression.Join):
         return _ORMJoin(self, right, onclause, True)
 
 def join(left, right, onclause=None, isouter=False):
+    """Produce an inner join between left and right clauses.
+    
+    In addition to the interface provided by 
+    sqlalchemy.sql.join(), left and right may be mapped 
+    classes or AliasedClass instances. The onclause may be a 
+    string name of a relation(), or a class-bound descriptor 
+    representing a relation.
+    
+    """
     return _ORMJoin(left, right, onclause, isouter)
 
 def outerjoin(left, right, onclause=None):
+    """Produce a left outer join between left and right clauses.
+    
+    In addition to the interface provided by 
+    sqlalchemy.sql.outerjoin(), left and right may be mapped 
+    classes or AliasedClass instances. The onclause may be a 
+    string name of a relation(), or a class-bound descriptor 
+    representing a relation.
+    
+    """
     return _ORMJoin(left, right, onclause, True)
 
 def with_parent(instance, prop):
@@ -417,6 +437,16 @@ def with_parent(instance, prop):
 
 
 def _entity_info(entity, compile=True):
+    """Return mapping information given a class, mapper, or AliasedClass.
+    
+    Returns 3-tuple of: mapper, mapped selectable, boolean indicating if this
+    is an aliased() construct.
+    
+    If the given entity is not a mapper, mapped class, or aliased construct,
+    returns None, the entity, False.  This is typically used to allow
+    unmapped selectables through.
+    
+    """
     if isinstance(entity, AliasedClass):
         return entity._AliasedClass__mapper, entity._AliasedClass__alias, True
     elif _is_mapped_class(entity):
@@ -432,6 +462,11 @@ def _entity_info(entity, compile=True):
         return None, entity, False
 
 def _entity_descriptor(entity, key):
+    """Return attribute/property information given an entity and string name.
+    
+    Returns a 2-tuple representing InstrumentedAttribute/MapperProperty.
+    
+    """
     if isinstance(entity, AliasedClass):
         desc = getattr(entity, key)
         return desc, desc.property
@@ -459,33 +494,26 @@ def _is_aliased_class(entity):
 def _state_mapper(state):
     return state.manager.mapper
 
-def object_mapper(object, raiseerror=True):
+def object_mapper(instance):
     """Given an object, return the primary Mapper associated with the object instance.
-
-        object
-            The object instance.
-
-        raiseerror
-            Defaults to True: raise an ``InvalidRequestError`` if no mapper can
-            be located.  If False, return None.
-
+    
+    Raises UnmappedInstanceError if no mapping is configured.
+    
     """
     try:
-        state = attributes.instance_state(object)
+        state = attributes.instance_state(instance)
+        if not state.manager.mapper:
+            raise exc.UnmappedInstanceError(instance)
+        return state.manager.mapper
     except exc.NO_STATE:
-        if not raiseerror:
-            return None
-        raise exc.UnmappedInstanceError(object)
-    return class_mapper(
-        type(object), compile=False, raiseerror=raiseerror)
+        raise exc.UnmappedInstanceError(instance)
 
-def class_mapper(class_, compile=True, raiseerror=True):
+def class_mapper(class_, compile=True):
     """Given a class (or an object), return the primary Mapper associated with the key.
 
-    If no mapper can be located, raises ``InvalidRequestError``.
-
-    """
+    Raises UnmappedClassError if no mapping is configured.
     
+    """
     if not isinstance(class_, type):
         class_ = type(class_)
     try:
@@ -497,9 +525,8 @@ def class_mapper(class_, compile=True, raiseerror=True):
             raise AttributeError
             
     except exc.NO_STATE:
-        if not raiseerror:
-            return
         raise exc.UnmappedClassError(class_)
+
     if compile:
         mapper = mapper.compile()
     return mapper
@@ -538,11 +565,12 @@ def instance_str(instance):
     return state_str(attributes.instance_state(instance))
 
 def state_str(state):
-    """Return a string describing an instance."""
+    """Return a string describing an instance via its InstanceState."""
+    
     if state is None:
         return "None"
     else:
-        return state.class_.__name__ + "@" + hex(id(state.obj()))
+        return '<%s at 0x%x>' % (state.class_.__name__, id(state.obj()))
 
 def attribute_str(instance, attribute):
     return instance_str(instance) + "." + attribute
index 81c9d3936330f24b4becdfc9ee3b8c055643ff26..8f611c05e378eb721dac0e9589acb8a4bba5b188 100644 (file)
@@ -880,7 +880,10 @@ def _compound_select(keyword, *selects, **kwargs):
     return CompoundSelect(keyword, *selects, **kwargs)
 
 def _is_literal(element):
-    return not isinstance(element, (ClauseElement, Operators))
+    global schema
+    if not schema:
+        from sqlalchemy import schema
+    return not isinstance(element, (ClauseElement, Operators, schema.SchemaItem))
 
 def _from_objects(*elements, **kwargs):
     return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements])
index 8dd98512e3e40514aac5bab2d6e00d90a31ea917..f6cd2cc91a3129354bc7a98fcdd5f73df21113a9 100644 (file)
@@ -4,7 +4,7 @@ import inspect
 import pickle
 from sqlalchemy.orm import create_session, sessionmaker, attributes
 from testlib import engines, sa, testing, config
-from testlib.sa import Table, Column, Integer, String
+from testlib.sa import Table, Column, Integer, String, Sequence
 from testlib.sa.orm import mapper, relation, backref
 from testlib.testing import eq_
 from engine import _base as engine_base
@@ -62,6 +62,17 @@ class SessionTest(_fixtures.FixtureTest):
         finally:
             c.close()
 
+    @testing.requires.sequences
+    def test_sequence_execute(self):
+        seq = Sequence("some_sequence")
+        seq.create(testing.db)
+        try:
+            sess = create_session(bind=testing.db)
+            eq_(sess.execute(seq), 1)
+        finally:
+            seq.drop(testing.db)
+        
+        
     @testing.resolve_artifact_names
     def test_expunge_cascade(self):
         mapper(Address, addresses)
@@ -842,33 +853,6 @@ class SessionTest(_fixtures.FixtureTest):
         assert s.query(Address).one().id == a.id
         assert s.query(User).first() is None
 
-    @testing.resolve_artifact_names
-    def test_identity_key_1(self):
-        mapper(User, users)
-        s = create_session()
-        key = s.identity_key(User, 1)
-        eq_(key, (User, (1,)))
-        key = s.identity_key(User, ident=1)
-        eq_(key, (User, (1,)))
-
-    @testing.resolve_artifact_names
-    def test_identity_key_2(self):
-        mapper(User, users)
-        s = create_session()
-        u = User(name='u1')
-        s.add(u)
-        s.flush()
-        key = s.identity_key(instance=u)
-        eq_(key, (User, (u.id,)))
-
-    @testing.resolve_artifact_names
-    def test_identity_key_3(self):
-        mapper(User, users)
-        s = create_session()
-        row = {users.c.id: 1, users.c.name: "Frank"}
-        key = s.identity_key(User, row=row)
-        eq_(key, (User, (1,)))
-
     @testing.resolve_artifact_names
     def test_extension(self):
         mapper(User, users)
index 1f2cbe13a799ca24203807b0c832e52a1d89ca62..0a449fbf7067bfb94a707158883682bb60a06e10 100644 (file)
@@ -7,16 +7,18 @@ from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import Table
 from sqlalchemy.orm import aliased
-from sqlalchemy.orm import mapper
+from sqlalchemy.orm import mapper, create_session
+from orm import _fixtures
+from testlib.testing import eq_
 
 
 class ExtensionCarrierTest(TestBase):
     def test_basic(self):
         carrier = util.ExtensionCarrier()
 
-        assert 'translate_row' not in carrier.methods
+        assert 'translate_row' not in carrier
         assert carrier.translate_row() is interfaces.EXT_CONTINUE
-        assert 'translate_row' not in carrier.methods
+        assert 'translate_row' not in carrier
 
         self.assertRaises(AttributeError, lambda: carrier.snickysnack)
 
@@ -27,15 +29,15 @@ class ExtensionCarrierTest(TestBase):
                 return self.marker
 
         carrier.append(Partial('end'))
-        assert 'translate_row' in carrier.methods
+        assert 'translate_row' in carrier
         assert carrier.translate_row(None) == 'end'
 
         carrier.push(Partial('front'))
         assert carrier.translate_row(None) == 'front'
 
-        assert 'populate_instance' not in carrier.methods
+        assert 'populate_instance' not in carrier
         carrier.append(interfaces.MapperExtension)
-        assert 'populate_instance' in carrier.methods
+        assert 'populate_instance' in carrier
 
         assert carrier.interface
         for m in carrier.interface:
@@ -201,8 +203,37 @@ class AliasedClassTest(TestBase):
 
         assert_table(Point.left_of(p2), table)
         assert_table(alias.left_of(p2), alias_table)
-    
 
+class IdentityKeyTest(_fixtures.FixtureTest):
+    run_inserts = None
+
+    @testing.resolve_artifact_names
+    def test_identity_key_1(self):
+        mapper(User, users)
+
+        key = util.identity_key(User, 1)
+        eq_(key, (User, (1,)))
+        key = util.identity_key(User, ident=1)
+        eq_(key, (User, (1,)))
+
+    @testing.resolve_artifact_names
+    def test_identity_key_2(self):
+        mapper(User, users)
+        s = create_session()
+        u = User(name='u1')
+        s.add(u)
+        s.flush()
+        key = util.identity_key(instance=u)
+        eq_(key, (User, (u.id,)))
+
+    @testing.resolve_artifact_names
+    def test_identity_key_3(self):
+        mapper(User, users)
+
+        row = {users.c.id: 1, users.c.name: "Frank"}
+        key = util.identity_key(User, row=row)
+        eq_(key, (User, (1,)))
+    
 if __name__ == '__main__':
     testenv.main()