From be8ffcfa4d91d28acc4ffc08e3203e0b01e29cc7 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 9 May 2025 11:50:26 -0400 Subject: [PATCH] add future mode tests for MappedAsDataclass; more py314b1 regressions for py314b2 all issues should be resolved py314: yes Change-Id: I498a1f623aeb5eb664289236e01e35d8a3dec99f --- lib/sqlalchemy/testing/exclusions.py | 4 +- pyproject.toml | 1 + test/orm/declarative/test_dc_transforms.py | 13 + .../test_dc_transforms_future_anno_sync.py | 2704 +++++++++++++++++ test/typing/test_overloads.py | 10 +- tools/sync_test_files.py | 17 +- tox.ini | 2 +- 7 files changed, 2740 insertions(+), 11 deletions(-) create mode 100644 test/orm/declarative/test_dc_transforms_future_anno_sync.py diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 8ff9b64438..d28e9d85e0 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -392,8 +392,8 @@ def open(): # noqa return skip_if(BooleanPredicate(False, "mark as execute")) -def closed(): - return skip_if(BooleanPredicate(True, "marked as skip")) +def closed(reason="marked as skip"): + return skip_if(BooleanPredicate(True, reason)) def fails(reason=None): diff --git a/pyproject.toml b/pyproject.toml index b076c74f8e..9010569134 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Database :: Front-Ends", diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index 004a119acd..34b9d1982b 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -164,6 +164,8 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): a3 = A("data") eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + # TODO: get this test to work with future anno mode as well + # anno only: @testing.exclusions.closed("doesn't work for future annotations mode yet") # noqa: E501 def test_generic_class(self): """further test for #8665""" @@ -311,6 +313,8 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): a3 = A("data") eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + # TODO: get this test to work with future anno mode as well + # anno only: @testing.exclusions.closed("doesn't work for future annotations mode yet") # noqa: E501 @testing.variation("dc_type", ["decorator", "superclass"]) def test_dataclass_fn(self, dc_type: Variation): annotations = {} @@ -387,6 +391,9 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): take place on INSERT """ + + # anno only: global intpk, str30, s_str30, user_fk + intpk = Annotated[int, mapped_column(primary_key=True)] str30 = Annotated[ str, mapped_column(String(30), insert_default=func.foo()) @@ -1212,6 +1219,8 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase): c1 = Child() eq_regex(repr(c1), r".*\.Child\(a=10, b=7, c=9\)") + # TODO: get this test to work with future anno mode as well + # anno only: @testing.exclusions.closed("doesn't work for future annotations mode yet") # noqa: E501 def test_abstract_is_dc(self): collected_annotations = {} @@ -1233,6 +1242,8 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase): eq_(collected_annotations, {Mixin: {"b": int}, Child: {"c": int}}) eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)") + # TODO: get this test to work with future anno mode as well + # anno only: @testing.exclusions.closed("doesn't work for future annotations mode yet") # noqa: E501 @testing.variation("check_annotations", [True, False]) def test_abstract_is_dc_w_mapped(self, check_annotations): if check_annotations: @@ -1296,6 +1307,8 @@ class DataclassesForNonMappedClassesTest(fixtures.TestBase): eq_regex(repr(Child(a=5, b=6, c=7)), r".*\.Child\(c=7\)") + # TODO: get this test to work with future anno mode as well + # anno only: @testing.exclusions.closed("doesn't work for future annotations mode yet") # noqa: E501 @testing.variation( "dataclass_scope", ["on_base", "on_mixin", "on_base_class", "on_sub_class"], diff --git a/test/orm/declarative/test_dc_transforms_future_anno_sync.py b/test/orm/declarative/test_dc_transforms_future_anno_sync.py new file mode 100644 index 0000000000..d1f319e240 --- /dev/null +++ b/test/orm/declarative/test_dc_transforms_future_anno_sync.py @@ -0,0 +1,2704 @@ +"""This file is automatically generated from the file +'test/orm/declarative/test_dc_transforms.py' +by the 'tools/sync_test_files.py' script. + +Do not edit manually, any change will be lost. +""" # noqa: E501 + +from __future__ import annotations + +import contextlib +import dataclasses +from dataclasses import InitVar +import functools +import inspect as pyinspect +from itertools import product +from typing import Any +from typing import ClassVar +from typing import Dict +from typing import Generic +from typing import List +from typing import Optional +from typing import Set +from typing import Type +from typing import TypeVar +from unittest import mock + +from typing_extensions import Annotated + +from sqlalchemy import BigInteger +from sqlalchemy import Column +from sqlalchemy import exc +from sqlalchemy import ForeignKey +from sqlalchemy import func +from sqlalchemy import inspect +from sqlalchemy import Integer +from sqlalchemy import JSON +from sqlalchemy import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy import testing +from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy.orm import column_property +from sqlalchemy.orm import composite +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import deferred +from sqlalchemy.orm import interfaces +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import MappedAsDataclass +from sqlalchemy.orm import MappedColumn +from sqlalchemy.orm import query_expression +from sqlalchemy.orm import registry +from sqlalchemy.orm import registry as _RegistryType +from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session +from sqlalchemy.orm import synonym +from sqlalchemy.orm.attributes import LoaderCallableStatus +from sqlalchemy.sql.base import _NoArg +from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import eq_ +from sqlalchemy.testing import eq_regex +from sqlalchemy.testing import expect_deprecated +from sqlalchemy.testing import expect_raises +from sqlalchemy.testing import expect_raises_message +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_true +from sqlalchemy.testing import ne_ +from sqlalchemy.testing import Variation +from sqlalchemy.util import compat + + +def _dataclass_mixin_warning(clsname, attrnames): + return testing.expect_deprecated( + rf"When transforming .* to a dataclass, attribute\(s\) " + rf"{attrnames} originates from superclass .*{clsname}" + ) + + +class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): + @testing.fixture(params=["(MAD, DB)", "(DB, MAD)"]) + def dc_decl_base(self, request, metadata): + _md = metadata + + if request.param == "(MAD, DB)": + + class Base(MappedAsDataclass, DeclarativeBase): + _mad_before = True + metadata = _md + type_annotation_map = { + str: String().with_variant( + String(50), "mysql", "mariadb", "oracle" + ) + } + + else: + # test #8665 by reversing the order of the classes + class Base(DeclarativeBase, MappedAsDataclass): + _mad_before = False + metadata = _md + type_annotation_map = { + str: String().with_variant( + String(50), "mysql", "mariadb", "oracle" + ) + } + + yield Base + Base.registry.dispose() + + def test_basic_constructor_repr_base_cls( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + class B(dc_decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + a_id: Mapped[Optional[int]] = mapped_column( + ForeignKey("a.id"), init=False + ) + x: Mapped[Optional[int]] = mapped_column(default=None) + + A.__qualname__ = "some_module.A" + B.__qualname__ = "some_module.B" + + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x", "bs"], + varargs=None, + varkw=None, + defaults=(LoaderCallableStatus.DONT_SET, mock.ANY), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + eq_( + pyinspect.getfullargspec(B.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x"], + varargs=None, + varkw=None, + defaults=(LoaderCallableStatus.DONT_SET,), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)]) + eq_( + repr(a2), + "some_module.A(id=None, data='10', x=5, " + "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), " + "some_module.B(id=None, data='data2', a_id=None, x=12)])", + ) + + a3 = A("data") + eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + + # TODO: get this test to work with future anno mode as well + @testing.exclusions.closed( + "doesn't work for future annotations mode yet" + ) # noqa: E501 + def test_generic_class(self): + """further test for #8665""" + + T_Value = TypeVar("T_Value") + + class SomeBaseClass(DeclarativeBase): + pass + + class GenericSetting( + MappedAsDataclass, SomeBaseClass, Generic[T_Value] + ): + __tablename__ = "xx" + + id: Mapped[int] = mapped_column( + Integer, primary_key=True, init=False + ) + + key: Mapped[str] = mapped_column(String, init=True) + + value: Mapped[T_Value] = mapped_column( + JSON, init=True, default_factory=lambda: {} + ) + + new_instance: GenericSetting[Dict[str, Any]] = ( # noqa: F841 + GenericSetting(key="x", value={"foo": "bar"}) + ) + + def test_no_anno_doesnt_go_into_dc( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class User(dc_decl_base): + __tablename__: ClassVar[Optional[str]] = "user" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + username: Mapped[str] + password: Mapped[str] + addresses: Mapped[List["Address"]] = relationship( # noqa: F821 + default_factory=list + ) + + class Address(dc_decl_base): + __tablename__: ClassVar[Optional[str]] = "address" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + # should not be in the dataclass constructor + user_id = mapped_column(ForeignKey(User.id)) + + email_address: Mapped[str] + + a1 = Address("email@address") + eq_(a1.email_address, "email@address") + + def test_warn_on_non_dc_mixin(self): + class _BaseMixin: + create_user: Mapped[int] = mapped_column() + update_user: Mapped[Optional[int]] = mapped_column( + default=None, init=False + ) + + class Base(DeclarativeBase, MappedAsDataclass, _BaseMixin): + pass + + class SubMixin: + foo: Mapped[str] + bar: Mapped[str] = mapped_column() + + with ( + _dataclass_mixin_warning( + "_BaseMixin", "'create_user', 'update_user'" + ), + _dataclass_mixin_warning("SubMixin", "'foo', 'bar'"), + ): + + class User(SubMixin, Base): + __tablename__ = "sys_user" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + username: Mapped[str] = mapped_column(String) + password: Mapped[str] = mapped_column(String) + + def test_basic_constructor_repr_cls_decorator( + self, registry: _RegistryType + ): + @registry.mapped_as_dataclass() + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + @registry.mapped_as_dataclass() + class B: + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column(default=None) + + A.__qualname__ = "some_module.A" + B.__qualname__ = "some_module.B" + + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x", "bs"], + varargs=None, + varkw=None, + defaults=(LoaderCallableStatus.DONT_SET, mock.ANY), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + eq_( + pyinspect.getfullargspec(B.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x"], + varargs=None, + varkw=None, + defaults=(LoaderCallableStatus.DONT_SET,), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)]) + + # note a_id isn't included because it wasn't annotated + eq_( + repr(a2), + "some_module.A(id=None, data='10', x=5, " + "bs=[some_module.B(id=None, data='data1', x=None), " + "some_module.B(id=None, data='data2', x=12)])", + ) + + a3 = A("data") + eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + + # TODO: get this test to work with future anno mode as well + @testing.exclusions.closed( + "doesn't work for future annotations mode yet" + ) # noqa: E501 + @testing.variation("dc_type", ["decorator", "superclass"]) + def test_dataclass_fn(self, dc_type: Variation): + annotations = {} + + def dc_callable(kls, **kw) -> Type[Any]: + annotations[kls] = kls.__annotations__ + return dataclasses.dataclass(kls, **kw) # type: ignore + + if dc_type.decorator: + reg = registry() + + @reg.mapped_as_dataclass(dataclass_callable=dc_callable) + class MappedClass: + __tablename__ = "mapped_class" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + eq_(annotations, {MappedClass: {"id": int, "name": str}}) + + elif dc_type.superclass: + + class Base(DeclarativeBase): + pass + + class Mixin(MappedAsDataclass, dataclass_callable=dc_callable): + id: Mapped[int] = mapped_column(primary_key=True) + + class MappedClass(Mixin, Base): + __tablename__ = "mapped_class" + name: Mapped[str] + + eq_( + annotations, + {Mixin: {"id": int}, MappedClass: {"id": int, "name": str}}, + ) + else: + dc_type.fail() + + def test_default_fn(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column(default="d1") + data2: Mapped[str] = mapped_column(default_factory=lambda: "d2") + + a1 = A() + eq_(a1.data, "d1") + eq_(a1.data2, "d2") + + def test_default_factory_vs_collection_class( + self, dc_decl_base: Type[MappedAsDataclass] + ): + # this is currently the error raised by dataclasses. We can instead + # do this validation ourselves, but overall I don't know that we + # can hit every validation and rule that's in dataclasses + with expect_raises_message( + ValueError, "cannot specify both default and default_factory" + ): + + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column( + default="d1", default_factory=lambda: "d2" + ) + + def test_combine_args_from_pep593(self, decl_base: Type[DeclarativeBase]): + """test that we can set up column-level defaults separate from + dataclass defaults with a pep593 setup; however the dataclass + defaults need to override the insert_defaults so that they + take place on INSERT + + """ + + global intpk, str30, s_str30, user_fk + + intpk = Annotated[int, mapped_column(primary_key=True)] + str30 = Annotated[ + str, mapped_column(String(30), insert_default=func.foo()) + ] + s_str30 = Annotated[ + str, + mapped_column(String(30), server_default="some server default"), + ] + user_fk = Annotated[int, mapped_column(ForeignKey("user_account.id"))] + + class User(MappedAsDataclass, decl_base): + __tablename__ = "user_account" + + # we need this case for dataclasses that can't derive things + # from Annotated yet at the typing level + id: Mapped[intpk] = mapped_column(init=False) + name_plain: Mapped[str30] = mapped_column() + name_no_init: Mapped[str30] = mapped_column(init=False) + name_none: Mapped[Optional[str30]] = mapped_column(default=None) + name_insert_none: Mapped[Optional[str30]] = mapped_column( + insert_default=None, init=False + ) + name: Mapped[str30] = mapped_column(default="hi") + name_insert: Mapped[str30] = mapped_column( + insert_default="hi", init=False + ) + name2: Mapped[s_str30] = mapped_column(default="there") + name2_insert: Mapped[s_str30] = mapped_column( + insert_default="there", init=False + ) + addresses: Mapped[List["Address"]] = relationship( # noqa: F821 + back_populates="user", default_factory=list + ) + + class Address(MappedAsDataclass, decl_base): + __tablename__ = "address" + + id: Mapped[intpk] = mapped_column(init=False) + email_address: Mapped[str] + user_id: Mapped[user_fk] = mapped_column(init=False) + user: Mapped[Optional["User"]] = relationship( + back_populates="addresses", default=None + ) + + is_true(User.__table__.c.id.primary_key) + + # the default from the Annotated overrides mapped_cols that have + # nothing for default or insert default + is_true(User.__table__.c.name_plain.default.arg.compare(func.foo())) + is_true(User.__table__.c.name_no_init.default.arg.compare(func.foo())) + + # mapped cols that have None for default or insert default, that + # default overrides + is_true(User.__table__.c.name_none.default is None) + is_true(User.__table__.c.name_insert_none.default is None) + + # mapped cols that have a value for default or insert default, that + # default overrides + is_true(User.__table__.c.name.default.arg == "hi") + is_true(User.__table__.c.name2.default.arg == "there") + is_true(User.__table__.c.name_insert.default.arg == "hi") + is_true(User.__table__.c.name2_insert.default.arg == "there") + + eq_(User.__table__.c.name2.server_default.arg, "some server default") + + is_true(Address.__table__.c.user_id.references(User.__table__.c.id)) + u1 = User(name_plain="name") + eq_(u1.name_none, None) + eq_(u1.name_insert_none, None) + eq_(u1.name, "hi") + eq_(u1.name2, "there") + eq_(u1.name_insert, None) + eq_(u1.name2_insert, None) + + def test_inheritance(self, dc_decl_base: Type[MappedAsDataclass]): + class Person(dc_decl_base): + __tablename__ = "person" + person_id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + name: Mapped[str] + type: Mapped[str] = mapped_column(init=False) + + __mapper_args__ = {"polymorphic_on": type} + + class Engineer(Person): + __tablename__ = "engineer" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True, init=False + ) + + status: Mapped[str] = mapped_column(String(30)) + engineer_name: Mapped[str] + primary_language: Mapped[str] + __mapper_args__ = {"polymorphic_identity": "engineer"} + + e1 = Engineer("nm", "st", "en", "pl") + eq_(e1.name, "nm") + eq_(e1.status, "st") + eq_(e1.engineer_name, "en") + eq_(e1.primary_language, "pl") + + def test_non_mapped_fields_wo_mapped_or_dc( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: str + ctrl_one: str = dataclasses.field() + some_field: int = dataclasses.field(default=5) + + a1 = A("data", "ctrl_one", 5) + eq_( + dataclasses.asdict(a1), + { + "ctrl_one": "ctrl_one", + "data": "data", + "id": None, + "some_field": 5, + }, + ) + + def test_non_mapped_fields_wo_mapped_or_dc_w_inherits( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: str + ctrl_one: str = dataclasses.field() + some_field: int = dataclasses.field(default=5) + + class B(A): + b_data: Mapped[str] = mapped_column(default="bd") + + # ensure we didnt break dataclasses contract of removing Field + # issue #8880 + eq_(A.__dict__["some_field"], 5) + assert "ctrl_one" not in A.__dict__ + + b1 = B(data="data", ctrl_one="ctrl_one", some_field=5, b_data="x") + eq_( + dataclasses.asdict(b1), + { + "ctrl_one": "ctrl_one", + "data": "data", + "id": None, + "some_field": 5, + "b_data": "x", + }, + ) + + def test_init_var(self, dc_decl_base: Type[MappedAsDataclass]): + class User(dc_decl_base): + __tablename__ = "user_account" + + id: Mapped[int] = mapped_column(init=False, primary_key=True) + name: Mapped[str] + + password: InitVar[str] + repeat_password: InitVar[str] + + password_hash: Mapped[str] = mapped_column( + init=False, nullable=False + ) + + def __post_init__(self, password: str, repeat_password: str): + if password != repeat_password: + raise ValueError("passwords do not match") + + self.password_hash = f"some hash... {password}" + + u1 = User(name="u1", password="p1", repeat_password="p1") + eq_(u1.password_hash, "some hash... p1") + self.assert_compile( + select(User), + "SELECT user_account.id, user_account.name, " + "user_account.password_hash FROM user_account", + ) + + def test_integrated_dc(self, dc_decl_base: Type[MappedAsDataclass]): + """We will be telling users "this is a dataclass that is also + mapped". Therefore, they will want *any* kind of attribute to do what + it would normally do in a dataclass, including normal types without any + field and explicit use of dataclasses.field(). additionally, we'd like + ``Mapped`` to mean "persist this attribute". So the absence of + ``Mapped`` should also mean something too. + + """ + + class A(dc_decl_base): + __tablename__ = "a" + + ctrl_one: str = dataclasses.field() + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + some_field: int = dataclasses.field(default=5) + + some_none_field: Optional[str] = dataclasses.field(default=None) + + some_other_int_field: int = 10 + + # some field is part of the constructor + a1 = A("ctrlone", "datafield") + eq_( + dataclasses.asdict(a1), + { + "ctrl_one": "ctrlone", + "data": "datafield", + "id": None, + "some_field": 5, + "some_none_field": None, + "some_other_int_field": 10, + }, + ) + + a2 = A( + "ctrlone", + "datafield", + some_field=7, + some_other_int_field=12, + some_none_field="x", + ) + eq_( + dataclasses.asdict(a2), + { + "ctrl_one": "ctrlone", + "data": "datafield", + "id": None, + "some_field": 7, + "some_none_field": "x", + "some_other_int_field": 12, + }, + ) + + # only Mapped[] is mapped + self.assert_compile(select(A), "SELECT a.id, a.data FROM a") + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=[ + "self", + "ctrl_one", + "data", + "some_field", + "some_none_field", + "some_other_int_field", + ], + varargs=None, + varkw=None, + defaults=(5, None, 10), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + def test_dc_on_top_of_non_dc(self, decl_base: Type[DeclarativeBase]): + class Person(decl_base): + __tablename__ = "person" + person_id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + type: Mapped[str] = mapped_column() + + __mapper_args__ = {"polymorphic_on": type} + + class Engineer(MappedAsDataclass, Person): + __tablename__ = "engineer" + + person_id: Mapped[int] = mapped_column( + ForeignKey("person.person_id"), primary_key=True, init=False + ) + + status: Mapped[str] = mapped_column(String(30)) + engineer_name: Mapped[str] + primary_language: Mapped[str] + __mapper_args__ = {"polymorphic_identity": "engineer"} + + e1 = Engineer("st", "en", "pl") + eq_(e1.status, "st") + eq_(e1.engineer_name, "en") + eq_(e1.primary_language, "pl") + + eq_( + pyinspect.getfullargspec(Person.__init__), + # the boring **kw __init__ + pyinspect.FullArgSpec( + args=["self"], + varargs=None, + varkw="kwargs", + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + eq_( + pyinspect.getfullargspec(Engineer.__init__), + # the exciting dataclasses __init__ + pyinspect.FullArgSpec( + args=["self", "status", "engineer_name", "primary_language"], + varargs=None, + varkw=None, + defaults=None, + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + def test_compare(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, compare=False) + data: Mapped[str] + + a1 = A(id=0, data="foo") + a2 = A(id=1, data="foo") + eq_(a1, a2) + + @testing.requires.python310 + def test_kw_only_attribute(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(kw_only=True) + + fas = pyinspect.getfullargspec(A.__init__) + eq_(fas.args, ["self", "id"]) + eq_(fas.kwonlyargs, ["data"]) + + @testing.combinations(True, False, argnames="unsafe_hash") + def test_hash_attribute( + self, dc_decl_base: Type[MappedAsDataclass], unsafe_hash + ): + class A(dc_decl_base, unsafe_hash=unsafe_hash): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, hash=False) + data: Mapped[str] = mapped_column(hash=True) + + a = A(id=1, data="x") + if not unsafe_hash or not dc_decl_base._mad_before: + with expect_raises(TypeError): + a_hash1 = hash(a) + else: + a_hash1 = hash(a) + a.id = 41 + eq_(hash(a), a_hash1) + a.data = "y" + ne_(hash(a), a_hash1) + + @testing.requires.python310 + def test_kw_only_dataclass_constant( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class Mixin(MappedAsDataclass): + a: Mapped[int] = mapped_column(primary_key=True) + b: Mapped[int] = mapped_column(default=1) + + class Child(Mixin, dc_decl_base): + __tablename__ = "child" + + _: dataclasses.KW_ONLY + c: Mapped[int] + + c1 = Child(1, c=5) + eq_(c1, Child(a=1, b=1, c=5)) + + def test_mapped_column_overrides(self, dc_decl_base): + """test #8688""" + + class TriggeringMixin(MappedAsDataclass): + mixin_value: Mapped[int] = mapped_column(BigInteger) + + class NonTriggeringMixin(MappedAsDataclass): + mixin_value: Mapped[int] + + class Foo(dc_decl_base, TriggeringMixin): + __tablename__ = "foo" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + foo_value: Mapped[float] = mapped_column(default=78) + + class Bar(dc_decl_base, NonTriggeringMixin): + __tablename__ = "bar" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + bar_value: Mapped[float] = mapped_column(default=78) + + f1 = Foo(mixin_value=5) + eq_(f1.foo_value, 78) + + b1 = Bar(mixin_value=5) + eq_(b1.bar_value, 78) + + def test_mixing_MappedAsDataclass_with_decorator_raises(self, registry): + """test #9211""" + + class Mixin(MappedAsDataclass): + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + with expect_raises_message( + exc.InvalidRequestError, + "Class .*Foo.* is already a dataclass; ensure that " + "base classes / decorator styles of establishing dataclasses " + "are not being mixed. ", + ): + + @registry.mapped_as_dataclass + class Foo(Mixin): + bar_value: Mapped[float] = mapped_column(default=78) + + def test_MappedAsDataclass_table_provided(self, registry): + """test #11973""" + + with expect_raises_message( + exc.InvalidRequestError, + "Class .*Foo.* already defines a '__table__'. " + "ORM Annotated Dataclasses do not support a pre-existing " + "'__table__' element", + ): + + @registry.mapped_as_dataclass + class Foo: + __table__ = Table("foo", registry.metadata) + foo: Mapped[float] + + def test_dataclass_exception_wrapped(self, dc_decl_base): + with expect_raises_message( + exc.InvalidRequestError, + r"Python dataclasses error encountered when creating dataclass " + r"for \'Foo\': .*Please refer to Python dataclasses.*", + ) as ec: + + class Foo(dc_decl_base): + id: Mapped[int] = mapped_column(primary_key=True, init=False) + foo_value: Mapped[float] = mapped_column(default=78) + foo_no_value: Mapped[float] = mapped_column() + __tablename__ = "foo" + + is_true(isinstance(ec.error.__cause__, TypeError)) + + def test_dataclass_default(self, dc_decl_base): + """test for #9879""" + + def c10(): + return 10 + + def c20(): + return 20 + + class A(dc_decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + def_init: Mapped[int] = mapped_column(default=42) + call_init: Mapped[int] = mapped_column(default_factory=c10) + def_no_init: Mapped[int] = mapped_column(default=13, init=False) + call_no_init: Mapped[int] = mapped_column( + default_factory=c20, init=False + ) + + a = A(id=100) + eq_(a.def_init, 42) + eq_(a.call_init, 10) + eq_(a.def_no_init, 13) + eq_(a.call_no_init, 20) + + fields = {f.name: f for f in dataclasses.fields(A)} + eq_(fields["def_init"].default, LoaderCallableStatus.DONT_SET) + eq_(fields["call_init"].default_factory, c10) + eq_(fields["def_no_init"].default, dataclasses.MISSING) + ne_(fields["def_no_init"].default_factory, dataclasses.MISSING) + eq_(fields["call_no_init"].default_factory, c20) + + def test_dataclass_default_callable(self, dc_decl_base): + """test for #9936""" + + def cd(): + return 42 + + with expect_deprecated( + "Callable object passed to the ``default`` parameter for " + "attribute 'value' in a ORM-mapped Dataclasses context is " + "ambiguous, and this use will raise an error in a future " + "release. If this callable is intended to produce Core level ", + "Callable object passed to the ``default`` parameter for " + "attribute 'no_init' in a ORM-mapped Dataclasses context is " + "ambiguous, and this use will raise an error in a future " + "release. If this callable is intended to produce Core level ", + ): + + class A(dc_decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + value: Mapped[int] = mapped_column(default=cd) + no_init: Mapped[int] = mapped_column(default=cd, init=False) + + a = A(id=100) + is_false("no_init" in a.__dict__) + eq_(a.value, cd) + eq_(a.no_init, None) + + fields = {f.name: f for f in dataclasses.fields(A)} + eq_(fields["value"].default, cd) + eq_(fields["no_init"].default, cd) + + +class RelationshipDefaultFactoryTest(fixtures.TestBase): + def test_list(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=lambda: [B(data="hi")] + ) + + class B(dc_decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + a1 = A() + eq_(a1.bs[0].data, "hi") + + def test_set(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[Set["B"]] = relationship( # noqa: F821 + default_factory=lambda: {B(data="hi")} + ) + + class B(dc_decl_base, unsafe_hash=True): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + a1 = A() + eq_(a1.bs.pop().data, "hi") + + def test_oh_no_mismatch(self, dc_decl_base: Type[MappedAsDataclass]): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + bs: Mapped[Set["B"]] = relationship( # noqa: F821 + default_factory=lambda: [B(data="hi")] + ) + + class B(dc_decl_base, unsafe_hash=True): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + + # old school collection mismatch error FTW + with expect_raises_message( + TypeError, "Incompatible collection type: list is not set-like" + ): + A() + + def test_one_to_one_example(self, dc_decl_base: Type[MappedAsDataclass]): + """test example in the relationship docs will derive uselist=False + correctly""" + + class Parent(dc_decl_base): + __tablename__ = "parent" + + id: Mapped[int] = mapped_column(init=False, primary_key=True) + child: Mapped["Child"] = relationship( # noqa: F821 + back_populates="parent", default=None + ) + + class Child(dc_decl_base): + __tablename__ = "child" + + id: Mapped[int] = mapped_column(init=False, primary_key=True) + parent_id: Mapped[int] = mapped_column( + ForeignKey("parent.id"), init=False + ) + parent: Mapped["Parent"] = relationship( + back_populates="child", default=None + ) + + c1 = Child() + p1 = Parent(child=c1) + is_(p1.child, c1) + is_(c1.parent, p1) + + p2 = Parent() + is_(p2.child, None) + + def test_replace_operation_works_w_history_etc( + self, registry: _RegistryType + ): + @registry.mapped_as_dataclass + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + @registry.mapped_as_dataclass + class B: + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + a_id = mapped_column(ForeignKey("a.id"), init=False) + data: Mapped[str] + x: Mapped[Optional[int]] = mapped_column(default=None) + + registry.metadata.create_all(testing.db) + + with Session(testing.db) as sess: + a1 = A("data", 10, [B("b1"), B("b2", x=5), B("b3")]) + sess.add(a1) + sess.commit() + + a2 = dataclasses.replace(a1, x=12, bs=[B("b4")]) + + assert a1 in sess + assert not sess.is_modified(a1, include_collections=True) + assert a2 not in sess + eq_(inspect(a2).attrs.x.history, ([12], (), ())) + sess.add(a2) + sess.commit() + + eq_(sess.scalars(select(A.x).order_by(A.id)).all(), [10, 12]) + eq_( + sess.scalars(select(B.data).order_by(B.id)).all(), + ["b1", "b2", "b3", "b4"], + ) + + def test_post_init(self, registry: _RegistryType): + @registry.mapped_as_dataclass + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column(init=False) + + def __post_init__(self): + self.data = "some data" + + a1 = A() + eq_(a1.data, "some data") + + def test_no_field_args_w_new_style(self, registry: _RegistryType): + with expect_raises_message( + exc.InvalidRequestError, + "SQLAlchemy mapped dataclasses can't consume mapping information", + ): + + @registry.mapped_as_dataclass() + class A: + __tablename__ = "a" + __sa_dataclass_metadata_key__ = "sa" + + account_id: int = dataclasses.field( + init=False, + metadata={"sa": Column(Integer, primary_key=True)}, + ) + + def test_no_field_args_w_new_style_two(self, registry: _RegistryType): + @dataclasses.dataclass + class Base: + pass + + with expect_raises_message( + exc.InvalidRequestError, + "SQLAlchemy mapped dataclasses can't consume mapping information", + ): + + @registry.mapped_as_dataclass() + class A(Base): + __tablename__ = "a" + __sa_dataclass_metadata_key__ = "sa" + + account_id: int = dataclasses.field( + init=False, + metadata={"sa": Column(Integer, primary_key=True)}, + ) + + +class DataclassesForNonMappedClassesTest(fixtures.TestBase): + """test for cases added in #9179""" + + def test_base_is_dc(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int + + class Child(Parent): + __tablename__ = "child" + b: Mapped[int] = mapped_column(primary_key=True) + + eq_regex(repr(Child(5, 6)), r".*\.Child\(a=5, b=6\)") + + def test_base_is_dc_plus_options(self): + class Parent(MappedAsDataclass, DeclarativeBase, unsafe_hash=True): + a: int + + class Child(Parent, repr=False): + __tablename__ = "child" + b: Mapped[int] = mapped_column(primary_key=True) + + c1 = Child(5, 6) + eq_(hash(c1), hash(Child(5, 6))) + + # still reprs, because base has a repr, but b not included + eq_regex(repr(c1), r".*\.Child\(a=5\)") + + def test_base_is_dc_init_var(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: InitVar[int] + + class Child(Parent): + __tablename__ = "child" + b: Mapped[int] = mapped_column(primary_key=True) + + c1 = Child(a=5, b=6) + eq_regex(repr(c1), r".*\.Child\(b=6\)") + + def test_base_is_dc_field(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int = dataclasses.field(default=10) + + class Child(Parent): + __tablename__ = "child" + b: Mapped[int] = mapped_column(primary_key=True, default=7) + + c1 = Child(a=5, b=6) + eq_regex(repr(c1), r".*\.Child\(a=5, b=6\)") + + c1 = Child(b=6) + eq_regex(repr(c1), r".*\.Child\(a=10, b=6\)") + + c1 = Child() + eq_regex(repr(c1), r".*\.Child\(a=10, b=7\)") + + def test_abstract_and_base_is_dc(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int + + class Mixin(Parent): + __abstract__ = True + b: int + + class Child(Mixin): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + eq_regex(repr(Child(5, 6, 7)), r".*\.Child\(a=5, b=6, c=7\)") + + def test_abstract_and_base_is_dc_plus_options(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int + + class Mixin(Parent, unsafe_hash=True): + __abstract__ = True + b: int + + class Child(Mixin, repr=False): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + eq_(hash(Child(5, 6, 7)), hash(Child(5, 6, 7))) + + eq_regex(repr(Child(5, 6, 7)), r".*\.Child\(a=5, b=6\)") + + def test_abstract_and_base_is_dc_init_var(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: InitVar[int] + + class Mixin(Parent): + __abstract__ = True + b: InitVar[int] + + class Child(Mixin): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + c1 = Child(a=5, b=6, c=7) + eq_regex(repr(c1), r".*\.Child\(c=7\)") + + def test_abstract_and_base_is_dc_field(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int = dataclasses.field(default=10) + + class Mixin(Parent): + __abstract__ = True + b: int = dataclasses.field(default=7) + + class Child(Mixin): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True, default=9) + + c1 = Child(b=6, c=7) + eq_regex(repr(c1), r".*\.Child\(a=10, b=6, c=7\)") + + c1 = Child() + eq_regex(repr(c1), r".*\.Child\(a=10, b=7, c=9\)") + + # TODO: get this test to work with future anno mode as well + @testing.exclusions.closed( + "doesn't work for future annotations mode yet" + ) # noqa: E501 + def test_abstract_is_dc(self): + collected_annotations = {} + + def check_args(cls, **kw): + collected_annotations[cls] = cls.__annotations__ + return dataclasses.dataclass(cls, **kw) + + class Parent(DeclarativeBase): + a: int + + class Mixin(MappedAsDataclass, Parent, dataclass_callable=check_args): + __abstract__ = True + b: int + + class Child(Mixin): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + eq_(collected_annotations, {Mixin: {"b": int}, Child: {"c": int}}) + eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)") + + # TODO: get this test to work with future anno mode as well + @testing.exclusions.closed( + "doesn't work for future annotations mode yet" + ) # noqa: E501 + @testing.variation("check_annotations", [True, False]) + def test_abstract_is_dc_w_mapped(self, check_annotations): + if check_annotations: + collected_annotations = {} + + def check_args(cls, **kw): + collected_annotations[cls] = cls.__annotations__ + return dataclasses.dataclass(cls, **kw) + + class_kw = {"dataclass_callable": check_args} + else: + class_kw = {} + + class Parent(DeclarativeBase): + a: int + + class Mixin(MappedAsDataclass, Parent, **class_kw): + __abstract__ = True + b: Mapped[int] = mapped_column() + + class Child(Mixin): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + if check_annotations: + # note: current dataclasses process adds Field() object to Child + # based on attributes which include those from Mixin. This means + # the annotations of Child are also augmented while we do + # dataclasses collection. + eq_( + collected_annotations, + {Mixin: {"b": int}, Child: {"b": int, "c": int}}, + ) + eq_regex(repr(Child(6, 7)), r".*\.Child\(b=6, c=7\)") + + def test_mixin_and_base_is_dc(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: int + + @dataclasses.dataclass + class Mixin: + b: int + + class Child(Mixin, Parent): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + eq_regex(repr(Child(5, 6, 7)), r".*\.Child\(a=5, b=6, c=7\)") + + def test_mixin_and_base_is_dc_init_var(self): + class Parent(MappedAsDataclass, DeclarativeBase): + a: InitVar[int] + + @dataclasses.dataclass + class Mixin: + b: InitVar[int] + + class Child(Mixin, Parent): + __tablename__ = "child" + c: Mapped[int] = mapped_column(primary_key=True) + + eq_regex(repr(Child(a=5, b=6, c=7)), r".*\.Child\(c=7\)") + + # TODO: get this test to work with future anno mode as well + @testing.exclusions.closed( + "doesn't work for future annotations mode yet" + ) # noqa: E501 + @testing.variation( + "dataclass_scope", + ["on_base", "on_mixin", "on_base_class", "on_sub_class"], + ) + @testing.variation( + "test_alternative_callable", + [True, False], + ) + def test_mixin_w_inheritance( + self, dataclass_scope, test_alternative_callable + ): + """test #9226""" + + expected_annotations = {} + + if test_alternative_callable: + collected_annotations = {} + + def check_args(cls, **kw): + collected_annotations[cls] = getattr( + cls, "__annotations__", {} + ) + return dataclasses.dataclass(cls, **kw) + + klass_kw = {"dataclass_callable": check_args} + else: + klass_kw = {} + + if dataclass_scope.on_base: + + class Base(MappedAsDataclass, DeclarativeBase, **klass_kw): + pass + + expected_annotations[Base] = {} + else: + + class Base(DeclarativeBase): + pass + + if dataclass_scope.on_mixin: + + class Mixin(MappedAsDataclass, **klass_kw): + @declared_attr.directive + @classmethod + def __tablename__(cls) -> str: + return cls.__name__.lower() + + @declared_attr.directive + @classmethod + def __mapper_args__(cls) -> Dict[str, Any]: + return { + "polymorphic_identity": cls.__name__, + "polymorphic_on": "polymorphic_type", + } + + @declared_attr + @classmethod + def polymorphic_type(cls) -> Mapped[str]: + return mapped_column( + String, + insert_default=cls.__name__, + init=False, + ) + + expected_annotations[Mixin] = {} + + non_dc_mixin = contextlib.nullcontext + + else: + + class Mixin: + @declared_attr.directive + @classmethod + def __tablename__(cls) -> str: + return cls.__name__.lower() + + @declared_attr.directive + @classmethod + def __mapper_args__(cls) -> Dict[str, Any]: + return { + "polymorphic_identity": cls.__name__, + "polymorphic_on": "polymorphic_type", + } + + if dataclass_scope.on_base or dataclass_scope.on_base_class: + + @declared_attr + @classmethod + def polymorphic_type(cls) -> Mapped[str]: + return mapped_column( + String, + insert_default=cls.__name__, + init=False, + ) + + else: + + @declared_attr + @classmethod + def polymorphic_type(cls) -> Mapped[str]: + return mapped_column( + String, + insert_default=cls.__name__, + ) + + non_dc_mixin = functools.partial( + _dataclass_mixin_warning, "Mixin", "'polymorphic_type'" + ) + + if dataclass_scope.on_base_class: + with non_dc_mixin(): + + class Book(Mixin, MappedAsDataclass, Base, **klass_kw): + id: Mapped[int] = mapped_column( + Integer, + primary_key=True, + init=False, + ) + + else: + if dataclass_scope.on_base: + local_non_dc_mixin = non_dc_mixin + else: + local_non_dc_mixin = contextlib.nullcontext + + with local_non_dc_mixin(): + + class Book(Mixin, Base): + if not dataclass_scope.on_sub_class: + id: Mapped[int] = mapped_column( # noqa: A001 + Integer, primary_key=True, init=False + ) + else: + id: Mapped[int] = mapped_column( # noqa: A001 + Integer, + primary_key=True, + ) + + if MappedAsDataclass in Book.__mro__: + expected_annotations[Book] = {"id": int, "polymorphic_type": str} + + if dataclass_scope.on_sub_class: + with non_dc_mixin(): + + class Novel(MappedAsDataclass, Book, **klass_kw): + id: Mapped[int] = mapped_column( # noqa: A001 + ForeignKey("book.id"), + primary_key=True, + init=False, + ) + description: Mapped[Optional[str]] + + else: + with non_dc_mixin(): + + class Novel(Book): + id: Mapped[int] = mapped_column( + ForeignKey("book.id"), + primary_key=True, + init=False, + ) + description: Mapped[Optional[str]] + + expected_annotations[Novel] = {"id": int, "description": Optional[str]} + + if test_alternative_callable: + eq_(collected_annotations, expected_annotations) + + n1 = Novel("the description") + eq_(n1.description, "the description") + + +class DataclassArgsTest(fixtures.TestBase): + dc_arg_names = ("init", "repr", "eq", "order", "unsafe_hash") + if compat.py310: + dc_arg_names += ("match_args", "kw_only") + + @testing.fixture(params=product(dc_arg_names, (True, False))) + def dc_argument_fixture(self, request: Any, registry: _RegistryType): + name, use_defaults = request.param + + args = {n: n == name for n in self.dc_arg_names} + if args["order"]: + args["eq"] = True + if use_defaults: + default = { + "init": True, + "repr": True, + "eq": True, + "order": False, + "unsafe_hash": False, + } + if compat.py310: + default |= {"match_args": True, "kw_only": False} + to_apply = {k: v for k, v in args.items() if v} + effective = {**default, **to_apply} + return to_apply, effective + else: + return args, args + + @testing.fixture(params=["mapped_column", "deferred"]) + def mapped_expr_constructor(self, request): + name = request.param + + if name == "mapped_column": + yield mapped_column(default=7, init=True) + elif name == "deferred": + yield deferred(Column(Integer), default=7, init=True) + + def test_attrs_rejected_if_not_a_dc( + self, mapped_expr_constructor, decl_base: Type[DeclarativeBase] + ): + if isinstance(mapped_expr_constructor, MappedColumn): + unwanted_args = "'init'" + else: + unwanted_args = "'default', 'init'" + with expect_raises_message( + exc.ArgumentError, + r"Attribute 'x' on class .*A.* includes dataclasses " + r"argument\(s\): " + rf"{unwanted_args} but class does not specify SQLAlchemy native " + "dataclass configuration", + ): + + class A(decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + + x: Mapped[int] = mapped_expr_constructor + + def _assert_cls(self, cls, dc_arguments): + if dc_arguments["init"]: + + def create(data, x): + if dc_arguments.get("kw_only"): + return cls(data=data, x=x) + else: + return cls(data, x) + + else: + + def create(data, x): + a1 = cls() + a1.data = data + a1.x = x + return a1 + + for n in self.dc_arg_names: + if dc_arguments[n]: + getattr(self, f"_assert_{n}")(cls, create, dc_arguments) + else: + getattr(self, f"_assert_not_{n}")(cls, create, dc_arguments) + + if dc_arguments["init"]: + a1 = cls(data="some data") + eq_(a1.x, 7) + + a1 = create("some data", 15) + some_int = a1.some_int + eq_( + dataclasses.asdict(a1), + {"data": "some data", "id": None, "some_int": some_int, "x": 15}, + ) + eq_(dataclasses.astuple(a1), (None, "some data", some_int, 15)) + + def _assert_unsafe_hash(self, cls, create, dc_arguments): + a1 = create("d1", 5) + hash(a1) + + def _assert_not_unsafe_hash(self, cls, create, dc_arguments): + a1 = create("d1", 5) + + if dc_arguments["eq"]: + with expect_raises(TypeError): + hash(a1) + else: + hash(a1) + + def _assert_eq(self, cls, create, dc_arguments): + a1 = create("d1", 5) + a2 = create("d2", 10) + a3 = create("d1", 5) + + eq_(a1, a3) + ne_(a1, a2) + + def _assert_not_eq(self, cls, create, dc_arguments): + a1 = create("d1", 5) + a2 = create("d2", 10) + a3 = create("d1", 5) + + eq_(a1, a1) + ne_(a1, a3) + ne_(a1, a2) + + def _assert_order(self, cls, create, dc_arguments): + is_false(create("g", 10) < create("b", 7)) + + is_true(create("g", 10) > create("b", 7)) + + is_false(create("g", 10) <= create("b", 7)) + + is_true(create("g", 10) >= create("b", 7)) + + eq_( + list(sorted([create("g", 10), create("g", 5), create("b", 7)])), + [ + create("b", 7), + create("g", 5), + create("g", 10), + ], + ) + + def _assert_not_order(self, cls, create, dc_arguments): + with expect_raises(TypeError): + create("g", 10) < create("b", 7) + + with expect_raises(TypeError): + create("g", 10) > create("b", 7) + + with expect_raises(TypeError): + create("g", 10) <= create("b", 7) + + with expect_raises(TypeError): + create("g", 10) >= create("b", 7) + + def _assert_repr(self, cls, create, dc_arguments): + assert "__repr__" in cls.__dict__ + a1 = create("some data", 12) + eq_regex(repr(a1), r".*A\(id=None, data='some data', x=12\)") + + def _assert_not_repr(self, cls, create, dc_arguments): + assert "__repr__" not in cls.__dict__ + + # if a superclass has __repr__, then we still get repr. + # so can't test this + # a1 = create("some data", 12) + # eq_regex(repr(a1), r"<.*A object at 0x.*>") + + def _assert_init(self, cls, create, dc_arguments): + if not dc_arguments.get("kw_only", False): + a1 = cls("some data", 5) + + eq_(a1.data, "some data") + eq_(a1.x, 5) + + a2 = cls(data="some data", x=5) + eq_(a2.data, "some data") + eq_(a2.x, 5) + + a3 = cls(data="some data") + eq_(a3.data, "some data") + eq_(a3.x, 7) + + def _assert_not_init(self, cls, create, dc_arguments): + with expect_raises(TypeError): + cls("Some data", 5) + + # behavior change in 2.1, even if init=False we set descriptor + # defaults + + a1 = cls(data="some data") + eq_(a1.data, "some data") + + eq_(a1.x, 7) + + a1 = cls() + eq_(a1.data, None) + + # but this breaks for synonyms + eq_(a1.x, 7) + + def _assert_match_args(self, cls, create, dc_arguments): + if not dc_arguments["kw_only"]: + is_true(len(cls.__match_args__) > 0) + + def _assert_not_match_args(self, cls, create, dc_arguments): + is_false(hasattr(cls, "__match_args__")) + + def _assert_kw_only(self, cls, create, dc_arguments): + if dc_arguments["init"]: + fas = pyinspect.getfullargspec(cls.__init__) + eq_(fas.args, ["self"]) + eq_( + len(fas.kwonlyargs), + len(pyinspect.signature(cls.__init__).parameters) - 1, + ) + + def _assert_not_kw_only(self, cls, create, dc_arguments): + if dc_arguments["init"]: + fas = pyinspect.getfullargspec(cls.__init__) + eq_( + len(fas.args), + len(pyinspect.signature(cls.__init__).parameters), + ) + eq_(fas.kwonlyargs, []) + + def test_dc_arguments_decorator( + self, + dc_argument_fixture, + mapped_expr_constructor, + registry: _RegistryType, + ): + @registry.mapped_as_dataclass(**dc_argument_fixture[0]) + class A: + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self._assert_cls(A, dc_argument_fixture[1]) + + def test_dc_arguments_base( + self, + dc_argument_fixture, + mapped_expr_constructor, + registry: _RegistryType, + ): + reg = registry + + class Base( + MappedAsDataclass, DeclarativeBase, **dc_argument_fixture[0] + ): + registry = reg + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self._assert_cls(A, dc_argument_fixture[1]) + + def test_dc_arguments_perclass( + self, + dc_argument_fixture, + mapped_expr_constructor, + decl_base: Type[DeclarativeBase], + ): + class A(MappedAsDataclass, decl_base, **dc_argument_fixture[0]): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_expr_constructor + + self._assert_cls(A, dc_argument_fixture[1]) + + def test_dc_arguments_override_base(self, registry: _RegistryType): + reg = registry + + class Base(MappedAsDataclass, DeclarativeBase, init=False, order=True): + registry = reg + + class A(Base, init=True, repr=False): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + some_int: Mapped[int] = mapped_column(init=False, repr=False) + + x: Mapped[Optional[int]] = mapped_column(default=7) + + effective = { + "init": True, + "repr": False, + "eq": True, + "order": True, + "unsafe_hash": False, + } + if compat.py310: + effective |= {"match_args": True, "kw_only": False} + self._assert_cls(A, effective) + + def test_dc_base_unsupported_argument(self, registry: _RegistryType): + reg = registry + with expect_raises(TypeError): + + class Base(MappedAsDataclass, DeclarativeBase, slots=True): + registry = reg + + class Base2(MappedAsDataclass, DeclarativeBase, order=True): + registry = reg + + with expect_raises(TypeError): + + class A(Base2, slots=False): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + def test_dc_decorator_unsupported_argument(self, registry: _RegistryType): + reg = registry + with expect_raises(TypeError): + + @registry.mapped_as_dataclass(slots=True) + class Base(DeclarativeBase): + registry = reg + + class Base2(MappedAsDataclass, DeclarativeBase, order=True): + registry = reg + + with expect_raises(TypeError): + + @registry.mapped_as_dataclass(slots=True) + class A(Base2): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + def test_dc_raise_for_slots( + self, + registry: _RegistryType, + decl_base: Type[DeclarativeBase], + ): + reg = registry + with expect_raises_message( + exc.ArgumentError, + r"Dataclass argument\(s\) 'slots', 'unknown' are not accepted", + ): + + class A(MappedAsDataclass, decl_base): + __tablename__ = "a" + _sa_apply_dc_transforms = {"slots": True, "unknown": 5} + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + with expect_raises_message( + exc.ArgumentError, + r"Dataclass argument\(s\) 'slots' are not accepted", + ): + + class Base(MappedAsDataclass, DeclarativeBase, order=True): + registry = reg + _sa_apply_dc_transforms = {"slots": True} + + with expect_raises_message( + exc.ArgumentError, + r"Dataclass argument\(s\) 'slots', 'unknown' are not accepted", + ): + + @reg.mapped + class C: + __tablename__ = "a" + _sa_apply_dc_transforms = {"slots": True, "unknown": 5} + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + @testing.variation("use_arguments", [True, False]) + @testing.combinations( + mapped_column, + lambda **kw: synonym("some_int", **kw), + lambda **kw: deferred(Column(Integer), **kw), + lambda **kw: composite("foo", **kw), + lambda **kw: relationship("Foo", **kw), + lambda **kw: association_proxy("foo", "bar", **kw), + argnames="construct", + ) + def test_attribute_options(self, use_arguments, construct): + if use_arguments: + kw = { + "init": False, + "repr": False, + "default": None, + "default_factory": list, + "compare": True, + "kw_only": False, + "hash": False, + } + exp = interfaces._AttributeOptions( + False, False, None, list, True, False, False + ) + else: + kw = {} + exp = interfaces._DEFAULT_ATTRIBUTE_OPTIONS + + prop = construct(**kw) + eq_(prop._attribute_options, exp) + + @testing.variation("use_arguments", [True, False]) + @testing.combinations( + lambda **kw: column_property(Column(Integer), **kw), + lambda **kw: query_expression(**kw), + argnames="construct", + ) + def test_ro_attribute_options(self, use_arguments, construct): + if use_arguments: + kw = { + "repr": False, + "compare": True, + } + exp = interfaces._AttributeOptions( + False, + False, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + True, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + ) + else: + kw = {} + exp = interfaces._DEFAULT_READONLY_ATTRIBUTE_OPTIONS + + prop = construct(**kw) + eq_(prop._attribute_options, exp) + + +class MixinColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): + """tests for #8718""" + + __dialect__ = "default" + + @testing.fixture + def model(self): + def go(use_mixin, use_inherits, mad_setup, dataclass_kw): + if use_mixin: + if mad_setup == "dc, mad": + + class BaseEntity( + DeclarativeBase, MappedAsDataclass, **dataclass_kw + ): + pass + + elif mad_setup == "mad, dc": + + class BaseEntity( + MappedAsDataclass, DeclarativeBase, **dataclass_kw + ): + pass + + elif mad_setup == "subclass": + + class BaseEntity(DeclarativeBase): + pass + + class IdMixin(MappedAsDataclass): + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + if mad_setup == "subclass": + + class A( + IdMixin, MappedAsDataclass, BaseEntity, **dataclass_kw + ): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + else: + + class A(IdMixin, BaseEntity): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + else: + if mad_setup == "dc, mad": + + class BaseEntity( + DeclarativeBase, MappedAsDataclass, **dataclass_kw + ): + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + elif mad_setup == "mad, dc": + + class BaseEntity( + MappedAsDataclass, DeclarativeBase, **dataclass_kw + ): + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + elif mad_setup == "subclass": + + class BaseEntity(MappedAsDataclass, DeclarativeBase): + id: Mapped[int] = mapped_column( + primary_key=True, init=False + ) + + if mad_setup == "subclass": + + class A(BaseEntity, **dataclass_kw): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + else: + + class A(BaseEntity): + __mapper_args__ = { + "polymorphic_on": "type", + "polymorphic_identity": "a", + } + + __tablename__ = "a" + type: Mapped[str] = mapped_column(String, init=False) + data: Mapped[str] = mapped_column(String, init=False) + + if use_inherits: + + class B(A): + __mapper_args__ = { + "polymorphic_identity": "b", + } + b_data: Mapped[str] = mapped_column(String, init=False) + + return B + else: + return A + + yield go + + @testing.combinations("inherits", "plain", argnames="use_inherits") + @testing.combinations("mixin", "base", argnames="use_mixin") + @testing.combinations( + "mad, dc", "dc, mad", "subclass", argnames="mad_setup" + ) + def test_mapping(self, model, use_inherits, use_mixin, mad_setup): + target_cls = model( + use_inherits=use_inherits == "inherits", + use_mixin=use_mixin == "mixin", + mad_setup=mad_setup, + dataclass_kw={}, + ) + + obj = target_cls() + assert "id" not in obj.__dict__ + + +class CompositeTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + def test_composite_setup(self, dc_decl_base: Type[MappedAsDataclass]): + @dataclasses.dataclass + class Point: + x: int + y: int + + class Edge(dc_decl_base): + __tablename__ = "edge" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + graph_id: Mapped[int] = mapped_column( + ForeignKey("graph.id"), init=False + ) + + start: Mapped[Point] = composite( + Point, mapped_column("x1"), mapped_column("y1"), default=None + ) + + end: Mapped[Point] = composite( + Point, mapped_column("x2"), mapped_column("y2"), default=None + ) + + class Graph(dc_decl_base): + __tablename__ = "graph" + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + edges: Mapped[List[Edge]] = relationship() + + Point.__qualname__ = "mymodel.Point" + Edge.__qualname__ = "mymodel.Edge" + Graph.__qualname__ = "mymodel.Graph" + g = Graph( + edges=[ + Edge(start=Point(1, 2), end=Point(3, 4)), + Edge(start=Point(7, 8), end=Point(5, 6)), + ] + ) + eq_( + repr(g), + "mymodel.Graph(id=None, edges=[mymodel.Edge(id=None, " + "graph_id=None, start=mymodel.Point(x=1, y=2), " + "end=mymodel.Point(x=3, y=4)), " + "mymodel.Edge(id=None, graph_id=None, " + "start=mymodel.Point(x=7, y=8), end=mymodel.Point(x=5, y=6))])", + ) + + def test_named_setup(self, dc_decl_base: Type[MappedAsDataclass]): + @dataclasses.dataclass + class Address: + street: str + state: str + zip_: str + + class User(dc_decl_base): + __tablename__ = "user" + + id: Mapped[int] = mapped_column( + primary_key=True, init=False, repr=False + ) + name: Mapped[str] = mapped_column() + + address: Mapped[Address] = composite( + Address, + mapped_column(), + mapped_column(), + mapped_column("zip"), + default=None, + ) + + Address.__qualname__ = "mymodule.Address" + User.__qualname__ = "mymodule.User" + u = User( + name="user 1", + address=Address("123 anywhere street", "NY", "12345"), + ) + u2 = User("u2") + eq_( + repr(u), + "mymodule.User(name='user 1', " + "address=mymodule.Address(street='123 anywhere street', " + "state='NY', zip_='12345'))", + ) + eq_(repr(u2), "mymodule.User(name='u2', address=None)") + + +class ReadOnlyAttrTest(fixtures.TestBase, testing.AssertsCompiledSQL): + """tests related to #9628""" + + __dialect__ = "default" + + @testing.combinations( + (query_expression,), (column_property,), argnames="construct" + ) + def test_default_behavior( + self, dc_decl_base: Type[MappedAsDataclass], construct + ): + class MyClass(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column() + + const: Mapped[str] = construct(data + "asdf") + + m1 = MyClass(data="foo") + eq_(m1, MyClass(data="foo")) + ne_(m1, MyClass(data="bar")) + + eq_regex( + repr(m1), + r".*MyClass\(id=None, data='foo', const=None\)", + ) + + @testing.combinations( + (query_expression,), (column_property,), argnames="construct" + ) + def test_no_repr_behavior( + self, dc_decl_base: Type[MappedAsDataclass], construct + ): + class MyClass(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column() + + const: Mapped[str] = construct(data + "asdf", repr=False) + + m1 = MyClass(data="foo") + + eq_regex( + repr(m1), + r".*MyClass\(id=None, data='foo'\)", + ) + + @testing.combinations( + (query_expression,), (column_property,), argnames="construct" + ) + def test_enable_compare( + self, dc_decl_base: Type[MappedAsDataclass], construct + ): + class MyClass(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column() + + const: Mapped[str] = construct(data + "asdf", compare=True) + + m1 = MyClass(data="foo") + eq_(m1, MyClass(data="foo")) + ne_(m1, MyClass(data="bar")) + + m2 = MyClass(data="foo") + m2.const = "some const" + ne_(m2, MyClass(data="foo")) + m3 = MyClass(data="foo") + m3.const = "some const" + eq_(m2, m3) + + +class UseDescriptorDefaultsTest(fixtures.TestBase, testing.AssertsCompiledSQL): + """tests related to #12168""" + + __dialect__ = "default" + + @testing.fixture(params=[True, False]) + def dc_decl_base(self, request, metadata): + _md = metadata + + udd = request.param + + class Base(MappedAsDataclass, DeclarativeBase): + use_descriptor_defaults = udd + + if not use_descriptor_defaults: + _sa_disable_descriptor_defaults = True + + metadata = _md + type_annotation_map = { + str: String().with_variant( + String(50), "mysql", "mariadb", "oracle" + ) + } + + yield Base + Base.registry.dispose() + + def test_mapped_column_default(self, dc_decl_base): + + class MyClass(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column(default="my_default") + + mc = MyClass() + eq_(mc.data, "my_default") + + if not MyClass.use_descriptor_defaults: + eq_(mc.__dict__["data"], "my_default") + else: + assert "data" not in mc.__dict__ + + eq_(MyClass.__table__.c.data.default.arg, "my_default") + + def test_mapped_column_default_and_insert_default(self, dc_decl_base): + with expect_raises_message( + exc.ArgumentError, + "The 'default' and 'insert_default' parameters of " + "Column are mutually exclusive", + ): + mapped_column(default="x", insert_default="y") + + def test_relationship_only_none_default(self): + with expect_raises_message( + exc.ArgumentError, + r"Only 'None' is accepted as dataclass " + r"default for a relationship\(\)", + ): + relationship(default="not none") + + @testing.variation("uselist_type", ["implicit", "m2o_explicit"]) + def test_relationship_only_nouselist_none_default( + self, dc_decl_base, uselist_type + ): + with expect_raises_message( + exc.ArgumentError, + rf"On relationship {'A.bs' if uselist_type.implicit else 'B.a'}, " + "the dataclass default for relationship " + "may only be set for a relationship that references a scalar " + "value, i.e. many-to-one or explicitly uselist=False", + ): + + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + + if uselist_type.implicit: + bs: Mapped[List["B"]] = relationship("B", default=None) + + class B(dc_decl_base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + data: Mapped[str] + + if uselist_type.m2o_explicit: + a: Mapped[List[A]] = relationship( + "A", uselist=True, default=None + ) + + dc_decl_base.registry.configure() + + def test_constructor_repr(self, dc_decl_base): + + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + + x: Mapped[Optional[int]] = mapped_column(default=None) + + bs: Mapped[List["B"]] = relationship( # noqa: F821 + default_factory=list + ) + + class B(dc_decl_base): + __tablename__ = "b" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] + a_id: Mapped[Optional[int]] = mapped_column( + ForeignKey("a.id"), init=False + ) + x: Mapped[Optional[int]] = mapped_column(default=None) + + A.__qualname__ = "some_module.A" + B.__qualname__ = "some_module.B" + + eq_( + pyinspect.getfullargspec(A.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x", "bs"], + varargs=None, + varkw=None, + defaults=( + (LoaderCallableStatus.DONT_SET, mock.ANY) + if A.use_descriptor_defaults + else (None, mock.ANY) + ), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + eq_( + pyinspect.getfullargspec(B.__init__), + pyinspect.FullArgSpec( + args=["self", "data", "x"], + varargs=None, + varkw=None, + defaults=( + (LoaderCallableStatus.DONT_SET,) + if B.use_descriptor_defaults + else (None,) + ), + kwonlyargs=[], + kwonlydefaults=None, + annotations={}, + ), + ) + + a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)]) + eq_( + repr(a2), + "some_module.A(id=None, data='10', x=5, " + "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), " + "some_module.B(id=None, data='data2', a_id=None, x=12)])", + ) + + a3 = A("data") + eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + + def test_defaults_if_no_init_dc_level( + self, dc_decl_base: Type[MappedAsDataclass] + ): + + class MyClass(dc_decl_base, init=False): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column(default="default_status") + + mc = MyClass() + if MyClass.use_descriptor_defaults: + # behavior change of honoring default when dataclass init=False + eq_(mc.data, "default_status") + else: + eq_(mc.data, None) # "default_status") + + def test_defaults_w_no_init_attr_level( + self, dc_decl_base: Type[MappedAsDataclass] + ): + + class MyClass(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: Mapped[str] = mapped_column( + default="default_status", init=False + ) + + mc = MyClass() + eq_(mc.data, "default_status") + + if MyClass.use_descriptor_defaults: + assert "data" not in mc.__dict__ + else: + eq_(mc.__dict__["data"], "default_status") + + @testing.variation("use_attr_init", [True, False]) + def test_fk_set_scenario(self, dc_decl_base, use_attr_init): + if use_attr_init: + attr_init_kw = {} + else: + attr_init_kw = {"init": False} + + class Parent(dc_decl_base): + __tablename__ = "parent" + id: Mapped[int] = mapped_column( + primary_key=True, autoincrement=False + ) + + class Child(dc_decl_base): + __tablename__ = "child" + id: Mapped[int] = mapped_column(primary_key=True) + parent_id: Mapped[Optional[int]] = mapped_column( + ForeignKey("parent.id"), default=None + ) + parent: Mapped[Optional[Parent]] = relationship( + default=None, **attr_init_kw + ) + + dc_decl_base.metadata.create_all(testing.db) + + with Session(testing.db) as sess: + p1 = Parent(id=14) + sess.add(p1) + sess.flush() + + # parent_id=14, parent=None but fk is kept + c1 = Child(id=7, parent_id=14) + sess.add(c1) + sess.flush() + + if Parent.use_descriptor_defaults: + assert c1.parent is p1 + else: + assert c1.parent is None + + @testing.variation("use_attr_init", [True, False]) + def test_merge_scenario(self, dc_decl_base, use_attr_init): + if use_attr_init: + attr_init_kw = {} + else: + attr_init_kw = {"init": False} + + class MyClass(dc_decl_base): + __tablename__ = "myclass" + + id: Mapped[int] = mapped_column( + primary_key=True, autoincrement=False + ) + name: Mapped[str] + status: Mapped[str] = mapped_column( + default="default_status", **attr_init_kw + ) + + dc_decl_base.metadata.create_all(testing.db) + + with Session(testing.db) as sess: + if use_attr_init: + u1 = MyClass(id=1, name="x", status="custom_status") + else: + u1 = MyClass(id=1, name="x") + u1.status = "custom_status" + sess.add(u1) + + sess.flush() + + u2 = sess.merge(MyClass(id=1, name="y")) + is_(u2, u1) + eq_(u2.name, "y") + + if MyClass.use_descriptor_defaults: + eq_(u2.status, "custom_status") + else: + # was overridden by the default in __dict__ + eq_(u2.status, "default_status") + + if use_attr_init: + u3 = sess.merge( + MyClass(id=1, name="z", status="default_status") + ) + else: + mc = MyClass(id=1, name="z") + mc.status = "default_status" + u3 = sess.merge(mc) + + is_(u3, u1) + eq_(u3.name, "z") + + # field was explicit so is overridden by merge + eq_(u3.status, "default_status") + + +class SynonymDescriptorDefaultTest(AssertsCompiledSQL, fixtures.TestBase): + """test new behaviors for synonyms given dataclasses descriptor defaults + introduced in 2.1. Related to #12168""" + + __dialect__ = "default" + + @testing.fixture(params=[True, False]) + def dc_decl_base(self, request, metadata): + _md = metadata + + udd = request.param + + class Base(MappedAsDataclass, DeclarativeBase): + use_descriptor_defaults = udd + + if not use_descriptor_defaults: + _sa_disable_descriptor_defaults = True + + metadata = _md + type_annotation_map = { + str: String().with_variant( + String(50), "mysql", "mariadb", "oracle" + ) + } + + yield Base + Base.registry.dispose() + + def test_syn_matches_col_default( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + some_int: Mapped[int] = mapped_column(default=7, init=False) + some_syn: Mapped[int] = synonym("some_int", default=7) + + a1 = A() + eq_(a1.some_syn, 7) + eq_(a1.some_int, 7) + + a1 = A(some_syn=10) + eq_(a1.some_syn, 10) + eq_(a1.some_int, 10) + + @testing.variation("some_int_init", [True, False]) + def test_syn_does_not_match_col_default( + self, dc_decl_base: Type[MappedAsDataclass], some_int_init + ): + with ( + expect_raises_message( + exc.ArgumentError, + "Synonym 'some_syn' default argument 10 must match the " + "dataclasses default value of proxied object 'some_int', " + "currently 7", + ) + if dc_decl_base.use_descriptor_defaults + else contextlib.nullcontext() + ): + + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + some_int: Mapped[int] = mapped_column( + default=7, init=bool(some_int_init) + ) + some_syn: Mapped[int] = synonym("some_int", default=10) + + @testing.variation("some_int_init", [True, False]) + def test_syn_requires_col_default( + self, dc_decl_base: Type[MappedAsDataclass], some_int_init + ): + with ( + expect_raises_message( + exc.ArgumentError, + "Synonym 'some_syn' default argument 10 must match the " + "dataclasses default value of proxied object 'some_int', " + "currently not set", + ) + if dc_decl_base.use_descriptor_defaults + else contextlib.nullcontext() + ): + + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + some_int: Mapped[int] = mapped_column(init=bool(some_int_init)) + some_syn: Mapped[int] = synonym("some_int", default=10) + + @testing.variation("intermediary_init", [True, False]) + @testing.variation("some_syn_2_first", [True, False]) + def test_syn_matches_syn_default_one( + self, + intermediary_init, + some_syn_2_first, + dc_decl_base: Type[MappedAsDataclass], + ): + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + if some_syn_2_first: + some_syn_2: Mapped[int] = synonym("some_syn", default=7) + + some_int: Mapped[int] = mapped_column(default=7, init=False) + some_syn: Mapped[int] = synonym( + "some_int", default=7, init=bool(intermediary_init) + ) + + if not some_syn_2_first: + some_syn_2: Mapped[int] = synonym("some_syn", default=7) + + a1 = A() + eq_(a1.some_syn_2, 7) + eq_(a1.some_syn, 7) + eq_(a1.some_int, 7) + + a1 = A(some_syn_2=10) + + if not A.use_descriptor_defaults: + if some_syn_2_first: + eq_(a1.some_syn_2, 7) + eq_(a1.some_syn, 7) + eq_(a1.some_int, 7) + else: + eq_(a1.some_syn_2, 10) + eq_(a1.some_syn, 10) + eq_(a1.some_int, 10) + else: + eq_(a1.some_syn_2, 10) + eq_(a1.some_syn, 10) + eq_(a1.some_int, 10) + + # here we have both some_syn and some_syn_2 in the constructor, + # which makes absolutely no sense to do in practice. + # the new 2.1 behavior we can see is better, however, having + # multiple synonyms in a chain with dataclasses with more than one + # of them in init is pretty much a bad idea + if intermediary_init: + a1 = A(some_syn_2=10, some_syn=12) + if some_syn_2_first: + eq_(a1.some_syn_2, 12) + eq_(a1.some_syn, 12) + eq_(a1.some_int, 12) + else: + eq_(a1.some_syn_2, 10) + eq_(a1.some_syn, 10) + eq_(a1.some_int, 10) diff --git a/test/typing/test_overloads.py b/test/typing/test_overloads.py index 355b4b568b..38ba568371 100644 --- a/test/typing/test_overloads.py +++ b/test/typing/test_overloads.py @@ -80,12 +80,10 @@ class OverloadTest(fixtures.TestBase): @testing.combinations( (CoreExecuteOptionsParameter, core_execution_options), - # https://github.com/python/cpython/issues/133701 - ( - OrmExecuteOptionsParameter, - orm_execution_options, - testing.requires.fail_python314b1, - ), + # note: this failed on python 3.14.0b1 + # due to https://github.com/python/cpython/issues/133701. + # something to keep in mind in case it breaks again + (OrmExecuteOptionsParameter, orm_execution_options), ) def test_typed_dicts(self, typ, expected): # we currently expect these to be union types with first entry diff --git a/tools/sync_test_files.py b/tools/sync_test_files.py index f855cd12c2..4c825c2d7f 100644 --- a/tools/sync_test_files.py +++ b/tools/sync_test_files.py @@ -6,6 +6,7 @@ from __future__ import annotations from pathlib import Path +from tempfile import NamedTemporaryFile from typing import Any from typing import Iterable @@ -34,7 +35,15 @@ def run_operation( source_data = Path(source).read_text().replace(remove_str, "") dest_data = header.format(source=source, this_file=this_file) + source_data - cmd.write_output_file_from_text(dest_data, dest) + with NamedTemporaryFile( + mode="w", + delete=False, + suffix=".py", + ) as buf: + buf.write(dest_data) + + cmd.run_black(buf.name) + cmd.write_output_file_from_tempfile(buf.name, dest) def main(file: str, cmd: code_writer_cmd) -> None: @@ -51,7 +60,11 @@ files = { "typed_annotation": { "source": "test/orm/declarative/test_typed_mapping.py", "dest": "test/orm/declarative/test_tm_future_annotations_sync.py", - } + }, + "dc_typed_annotation": { + "source": "test/orm/declarative/test_dc_transforms.py", + "dest": "test/orm/declarative/test_dc_transforms_future_anno_sync.py", + }, } if __name__ == "__main__": diff --git a/tox.ini b/tox.ini index 3012ec8748..a82d40812a 100644 --- a/tox.ini +++ b/tox.ini @@ -31,7 +31,7 @@ extras= # this can be limited to specific python versions IF there is no # greenlet available for the most recent python. otherwise # keep this present in all cases - py{38,39,310,311,312,313}: {[greenletextras]extras} + py{38,39,310,311,312,313,314}: {[greenletextras]extras} postgresql: postgresql postgresql: postgresql_pg8000 -- 2.47.2