]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
call super().__init_subclass__(); support GenericAlias
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 18 Oct 2022 17:25:06 +0000 (13:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 19 Oct 2022 12:00:35 +0000 (08:00 -0400)
Improved the :class:`.DeclarativeBase` class so that when combined with
other mixins like :class:`.MappedAsDataclass`, the order of the classes may
be in either order.

Added support for mapped classes that are also ``Generic`` subclasses,
to be specified as a ``GenericAlias`` object (e.g. ``MyClass[str]``)
within statements and calls to :func:`_sa.inspect`.

Fixes: #8665
Change-Id: I03063a28b0438a44b9e028fd9d45e8ce08bd18c4

doc/build/changelog/unreleased_20/8665.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/util.py
test/orm/declarative/test_dc_transforms.py
test/orm/declarative/test_typed_mapping.py

diff --git a/doc/build/changelog/unreleased_20/8665.rst b/doc/build/changelog/unreleased_20/8665.rst
new file mode 100644 (file)
index 0000000..4f5535a
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, declarative, orm
+    :tickets: 8665
+
+    Improved the :class:`.DeclarativeBase` class so that when combined with
+    other mixins like :class:`.MappedAsDataclass`, the order of the classes may
+    be in either order.
+
+
+.. change::
+    :tags: usecase, declarative, orm
+    :tickets: 8665
+
+    Added support for mapped classes that are also ``Generic`` subclasses,
+    to be specified as a ``GenericAlias`` object (e.g. ``MyClass[str]``)
+    within statements and calls to :func:`_sa.inspect`.
+
+
index 5724d53a25ee68f6fe1f6f1b1bbd80d8ffcb4a2b..a43b59a45c7e3813202a5ae1d3b1ca3f3714b6c4 100644 (file)
@@ -581,7 +581,6 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative):
         match_args: Union[_NoArg, bool] = _NoArg.NO_ARG,
         kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG,
     ) -> None:
-
         apply_dc_transforms: _DataclassArguments = {
             "init": init,
             "repr": repr,
@@ -696,6 +695,7 @@ class DeclarativeBase(
             _setup_declarative_base(cls)
         else:
             _as_declarative(cls._sa_registry, cls, cls.__dict__)
+        super().__init_subclass__()
 
 
 def _check_not_declarative(cls: Type[Any], base: Type[Any]) -> None:
index 0f16df9c8858c485d562f5c3eea3b3aaef173cc5..b9d1b50e7a87574e86d66dd72dfe492285f399fd 100644 (file)
@@ -79,6 +79,7 @@ from ..util.langhelpers import MemoizedSlots
 from ..util.typing import de_stringify_annotation
 from ..util.typing import is_origin_of_cls
 from ..util.typing import Literal
+from ..util.typing import typing_get_origin
 
 if typing.TYPE_CHECKING:
     from ._typing import _EntityType
@@ -1361,6 +1362,18 @@ def _inspect_mc(
         return mapper
 
 
+GenericAlias = type(List[_T])
+
+
+@inspection._inspects(GenericAlias)
+def _inspect_generic_alias(
+    class_: Type[_O],
+) -> Optional[Mapper[_O]]:
+
+    origin = cast("Type[_O]", typing_get_origin(class_))
+    return _inspect_mc(origin)
+
+
 @inspection._self_inspects
 class Bundle(
     ORMColumnsClauseRole[_T],
index d6f1532ee5c39fa8c8b1af20d6c0cd4fcc21a1f4..ae1f9b35e11426a69fff199515750f433b2330d2 100644 (file)
@@ -3,10 +3,13 @@ import inspect as pyinspect
 from itertools import product
 from typing import Any
 from typing import ClassVar
+from typing import Dict
+from typing import Generic
 from typing import List
 from typing import Optional
 from typing import Set
 from typing import Type
+from typing import TypeVar
 from unittest import mock
 
 from typing_extensions import Annotated
@@ -17,6 +20,7 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import JSON
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
@@ -47,6 +51,29 @@ from sqlalchemy.util import compat
 
 
 class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
+    @testing.fixture(params=["(MAD, DB)", "(DB, MAD)"])
+    def dc_decl_base(self, request, metadata):
+        _md = metadata
+
+        if request.param == "(MAD, DB)":
+
+            class Base(MappedAsDataclass, DeclarativeBase):
+                metadata = _md
+                type_annotation_map = {
+                    str: String().with_variant(String(50), "mysql", "mariadb")
+                }
+
+        else:
+            # test #8665 by reversing the order of the classes
+            class Base(DeclarativeBase, MappedAsDataclass):
+                metadata = _md
+                type_annotation_map = {
+                    str: String().with_variant(String(50), "mysql", "mariadb")
+                }
+
+        yield Base
+        Base.registry.dispose()
+
     def test_basic_constructor_repr_base_cls(
         self, dc_decl_base: Type[MappedAsDataclass]
     ):
@@ -111,6 +138,33 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
         a3 = A("data")
         eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])")
 
+    def test_generic_class(self):
+        """further test for #8665"""
+
+        T_Value = TypeVar("T_Value")
+
+        class SomeBaseClass(DeclarativeBase):
+            pass
+
+        class GenericSetting(
+            MappedAsDataclass, SomeBaseClass, Generic[T_Value]
+        ):
+            __tablename__ = "xx"
+
+            id: Mapped[int] = mapped_column(
+                Integer, primary_key=True, init=False
+            )
+
+            key: Mapped[str] = mapped_column(String, init=True)
+
+            value: Mapped[T_Value] = mapped_column(
+                JSON, init=True, default_factory=lambda: {}
+            )
+
+        new_instance: GenericSetting[  # noqa: F841
+            Dict[str, Any]
+        ] = GenericSetting(key="x", value={"foo": "bar"})
+
     def test_no_anno_doesnt_go_into_dc(
         self, dc_decl_base: Type[MappedAsDataclass]
     ):
index f06a608a8dbecc964a4705a0eae9e0e033b28193..b64694fc597e884d2419dc29722459d668bd8760 100644 (file)
@@ -1,6 +1,7 @@
 import dataclasses
 import datetime
 from decimal import Decimal
+from typing import Any
 from typing import ClassVar
 from typing import Dict
 from typing import Generic
@@ -22,6 +23,7 @@ from sqlalchemy import func
 from sqlalchemy import Identity
 from sqlalchemy import inspect
 from sqlalchemy import Integer
+from sqlalchemy import JSON
 from sqlalchemy import Numeric
 from sqlalchemy import select
 from sqlalchemy import String
@@ -30,6 +32,7 @@ from sqlalchemy import testing
 from sqlalchemy import types
 from sqlalchemy import VARCHAR
 from sqlalchemy.exc import ArgumentError
+from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.orm import as_declarative
 from sqlalchemy.orm import composite
 from sqlalchemy.orm import declarative_base
@@ -39,12 +42,14 @@ from sqlalchemy.orm import deferred
 from sqlalchemy.orm import DynamicMapped
 from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import MappedAsDataclass
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import undefer
 from sqlalchemy.orm import WriteOnlyMapped
 from sqlalchemy.orm.collections import attribute_keyed_dict
 from sqlalchemy.orm.collections import KeyFuncDict
 from sqlalchemy.schema import CreateTable
+from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_raises_message
@@ -1898,3 +1903,52 @@ class WriteOnlyRelationshipTest(fixtures.TestBase):
             bs: WriteOnlyMapped[B] = relationship()
 
         self._assertions(A, B, "write_only")
+
+
+class GenericMappingQueryTest(AssertsCompiledSQL, fixtures.TestBase):
+    """test the Generic support added as part of #8665"""
+
+    __dialect__ = "default"
+
+    @testing.fixture
+    def mapping(self):
+        T_Value = TypeVar("T_Value")
+
+        class SomeBaseClass(DeclarativeBase):
+            pass
+
+        class GenericSetting(
+            MappedAsDataclass, SomeBaseClass, Generic[T_Value]
+        ):
+            """Represents key value pairs for settings or values"""
+
+            __tablename__ = "xx"
+
+            id: Mapped[int] = mapped_column(
+                Integer, primary_key=True, init=False
+            )
+
+            key: Mapped[str] = mapped_column(String, init=True)
+
+            value: Mapped[T_Value] = mapped_column(
+                MutableDict.as_mutable(JSON),
+                init=True,
+                default_factory=lambda: {},
+            )
+
+        return GenericSetting
+
+    def test_inspect(self, mapping):
+        GenericSetting = mapping
+
+        typ = GenericSetting[Dict[str, Any]]
+        is_(inspect(typ), GenericSetting.__mapper__)
+
+    def test_select(self, mapping):
+        GenericSetting = mapping
+
+        typ = GenericSetting[Dict[str, Any]]
+        self.assert_compile(
+            select(typ).where(typ.key == "x"),
+            "SELECT xx.id, xx.key, xx.value FROM xx WHERE xx.key = :key_1",
+        )