]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Repair _reinstall_default_lookups to also flip the _extended flag
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 1 May 2015 16:06:34 +0000 (12:06 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 1 May 2015 16:33:45 +0000 (12:33 -0400)
off again so that test fixtures setup/teardown instrumentation as
expected
- clean up test_extendedattr.py and fix it to no longer leak
itself outside by ensuring _reinstall_default_lookups is always called,
part of #3408
- Fixed bug where when using extended attribute instrumentation system,
the correct exception would not be raised when :func:`.class_mapper`
were called with an invalid input that also happened to not
be weak referencable, such as an integer.
fixes #3408

doc/build/changelog/changelog_09.rst
lib/sqlalchemy/ext/instrumentation.py
test/ext/test_extendedattr.py

index e66203ed343869d6cf0d3d08e2775229c0ca94e6..2506d21bdb2cdd9807cf8d99796e554bb8e49d3e 100644 (file)
 .. changelog::
     :version: 0.9.10
 
+    .. change::
+        :tags: bug, ext
+        :tickets: 3408
+        :versions: 1.0.4
+
+        Fixed bug where when using extended attribute instrumentation system,
+        the correct exception would not be raised when :func:`.class_mapper`
+        were called with an invalid input that also happened to not
+        be weak referencable, such as an integer.
+
     .. change::
         :tags: bug, tests, pypy
         :tickets: 3406
index 0241366612bb05ecdc5b0c26b5de48435ef1efc3..30a0ab7d739540677b05fe5471223c0448a94c87 100644 (file)
@@ -166,7 +166,13 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory):
     def manager_of_class(self, cls):
         if cls is None:
             return None
-        return self._manager_finders.get(cls, _default_manager_getter)(cls)
+        try:
+            finder = self._manager_finders.get(cls, _default_manager_getter)
+        except TypeError:
+            # due to weakref lookup on invalid object
+            return None
+        else:
+            return finder(cls)
 
     def state_of(self, instance):
         if instance is None:
@@ -392,6 +398,7 @@ def _reinstall_default_lookups():
             manager_of_class=_default_manager_getter
         )
     )
+    _instrumentation_factory._extended = False
 
 
 def _install_lookups(lookups):
index c7627c8b23616a3736bb779fe4628220dd56e94d..653418ac4007d57758399faab2b31e728de0b167 100644 (file)
@@ -1,10 +1,12 @@
 from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, ne_
 from sqlalchemy import util
+import sqlalchemy as sa
+from sqlalchemy.orm import class_mapper
 from sqlalchemy.orm import attributes
-from sqlalchemy.orm.attributes import set_attribute, get_attribute, del_attribute
+from sqlalchemy.orm.attributes import set_attribute, \
+    get_attribute, del_attribute
 from sqlalchemy.orm.instrumentation import is_instrumented
 from sqlalchemy.orm import clear_mappers
-from sqlalchemy import testing
 from sqlalchemy.testing import fixtures
 from sqlalchemy.ext import instrumentation
 from sqlalchemy.orm.instrumentation import register_class
@@ -12,6 +14,7 @@ from sqlalchemy.testing.util import decorator
 from sqlalchemy.orm import events
 from sqlalchemy import event
 
+
 @decorator
 def modifies_instrumentation_finders(fn, *args, **kw):
     pristine = instrumentation.instrumentation_finders[:]
@@ -21,15 +24,11 @@ def modifies_instrumentation_finders(fn, *args, **kw):
         del instrumentation.instrumentation_finders[:]
         instrumentation.instrumentation_finders.extend(pristine)
 
-def with_lookup_strategy(strategy):
-    @decorator
-    def decorate(fn, *args, **kw):
-        try:
-            ext_instrumentation._install_instrumented_lookups()
-            return fn(*args, **kw)
-        finally:
-            ext_instrumentation._reinstall_default_lookups()
-    return decorate
+
+class _ExtBase(object):
+    @classmethod
+    def teardown_class(cls):
+        instrumentation._reinstall_default_lookups()
 
 
 class MyTypesManager(instrumentation.InstrumentationManager):
@@ -58,16 +57,19 @@ class MyTypesManager(instrumentation.InstrumentationManager):
     def state_getter(self, class_):
         return lambda instance: instance.__dict__['_my_state']
 
+
 class MyListLike(list):
     # add @appender, @remover decorators as needed
     _sa_iterator = list.__iter__
     _sa_linker = None
     _sa_converter = None
+
     def _sa_appender(self, item, _sa_initiator=None):
         if _sa_initiator is not False:
             self._sa_adapter.fire_append_event(item, _sa_initiator)
         list.append(self, item)
     append = _sa_appender
+
     def _sa_remover(self, item, _sa_initiator=None):
         self._sa_adapter.fire_pre_remove_event(_sa_initiator)
         if _sa_initiator is not False:
@@ -75,57 +77,64 @@ class MyListLike(list):
         list.remove(self, item)
     remove = _sa_remover
 
-class MyBaseClass(object):
-    __sa_instrumentation_manager__ = instrumentation.InstrumentationManager
-
-class MyClass(object):
-
-    # This proves that a staticmethod will work here; don't
-    # flatten this back to a class assignment!
-    def __sa_instrumentation_manager__(cls):
-        return MyTypesManager(cls)
-
-    __sa_instrumentation_manager__ = staticmethod(__sa_instrumentation_manager__)
-
-    # This proves SA can handle a class with non-string dict keys
-    if not util.pypy and not util.jython:
-        locals()[42] = 99   # Don't remove this line!
-
-    def __init__(self, **kwargs):
-        for k in kwargs:
-            setattr(self, k, kwargs[k])
-
-    def __getattr__(self, key):
-        if is_instrumented(self, key):
-            return get_attribute(self, key)
-        else:
-            try:
-                return self._goofy_dict[key]
-            except KeyError:
-                raise AttributeError(key)
-
-    def __setattr__(self, key, value):
-        if is_instrumented(self, key):
-            set_attribute(self, key, value)
-        else:
-            self._goofy_dict[key] = value
-
-    def __hasattr__(self, key):
-        if is_instrumented(self, key):
-            return True
-        else:
-            return key in self._goofy_dict
-
-    def __delattr__(self, key):
-        if is_instrumented(self, key):
-            del_attribute(self, key)
-        else:
-            del self._goofy_dict[key]
-
-class UserDefinedExtensionTest(fixtures.ORMTest):
+
+MyBaseClass, MyClass = None, None
+
+
+class UserDefinedExtensionTest(_ExtBase, fixtures.ORMTest):
+
     @classmethod
-    def teardown_class(cls):
-        instrumentation._reinstall_default_lookups()
+    def setup_class(cls):
+        global MyBaseClass, MyClass
+
+        class MyBaseClass(object):
+            __sa_instrumentation_manager__ = \
+                instrumentation.InstrumentationManager
+
+        class MyClass(object):
+
+            # This proves that a staticmethod will work here; don't
+            # flatten this back to a class assignment!
+            def __sa_instrumentation_manager__(cls):
+                return MyTypesManager(cls)
+
+            __sa_instrumentation_manager__ = staticmethod(
+                __sa_instrumentation_manager__)
+
+            # This proves SA can handle a class with non-string dict keys
+            if not util.pypy and not util.jython:
+                locals()[42] = 99   # Don't remove this line!
+
+            def __init__(self, **kwargs):
+                for k in kwargs:
+                    setattr(self, k, kwargs[k])
+
+            def __getattr__(self, key):
+                if is_instrumented(self, key):
+                    return get_attribute(self, key)
+                else:
+                    try:
+                        return self._goofy_dict[key]
+                    except KeyError:
+                        raise AttributeError(key)
+
+            def __setattr__(self, key, value):
+                if is_instrumented(self, key):
+                    set_attribute(self, key, value)
+                else:
+                    self._goofy_dict[key] = value
+
+            def __hasattr__(self, key):
+                if is_instrumented(self, key):
+                    return True
+                else:
+                    return key in self._goofy_dict
+
+            def __delattr__(self, key):
+                if is_instrumented(self, key):
+                    del_attribute(self, key)
+                else:
+                    del self._goofy_dict[key]
 
     def teardown(self):
         clear_mappers()
@@ -135,15 +144,25 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
             pass
 
         register_class(User)
-        attributes.register_attribute(User, 'user_id', uselist = False, useobject=False)
-        attributes.register_attribute(User, 'user_name', uselist = False, useobject=False)
-        attributes.register_attribute(User, 'email_address', uselist = False, useobject=False)
+        attributes.register_attribute(
+            User, 'user_id', uselist=False, useobject=False)
+        attributes.register_attribute(
+            User, 'user_name', uselist=False, useobject=False)
+        attributes.register_attribute(
+            User, 'email_address', uselist=False, useobject=False)
 
         u = User()
         u.user_id = 7
         u.user_name = 'john'
         u.email_address = 'lala@123.com'
-        self.assert_(u.__dict__ == {'_my_state':u._my_state, '_goofy_dict':{'user_id':7, 'user_name':'john', 'email_address':'lala@123.com'}}, u.__dict__)
+        eq_(
+            u.__dict__,
+            {
+                '_my_state': u._my_state,
+                '_goofy_dict': {
+                    'user_id': 7, 'user_name': 'john',
+                    'email_address': 'lala@123.com'}}
+        )
 
     def test_basic(self):
         for base in (object, MyBaseClass, MyClass):
@@ -151,29 +170,40 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
                 pass
 
             register_class(User)
-            attributes.register_attribute(User, 'user_id', uselist = False, useobject=False)
-            attributes.register_attribute(User, 'user_name', uselist = False, useobject=False)
-            attributes.register_attribute(User, 'email_address', uselist = False, useobject=False)
+            attributes.register_attribute(
+                User, 'user_id', uselist=False, useobject=False)
+            attributes.register_attribute(
+                User, 'user_name', uselist=False, useobject=False)
+            attributes.register_attribute(
+                User, 'email_address', uselist=False, useobject=False)
 
             u = User()
             u.user_id = 7
             u.user_name = 'john'
             u.email_address = 'lala@123.com'
 
-            self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
-            attributes.instance_state(u)._commit_all(attributes.instance_dict(u))
-            self.assert_(u.user_id == 7 and u.user_name == 'john' and u.email_address == 'lala@123.com')
+            eq_(u.user_id, 7)
+            eq_(u.user_name, "john")
+            eq_(u.email_address, "lala@123.com")
+            attributes.instance_state(u)._commit_all(
+                attributes.instance_dict(u))
+            eq_(u.user_id, 7)
+            eq_(u.user_name, "john")
+            eq_(u.email_address, "lala@123.com")
 
             u.user_name = 'heythere'
             u.email_address = 'foo@bar.com'
-            self.assert_(u.user_id == 7 and u.user_name == 'heythere' and u.email_address == 'foo@bar.com')
+            eq_(u.user_id, 7)
+            eq_(u.user_name, "heythere")
+            eq_(u.email_address, "foo@bar.com")
 
     def test_deferred(self):
         for base in (object, MyBaseClass, MyClass):
             class Foo(base):
                 pass
 
-            data = {'a':'this is a', 'b':12}
+            data = {'a': 'this is a', 'b': 12}
+
             def loader(state, keys):
                 for k in keys:
                     state.dict[k] = data[k]
@@ -181,30 +211,38 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
 
             manager = register_class(Foo)
             manager.deferred_scalar_loader = loader
-            attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
-            attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
+            attributes.register_attribute(
+                Foo, 'a', uselist=False, useobject=False)
+            attributes.register_attribute(
+                Foo, 'b', uselist=False, useobject=False)
 
             if base is object:
-                assert Foo not in instrumentation._instrumentation_factory._state_finders
+                assert Foo not in \
+                    instrumentation._instrumentation_factory._state_finders
             else:
-                assert Foo in instrumentation._instrumentation_factory._state_finders
+                assert Foo in \
+                    instrumentation._instrumentation_factory._state_finders
 
             f = Foo()
-            attributes.instance_state(f)._expire(attributes.instance_dict(f), set())
+            attributes.instance_state(f)._expire(
+                attributes.instance_dict(f), set())
             eq_(f.a, "this is a")
             eq_(f.b, 12)
 
             f.a = "this is some new a"
-            attributes.instance_state(f)._expire(attributes.instance_dict(f), set())
+            attributes.instance_state(f)._expire(
+                attributes.instance_dict(f), set())
             eq_(f.a, "this is a")
             eq_(f.b, 12)
 
-            attributes.instance_state(f)._expire(attributes.instance_dict(f), set())
+            attributes.instance_state(f)._expire(
+                attributes.instance_dict(f), set())
             f.a = "this is another new a"
             eq_(f.a, "this is another new a")
             eq_(f.b, 12)
 
-            attributes.instance_state(f)._expire(attributes.instance_dict(f), set())
+            attributes.instance_state(f)._expire(
+                attributes.instance_dict(f), set())
             eq_(f.a, "this is a")
             eq_(f.b, 12)
 
@@ -212,7 +250,8 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
             eq_(f.a, None)
             eq_(f.b, 12)
 
-            attributes.instance_state(f)._commit_all(attributes.instance_dict(f))
+            attributes.instance_state(f)._commit_all(
+                attributes.instance_dict(f))
             eq_(f.a, None)
             eq_(f.b, 12)
 
@@ -220,27 +259,32 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
         """tests that attributes are polymorphic"""
 
         for base in (object, MyBaseClass, MyClass):
-            class Foo(base):pass
-            class Bar(Foo):pass
+            class Foo(base):
+                pass
+
+            class Bar(Foo):
+                pass
 
             register_class(Foo)
             register_class(Bar)
 
             def func1(state, passive):
                 return "this is the foo attr"
+
             def func2(state, passive):
                 return "this is the bar attr"
+
             def func3(state, passive):
                 return "this is the shared attr"
             attributes.register_attribute(Foo, 'element',
-                    uselist=False, callable_=func1,
-                    useobject=True)
+                                          uselist=False, callable_=func1,
+                                          useobject=True)
             attributes.register_attribute(Foo, 'element2',
-                    uselist=False, callable_=func3,
-                    useobject=True)
+                                          uselist=False, callable_=func3,
+                                          useobject=True)
             attributes.register_attribute(Bar, 'element',
-                    uselist=False, callable_=func2,
-                    useobject=True)
+                                          uselist=False, callable_=func2,
+                                          useobject=True)
 
             x = Foo()
             y = Bar()
@@ -251,15 +295,20 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
 
     def test_collection_with_backref(self):
         for base in (object, MyBaseClass, MyClass):
-            class Post(base):pass
-            class Blog(base):pass
+            class Post(base):
+                pass
+
+            class Blog(base):
+                pass
 
             register_class(Post)
             register_class(Blog)
-            attributes.register_attribute(Post, 'blog', uselist=False,
-                    backref='posts', trackparent=True, useobject=True)
-            attributes.register_attribute(Blog, 'posts', uselist=True,
-                    backref='blog', trackparent=True, useobject=True)
+            attributes.register_attribute(
+                Post, 'blog', uselist=False,
+                backref='posts', trackparent=True, useobject=True)
+            attributes.register_attribute(
+                Blog, 'posts', uselist=True,
+                backref='blog', trackparent=True, useobject=True)
             b = Blog()
             (p1, p2, p3) = (Post(), Post(), Post())
             b.posts.append(p1)
@@ -287,47 +336,77 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
         for base in (object, MyBaseClass, MyClass):
             class Foo(base):
                 pass
+
             class Bar(base):
                 pass
 
             register_class(Foo)
             register_class(Bar)
-            attributes.register_attribute(Foo, "name", uselist=False, useobject=False)
-            attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True)
-            attributes.register_attribute(Bar, "name", uselist=False, useobject=False)
-
+            attributes.register_attribute(
+                Foo, "name", uselist=False, useobject=False)
+            attributes.register_attribute(
+                Foo, "bars", uselist=True, trackparent=True, useobject=True)
+            attributes.register_attribute(
+                Bar, "name", uselist=False, useobject=False)
 
             f1 = Foo()
             f1.name = 'f1'
 
-            eq_(attributes.get_state_history(attributes.instance_state(f1), 'name'), (['f1'], (), ()))
+            eq_(
+                attributes.get_state_history(
+                    attributes.instance_state(f1), 'name'),
+                (['f1'], (), ()))
 
             b1 = Bar()
             b1.name = 'b1'
             f1.bars.append(b1)
-            eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ([b1], [], []))
-
-            attributes.instance_state(f1)._commit_all(attributes.instance_dict(f1))
-            attributes.instance_state(b1)._commit_all(attributes.instance_dict(b1))
-
-            eq_(attributes.get_state_history(attributes.instance_state(f1), 'name'), ((), ['f1'], ()))
-            eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ((), [b1], ()))
+            eq_(
+                attributes.get_state_history(
+                    attributes.instance_state(f1), 'bars'),
+                ([b1], [], []))
+
+            attributes.instance_state(f1)._commit_all(
+                attributes.instance_dict(f1))
+            attributes.instance_state(b1)._commit_all(
+                attributes.instance_dict(b1))
+
+            eq_(
+                attributes.get_state_history(
+                    attributes.instance_state(f1),
+                    'name'),
+                ((), ['f1'], ()))
+            eq_(
+                attributes.get_state_history(
+                    attributes.instance_state(f1),
+                    'bars'),
+                ((), [b1], ()))
 
             f1.name = 'f1mod'
             b2 = Bar()
             b2.name = 'b2'
             f1.bars.append(b2)
-            eq_(attributes.get_state_history(attributes.instance_state(f1), 'name'), (['f1mod'], (), ['f1']))
-            eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ([b2], [b1], []))
+            eq_(
+                attributes.get_state_history(
+                    attributes.instance_state(f1), 'name'),
+                (['f1mod'], (), ['f1']))
+            eq_(
+                attributes.get_state_history(
+                    attributes.instance_state(f1), 'bars'),
+                ([b2], [b1], []))
             f1.bars.remove(b1)
-            eq_(attributes.get_state_history(attributes.instance_state(f1), 'bars'), ([b2], [], [b1]))
+            eq_(
+                attributes.get_state_history(
+                    attributes.instance_state(f1), 'bars'),
+                ([b2], [], [b1]))
 
     def test_null_instrumentation(self):
         class Foo(MyBaseClass):
             pass
         register_class(Foo)
-        attributes.register_attribute(Foo, "name", uselist=False, useobject=False)
-        attributes.register_attribute(Foo, "bars", uselist=True, trackparent=True, useobject=True)
+        attributes.register_attribute(
+            Foo, "name", uselist=False, useobject=False)
+        attributes.register_attribute(
+            Foo, "bars", uselist=True, trackparent=True, useobject=True)
 
         assert Foo.name == attributes.manager_of_class(Foo)['name']
         assert Foo.bars == attributes.manager_of_class(Foo)['bars']
@@ -335,8 +414,11 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
     def test_alternate_finders(self):
         """Ensure the generic finder front-end deals with edge cases."""
 
-        class Unknown(object): pass
-        class Known(MyBaseClass): pass
+        class Unknown(object):
+            pass
+
+        class Known(MyBaseClass):
+            pass
 
         register_class(Known)
         k, u = Known(), Unknown()
@@ -347,28 +429,59 @@ class UserDefinedExtensionTest(fixtures.ORMTest):
 
         assert attributes.instance_state(k) is not None
         assert_raises((AttributeError, KeyError),
-                          attributes.instance_state, u)
+                      attributes.instance_state, u)
         assert_raises((AttributeError, KeyError),
-                          attributes.instance_state, None)
+                      attributes.instance_state, None)
+
+    def test_unmapped_not_type_error(self):
+        """extension version of the same test in test_mapper.
+
+        fixes #3408
+        """
+        assert_raises_message(
+            sa.exc.ArgumentError,
+            "Class object expected, got '5'.",
+            class_mapper, 5
+        )
 
+    def test_unmapped_not_type_error_iter_ok(self):
+        """extension version of the same test in test_mapper.
+
+        fixes #3408
+        """
+        assert_raises_message(
+            sa.exc.ArgumentError,
+            r"Class object expected, got '\(5, 6\)'.",
+            class_mapper, (5, 6)
+        )
+
+
+class FinderTest(_ExtBase, fixtures.ORMTest):
 
-class FinderTest(fixtures.ORMTest):
     def test_standard(self):
-        class A(object): pass
+        class A(object):
+            pass
 
         register_class(A)
 
-        eq_(type(instrumentation.manager_of_class(A)), instrumentation.ClassManager)
+        eq_(
+            type(instrumentation.manager_of_class(A)),
+            instrumentation.ClassManager)
 
     def test_nativeext_interfaceexact(self):
         class A(object):
-            __sa_instrumentation_manager__ = instrumentation.InstrumentationManager
+            __sa_instrumentation_manager__ = \
+                instrumentation.InstrumentationManager
 
         register_class(A)
-        ne_(type(instrumentation.manager_of_class(A)), instrumentation.ClassManager)
+        ne_(
+            type(instrumentation.manager_of_class(A)),
+            instrumentation.ClassManager)
 
     def test_nativeext_submanager(self):
-        class Mine(instrumentation.ClassManager): pass
+        class Mine(instrumentation.ClassManager):
+            pass
+
         class A(object):
             __sa_instrumentation_manager__ = Mine
 
@@ -377,8 +490,12 @@ class FinderTest(fixtures.ORMTest):
 
     @modifies_instrumentation_finders
     def test_customfinder_greedy(self):
-        class Mine(instrumentation.ClassManager): pass
-        class A(object): pass
+        class Mine(instrumentation.ClassManager):
+            pass
+
+        class A(object):
+            pass
+
         def find(cls):
             return Mine
 
@@ -388,20 +505,28 @@ class FinderTest(fixtures.ORMTest):
 
     @modifies_instrumentation_finders
     def test_customfinder_pass(self):
-        class A(object): pass
+        class A(object):
+            pass
+
         def find(cls):
             return None
 
         instrumentation.instrumentation_finders.insert(0, find)
         register_class(A)
-        eq_(type(instrumentation.manager_of_class(A)), instrumentation.ClassManager)
+        eq_(
+            type(instrumentation.manager_of_class(A)),
+            instrumentation.ClassManager)
+
+
+class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest):
 
-class InstrumentationCollisionTest(fixtures.ORMTest):
     def test_none(self):
-        class A(object): pass
+        class A(object):
+            pass
         register_class(A)
 
         mgr_factory = lambda cls: instrumentation.ClassManager(cls)
+
         class B(object):
             __sa_instrumentation_manager__ = staticmethod(mgr_factory)
         register_class(B)
@@ -411,79 +536,114 @@ class InstrumentationCollisionTest(fixtures.ORMTest):
         register_class(C)
 
     def test_single_down(self):
-        class A(object): pass
+        class A(object):
+            pass
         register_class(A)
 
         mgr_factory = lambda cls: instrumentation.ClassManager(cls)
+
         class B(A):
             __sa_instrumentation_manager__ = staticmethod(mgr_factory)
 
-        assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B)
+        assert_raises_message(
+            TypeError, "multiple instrumentation implementations",
+            register_class, B)
 
     def test_single_up(self):
 
-        class A(object): pass
+        class A(object):
+            pass
         # delay registration
 
         mgr_factory = lambda cls: instrumentation.ClassManager(cls)
+
         class B(A):
             __sa_instrumentation_manager__ = staticmethod(mgr_factory)
         register_class(B)
 
-        assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, A)
+        assert_raises_message(
+            TypeError, "multiple instrumentation implementations",
+            register_class, A)
 
     def test_diamond_b1(self):
         mgr_factory = lambda cls: instrumentation.ClassManager(cls)
 
-        class A(object): pass
-        class B1(A): pass
+        class A(object):
+            pass
+
+        class B1(A):
+            pass
+
         class B2(A):
             __sa_instrumentation_manager__ = staticmethod(mgr_factory)
-        class C(object): pass
 
-        assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B1)
+        class C(object):
+            pass
+
+        assert_raises_message(
+            TypeError, "multiple instrumentation implementations",
+            register_class, B1)
 
     def test_diamond_b2(self):
         mgr_factory = lambda cls: instrumentation.ClassManager(cls)
 
-        class A(object): pass
-        class B1(A): pass
+        class A(object):
+            pass
+
+        class B1(A):
+            pass
+
         class B2(A):
             __sa_instrumentation_manager__ = staticmethod(mgr_factory)
-        class C(object): pass
+
+        class C(object):
+            pass
 
         register_class(B2)
-        assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B1)
+        assert_raises_message(
+            TypeError, "multiple instrumentation implementations",
+            register_class, B1)
 
     def test_diamond_c_b(self):
         mgr_factory = lambda cls: instrumentation.ClassManager(cls)
 
-        class A(object): pass
-        class B1(A): pass
+        class A(object):
+            pass
+
+        class B1(A):
+            pass
+
         class B2(A):
             __sa_instrumentation_manager__ = staticmethod(mgr_factory)
-        class C(object): pass
+
+        class C(object):
+            pass
 
         register_class(C)
 
-        assert_raises_message(TypeError, "multiple instrumentation implementations", register_class, B1)
+        assert_raises_message(
+            TypeError, "multiple instrumentation implementations",
+            register_class, B1)
 
 
-class ExtendedEventsTest(fixtures.ORMTest):
+class ExtendedEventsTest(_ExtBase, fixtures.ORMTest):
+
     """Allow custom Events implementations."""
 
     @modifies_instrumentation_finders
     def test_subclassed(self):
         class MyEvents(events.InstanceEvents):
             pass
+
         class MyClassManager(instrumentation.ClassManager):
             dispatch = event.dispatcher(MyEvents)
 
-        instrumentation.instrumentation_finders.insert(0, lambda cls: MyClassManager)
+        instrumentation.instrumentation_finders.insert(
+            0, lambda cls: MyClassManager)
 
-        class A(object): pass
+        class A(object):
+            pass
 
         register_class(A)
         manager = instrumentation.manager_of_class(A)
         assert issubclass(manager.dispatch._events, MyEvents)
-