]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
work around Python 3.11 IntEnum issue; update FastIntFlag
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 9 Nov 2022 23:41:54 +0000 (18:41 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 10 Nov 2022 17:19:02 +0000 (12:19 -0500)
in [1], Python 3.11 seems to have changed the behavior of
IntEnum.  We didn't notice this because we have our own
workaround class already, but typing did.   Ensure we remain
compatible with IntFlag.

This change also modifies FastIntFlag to no longer use
global symbols; this is unnecessary as we assign FastIntFlag
members explicitly.  Use of ``symbol()`` should probably
be phased out.

[1] https://github.com/python/cpython/issues/99304
Fixes: #8783
Change-Id: I8ae2e871ff1467ae5ca1f63e66b5dae45d4a6c93

doc/build/changelog/unreleased_20/8783.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/collections.py
lib/sqlalchemy/util/langhelpers.py
test/base/test_utils.py

diff --git a/doc/build/changelog/unreleased_20/8783.rst b/doc/build/changelog/unreleased_20/8783.rst
new file mode 100644 (file)
index 0000000..1462cf7
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug
+    :tickets: 8783
+
+    Adjusted internal use of the Python ``IntFlag`` class which changed its
+    behavioral contract in Python 3.11. This was not causing runtime failures
+    however caused typing runs to fail under Python 3.11.
index 4a8df5b5e8961976d6d530cc28d4b05e533d6f25..cb5cab178ea98f4870a87e7468f35faeca9ce2b2 100644 (file)
@@ -583,7 +583,7 @@ class ExecutemanyMode(FastIntFlag):
 (
     EXECUTEMANY_VALUES,
     EXECUTEMANY_VALUES_PLUS_BATCH,
-) = tuple(ExecutemanyMode)
+) = ExecutemanyMode.__members__.values()
 
 
 class PGDialect_psycopg2(_PGDialect_common_psycopg):
index e4a69a352ae33f1c07ef423130294d3c9d331dd2..b46c787996e4a729bff6a5e27d8778a029157421 100644 (file)
@@ -188,7 +188,7 @@ class PassiveFlag(FastIntFlag):
     PASSIVE_NO_FETCH,
     PASSIVE_NO_FETCH_RELATED,
     PASSIVE_ONLY_PERSISTENT,
-) = tuple(PassiveFlag)
+) = PassiveFlag.__members__.values()
 
 DEFAULT_MANAGER_ATTR = "_sa_class_manager"
 DEFAULT_STATE_ATTR = "_sa_instance_state"
index e3051e268f665f026675e7d4c3bfabfc67488226..0c1ccbf10990ea97fdb9a440ac9c5e2392ffeb17 100644 (file)
@@ -128,6 +128,7 @@ import weakref
 from .base import NO_KEY
 from .. import exc as sa_exc
 from .. import util
+from ..sql.base import NO_ARG
 from ..util.compat import inspect_getfullargspec
 from ..util.typing import Protocol
 
@@ -1222,8 +1223,6 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
         fn._sa_instrumented = True
         fn.__doc__ = getattr(dict, fn.__name__).__doc__
 
-    Unspecified = util.symbol("Unspecified")
-
     def __setitem__(fn):
         def __setitem__(self, key, value, _sa_initiator=None):
             if key in self:
@@ -1253,10 +1252,10 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
         return clear
 
     def pop(fn):
-        def pop(self, key, default=Unspecified):
+        def pop(self, key, default=NO_ARG):
             __before_pop(self)
             _to_del = key in self
-            if default is Unspecified:
+            if default is NO_ARG:
                 item = fn(self, key)
             else:
                 item = fn(self, key, default)
@@ -1293,8 +1292,8 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
         return setdefault
 
     def update(fn):
-        def update(self, __other=Unspecified, **kw):
-            if __other is not Unspecified:
+        def update(self, __other=NO_ARG, **kw):
+            if __other is not NO_ARG:
                 if hasattr(__other, "keys"):
                     for key in list(__other):
                         if key not in self or self[key] is not __other[key]:
@@ -1318,7 +1317,6 @@ def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]:
 
     l = locals().copy()
     l.pop("_tidy")
-    l.pop("Unspecified")
     return l
 
 
@@ -1346,8 +1344,6 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
         fn._sa_instrumented = True
         fn.__doc__ = getattr(set, fn.__name__).__doc__
 
-    Unspecified = util.symbol("Unspecified")
-
     def add(fn):
         def add(self, value, _sa_initiator=None):
             if value not in self:
@@ -1500,7 +1496,6 @@ def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]:
 
     l = locals().copy()
     l.pop("_tidy")
-    l.pop("Unspecified")
     return l
 
 
index d8d39f56c6200ec633175b81cc6de0283d0bb029..051a8c89e10dc649097cd3c885db2004b5f6ca19 100644 (file)
@@ -1603,11 +1603,6 @@ class symbol(int):
 
     Repeated calls of symbol('name') will all return the same instance.
 
-    In SQLAlchemy 2.0, symbol() is used for the implementation of
-    ``_FastIntFlag``, but otherwise should be mostly replaced by
-    ``enum.Enum`` and variants.
-
-
     """
 
     name: str
@@ -1632,7 +1627,17 @@ class symbol(int):
                 if doc:
                     sym.__doc__ = doc
 
+                # NOTE: we should ultimately get rid of this global thing,
+                # however, currently it is to support pickling.  The best
+                # change would be when we are on py3.11 at a minimum, we
+                # switch to stdlib enum.IntFlag.
                 cls.symbols[name] = sym
+            else:
+                if canonical and canonical != sym:
+                    raise TypeError(
+                        f"Can't replace canonical symbol for {name} "
+                        f"with new int value {canonical}"
+                    )
             return sym
 
     def __reduce__(self):
@@ -1665,8 +1670,16 @@ class _IntFlagMeta(type):
             setattr(cls, k, sym)
             items.append(sym)
 
+        cls.__members__ = _collections.immutabledict(
+            {sym.name: sym for sym in items}
+        )
+
     def __iter__(self) -> Iterator[symbol]:
-        return iter(self._items)
+        raise NotImplementedError(
+            "iter not implemented to ensure compatibility with "
+            "Python 3.11 IntFlag.  Please use __members__.  See "
+            "https://github.com/python/cpython/issues/99304"
+        )
 
 
 class _FastIntFlag(metaclass=_IntFlagMeta):
index 0ef86545b53cb93bc011b4a69fd5a1f413e3c198..098652928fa787e8f3812407e7df6e64f0078e9a 100644 (file)
@@ -2346,17 +2346,92 @@ class SymbolTest(fixtures.TestBase):
 
     def test_fast_int_flag(self):
         class Enum(FastIntFlag):
-            sym1 = 1
-            sym2 = 2
+            fi_sym1 = 1
+            fi_sym2 = 2
+
+            fi_sym3 = 3
+
+        assert Enum.fi_sym1 is not Enum.fi_sym3
+        assert Enum.fi_sym1 != Enum.fi_sym3
+
+        assert Enum.fi_sym1.name == "fi_sym1"
+
+        # modified for #8783
+        eq_(
+            list(Enum.__members__.values()),
+            [Enum.fi_sym1, Enum.fi_sym2, Enum.fi_sym3],
+        )
+
+    def test_fast_int_flag_still_global(self):
+        """FastIntFlag still causes elements to be global symbols.
+
+        This is to support pickling.  There are likely other ways to
+        achieve this, however this is what we have for now.
+
+        """
+
+        class Enum1(FastIntFlag):
+            fi_sym1 = 1
+            fi_sym2 = 2
+
+        class Enum2(FastIntFlag):
+            fi_sym1 = 1
+            fi_sym2 = 2
+
+        # they are global
+        assert Enum1.fi_sym1 is Enum2.fi_sym1
+
+    def test_fast_int_flag_dont_allow_conflicts(self):
+        """FastIntFlag still causes elements to be global symbols.
+
+        While we do this and haven't yet changed it, make sure conflicting
+        int values for the same name don't come in.
+
+        """
 
-            sym3 = 3
+        class Enum1(FastIntFlag):
+            fi_sym1 = 1
+            fi_sym2 = 2
 
-        assert Enum.sym1 is not Enum.sym3
-        assert Enum.sym1 != Enum.sym3
+        with expect_raises_message(
+            TypeError,
+            "Can't replace canonical symbol for fi_sym1 with new int value 2",
+        ):
+
+            class Enum2(FastIntFlag):
+                fi_sym1 = 2
+                fi_sym2 = 3
+
+    @testing.combinations("native", "ours", argnames="native")
+    def test_compare_to_native_py_intflag(self, native):
+        """monitor IntFlag behavior in upstream Python for #8783"""
+
+        if native == "native":
+            from enum import IntFlag
+        else:
+            from sqlalchemy.util import FastIntFlag as IntFlag
+
+        class Enum(IntFlag):
+            fi_sym1 = 1
+            fi_sym2 = 2
+            fi_sym4 = 4
+
+            fi_sym1plus2 = 3
 
-        assert Enum.sym1.name == "sym1"
+            # not an alias because there's no 16
+            fi_sym17 = 17
 
-        eq_(list(Enum), [Enum.sym1, Enum.sym2, Enum.sym3])
+        sym1, sym2, sym4, sym1plus2, sym17 = Enum.__members__.values()
+        eq_(
+            [sym1, sym2, sym4, sym1plus2, sym17],
+            [
+                Enum.fi_sym1,
+                Enum.fi_sym2,
+                Enum.fi_sym4,
+                Enum.fi_sym1plus2,
+                Enum.fi_sym17,
+            ],
+        )
 
     def test_pickle(self):
         sym1 = util.symbol("foo")
@@ -2395,6 +2470,20 @@ class SymbolTest(fixtures.TestBase):
         assert not (sym1 | sym2) & (sym3 | sym4)
         assert (sym1 | sym2) & (sym2 | sym4)
 
+    def test_fast_int_flag_no_more_iter(self):
+        """test #8783"""
+
+        class MyEnum(FastIntFlag):
+            sym1 = 1
+            sym2 = 2
+            sym3 = 4
+            sym4 = 8
+
+        with expect_raises_message(
+            NotImplementedError, "iter not implemented to ensure compatibility"
+        ):
+            list(MyEnum)
+
     def test_parser(self):
         class MyEnum(FastIntFlag):
             sym1 = 1
@@ -2402,7 +2491,7 @@ class SymbolTest(fixtures.TestBase):
             sym3 = 4
             sym4 = 8
 
-        sym1, sym2, sym3, sym4 = tuple(MyEnum)
+        sym1, sym2, sym3, sym4 = tuple(MyEnum.__members__.values())
         lookup_one = {sym1: [], sym2: [True], sym3: [False], sym4: [None]}
         lookup_two = {sym1: [], sym2: [True], sym3: [False]}
         lookup_three = {sym1: [], sym2: ["symbol2"], sym3: []}