]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Improvements to the mechanism used by :class:`.Session` to locate
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Oct 2014 18:04:17 +0000 (14:04 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Oct 2014 18:04:17 +0000 (14:04 -0400)
"binds" (e.g. engines to use), such engines can be associated with
mixin classes, concrete subclasses, as well as a wider variety
of table metadata such as joined inheritance tables.
fixes #3035

doc/build/changelog/changelog_10.rst
doc/build/changelog/migration_10.rst
lib/sqlalchemy/orm/session.py
test/orm/test_bind.py
test/orm/test_session.py

index 8578c788395cb2e98ebdf863c48f43484ce650eb..66fa2ad267eee3ec6dbe21f011b2a7b653e3c67b 100644 (file)
     series as well.  For changes that are specific to 1.0 with an emphasis
     on compatibility concerns, see :doc:`/changelog/migration_10`.
 
+    .. change::
+        :tags: bug, orm
+        :tickets: 3035
+
+        Improvements to the mechanism used by :class:`.Session` to locate
+        "binds" (e.g. engines to use), such engines can be associated with
+        mixin classes, concrete subclasses, as well as a wider variety
+        of table metadata such as joined inheritance tables.
+
+        .. seealso::
+
+            :ref:`bug_3035`
+
     .. change::
         :tags: bug, general
         :tickets: 3218
index 951e3960306a03d9d75eba2bc7a617517d03ab5b..dd8964f8bbb88cfd61752030afa9fedfc5d57cad 100644 (file)
@@ -468,6 +468,48 @@ object totally smokes both namedtuple and KeyedTuple::
 
 :ticket:`3176`
 
+.. _bug_3035:
+
+Session.get_bind() handles a wider variety of inheritance scenarios
+-------------------------------------------------------------------
+
+The :meth:`.Session.get_bind` method is invoked whenever a query or unit
+of work flush process seeks to locate the database engine that corresponds
+to a particular class.   The method has been improved to handle a variety
+of inheritance-oriented scenarios, including:
+
+* Binding to a Mixin or Abstract Class::
+
+        class MyClass(SomeMixin, Base):
+            __tablename__ = 'my_table'
+            # ...
+
+        session = Session(binds={SomeMixin: some_engine})
+
+
+* Binding to inherited concrete subclasses individually based on table::
+
+        class BaseClass(Base):
+            __tablename__ = 'base'
+
+            # ...
+
+        class ConcreteSubClass(BaseClass):
+            __tablename__ = 'concrete'
+
+            # ...
+
+            __mapper_args__ = {'concrete': True}
+
+
+        session = Session(binds={
+            base_table: some_engine,
+            concrete_table: some_other_engine
+        })
+
+
+:ticket:`3035`
+
 .. _feature_3178:
 
 New systems to safely emit parameterized warnings
index 13afcb35746012e9ff881673696c42468991dfcf..db9d3a51d860241c6ee3291640cbce9615296cc6 100644 (file)
@@ -641,14 +641,8 @@ class Session(_SessionClassMethods):
                 SessionExtension._adapt_listener(self, ext)
 
         if binds is not None:
-            for mapperortable, bind in binds.items():
-                insp = inspect(mapperortable)
-                if insp.is_selectable:
-                    self.bind_table(mapperortable, bind)
-                elif insp.is_mapper:
-                    self.bind_mapper(mapperortable, bind)
-                else:
-                    assert False
+            for key, bind in binds.items():
+                self._add_bind(key, bind)
 
         if not self.autocommit:
             self.begin()
@@ -1026,40 +1020,47 @@ class Session(_SessionClassMethods):
     # TODO: + crystallize + document resolution order
     #       vis. bind_mapper/bind_table
 
-    def bind_mapper(self, mapper, bind):
-        """Bind operations for a mapper to a Connectable.
-
-        mapper
-          A mapper instance or mapped class
+    def _add_bind(self, key, bind):
+        try:
+            insp = inspect(key)
+        except sa_exc.NoInspectionAvailable:
+            if not isinstance(key, type):
+                raise exc.ArgumentError(
+                            "Not acceptable bind target: %s" %
+                            key)
+            else:
+                self.__binds[key] = bind
+        else:
+            if insp.is_selectable:
+                self.__binds[insp] = bind
+            elif insp.is_mapper:
+                self.__binds[insp.class_] = bind
+                for selectable in insp._all_tables:
+                    self.__binds[selectable] = bind
+            else:
+                raise exc.ArgumentError(
+                            "Not acceptable bind target: %s" %
+                            key)
 
-        bind
-          Any Connectable: a :class:`.Engine` or :class:`.Connection`.
+    def bind_mapper(self, mapper, bind):
+        """Associate a :class:`.Mapper` with a "bind", e.g. a :class:`.Engine`
+        or :class:`.Connection`.
 
-        All subsequent operations involving this mapper will use the given
-        `bind`.
+        The given mapper is added to a lookup used by the
+        :meth:`.Session.get_bind` method.
 
         """
-        if isinstance(mapper, type):
-            mapper = class_mapper(mapper)
-
-        self.__binds[mapper.base_mapper] = bind
-        for t in mapper._all_tables:
-            self.__binds[t] = bind
+        self._add_bind(mapper, bind)
 
     def bind_table(self, table, bind):
-        """Bind operations on a Table to a Connectable.
-
-        table
-          A :class:`.Table` instance
+        """Associate a :class:`.Table` with a "bind", e.g. a :class:`.Engine`
+        or :class:`.Connection`.
 
-        bind
-          Any Connectable: a :class:`.Engine` or :class:`.Connection`.
-
-        All subsequent operations involving this :class:`.Table` will use the
-        given `bind`.
+        The given mapper is added to a lookup used by the
+        :meth:`.Session.get_bind` method.
 
         """
-        self.__binds[table] = bind
+        self._add_bind(table, bind)
 
     def get_bind(self, mapper=None, clause=None):
         """Return a "bind" to which this :class:`.Session` is bound.
@@ -1113,6 +1114,7 @@ class Session(_SessionClassMethods):
             bound :class:`.MetaData`.
 
         """
+
         if mapper is clause is None:
             if self.bind:
                 return self.bind
@@ -1122,15 +1124,23 @@ class Session(_SessionClassMethods):
                     "Connection, and no context was provided to locate "
                     "a binding.")
 
-        c_mapper = mapper is not None and _class_to_mapper(mapper) or None
+        if mapper is not None:
+            try:
+                mapper = inspect(mapper)
+            except sa_exc.NoInspectionAvailable:
+                if isinstance(mapper, type):
+                    raise exc.UnmappedClassError(mapper)
+                else:
+                    raise
 
-        # manually bound?
         if self.__binds:
-            if c_mapper:
-                if c_mapper.base_mapper in self.__binds:
-                    return self.__binds[c_mapper.base_mapper]
-                elif c_mapper.mapped_table in self.__binds:
-                    return self.__binds[c_mapper.mapped_table]
+            if mapper:
+                for cls in mapper.class_.__mro__:
+                    if cls in self.__binds:
+                        return self.__binds[cls]
+                if clause is None:
+                    clause = mapper.mapped_table
+
             if clause is not None:
                 for t in sql_util.find_tables(clause, include_crud=True):
                     if t in self.__binds:
@@ -1142,12 +1152,12 @@ class Session(_SessionClassMethods):
         if isinstance(clause, sql.expression.ClauseElement) and clause.bind:
             return clause.bind
 
-        if c_mapper and c_mapper.mapped_table.bind:
-            return c_mapper.mapped_table.bind
+        if mapper and mapper.mapped_table.bind:
+            return mapper.mapped_table.bind
 
         context = []
         if mapper is not None:
-            context.append('mapper %s' % c_mapper)
+            context.append('mapper %s' % mapper)
         if clause is not None:
             context.append('SQL expression')
 
index 3e5af0cba416a0c2b33a8cb8f0193de1cfbe8133..33cd66ebcd2ed71af4f99f96b64cfbcf611e708b 100644 (file)
@@ -1,13 +1,14 @@
 from sqlalchemy.testing import assert_raises_message
-from sqlalchemy import MetaData, Integer
+from sqlalchemy import MetaData, Integer, ForeignKey
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.orm import mapper, create_session
 import sqlalchemy as sa
 from sqlalchemy import testing
-from sqlalchemy.testing import fixtures, eq_, engines
+from sqlalchemy.testing import fixtures, eq_, engines, is_
 from sqlalchemy.orm import relationship, Session, backref, sessionmaker
 from test.orm import _fixtures
+from sqlalchemy.testing.mock import Mock
 
 
 class BindIntegrationTest(_fixtures.FixtureTest):
@@ -249,3 +250,218 @@ class SessionBindTest(fixtures.MappedTest):
             ('Could not locate a bind configured on Mapper|Foo|test_table '
              'or this Session'),
             sess.flush)
+
+
+class GetBindTest(fixtures.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            'base_table', metadata,
+            Column('id', Integer, primary_key=True)
+        )
+        Table(
+            'w_mixin_table', metadata,
+            Column('id', Integer, primary_key=True)
+        )
+        Table(
+            'joined_sub_table', metadata,
+            Column('id', ForeignKey('base_table.id'), primary_key=True)
+        )
+        Table(
+            'concrete_sub_table', metadata,
+            Column('id', Integer, primary_key=True)
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class MixinOne(cls.Basic):
+            pass
+
+        class BaseClass(cls.Basic):
+            pass
+
+        class ClassWMixin(MixinOne, cls.Basic):
+            pass
+
+        class JoinedSubClass(BaseClass):
+            pass
+
+        class ConcreteSubClass(BaseClass):
+            pass
+
+    @classmethod
+    def setup_mappers(cls):
+        mapper(cls.classes.ClassWMixin, cls.tables.w_mixin_table)
+        mapper(cls.classes.BaseClass, cls.tables.base_table)
+        mapper(
+            cls.classes.JoinedSubClass,
+            cls.tables.joined_sub_table, inherits=cls.classes.BaseClass)
+        mapper(
+            cls.classes.ConcreteSubClass,
+            cls.tables.concrete_sub_table, inherits=cls.classes.BaseClass,
+            concrete=True)
+
+    def _fixture(self, binds):
+        return Session(binds=binds)
+
+    def test_fallback_table_metadata(self):
+        session = self._fixture({})
+        is_(
+            session.get_bind(self.classes.BaseClass),
+            testing.db
+        )
+
+    def test_bind_base_table_base_class(self):
+        base_class_bind = Mock()
+        session = self._fixture({
+            self.tables.base_table: base_class_bind
+        })
+
+        is_(
+            session.get_bind(self.classes.BaseClass),
+            base_class_bind
+        )
+
+    def test_bind_base_table_joined_sub_class(self):
+        base_class_bind = Mock()
+        session = self._fixture({
+            self.tables.base_table: base_class_bind
+        })
+
+        is_(
+            session.get_bind(self.classes.BaseClass),
+            base_class_bind
+        )
+        is_(
+            session.get_bind(self.classes.JoinedSubClass),
+            base_class_bind
+        )
+
+    def test_bind_joined_sub_table_joined_sub_class(self):
+        base_class_bind = Mock(name='base')
+        joined_class_bind = Mock(name='joined')
+        session = self._fixture({
+            self.tables.base_table: base_class_bind,
+            self.tables.joined_sub_table: joined_class_bind
+        })
+
+        is_(
+            session.get_bind(self.classes.BaseClass),
+            base_class_bind
+        )
+        # joined table inheritance has to query based on the base
+        # table, so this is what we expect
+        is_(
+            session.get_bind(self.classes.JoinedSubClass),
+            base_class_bind
+        )
+
+    def test_bind_base_table_concrete_sub_class(self):
+        base_class_bind = Mock()
+        session = self._fixture({
+            self.tables.base_table: base_class_bind
+        })
+
+        is_(
+            session.get_bind(self.classes.ConcreteSubClass),
+            testing.db
+        )
+
+    def test_bind_sub_table_concrete_sub_class(self):
+        base_class_bind = Mock(name='base')
+        concrete_sub_bind = Mock(name='concrete')
+
+        session = self._fixture({
+            self.tables.base_table: base_class_bind,
+            self.tables.concrete_sub_table: concrete_sub_bind
+        })
+
+        is_(
+            session.get_bind(self.classes.BaseClass),
+            base_class_bind
+        )
+        is_(
+            session.get_bind(self.classes.ConcreteSubClass),
+            concrete_sub_bind
+        )
+
+    def test_bind_base_class_base_class(self):
+        base_class_bind = Mock()
+        session = self._fixture({
+            self.classes.BaseClass: base_class_bind
+        })
+
+        is_(
+            session.get_bind(self.classes.BaseClass),
+            base_class_bind
+        )
+
+    def test_bind_mixin_class_simple_class(self):
+        base_class_bind = Mock()
+        session = self._fixture({
+            self.classes.MixinOne: base_class_bind
+        })
+
+        is_(
+            session.get_bind(self.classes.ClassWMixin),
+            base_class_bind
+        )
+
+    def test_bind_base_class_joined_sub_class(self):
+        base_class_bind = Mock()
+        session = self._fixture({
+            self.classes.BaseClass: base_class_bind
+        })
+
+        is_(
+            session.get_bind(self.classes.JoinedSubClass),
+            base_class_bind
+        )
+
+    def test_bind_joined_sub_class_joined_sub_class(self):
+        base_class_bind = Mock(name='base')
+        joined_class_bind = Mock(name='joined')
+        session = self._fixture({
+            self.classes.BaseClass: base_class_bind,
+            self.classes.JoinedSubClass: joined_class_bind
+        })
+
+        is_(
+            session.get_bind(self.classes.BaseClass),
+            base_class_bind
+        )
+        is_(
+            session.get_bind(self.classes.JoinedSubClass),
+            joined_class_bind
+        )
+
+    def test_bind_base_class_concrete_sub_class(self):
+        base_class_bind = Mock()
+        session = self._fixture({
+            self.classes.BaseClass: base_class_bind
+        })
+
+        is_(
+            session.get_bind(self.classes.ConcreteSubClass),
+            base_class_bind
+        )
+
+    def test_bind_sub_class_concrete_sub_class(self):
+        base_class_bind = Mock(name='base')
+        concrete_sub_bind = Mock(name='concrete')
+
+        session = self._fixture({
+            self.classes.BaseClass: base_class_bind,
+            self.classes.ConcreteSubClass: concrete_sub_bind
+        })
+
+        is_(
+            session.get_bind(self.classes.BaseClass),
+            base_class_bind
+        )
+        is_(
+            session.get_bind(self.classes.ConcreteSubClass),
+            concrete_sub_bind
+        )
+
+
index 06d1d733424f7696e1a7bf5575e811443e43778d..b0b00d5ed35e4c262b8f05fe63bde1c4c99403a9 100644 (file)
@@ -1403,14 +1403,19 @@ class SessionInterface(fixtures.TestBase):
         eq_(watchdog, instance_methods,
             watchdog.symmetric_difference(instance_methods))
 
-    def _test_class_guards(self, user_arg):
+    def _test_class_guards(self, user_arg, is_class=True):
         watchdog = set()
 
         def raises_(method, *args, **kw):
             watchdog.add(method)
             callable_ = getattr(create_session(), method)
-            assert_raises(sa.orm.exc.UnmappedClassError,
-                              callable_, *args, **kw)
+            if is_class:
+                assert_raises(
+                    sa.orm.exc.UnmappedClassError,
+                    callable_, *args, **kw)
+            else:
+                assert_raises(
+                    sa.exc.NoInspectionAvailable, callable_, *args, **kw)
 
         raises_('connection', mapper=user_arg)
 
@@ -1433,7 +1438,7 @@ class SessionInterface(fixtures.TestBase):
     def test_unmapped_primitives(self):
         for prim in ('doh', 123, ('t', 'u', 'p', 'l', 'e')):
             self._test_instance_guards(prim)
-            self._test_class_guards(prim)
+            self._test_class_guards(prim, is_class=False)
 
     def test_unmapped_class_for_instance(self):
         class Unmapped(object):
@@ -1457,7 +1462,7 @@ class SessionInterface(fixtures.TestBase):
         self._map_it(Mapped)
 
         self._test_instance_guards(early)
-        self._test_class_guards(early)
+        self._test_class_guards(early, is_class=False)
 
 
 class TLTransactionTest(fixtures.MappedTest):