]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Update type annotations for loading options
authorJanek Nouvertné <25355197+provinzkraut@users.noreply.github.com>
Wed, 16 Aug 2023 15:01:39 +0000 (11:01 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 Sep 2023 17:31:40 +0000 (13:31 -0400)
Update type annotations for ORM loading options, restricting them to accept
only `"*"` instead of any string for string arguments.  Pull request
courtesy Janek Nouvertné.

Fixes: #10131
Closes: #10133
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10133
Pull-request-sha: 08793ed5bfbffbc8688a2573f021e834fc7de367

Change-Id: I33bb93d36cd1eb9d8c7390ed0e94a784e0b8af46

doc/build/changelog/unreleased_20/10131.rst [new file with mode: 0644]
lib/sqlalchemy/orm/path_registry.py
lib/sqlalchemy/orm/strategy_options.py
test/typing/plain_files/orm/orm_querying.py

diff --git a/doc/build/changelog/unreleased_20/10131.rst b/doc/build/changelog/unreleased_20/10131.rst
new file mode 100644 (file)
index 0000000..b82b721
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: typing, bug
+    :tickets: 10131
+
+    Update type annotations for ORM loading options, restricting them to accept
+    only `"*"` instead of any string for string arguments.  Pull request
+    courtesy Janek Nouvertné.
index 41ca328e1cf9d49e0ab90cbcc5e78146ec56bad0..58a064afab8848cf7498c41d0a6c30489067cce1 100644 (file)
@@ -57,9 +57,9 @@ else:
 
 
 _SerializedPath = List[Any]
-
+_StrPathToken = str
 _PathElementType = Union[
-    str, "_InternalEntityType[Any]", "MapperProperty[Any]"
+    _StrPathToken, "_InternalEntityType[Any]", "MapperProperty[Any]"
 ]
 
 # the representation is in fact
@@ -180,7 +180,7 @@ class PathRegistry(HasCacheKey):
         return id(self)
 
     @overload
-    def __getitem__(self, entity: str) -> TokenRegistry:
+    def __getitem__(self, entity: _StrPathToken) -> TokenRegistry:
         ...
 
     @overload
@@ -204,7 +204,11 @@ class PathRegistry(HasCacheKey):
     def __getitem__(
         self,
         entity: Union[
-            str, int, slice, _InternalEntityType[Any], MapperProperty[Any]
+            _StrPathToken,
+            int,
+            slice,
+            _InternalEntityType[Any],
+            MapperProperty[Any],
         ],
     ) -> Union[
         TokenRegistry,
@@ -355,7 +359,7 @@ class CreatesToken(PathRegistry):
     is_aliased_class: bool
     is_root: bool
 
-    def token(self, token: str) -> TokenRegistry:
+    def token(self, token: _StrPathToken) -> TokenRegistry:
         if token.endswith(f":{_WILDCARD_TOKEN}"):
             return TokenRegistry(self, token)
         elif token.endswith(f":{_DEFAULT_TOKEN}"):
@@ -385,7 +389,7 @@ class RootRegistry(CreatesToken):
     ) -> Union[TokenRegistry, AbstractEntityRegistry]:
         if entity in PathToken._intern:
             if TYPE_CHECKING:
-                assert isinstance(entity, str)
+                assert isinstance(entity, _StrPathToken)
             return TokenRegistry(self, PathToken._intern[entity])
         else:
             try:
@@ -433,10 +437,10 @@ class TokenRegistry(PathRegistry):
 
     inherit_cache = True
 
-    token: str
+    token: _StrPathToken
     parent: CreatesToken
 
-    def __init__(self, parent: CreatesToken, token: str):
+    def __init__(self, parent: CreatesToken, token: _StrPathToken):
         token = PathToken.intern(token)
 
         self.token = token
index d59fbb7693f534f1c688dcb9657926b084c4e517..fd1dbe122e3a5708f65b403ca5abdcc6cf79891a 100644 (file)
@@ -34,6 +34,7 @@ from .attributes import QueryableAttribute
 from .base import InspectionAttr
 from .interfaces import LoaderOption
 from .path_registry import _DEFAULT_TOKEN
+from .path_registry import _StrPathToken
 from .path_registry import _WILDCARD_TOKEN
 from .path_registry import AbstractEntityRegistry
 from .path_registry import path_is_property
@@ -77,7 +78,7 @@ if typing.TYPE_CHECKING:
     from ..sql.cache_key import CacheKey
 
 
-_AttrType = Union[str, "QueryableAttribute[Any]"]
+_AttrType = Union[Literal["*"], "QueryableAttribute[Any]"]
 
 _WildcardKeyType = Literal["relationship", "column"]
 _StrategySpec = Dict[str, Any]
@@ -1668,7 +1669,7 @@ class _LoadElement(
     def create(
         cls,
         path: PathRegistry,
-        attr: Optional[_AttrType],
+        attr: Union[_AttrType, _StrPathToken, None],
         strategy: Optional[_StrategyKey],
         wildcard_key: Optional[_WildcardKeyType],
         local_opts: Optional[_OptsType],
index 6bde850aaed978730487381b3b74dbf151bb6dea..fa59baad43a9185d3db3aea550834f0ac3801d79 100644 (file)
@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 from sqlalchemy import ForeignKey
+from sqlalchemy import orm
 from sqlalchemy import select
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import DeclarativeBase
@@ -27,6 +28,7 @@ class B(Base):
     id: Mapped[int] = mapped_column(primary_key=True)
     a_id: Mapped[int] = mapped_column(ForeignKey("a.id"))
     data: Mapped[str]
+    a: Mapped[A] = relationship()
 
 
 def test_9669_and() -> None:
@@ -36,3 +38,89 @@ def test_9669_and() -> None:
 def test_9669_of_type() -> None:
     ba = aliased(B)
     select(A).options(selectinload(A.bs.of_type(ba)))
+
+
+def load_options_ok() -> None:
+    select(B).options(
+        orm.contains_eager("*").contains_eager(A.bs),
+        orm.load_only("*").load_only(A.bs),
+        orm.joinedload("*").joinedload(A.bs),
+        orm.subqueryload("*").subqueryload(A.bs),
+        orm.selectinload("*").selectinload(A.bs),
+        orm.lazyload("*").lazyload(A.bs),
+        orm.immediateload("*").immediateload(A.bs),
+        orm.noload("*").noload(A.bs),
+        orm.raiseload("*").raiseload(A.bs),
+        orm.defaultload("*").defaultload(A.bs),
+        orm.defer("*").defer(A.bs),
+        orm.undefer("*").undefer(A.bs),
+    )
+    select(B).options(
+        orm.contains_eager(B.a).contains_eager("*"),
+        orm.load_only(B.a).load_only("*"),
+        orm.joinedload(B.a).joinedload("*"),
+        orm.subqueryload(B.a).subqueryload("*"),
+        orm.selectinload(B.a).selectinload("*"),
+        orm.lazyload(B.a).lazyload("*"),
+        orm.immediateload(B.a).immediateload("*"),
+        orm.noload(B.a).noload("*"),
+        orm.raiseload(B.a).raiseload("*"),
+        orm.defaultload(B.a).defaultload("*"),
+        orm.defer(B.a).defer("*"),
+        orm.undefer(B.a).undefer("*"),
+    )
+
+
+def load_options_error() -> None:
+    select(B).options(
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.contains_eager("foo"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.load_only("foo"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.joinedload("foo"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.subqueryload("foo"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.selectinload("foo"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.lazyload("foo"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.immediateload("foo"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.noload("foo"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.raiseload("foo"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.defaultload("foo"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.defer("foo"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.undefer("foo"),
+    )
+    select(B).options(
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.contains_eager(B.a).contains_eager("bar"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.load_only(B.a).load_only("bar"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.joinedload(B.a).joinedload("bar"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.subqueryload(B.a).subqueryload("bar"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.selectinload(B.a).selectinload("bar"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.lazyload(B.a).lazyload("bar"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.immediateload(B.a).immediateload("bar"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.noload(B.a).noload("bar"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.raiseload(B.a).raiseload("bar"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.defaultload(B.a).defaultload("bar"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.defer(B.a).defer("bar"),
+        # EXPECTED_MYPY_RE: Argument 1 to .* has incompatible type .*
+        orm.undefer(B.a).undefer("bar"),
+    )