]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Consult plugin_subject for non-ORM enabled stmts in get_bind()
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 14 May 2021 14:29:55 +0000 (10:29 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 14 May 2021 15:49:01 +0000 (11:49 -0400)
Enhanced the bind resolution rules for :meth:`_orm.Session.execute` so that
when a non-ORM statement such as an :func:`_sql.insert` construct
nonetheless is built against ORM objects, to the greatest degree possible
the ORM entity will be used to resolve the bind, such as for a
:class:`_orm.Session` that has a bind map set up on a common superclass
without specific mappers or tables named in the map.

Fixes: #6484
Change-Id: Iaa711b7f2c7451377b50af63029f37c4375a6f7e

doc/build/changelog/unreleased_14/6484.rst [new file with mode: 0644]
lib/sqlalchemy/orm/session.py
test/orm/test_bind.py

diff --git a/doc/build/changelog/unreleased_14/6484.rst b/doc/build/changelog/unreleased_14/6484.rst
new file mode 100644 (file)
index 0000000..28106dd
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 6484
+
+    Enhanced the bind resolution rules for :meth:`_orm.Session.execute` so that
+    when a non-ORM statement such as an :func:`_sql.insert` construct
+    nonetheless is built against ORM objects, to the greatest degree possible
+    the ORM entity will be used to resolve the bind, such as for a
+    :class:`_orm.Session` that has a bind map set up on a common superclass
+    without specific mappers or tables named in the map.
index cdf3a158565b1310e9715cc61495c2a340298d34..e4a9b90463315a0efe2a14e5c515e757b963c560 100644 (file)
@@ -1997,6 +1997,15 @@ class Session(_SessionClassMethods):
                     clause = mapper.persist_selectable
 
             if clause is not None:
+                plugin_subject = clause._propagate_attrs.get(
+                    "plugin_subject", None
+                )
+
+                if plugin_subject is not None:
+                    for cls in plugin_subject.mapper.class_.__mro__:
+                        if cls in self.__binds:
+                            return self.__binds[cls]
+
                 for obj in visitors.iterate(clause):
                     if obj in self.__binds:
                         return self.__binds[obj]
index 014fa152e9071dbe47ad72cd6ff3518abdae88df..5c7f3f72e5d200b00234ce7c98d55d8e7c4b8842 100644 (file)
@@ -1,5 +1,7 @@
 import sqlalchemy as sa
+from sqlalchemy import delete
 from sqlalchemy import ForeignKey
+from sqlalchemy import insert
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
@@ -7,6 +9,8 @@ from sqlalchemy import select
 from sqlalchemy import table
 from sqlalchemy import testing
 from sqlalchemy import true
+from sqlalchemy import update
+from sqlalchemy.orm import aliased
 from sqlalchemy.orm import backref
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
@@ -734,3 +738,23 @@ class GetBindTest(fixtures.MappedTest):
             select(self.tables.concrete_sub_table)
         )
         is_(session.get_bind(clause=stmt), base_class_bind)
+
+    @testing.combinations(
+        (insert,),
+        (update,),
+        (delete,),
+        (select,),
+    )
+    def test_clause_extracts_orm_plugin_subject(self, sql_elem):
+        ClassWMixin = self.classes.ClassWMixin
+        MixinOne = self.classes.MixinOne
+        base_class_bind = Mock()
+
+        session = self._fixture({MixinOne: base_class_bind})
+
+        stmt = sql_elem(ClassWMixin)
+        is_(session.get_bind(clause=stmt), base_class_bind)
+
+        cwm_alias = aliased(ClassWMixin)
+        stmt = sql_elem(cwm_alias)
+        is_(session.get_bind(clause=stmt), base_class_bind)