]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support pep695 when resolving type map types
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 30 Dec 2023 15:36:40 +0000 (10:36 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 31 Dec 2023 04:24:23 +0000 (23:24 -0500)
Added preliminary support for Python 3.12 pep-695 type alias structures,
when resolving custom type maps for ORM Annotated Declarative mappings.

Fixes: #10807
Change-Id: Ia28123ce1d6d1fd6bae5e8a037be4754c890f281
(cherry picked from commit 692525492986a109877d881b2f2936b610b9066f)

doc/build/changelog/unreleased_20/10807.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/util/typing.py
setup.cfg
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/10807.rst b/doc/build/changelog/unreleased_20/10807.rst
new file mode 100644 (file)
index 0000000..afceef6
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 10807
+
+    Added preliminary support for Python 3.12 pep-695 type alias structures,
+    when resolving custom type maps for ORM Annotated Declarative mappings.
+
index 60bd2ae4901097fc7b97c4e277a4fc2b9e1f9ce2..fe7ed146219cab6c6c19690fb85c8385e414ab5a 100644 (file)
@@ -77,6 +77,7 @@ from ..util.typing import flatten_newtype
 from ..util.typing import is_generic
 from ..util.typing import is_literal
 from ..util.typing import is_newtype
+from ..util.typing import is_pep695
 from ..util.typing import Literal
 from ..util.typing import Self
 
@@ -1264,6 +1265,10 @@ class registry:
         elif is_newtype(python_type):
             python_type_type = flatten_newtype(python_type)
             search = ((python_type, python_type_type),)
+        elif is_pep695(python_type):
+            python_type_type = python_type.__value__
+            flattened = None
+            search = ((python_type, python_type_type),)
         else:
             python_type_type = cast("Type[Any]", python_type)
             flattened = None
index 9226b01e61ac4eefa0a5012a132848f4fc8b9db4..ab387bc7afb21b451f21dc57096d0f2dab1f6cc9 100644 (file)
@@ -39,6 +39,7 @@ from .. import exc
 from .. import util
 from ..util.typing import Protocol
 from ..util.typing import Self
+from ..util.typing import TypeAliasType
 from ..util.typing import TypedDict
 from ..util.typing import TypeGuard
 
@@ -67,7 +68,9 @@ _O = TypeVar("_O", bound=object)
 _TE = TypeVar("_TE", bound="TypeEngine[Any]")
 _CT = TypeVar("_CT", bound=Any)
 
-_MatchedOnType = Union["GenericProtocol[Any]", NewType, Type[Any]]
+_MatchedOnType = Union[
+    "GenericProtocol[Any]", TypeAliasType, NewType, Type[Any]
+]
 
 
 class _NoValueInList(Enum):
index 4dd5176a3ee4f6848994a8fb156a6aa06e4c4c2e..be700a420cc9ca480a4f15a93806bf875767b2a1 100644 (file)
@@ -1530,6 +1530,12 @@ class SuiteRequirements(Requirements):
             lambda: util.py311, "Python 3.11 or above required"
         )
 
+    @property
+    def python312(self):
+        return exclusions.only_if(
+            lambda: util.py312, "Python 3.12 or above required"
+        )
+
     @property
     def cpython(self):
         return exclusions.only_if(
index faf71c89a295f916eb7aead6ed2317a15ea998d7..83735f93b74ef4b3958087000cf5909a00bf6bc8 100644 (file)
@@ -53,7 +53,7 @@ if True:  # zimports removes the tailing comments
     from typing_extensions import TypedDict as TypedDict  # 3.8
     from typing_extensions import TypeGuard as TypeGuard  # 3.10
     from typing_extensions import Self as Self  # 3.11
-
+    from typing_extensions import TypeAliasType as TypeAliasType  # 3.12
 
 _T = TypeVar("_T", bound=Any)
 _KT = TypeVar("_KT")
@@ -77,7 +77,7 @@ typing_get_origin = get_origin
 
 
 _AnnotationScanType = Union[
-    Type[Any], str, ForwardRef, NewType, "GenericProtocol[Any]"
+    Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
 ]
 
 
@@ -319,6 +319,10 @@ def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
     return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
 
 
+def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
+    return isinstance(type_, TypeAliasType)
+
+
 def flatten_newtype(type_: NewType) -> Type[Any]:
     super_type = type_.__supertype__
     while is_newtype(super_type):
index c8594c17885933458bafe13daaf9fa133d04894a..093961626f63984581752725eb3184d1b640a6fe 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -39,7 +39,7 @@ package_dir =
 install_requires =
     importlib-metadata;python_version<"3.8"
     greenlet != 0.4.17;(platform_machine=='aarch64' or (platform_machine=='ppc64le' or (platform_machine=='x86_64' or (platform_machine=='amd64' or (platform_machine=='AMD64' or (platform_machine=='win32' or platform_machine=='WIN32'))))))
-    typing-extensions >= 4.2.0
+    typing-extensions >= 4.6.0
 
 [options.extras_require]
 asyncio =
index e61900418e2e15b93ae800d25fe4913ba18ebc60..e64834b39d75d6eb96379e65c428d952e2beb317 100644 (file)
@@ -31,6 +31,8 @@ import uuid
 
 from typing_extensions import get_args as get_args
 from typing_extensions import Literal as Literal
+from typing_extensions import TypeAlias as TypeAlias
+from typing_extensions import TypedDict
 
 from sqlalchemy import BIGINT
 from sqlalchemy import BigInteger
@@ -93,6 +95,31 @@ from sqlalchemy.util import compat
 from sqlalchemy.util.typing import Annotated
 
 
+class _SomeDict1(TypedDict):
+    type: Literal["1"]
+
+
+class _SomeDict2(TypedDict):
+    type: Literal["2"]
+
+
+_UnionTypeAlias: TypeAlias = Union[_SomeDict1, _SomeDict2]
+
+_StrTypeAlias: TypeAlias = str
+
+_StrPep695: TypeAlias = Union[_SomeDict1, _SomeDict2]
+_UnionPep695: TypeAlias = str
+
+if compat.py312:
+    exec(
+        """
+type _UnionPep695 = _SomeDict1 | _SomeDict2
+type _StrPep695 = str
+""",
+        globals(),
+    )
+
+
 def expect_annotation_syntax_error(name):
     return expect_raises_message(
         sa_exc.ArgumentError,
@@ -731,6 +758,41 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(MyClass.__table__.c.data_two.nullable)
         eq_(MyClass.__table__.c.data_three.type.length, 50)
 
+    def test_plain_typealias_as_typemap_keys(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        decl_base.registry.update_type_annotation_map(
+            {_UnionTypeAlias: JSON, _StrTypeAlias: String(30)}
+        )
+
+        class Test(decl_base):
+            __tablename__ = "test"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[_StrTypeAlias]
+            structure: Mapped[_UnionTypeAlias]
+
+        eq_(Test.__table__.c.data.type.length, 30)
+        is_(Test.__table__.c.structure.type._type_affinity, JSON)
+
+    @testing.requires.python312
+    def test_pep695_typealias_as_typemap_keys(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        """test #10807"""
+
+        decl_base.registry.update_type_annotation_map(
+            {_UnionPep695: JSON, _StrPep695: String(30)}
+        )
+
+        class Test(decl_base):
+            __tablename__ = "test"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[_StrPep695]  # type: ignore
+            structure: Mapped[_UnionPep695]  # type: ignore
+
+        eq_(Test.__table__.c.data.type.length, 30)
+        is_(Test.__table__.c.structure.type._type_affinity, JSON)
+
     @testing.requires.python310
     def test_we_got_all_attrs_test_annotated(self):
         argnames = _py_inspect.getfullargspec(mapped_column)
index 8da83ccb9d6cf25a4558a1d68315255ed7451a6a..44327324cab8f50e9c6911552b78c7a8971d72a7 100644 (file)
@@ -22,6 +22,8 @@ import uuid
 
 from typing_extensions import get_args as get_args
 from typing_extensions import Literal as Literal
+from typing_extensions import TypeAlias as TypeAlias
+from typing_extensions import TypedDict
 
 from sqlalchemy import BIGINT
 from sqlalchemy import BigInteger
@@ -84,6 +86,31 @@ from sqlalchemy.util import compat
 from sqlalchemy.util.typing import Annotated
 
 
+class _SomeDict1(TypedDict):
+    type: Literal["1"]
+
+
+class _SomeDict2(TypedDict):
+    type: Literal["2"]
+
+
+_UnionTypeAlias: TypeAlias = Union[_SomeDict1, _SomeDict2]
+
+_StrTypeAlias: TypeAlias = str
+
+_StrPep695: TypeAlias = Union[_SomeDict1, _SomeDict2]
+_UnionPep695: TypeAlias = str
+
+if compat.py312:
+    exec(
+        """
+type _UnionPep695 = _SomeDict1 | _SomeDict2
+type _StrPep695 = str
+""",
+        globals(),
+    )
+
+
 def expect_annotation_syntax_error(name):
     return expect_raises_message(
         sa_exc.ArgumentError,
@@ -722,6 +749,41 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(MyClass.__table__.c.data_two.nullable)
         eq_(MyClass.__table__.c.data_three.type.length, 50)
 
+    def test_plain_typealias_as_typemap_keys(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        decl_base.registry.update_type_annotation_map(
+            {_UnionTypeAlias: JSON, _StrTypeAlias: String(30)}
+        )
+
+        class Test(decl_base):
+            __tablename__ = "test"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[_StrTypeAlias]
+            structure: Mapped[_UnionTypeAlias]
+
+        eq_(Test.__table__.c.data.type.length, 30)
+        is_(Test.__table__.c.structure.type._type_affinity, JSON)
+
+    @testing.requires.python312
+    def test_pep695_typealias_as_typemap_keys(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        """test #10807"""
+
+        decl_base.registry.update_type_annotation_map(
+            {_UnionPep695: JSON, _StrPep695: String(30)}
+        )
+
+        class Test(decl_base):
+            __tablename__ = "test"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[_StrPep695]  # type: ignore
+            structure: Mapped[_UnionPep695]  # type: ignore
+
+        eq_(Test.__table__.c.data.type.length, 30)
+        is_(Test.__table__.c.structure.type._type_affinity, JSON)
+
     @testing.requires.python310
     def test_we_got_all_attrs_test_annotated(self):
         argnames = _py_inspect.getfullargspec(mapped_column)