From: Mike Bayer Date: Wed, 20 Dec 2023 15:56:18 +0000 (-0500) Subject: use a standard function to check for iterable collections X-Git-Tag: rel_2_0_24~5 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=1fad31cb948d5e1b7421e39d84bc18179875fd26;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git use a standard function to check for iterable collections Fixed 2.0 regression in :class:`.MutableList` where a routine that detects sequences would not correctly filter out string or bytes instances, making it impossible to assign a string value to a specific index (while non-sequence values would work fine). Fixes: #10784 Change-Id: I829cd2a1ef555184de8e6a752f39df65f69f6943 (cherry picked from commit 99da5ebab36da61b7bfa0b868f50974d6a4c4655) --- diff --git a/doc/build/changelog/unreleased_20/10784.rst b/doc/build/changelog/unreleased_20/10784.rst new file mode 100644 index 0000000000..a67d5b6392 --- /dev/null +++ b/doc/build/changelog/unreleased_20/10784.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, orm + :tickets: 10784 + + Fixed 2.0 regression in :class:`.MutableList` where a routine that detects + sequences would not correctly filter out string or bytes instances, making + it impossible to assign a string value to a specific index (while + non-sequence values would work fine). diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 0f82518aaa..ff4dea0866 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -378,6 +378,7 @@ from weakref import WeakKeyDictionary from .. import event from .. import inspect from .. import types +from .. import util from ..orm import Mapper from ..orm._typing import _ExternalEntityType from ..orm._typing import _O @@ -909,10 +910,10 @@ class MutableList(Mutable, List[_T]): self[:] = state def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]: - return not isinstance(value, Iterable) + return not util.is_non_string_iterable(value) def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]: - return isinstance(value, Iterable) + return util.is_non_string_iterable(value) def __setitem__( self, index: SupportsIndex | slice, value: _T | Iterable[_T] diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index c4d340713b..3926e557a9 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -851,9 +851,7 @@ class InElementImpl(RoleImpl): ) def _literal_coercion(self, element, expr, operator, **kw): - if isinstance(element, collections_abc.Iterable) and not isinstance( - element, str - ): + if util.is_non_string_iterable(element): non_literal_expressions: Dict[ Optional[operators.ColumnOperators], operators.ColumnOperators, diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index c804f96887..b60dcf2d94 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -157,3 +157,4 @@ from .langhelpers import warn_exception as warn_exception from .langhelpers import warn_limited as warn_limited from .langhelpers import wrap_callable as wrap_callable from .preloaded import preload_module as preload_module +from .typing import is_non_string_iterable as is_non_string_iterable diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index a0b1977ee5..1e602165c8 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -9,7 +9,6 @@ """Collection classes and helpers.""" from __future__ import annotations -import collections.abc as collections_abc import operator import threading import types @@ -36,6 +35,7 @@ from typing import ValuesView import weakref from ._has_cy import HAS_CYEXTENSION +from .typing import is_non_string_iterable from .typing import Literal from .typing import Protocol @@ -419,9 +419,7 @@ def coerce_generator_arg(arg: Any) -> List[Any]: def to_list(x: Any, default: Optional[List[Any]] = None) -> List[Any]: if x is None: return default # type: ignore - if not isinstance(x, collections_abc.Iterable) or isinstance( - x, (str, bytes) - ): + if not is_non_string_iterable(x): return [x] elif isinstance(x, list): return x diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index aad5709451..faf71c89a2 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -9,6 +9,7 @@ from __future__ import annotations import builtins +import collections.abc as collections_abc import re import sys import typing @@ -296,6 +297,12 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: return type_ is not None and typing_get_origin(type_) is Annotated +def is_non_string_iterable(obj: Any) -> TypeGuard[Iterable[Any]]: + return isinstance(obj, collections_abc.Iterable) and not isinstance( + obj, (str, bytes) + ) + + def is_literal(type_: _AnnotationScanType) -> bool: return get_origin(type_) is Literal diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 7dcf0968a7..de8712c852 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,4 +1,5 @@ import copy +from decimal import Decimal import inspect from pathlib import Path import pickle @@ -31,6 +32,7 @@ from sqlalchemy.util import classproperty from sqlalchemy.util import compat from sqlalchemy.util import FastIntFlag from sqlalchemy.util import get_callable_argspec +from sqlalchemy.util import is_non_string_iterable from sqlalchemy.util import langhelpers from sqlalchemy.util import preloaded from sqlalchemy.util import WeakSequence @@ -1550,6 +1552,30 @@ class HashEqOverride: return True +class MiscTest(fixtures.TestBase): + @testing.combinations( + (["one", "two", "three"], True), + (("one", "two", "three"), True), + ((), True), + ("four", False), + (252, False), + (Decimal("252"), False), + (b"four", False), + (iter("four"), True), + (b"", False), + ("", False), + (None, False), + ({"dict": "value"}, True), + ({}, True), + ({"set", "two"}, True), + (set(), True), + (util.immutabledict(), True), + (util.immutabledict({"key": "value"}), True), + ) + def test_non_string_iterable_check(self, fixture, expected): + is_(is_non_string_iterable(fixture), expected) + + class IdentitySetTest(fixtures.TestBase): obj_type = object diff --git a/test/ext/test_mutable.py b/test/ext/test_mutable.py index dffdac8d84..4237847778 100644 --- a/test/ext/test_mutable.py +++ b/test/ext/test_mutable.py @@ -542,7 +542,7 @@ class _MutableListTestBase(_MutableListTestFixture): data={1, 2, 3}, ) - def test_in_place_mutation(self): + def test_in_place_mutation_int(self): sess = fixture_session() f1 = Foo(data=[1, 2]) @@ -554,7 +554,19 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [3, 2]) - def test_in_place_slice_mutation(self): + def test_in_place_mutation_str(self): + sess = fixture_session() + + f1 = Foo(data=["one", "two"]) + sess.add(f1) + sess.commit() + + f1.data[0] = "three" + sess.commit() + + eq_(f1.data, ["three", "two"]) + + def test_in_place_slice_mutation_int(self): sess = fixture_session() f1 = Foo(data=[1, 2, 3, 4]) @@ -566,6 +578,18 @@ class _MutableListTestBase(_MutableListTestFixture): eq_(f1.data, [1, 5, 6, 4]) + def test_in_place_slice_mutation_str(self): + sess = fixture_session() + + f1 = Foo(data=["one", "two", "three", "four"]) + sess.add(f1) + sess.commit() + + f1.data[1:3] = "five", "six" + sess.commit() + + eq_(f1.data, ["one", "five", "six", "four"]) + def test_del_slice(self): sess = fixture_session() @@ -1240,6 +1264,12 @@ class MutableColumnCopyArrayTest(_MutableListTestBase, fixtures.MappedTest): __tablename__ = "foo" id = Column(Integer, primary_key=True) + def test_in_place_mutation_str(self): + """this test is hardcoded to integer, skip strings""" + + def test_in_place_slice_mutation_str(self): + """this test is hardcoded to integer, skip strings""" + class MutableListWithScalarPickleTest( _MutableListTestBase, fixtures.MappedTest