From 514f2a8b4c6de8c033496543e9aaf2a0a4eb599d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 14 Jun 2022 17:05:44 -0400 Subject: [PATCH] new features for pep 593 Annotated * extract the inner type from Annotated when the outer type isn't present in the type map, to allow for arbitrary Annotated * allow _IntrospectsAnnotations objects to be directly present in an Annotated and resolve the mapper property from that. Currently implemented for mapped_column(), with message for others. Can work for composite() and likely some relationship() as well at some point References: https://twitter.com/zzzeek/status/1536693554621341697 and replies Change-Id: I04657050a8785f194bf8f63291faf3475af88781 --- lib/sqlalchemy/orm/decl_base.py | 31 +++- lib/sqlalchemy/orm/descriptor_props.py | 6 + lib/sqlalchemy/orm/interfaces.py | 9 ++ lib/sqlalchemy/orm/properties.py | 24 +++- lib/sqlalchemy/sql/annotation.py | 4 + lib/sqlalchemy/util/typing.py | 16 +++ test/orm/declarative/test_typed_mapping.py | 160 ++++++++++++++++++++- 7 files changed, 236 insertions(+), 14 deletions(-) diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index ce044d7e0e..1366bedf24 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -65,6 +65,7 @@ from ..util import topological from ..util.typing import _AnnotationScanType from ..util.typing import Protocol from ..util.typing import TypedDict +from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _ClassDict @@ -885,12 +886,16 @@ class _ClassScanMapperConfig(_MapperConfig): obj, ) elif _is_mapped_annotation(annotation, cls): - self._collect_annotation( + generated_obj = self._collect_annotation( name, annotation, is_dataclass_field, True, obj ) if obj is None: if not fixed_table: - collected_attributes[name] = MappedColumn() + collected_attributes[name] = ( + generated_obj + if generated_obj is not None + else MappedColumn() + ) else: collected_attributes[name] = obj else: @@ -920,7 +925,7 @@ class _ClassScanMapperConfig(_MapperConfig): name, annotation, True, False, obj ) else: - self._collect_annotation( + generated_obj = self._collect_annotation( name, annotation, False, None, obj ) if ( @@ -928,7 +933,11 @@ class _ClassScanMapperConfig(_MapperConfig): and not fixed_table and _is_mapped_annotation(annotation, cls) ): - collected_attributes[name] = MappedColumn() + collected_attributes[name] = ( + generated_obj + if generated_obj is not None + else MappedColumn() + ) elif name in clsdict_view: collected_attributes[name] = obj # else if the name is not in the cls.__dict__, @@ -1022,9 +1031,9 @@ class _ClassScanMapperConfig(_MapperConfig): is_dataclass: bool, expect_mapped: Optional[bool], attr_value: Any, - ) -> None: + ) -> Any: if raw_annotation is None: - return + return attr_value is_dataclass = self.is_dataclass_prior_to_mapping allow_unmapped = self.allow_unmapped_annotations @@ -1053,15 +1062,23 @@ class _ClassScanMapperConfig(_MapperConfig): expect_mapped=expect_mapped and not is_dataclass, # self.allow_dataclass_fields, ) + if extracted_mapped_annotation is None: # ClassVar can come out here - return + return attr_value + elif attr_value is None: + for elem in typing_get_args(extracted_mapped_annotation): + # look in Annotated[...] for an ORM construct, + # such as Annotated[int, mapped_column(primary_key=True)] + if isinstance(elem, _IntrospectsAnnotations): + attr_value = elem.found_in_pep593_annotated() self.collected_annotations[name] = ( raw_annotation, extracted_mapped_annotation, is_dataclass, ) + return attr_value def _warn_for_decl_attributes( self, cls: Type[Any], key: str, c: Any diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index d67319700a..6d308e141c 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -49,6 +49,8 @@ from .. import sql from .. import util from ..sql import expression from ..sql.elements import BindParameter +from ..util.typing import is_pep593 +from ..util.typing import typing_get_args if typing.TYPE_CHECKING: from ._typing import _InstanceDict @@ -342,6 +344,10 @@ class Composite( ): self._raise_for_required(key, cls) argument = extracted_mapped_annotation + + if is_pep593(argument): + argument = typing_get_args(argument)[0] + if argument and self.composite_class is None: if isinstance(argument, str) or hasattr( argument, "__forward_arg__" diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index e0034061d4..a9ae4436f1 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -140,6 +140,15 @@ class ORMColumnDescription(TypedDict): class _IntrospectsAnnotations: __slots__ = () + def found_in_pep593_annotated(self) -> Any: + """return a copy of this object to use in declarative when the + object is found inside of an Annotated object.""" + + raise NotImplementedError( + f"Use of the {self.__class__} construct inside of an " + f"Annotated object is not yet supported." + ) + def declarative_scan( self, registry: RegistryType, diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 0644222936..d1faff1d96 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -52,8 +52,10 @@ from ..sql.schema import SchemaConst from ..util.typing import de_optionalize_union_types from ..util.typing import de_stringify_annotation from ..util.typing import is_fwd_ref +from ..util.typing import is_pep593 from ..util.typing import NoneType from ..util.typing import Self +from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _IdentityKeyType @@ -569,6 +571,9 @@ class MappedColumn( col = self.__clause_element__() return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501 + def found_in_pep593_annotated(self) -> Any: + return self._copy() + def declarative_scan( self, registry: _RegistryType, @@ -632,12 +637,19 @@ class MappedColumn( if is_fwd_ref(our_type): our_type = de_stringify_annotation(cls, our_type) - if registry.type_annotation_map: - new_sqltype = registry.type_annotation_map.get(our_type) - if new_sqltype is None: - new_sqltype = sqltypes._type_map_get(our_type) # type: ignore - - if new_sqltype is None: + if is_pep593(our_type): + checks = (our_type,) + typing_get_args(our_type) + else: + checks = (our_type,) + + for check_type in checks: + if registry.type_annotation_map: + new_sqltype = registry.type_annotation_map.get(check_type) + if new_sqltype is None: + new_sqltype = sqltypes._type_map_get(check_type) # type: ignore # noqa: E501 + if new_sqltype is not None: + break + else: raise sa_exc.ArgumentError( f"Could not locate SQLAlchemy Core " f"type for Python type: {our_type}" diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 56d88bc2fb..61849d0539 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -9,6 +9,10 @@ copies of SQL constructs which contain context-specific markers and associations. +Note that the :class:`.Annotated` concept as implemented in this module is not +related in any way to the pep-593 concept of "Annotated". + + """ from __future__ import annotations diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index 454de100bd..eb625e06e2 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -61,9 +61,18 @@ else: if typing.TYPE_CHECKING or compat.py310: from typing import Annotated as Annotated + from typing import get_args as get_args + from typing import get_origin as get_origin else: from typing_extensions import Annotated as Annotated # noqa: F401 + # these are in py38 but don't work with Annotated correctly, so + # for 3.8 / 3.9 we use the typing extensions version + from typing_extensions import get_args as get_args # noqa: F401 + from typing_extensions import ( + get_origin as get_origin, # noqa: F401, + ) + if typing.TYPE_CHECKING or compat.py38: from typing import Literal as Literal from typing import Protocol as Protocol @@ -75,6 +84,9 @@ else: from typing_extensions import TypedDict as TypedDict # noqa: F401 from typing_extensions import Final as Final # noqa: F401 +typing_get_args = get_args +typing_get_origin = get_origin + # copied from TypeShed, required in order to implement # MutableMapping.update() @@ -140,6 +152,10 @@ def de_stringify_annotation( return annotation # type: ignore +def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: + return type_ is not None and typing_get_origin(type_) is Annotated + + def is_fwd_ref(type_: _AnnotationScanType) -> bool: return isinstance(type_, ForwardRef) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index afaa099b21..ae2773d1c7 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -396,7 +396,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): id: Mapped["int"] = mapped_column(primary_key=True) data_one: Mapped["str"] - def test_annotated_types_as_keys(self, decl_base: Type[DeclarativeBase]): + def test_pep593_types_as_typemap_keys( + self, decl_base: Type[DeclarativeBase] + ): """neat!!!""" str50 = Annotated[str, 50] @@ -425,6 +427,138 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(MyClass.__table__.c.data_two.nullable) eq_(MyClass.__table__.c.data_three.type.length, 50) + def test_extract_base_type_from_pep593( + self, decl_base: Type[DeclarativeBase] + ): + """base type is extracted from an Annotated structure if not otherwise + in the type lookup dictionary""" + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[Annotated[Annotated[int, "q"], "t"]] = mapped_column( + primary_key=True + ) + + is_(MyClass.__table__.c.id.type._type_affinity, Integer) + + def test_extract_sqla_from_pep593_not_yet( + self, decl_base: Type[DeclarativeBase] + ): + """https://twitter.com/zzzeek/status/1536693554621341697""" + + class SomeRelated(decl_base): + __tablename__: ClassVar[Optional[str]] = "some_related" + id: Mapped["int"] = mapped_column(primary_key=True) + + with expect_raises_message( + NotImplementedError, + r"Use of the \ construct inside of an Annotated " + r"object is not yet supported.", + ): + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped["int"] = mapped_column(primary_key=True) + data_one: Mapped[Annotated["SomeRelated", relationship()]] + + def test_extract_sqla_from_pep593_plain( + self, decl_base: Type[DeclarativeBase] + ): + """extraction of mapped_column() from the Annotated type + + https://twitter.com/zzzeek/status/1536693554621341697""" + + intpk = Annotated[int, mapped_column(primary_key=True)] + + strnone = Annotated[str, mapped_column()] # str -> NOT NULL + str30nullable = Annotated[ + str, mapped_column(String(30), nullable=True) # nullable -> NULL + ] + opt_strnone = Optional[strnone] # Optional[str] -> NULL + opt_str30 = Optional[str30nullable] # nullable -> NULL + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[intpk] + + data_one: Mapped[strnone] + data_two: Mapped[str30nullable] + data_three: Mapped[opt_strnone] + data_four: Mapped[opt_str30] + + class MyOtherClass(decl_base): + __tablename__ = "my_other_table" + + id: Mapped[intpk] + + data_one: Mapped[strnone] + data_two: Mapped[str30nullable] + data_three: Mapped[opt_strnone] + data_four: Mapped[opt_str30] + + for cls in MyClass, MyOtherClass: + table = cls.__table__ + assert table is not None + + is_(table.c.id.primary_key, True) + is_(table.c.id.table, table) + + eq_(table.c.data_one.type.length, None) + eq_(table.c.data_two.type.length, 30) + eq_(table.c.data_three.type.length, None) + + is_false(table.c.data_one.nullable) + is_true(table.c.data_two.nullable) + is_true(table.c.data_three.nullable) + is_true(table.c.data_four.nullable) + + def test_extract_sqla_from_pep593_mixin( + self, decl_base: Type[DeclarativeBase] + ): + """extraction of mapped_column() from the Annotated type + + https://twitter.com/zzzeek/status/1536693554621341697""" + + intpk = Annotated[int, mapped_column(primary_key=True)] + + strnone = Annotated[str, mapped_column()] # str -> NOT NULL + str30nullable = Annotated[ + str, mapped_column(String(30), nullable=True) # nullable -> NULL + ] + opt_strnone = Optional[strnone] # Optional[str] -> NULL + opt_str30 = Optional[str30nullable] # nullable -> NULL + + class HasPk: + id: Mapped[intpk] + + data_one: Mapped[strnone] + data_two: Mapped[str30nullable] + + class MyClass(HasPk, decl_base): + __tablename__ = "my_table" + + data_three: Mapped[opt_strnone] + data_four: Mapped[opt_str30] + + table = MyClass.__table__ + assert table is not None + + is_(table.c.id.primary_key, True) + is_(table.c.id.table, table) + + eq_(table.c.data_one.type.length, None) + eq_(table.c.data_two.type.length, 30) + eq_(table.c.data_three.type.length, None) + + is_false(table.c.data_one.nullable) + is_true(table.c.data_two.nullable) + is_true(table.c.data_three.nullable) + is_true(table.c.data_four.nullable) + def test_unions(self): our_type = Numeric(10, 2) @@ -1088,6 +1222,30 @@ class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): mapped_column(), mapped_column(), mapped_column("zip") ) + def test_extract_from_pep593(self, decl_base): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column() + + address: Mapped[Annotated[Address, "foo"]] = composite( + mapped_column(), mapped_column(), mapped_column("zip") + ) + + self.assert_compile( + select(User), + 'SELECT "user".id, "user".name, "user".street, ' + '"user".state, "user".zip FROM "user"', + dialect="default", + ) + def test_cls_not_composite_compliant(self, decl_base): class Address: def __init__(self, street: int, state: str, zip_: str): -- 2.47.2