]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use a standard function to check for iterable collections
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Dec 2023 15:56:18 +0000 (10:56 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 20 Dec 2023 20:41:54 +0000 (15:41 -0500)
Fixed 2.0 regression in :class:`.MutableList` where a routine that detects
sequences would not correctly filter out string or bytes instances, making
it impossible to assign a string value to a specific index (while
non-sequence values would work fine).

Fixes: #10784
Change-Id: I829cd2a1ef555184de8e6a752f39df65f69f6943

doc/build/changelog/unreleased_20/10784.rst [new file with mode: 0644]
lib/sqlalchemy/ext/mutable.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/_collections.py
lib/sqlalchemy/util/typing.py
test/base/test_utils.py
test/ext/test_mutable.py

diff --git a/doc/build/changelog/unreleased_20/10784.rst b/doc/build/changelog/unreleased_20/10784.rst
new file mode 100644 (file)
index 0000000..a67d5b6
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 10784
+
+    Fixed 2.0 regression in :class:`.MutableList` where a routine that detects
+    sequences would not correctly filter out string or bytes instances, making
+    it impossible to assign a string value to a specific index (while
+    non-sequence values would work fine).
index 312f0e49d67713462fd2b826d984922c2cb457ab..bd5820ebdef070f03c752c9bd9029c7b91802d48 100644 (file)
@@ -379,6 +379,7 @@ from weakref import WeakKeyDictionary
 from .. import event
 from .. import inspect
 from .. import types
+from .. import util
 from ..orm import Mapper
 from ..orm._typing import _ExternalEntityType
 from ..orm._typing import _O
@@ -909,10 +910,10 @@ class MutableList(Mutable, List[_T]):
         self[:] = state
 
     def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]:
-        return not isinstance(value, Iterable)
+        return not util.is_non_string_iterable(value)
 
     def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]:
-        return isinstance(value, Iterable)
+        return util.is_non_string_iterable(value)
 
     def __setitem__(
         self, index: SupportsIndex | slice, value: _T | Iterable[_T]
index c4d340713baa42caeaeb39686a0fa45bac4ea575..3926e557a948451ed471e21e15890d0f8bfb58cb 100644 (file)
@@ -851,9 +851,7 @@ class InElementImpl(RoleImpl):
         )
 
     def _literal_coercion(self, element, expr, operator, **kw):
-        if isinstance(element, collections_abc.Iterable) and not isinstance(
-            element, str
-        ):
+        if util.is_non_string_iterable(element):
             non_literal_expressions: Dict[
                 Optional[operators.ColumnOperators],
                 operators.ColumnOperators,
index caaa657f935fa1faf9e31b97f8d37ca095d13878..91e400e1813b4295663ed5d25cafa5149d51bcd3 100644 (file)
@@ -156,3 +156,4 @@ from .langhelpers import warn_exception as warn_exception
 from .langhelpers import warn_limited as warn_limited
 from .langhelpers import wrap_callable as wrap_callable
 from .preloaded import preload_module as preload_module
+from .typing import is_non_string_iterable as is_non_string_iterable
index 90cfa716e9e713fc5c4becfe0b87b2e52e6b9c9b..bf5d7117db85f436cf66ce33437bd23d5428841d 100644 (file)
@@ -9,7 +9,6 @@
 """Collection classes and helpers."""
 from __future__ import annotations
 
-import collections.abc as collections_abc
 import operator
 import threading
 import types
@@ -37,6 +36,7 @@ from typing import ValuesView
 import weakref
 
 from ._has_cy import HAS_CYEXTENSION
+from .typing import is_non_string_iterable
 from .typing import Literal
 
 if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
@@ -419,9 +419,7 @@ def coerce_generator_arg(arg: Any) -> List[Any]:
 def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]:
     if x is None:
         return default  # type: ignore
-    if not isinstance(x, collections_abc.Iterable) or isinstance(
-        x, (str, bytes)
-    ):
+    if not is_non_string_iterable(x):
         return [x]
     elif isinstance(x, list):
         return x
index d13859309afba00e5f89744cd07945061118fe6a..c4f41d9151825518b6fb38fea7bdc05621a8910f 100644 (file)
@@ -9,6 +9,7 @@
 from __future__ import annotations
 
 import builtins
+import collections.abc as collections_abc
 import re
 import sys
 from typing import Any
@@ -293,6 +294,12 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
     return type_ is not None and typing_get_origin(type_) is Annotated
 
 
+def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]:
+    return isinstance(obj, collections_abc.Iterable) and not isinstance(
+        obj, (str, bytes)
+    )
+
+
 def is_literal(type_: _AnnotationScanType) -> bool:
     return get_origin(type_) is Literal
 
index 7dcf0968a7cd29f1422989ecdfc1100c97fdca03..de8712c852343b54c356695ec7f0413f4ad7bcd3 100644 (file)
@@ -1,4 +1,5 @@
 import copy
+from decimal import Decimal
 import inspect
 from pathlib import Path
 import pickle
@@ -31,6 +32,7 @@ from sqlalchemy.util import classproperty
 from sqlalchemy.util import compat
 from sqlalchemy.util import FastIntFlag
 from sqlalchemy.util import get_callable_argspec
+from sqlalchemy.util import is_non_string_iterable
 from sqlalchemy.util import langhelpers
 from sqlalchemy.util import preloaded
 from sqlalchemy.util import WeakSequence
@@ -1550,6 +1552,30 @@ class HashEqOverride:
             return True
 
 
+class MiscTest(fixtures.TestBase):
+    @testing.combinations(
+        (["one", "two", "three"], True),
+        (("one", "two", "three"), True),
+        ((), True),
+        ("four", False),
+        (252, False),
+        (Decimal("252"), False),
+        (b"four", False),
+        (iter("four"), True),
+        (b"", False),
+        ("", False),
+        (None, False),
+        ({"dict": "value"}, True),
+        ({}, True),
+        ({"set", "two"}, True),
+        (set(), True),
+        (util.immutabledict(), True),
+        (util.immutabledict({"key": "value"}), True),
+    )
+    def test_non_string_iterable_check(self, fixture, expected):
+        is_(is_non_string_iterable(fixture), expected)
+
+
 class IdentitySetTest(fixtures.TestBase):
     obj_type = object
 
index dffdac8d84298fceb2e3fed90b682f4f5baa4aa1..423784777862374d652ffb54c5c62a06c5e77990 100644 (file)
@@ -542,7 +542,7 @@ class _MutableListTestBase(_MutableListTestFixture):
             data={1, 2, 3},
         )
 
-    def test_in_place_mutation(self):
+    def test_in_place_mutation_int(self):
         sess = fixture_session()
 
         f1 = Foo(data=[1, 2])
@@ -554,7 +554,19 @@ class _MutableListTestBase(_MutableListTestFixture):
 
         eq_(f1.data, [3, 2])
 
-    def test_in_place_slice_mutation(self):
+    def test_in_place_mutation_str(self):
+        sess = fixture_session()
+
+        f1 = Foo(data=["one", "two"])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data[0] = "three"
+        sess.commit()
+
+        eq_(f1.data, ["three", "two"])
+
+    def test_in_place_slice_mutation_int(self):
         sess = fixture_session()
 
         f1 = Foo(data=[1, 2, 3, 4])
@@ -566,6 +578,18 @@ class _MutableListTestBase(_MutableListTestFixture):
 
         eq_(f1.data, [1, 5, 6, 4])
 
+    def test_in_place_slice_mutation_str(self):
+        sess = fixture_session()
+
+        f1 = Foo(data=["one", "two", "three", "four"])
+        sess.add(f1)
+        sess.commit()
+
+        f1.data[1:3] = "five", "six"
+        sess.commit()
+
+        eq_(f1.data, ["one", "five", "six", "four"])
+
     def test_del_slice(self):
         sess = fixture_session()
 
@@ -1240,6 +1264,12 @@ class MutableColumnCopyArrayTest(_MutableListTestBase, fixtures.MappedTest):
             __tablename__ = "foo"
             id = Column(Integer, primary_key=True)
 
+    def test_in_place_mutation_str(self):
+        """this test is hardcoded to integer, skip strings"""
+
+    def test_in_place_slice_mutation_str(self):
+        """this test is hardcoded to integer, skip strings"""
+
 
 class MutableListWithScalarPickleTest(
     _MutableListTestBase, fixtures.MappedTest