]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use a standard function to check for iterable collections
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Dec 2023 15:56:18 +0000 (10:56 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Dec 2023 20:42:02 +0000 (15:42 -0500)
Fixed 2.0 regression in :class:`.MutableList` where a routine that detects
sequences would not correctly filter out string or bytes instances, making
it impossible to assign a string value to a specific index (while
non-sequence values would work fine).

Fixes: #10784
Change-Id: I829cd2a1ef555184de8e6a752f39df65f69f6943
(cherry picked from commit 99da5ebab36da61b7bfa0b868f50974d6a4c4655)

doc/build/changelog/unreleased_20/10784.rst [new file with mode: 0644]
lib/sqlalchemy/ext/mutable.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/typing.py
test/base/test_utils.py
test/ext/test_mutable.py

diff --git a/doc/build/changelog/unreleased_20/10784.rst b/doc/build/changelog/unreleased_20/10784.rst
new file mode 100644 (file)
index 0000000..a67d5b6
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 10784
+
+    Fixed 2.0 regression in :class:`.MutableList` where a routine that detects
+    sequences would not correctly filter out string or bytes instances, making
+    it impossible to assign a string value to a specific index (while
+    non-sequence values would work fine).
index 0f82518aaa176d241c037f9e63898aade601d880..ff4dea086617a3f1a550933c1333ba9d1c6bef50 100644 (file)
@@ -378,6 +378,7 @@ from weakref import WeakKeyDictionary
 from .. import event
 from .. import inspect
 from .. import types
+from .. import util
 from ..orm import Mapper
 from ..orm._typing import _ExternalEntityType
 from ..orm._typing import _O
@@ -909,10 +910,10 @@ class MutableList(Mutable, List[_T]):
         self[:] = state
 
     def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]:
-        return not isinstance(value, Iterable)
+        return not util.is_non_string_iterable(value)
 
     def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]:
-        return isinstance(value, Iterable)
+        return util.is_non_string_iterable(value)
 
     def __setitem__(
         self, index: SupportsIndex | slice, value: _T | Iterable[_T]
index c4d340713baa42caeaeb39686a0fa45bac4ea575..3926e557a948451ed471e21e15890d0f8bfb58cb 100644 (file)
@@ -851,9 +851,7 @@ class InElementImpl(RoleImpl):
         )
 
     def _literal_coercion(self, element, expr, operator, **kw):
-        if isinstance(element, collections_abc.Iterable) and not isinstance(
-            element, str
-        ):
+        if util.is_non_string_iterable(element):
             non_literal_expressions: Dict[
                 Optional[operators.ColumnOperators],
                 operators.ColumnOperators,
index c804f96887805eed518c99aa1a280bd6038305d1..b60dcf2d948188c2ea2f7108cd9c0ef6b4f75037 100644 (file)
@@ -157,3 +157,4 @@ from .langhelpers import warn_exception as warn_exception
 from .langhelpers import warn_limited as warn_limited
 from .langhelpers import wrap_callable as wrap_callable
 from .preloaded import preload_module as preload_module
+from .typing import is_non_string_iterable as is_non_string_iterable
index a0b1977ee5096594b55d47b699fe54a18f6c3a1f..1e602165c801d0f409a72821f42a7a0576fbc17f 100644 (file)
@@ -9,7 +9,6 @@
 """Collection classes and helpers."""
 from __future__ import annotations
 
-import collections.abc as collections_abc
 import operator
 import threading
 import types
@@ -36,6 +35,7 @@ from typing import ValuesView
 import weakref
 
 from ._has_cy import HAS_CYEXTENSION
+from .typing import is_non_string_iterable
 from .typing import Literal
 from .typing import Protocol
 
@@ -419,9 +419,7 @@ def coerce_generator_arg(arg: Any) -> List[Any]:
 def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
     if x is None:
         return default  # type: ignore
-    if not isinstance(x, collections_abc.Iterable) or isinstance(
-        x, (str, bytes)
-    ):
+    if not is_non_string_iterable(x):
         return [x]
     elif isinstance(x, list):
         return x
index aad5709451dce41d83262d6702d832f0ca9cbd88..faf71c89a295f916eb7aead6ed2317a15ea998d7 100644 (file)
@@ -9,6 +9,7 @@
 from __future__ import annotations
 
 import builtins
+import collections.abc as collections_abc
 import re
 import sys
 import typing
@@ -296,6 +297,12 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
     return type_ is not None and typing_get_origin(type_) is Annotated
 
 
+def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
+    return isinstance(obj, collections_abc.Iterable) and not isinstance(
+        obj, (str, bytes)
+    )
+
+
 def is_literal(type_: _AnnotationScanType) -> bool:
     return get_origin(type_) is Literal
 
index 7dcf0968a7cd29f1422989ecdfc1100c97fdca03..de8712c852343b54c356695ec7f0413f4ad7bcd3 100644 (file)
@@ -1,4 +1,5 @@
 import copy
+from decimal import Decimal
 import inspect
 from pathlib import Path
 import pickle
@@ -31,6 +32,7 @@ from sqlalchemy.util import classproperty
 from sqlalchemy.util import compat
 from sqlalchemy.util import FastIntFlag
 from sqlalchemy.util import get_callable_argspec
+from sqlalchemy.util import is_non_string_iterable
 from sqlalchemy.util import langhelpers
 from sqlalchemy.util import preloaded
 from sqlalchemy.util import WeakSequence
@@ -1550,6 +1552,30 @@ class HashEqOverride:
             return True
 
 
+class MiscTest(fixtures.TestBase):
+    @testing.combinations(
+        (["one", "two", "three"], True),
+        (("one", "two", "three"), True),
+        ((), True),
+        ("four", False),
+        (252, False),
+        (Decimal("252"), False),
+        (b"four", False),
+        (iter("four"), True),
+        (b"", False),
+        ("", False),
+        (None, False),
+        ({"dict": "value"}, True),
+        ({}, True),
+        ({"set", "two"}, True),
+        (set(), True),
+        (util.immutabledict(), True),
+        (util.immutabledict({"key": "value"}), True),
+    )
+    def test_non_string_iterable_check(self, fixture, expected):
+        is_(is_non_string_iterable(fixture), expected)
+
+
 class IdentitySetTest(fixtures.TestBase):
     obj_type = object
 
index dffdac8d84298fceb2e3fed90b682f4f5baa4aa1..423784777862374d652ffb54c5c62a06c5e77990 100644 (file)
@@ -542,7 +542,7 @@ class _MutableListTestBase(_MutableListTestFixture):
             data={1, 2, 3},
         )
 
-    def test_in_place_mutation(self):
+    def test_in_place_mutation_int(self):
         sess = fixture_session()
 
         f1 = Foo(data=[1, 2])
@@ -554,7 +554,19 @@ class _MutableListTestBase(_MutableListTestFixture):
 
         eq_(f1.data, [3, 2])
 
-    def test_in_place_slice_mutation(self):
+    def test_in_place_mutation_str(self):
+        sess = fixture_session()
+
+        f1 = Foo(data=["one", "two"])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data[0] = "three"
+        sess.commit()
+
+        eq_(f1.data, ["three", "two"])
+
+    def test_in_place_slice_mutation_int(self):
         sess = fixture_session()
 
         f1 = Foo(data=[1, 2, 3, 4])
@@ -566,6 +578,18 @@ class _MutableListTestBase(_MutableListTestFixture):
 
         eq_(f1.data, [1, 5, 6, 4])
 
+    def test_in_place_slice_mutation_str(self):
+        sess = fixture_session()
+
+        f1 = Foo(data=["one", "two", "three", "four"])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data[1:3] = "five", "six"
+        sess.commit()
+
+        eq_(f1.data, ["one", "five", "six", "four"])
+
     def test_del_slice(self):
         sess = fixture_session()
 
@@ -1240,6 +1264,12 @@ class MutableColumnCopyArrayTest(_MutableListTestBase, fixtures.MappedTest):
             __tablename__ = "foo"
             id = Column(Integer, primary_key=True)
 
+    def test_in_place_mutation_str(self):
+        """this test is hardcoded to integer, skip strings"""
+
+    def test_in_place_slice_mutation_str(self):
+        """this test is hardcoded to integer, skip strings"""
+
 
 class MutableListWithScalarPickleTest(
     _MutableListTestBase, fixtures.MappedTest