From: Mike Bayer Date: Fri, 3 Jun 2022 14:34:19 +0000 (-0400) Subject: some typing fixes X-Git-Tag: rel_2_0_0b1~270 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=47eff8b9e35dec9305d22484c17dd6c0649a876a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git some typing fixes * ClassVar for decl fields, add __tablename__ * dataclasses require annotations for all fields. For us, if no annotation, then skip that field as part of what is considered to be a "dataclass", as this matches the behavior of pyright right now. We could alternatively raise on this use, which is what dataclasses does. we should ask the pep people * plain field that's just "str", "int", etc., with no value. Disallow it unless __allow_unmapped__ is set. If field has dataclasses.field, Column, None, a value etc, it goes through, and when using dataclasses mixin all such fields are considered for the dataclass setup just like a dataclass. Hopefully this does not have major backwards compat issues. __allow_unmapped__ can be set on the base class, mixins, etc., it's liberal for now in case people have this problem. * accommodate for ClassVar, these are not considered at all for mapping. Change-Id: Id743aa0456bade9a5d5832796caeecc3dc4accb7 --- diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index feeda98f83..9c095c7401 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -683,13 +683,15 @@ class DeclarativeBase( """ - registry: ClassVar[_RegistryType] - _sa_registry: ClassVar[_RegistryType] - metadata: ClassVar[MetaData] - __mapper__: ClassVar[Mapper[Any]] - __table__: Optional[FromClause] - if typing.TYPE_CHECKING: + registry: ClassVar[_RegistryType] + _sa_registry: ClassVar[_RegistryType] + metadata: ClassVar[MetaData] + + __mapper__: ClassVar[Mapper[Any]] + __table__: ClassVar[Optional[FromClause]] + + __tablename__: ClassVar[Optional[str]] def __init__(self, **kw: Any): ... diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 1e7c0eaf6a..ce044d7e0e 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -411,6 +411,8 @@ class _ClassScanMapperConfig(_MapperConfig): "inherits", "allow_dataclass_fields", "dataclass_setup_arguments", + "is_dataclass_prior_to_mapping", + "allow_unmapped_annotations", ) registry: _RegistryType @@ -430,6 +432,9 @@ class _ClassScanMapperConfig(_MapperConfig): mapper_args_fn: Optional[Callable[[], Dict[str, Any]]] inherits: Optional[Type[Any]] + is_dataclass_prior_to_mapping: bool + allow_unmapped_annotations: bool + dataclass_setup_arguments: Optional[_DataclassArguments] """if the class has SQLAlchemy native dataclass parameters, where we will turn the class into a dataclass within the declarative mapping @@ -440,7 +445,12 @@ class _ClassScanMapperConfig(_MapperConfig): allow_dataclass_fields: bool """if true, look for dataclass-processed Field objects on the target class as well as superclasses and extract ORM mapping directives from - the "metadata" attribute of each Field""" + the "metadata" attribute of each Field. + + if False, dataclass fields can still be used, however they won't be + mapped. + + """ def __init__( self, @@ -469,7 +479,13 @@ class _ClassScanMapperConfig(_MapperConfig): self.cls, "_sa_apply_dc_transforms", None ) - cld = dataclasses.is_dataclass(cls_) + self.allow_unmapped_annotations = getattr( + self.cls, "__allow_unmapped__", False + ) + + self.is_dataclass_prior_to_mapping = cld = dataclasses.is_dataclass( + cls_ + ) sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__") @@ -1007,19 +1023,39 @@ class _ClassScanMapperConfig(_MapperConfig): expect_mapped: Optional[bool], attr_value: Any, ) -> None: + if raw_annotation is None: + return + + is_dataclass = self.is_dataclass_prior_to_mapping + allow_unmapped = self.allow_unmapped_annotations if expect_mapped is None: - expect_mapped = isinstance(attr_value, _MappedAttribute) + is_dataclass_field = isinstance(attr_value, dataclasses.Field) + expect_mapped = ( + not is_dataclass_field + and not allow_unmapped + and ( + attr_value is None + or isinstance(attr_value, _MappedAttribute) + ) + ) + else: + is_dataclass_field = False + is_dataclass_field = False extracted_mapped_annotation = _extract_mapped_subtype( raw_annotation, self.cls, name, type(attr_value), required=False, - is_dataclass_field=False, - expect_mapped=expect_mapped and not self.allow_dataclass_fields, + is_dataclass_field=is_dataclass_field, + expect_mapped=expect_mapped + and not is_dataclass, # self.allow_dataclass_fields, ) + if extracted_mapped_annotation is None: + # ClassVar can come out here + return self.collected_annotations[name] = ( raw_annotation, diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index c8c802d9f2..e1e78c99da 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -1970,6 +1970,7 @@ def _extract_mapped_subtype( required: bool, is_dataclass_field: bool, expect_mapped: bool = True, + raiseerr: bool = True, ) -> Optional[Union[type, str]]: """given an annotation, figure out if it's ``Mapped[something]`` and if so, return the ``something`` part. @@ -2008,12 +2009,35 @@ def _extract_mapped_subtype( our_annotated_str = anno_name if expect_mapped: - raise sa_exc.ArgumentError( - f'Type annotation for "{cls.__name__}.{key}" ' - "should use the " - f'syntax "Mapped[{our_annotated_str}]" or ' - f'"{attr_cls.__name__}[{our_annotated_str}]".' - ) + if getattr(annotated, "__origin__", None) is typing.ClassVar: + return None + + if not raiseerr: + return None + + if attr_cls.__name__ == our_annotated_str or attr_cls is type( + None + ): + raise sa_exc.ArgumentError( + f'Type annotation for "{cls.__name__}.{key}" ' + "should use the " + f'syntax "Mapped[{our_annotated_str}]". To leave ' + f"the attribute unmapped, use " + f"ClassVar[{our_annotated_str}], assign a value to " + f"the attribute, or " + f"set __allow_unmapped__ = True on the class." + ) + else: + raise sa_exc.ArgumentError( + f'Type annotation for "{cls.__name__}.{key}" ' + "should use the " + f'syntax "Mapped[{our_annotated_str}]" or ' + f'"{attr_cls.__name__}[{our_annotated_str}]". To ' + f"leave the attribute unmapped, use " + f"ClassVar[{our_annotated_str}], assign a value to " + f"the attribute, or " + f"set __allow_unmapped__ = True on the class." + ) else: return annotated diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index 308ebfeb17..271b235966 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -2,6 +2,7 @@ import dataclasses import inspect as pyinspect from itertools import product from typing import Any +from typing import ClassVar from typing import List from typing import Optional from typing import Set @@ -59,8 +60,10 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): __tablename__ = "b" id: Mapped[int] = mapped_column(primary_key=True, init=False) - a_id = mapped_column(ForeignKey("a.id"), init=False) data: Mapped[str] + a_id: Mapped[Optional[int]] = mapped_column( + ForeignKey("a.id"), init=False + ) x: Mapped[Optional[int]] = mapped_column(default=None) A.__qualname__ = "some_module.A" @@ -102,6 +105,32 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): a3 = A("data") eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])") + def test_no_anno_doesnt_go_into_dc( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class User(dc_decl_base): + __tablename__: ClassVar[Optional[str]] = "user" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + username: Mapped[str] + password: Mapped[str] + addresses: Mapped[List["Address"]] = relationship( # noqa: F821 + default_factory=list + ) + + class Address(dc_decl_base): + __tablename__: ClassVar[Optional[str]] = "address" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + + # should not be in the dataclass constructor + user_id = mapped_column(ForeignKey(User.id)) + + email_address: Mapped[str] + + a1 = Address("email@address") + eq_(a1.email_address, "email@address") + def test_basic_constructor_repr_cls_decorator( self, registry: _RegistryType ): @@ -156,11 +185,13 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): ) a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)]) + + # note a_id isn't included because it wasn't annotated eq_( repr(a2), "some_module.A(id=None, data='10', x=5, " - "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), " - "some_module.B(id=None, data='data2', a_id=None, x=12)])", + "bs=[some_module.B(id=None, data='data1', x=None), " + "some_module.B(id=None, data='data2', x=12)])", ) a3 = A("data") @@ -224,6 +255,50 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): eq_(e1.engineer_name, "en") eq_(e1.primary_language, "pl") + def test_no_fields_wo_mapped_or_dc( + self, dc_decl_base: Type[MappedAsDataclass] + ): + """since I made this mistake in my own mapping video, lets have it + raise an error""" + + with expect_raises_message( + exc.ArgumentError, + r'Type annotation for "A.data" should ' + r'use the syntax "Mapped\[str\]". ' + r"To leave the attribute unmapped,", + ): + + class A(dc_decl_base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: str + ctrl_one: str = dataclasses.field() + some_field: int = dataclasses.field(default=5) + + def test_allow_unmapped_fields_wo_mapped_or_dc( + self, dc_decl_base: Type[MappedAsDataclass] + ): + class A(dc_decl_base): + __tablename__ = "a" + __allow_unmapped__ = True + + id: Mapped[int] = mapped_column(primary_key=True, init=False) + data: str + ctrl_one: str = dataclasses.field() + some_field: int = dataclasses.field(default=5) + + a1 = A("data", "ctrl_one", 5) + eq_( + dataclasses.asdict(a1), + { + "ctrl_one": "ctrl_one", + "data": "data", + "id": None, + "some_field": 5, + }, + ) + def test_integrated_dc(self, dc_decl_base: Type[MappedAsDataclass]): """We will be telling users "this is a dataclass that is also mapped". Therefore, they will want *any* kind of attribute to do what @@ -237,17 +312,48 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): class A(dc_decl_base): __tablename__ = "a" - ctrl_one: str + ctrl_one: str = dataclasses.field() id: Mapped[int] = mapped_column(primary_key=True, init=False) data: Mapped[str] some_field: int = dataclasses.field(default=5) - some_none_field: Optional[str] = None + some_none_field: Optional[str] = dataclasses.field(default=None) + + some_other_int_field: int = 10 + # some field is part of the constructor a1 = A("ctrlone", "datafield") - eq_(a1.some_field, 5) - eq_(a1.some_none_field, None) + eq_( + dataclasses.asdict(a1), + { + "ctrl_one": "ctrlone", + "data": "datafield", + "id": None, + "some_field": 5, + "some_none_field": None, + "some_other_int_field": 10, + }, + ) + + a2 = A( + "ctrlone", + "datafield", + some_field=7, + some_other_int_field=12, + some_none_field="x", + ) + eq_( + dataclasses.asdict(a2), + { + "ctrl_one": "ctrlone", + "data": "datafield", + "id": None, + "some_field": 7, + "some_none_field": "x", + "some_other_int_field": 12, + }, + ) # only Mapped[] is mapped self.assert_compile(select(A), "SELECT a.id, a.data FROM a") @@ -260,10 +366,11 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): "data", "some_field", "some_none_field", + "some_other_int_field", ], varargs=None, varkw=None, - defaults=(5, None), + defaults=(5, None, 10), kwonlyargs=[], kwonlydefaults=None, annotations={}, diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index ce8cd6bdf2..01849a8ee9 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1,6 +1,7 @@ import dataclasses import datetime from decimal import Decimal +from typing import ClassVar from typing import Dict from typing import Generic from typing import List @@ -69,7 +70,9 @@ class DeclarativeBaseTest(fixtures.TestBase): class Tab(Base["Tab"]): __tablename__ = "foo" - a = Column(Integer, primary_key=True) + + # old mypy plugin use + a: int = Column(Integer, primary_key=True) eq_(Tab.foo, 1) is_(Tab.__table__, inspect(Tab).local_table) @@ -192,6 +195,88 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(User.__table__.c.data.nullable) assert isinstance(User.__table__.c.created_at.type, DateTime) + def test_i_have_a_classvar_on_my_class(self, decl_base): + class MyClass(decl_base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(default="some default") + + status: ClassVar[int] + + m1 = MyClass(id=1, data=5) + assert "status" not in inspect(m1).mapper.attrs + + def test_i_have_plain_or_column_attrs_on_my_class_w_values( + self, decl_base + ): + class MyClass(decl_base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(default="some default") + + old_column: str = Column(String) + + # we assume this is intentional + status: int = 5 + + # it's mapped too + assert "old_column" in inspect(MyClass).attrs + + def test_i_have_plain_attrs_on_my_class_disallowed(self, decl_base): + with expect_raises_message( + sa_exc.ArgumentError, + r'Type annotation for "MyClass.status" should use the syntax ' + r'"Mapped\[int\]". To leave the attribute unmapped, use ' + r"ClassVar\[int\], assign a value to the attribute, or " + r"set __allow_unmapped__ = True on the class.", + ): + + class MyClass(decl_base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(default="some default") + + # we assume this is not intentional. because I made the + # same mistake myself :) + status: int + + def test_i_have_plain_attrs_on_my_class_allowed(self, decl_base): + class MyClass(decl_base): + __tablename__ = "mytable" + __allow_unmapped__ = True + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(default="some default") + + status: int + + def test_allow_unmapped_on_mixin(self, decl_base): + class AllowsUnmapped: + __allow_unmapped__ = True + + class MyClass(AllowsUnmapped, decl_base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(default="some default") + + status: int + + def test_allow_unmapped_on_base(self): + class Base(DeclarativeBase): + __allow_unmapped__ = True + + class MyClass(Base): + __tablename__ = "mytable" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] = mapped_column(default="some default") + + status: int + def test_column_default(self, decl_base): class MyClass(decl_base): __tablename__ = "mytable"