]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Check for Mapping explicitly in 2.0 params
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Nov 2021 21:02:24 +0000 (17:02 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 4 Nov 2021 22:07:21 +0000 (18:07 -0400)
Fixed issue in future :class:`_future.Connection` object where the
:meth:`_future.Connection.execute` method would not accept a non-dict
mapping object, such as SQLAlchemy's own :class:`.RowMapping` or other
``abc.collections.Mapping`` object as a parameter dictionary.

Fixes: #7291
Change-Id: I819f079d86d19d1d81c570e0680f987e51e34b84

doc/build/changelog/unreleased_14/7291.rst [new file with mode: 0644]
lib/sqlalchemy/engine/util.py
test/engine/test_execute.py

diff --git a/doc/build/changelog/unreleased_14/7291.rst b/doc/build/changelog/unreleased_14/7291.rst
new file mode 100644 (file)
index 0000000..add383e
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, engine
+    :tickets: 7291
+
+    Fixed issue in future :class:`_future.Connection` object where the
+    :meth:`_future.Connection.execute` method would not accept a non-dict
+    mapping object, such as SQLAlchemy's own :class:`.RowMapping` or other
+    ``abc.collections.Mapping`` object as a parameter dictionary.
index 4f2e031ab7422b2767d765295f2b9b458f210ff3..8eb0f182085cd7b5f36787c69019cf8654e9e118 100644 (file)
@@ -147,9 +147,9 @@ def _distill_params_20(params):
     elif isinstance(
         params,
         (tuple, dict, immutabledict),
-        # avoid abc.__instancecheck__
-        # (collections_abc.Sequence, collections_abc.Mapping),
-    ):
+        # only do abc.__instancecheck__ for Mapping after we've checked
+        # for plain dictionaries and would otherwise raise
+    ) or isinstance(params, collections_abc.Mapping):
         return (params,), _no_kw
     else:
         raise exc.ArgumentError("mapping or sequence expected for parameters")
index aeb6b484905ccf34a5074a516feb457105cc513c..23df3b03d2fd70637c4c03192bb846217a1f42bb 100644 (file)
@@ -264,6 +264,58 @@ class ExecuteTest(fixtures.TablesTest):
             (4, "sally"),
         ]
 
+    def test_non_dict_mapping(self, connection):
+        """ensure arbitrary Mapping works for execute()"""
+
+        class NotADict(collections_abc.Mapping):
+            def __init__(self, _data):
+                self._data = _data
+
+            def __iter__(self):
+                return iter(self._data)
+
+            def __len__(self):
+                return len(self._data)
+
+            def __getitem__(self, key):
+                return self._data[key]
+
+            def keys(self):
+                return self._data.keys()
+
+        nd = NotADict({"a": 10, "b": 15})
+        eq_(dict(nd), {"a": 10, "b": 15})
+
+        result = connection.execute(
+            select(
+                bindparam("a", type_=Integer), bindparam("b", type_=Integer)
+            ),
+            nd,
+        )
+        eq_(result.first(), (10, 15))
+
+    def test_row_works_as_mapping(self, connection):
+        """ensure the RowMapping object works as a parameter dictionary for
+        execute."""
+
+        result = connection.execute(
+            select(literal(10).label("a"), literal(15).label("b"))
+        )
+        row = result.first()
+        eq_(row, (10, 15))
+        eq_(row._mapping, {"a": 10, "b": 15})
+
+        result = connection.execute(
+            select(
+                bindparam("a", type_=Integer).label("a"),
+                bindparam("b", type_=Integer).label("b"),
+            ),
+            row._mapping,
+        )
+        row = result.first()
+        eq_(row, (10, 15))
+        eq_(row._mapping, {"a": 10, "b": 15})
+
     def test_dialect_has_table_assertion(self):
         with expect_raises_message(
             tsa.exc.ArgumentError,
@@ -3463,6 +3515,58 @@ class FutureExecuteTest(fixtures.FutureEngineMixin, fixtures.TablesTest):
             test_needs_acid=True,
         )
 
+    def test_non_dict_mapping(self, connection):
+        """ensure arbitrary Mapping works for execute()"""
+
+        class NotADict(collections_abc.Mapping):
+            def __init__(self, _data):
+                self._data = _data
+
+            def __iter__(self):
+                return iter(self._data)
+
+            def __len__(self):
+                return len(self._data)
+
+            def __getitem__(self, key):
+                return self._data[key]
+
+            def keys(self):
+                return self._data.keys()
+
+        nd = NotADict({"a": 10, "b": 15})
+        eq_(dict(nd), {"a": 10, "b": 15})
+
+        result = connection.execute(
+            select(
+                bindparam("a", type_=Integer), bindparam("b", type_=Integer)
+            ),
+            nd,
+        )
+        eq_(result.first(), (10, 15))
+
+    def test_row_works_as_mapping(self, connection):
+        """ensure the RowMapping object works as a parameter dictionary for
+        execute."""
+
+        result = connection.execute(
+            select(literal(10).label("a"), literal(15).label("b"))
+        )
+        row = result.first()
+        eq_(row, (10, 15))
+        eq_(row._mapping, {"a": 10, "b": 15})
+
+        result = connection.execute(
+            select(
+                bindparam("a", type_=Integer).label("a"),
+                bindparam("b", type_=Integer).label("b"),
+            ),
+            row._mapping,
+        )
+        row = result.first()
+        eq_(row, (10, 15))
+        eq_(row._mapping, {"a": 10, "b": 15})
+
     @testing.combinations(
         ({}, {}, {}),
         ({"a": "b"}, {}, {"a": "b"}),