]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Adjust dataclass rules to account for field w/ default
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Mar 2021 19:07:03 +0000 (15:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 18 Mar 2021 19:08:26 +0000 (15:08 -0400)
Fixed issue in new ORM dataclasses functionality where dataclass fields on
an abstract base or mixin that contained column or other mapping constructs
would not be mapped if they also included a "default" key within the
dataclasses.field() object.

Fixes: #6093
Change-Id: I628086ceb48ab1dd0702f239cd12be74074f58f1

doc/build/changelog/unreleased_14/6093.rst [new file with mode: 0644]
lib/sqlalchemy/orm/decl_base.py
test/orm/test_dataclasses_py3k.py

diff --git a/doc/build/changelog/unreleased_14/6093.rst b/doc/build/changelog/unreleased_14/6093.rst
new file mode 100644 (file)
index 0000000..95e4af6
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: bug, orm, dataclasses
+    :tickets: 6093
+
+    Fixed issue in new ORM dataclasses functionality where dataclass fields on
+    an abstract base or mixin that contained column or other mapping constructs
+    would not be mapped if they also included a "default" key within the
+    dataclasses.field() object.
+
index a21af192e0a0baf7006778863abdc6202a08c951..0a73288fd6421256f7acbd61b0ee134c7c21f6d4 100644 (file)
@@ -15,6 +15,7 @@ from sqlalchemy.orm import instrumentation
 from . import clsregistry
 from . import exc as orm_exc
 from . import mapper as mapperlib
+from .attributes import InstrumentedAttribute
 from .attributes import QueryableAttribute
 from .base import _is_mapped_class
 from .base import InspectionAttr
@@ -366,18 +367,24 @@ class _ClassScanMapperConfig(_MapperConfig):
                 elif ret is not absent:
                     return True
 
+                all_field = all_datacls_fields.get(key, absent)
+
                 ret = getattr(cls, key, obj)
 
                 if ret is obj:
                     return False
-                elif ret is not absent:
-                    return True
 
-                ret = all_datacls_fields.get(key, absent)
+                # for dataclasses, this could be the
+                # 'default' of the field.  so filter more specifically
+                # for an already-mapped InstrumentedAttribute
+                if ret is not absent and isinstance(
+                    ret, InstrumentedAttribute
+                ):
+                    return True
 
-                if ret is obj:
+                if all_field is obj:
                     return False
-                elif ret is not absent:
+                elif all_field is not absent:
                     return True
 
                 # can't find another attribute
@@ -401,15 +408,18 @@ class _ClassScanMapperConfig(_MapperConfig):
                     yield name, obj
 
         else:
+            field_names = set()
 
             def local_attributes_for_class():
-                for name, obj in vars(cls).items():
-                    yield name, obj
                 for field in util.local_dataclass_fields(cls):
                     if sa_dataclass_metadata_key in field.metadata:
+                        field_names.add(field.name)
                         yield field.name, field.metadata[
                             sa_dataclass_metadata_key
                         ]
+                for name, obj in vars(cls).items():
+                    if name not in field_names:
+                        yield name, obj
 
         return local_attributes_for_class
 
index 51c76c125a4760191c67bd234e1b43ab4eae92b3..ef1c12050edf4efe78ab8677eff7988ccc268948 100644 (file)
@@ -521,7 +521,7 @@ class FieldEmbeddedWMixinTest(FieldEmbeddedDeclarativeDataclassesTest):
         eq_(dataclasses.astuple(widget), (None, "Bar", True))
 
 
-class PropagationBlockTest(fixtures.TestBase):
+class PropagationFromMixinTest(fixtures.TestBase):
     __requires__ = ("dataclasses",)
 
     run_setup_classes = "each"
@@ -559,6 +559,25 @@ class PropagationBlockTest(fixtures.TestBase):
 
         run_test(CommonMixin)
 
+    def test_propagate_w_field_mixin_col_and_default(self, run_test):
+        @dataclasses.dataclass
+        class CommonMixin:
+            __sa_dataclass_metadata_key__ = "sa"
+
+            @declared_attr
+            def __tablename__(cls):
+                return cls.__name__.lower()
+
+            __table_args__ = {"mysql_engine": "InnoDB"}
+
+            timestamp: int = dataclasses.field(
+                init=False,
+                default=12,
+                metadata={"sa": Column(Integer, nullable=False)},
+            )
+
+        run_test(CommonMixin)
+
     @testing.fixture()
     def run_test(self):
         def go(CommonMixin):
@@ -603,3 +622,99 @@ class PropagationBlockTest(fixtures.TestBase):
         yield go
 
         clear_mappers()
+
+
+class PropagationFromAbstractTest(fixtures.TestBase):
+    __requires__ = ("dataclasses",)
+
+    run_setup_classes = "each"
+    run_setup_mappers = "each"
+
+    def test_propagate_w_plain_mixin_col(self, run_test):
+        @dataclasses.dataclass
+        class BaseType:
+            __sa_dataclass_metadata_key__ = "sa"
+
+            __table_args__ = {"mysql_engine": "InnoDB"}
+
+            discriminator: str = Column("type", String(50))
+            __mapper_args__ = dict(polymorphic_on=discriminator)
+            id: int = Column(Integer, primary_key=True)
+            value: int = Column(Integer())
+
+            timestamp: int = Column(Integer)
+
+        run_test(BaseType)
+
+    def test_propagate_w_field_mixin_col(self, run_test):
+        @dataclasses.dataclass
+        class BaseType:
+            __sa_dataclass_metadata_key__ = "sa"
+
+            __table_args__ = {"mysql_engine": "InnoDB"}
+
+            discriminator: str = Column("type", String(50))
+            __mapper_args__ = dict(polymorphic_on=discriminator)
+            id: int = Column(Integer, primary_key=True)
+            value: int = Column(Integer())
+
+            timestamp: int = dataclasses.field(
+                init=False,
+                metadata={"sa": Column(Integer, nullable=False)},
+            )
+
+        run_test(BaseType)
+
+    def test_propagate_w_field_mixin_col_and_default(self, run_test):
+        @dataclasses.dataclass
+        class BaseType:
+            __sa_dataclass_metadata_key__ = "sa"
+
+            __table_args__ = {"mysql_engine": "InnoDB"}
+
+            discriminator: str = Column("type", String(50))
+            __mapper_args__ = dict(polymorphic_on=discriminator)
+            id: int = Column(Integer, primary_key=True)
+            value: int = Column(Integer())
+
+            timestamp: int = dataclasses.field(
+                init=False,
+                default=None,
+                metadata={"sa": Column(Integer, nullable=False)},
+            )
+
+        run_test(BaseType)
+
+    @testing.fixture()
+    def run_test(self):
+        def go(BaseType):
+            declarative = registry().mapped
+
+            @declarative
+            @dataclasses.dataclass
+            class Single(BaseType):
+
+                __tablename__ = "single"
+                __mapper_args__ = dict(polymorphic_identity="type1")
+
+            @declarative
+            @dataclasses.dataclass
+            class Joined(Single):
+                __tablename__ = "joined"
+                __mapper_args__ = dict(polymorphic_identity="type2")
+                id = Column(Integer, ForeignKey("single.id"), primary_key=True)
+
+            eq_(Single.__table__.name, "single")
+            eq_(
+                list(Single.__table__.c.keys()),
+                ["type", "id", "value", "timestamp"],
+            )
+            eq_(Single.__table__.kwargs, {"mysql_engine": "InnoDB"})
+
+            eq_(Joined.__table__.name, "joined")
+            eq_(list(Joined.__table__.c.keys()), ["id"])
+            eq_(Joined.__table__.kwargs, {"mysql_engine": "InnoDB"})
+
+        yield go
+
+        clear_mappers()