]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add full parameter types for ORM with_for_update
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 10 May 2023 15:08:07 +0000 (11:08 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 10 May 2023 19:24:11 +0000 (15:24 -0400)
Fixed typing for the :paramref:`_orm.Session.get.with_for_update` parameter
of :meth:`_orm.Session.get` and :meth:`_orm.Session.refresh` (as well as
corresponding methods on :class:`_asyncio.AsyncSession`) to accept boolean
``True`` and all other argument forms accepted by the parameter at runtime.

Fixes: #9762
Change-Id: Ied4d37a269906b3d9be5ab7d31a2fa863360cced

doc/build/changelog/unreleased_20/9762.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/scoping.py
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql/selectable.py
test/ext/mypy/plain_files/session.py

diff --git a/doc/build/changelog/unreleased_20/9762.rst b/doc/build/changelog/unreleased_20/9762.rst
new file mode 100644 (file)
index 0000000..9906bfb
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 9762
+
+    Fixed typing for the :paramref:`_orm.Session.get.with_for_update` parameter
+    of :meth:`_orm.Session.get` and :meth:`_orm.Session.refresh` (as well as
+    corresponding methods on :class:`_asyncio.AsyncSession`) to accept boolean
+    ``True`` and all other argument forms accepted by the parameter at runtime.
index 52eeb08281ec9838d60be7dcb01b4ec0655e2c0d..49d8b3af937643866a0599a3cf0025052dd755d8 100644 (file)
@@ -55,7 +55,7 @@ if TYPE_CHECKING:
     from ...orm.session import _SessionBind
     from ...sql.base import Executable
     from ...sql.elements import ClauseElement
-    from ...sql.selectable import ForUpdateArg
+    from ...sql.selectable import ForUpdateParameter
     from ...sql.selectable import TypedReturnsRows
 
 _T = TypeVar("_T", bound=Any)
@@ -217,7 +217,7 @@ class async_scoped_session(Generic[_AS]):
         *,
         options: Optional[Sequence[ORMOption]] = None,
         populate_existing: bool = False,
-        with_for_update: Optional[ForUpdateArg] = None,
+        with_for_update: ForUpdateParameter = None,
         identity_token: Optional[Any] = None,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
     ) -> Optional[_O]:
@@ -934,7 +934,7 @@ class async_scoped_session(Generic[_AS]):
         self,
         instance: object,
         attribute_names: Optional[Iterable[str]] = None,
-        with_for_update: Optional[ForUpdateArg] = None,
+        with_for_update: ForUpdateParameter = None,
     ) -> None:
         r"""Expire and refresh the attributes on the given instance.
 
index 00fee9716182254614d3fca927327d04ad35227b..61500330256f1b487eabc35d8fa9e4cc6b615d1d 100644 (file)
@@ -62,7 +62,7 @@ if TYPE_CHECKING:
     from ...sql._typing import _InfoType
     from ...sql.base import Executable
     from ...sql.elements import ClauseElement
-    from ...sql.selectable import ForUpdateArg
+    from ...sql.selectable import ForUpdateParameter
     from ...sql.selectable import TypedReturnsRows
 
 _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"]
@@ -301,7 +301,7 @@ class AsyncSession(ReversibleProxy[Session]):
         self,
         instance: object,
         attribute_names: Optional[Iterable[str]] = None,
-        with_for_update: Optional[ForUpdateArg] = None,
+        with_for_update: ForUpdateParameter = None,
     ) -> None:
         """Expire and refresh the attributes on the given instance.
 
@@ -566,7 +566,7 @@ class AsyncSession(ReversibleProxy[Session]):
         *,
         options: Optional[Sequence[ORMOption]] = None,
         populate_existing: bool = False,
-        with_for_update: Optional[ForUpdateArg] = None,
+        with_for_update: ForUpdateParameter = None,
         identity_token: Optional[Any] = None,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
     ) -> Optional[_O]:
index 19217ec32e7cdb4c8f6d0b75c39b1254fe3749c7..83ce6e44e66f0ad2d5dd229e37c76acffee3617e 100644 (file)
@@ -70,7 +70,7 @@ if TYPE_CHECKING:
     from ..sql.base import Executable
     from ..sql.elements import ClauseElement
     from ..sql.roles import TypedColumnsClauseRole
-    from ..sql.selectable import ForUpdateArg
+    from ..sql.selectable import ForUpdateParameter
     from ..sql.selectable import TypedReturnsRows
 
 _T = TypeVar("_T", bound=Any)
@@ -889,7 +889,7 @@ class scoped_session(Generic[_S]):
         *,
         options: Optional[Sequence[ORMOption]] = None,
         populate_existing: bool = False,
-        with_for_update: Optional[ForUpdateArg] = None,
+        with_for_update: ForUpdateParameter = None,
         identity_token: Optional[Any] = None,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -1592,7 +1592,7 @@ class scoped_session(Generic[_S]):
         self,
         instance: object,
         attribute_names: Optional[Iterable[str]] = None,
-        with_for_update: Optional[ForUpdateArg] = None,
+        with_for_update: ForUpdateParameter = None,
     ) -> None:
         r"""Expire and refresh attributes on the given instance.
 
index 792b59e817f728764f8cd95bbeaa4cc9e60b671c..0ce53bcab603667033961c343e456ea710834409 100644 (file)
@@ -126,6 +126,7 @@ if typing.TYPE_CHECKING:
     from ..sql.base import ExecutableOption
     from ..sql.elements import ClauseElement
     from ..sql.roles import TypedColumnsClauseRole
+    from ..sql.selectable import ForUpdateParameter
     from ..sql.selectable import TypedReturnsRows
 
 _T = TypeVar("_T", bound=Any)
@@ -2911,7 +2912,7 @@ class Session(_SessionClassMethods, EventTarget):
         self,
         instance: object,
         attribute_names: Optional[Iterable[str]] = None,
-        with_for_update: Optional[ForUpdateArg] = None,
+        with_for_update: ForUpdateParameter = None,
     ) -> None:
         """Expire and refresh attributes on the given instance.
 
@@ -3432,7 +3433,7 @@ class Session(_SessionClassMethods, EventTarget):
         *,
         options: Optional[Sequence[ORMOption]] = None,
         populate_existing: bool = False,
-        with_for_update: Optional[ForUpdateArg] = None,
+        with_for_update: ForUpdateParameter = None,
         identity_token: Optional[Any] = None,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
@@ -3559,7 +3560,7 @@ class Session(_SessionClassMethods, EventTarget):
         *,
         options: Optional[Sequence[ExecutableOption]] = None,
         populate_existing: bool = False,
-        with_for_update: Optional[ForUpdateArg] = None,
+        with_for_update: ForUpdateParameter = None,
         identity_token: Optional[Any] = None,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
         bind_arguments: Optional[_BindArguments] = None,
index 8a371951efabeefae63cac08f8143d951cb34a6b..19d4641808ba0e35e812322bb9238838881136cf 100644 (file)
@@ -3067,6 +3067,9 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
         return [self]
 
 
+ForUpdateParameter = Union["ForUpdateArg", None, bool, Dict[str, Any]]
+
+
 class ForUpdateArg(ClauseElement):
     _traverse_internals: _TraverseInternalsType = [
         ("of", InternalTraversal.dp_clauseelement_list),
@@ -3082,7 +3085,7 @@ class ForUpdateArg(ClauseElement):
 
     @classmethod
     def _from_argument(
-        cls, with_for_update: Union[ForUpdateArg, None, bool, Dict[str, Any]]
+        cls, with_for_update: ForUpdateParameter
     ) -> Optional[ForUpdateArg]:
         if isinstance(with_for_update, ForUpdateArg):
             return with_for_update
index 9106b901690944ec81b80dc7771c110a3462c4a2..dfebdd5a9ac34ecd42ee5aab4d7389e7ee62df07 100644 (file)
@@ -1,14 +1,20 @@
 from __future__ import annotations
 
+import asyncio
 from typing import List
 
 from sqlalchemy import create_engine
 from sqlalchemy import ForeignKey
+from sqlalchemy.ext.asyncio import async_scoped_session
+from sqlalchemy.ext.asyncio import async_sessionmaker
+from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy.orm import DeclarativeBase
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import scoped_session
 from sqlalchemy.orm import Session
+from sqlalchemy.orm import sessionmaker
 
 
 class Base(DeclarativeBase):
@@ -94,3 +100,41 @@ with Session(e) as sess:
     ).offset(User.id)
 
 # more result tests in typed_results.py
+
+
+def test_with_for_update() -> None:
+    """test #9762"""
+    sess = Session()
+    ss = scoped_session(sessionmaker())
+
+    sess.get(User, 1)
+    sess.get(User, 1, with_for_update=True)
+    ss.get(User, 1)
+    ss.get(User, 1, with_for_update=True)
+
+    u1 = User()
+    sess.refresh(u1)
+    sess.refresh(u1, with_for_update=True)
+    ss.refresh(u1)
+    ss.refresh(u1, with_for_update=True)
+
+
+async def test_with_for_update_async() -> None:
+    """test #9762"""
+    sess = AsyncSession()
+    ss = async_scoped_session(
+        async_sessionmaker(), scopefunc=asyncio.current_task
+    )
+
+    await sess.get(User, 1)
+    await sess.get(User, 1, with_for_update=True)
+
+    await ss.get(User, 1)
+    await ss.get(User, 1, with_for_update=True)
+
+    u1 = User()
+    await sess.refresh(u1)
+    await sess.refresh(u1, with_for_update=True)
+
+    await ss.refresh(u1)
+    await ss.refresh(u1, with_for_update=True)