From: Mike Bayer Date: Tue, 18 Oct 2022 17:25:06 +0000 (-0400) Subject: call super().__init_subclass__(); support GenericAlias X-Git-Tag: rel_2_0_0b2~8^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=de7007e7cc6231b067df71ca79efee75f3317eae;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git call super().__init_subclass__(); support GenericAlias Improved the :class:`.DeclarativeBase` class so that when combined with other mixins like :class:`.MappedAsDataclass`, the order of the classes may be in either order. Added support for mapped classes that are also ``Generic`` subclasses, to be specified as a ``GenericAlias`` object (e.g. ``MyClass[str]``) within statements and calls to :func:`_sa.inspect`. Fixes: #8665 Change-Id: I03063a28b0438a44b9e028fd9d45e8ce08bd18c4 --- diff --git a/doc/build/changelog/unreleased_20/8665.rst b/doc/build/changelog/unreleased_20/8665.rst new file mode 100644 index 0000000000..4f5535a586 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8665.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: bug, declarative, orm + :tickets: 8665 + + Improved the :class:`.DeclarativeBase` class so that when combined with + other mixins like :class:`.MappedAsDataclass`, the order of the classes may + be in either order. + + +.. change:: + :tags: usecase, declarative, orm + :tickets: 8665 + + Added support for mapped classes that are also ``Generic`` subclasses, + to be specified as a ``GenericAlias`` object (e.g. ``MyClass[str]``) + within statements and calls to :func:`_sa.inspect`. + + diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 5724d53a25..a43b59a45c 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -581,7 +581,6 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): match_args: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, ) -> None: - apply_dc_transforms: _DataclassArguments = { "init": init, "repr": repr, @@ -696,6 +695,7 @@ class DeclarativeBase( _setup_declarative_base(cls) else: _as_declarative(cls._sa_registry, cls, cls.__dict__) + super().__init_subclass__() def _check_not_declarative(cls: Type[Any], base: Type[Any]) -> None: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 0f16df9c88..b9d1b50e7a 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -79,6 +79,7 @@ from ..util.langhelpers import MemoizedSlots from ..util.typing import de_stringify_annotation from ..util.typing import is_origin_of_cls from ..util.typing import Literal +from ..util.typing import typing_get_origin if typing.TYPE_CHECKING: from ._typing import _EntityType @@ -1361,6 +1362,18 @@ def _inspect_mc( return mapper +GenericAlias = type(List[_T]) + + +@inspection._inspects(GenericAlias) +def _inspect_generic_alias( + class_: Type[_O], +) -> Optional[Mapper[_O]]: + + origin = cast("Type[_O]", typing_get_origin(class_)) + return _inspect_mc(origin) + + @inspection._self_inspects class Bundle( ORMColumnsClauseRole[_T], diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index d6f1532ee5..ae1f9b35e1 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -3,10 +3,13 @@ import inspect as pyinspect from itertools import product from typing import Any from typing import ClassVar +from typing import Dict +from typing import Generic from typing import List from typing import Optional from typing import Set from typing import Type +from typing import TypeVar from unittest import mock from typing_extensions import Annotated @@ -17,6 +20,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import JSON from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing @@ -47,6 +51,29 @@ from sqlalchemy.util import compat class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): + @testing.fixture(params=["(MAD, DB)", "(DB, MAD)"]) + def dc_decl_base(self, request, metadata): + _md = metadata + + if request.param == "(MAD, DB)": + + class Base(MappedAsDataclass, DeclarativeBase): + metadata = _md + type_annotation_map = { + str: String().with_variant(String(50), "mysql", "mariadb") + } + + else: + # test #8665 by reversing the order of the classes + class Base(DeclarativeBase, MappedAsDataclass): + metadata = _md + type_annotation_map = { + str: String().with_variant(String(50), "mysql", "mariadb") + } + + yield Base + Base.registry.dispose() + def test_basic_constructor_repr_base_cls( self, dc_decl_base: Type[MappedAsDataclass] ): @@ -111,6 +138,33 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): a3 = A("data") eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + def test_generic_class(self): + """further test for #8665""" + + T_Value = TypeVar("T_Value") + + class SomeBaseClass(DeclarativeBase): + pass + + class GenericSetting( + MappedAsDataclass, SomeBaseClass, Generic[T_Value] + ): + __tablename__ = "xx" + + id: Mapped[int] = mapped_column( + Integer, primary_key=True, init=False + ) + + key: Mapped[str] = mapped_column(String, init=True) + + value: Mapped[T_Value] = mapped_column( + JSON, init=True, default_factory=lambda: {} + ) + + new_instance: GenericSetting[ # noqa: F841 + Dict[str, Any] + ] = GenericSetting(key="x", value={"foo": "bar"}) + def test_no_anno_doesnt_go_into_dc( self, dc_decl_base: Type[MappedAsDataclass] ): diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index f06a608a8d..b64694fc59 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1,6 +1,7 @@ import dataclasses import datetime from decimal import Decimal +from typing import Any from typing import ClassVar from typing import Dict from typing import Generic @@ -22,6 +23,7 @@ from sqlalchemy import func from sqlalchemy import Identity from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import JSON from sqlalchemy import Numeric from sqlalchemy import select from sqlalchemy import String @@ -30,6 +32,7 @@ from sqlalchemy import testing from sqlalchemy import types from sqlalchemy import VARCHAR from sqlalchemy.exc import ArgumentError +from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.orm import as_declarative from sqlalchemy.orm import composite from sqlalchemy.orm import declarative_base @@ -39,12 +42,14 @@ from sqlalchemy.orm import deferred from sqlalchemy.orm import DynamicMapped from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import MappedAsDataclass from sqlalchemy.orm import relationship from sqlalchemy.orm import undefer from sqlalchemy.orm import WriteOnlyMapped from sqlalchemy.orm.collections import attribute_keyed_dict from sqlalchemy.orm.collections import KeyFuncDict from sqlalchemy.schema import CreateTable +from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message @@ -1898,3 +1903,52 @@ class WriteOnlyRelationshipTest(fixtures.TestBase): bs: WriteOnlyMapped[B] = relationship() self._assertions(A, B, "write_only") + + +class GenericMappingQueryTest(AssertsCompiledSQL, fixtures.TestBase): + """test the Generic support added as part of #8665""" + + __dialect__ = "default" + + @testing.fixture + def mapping(self): + T_Value = TypeVar("T_Value") + + class SomeBaseClass(DeclarativeBase): + pass + + class GenericSetting( + MappedAsDataclass, SomeBaseClass, Generic[T_Value] + ): + """Represents key value pairs for settings or values""" + + __tablename__ = "xx" + + id: Mapped[int] = mapped_column( + Integer, primary_key=True, init=False + ) + + key: Mapped[str] = mapped_column(String, init=True) + + value: Mapped[T_Value] = mapped_column( + MutableDict.as_mutable(JSON), + init=True, + default_factory=lambda: {}, + ) + + return GenericSetting + + def test_inspect(self, mapping): + GenericSetting = mapping + + typ = GenericSetting[Dict[str, Any]] + is_(inspect(typ), GenericSetting.__mapper__) + + def test_select(self, mapping): + GenericSetting = mapping + + typ = GenericSetting[Dict[str, Any]] + self.assert_compile( + select(typ).where(typ.key == "x"), + "SELECT xx.id, xx.key, xx.value FROM xx WHERE xx.key = :key_1", + )