From: Federico Caselli Date: Sat, 5 Aug 2023 13:07:26 +0000 (+0200) Subject: add nullable construct to better type outer joins X-Git-Tag: rel_2_0_20~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=49297c236bd91b25a1a9a48f380603f507d643ea;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add nullable construct to better type outer joins Fixes: #10173 Change-Id: Ie203232de744882235f1543ec8c73c7d5fe99e3e --- diff --git a/doc/build/changelog/unreleased_20/10173.rst b/doc/build/changelog/unreleased_20/10173.rst new file mode 100644 index 0000000000..ad1b4ade3e --- /dev/null +++ b/doc/build/changelog/unreleased_20/10173.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: typing, usecase + :tickets: 10173 + + Added new typing only utility functions :func:`.Nullable` and + :func:`.NotNullable` to type a column or ORM class as, respectively, + nullable or not nullable. + These function are no-op at runtime, returning the input unchanged. diff --git a/doc/build/core/sqlelement.rst b/doc/build/core/sqlelement.rst index e2174a2c84..9481bf5d9f 100644 --- a/doc/build/core/sqlelement.rst +++ b/doc/build/core/sqlelement.rst @@ -223,5 +223,13 @@ The classes here are generated using the constructors listed at .. autoclass:: UnaryExpression :members: +Column Element Typing Utilities +------------------------------- +Standalone utility functions imported from the ``sqlalchemy`` namespace +to improve support by type checkers. + +.. autofunction:: sqlalchemy.NotNullable + +.. autofunction:: sqlalchemy.Nullable diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index aab16256da..37886f7b26 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -80,6 +80,8 @@ from .schema import Sequence as Sequence from .schema import Table as Table from .schema import UniqueConstraint as UniqueConstraint from .sql import ColumnExpressionArgument as ColumnExpressionArgument +from .sql import NotNullable as NotNullable +from .sql import Nullable as Nullable from .sql import SelectLabelStyle as SelectLabelStyle from .sql.expression import Alias as Alias from .sql.expression import alias as alias diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index c4c8c7d27a..a81509fed7 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -8,6 +8,8 @@ from typing import Any from typing import TYPE_CHECKING from ._typing import ColumnExpressionArgument as ColumnExpressionArgument +from ._typing import NotNullable as NotNullable +from ._typing import Nullable as Nullable from .base import Executable as Executable from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS from .compiler import FROM_LINTING as FROM_LINTING diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index f83c4b4771..a08a770945 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -13,6 +13,8 @@ from typing import Callable from typing import Dict from typing import Mapping from typing import NoReturn +from typing import Optional +from typing import overload from typing import Set from typing import Tuple from typing import Type @@ -150,7 +152,6 @@ sets; select(...), insert().returning(...), etc. _TypedColumnClauseArgument = Union[ roles.TypedColumnsClauseRole[_T], "SQLCoreOperations[_T]", - roles.ExpressionElementRole[_T], Type[_T], ] @@ -374,3 +375,82 @@ def _no_kw() -> exc.ArgumentError: def _unexpected_kw(methname: str, kw: Dict[str, Any]) -> NoReturn: k = list(kw)[0] raise TypeError(f"{methname} got an unexpected keyword argument '{k}'") + + +@overload +def Nullable( + val: "SQLCoreOperations[_T]", +) -> "SQLCoreOperations[Optional[_T]]": + ... + + +@overload +def Nullable( + val: roles.ExpressionElementRole[_T], +) -> roles.ExpressionElementRole[Optional[_T]]: + ... + + +@overload +def Nullable(val: Type[_T]) -> Type[Optional[_T]]: + ... + + +def Nullable( + val: _TypedColumnClauseArgument[_T], +) -> _TypedColumnClauseArgument[Optional[_T]]: + """Types a column or ORM class as nullable. + + This can be used in select and other contexts to express that the value of + a column can be null, for example due to an outer join:: + + stmt1 = select(A, Nullable(B)).outerjoin(A.bs) + stmt2 = select(A.data, Nullable(B.data)).outerjoin(A.bs) + + At runtime this method returns the input unchanged. + + .. versionadded:: 2.0.20 + """ + return val # type: ignore + + +@overload +def NotNullable( + val: "SQLCoreOperations[Optional[_T]]", +) -> "SQLCoreOperations[_T]": + ... + + +@overload +def NotNullable( + val: roles.ExpressionElementRole[Optional[_T]], +) -> roles.ExpressionElementRole[_T]: + ... + + +@overload +def NotNullable(val: Type[Optional[_T]]) -> Type[_T]: + ... + + +@overload +def NotNullable(val: Optional[Type[_T]]) -> Type[_T]: + ... + + +def NotNullable( + val: Union[_TypedColumnClauseArgument[Optional[_T]], Optional[Type[_T]]], +) -> _TypedColumnClauseArgument[_T]: + """Types a column or ORM class as not nullable. + + This can be used in select and other contexts to express that the value of + a column cannot be null, for example due to a where condition on a + nullable column:: + + stmt = select(NotNullable(A.value)).where(A.value.is_not(None)) + + At runtime this method returns the input unchanged. + + .. versionadded:: 2.0.20 + """ + return val # type: ignore diff --git a/test/typing/plain_files/sql/typed_results.py b/test/typing/plain_files/sql/typed_results.py index 3596099cb8..c7842a7e79 100644 --- a/test/typing/plain_files/sql/typed_results.py +++ b/test/typing/plain_files/sql/typed_results.py @@ -2,6 +2,9 @@ from __future__ import annotations import asyncio from typing import cast +from typing import Optional +from typing import Tuple +from typing import Type from sqlalchemy import Column from sqlalchemy import column @@ -9,6 +12,9 @@ from sqlalchemy import create_engine from sqlalchemy import insert from sqlalchemy import Integer from sqlalchemy import MetaData +from sqlalchemy import NotNullable +from sqlalchemy import Nullable +from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table @@ -32,6 +38,7 @@ class User(Base): id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] + value: Mapped[Optional[str]] t_user = Table( @@ -665,3 +672,29 @@ async def t_async_session_stream_scalars() -> None: # EXPECTED_RE_TYPE: typing.Sequence\*?\[builtins.str\*?\] reveal_type(data) + + +def test_outerjoin_10173() -> None: + class Other(Base): + __tablename__ = "other" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + stmt: Select[Tuple[User, Other]] = select(User, Other).outerjoin( + Other, User.id == Other.id + ) + stmt2: Select[Tuple[User, Optional[Other]]] = select( + User, Nullable(Other) + ).outerjoin(Other, User.id == Other.id) + stmt3: Select[Tuple[int, Optional[str]]] = select( + User.id, Nullable(Other.name) + ).outerjoin(Other, User.id == Other.id) + + def go(W: Optional[Type[Other]]) -> None: + stmt4: Select[Tuple[str, Other]] = select( + NotNullable(User.value), NotNullable(W) + ).where(User.value.is_not(None)) + print(stmt4) + + print(stmt, stmt2, stmt3)