from ..util.typing import _AnnotationScanType
from ..util.typing import Protocol
from ..util.typing import TypedDict
+from ..util.typing import typing_get_args
if TYPE_CHECKING:
from ._typing import _ClassDict
obj,
)
elif _is_mapped_annotation(annotation, cls):
- self._collect_annotation(
+ generated_obj = self._collect_annotation(
name, annotation, is_dataclass_field, True, obj
)
if obj is None:
if not fixed_table:
- collected_attributes[name] = MappedColumn()
+ collected_attributes[name] = (
+ generated_obj
+ if generated_obj is not None
+ else MappedColumn()
+ )
else:
collected_attributes[name] = obj
else:
name, annotation, True, False, obj
)
else:
- self._collect_annotation(
+ generated_obj = self._collect_annotation(
name, annotation, False, None, obj
)
if (
and not fixed_table
and _is_mapped_annotation(annotation, cls)
):
- collected_attributes[name] = MappedColumn()
+ collected_attributes[name] = (
+ generated_obj
+ if generated_obj is not None
+ else MappedColumn()
+ )
elif name in clsdict_view:
collected_attributes[name] = obj
# else if the name is not in the cls.__dict__,
is_dataclass: bool,
expect_mapped: Optional[bool],
attr_value: Any,
- ) -> None:
+ ) -> Any:
if raw_annotation is None:
- return
+ return attr_value
is_dataclass = self.is_dataclass_prior_to_mapping
allow_unmapped = self.allow_unmapped_annotations
expect_mapped=expect_mapped
and not is_dataclass, # self.allow_dataclass_fields,
)
+
if extracted_mapped_annotation is None:
# ClassVar can come out here
- return
+ return attr_value
+ elif attr_value is None:
+ for elem in typing_get_args(extracted_mapped_annotation):
+ # look in Annotated[...] for an ORM construct,
+ # such as Annotated[int, mapped_column(primary_key=True)]
+ if isinstance(elem, _IntrospectsAnnotations):
+ attr_value = elem.found_in_pep593_annotated()
self.collected_annotations[name] = (
raw_annotation,
extracted_mapped_annotation,
is_dataclass,
)
+ return attr_value
def _warn_for_decl_attributes(
self, cls: Type[Any], key: str, c: Any
from .. import util
from ..sql import expression
from ..sql.elements import BindParameter
+from ..util.typing import is_pep593
+from ..util.typing import typing_get_args
if typing.TYPE_CHECKING:
from ._typing import _InstanceDict
):
self._raise_for_required(key, cls)
argument = extracted_mapped_annotation
+
+ if is_pep593(argument):
+ argument = typing_get_args(argument)[0]
+
if argument and self.composite_class is None:
if isinstance(argument, str) or hasattr(
argument, "__forward_arg__"
class _IntrospectsAnnotations:
__slots__ = ()
+ def found_in_pep593_annotated(self) -> Any:
+ """return a copy of this object to use in declarative when the
+ object is found inside of an Annotated object."""
+
+ raise NotImplementedError(
+ f"Use of the {self.__class__} construct inside of an "
+ f"Annotated object is not yet supported."
+ )
+
def declarative_scan(
self,
registry: RegistryType,
from ..util.typing import de_optionalize_union_types
from ..util.typing import de_stringify_annotation
from ..util.typing import is_fwd_ref
+from ..util.typing import is_pep593
from ..util.typing import NoneType
from ..util.typing import Self
+from ..util.typing import typing_get_args
if TYPE_CHECKING:
from ._typing import _IdentityKeyType
col = self.__clause_element__()
return op(col._bind_param(op, other), col, **kwargs) # type: ignore[return-value] # noqa: E501
+ def found_in_pep593_annotated(self) -> Any:
+ return self._copy()
+
def declarative_scan(
self,
registry: _RegistryType,
if is_fwd_ref(our_type):
our_type = de_stringify_annotation(cls, our_type)
- if registry.type_annotation_map:
- new_sqltype = registry.type_annotation_map.get(our_type)
- if new_sqltype is None:
- new_sqltype = sqltypes._type_map_get(our_type) # type: ignore
-
- if new_sqltype is None:
+ if is_pep593(our_type):
+ checks = (our_type,) + typing_get_args(our_type)
+ else:
+ checks = (our_type,)
+
+ for check_type in checks:
+ if registry.type_annotation_map:
+ new_sqltype = registry.type_annotation_map.get(check_type)
+ if new_sqltype is None:
+ new_sqltype = sqltypes._type_map_get(check_type) # type: ignore # noqa: E501
+ if new_sqltype is not None:
+ break
+ else:
raise sa_exc.ArgumentError(
f"Could not locate SQLAlchemy Core "
f"type for Python type: {our_type}"
copies of SQL constructs which contain context-specific markers and
associations.
+Note that the :class:`.Annotated` concept as implemented in this module is not
+related in any way to the pep-593 concept of "Annotated".
+
+
"""
from __future__ import annotations
if typing.TYPE_CHECKING or compat.py310:
from typing import Annotated as Annotated
+ from typing import get_args as get_args
+ from typing import get_origin as get_origin
else:
from typing_extensions import Annotated as Annotated # noqa: F401
+ # these are in py38 but don't work with Annotated correctly, so
+ # for 3.8 / 3.9 we use the typing extensions version
+ from typing_extensions import get_args as get_args # noqa: F401
+ from typing_extensions import (
+ get_origin as get_origin, # noqa: F401,
+ )
+
if typing.TYPE_CHECKING or compat.py38:
from typing import Literal as Literal
from typing import Protocol as Protocol
from typing_extensions import TypedDict as TypedDict # noqa: F401
from typing_extensions import Final as Final # noqa: F401
+typing_get_args = get_args
+typing_get_origin = get_origin
+
# copied from TypeShed, required in order to implement
# MutableMapping.update()
return annotation # type: ignore
+def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
+ return type_ is not None and typing_get_origin(type_) is Annotated
+
+
def is_fwd_ref(type_: _AnnotationScanType) -> bool:
return isinstance(type_, ForwardRef)
id: Mapped["int"] = mapped_column(primary_key=True)
data_one: Mapped["str"]
- def test_annotated_types_as_keys(self, decl_base: Type[DeclarativeBase]):
+ def test_pep593_types_as_typemap_keys(
+ self, decl_base: Type[DeclarativeBase]
+ ):
"""neat!!!"""
str50 = Annotated[str, 50]
is_true(MyClass.__table__.c.data_two.nullable)
eq_(MyClass.__table__.c.data_three.type.length, 50)
+ def test_extract_base_type_from_pep593(
+ self, decl_base: Type[DeclarativeBase]
+ ):
+ """base type is extracted from an Annotated structure if not otherwise
+ in the type lookup dictionary"""
+
+ class MyClass(decl_base):
+ __tablename__ = "my_table"
+
+ id: Mapped[Annotated[Annotated[int, "q"], "t"]] = mapped_column(
+ primary_key=True
+ )
+
+ is_(MyClass.__table__.c.id.type._type_affinity, Integer)
+
+ def test_extract_sqla_from_pep593_not_yet(
+ self, decl_base: Type[DeclarativeBase]
+ ):
+ """https://twitter.com/zzzeek/status/1536693554621341697"""
+
+ class SomeRelated(decl_base):
+ __tablename__: ClassVar[Optional[str]] = "some_related"
+ id: Mapped["int"] = mapped_column(primary_key=True)
+
+ with expect_raises_message(
+ NotImplementedError,
+ r"Use of the \<class 'sqlalchemy.orm."
+ r"relationships.Relationship'\> construct inside of an Annotated "
+ r"object is not yet supported.",
+ ):
+
+ class MyClass(decl_base):
+ __tablename__ = "my_table"
+
+ id: Mapped["int"] = mapped_column(primary_key=True)
+ data_one: Mapped[Annotated["SomeRelated", relationship()]]
+
+ def test_extract_sqla_from_pep593_plain(
+ self, decl_base: Type[DeclarativeBase]
+ ):
+ """extraction of mapped_column() from the Annotated type
+
+ https://twitter.com/zzzeek/status/1536693554621341697"""
+
+ intpk = Annotated[int, mapped_column(primary_key=True)]
+
+ strnone = Annotated[str, mapped_column()] # str -> NOT NULL
+ str30nullable = Annotated[
+ str, mapped_column(String(30), nullable=True) # nullable -> NULL
+ ]
+ opt_strnone = Optional[strnone] # Optional[str] -> NULL
+ opt_str30 = Optional[str30nullable] # nullable -> NULL
+
+ class MyClass(decl_base):
+ __tablename__ = "my_table"
+
+ id: Mapped[intpk]
+
+ data_one: Mapped[strnone]
+ data_two: Mapped[str30nullable]
+ data_three: Mapped[opt_strnone]
+ data_four: Mapped[opt_str30]
+
+ class MyOtherClass(decl_base):
+ __tablename__ = "my_other_table"
+
+ id: Mapped[intpk]
+
+ data_one: Mapped[strnone]
+ data_two: Mapped[str30nullable]
+ data_three: Mapped[opt_strnone]
+ data_four: Mapped[opt_str30]
+
+ for cls in MyClass, MyOtherClass:
+ table = cls.__table__
+ assert table is not None
+
+ is_(table.c.id.primary_key, True)
+ is_(table.c.id.table, table)
+
+ eq_(table.c.data_one.type.length, None)
+ eq_(table.c.data_two.type.length, 30)
+ eq_(table.c.data_three.type.length, None)
+
+ is_false(table.c.data_one.nullable)
+ is_true(table.c.data_two.nullable)
+ is_true(table.c.data_three.nullable)
+ is_true(table.c.data_four.nullable)
+
+ def test_extract_sqla_from_pep593_mixin(
+ self, decl_base: Type[DeclarativeBase]
+ ):
+ """extraction of mapped_column() from the Annotated type
+
+ https://twitter.com/zzzeek/status/1536693554621341697"""
+
+ intpk = Annotated[int, mapped_column(primary_key=True)]
+
+ strnone = Annotated[str, mapped_column()] # str -> NOT NULL
+ str30nullable = Annotated[
+ str, mapped_column(String(30), nullable=True) # nullable -> NULL
+ ]
+ opt_strnone = Optional[strnone] # Optional[str] -> NULL
+ opt_str30 = Optional[str30nullable] # nullable -> NULL
+
+ class HasPk:
+ id: Mapped[intpk]
+
+ data_one: Mapped[strnone]
+ data_two: Mapped[str30nullable]
+
+ class MyClass(HasPk, decl_base):
+ __tablename__ = "my_table"
+
+ data_three: Mapped[opt_strnone]
+ data_four: Mapped[opt_str30]
+
+ table = MyClass.__table__
+ assert table is not None
+
+ is_(table.c.id.primary_key, True)
+ is_(table.c.id.table, table)
+
+ eq_(table.c.data_one.type.length, None)
+ eq_(table.c.data_two.type.length, 30)
+ eq_(table.c.data_three.type.length, None)
+
+ is_false(table.c.data_one.nullable)
+ is_true(table.c.data_two.nullable)
+ is_true(table.c.data_three.nullable)
+ is_true(table.c.data_four.nullable)
+
def test_unions(self):
our_type = Numeric(10, 2)
mapped_column(), mapped_column(), mapped_column("zip")
)
+ def test_extract_from_pep593(self, decl_base):
+ @dataclasses.dataclass
+ class Address:
+ street: str
+ state: str
+ zip_: str
+
+ class User(decl_base):
+ __tablename__ = "user"
+
+ id: Mapped[int] = mapped_column(primary_key=True)
+ name: Mapped[str] = mapped_column()
+
+ address: Mapped[Annotated[Address, "foo"]] = composite(
+ mapped_column(), mapped_column(), mapped_column("zip")
+ )
+
+ self.assert_compile(
+ select(User),
+ 'SELECT "user".id, "user".name, "user".street, '
+ '"user".state, "user".zip FROM "user"',
+ dialect="default",
+ )
+
def test_cls_not_composite_compliant(self, decl_base):
class Address:
def __init__(self, street: int, state: str, zip_: str):