]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
set up Literal for synchronize_session
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 Jul 2022 14:07:58 +0000 (10:07 -0400)
committermike bayer <mike_mp@zzzcomputing.com>
Sat, 23 Jul 2022 21:59:03 +0000 (21:59 +0000)
Fixes: #8280
Change-Id: I59bc6cc0483375f79e17952188e0c2cde926502c

lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
test/ext/mypy/plain_files/session.py

index 0c035e7cfaf9d360a42341f820dc980ccc2fbc37..c10f4701e031e2009a2ca2fb972eb99d111b8a5e 100644 (file)
@@ -55,6 +55,7 @@ from ..sql.dml import InsertDMLState
 from ..sql.dml import UpdateDMLState
 from ..sql.elements import BooleanClauseList
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
+from ..util.typing import Literal
 
 if TYPE_CHECKING:
     from .mapper import Mapper
@@ -65,6 +66,9 @@ if TYPE_CHECKING:
 _O = TypeVar("_O", bound=object)
 
 
+_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"]
+
+
 def _bulk_insert(
     mapper: Mapper[_O],
     mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
index a29e368b6a0e86a7ce99b81167286f70746eeb91..99131e3e9a7afdb114ebdab180ae2b9e1feae777 100644 (file)
@@ -99,6 +99,7 @@ if TYPE_CHECKING:
     from ._typing import _InternalEntityType
     from .mapper import Mapper
     from .path_registry import PathRegistry
+    from .persistence import _SynchronizeSessionArgument
     from .session import _PKIdentityArgument
     from .session import Session
     from .state import InstanceState
@@ -2969,7 +2970,9 @@ class Query(
             self._legacy_from_self(col).enable_eagerloads(False).scalar()
         )
 
-    def delete(self, synchronize_session: str = "evaluate") -> int:
+    def delete(
+        self, synchronize_session: _SynchronizeSessionArgument = "evaluate"
+    ) -> int:
         r"""Perform a DELETE with an arbitrary WHERE clause.
 
         Deletes rows matched by this query from the database.
@@ -3030,7 +3033,7 @@ class Query(
     def update(
         self,
         values: Dict[_DMLColumnArgument, Any],
-        synchronize_session: str = "evaluate",
+        synchronize_session: _SynchronizeSessionArgument = "evaluate",
         update_args: Optional[Dict[Any, Any]] = None,
     ) -> int:
         r"""Perform an UPDATE with an arbitrary WHERE clause.
index 0dfa0a75201eae70de67ccfdc8ccd0bab21da4ae..49f1b44cb4a31a23bd93c09f44f3101d52a2bd82 100644 (file)
@@ -58,4 +58,23 @@ with Session(e) as sess:
     # EXPECTED_TYPE: List[Row[Tuple[int]]]
     reveal_type(rows2)
 
+    # test #8280
+
+    sess.query(User).update(
+        {"name": User.name + " some name"}, synchronize_session="fetch"
+    )
+    sess.query(User).update(
+        {"name": User.name + " some name"}, synchronize_session=False
+    )
+    sess.query(User).update(
+        {"name": User.name + " some name"}, synchronize_session="evaluate"
+    )
+
+    sess.query(User).update(
+        {"name": User.name + " some name"},
+        # EXPECTED_MYPY: Argument "synchronize_session" to "update" of "Query" has incompatible type  # noqa: E501
+        synchronize_session="invalid",
+    )
+    sess.query(User).update({"name": User.name + " some name"})
+
 # more result tests in typed_results.py