]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
support pep695 when resolving type map types
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 30 Dec 2023 15:36:40 +0000 (10:36 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 30 Dec 2023 17:05:17 +0000 (12:05 -0500)
Added preliminary support for Python 3.12 pep-695 type alias structures,
when resolving custom type maps for ORM Annotated Declarative mappings.

Fixes: #10807
Change-Id: Ia28123ce1d6d1fd6bae5e8a037be4754c890f281

doc/build/changelog/unreleased_20/10807.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/sql/type_api.py
lib/sqlalchemy/testing/requirements.py
lib/sqlalchemy/util/typing.py
setup.cfg
test/orm/declarative/test_tm_future_annotations_sync.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/10807.rst b/doc/build/changelog/unreleased_20/10807.rst
new file mode 100644 (file)
index 0000000..afceef6
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 10807
+
+    Added preliminary support for Python 3.12 pep-695 type alias structures,
+    when resolving custom type maps for ORM Annotated Declarative mappings.
+
index 9520fbb971c8fe145babca2f5b97595f7e2194bd..e8e94f6a957a3a3b1a063542bf8f2c983b1c4de7 100644 (file)
@@ -77,6 +77,7 @@ from ..util.typing import flatten_newtype
 from ..util.typing import is_generic
 from ..util.typing import is_literal
 from ..util.typing import is_newtype
+from ..util.typing import is_pep695
 from ..util.typing import Literal
 from ..util.typing import Self
 
@@ -1264,6 +1265,10 @@ class registry:
         elif is_newtype(python_type):
             python_type_type = flatten_newtype(python_type)
             search = ((python_type, python_type_type),)
+        elif is_pep695(python_type):
+            python_type_type = python_type.__value__
+            flattened = None
+            search = ((python_type, python_type_type),)
         else:
             python_type_type = cast("Type[Any]", python_type)
             flattened = None
index 5b26e05cab099d39cbe46b03b827a302dafc3893..6a01fcec7017f4748e4068979fcd96adb6a212d3 100644 (file)
@@ -40,6 +40,7 @@ from .visitors import Visitable
 from .. import exc
 from .. import util
 from ..util.typing import Self
+from ..util.typing import TypeAliasType
 from ..util.typing import TypeGuard
 
 # these are back-assigned by sqltypes.
@@ -67,7 +68,9 @@ _O = TypeVar("_O", bound=object)
 _TE = TypeVar("_TE", bound="TypeEngine[Any]")
 _CT = TypeVar("_CT", bound=Any)
 
-_MatchedOnType = Union["GenericProtocol[Any]", NewType, Type[Any]]
+_MatchedOnType = Union[
+    "GenericProtocol[Any]", TypeAliasType, NewType, Type[Any]
+]
 
 
 class _NoValueInList(Enum):
index b288cbbaf49a71553b1eadf6849f12d93dc98ef9..467138c9b31163069bc0a3b95e96fcea85f2e32e 100644 (file)
@@ -1524,6 +1524,12 @@ class SuiteRequirements(Requirements):
             lambda: util.py311, "Python 3.11 or above required"
         )
 
+    @property
+    def python312(self):
+        return exclusions.only_if(
+            lambda: util.py312, "Python 3.12 or above required"
+        )
+
     @property
     def cpython(self):
         return exclusions.only_if(
index c4f41d9151825518b6fb38fea7bdc05621a8910f..a7724d0832129d8d6adf968910b43d0bc6645173 100644 (file)
@@ -50,7 +50,7 @@ if True:  # zimports removes the tailing comments
     from typing_extensions import TypeAlias as TypeAlias  # 3.10
     from typing_extensions import TypeGuard as TypeGuard  # 3.10
     from typing_extensions import Self as Self  # 3.11
-
+    from typing_extensions import TypeAliasType as TypeAliasType  # 3.12
 
 _T = TypeVar("_T", bound=Any)
 _KT = TypeVar("_KT")
@@ -74,7 +74,7 @@ typing_get_origin = get_origin
 
 
 _AnnotationScanType = Union[
-    Type[Any], str, ForwardRef, NewType, "GenericProtocol[Any]"
+    Type[Any], str, ForwardRef, NewType, TypeAliasType, "GenericProtocol[Any]"
 ]
 
 
@@ -316,6 +316,10 @@ def is_generic(type_: _AnnotationScanType) -> TypeGuard[GenericProtocol[Any]]:
     return hasattr(type_, "__args__") and hasattr(type_, "__origin__")
 
 
+def is_pep695(type_: _AnnotationScanType) -> TypeGuard[TypeAliasType]:
+    return isinstance(type_, TypeAliasType)
+
+
 def flatten_newtype(type_: NewType) -> Type[Any]:
     super_type = type_.__supertype__
     while is_newtype(super_type):
index 129a5aa82d9f28012a446ab9547409d5fe416fa5..f9248486262ed71f6e5d300996bd94c18d4eead2 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -36,7 +36,7 @@ package_dir =
     =lib
 
 install_requires =
-    typing-extensions >= 4.2.0
+    typing-extensions >= 4.6.0
 
 [options.extras_require]
 asyncio =
index e61900418e2e15b93ae800d25fe4913ba18ebc60..b3b83b3de2cc95a6a74baab81abcea280e4bf761 100644 (file)
@@ -25,12 +25,14 @@ from typing import Optional
 from typing import Set
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypedDict
 from typing import TypeVar
 from typing import Union
 import uuid
 
 from typing_extensions import get_args as get_args
 from typing_extensions import Literal as Literal
+from typing_extensions import TypeAlias as TypeAlias
 
 from sqlalchemy import BIGINT
 from sqlalchemy import BigInteger
@@ -93,6 +95,31 @@ from sqlalchemy.util import compat
 from sqlalchemy.util.typing import Annotated
 
 
+class _SomeDict1(TypedDict):
+    type: Literal["1"]
+
+
+class _SomeDict2(TypedDict):
+    type: Literal["2"]
+
+
+_UnionTypeAlias: TypeAlias = Union[_SomeDict1, _SomeDict2]
+
+_StrTypeAlias: TypeAlias = str
+
+_StrPep695: TypeAlias = Union[_SomeDict1, _SomeDict2]
+_UnionPep695: TypeAlias = str
+
+if compat.py312:
+    exec(
+        """
+type _UnionPep695 = _SomeDict1 | _SomeDict2
+type _StrPep695 = str
+""",
+        globals(),
+    )
+
+
 def expect_annotation_syntax_error(name):
     return expect_raises_message(
         sa_exc.ArgumentError,
@@ -731,6 +758,41 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(MyClass.__table__.c.data_two.nullable)
         eq_(MyClass.__table__.c.data_three.type.length, 50)
 
+    def test_plain_typealias_as_typemap_keys(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        decl_base.registry.update_type_annotation_map(
+            {_UnionTypeAlias: JSON, _StrTypeAlias: String(30)}
+        )
+
+        class Test(decl_base):
+            __tablename__ = "test"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[_StrTypeAlias]
+            structure: Mapped[_UnionTypeAlias]
+
+        eq_(Test.__table__.c.data.type.length, 30)
+        is_(Test.__table__.c.structure.type._type_affinity, JSON)
+
+    @testing.requires.python312
+    def test_pep695_typealias_as_typemap_keys(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        """test #10807"""
+
+        decl_base.registry.update_type_annotation_map(
+            {_UnionPep695: JSON, _StrPep695: String(30)}
+        )
+
+        class Test(decl_base):
+            __tablename__ = "test"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[_StrPep695]  # type: ignore
+            structure: Mapped[_UnionPep695]  # type: ignore
+
+        eq_(Test.__table__.c.data.type.length, 30)
+        is_(Test.__table__.c.structure.type._type_affinity, JSON)
+
     @testing.requires.python310
     def test_we_got_all_attrs_test_annotated(self):
         argnames = _py_inspect.getfullargspec(mapped_column)
index 8da83ccb9d6cf25a4558a1d68315255ed7451a6a..8dcf2013939dfe013046b51797b8ac2062e6e0af 100644 (file)
@@ -16,12 +16,14 @@ from typing import Optional
 from typing import Set
 from typing import Type
 from typing import TYPE_CHECKING
+from typing import TypedDict
 from typing import TypeVar
 from typing import Union
 import uuid
 
 from typing_extensions import get_args as get_args
 from typing_extensions import Literal as Literal
+from typing_extensions import TypeAlias as TypeAlias
 
 from sqlalchemy import BIGINT
 from sqlalchemy import BigInteger
@@ -84,6 +86,31 @@ from sqlalchemy.util import compat
 from sqlalchemy.util.typing import Annotated
 
 
+class _SomeDict1(TypedDict):
+    type: Literal["1"]
+
+
+class _SomeDict2(TypedDict):
+    type: Literal["2"]
+
+
+_UnionTypeAlias: TypeAlias = Union[_SomeDict1, _SomeDict2]
+
+_StrTypeAlias: TypeAlias = str
+
+_StrPep695: TypeAlias = Union[_SomeDict1, _SomeDict2]
+_UnionPep695: TypeAlias = str
+
+if compat.py312:
+    exec(
+        """
+type _UnionPep695 = _SomeDict1 | _SomeDict2
+type _StrPep695 = str
+""",
+        globals(),
+    )
+
+
 def expect_annotation_syntax_error(name):
     return expect_raises_message(
         sa_exc.ArgumentError,
@@ -722,6 +749,41 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(MyClass.__table__.c.data_two.nullable)
         eq_(MyClass.__table__.c.data_three.type.length, 50)
 
+    def test_plain_typealias_as_typemap_keys(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        decl_base.registry.update_type_annotation_map(
+            {_UnionTypeAlias: JSON, _StrTypeAlias: String(30)}
+        )
+
+        class Test(decl_base):
+            __tablename__ = "test"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[_StrTypeAlias]
+            structure: Mapped[_UnionTypeAlias]
+
+        eq_(Test.__table__.c.data.type.length, 30)
+        is_(Test.__table__.c.structure.type._type_affinity, JSON)
+
+    @testing.requires.python312
+    def test_pep695_typealias_as_typemap_keys(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        """test #10807"""
+
+        decl_base.registry.update_type_annotation_map(
+            {_UnionPep695: JSON, _StrPep695: String(30)}
+        )
+
+        class Test(decl_base):
+            __tablename__ = "test"
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[_StrPep695]  # type: ignore
+            structure: Mapped[_UnionPep695]  # type: ignore
+
+        eq_(Test.__table__.c.data.type.length, 30)
+        is_(Test.__table__.c.structure.type._type_affinity, JSON)
+
     @testing.requires.python310
     def test_we_got_all_attrs_test_annotated(self):
         argnames = _py_inspect.getfullargspec(mapped_column)