]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support NewType in type_annotation_map
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 30 Jan 2023 20:12:52 +0000 (15:12 -0500)
committermike bayer <mike_mp@zzzcomputing.com>
Tue, 31 Jan 2023 19:13:16 +0000 (19:13 +0000)
Added support for :pep:`484` ``NewType`` to be used in the
:paramref:`_orm.registry.type_annotation_map` as well as within
:class:`.Mapped` constructs. These types will behave in the same way as
custom subclasses of types right now; they must appear explicitly within
the :paramref:`_orm.registry.type_annotation_map` to be mapped.

Within this change, the lookup between decl_api._resolve_type
and TypeEngine._resolve_for_python_type is streamlined to not
inspect the given type multiple times, instead passing
in from decl_api to TypeEngine the already "flattened" version
of a Generic or NewType type.

Fixes: #9175
Change-Id: I227cf84b4b88e4567fa2d1d7da0c05b54e00c562

doc/build/changelog/unreleased_20/9175.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/util/typing.py
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/9175.rst b/doc/build/changelog/unreleased_20/9175.rst
new file mode 100644 (file)
index 0000000..f791849
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm
+    :tickets: 9175
+
+    Added support for :pep:`484` ``NewType`` to be used in the
+    :paramref:`_orm.registry.type_annotation_map` as well as within
+    :class:`.Mapped` constructs. These types will behave in the same way as
+    custom subclasses of types right now; they must appear explicitly within
+    the :paramref:`_orm.registry.type_annotation_map` to be mapped.
index a46c1a7fbb294695b8939bc2ea38f1b92f8d1a27..4f84438330ba00ad37cc5545fef7e91163903b52 100644 (file)
@@ -19,6 +19,7 @@ from typing import ClassVar
 from typing import Dict
 from typing import FrozenSet
 from typing import Generic
+from typing import Iterable
 from typing import Iterator
 from typing import Mapping
 from typing import Optional
@@ -74,7 +75,9 @@ from ..util import hybridmethod
 from ..util import hybridproperty
 from ..util import typing as compat_typing
 from ..util.typing import CallableReference
+from ..util.typing import flatten_newtype
 from ..util.typing import is_generic
+from ..util.typing import is_newtype
 from ..util.typing import Literal
 
 if TYPE_CHECKING:
@@ -85,7 +88,7 @@ if TYPE_CHECKING:
     from .interfaces import MapperProperty
     from .state import InstanceState  # noqa
     from ..sql._typing import _TypeEngineArgument
-    from ..util.typing import GenericProtocol
+    from ..sql.type_api import _MatchedOnType
 
 _T = TypeVar("_T", bound=Any)
 
@@ -1211,21 +1214,24 @@ class registry:
         )
 
     def _resolve_type(
-        self, python_type: Union[GenericProtocol[Any], Type[Any]]
+        self, python_type: _MatchedOnType
     ) -> Optional[sqltypes.TypeEngine[Any]]:
 
-        search: Tuple[Union[GenericProtocol[Any], Type[Any]], ...]
+        search: Iterable[Tuple[_MatchedOnType, Type[Any]]]
 
         if is_generic(python_type):
             python_type_type: Type[Any] = python_type.__origin__
-            search = (python_type,)
+            search = ((python_type, python_type_type),)
+        elif is_newtype(python_type):
+            python_type_type = flatten_newtype(python_type)
+            search = ((python_type, python_type_type),)
         else:
-            # don't know why is_generic() TypeGuard[GenericProtocol[Any]]
-            # check above is not sufficient here
             python_type_type = cast("Type[Any]", python_type)
-            search = python_type_type.__mro__
+            flattened = None
+            search = ((pt, pt) for pt in python_type_type.__mro__)
 
-        for pt in search:
+        for pt, flattened in search:
+            # we search through full __mro__ for types.  however...
             sql_type = self.type_annotation_map.get(pt)
             if sql_type is None:
                 sql_type = sqltypes._type_map_get(pt)  # type: ignore  # noqa: E501
@@ -1233,8 +1239,15 @@ class registry:
             if sql_type is not None:
                 sql_type_inst = sqltypes.to_instance(sql_type)  # type: ignore
 
+                # ... this additional step will reject most
+                # type -> supertype matches, such as if we had
+                # a MyInt(int) subclass.  note also we pass NewType()
+                # here directly; these always have to be in the
+                # type_annotation_map to be useful
                 resolved_sql_type = sql_type_inst._resolve_for_python_type(
-                    python_type_type, pt
+                    python_type_type,
+                    pt,
+                    flattened,
                 )
                 if resolved_sql_type is not None:
                     return resolved_sql_type
index b5c79b4b97a9d1c383377cae976b2ee1eacd7232..717e6c0b229ec7169f007abb506b7e3d01f39e28 100644 (file)
@@ -59,7 +59,6 @@ from .. import util
 from ..engine import processors
 from ..util import langhelpers
 from ..util import OrderedDict
-from ..util.typing import GenericProtocol
 from ..util.typing import Literal
 
 if TYPE_CHECKING:
@@ -69,6 +68,7 @@ if TYPE_CHECKING:
     from .schema import MetaData
     from .type_api import _BindProcessorType
     from .type_api import _ComparatorFactory
+    from .type_api import _MatchedOnType
     from .type_api import _ResultProcessorType
     from ..engine.interfaces import Dialect
 
@@ -1493,14 +1493,16 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
             return enums, enums
 
     def _resolve_for_literal(self, value: Any) -> Enum:
-        typ = self._resolve_for_python_type(type(value), type(value))
+        tv = type(value)
+        typ = self._resolve_for_python_type(tv, tv, tv)
         assert typ is not None
         return typ
 
     def _resolve_for_python_type(
         self,
         python_type: Type[Any],
-        matched_on: Union[GenericProtocol[Any], Type[Any]],
+        matched_on: _MatchedOnType,
+        matched_on_flattened: Type[Any],
     ) -> Optional[Enum]:
         if not issubclass(python_type, enum.Enum):
             return None
index 79c8897636fcf86c66518d94207cb58c48dd379d..fefbf49974f266e7e7d5e967350a91fc1b2b6cb2 100644 (file)
@@ -19,6 +19,7 @@ from typing import cast
 from typing import Dict
 from typing import Generic
 from typing import Mapping
+from typing import NewType
 from typing import Optional
 from typing import overload
 from typing import Sequence
@@ -35,7 +36,6 @@ from .operators import ColumnOperators
 from .visitors import Visitable
 from .. import exc
 from .. import util
-from ..util.typing import flatten_generic
 from ..util.typing import Protocol
 from ..util.typing import TypedDict
 from ..util.typing import TypeGuard
@@ -65,6 +65,8 @@ _O = TypeVar("_O", bound=object)
 _TE = TypeVar("_TE", bound="TypeEngine[Any]")
 _CT = TypeVar("_CT", bound=Any)
 
+_MatchedOnType = Union["GenericProtocol[Any]", NewType, Type[Any]]
+
 # replace with pep-673 when applicable
 SelfTypeEngine = typing.TypeVar("SelfTypeEngine", bound="TypeEngine[Any]")
 
@@ -731,7 +733,8 @@ class TypeEngine(Visitable, Generic[_T]):
     def _resolve_for_python_type(
         self: SelfTypeEngine,
         python_type: Type[Any],
-        matched_on: Union[GenericProtocol[Any], Type[Any]],
+        matched_on: _MatchedOnType,
+        matched_on_flattened: Type[Any],
     ) -> Optional[SelfTypeEngine]:
         """given a Python type (e.g. ``int``, ``str``, etc. ) return an
         instance of this :class:`.TypeEngine` that's appropriate for this type.
@@ -772,9 +775,7 @@ class TypeEngine(Visitable, Generic[_T]):
 
         """
 
-        matched_on = flatten_generic(matched_on)
-
-        if python_type is not matched_on:
+        if python_type is not matched_on_flattened:
             return None
 
         return self
index e1670ed21b0c2eac0d399669e3bf1f35646023d9..51e95ecfa2af009fa16f9180807fc17db1af7d2a 100644 (file)
@@ -18,6 +18,7 @@ from typing import Dict
 from typing import ForwardRef
 from typing import Generic
 from typing import Iterable
+from typing import NewType
 from typing import NoReturn
 from typing import Optional
 from typing import overload
@@ -71,10 +72,9 @@ typing_get_args = get_args
 typing_get_origin = get_origin
 
 
-# copied from TypeShed, required in order to implement
-# MutableMapping.update()
-
-_AnnotationScanType = Union[Type[Any], str, ForwardRef, "GenericProtocol[Any]"]
+_AnnotationScanType = Union[
+    Type[Any], str, ForwardRef, NewType, "GenericProtocol[Any]"
+]
 
 
 class ArgsTypeProcotol(Protocol):
@@ -105,6 +105,8 @@ class GenericProtocol(Protocol[_T]):
     #     ...
 
 
+# copied from TypeShed, required in order to implement
+# MutableMapping.update()
 class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]):
     def keys(self) -> Iterable[_KT]:
         ...
@@ -247,17 +249,23 @@ def is_pep593(type_: Optional[_AnnotationScanType]) -> bool:
     return type_ is not None and typing_get_origin(type_) is Annotated
 
 
+def is_newtype(type_: Optional[_AnnotationScanType]) -> TypeGuard[NewType]:
+    return hasattr(type_, "__supertype__")
+
+    # doesn't work in 3.8, 3.7 as it passes a closure, not an
+    # object instance
+    # return isinstance(type_, NewType)
+
+
 def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
     return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
 
 
-def flatten_generic(
-    type_: Union[GenericProtocol[Any], Type[Any]]
-) -> Type[Any]:
-    if is_generic(type_):
-        return type_.__origin__
-    else:
-        return cast("Type[Any]", type_)
+def flatten_newtype(type_: NewType) -> Type[Any]:
+    super_type = type_.__supertype__
+    while is_newtype(super_type):
+        super_type = super_type.__supertype__
+    return super_type
 
 
 def is_fwd_ref(
index a83b02cd028a7bf4639a54738bee74fad2238ab0..8d3961ef701bab59033e255100a066e8dbdec3c5 100644 (file)
@@ -18,6 +18,7 @@ from typing import ClassVar
 from typing import Dict
 from typing import Generic
 from typing import List
+from typing import NewType
 from typing import Optional
 from typing import Set
 from typing import Type
@@ -460,9 +461,13 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         global anno_str, anno_str_optional, anno_str_mc
         global anno_str_optional_mc, anno_str_mc_nullable
         global anno_str_optional_mc_notnull
+        global newtype_str
+
         anno_str = Annotated[str, 50]
         anno_str_optional = Annotated[Optional[str], 30]
 
+        newtype_str = NewType("MyType", str)
+
         anno_str_mc = Annotated[str, mapped_column()]
         anno_str_optional_mc = Annotated[Optional[str], mapped_column()]
         anno_str_mc_nullable = Annotated[str, mapped_column(nullable=True)]
@@ -471,7 +476,11 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         ]
 
         decl_base.registry.update_type_annotation_map(
-            {anno_str: String(50), anno_str_optional: String(30)}
+            {
+                anno_str: String(50),
+                anno_str_optional: String(30),
+                newtype_str: String(40),
+            }
         )
 
         class User(decl_base):
@@ -520,6 +529,11 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 *args, nullable=True
             )
 
+            newtype_1a: Mapped[newtype_str] = mapped_column(*args)
+            newtype_1b: Mapped[newtype_str] = mapped_column(
+                *args, nullable=True
+            )
+
         is_false(User.__table__.c.lnnl_rndf.nullable)
         is_false(User.__table__.c.lnnl_rnnl.nullable)
         is_true(User.__table__.c.lnnl_rnl.nullable)
@@ -589,6 +603,41 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(MyClass.__table__.c.data_two.nullable)
         eq_(MyClass.__table__.c.data_three.type.length, 50)
 
+    def test_pep484_newtypes_as_typemap_keys(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+
+        global str50, str30, str3050
+
+        str50 = NewType("str50", str)
+        str30 = NewType("str30", str)
+        str3050 = NewType("str30", str50)
+
+        decl_base.registry.update_type_annotation_map(
+            {str50: String(50), str30: String(30), str3050: String(150)}
+        )
+
+        class MyClass(decl_base):
+            __tablename__ = "my_table"
+
+            id: Mapped[str50] = mapped_column(primary_key=True)
+            data_one: Mapped[str30]
+            data_two: Mapped[str50]
+            data_three: Mapped[Optional[str30]]
+            data_four: Mapped[str3050]
+
+        eq_(MyClass.__table__.c.data_one.type.length, 30)
+        is_false(MyClass.__table__.c.data_one.nullable)
+
+        eq_(MyClass.__table__.c.data_two.type.length, 50)
+        is_false(MyClass.__table__.c.data_two.nullable)
+
+        eq_(MyClass.__table__.c.data_three.type.length, 30)
+        is_true(MyClass.__table__.c.data_three.nullable)
+
+        eq_(MyClass.__table__.c.data_four.type.length, 150)
+        is_false(MyClass.__table__.c.data_four.nullable)
+
     def test_extract_base_type_from_pep593(
         self, decl_base: Type[DeclarativeBase]
     ):
@@ -1396,7 +1445,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
     def test_type_secondary_resolution(self):
         class MyString(String):
-            def _resolve_for_python_type(self, python_type, matched_type):
+            def _resolve_for_python_type(
+                self, python_type, matched_type, matched_on_flattened
+            ):
                 return String(length=42)
 
         Base = declarative_base(type_annotation_map={str: MyString})
index 9a2faf22a46e1e2f3d44399d3abc921df02e2f11..87fc298629e5438f86bc847e5a9297418f8fd808 100644 (file)
@@ -9,6 +9,7 @@ from typing import ClassVar
 from typing import Dict
 from typing import Generic
 from typing import List
+from typing import NewType
 from typing import Optional
 from typing import Set
 from typing import Type
@@ -451,9 +452,13 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         # anno only: global anno_str, anno_str_optional, anno_str_mc
         # anno only: global anno_str_optional_mc, anno_str_mc_nullable
         # anno only: global anno_str_optional_mc_notnull
+        # anno only: global newtype_str
+
         anno_str = Annotated[str, 50]
         anno_str_optional = Annotated[Optional[str], 30]
 
+        newtype_str = NewType("MyType", str)
+
         anno_str_mc = Annotated[str, mapped_column()]
         anno_str_optional_mc = Annotated[Optional[str], mapped_column()]
         anno_str_mc_nullable = Annotated[str, mapped_column(nullable=True)]
@@ -462,7 +467,11 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         ]
 
         decl_base.registry.update_type_annotation_map(
-            {anno_str: String(50), anno_str_optional: String(30)}
+            {
+                anno_str: String(50),
+                anno_str_optional: String(30),
+                newtype_str: String(40),
+            }
         )
 
         class User(decl_base):
@@ -511,6 +520,11 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 *args, nullable=True
             )
 
+            newtype_1a: Mapped[newtype_str] = mapped_column(*args)
+            newtype_1b: Mapped[newtype_str] = mapped_column(
+                *args, nullable=True
+            )
+
         is_false(User.__table__.c.lnnl_rndf.nullable)
         is_false(User.__table__.c.lnnl_rnnl.nullable)
         is_true(User.__table__.c.lnnl_rnl.nullable)
@@ -580,6 +594,41 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(MyClass.__table__.c.data_two.nullable)
         eq_(MyClass.__table__.c.data_three.type.length, 50)
 
+    def test_pep484_newtypes_as_typemap_keys(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+
+        # anno only: global str50, str30, str3050
+
+        str50 = NewType("str50", str)
+        str30 = NewType("str30", str)
+        str3050 = NewType("str30", str50)
+
+        decl_base.registry.update_type_annotation_map(
+            {str50: String(50), str30: String(30), str3050: String(150)}
+        )
+
+        class MyClass(decl_base):
+            __tablename__ = "my_table"
+
+            id: Mapped[str50] = mapped_column(primary_key=True)
+            data_one: Mapped[str30]
+            data_two: Mapped[str50]
+            data_three: Mapped[Optional[str30]]
+            data_four: Mapped[str3050]
+
+        eq_(MyClass.__table__.c.data_one.type.length, 30)
+        is_false(MyClass.__table__.c.data_one.nullable)
+
+        eq_(MyClass.__table__.c.data_two.type.length, 50)
+        is_false(MyClass.__table__.c.data_two.nullable)
+
+        eq_(MyClass.__table__.c.data_three.type.length, 30)
+        is_true(MyClass.__table__.c.data_three.nullable)
+
+        eq_(MyClass.__table__.c.data_four.type.length, 150)
+        is_false(MyClass.__table__.c.data_four.nullable)
+
     def test_extract_base_type_from_pep593(
         self, decl_base: Type[DeclarativeBase]
     ):
@@ -1387,7 +1436,9 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
 
     def test_type_secondary_resolution(self):
         class MyString(String):
-            def _resolve_for_python_type(self, python_type, matched_type):
+            def _resolve_for_python_type(
+                self, python_type, matched_type, matched_on_flattened
+            ):
                 return String(length=42)
 
         Base = declarative_base(type_annotation_map={str: MyString})