From: Mike Bayer Date: Mon, 30 Jan 2023 20:12:52 +0000 (-0500) Subject: support NewType in type_annotation_map X-Git-Tag: rel_2_0_1~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a21c715b7a89b0619db0d2d5b31617d17b25a27a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git support NewType in type_annotation_map Added support for :pep:`484` ``NewType`` to be used in the :paramref:`_orm.registry.type_annotation_map` as well as within :class:`.Mapped` constructs. These types will behave in the same way as custom subclasses of types right now; they must appear explicitly within the :paramref:`_orm.registry.type_annotation_map` to be mapped. Within this change, the lookup between decl_api._resolve_type and TypeEngine._resolve_for_python_type is streamlined to not inspect the given type multiple times, instead passing in from decl_api to TypeEngine the already "flattened" version of a Generic or NewType type. Fixes: #9175 Change-Id: I227cf84b4b88e4567fa2d1d7da0c05b54e00c562 --- diff --git a/doc/build/changelog/unreleased_20/9175.rst b/doc/build/changelog/unreleased_20/9175.rst new file mode 100644 index 0000000000..f79184931b --- /dev/null +++ b/doc/build/changelog/unreleased_20/9175.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: bug, orm + :tickets: 9175 + + Added support for :pep:`484` ``NewType`` to be used in the + :paramref:`_orm.registry.type_annotation_map` as well as within + :class:`.Mapped` constructs. These types will behave in the same way as + custom subclasses of types right now; they must appear explicitly within + the :paramref:`_orm.registry.type_annotation_map` to be mapped. diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index a46c1a7fbb..4f84438330 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -19,6 +19,7 @@ from typing import ClassVar from typing import Dict from typing import FrozenSet from typing import Generic +from typing import Iterable from typing import Iterator from typing import Mapping from typing import Optional @@ -74,7 +75,9 @@ from ..util import hybridmethod from ..util import hybridproperty from ..util import typing as compat_typing from ..util.typing import CallableReference +from ..util.typing import flatten_newtype from ..util.typing import is_generic +from ..util.typing import is_newtype from ..util.typing import Literal if TYPE_CHECKING: @@ -85,7 +88,7 @@ if TYPE_CHECKING: from .interfaces import MapperProperty from .state import InstanceState # noqa from ..sql._typing import _TypeEngineArgument - from ..util.typing import GenericProtocol + from ..sql.type_api import _MatchedOnType _T = TypeVar("_T", bound=Any) @@ -1211,21 +1214,24 @@ class registry: ) def _resolve_type( - self, python_type: Union[GenericProtocol[Any], Type[Any]] + self, python_type: _MatchedOnType ) -> Optional[sqltypes.TypeEngine[Any]]: - search: Tuple[Union[GenericProtocol[Any], Type[Any]], ...] + search: Iterable[Tuple[_MatchedOnType, Type[Any]]] if is_generic(python_type): python_type_type: Type[Any] = python_type.__origin__ - search = (python_type,) + search = ((python_type, python_type_type),) + elif is_newtype(python_type): + python_type_type = flatten_newtype(python_type) + search = ((python_type, python_type_type),) else: - # don't know why is_generic() TypeGuard[GenericProtocol[Any]] - # check above is not sufficient here python_type_type = cast("Type[Any]", python_type) - search = python_type_type.__mro__ + flattened = None + search = ((pt, pt) for pt in python_type_type.__mro__) - for pt in search: + for pt, flattened in search: + # we search through full __mro__ for types. however... sql_type = self.type_annotation_map.get(pt) if sql_type is None: sql_type = sqltypes._type_map_get(pt) # type: ignore # noqa: E501 @@ -1233,8 +1239,15 @@ class registry: if sql_type is not None: sql_type_inst = sqltypes.to_instance(sql_type) # type: ignore + # ... this additional step will reject most + # type -> supertype matches, such as if we had + # a MyInt(int) subclass. note also we pass NewType() + # here directly; these always have to be in the + # type_annotation_map to be useful resolved_sql_type = sql_type_inst._resolve_for_python_type( - python_type_type, pt + python_type_type, + pt, + flattened, ) if resolved_sql_type is not None: return resolved_sql_type diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index b5c79b4b97..717e6c0b22 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -59,7 +59,6 @@ from .. import util from ..engine import processors from ..util import langhelpers from ..util import OrderedDict -from ..util.typing import GenericProtocol from ..util.typing import Literal if TYPE_CHECKING: @@ -69,6 +68,7 @@ if TYPE_CHECKING: from .schema import MetaData from .type_api import _BindProcessorType from .type_api import _ComparatorFactory + from .type_api import _MatchedOnType from .type_api import _ResultProcessorType from ..engine.interfaces import Dialect @@ -1493,14 +1493,16 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): return enums, enums def _resolve_for_literal(self, value: Any) -> Enum: - typ = self._resolve_for_python_type(type(value), type(value)) + tv = type(value) + typ = self._resolve_for_python_type(tv, tv, tv) assert typ is not None return typ def _resolve_for_python_type( self, python_type: Type[Any], - matched_on: Union[GenericProtocol[Any], Type[Any]], + matched_on: _MatchedOnType, + matched_on_flattened: Type[Any], ) -> Optional[Enum]: if not issubclass(python_type, enum.Enum): return None diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 79c8897636..fefbf49974 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -19,6 +19,7 @@ from typing import cast from typing import Dict from typing import Generic from typing import Mapping +from typing import NewType from typing import Optional from typing import overload from typing import Sequence @@ -35,7 +36,6 @@ from .operators import ColumnOperators from .visitors import Visitable from .. import exc from .. import util -from ..util.typing import flatten_generic from ..util.typing import Protocol from ..util.typing import TypedDict from ..util.typing import TypeGuard @@ -65,6 +65,8 @@ _O = TypeVar("_O", bound=object) _TE = TypeVar("_TE", bound="TypeEngine[Any]") _CT = TypeVar("_CT", bound=Any) +_MatchedOnType = Union["GenericProtocol[Any]", NewType, Type[Any]] + # replace with pep-673 when applicable SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine[Any]") @@ -731,7 +733,8 @@ class TypeEngine(Visitable, Generic[_T]): def _resolve_for_python_type( self: SelfTypeEngine, python_type: Type[Any], - matched_on: Union[GenericProtocol[Any], Type[Any]], + matched_on: _MatchedOnType, + matched_on_flattened: Type[Any], ) -> Optional[SelfTypeEngine]: """given a Python type (e.g. ``int``, ``str``, etc. ) return an instance of this :class:`.TypeEngine` that's appropriate for this type. @@ -772,9 +775,7 @@ class TypeEngine(Visitable, Generic[_T]): """ - matched_on = flatten_generic(matched_on) - - if python_type is not matched_on: + if python_type is not matched_on_flattened: return None return self diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index e1670ed21b..51e95ecfa2 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -18,6 +18,7 @@ from typing import Dict from typing import ForwardRef from typing import Generic from typing import Iterable +from typing import NewType from typing import NoReturn from typing import Optional from typing import overload @@ -71,10 +72,9 @@ typing_get_args = get_args typing_get_origin = get_origin -# copied from TypeShed, required in order to implement -# MutableMapping.update() - -_AnnotationScanType = Union[Type[Any], str, ForwardRef, "GenericProtocol[Any]"] +_AnnotationScanType = Union[ + Type[Any], str, ForwardRef, NewType, "GenericProtocol[Any]" +] class ArgsTypeProcotol(Protocol): @@ -105,6 +105,8 @@ class GenericProtocol(Protocol[_T]): # ... +# copied from TypeShed, required in order to implement +# MutableMapping.update() class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): def keys(self) -> Iterable[_KT]: ... @@ -247,17 +249,23 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool: return type_ is not None and typing_get_origin(type_) is Annotated +def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]: + return hasattr(type_, "__supertype__") + + # doesn't work in 3.8, 3.7 as it passes a closure, not an + # object instance + # return isinstance(type_, NewType) + + def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]: return hasattr(type_, "__args__") and hasattr(type_, "__origin__") -def flatten_generic( - type_: Union[GenericProtocol[Any], Type[Any]] -) -> Type[Any]: - if is_generic(type_): - return type_.__origin__ - else: - return cast("Type[Any]", type_) +def flatten_newtype(type_: NewType) -> Type[Any]: + super_type = type_.__supertype__ + while is_newtype(super_type): + super_type = super_type.__supertype__ + return super_type def is_fwd_ref( diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index a83b02cd02..8d3961ef70 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -18,6 +18,7 @@ from typing import ClassVar from typing import Dict from typing import Generic from typing import List +from typing import NewType from typing import Optional from typing import Set from typing import Type @@ -460,9 +461,13 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): global anno_str, anno_str_optional, anno_str_mc global anno_str_optional_mc, anno_str_mc_nullable global anno_str_optional_mc_notnull + global newtype_str + anno_str = Annotated[str, 50] anno_str_optional = Annotated[Optional[str], 30] + newtype_str = NewType("MyType", str) + anno_str_mc = Annotated[str, mapped_column()] anno_str_optional_mc = Annotated[Optional[str], mapped_column()] anno_str_mc_nullable = Annotated[str, mapped_column(nullable=True)] @@ -471,7 +476,11 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): ] decl_base.registry.update_type_annotation_map( - {anno_str: String(50), anno_str_optional: String(30)} + { + anno_str: String(50), + anno_str_optional: String(30), + newtype_str: String(40), + } ) class User(decl_base): @@ -520,6 +529,11 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): *args, nullable=True ) + newtype_1a: Mapped[newtype_str] = mapped_column(*args) + newtype_1b: Mapped[newtype_str] = mapped_column( + *args, nullable=True + ) + is_false(User.__table__.c.lnnl_rndf.nullable) is_false(User.__table__.c.lnnl_rnnl.nullable) is_true(User.__table__.c.lnnl_rnl.nullable) @@ -589,6 +603,41 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(MyClass.__table__.c.data_two.nullable) eq_(MyClass.__table__.c.data_three.type.length, 50) + def test_pep484_newtypes_as_typemap_keys( + self, decl_base: Type[DeclarativeBase] + ): + + global str50, str30, str3050 + + str50 = NewType("str50", str) + str30 = NewType("str30", str) + str3050 = NewType("str30", str50) + + decl_base.registry.update_type_annotation_map( + {str50: String(50), str30: String(30), str3050: String(150)} + ) + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[str50] = mapped_column(primary_key=True) + data_one: Mapped[str30] + data_two: Mapped[str50] + data_three: Mapped[Optional[str30]] + data_four: Mapped[str3050] + + eq_(MyClass.__table__.c.data_one.type.length, 30) + is_false(MyClass.__table__.c.data_one.nullable) + + eq_(MyClass.__table__.c.data_two.type.length, 50) + is_false(MyClass.__table__.c.data_two.nullable) + + eq_(MyClass.__table__.c.data_three.type.length, 30) + is_true(MyClass.__table__.c.data_three.nullable) + + eq_(MyClass.__table__.c.data_four.type.length, 150) + is_false(MyClass.__table__.c.data_four.nullable) + def test_extract_base_type_from_pep593( self, decl_base: Type[DeclarativeBase] ): @@ -1396,7 +1445,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_type_secondary_resolution(self): class MyString(String): - def _resolve_for_python_type(self, python_type, matched_type): + def _resolve_for_python_type( + self, python_type, matched_type, matched_on_flattened + ): return String(length=42) Base = declarative_base(type_annotation_map={str: MyString}) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 9a2faf22a4..87fc298629 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -9,6 +9,7 @@ from typing import ClassVar from typing import Dict from typing import Generic from typing import List +from typing import NewType from typing import Optional from typing import Set from typing import Type @@ -451,9 +452,13 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): # anno only: global anno_str, anno_str_optional, anno_str_mc # anno only: global anno_str_optional_mc, anno_str_mc_nullable # anno only: global anno_str_optional_mc_notnull + # anno only: global newtype_str + anno_str = Annotated[str, 50] anno_str_optional = Annotated[Optional[str], 30] + newtype_str = NewType("MyType", str) + anno_str_mc = Annotated[str, mapped_column()] anno_str_optional_mc = Annotated[Optional[str], mapped_column()] anno_str_mc_nullable = Annotated[str, mapped_column(nullable=True)] @@ -462,7 +467,11 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): ] decl_base.registry.update_type_annotation_map( - {anno_str: String(50), anno_str_optional: String(30)} + { + anno_str: String(50), + anno_str_optional: String(30), + newtype_str: String(40), + } ) class User(decl_base): @@ -511,6 +520,11 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): *args, nullable=True ) + newtype_1a: Mapped[newtype_str] = mapped_column(*args) + newtype_1b: Mapped[newtype_str] = mapped_column( + *args, nullable=True + ) + is_false(User.__table__.c.lnnl_rndf.nullable) is_false(User.__table__.c.lnnl_rnnl.nullable) is_true(User.__table__.c.lnnl_rnl.nullable) @@ -580,6 +594,41 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(MyClass.__table__.c.data_two.nullable) eq_(MyClass.__table__.c.data_three.type.length, 50) + def test_pep484_newtypes_as_typemap_keys( + self, decl_base: Type[DeclarativeBase] + ): + + # anno only: global str50, str30, str3050 + + str50 = NewType("str50", str) + str30 = NewType("str30", str) + str3050 = NewType("str30", str50) + + decl_base.registry.update_type_annotation_map( + {str50: String(50), str30: String(30), str3050: String(150)} + ) + + class MyClass(decl_base): + __tablename__ = "my_table" + + id: Mapped[str50] = mapped_column(primary_key=True) + data_one: Mapped[str30] + data_two: Mapped[str50] + data_three: Mapped[Optional[str30]] + data_four: Mapped[str3050] + + eq_(MyClass.__table__.c.data_one.type.length, 30) + is_false(MyClass.__table__.c.data_one.nullable) + + eq_(MyClass.__table__.c.data_two.type.length, 50) + is_false(MyClass.__table__.c.data_two.nullable) + + eq_(MyClass.__table__.c.data_three.type.length, 30) + is_true(MyClass.__table__.c.data_three.nullable) + + eq_(MyClass.__table__.c.data_four.type.length, 150) + is_false(MyClass.__table__.c.data_four.nullable) + def test_extract_base_type_from_pep593( self, decl_base: Type[DeclarativeBase] ): @@ -1387,7 +1436,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_type_secondary_resolution(self): class MyString(String): - def _resolve_for_python_type(self, python_type, matched_type): + def _resolve_for_python_type( + self, python_type, matched_type, matched_on_flattened + ): return String(length=42) Base = declarative_base(type_annotation_map={str: MyString})