]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
improved immutabledict merge_with and union
authorLucasMalor <249807577+LucasMalor@users.noreply.github.com>
Tue, 23 Dec 2025 16:28:35 +0000 (11:28 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Jan 2026 18:42:11 +0000 (13:42 -0500)
Now `immutabledict.merge_with is an alias of `immutabledict.union`,
both accept multiple arguments.
The methods now avoid doing copies of not required: if the method is
called only one `immutabledict` that's not empty it's returned.

Fixes: #13043
Closes: #13042
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/13042
Pull-request-sha: bd53e488432edd5986c28f196b0363b976b26b04

Change-Id: I8078f239e1ca36994b488b15f2fac40facf7f249

lib/sqlalchemy/ext/asyncio/engine.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/strategies.py
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/util/_immutabledict_cy.py
test/base/test_utils.py
test/perf/compiled_extensions/collections_.py
test/profiles.txt

index 69b194db2dcbd4e68cd8e6b51692e704d5d9b3ca..02d25689aaf9da479edce861d06d5694172eeb03 100644 (file)
@@ -41,6 +41,7 @@ from ...engine import Engine
 from ...engine.base import NestedTransaction
 from ...engine.base import Transaction
 from ...exc import ArgumentError
+from ...util import immutabledict
 from ...util.concurrency import greenlet_spawn
 from ...util.typing import TupleAny
 from ...util.typing import TypeVarTuple
@@ -68,6 +69,7 @@ if TYPE_CHECKING:
 _P = ParamSpec("_P")
 _T = TypeVar("_T", bound=Any)
 _Ts = TypeVarTuple("_Ts")
+_stream_results = immutabledict(stream_results=True)
 
 
 def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine:
@@ -590,7 +592,7 @@ class AsyncConnection(  # type:ignore[misc]
             statement,
             parameters,
             execution_options=util.EMPTY_DICT.merge_with(
-                execution_options, {"stream_results": True}
+                execution_options, _stream_results
             ),
             _require_await=True,
         )
index 8abb20e12702bf93f778202a8f14c54a891992d2..03366bc9eb29c5d1e2036aec8fc49da316bc9344 100644 (file)
@@ -98,9 +98,6 @@ _T = TypeVar("_T", bound=Any)
 _Ts = TypeVarTuple("_Ts")
 _path_registry = PathRegistry.root
 
-_EMPTY_DICT = util.immutabledict()
-
-
 LABEL_STYLE_LEGACY_ORM = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM
 
 
@@ -173,8 +170,8 @@ class QueryContext:
         bind_arguments: Optional[_BindArguments] = None,
     ):
         self.load_options = load_options
-        self.execution_options = execution_options or _EMPTY_DICT
-        self.bind_arguments = bind_arguments or _EMPTY_DICT
+        self.execution_options = execution_options or util.EMPTY_DICT
+        self.bind_arguments = bind_arguments or util.EMPTY_DICT
         self.compile_state = compile_state
         self.query = statement
 
@@ -783,8 +780,8 @@ class _ORMFromStatementCompileState(_ORMCompileState):
     eager_adding_joins = False
     compound_eager_adapter = None
 
-    extra_criteria_entities = _EMPTY_DICT
-    eager_joins = _EMPTY_DICT
+    extra_criteria_entities = util.EMPTY_DICT
+    eager_joins = util.EMPTY_DICT
 
     @classmethod
     def _create_orm_context(
@@ -1088,7 +1085,7 @@ class _CompoundSelectCompileState(
 class _ORMSelectCompileState(_ORMCompileState, SelectState):
     _already_joined_edges = ()
 
-    _memoized_entities = _EMPTY_DICT
+    _memoized_entities = util.EMPTY_DICT
 
     _from_obj_alias = None
     _has_mapper_entities = False
@@ -1128,7 +1125,7 @@ class _ORMSelectCompileState(_ORMCompileState, SelectState):
             # query, and at the moment subqueryloader is putting some things
             # in here that we explicitly don't want stuck in a cache.
             self.select_statement = select_statement._clone()
-            self.select_statement._execution_options = util.immutabledict()
+            self.select_statement._execution_options = util.EMPTY_DICT
         else:
             self.select_statement = select_statement
 
index ad874272db9e6528603eef406aa4d0a48a17d38a..a8a799ccd55cf7f9310a35f6d2fcd2c9fb6d079e 100644 (file)
@@ -686,7 +686,8 @@ def _load_on_pk_identity(
         load_options += {"_autoflush": False}
 
     execution_options = util.EMPTY_DICT.merge_with(
-        execution_options, {"_sa_orm_load_options": load_options}
+        execution_options,
+        util.immutabledict(_sa_orm_load_options=load_options),
     )
     result = (
         session.execute(
index a529c4196f3b8e24b8d64fcd6dad96022e7bff2a..bfe49a302359e47bf11c3f51a717520a571f02d8 100644 (file)
@@ -1127,10 +1127,7 @@ class _LazyLoader(
 
         if execution_options:
             execution_options = util.EMPTY_DICT.merge_with(
-                execution_options,
-                {
-                    "_sa_orm_load_options": load_options,
-                },
+                execution_options, {"_sa_orm_load_options": load_options}
             )
         else:
             execution_options = {
index fe951e74c0025771489f9e4c2692e720695fb3a4..3a79ec19bd58f5e54e5c00900e2715b7d44e13e2 100644 (file)
@@ -240,7 +240,7 @@ class SupportsCloneAnnotations(SupportsWrappingAnnotations):
             # clone is used when we are also copying
             # the expression for a deep deannotation
             new = self._clone()
-            new._annotations = util.immutabledict()
+            new._annotations = util.EMPTY_DICT
             new.__dict__.pop("_annotations_cache_key", None)
             return new
         else:
index 94d525ab64e0e2d4df0e5bd7c4a03aa619f92f1f..3c1a6714fa941e1c410938bfd1b7c22f2e8048f4 100644 (file)
@@ -979,7 +979,7 @@ class Options(metaclass=_MetaOptions):
                     result[local] = statement_exec_options[argname]
 
             new_options = existing_options + result
-            exec_options = util.immutabledict().merge_with(
+            exec_options = util.EMPTY_DICT.merge_with(
                 exec_options, {key: new_options}
             )
             return new_options, exec_options
index 9a32f19a76072bab2e58a52576d563b1f7a9aa14..eb61a9f09b27f8614c2de553a125e856b409f749 100644 (file)
@@ -4,12 +4,13 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: disable-error-code="misc, arg-type, untyped-decorator"
+# mypy: disable-error-code="misc, arg-type, type-arg, untyped-decorator"
 from __future__ import annotations
 
 from typing import Any
 from typing import Dict
 from typing import Hashable
+from typing import Literal
 from typing import Mapping
 from typing import NoReturn
 from typing import Optional
@@ -147,41 +148,72 @@ class immutabledict(Dict[_KT, _VT]):
 
     @cython.annotation_typing(False)  # avoid cython crash from generic return
     def union(
-        self, other: Optional[Mapping[_KT, _VT]] = None, /
+        self, *dicts: Optional[Mapping[_KT, _VT]]
     ) -> immutabledict[_KT, _VT]:
-        if not other:
-            return self
-        # new + update is faster than immutabledict(self)
-        result: immutabledict = immutabledict()  # type: ignore[type-arg]
-        PyDict_Update(result, self)
-        if isinstance(other, dict):
-            # c version of PyDict_Update supports only dicts
-            PyDict_Update(result, other)
-        else:
-            dict.update(result, other)
-        return result
+        return self._union_other(dicts)  # type: ignore[no-any-return]
 
     @cython.annotation_typing(False)  # avoid cython crash from generic return
     def merge_with(
         self, *dicts: Optional[Mapping[_KT, _VT]]
     ) -> immutabledict[_KT, _VT]:
-        result: Optional[immutabledict] = None  # type: ignore[type-arg]
-        d: object
-        if not dicts:
+        # this is an alias of union
+        return self._union_other(dicts)  # type: ignore[no-any-return]
+
+    @cython.cfunc
+    @cython.inline
+    def _union_other(self, others: tuple) -> immutabledict:
+        size = len(others)
+        if size == 0:
+            return self
+
+        # only_one == immutabledict : we found exactly one immutabledict that
+        # has contents; no other dict / immutabledict has any contents
+        #
+        # only_one is None : we found more than one dict / immutabledict that
+        # has contents
+        #
+        # only_one is False : we've found nothing that is not an empty
+        # immutabledict
+        only_one: immutabledict | None | Literal[False]
+
+        if self:
+            self_is_empty = False
+            only_one = self
+        else:
+            only_one = False
+            self_is_empty = True
+
+        for i in range(size):
+            d = others[i]
+            if not d:
+                continue
+
+            if only_one is False and isinstance(d, immutabledict):
+                only_one = d
+            else:
+                only_one = None
+                break
+
+        if only_one is False:
             return self
-        for d in dicts:
-            if d is not None and len(d) > 0:
-                if result is None:
-                    # new + update is faster than immutabledict(self)
-                    result = immutabledict()
-                    PyDict_Update(result, self)
-                if isinstance(d, dict):
-                    # c version of PyDict_Update supports only dicts
-                    PyDict_Update(result, d)
-                else:
-                    dict.update(result, d)
-
-        return self if result is None else result
+        elif only_one is not None:
+            return only_one
+
+        result: immutabledict = immutabledict()
+        if not self_is_empty:
+            PyDict_Update(result, self)
+
+        for i in range(size):
+            d = others[i]
+            if not d:
+                continue
+            if isinstance(d, dict):
+                # c version of PyDict_Update supports only dicts
+                PyDict_Update(result, d)
+            else:
+                dict.update(result, d)
+
+        return result
 
     def copy(self) -> Self:
         return self
index b228bcc580183b838ffee4a9aac6ccee0b430a72..391bbf75d6dcfc05d82b06c9f6fecb14a43ebdf6 100644 (file)
@@ -22,6 +22,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_instance_of
 from sqlalchemy.testing import is_none
+from sqlalchemy.testing import is_not
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import ne_
@@ -360,50 +361,150 @@ class OrderedSetTest(fixtures.TestBase):
 
 
 class ImmutableDictTest(fixtures.TestBase):
-    def test_union_no_change(self):
-        d = util.immutabledict({1: 2, 3: 4})
+    methods = combinations(
+        util.immutabledict.union,
+        util.immutabledict.merge_with,
+        argnames="method",
+    )
 
-        d2 = d.union({})
+    @methods
+    def test_no_change(self, method):
+        d = util.immutabledict({1: 2, 3: 4})
 
+        d2 = method(d)
+        is_(d2, d)
+        d2 = method(d, {})
+        is_(d2, d)
+        d2 = method(d, None)
+        is_(d2, d)
+        d2 = method(d, {}, {}, {}, None)
         is_(d2, d)
 
-    def test_merge_with_no_change(self):
+    @methods
+    def test_no_change_self_empty(self, method):
         d = util.immutabledict({1: 2, 3: 4})
+        e = util.immutabledict()
+        d2 = method(e, d)
 
-        d2 = d.merge_with({}, None)
+        eq_(e, {})
+        is_(d2, d)
 
-        eq_(d2, {1: 2, 3: 4})
+        d2 = method(e, {}, d)
+        is_(d2, d)
+        d2 = method(e, None, d, {}, {})
         is_(d2, d)
 
-    def test_merge_with_dicts(self):
-        d = util.immutabledict({1: 2, 3: 4})
+        d2 = method(e, {1: 2, 3: 4})
+
+        eq_(d2, d)
+        assert isinstance(d2, util.immutabledict)
 
-        d2 = d.merge_with({3: 5, 7: 12}, {9: 18, 15: 25})
+        d2 = method(e, {1: 2, 3: 4}, {3: 5, 4: 7})
 
-        eq_(d, {1: 2, 3: 4})
-        eq_(d2, {1: 2, 3: 5, 7: 12, 9: 18, 15: 25})
+        eq_(d2, {1: 2, 3: 5, 4: 7})
         assert isinstance(d2, util.immutabledict)
 
-        d3 = d.merge_with({17: 42})
+    @methods
+    def test_start_empty_but_then_populate(self, method):
+        d = util.immutabledict()
+
+        d2 = method(d, {1: 2})
+        eq_(d2, {1: 2})
+        is_not(d2, d)
+
+        d3 = method(d, util.immutabledict(), {1: 2})
+        eq_(d3, {1: 2})
+
+        d4 = method(
+            d, util.immutabledict(), util.immutabledict({1: 2}), {3: 4}
+        )
+        eq_(d4, {1: 2, 3: 4})
+
+    @methods
+    def test_no_change_everyone_empty(self, method):
+        d = util.immutabledict()
+        e = util.immutabledict()
+        d2 = method(e, d)
+
+        eq_(e, {})
+        is_(d2, e)
+
+        f = {}
+
+        d3 = method(e, d, f)
+        eq_(e, {})
+        is_(d3, e)
+
+        g = util.immutabledict()
+        d4 = method(e, d, f, g)
+        eq_(e, {})
+        is_(d4, e)
+
+    @methods
+    def test_no_change_against_self(self, method):
+        d = util.immutabledict()
+        e = d
+        d2 = method(e, d)
+
+        eq_(e, {})
+        is_(d2, e)
+
+        f = d
 
-        eq_(d3, {1: 2, 3: 4, 17: 42})
+        d3 = method(e, d, f)
+        eq_(e, {})
+        is_(d3, e)
 
-    def test_merge_with_tuples(self):
+    @methods
+    def test_multiple_dicts(self, method):
         d = util.immutabledict({1: 2, 3: 4})
 
-        d2 = d.merge_with([(3, 5), (7, 12)], [(9, 18), (15, 25)])
+        d2 = method(d, {17: 42})
 
         eq_(d, {1: 2, 3: 4})
-        eq_(d2, {1: 2, 3: 5, 7: 12, 9: 18, 15: 25})
+        eq_(d2, {1: 2, 3: 4, 17: 42})
+
+        d3 = method(d, {3: 5, 7: 12}, {9: 18, 15: 25}, None)
 
-    def test_union_dictionary(self):
+        eq_(d3, {1: 2, 3: 5, 7: 12, 9: 18, 15: 25})
+        assert isinstance(d3, util.immutabledict)
+
+    @methods
+    def test_multiple_immutabledict(self, method):
         d = util.immutabledict({1: 2, 3: 4})
+        d2 = method(d, util.immutabledict({3: 5, 7: 12}))
+
+        eq_(d2, {1: 2, 3: 5, 7: 12})
+        assert isinstance(d2, util.immutabledict)
+        d2 = method(
+            d,
+            util.immutabledict({3: 5, 7: 12}),
+            util.immutabledict({7: 6, 11: 12}),
+        )
 
-        d2 = d.union({3: 5, 7: 12})
+        eq_(d2, {1: 2, 3: 5, 7: 6, 11: 12})
         assert isinstance(d2, util.immutabledict)
 
+        e = util.immutabledict()
+        d2 = method(
+            e,
+            util.immutabledict({3: 5, 7: 12}),
+            util.immutabledict({7: 6, 11: 12}),
+        )
+
+        eq_(d2, {3: 5, 7: 6, 11: 12})
+        assert isinstance(d2, util.immutabledict)
+
+    @methods
+    def test_with_tuples(self, method):
+        # this is not really supported, but it's useful to test the non-dict
+        # case
+        d = util.immutabledict({1: 2, 3: 4})
+
+        d2 = method(d, [(3, 5), (7, 12)], [(9, 18), (15, 25)])
+
         eq_(d, {1: 2, 3: 4})
-        eq_(d2, {1: 2, 3: 5, 7: 12})
+        eq_(d2, {1: 2, 3: 5, 7: 12, 9: 18, 15: 25})
 
     def _dont_test_union_kw(self):
         d = util.immutabledict({"a": "b", "c": "d"})
@@ -414,14 +515,6 @@ class ImmutableDictTest(fixtures.TestBase):
         eq_(d, {"a": "b", "c": "d"})
         eq_(d2, {"a": "b", "c": "d", "e": "f", "g": "h"})
 
-    def test_union_tuples(self):
-        d = util.immutabledict({1: 2, 3: 4})
-
-        d2 = d.union([(3, 5), (7, 12)])
-
-        eq_(d, {1: 2, 3: 4})
-        eq_(d2, {1: 2, 3: 5, 7: 12})
-
     def test_keys(self):
         d = util.immutabledict({1: 2, 3: 4})
 
@@ -454,6 +547,10 @@ class ImmutableDictTest(fixtures.TestBase):
         ne_(d, d4)
         eq_(d3, d4)
 
+    def test_copy(self):
+        d = util.immutabledict({1: 2, 3: 4})
+        is_(d.copy(), d)
+
     def test_serialize(self):
         d = util.immutabledict({1: 2, 3: 4})
         for loads, dumps in picklers():
index 7643bd14ef72c64f1d192752dd62dc07ea441764..0b66e1f4d535241e6420a8be0302fac554d7369e 100644 (file)
@@ -30,6 +30,7 @@ class ImmutableDict(Case):
     def init_objects(self):
         self.small = {"a": 5, "b": 4}
         self.large = {f"k{i}": f"v{i}" for i in range(50)}
+        self.empty = self.impl()
         self.d1 = self.impl({"x": 5, "y": 4})
         self.d2 = self.impl({f"key{i}": f"value{i}" for i in range(50)})
 
@@ -43,6 +44,10 @@ class ImmutableDict(Case):
     def init_empty(self):
         self.impl()
 
+    @test_case
+    def init_kw(self):
+        self.impl(a=1, b=2)
+
     @test_case
     def init(self):
         self.impl(self.small)
@@ -69,6 +74,12 @@ class ImmutableDict(Case):
     def union_large(self):
         self.d2.union(self.large)
 
+    @test_case
+    def union_imm(self):
+        self.empty.union(self.d1)
+        self.d1.union(self.d2)
+        self.d1.union(self.empty)
+
     @test_case
     def merge_with(self):
         self.d1.merge_with(self.small)
@@ -78,6 +89,21 @@ class ImmutableDict(Case):
     def merge_with_large(self):
         self.d2.merge_with(self.large)
 
+    @test_case
+    def merge_with_imm(self):
+        self.d1.merge_with(self.d2)
+        self.empty.merge_with(self.d1)
+        self.empty.merge_with(self.d1, self.d2)
+
+    @test_case
+    def merge_with_only_one(self):
+        self.d1.merge_with(self.empty, None, self.empty)
+        self.empty.merge_with(self.empty, self.d1, self.empty)
+
+    @test_case
+    def merge_with_many(self):
+        self.d1.merge_with(self.d2, self.small, None, self.small, self.large)
+
     @test_case
     def get(self):
         self.d1.get("x")
index 4a2254196f5becd5adaaadc7753abc00c5876f10..a650f68d6ef2c8d853ae57dc652a04efd5427d42 100644 (file)
@@ -459,7 +459,7 @@ test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute
 test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.13_postgresql_psycopg2_dbapiunicode_cextensions 50
 test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.13_postgresql_psycopg2_dbapiunicode_nocextensions 52
 test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.13_sqlite_pysqlite_dbapiunicode_cextensions 50
-test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.13_sqlite_pysqlite_dbapiunicode_nocextensions 52
+test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.13_sqlite_pysqlite_dbapiunicode_nocextensions 56
 test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.14_mariadb_mysqldb_dbapiunicode_cextensions 50
 test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.14_mariadb_mysqldb_dbapiunicode_nocextensions 52
 test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.14_mssql_pyodbc_dbapiunicode_cextensions 50
@@ -469,7 +469,7 @@ test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute
 test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.14_postgresql_psycopg2_dbapiunicode_cextensions 50
 test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.14_postgresql_psycopg2_dbapiunicode_nocextensions 52
 test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.14_sqlite_pysqlite_dbapiunicode_cextensions 50
-test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.14_sqlite_pysqlite_dbapiunicode_nocextensions 52
+test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_connection_execute x86_64_linux_cpython_3.14_sqlite_pysqlite_dbapiunicode_nocextensions 56
 
 # TEST: test.aaa_profiling.test_resultset.ExecutionTest.test_minimal_engine_execute