]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add nullable construct to better type outer joins
authorFederico Caselli <cfederico87@gmail.com>
Sat, 5 Aug 2023 13:07:26 +0000 (15:07 +0200)
committerFederico Caselli <cfederico87@gmail.com>
Tue, 15 Aug 2023 09:10:39 +0000 (11:10 +0200)
Fixes: #10173
Change-Id: Ie203232de744882235f1543ec8c73c7d5fe99e3e

doc/build/changelog/unreleased_20/10173.rst [new file with mode: 0644]
doc/build/core/sqlelement.rst
lib/sqlalchemy/__init__.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/_typing.py
test/typing/plain_files/sql/typed_results.py

diff --git a/doc/build/changelog/unreleased_20/10173.rst b/doc/build/changelog/unreleased_20/10173.rst
new file mode 100644 (file)
index 0000000..ad1b4ad
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: typing, usecase
+    :tickets: 10173
+
+    Added new typing only utility functions :func:`.Nullable` and 
+    :func:`.NotNullable` to type a column or ORM class as, respectively,
+    nullable or not nullable.
+    These function are no-op at runtime, returning the input unchanged.
index e2174a2c843051126495c67f194c0785f41b40f3..9481bf5d9f5ed50ca799f905de3ca0febba648af 100644 (file)
@@ -223,5 +223,13 @@ The classes here are generated using the constructors listed at
 .. autoclass:: UnaryExpression
    :members:
 
+Column Element Typing Utilities
+-------------------------------
 
+Standalone utility functions imported from the ``sqlalchemy`` namespace
+to improve support by type checkers.
 
+
+.. autofunction:: sqlalchemy.NotNullable
+
+.. autofunction:: sqlalchemy.Nullable
index aab16256dac554067e650daf25e9cac83276bdc8..37886f7b2682561ffc8b9567b377d466ef55be02 100644 (file)
@@ -80,6 +80,8 @@ from .schema import Sequence as Sequence
 from .schema import Table as Table
 from .schema import UniqueConstraint as UniqueConstraint
 from .sql import ColumnExpressionArgument as ColumnExpressionArgument
+from .sql import NotNullable as NotNullable
+from .sql import Nullable as Nullable
 from .sql import SelectLabelStyle as SelectLabelStyle
 from .sql.expression import Alias as Alias
 from .sql.expression import alias as alias
index c4c8c7d27a79378febfe6a258a705e9f82eaecf5..a81509fed745d8f33cb27b3ef970cc55d81b8576 100644 (file)
@@ -8,6 +8,8 @@ from typing import Any
 from typing import TYPE_CHECKING
 
 from ._typing import ColumnExpressionArgument as ColumnExpressionArgument
+from ._typing import NotNullable as NotNullable
+from ._typing import Nullable as Nullable
 from .base import Executable as Executable
 from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS
 from .compiler import FROM_LINTING as FROM_LINTING
index f83c4b47714a861b42a6a7ddfeb5f70b7f3f30b6..a08a770945df2de522f8baddcff8dc7b15121e88 100644 (file)
@@ -13,6 +13,8 @@ from typing import Callable
 from typing import Dict
 from typing import Mapping
 from typing import NoReturn
+from typing import Optional
+from typing import overload
 from typing import Set
 from typing import Tuple
 from typing import Type
@@ -150,7 +152,6 @@ sets; select(...), insert().returning(...), etc.
 _TypedColumnClauseArgument = Union[
     roles.TypedColumnsClauseRole[_T],
     "SQLCoreOperations[_T]",
-    roles.ExpressionElementRole[_T],
     Type[_T],
 ]
 
@@ -374,3 +375,82 @@ def _no_kw() -> exc.ArgumentError:
 def _unexpected_kw(methname: str, kw: Dict[str, Any]) -> NoReturn:
     k = list(kw)[0]
     raise TypeError(f"{methname} got an unexpected keyword argument '{k}'")
+
+
+@overload
+def Nullable(
+    val: "SQLCoreOperations[_T]",
+) -> "SQLCoreOperations[Optional[_T]]":
+    ...
+
+
+@overload
+def Nullable(
+    val: roles.ExpressionElementRole[_T],
+) -> roles.ExpressionElementRole[Optional[_T]]:
+    ...
+
+
+@overload
+def Nullable(val: Type[_T]) -> Type[Optional[_T]]:
+    ...
+
+
+def Nullable(
+    val: _TypedColumnClauseArgument[_T],
+) -> _TypedColumnClauseArgument[Optional[_T]]:
+    """Types a column or ORM class as nullable.
+
+    This can be used in select and other contexts to express that the value of
+    a column can be null, for example due to an outer join::
+
+        stmt1 = select(A, Nullable(B)).outerjoin(A.bs)
+        stmt2 = select(A.data, Nullable(B.data)).outerjoin(A.bs)
+
+    At runtime this method returns the input unchanged.
+
+    .. versionadded:: 2.0.20
+    """
+    return val  # type: ignore
+
+
+@overload
+def NotNullable(
+    val: "SQLCoreOperations[Optional[_T]]",
+) -> "SQLCoreOperations[_T]":
+    ...
+
+
+@overload
+def NotNullable(
+    val: roles.ExpressionElementRole[Optional[_T]],
+) -> roles.ExpressionElementRole[_T]:
+    ...
+
+
+@overload
+def NotNullable(val: Type[Optional[_T]]) -> Type[_T]:
+    ...
+
+
+@overload
+def NotNullable(val: Optional[Type[_T]]) -> Type[_T]:
+    ...
+
+
+def NotNullable(
+    val: Union[_TypedColumnClauseArgument[Optional[_T]], Optional[Type[_T]]],
+) -> _TypedColumnClauseArgument[_T]:
+    """Types a column or ORM class as not nullable.
+
+    This can be used in select and other contexts to express that the value of
+    a column cannot be null, for example due to a where condition on a
+    nullable column::
+
+        stmt = select(NotNullable(A.value)).where(A.value.is_not(None))
+
+    At runtime this method returns the input unchanged.
+
+    .. versionadded:: 2.0.20
+    """
+    return val  # type: ignore
index 3596099cb86169d0c07a53105b0a2f342c1429ab..c7842a7e7995d0407d46b1b94c73a67b92e6ada2 100644 (file)
@@ -2,6 +2,9 @@ from __future__ import annotations
 
 import asyncio
 from typing import cast
+from typing import Optional
+from typing import Tuple
+from typing import Type
 
 from sqlalchemy import Column
 from sqlalchemy import column
@@ -9,6 +12,9 @@ from sqlalchemy import create_engine
 from sqlalchemy import insert
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
+from sqlalchemy import NotNullable
+from sqlalchemy import Nullable
+from sqlalchemy import Select
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
@@ -32,6 +38,7 @@ class User(Base):
 
     id: Mapped[int] = mapped_column(primary_key=True)
     name: Mapped[str]
+    value: Mapped[Optional[str]]
 
 
 t_user = Table(
@@ -665,3 +672,29 @@ async def t_async_session_stream_scalars() -> None:
 
     # EXPECTED_RE_TYPE: typing.Sequence\*?\[builtins.str\*?\]
     reveal_type(data)
+
+
+def test_outerjoin_10173() -> None:
+    class Other(Base):
+        __tablename__ = "other"
+
+        id: Mapped[int] = mapped_column(primary_key=True)
+        name: Mapped[str]
+
+    stmt: Select[Tuple[User, Other]] = select(User, Other).outerjoin(
+        Other, User.id == Other.id
+    )
+    stmt2: Select[Tuple[User, Optional[Other]]] = select(
+        User, Nullable(Other)
+    ).outerjoin(Other, User.id == Other.id)
+    stmt3: Select[Tuple[int, Optional[str]]] = select(
+        User.id, Nullable(Other.name)
+    ).outerjoin(Other, User.id == Other.id)
+
+    def go(W: Optional[Type[Other]]) -> None:
+        stmt4: Select[Tuple[str, Other]] = select(
+            NotNullable(User.value), NotNullable(W)
+        ).where(User.value.is_not(None))
+        print(stmt4)
+
+    print(stmt, stmt2, stmt3)