]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
some typing fixes
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 3 Jun 2022 14:34:19 +0000 (10:34 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 3 Jun 2022 17:29:57 +0000 (13:29 -0400)
* ClassVar for decl fields, add __tablename__
* dataclasses require annotations for all fields.  For us,
  if no annotation, then skip that field as part of what is
  considered to be a "dataclass", as this matches the behavior
  of pyright right now.   We could alternatively raise on this
  use, which is what dataclasses does.   we should ask the pep
  people
* plain field that's just "str", "int", etc., with no value.
  Disallow it unless __allow_unmapped__ is set.   If field
  has dataclasses.field, Column, None, a value etc, it goes through,
  and when using dataclasses mixin all such fields are considered
  for the dataclass setup just like a dataclass.  Hopefully this
  does not have major backwards compat issues.  __allow_unmapped__
  can be set on the base class, mixins, etc., it's liberal for
  now in case people have this problem.
* accommodate for ClassVar, these are not considered at all for
  mapping.

Change-Id: Id743aa0456bade9a5d5832796caeecc3dc4accb7

lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
lib/sqlalchemy/orm/util.py
test/orm/declarative/test_dc_transforms.py
test/orm/declarative/test_typed_mapping.py

index feeda98f83851e951b7784ca58686b4405a89068..9c095c7401a7f7bbdd1a05d3eaa00d66619947ff 100644 (file)
@@ -683,13 +683,15 @@ class DeclarativeBase(
 
     """
 
-    registry: ClassVar[_RegistryType]
-    _sa_registry: ClassVar[_RegistryType]
-    metadata: ClassVar[MetaData]
-    __mapper__: ClassVar[Mapper[Any]]
-    __table__: Optional[FromClause]
-
     if typing.TYPE_CHECKING:
+        registry: ClassVar[_RegistryType]
+        _sa_registry: ClassVar[_RegistryType]
+        metadata: ClassVar[MetaData]
+
+        __mapper__: ClassVar[Mapper[Any]]
+        __table__: ClassVar[Optional[FromClause]]
+
+        __tablename__: ClassVar[Optional[str]]
 
         def __init__(self, **kw: Any):
             ...
index 1e7c0eaf6ab51446ed91af37568ead541fd6c157..ce044d7e0e975d365647f987f24c9769349bcc1b 100644 (file)
@@ -411,6 +411,8 @@ class _ClassScanMapperConfig(_MapperConfig):
         "inherits",
         "allow_dataclass_fields",
         "dataclass_setup_arguments",
+        "is_dataclass_prior_to_mapping",
+        "allow_unmapped_annotations",
     )
 
     registry: _RegistryType
@@ -430,6 +432,9 @@ class _ClassScanMapperConfig(_MapperConfig):
     mapper_args_fn: Optional[Callable[[], Dict[str, Any]]]
     inherits: Optional[Type[Any]]
 
+    is_dataclass_prior_to_mapping: bool
+    allow_unmapped_annotations: bool
+
     dataclass_setup_arguments: Optional[_DataclassArguments]
     """if the class has SQLAlchemy native dataclass parameters, where
     we will turn the class into a dataclass within the declarative mapping
@@ -440,7 +445,12 @@ class _ClassScanMapperConfig(_MapperConfig):
     allow_dataclass_fields: bool
     """if true, look for dataclass-processed Field objects on the target
     class as well as superclasses and extract ORM mapping directives from
-    the "metadata" attribute of each Field"""
+    the "metadata" attribute of each Field.
+
+    if False, dataclass fields can still be used, however they won't be
+    mapped.
+
+    """
 
     def __init__(
         self,
@@ -469,7 +479,13 @@ class _ClassScanMapperConfig(_MapperConfig):
             self.cls, "_sa_apply_dc_transforms", None
         )
 
-        cld = dataclasses.is_dataclass(cls_)
+        self.allow_unmapped_annotations = getattr(
+            self.cls, "__allow_unmapped__", False
+        )
+
+        self.is_dataclass_prior_to_mapping = cld = dataclasses.is_dataclass(
+            cls_
+        )
 
         sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__")
 
@@ -1007,19 +1023,39 @@ class _ClassScanMapperConfig(_MapperConfig):
         expect_mapped: Optional[bool],
         attr_value: Any,
     ) -> None:
+        if raw_annotation is None:
+            return
+
+        is_dataclass = self.is_dataclass_prior_to_mapping
+        allow_unmapped = self.allow_unmapped_annotations
 
         if expect_mapped is None:
-            expect_mapped = isinstance(attr_value, _MappedAttribute)
+            is_dataclass_field = isinstance(attr_value, dataclasses.Field)
+            expect_mapped = (
+                not is_dataclass_field
+                and not allow_unmapped
+                and (
+                    attr_value is None
+                    or isinstance(attr_value, _MappedAttribute)
+                )
+            )
+        else:
+            is_dataclass_field = False
 
+        is_dataclass_field = False
         extracted_mapped_annotation = _extract_mapped_subtype(
             raw_annotation,
             self.cls,
             name,
             type(attr_value),
             required=False,
-            is_dataclass_field=False,
-            expect_mapped=expect_mapped and not self.allow_dataclass_fields,
+            is_dataclass_field=is_dataclass_field,
+            expect_mapped=expect_mapped
+            and not is_dataclass,  # self.allow_dataclass_fields,
         )
+        if extracted_mapped_annotation is None:
+            # ClassVar can come out here
+            return
 
         self.collected_annotations[name] = (
             raw_annotation,
index c8c802d9f24853fb8e9b063ba38251a5219c471a..e1e78c99da94e735e36a8b084aa74ad914c4cdc4 100644 (file)
@@ -1970,6 +1970,7 @@ def _extract_mapped_subtype(
     required: bool,
     is_dataclass_field: bool,
     expect_mapped: bool = True,
+    raiseerr: bool = True,
 ) -> Optional[Union[type, str]]:
     """given an annotation, figure out if it's ``Mapped[something]`` and if
     so, return the ``something`` part.
@@ -2008,12 +2009,35 @@ def _extract_mapped_subtype(
                 our_annotated_str = anno_name
 
             if expect_mapped:
-                raise sa_exc.ArgumentError(
-                    f'Type annotation for "{cls.__name__}.{key}" '
-                    "should use the "
-                    f'syntax "Mapped[{our_annotated_str}]" or '
-                    f'"{attr_cls.__name__}[{our_annotated_str}]".'
-                )
+                if getattr(annotated, "__origin__", None) is typing.ClassVar:
+                    return None
+
+                if not raiseerr:
+                    return None
+
+                if attr_cls.__name__ == our_annotated_str or attr_cls is type(
+                    None
+                ):
+                    raise sa_exc.ArgumentError(
+                        f'Type annotation for "{cls.__name__}.{key}" '
+                        "should use the "
+                        f'syntax "Mapped[{our_annotated_str}]".  To leave '
+                        f"the attribute unmapped, use "
+                        f"ClassVar[{our_annotated_str}], assign a value to "
+                        f"the attribute, or "
+                        f"set __allow_unmapped__ = True on the class."
+                    )
+                else:
+                    raise sa_exc.ArgumentError(
+                        f'Type annotation for "{cls.__name__}.{key}" '
+                        "should use the "
+                        f'syntax "Mapped[{our_annotated_str}]" or '
+                        f'"{attr_cls.__name__}[{our_annotated_str}]".  To '
+                        f"leave the attribute unmapped, use "
+                        f"ClassVar[{our_annotated_str}], assign a value to "
+                        f"the attribute, or "
+                        f"set __allow_unmapped__ = True on the class."
+                    )
 
             else:
                 return annotated
index 308ebfeb17aa6b8b9b1707f1897c038ca1aef371..271b235966bf7c65aab3e5b72a6c207c2be48f3d 100644 (file)
@@ -2,6 +2,7 @@ import dataclasses
 import inspect as pyinspect
 from itertools import product
 from typing import Any
+from typing import ClassVar
 from typing import List
 from typing import Optional
 from typing import Set
@@ -59,8 +60,10 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
             __tablename__ = "b"
 
             id: Mapped[int] = mapped_column(primary_key=True, init=False)
-            a_id = mapped_column(ForeignKey("a.id"), init=False)
             data: Mapped[str]
+            a_id: Mapped[Optional[int]] = mapped_column(
+                ForeignKey("a.id"), init=False
+            )
             x: Mapped[Optional[int]] = mapped_column(default=None)
 
         A.__qualname__ = "some_module.A"
@@ -102,6 +105,32 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
         a3 = A("data")
         eq_(repr(a3), "some_module.A(id=None, data='data', x=None, bs=[])")
 
+    def test_no_anno_doesnt_go_into_dc(
+        self, dc_decl_base: Type[MappedAsDataclass]
+    ):
+        class User(dc_decl_base):
+            __tablename__: ClassVar[Optional[str]] = "user"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            username: Mapped[str]
+            password: Mapped[str]
+            addresses: Mapped[List["Address"]] = relationship(  # noqa: F821
+                default_factory=list
+            )
+
+        class Address(dc_decl_base):
+            __tablename__: ClassVar[Optional[str]] = "address"
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+
+            # should not be in the dataclass constructor
+            user_id = mapped_column(ForeignKey(User.id))
+
+            email_address: Mapped[str]
+
+        a1 = Address("email@address")
+        eq_(a1.email_address, "email@address")
+
     def test_basic_constructor_repr_cls_decorator(
         self, registry: _RegistryType
     ):
@@ -156,11 +185,13 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
         )
 
         a2 = A("10", x=5, bs=[B("data1"), B("data2", x=12)])
+
+        # note a_id isn't included because it wasn't annotated
         eq_(
             repr(a2),
             "some_module.A(id=None, data='10', x=5, "
-            "bs=[some_module.B(id=None, data='data1', a_id=None, x=None), "
-            "some_module.B(id=None, data='data2', a_id=None, x=12)])",
+            "bs=[some_module.B(id=None, data='data1', x=None), "
+            "some_module.B(id=None, data='data2', x=12)])",
         )
 
         a3 = A("data")
@@ -224,6 +255,50 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
         eq_(e1.engineer_name, "en")
         eq_(e1.primary_language, "pl")
 
+    def test_no_fields_wo_mapped_or_dc(
+        self, dc_decl_base: Type[MappedAsDataclass]
+    ):
+        """since I made this mistake in my own mapping video, lets have it
+        raise an error"""
+
+        with expect_raises_message(
+            exc.ArgumentError,
+            r'Type annotation for "A.data" should '
+            r'use the syntax "Mapped\[str\]".  '
+            r"To leave the attribute unmapped,",
+        ):
+
+            class A(dc_decl_base):
+                __tablename__ = "a"
+
+                id: Mapped[int] = mapped_column(primary_key=True, init=False)
+                data: str
+                ctrl_one: str = dataclasses.field()
+                some_field: int = dataclasses.field(default=5)
+
+    def test_allow_unmapped_fields_wo_mapped_or_dc(
+        self, dc_decl_base: Type[MappedAsDataclass]
+    ):
+        class A(dc_decl_base):
+            __tablename__ = "a"
+            __allow_unmapped__ = True
+
+            id: Mapped[int] = mapped_column(primary_key=True, init=False)
+            data: str
+            ctrl_one: str = dataclasses.field()
+            some_field: int = dataclasses.field(default=5)
+
+        a1 = A("data", "ctrl_one", 5)
+        eq_(
+            dataclasses.asdict(a1),
+            {
+                "ctrl_one": "ctrl_one",
+                "data": "data",
+                "id": None,
+                "some_field": 5,
+            },
+        )
+
     def test_integrated_dc(self, dc_decl_base: Type[MappedAsDataclass]):
         """We will be telling users "this is a dataclass that is also
         mapped". Therefore, they will want *any* kind of attribute to do what
@@ -237,17 +312,48 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
         class A(dc_decl_base):
             __tablename__ = "a"
 
-            ctrl_one: str
+            ctrl_one: str = dataclasses.field()
 
             id: Mapped[int] = mapped_column(primary_key=True, init=False)
             data: Mapped[str]
             some_field: int = dataclasses.field(default=5)
 
-            some_none_field: Optional[str] = None
+            some_none_field: Optional[str] = dataclasses.field(default=None)
+
+            some_other_int_field: int = 10
 
+        # some field is part of the constructor
         a1 = A("ctrlone", "datafield")
-        eq_(a1.some_field, 5)
-        eq_(a1.some_none_field, None)
+        eq_(
+            dataclasses.asdict(a1),
+            {
+                "ctrl_one": "ctrlone",
+                "data": "datafield",
+                "id": None,
+                "some_field": 5,
+                "some_none_field": None,
+                "some_other_int_field": 10,
+            },
+        )
+
+        a2 = A(
+            "ctrlone",
+            "datafield",
+            some_field=7,
+            some_other_int_field=12,
+            some_none_field="x",
+        )
+        eq_(
+            dataclasses.asdict(a2),
+            {
+                "ctrl_one": "ctrlone",
+                "data": "datafield",
+                "id": None,
+                "some_field": 7,
+                "some_none_field": "x",
+                "some_other_int_field": 12,
+            },
+        )
 
         # only Mapped[] is mapped
         self.assert_compile(select(A), "SELECT a.id, a.data FROM a")
@@ -260,10 +366,11 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
                     "data",
                     "some_field",
                     "some_none_field",
+                    "some_other_int_field",
                 ],
                 varargs=None,
                 varkw=None,
-                defaults=(5, None),
+                defaults=(5, None, 10),
                 kwonlyargs=[],
                 kwonlydefaults=None,
                 annotations={},
index ce8cd6bdf2e8afc0a2a051e229f3f2410c326710..01849a8ee9f216b8efbc9227dc2b613b7c07b67b 100644 (file)
@@ -1,6 +1,7 @@
 import dataclasses
 import datetime
 from decimal import Decimal
+from typing import ClassVar
 from typing import Dict
 from typing import Generic
 from typing import List
@@ -69,7 +70,9 @@ class DeclarativeBaseTest(fixtures.TestBase):
 
         class Tab(Base["Tab"]):
             __tablename__ = "foo"
-            a = Column(Integer, primary_key=True)
+
+            # old mypy plugin use
+            a: int = Column(Integer, primary_key=True)
 
         eq_(Tab.foo, 1)
         is_(Tab.__table__, inspect(Tab).local_table)
@@ -192,6 +195,88 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(User.__table__.c.data.nullable)
         assert isinstance(User.__table__.c.created_at.type, DateTime)
 
+    def test_i_have_a_classvar_on_my_class(self, decl_base):
+        class MyClass(decl_base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column(default="some default")
+
+            status: ClassVar[int]
+
+        m1 = MyClass(id=1, data=5)
+        assert "status" not in inspect(m1).mapper.attrs
+
+    def test_i_have_plain_or_column_attrs_on_my_class_w_values(
+        self, decl_base
+    ):
+        class MyClass(decl_base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column(default="some default")
+
+            old_column: str = Column(String)
+
+            # we assume this is intentional
+            status: int = 5
+
+        # it's mapped too
+        assert "old_column" in inspect(MyClass).attrs
+
+    def test_i_have_plain_attrs_on_my_class_disallowed(self, decl_base):
+        with expect_raises_message(
+            sa_exc.ArgumentError,
+            r'Type annotation for "MyClass.status" should use the syntax '
+            r'"Mapped\[int\]".  To leave the attribute unmapped, use '
+            r"ClassVar\[int\], assign a value to the attribute, or "
+            r"set __allow_unmapped__ = True on the class.",
+        ):
+
+            class MyClass(decl_base):
+                __tablename__ = "mytable"
+
+                id: Mapped[int] = mapped_column(primary_key=True)
+                data: Mapped[str] = mapped_column(default="some default")
+
+                # we assume this is not intentional.  because I made the
+                # same mistake myself :)
+                status: int
+
+    def test_i_have_plain_attrs_on_my_class_allowed(self, decl_base):
+        class MyClass(decl_base):
+            __tablename__ = "mytable"
+            __allow_unmapped__ = True
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column(default="some default")
+
+            status: int
+
+    def test_allow_unmapped_on_mixin(self, decl_base):
+        class AllowsUnmapped:
+            __allow_unmapped__ = True
+
+        class MyClass(AllowsUnmapped, decl_base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column(default="some default")
+
+            status: int
+
+    def test_allow_unmapped_on_base(self):
+        class Base(DeclarativeBase):
+            __allow_unmapped__ = True
+
+        class MyClass(Base):
+            __tablename__ = "mytable"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+            data: Mapped[str] = mapped_column(default="some default")
+
+            status: int
+
     def test_column_default(self, decl_base):
         class MyClass(decl_base):
             __tablename__ = "mytable"