]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
make anno-only Mapped[] column available for mixins
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Jul 2022 16:25:22 +0000 (12:25 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 15 Jul 2022 16:37:06 +0000 (12:37 -0400)
Documentation is relying on the recently improved
behavior of produce_column_copies() to make sure everything is
available on cls for a declared_attr.  therefore for anno-only
attribute, we also need to generate the mapped_column() up front
before scan is called.

noticed in pylance, allow @declared_attr to recognize
@classmethod also which allows letting typing tools know
something is explicitly a classmethod

Change-Id: I07ff1a642a75679f685914a33c674807929f4918

lib/sqlalchemy/orm/decl_api.py
lib/sqlalchemy/orm/decl_base.py
test/orm/declarative/test_mixin.py
test/orm/declarative/test_typed_mapping.py

index e6e69a9e07b97c19852bde27548b4e25f9745a92..05d6dacfb702c1d73d66649b8d893b00d8a33e15 100644 (file)
@@ -324,6 +324,16 @@ class declared_attr(interfaces._MappedAttribute[_T]):
         fn: _DeclaredAttrDecorated[_T],
         cascading: bool = False,
     ):
+        # suppport
+        # @declared_attr
+        # @classmethod
+        # def foo(cls) -> Mapped[thing]:
+        #    ...
+        # which seems to help typing tools interpret the fn as a classmethod
+        # for situations where needed
+        if isinstance(fn, classmethod):
+            fn = fn.__func__  # type: ignore
+
         self.fget = fn
         self._cascading = cascading
         self.__doc__ = fn.__doc__
index 62251fa2b458c68f943515dcd3b47507312dd311..108027dd50fd980e1dd13abaae350de932ec76b7 100644 (file)
@@ -737,6 +737,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                 locally_collected_columns = self._produce_column_copies(
                     local_attributes_for_class,
                     attribute_is_overridden,
+                    fixed_table,
                 )
             else:
                 locally_collected_columns = {}
@@ -828,9 +829,7 @@ class _ClassScanMapperConfig(_MapperConfig):
                     # acting like that for now.
 
                     if isinstance(obj, (Column, MappedColumn)):
-                        self._collect_annotation(
-                            name, annotation, is_dataclass_field, True, obj
-                        )
+                        self._collect_annotation(name, annotation, True, obj)
                         # already copied columns to the mapped class.
                         continue
                     elif isinstance(obj, MapperProperty):
@@ -913,23 +912,18 @@ class _ClassScanMapperConfig(_MapperConfig):
                         self._collect_annotation(
                             name,
                             obj._collect_return_annotation(),
-                            False,
                             True,
                             obj,
                         )
                     elif _is_mapped_annotation(annotation, cls):
-                        generated_obj = self._collect_annotation(
-                            name, annotation, is_dataclass_field, True, obj
-                        )
-                        if obj is None:
-                            if not fixed_table:
-                                collected_attributes[name] = (
-                                    generated_obj
-                                    if generated_obj is not None
-                                    else MappedColumn()
-                                )
-                        else:
-                            collected_attributes[name] = obj
+                        # Mapped annotation without any object.
+                        # product_column_copies should have handled this.
+                        # if future support for other MapperProperty,
+                        # then test if this name is already handled and
+                        # otherwise proceed to generate.
+                        if not fixed_table:
+                            assert name in collected_attributes
+                        continue
                     else:
                         # here, the attribute is some other kind of
                         # property that we assume is not part of the
@@ -953,12 +947,10 @@ class _ClassScanMapperConfig(_MapperConfig):
                         obj = obj.fget()
 
                     collected_attributes[name] = obj
-                    self._collect_annotation(
-                        name, annotation, True, False, obj
-                    )
+                    self._collect_annotation(name, annotation, False, obj)
                 else:
                     generated_obj = self._collect_annotation(
-                        name, annotation, False, None, obj
+                        name, annotation, None, obj
                     )
                     if (
                         obj is None
@@ -1060,7 +1052,6 @@ class _ClassScanMapperConfig(_MapperConfig):
         self,
         name: str,
         raw_annotation: _AnnotationScanType,
-        is_dataclass: bool,
         expect_mapped: Optional[bool],
         attr_value: Any,
     ) -> Any:
@@ -1128,6 +1119,7 @@ class _ClassScanMapperConfig(_MapperConfig):
             [], Iterable[Tuple[str, Any, Any, bool]]
         ],
         attribute_is_overridden: Callable[[str, Any], bool],
+        fixed_table: bool,
     ) -> Dict[str, Union[Column[Any], MappedColumn[Any]]]:
         cls = self.cls
         dict_ = self.clsdict_view
@@ -1136,7 +1128,19 @@ class _ClassScanMapperConfig(_MapperConfig):
         # copy mixin columns to the mapped class
 
         for name, obj, annotation, is_dataclass in attributes_for_class():
-            if isinstance(obj, (Column, MappedColumn)):
+            if (
+                not fixed_table
+                and obj is None
+                and _is_mapped_annotation(annotation, cls)
+            ):
+                obj = self._collect_annotation(name, annotation, True, obj)
+                if obj is None:
+                    obj = MappedColumn()
+
+                locally_collected_attributes[name] = obj
+                setattr(cls, name, obj)
+
+            elif isinstance(obj, (Column, MappedColumn)):
                 if attribute_is_overridden(name, obj):
                     # if column has been overridden
                     # (like by the InstrumentedAttribute of the
index 72e14ceebd06c7dd12d921422e574eeec1563cd3..a6851de5b449ca39a182121fe2c6575be25bd4da 100644 (file)
@@ -1,5 +1,7 @@
 from operator import is_not
 
+from typing_extensions import Annotated
+
 import sqlalchemy as sa
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
@@ -21,6 +23,7 @@ from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import deferred
 from sqlalchemy.orm import events as orm_events
 from sqlalchemy.orm import has_inherited_table
+from sqlalchemy.orm import Mapped
 from sqlalchemy.orm import registry
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import synonym
@@ -1646,6 +1649,93 @@ class DeclarativeMixinPropertyTest(
             m2,
         )
 
+    @testing.combinations(
+        "anno",
+        "anno_w_clsmeth",
+        "pep593",
+        "nonanno",
+        "legacy",
+        argnames="clstype",
+    )
+    def test_column_property_col_ref(self, decl_base, clstype):
+
+        if clstype == "anno":
+
+            class SomethingMixin:
+                x: Mapped[int]
+                y: Mapped[int] = mapped_column()
+
+                @declared_attr
+                def x_plus_y(cls) -> Mapped[int]:
+                    return column_property(cls.x + cls.y)
+
+        elif clstype == "anno_w_clsmeth":
+            # this form works better w/ pylance, so support it
+            class SomethingMixin:
+                x: Mapped[int]
+                y: Mapped[int] = mapped_column()
+
+                @declared_attr
+                @classmethod
+                def x_plus_y(cls) -> Mapped[int]:
+                    return column_property(cls.x + cls.y)
+
+        elif clstype == "nonanno":
+
+            class SomethingMixin:
+                x = mapped_column(Integer)
+                y = mapped_column(Integer)
+
+                @declared_attr
+                def x_plus_y(cls) -> Mapped[int]:
+                    return column_property(cls.x + cls.y)
+
+        elif clstype == "pep593":
+            myint = Annotated[int, mapped_column(Integer)]
+
+            class SomethingMixin:
+                x: Mapped[myint]
+                y: Mapped[myint]
+
+                @declared_attr
+                def x_plus_y(cls) -> Mapped[int]:
+                    return column_property(cls.x + cls.y)
+
+        elif clstype == "legacy":
+
+            class SomethingMixin:
+                x = Column(Integer)
+                y = Column(Integer)
+
+                @declared_attr
+                def x_plus_y(cls) -> Mapped[int]:
+                    return column_property(cls.x + cls.y)
+
+        else:
+            assert False
+
+        class Something(SomethingMixin, Base):
+            __tablename__ = "something"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        class SomethingElse(SomethingMixin, Base):
+            __tablename__ = "something_else"
+
+            id: Mapped[int] = mapped_column(primary_key=True)
+
+        # use the mixin twice, make sure columns are copied, etc
+        self.assert_compile(
+            select(Something.x_plus_y),
+            "SELECT something.x + something.y AS anon_1 FROM something",
+        )
+
+        self.assert_compile(
+            select(SomethingElse.x_plus_y),
+            "SELECT something_else.x + something_else.y AS anon_1 "
+            "FROM something_else",
+        )
+
     def test_doc(self):
         """test documentation transfer.
 
index cd45d96d1510912ee8e7928eec0831b7f7004847..c33aef9c45d83cb37403019e5edafd8e0b72bb43 100644 (file)
@@ -815,10 +815,9 @@ class MixinTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             __tablename__ = "a"
             id: Mapped[int] = mapped_column(primary_key=True)
 
-        # ordering of cols is TODO
-        eq_(A.__table__.c.keys(), ["id", "y", "name", "x"])
+        eq_(A.__table__.c.keys(), ["id", "name", "x", "y"])
 
-        self.assert_compile(select(A), "SELECT a.id, a.y, a.name, a.x FROM a")
+        self.assert_compile(select(A), "SELECT a.id, a.name, a.x, a.y FROM a")
 
     def test_mapped_column_omit_fn_fixed_table(self, decl_base):
         class MixinOne: