]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
new features for pep 593 Annotated
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 14 Jun 2022 21:05:44 +0000 (17:05 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 15 Jun 2022 13:04:51 +0000 (09:04 -0400)
* extract the inner type from Annotated when the outer type
  isn't present in the type map, to allow for arbitrary Annotated
* allow _IntrospectsAnnotations objects to be directly present
  in an Annotated and resolve the mapper property from that.
  Currently implemented for mapped_column(), with message for
  others.  Can work for composite() and likely some
  relationship() as well at some point

References: https://twitter.com/zzzeek/status/1536693554621341697 and
replies

Change-Id: I04657050a8785f194bf8f63291faf3475af88781

lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/interfaces.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql/annotation.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_typed_mapping.py

index ce044d7e0e975d365647f987f24c9769349bcc1b..1366bedf2499b5663aea9d4677ee6257061d3484 100644 (file)
@@ -65,6 +65,7 @@ from ..util import topological
 from ..util.typing import _AnnotationScanType
 from ..util.typing import Protocol
 from ..util.typing import TypedDict
+from ..util.typing import typing_get_args
 
 if TYPE_CHECKING:
     from ._typing import _ClassDict
@@ -885,12 +886,16 @@ class _ClassScanMapperConfig(_MapperConfig):
                             obj,
                         )
                     elif _is_mapped_annotation(annotation, cls):
-                        self._collect_annotation(
+                        generated_obj = self._collect_annotation(
                             name, annotation, is_dataclass_field, True, obj
                         )
                         if obj is None:
                             if not fixed_table:
-                                collected_attributes[name] = MappedColumn()
+                                collected_attributes[name] = (
+                                    generated_obj
+                                    if generated_obj is not None
+                                    else MappedColumn()
+                                )
                         else:
                             collected_attributes[name] = obj
                     else:
@@ -920,7 +925,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                         name, annotation, True, False, obj
                     )
                 else:
-                    self._collect_annotation(
+                    generated_obj = self._collect_annotation(
                         name, annotation, False, None, obj
                     )
                     if (
@@ -928,7 +933,11 @@ class _ClassScanMapperConfig(_MapperConfig):
                         and not fixed_table
                         and _is_mapped_annotation(annotation, cls)
                     ):
-                        collected_attributes[name] = MappedColumn()
+                        collected_attributes[name] = (
+                            generated_obj
+                            if generated_obj is not None
+                            else MappedColumn()
+                        )
                     elif name in clsdict_view:
                         collected_attributes[name] = obj
                     # else if the name is not in the cls.__dict__,
@@ -1022,9 +1031,9 @@ class _ClassScanMapperConfig(_MapperConfig):
         is_dataclass: bool,
         expect_mapped: Optional[bool],
         attr_value: Any,
-    ) -> None:
+    ) -> Any:
         if raw_annotation is None:
-            return
+            return attr_value
 
         is_dataclass = self.is_dataclass_prior_to_mapping
         allow_unmapped = self.allow_unmapped_annotations
@@ -1053,15 +1062,23 @@ class _ClassScanMapperConfig(_MapperConfig):
             expect_mapped=expect_mapped
             and not is_dataclass,  # self.allow_dataclass_fields,
         )
+
         if extracted_mapped_annotation is None:
             # ClassVar can come out here
-            return
+            return attr_value
+        elif attr_value is None:
+            for elem in typing_get_args(extracted_mapped_annotation):
+                # look in Annotated[...] for an ORM construct,
+                # such as Annotated[int, mapped_column(primary_key=True)]
+                if isinstance(elem, _IntrospectsAnnotations):
+                    attr_value = elem.found_in_pep593_annotated()
 
         self.collected_annotations[name] = (
             raw_annotation,
             extracted_mapped_annotation,
             is_dataclass,
         )
+        return attr_value
 
     def _warn_for_decl_attributes(
         self, cls: Type[Any], key: str, c: Any
index d67319700abe5e37e265097fc7cbbfd001488a76..6d308e141ced95c9dd695083341a37df8ce32226 100644 (file)
@@ -49,6 +49,8 @@ from .. import sql
 from .. import util
 from ..sql import expression
 from ..sql.elements import BindParameter
+from ..util.typing import is_pep593
+from ..util.typing import typing_get_args
 
 if typing.TYPE_CHECKING:
     from ._typing import _InstanceDict
@@ -342,6 +344,10 @@ class Composite(
         ):
             self._raise_for_required(key, cls)
         argument = extracted_mapped_annotation
+
+        if is_pep593(argument):
+            argument = typing_get_args(argument)[0]
+
         if argument and self.composite_class is None:
             if isinstance(argument, str) or hasattr(
                 argument, "__forward_arg__"
index e0034061d4e25e33e580afcb18899566e666c01b..a9ae4436f17bba5cfde29ee2b53cea7de76de137 100644 (file)
@@ -140,6 +140,15 @@ class ORMColumnDescription(TypedDict):
 class _IntrospectsAnnotations:
     __slots__ = ()
 
+    def found_in_pep593_annotated(self) -> Any:
+        """return a copy of this object to use in declarative when the
+        object is found inside of an Annotated object."""
+
+        raise NotImplementedError(
+            f"Use of the {self.__class__} construct inside of an "
+            f"Annotated object is not yet supported."
+        )
+
     def declarative_scan(
         self,
         registry: RegistryType,
index 064422293628c447adceb4cea691aab913e4f35d..d1faff1d964209fc6ab443f376d2e5d7bb084274 100644 (file)
@@ -52,8 +52,10 @@ from ..sql.schema import SchemaConst
 from ..util.typing import de_optionalize_union_types
 from ..util.typing import de_stringify_annotation
 from ..util.typing import is_fwd_ref
+from ..util.typing import is_pep593
 from ..util.typing import NoneType
 from ..util.typing import Self
+from ..util.typing import typing_get_args
 
 if TYPE_CHECKING:
     from ._typing import _IdentityKeyType
@@ -569,6 +571,9 @@ class MappedColumn(
         col = self.__clause_element__()
         return op(col._bind_param(op, other), col, **kwargs)  # type: ignore[return-value]  # noqa: E501
 
+    def found_in_pep593_annotated(self) -> Any:
+        return self._copy()
+
     def declarative_scan(
         self,
         registry: _RegistryType,
@@ -632,12 +637,19 @@ class MappedColumn(
             if is_fwd_ref(our_type):
                 our_type = de_stringify_annotation(cls, our_type)
 
-            if registry.type_annotation_map:
-                new_sqltype = registry.type_annotation_map.get(our_type)
-            if new_sqltype is None:
-                new_sqltype = sqltypes._type_map_get(our_type)  # type: ignore
-
-            if new_sqltype is None:
+            if is_pep593(our_type):
+                checks = (our_type,) + typing_get_args(our_type)
+            else:
+                checks = (our_type,)
+
+            for check_type in checks:
+                if registry.type_annotation_map:
+                    new_sqltype = registry.type_annotation_map.get(check_type)
+                if new_sqltype is None:
+                    new_sqltype = sqltypes._type_map_get(check_type)  # type: ignore  # noqa: E501
+                if new_sqltype is not None:
+                    break
+            else:
                 raise sa_exc.ArgumentError(
                     f"Could not locate SQLAlchemy Core "
                     f"type for Python type: {our_type}"
index 56d88bc2fb8881fba9c127b8818ff813bfd05a7b..61849d0539935f41ff33446bbeb444adf9e1b597 100644 (file)
@@ -9,6 +9,10 @@
 copies of SQL constructs which contain context-specific markers and
 associations.
 
+Note that the :class:`.Annotated` concept as implemented in this module is not
+related in any way to the pep-593 concept of "Annotated".
+
+
 """
 
 from __future__ import annotations
index 454de100bd2d02959ac6c768da3772df6045a66a..eb625e06e290d1fdb58a644f68e8cada63e67708 100644 (file)
@@ -61,9 +61,18 @@ else:
 
 if typing.TYPE_CHECKING or compat.py310:
     from typing import Annotated as Annotated
+    from typing import get_args as get_args
+    from typing import get_origin as get_origin
 else:
     from typing_extensions import Annotated as Annotated  # noqa: F401
 
+    # these are in py38 but don't work with Annotated correctly, so
+    # for 3.8 / 3.9 we use the typing extensions version
+    from typing_extensions import get_args as get_args  # noqa: F401
+    from typing_extensions import (
+        get_origin as get_origin,  # noqa: F401,
+    )
+
 if typing.TYPE_CHECKING or compat.py38:
     from typing import Literal as Literal
     from typing import Protocol as Protocol
@@ -75,6 +84,9 @@ else:
     from typing_extensions import TypedDict as TypedDict  # noqa: F401
     from typing_extensions import Final as Final  # noqa: F401
 
+typing_get_args = get_args
+typing_get_origin = get_origin
+
 # copied from TypeShed, required in order to implement
 # MutableMapping.update()
 
@@ -140,6 +152,10 @@ def de_stringify_annotation(
     return annotation  # type: ignore
 
 
+def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
+    return type_ is not None and typing_get_origin(type_) is Annotated
+
+
 def is_fwd_ref(type_: _AnnotationScanType) -> bool:
     return isinstance(type_, ForwardRef)
 
index afaa099b21326d91b1a6428ef8d613033a20bd92..ae2773d1c7f9bcda26fafd60d1769da107c37d08 100644 (file)
@@ -396,7 +396,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             id: Mapped["int"] = mapped_column(primary_key=True)
             data_one: Mapped["str"]
 
-    def test_annotated_types_as_keys(self, decl_base: Type[DeclarativeBase]):
+    def test_pep593_types_as_typemap_keys(
+        self, decl_base: Type[DeclarativeBase]
+    ):
         """neat!!!"""
 
         str50 = Annotated[str, 50]
@@ -425,6 +427,138 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(MyClass.__table__.c.data_two.nullable)
         eq_(MyClass.__table__.c.data_three.type.length, 50)
 
+    def test_extract_base_type_from_pep593(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        """base type is extracted from an Annotated structure if not otherwise
+        in the type lookup dictionary"""
+
+        class MyClass(decl_base):
+            __tablename__ = "my_table"
+
+            id: Mapped[Annotated[Annotated[int, "q"], "t"]] = mapped_column(
+                primary_key=True
+            )
+
+        is_(MyClass.__table__.c.id.type._type_affinity, Integer)
+
+    def test_extract_sqla_from_pep593_not_yet(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        """https://twitter.com/zzzeek/status/1536693554621341697"""
+
+        class SomeRelated(decl_base):
+            __tablename__: ClassVar[Optional[str]] = "some_related"
+            id: Mapped["int"] = mapped_column(primary_key=True)
+
+        with expect_raises_message(
+            NotImplementedError,
+            r"Use of the \<class 'sqlalchemy.orm."
+            r"relationships.Relationship'\> construct inside of an Annotated "
+            r"object is not yet supported.",
+        ):
+
+            class MyClass(decl_base):
+                __tablename__ = "my_table"
+
+                id: Mapped["int"] = mapped_column(primary_key=True)
+                data_one: Mapped[Annotated["SomeRelated", relationship()]]
+
+    def test_extract_sqla_from_pep593_plain(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        """extraction of mapped_column() from the Annotated type
+
+        https://twitter.com/zzzeek/status/1536693554621341697"""
+
+        intpk = Annotated[int, mapped_column(primary_key=True)]
+
+        strnone = Annotated[str, mapped_column()]  # str -> NOT NULL
+        str30nullable = Annotated[
+            str, mapped_column(String(30), nullable=True)  # nullable -> NULL
+        ]
+        opt_strnone = Optional[strnone]  # Optional[str] -> NULL
+        opt_str30 = Optional[str30nullable]  # nullable -> NULL
+
+        class MyClass(decl_base):
+            __tablename__ = "my_table"
+
+            id: Mapped[intpk]
+
+            data_one: Mapped[strnone]
+            data_two: Mapped[str30nullable]
+            data_three: Mapped[opt_strnone]
+            data_four: Mapped[opt_str30]
+
+        class MyOtherClass(decl_base):
+            __tablename__ = "my_other_table"
+
+            id: Mapped[intpk]
+
+            data_one: Mapped[strnone]
+            data_two: Mapped[str30nullable]
+            data_three: Mapped[opt_strnone]
+            data_four: Mapped[opt_str30]
+
+        for cls in MyClass, MyOtherClass:
+            table = cls.__table__
+            assert table is not None
+
+            is_(table.c.id.primary_key, True)
+            is_(table.c.id.table, table)
+
+            eq_(table.c.data_one.type.length, None)
+            eq_(table.c.data_two.type.length, 30)
+            eq_(table.c.data_three.type.length, None)
+
+            is_false(table.c.data_one.nullable)
+            is_true(table.c.data_two.nullable)
+            is_true(table.c.data_three.nullable)
+            is_true(table.c.data_four.nullable)
+
+    def test_extract_sqla_from_pep593_mixin(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        """extraction of mapped_column() from the Annotated type
+
+        https://twitter.com/zzzeek/status/1536693554621341697"""
+
+        intpk = Annotated[int, mapped_column(primary_key=True)]
+
+        strnone = Annotated[str, mapped_column()]  # str -> NOT NULL
+        str30nullable = Annotated[
+            str, mapped_column(String(30), nullable=True)  # nullable -> NULL
+        ]
+        opt_strnone = Optional[strnone]  # Optional[str] -> NULL
+        opt_str30 = Optional[str30nullable]  # nullable -> NULL
+
+        class HasPk:
+            id: Mapped[intpk]
+
+            data_one: Mapped[strnone]
+            data_two: Mapped[str30nullable]
+
+        class MyClass(HasPk, decl_base):
+            __tablename__ = "my_table"
+
+            data_three: Mapped[opt_strnone]
+            data_four: Mapped[opt_str30]
+
+        table = MyClass.__table__
+        assert table is not None
+
+        is_(table.c.id.primary_key, True)
+        is_(table.c.id.table, table)
+
+        eq_(table.c.data_one.type.length, None)
+        eq_(table.c.data_two.type.length, 30)
+        eq_(table.c.data_three.type.length, None)
+
+        is_false(table.c.data_one.nullable)
+        is_true(table.c.data_two.nullable)
+        is_true(table.c.data_three.nullable)
+        is_true(table.c.data_four.nullable)
+
     def test_unions(self):
         our_type = Numeric(10, 2)
 
@@ -1088,6 +1222,30 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                     mapped_column(), mapped_column(), mapped_column("zip")
                 )
 
+    def test_extract_from_pep593(self, decl_base):
+        @dataclasses.dataclass
+        class Address:
+            street: str
+            state: str
+            zip_: str
+
+        class User(decl_base):
+            __tablename__ = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            name: Mapped[str] = mapped_column()
+
+            address: Mapped[Annotated[Address, "foo"]] = composite(
+                mapped_column(), mapped_column(), mapped_column("zip")
+            )
+
+        self.assert_compile(
+            select(User),
+            'SELECT "user".id, "user".name, "user".street, '
+            '"user".state, "user".zip FROM "user"',
+            dialect="default",
+        )
+
     def test_cls_not_composite_compliant(self, decl_base):
         class Address:
             def __init__(self, street: int, state: str, zip_: str):