]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
limit new constructor scan thing for composites to dataclasses only
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Feb 2022 02:26:42 +0000 (21:26 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 22 Feb 2022 02:26:42 +0000 (21:26 -0500)
Fixes: #7753
Change-Id: Ibf92fa34097a7d6b39dc71c72253034e314bd6a1

lib/sqlalchemy/orm/descriptor_props.py
test/orm/test_composites.py

index 4616e40945ffc810312574369eed0dec0c8a88c7..d4d010cbe686346a560c71d450ed05b211b600f3 100644 (file)
@@ -12,6 +12,7 @@ as actively in the load/persist ORM loop.
 """
 from __future__ import annotations
 
+from dataclasses import is_dataclass
 import inspect
 import itertools
 import operator
@@ -248,12 +249,10 @@ class Composite(
         self.descriptor = property(fget, fset, fdel)
 
     @util.preload_module("sqlalchemy.orm.properties")
-    @util.preload_module("sqlalchemy.orm.decl_base")
     def declarative_scan(
         self, registry, cls, key, annotation, is_dataclass_field
     ):
         MappedColumn = util.preloaded.orm_properties.MappedColumn
-        decl_base = util.preloaded.orm_decl_base
 
         argument = _extract_mapped_subtype(
             annotation,
@@ -273,6 +272,17 @@ class Composite(
                     f"class argument"
                 )
             self.composite_class = argument
+
+        if is_dataclass(self.composite_class):
+            self._setup_for_dataclass(registry, cls, key)
+
+    @util.preload_module("sqlalchemy.orm.properties")
+    @util.preload_module("sqlalchemy.orm.decl_base")
+    def _setup_for_dataclass(self, registry, cls, key):
+        MappedColumn = util.preloaded.orm_properties.MappedColumn
+
+        decl_base = util.preloaded.orm_decl_base
+
         insp = inspect.signature(self.composite_class)
         for param, attr in itertools.zip_longest(
             insp.parameters.values(), self.attrs
@@ -289,7 +299,7 @@ class Composite(
             elif isinstance(attr, schema.Column):
                 decl_base._undefer_column_name(param.name, attr)
 
-        if not hasattr(cls, "__composite_values__"):
+        if not hasattr(self.composite_class, "__composite_values__"):
             getter = operator.attrgetter(
                 *[p.name for p in insp.parameters.values()]
             )
index afa3daf1356c2a09979d5cc73e1ac1907bb8e016..b8d9d900804f91a10db8b1385b8017869dc2c51e 100644 (file)
@@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
@@ -425,7 +426,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
         e = Edge()
         eq_(e.start, None)
 
-    def test_no_name_declarative(self, decl_base):
+    def test_no_name_declarative(self, decl_base, connection):
         """test #7751"""
 
         class Point:
@@ -467,6 +468,89 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
             "vertices.y2 FROM vertices",
         )
 
+        decl_base.metadata.create_all(connection)
+        s = Session(connection)
+        hv = Vertex(start=Point(1, 2), end=Point(3, 4))
+        s.add(hv)
+        s.commit()
+
+        is_(
+            hv,
+            s.scalars(
+                select(Vertex).where(Vertex.start == Point(1, 2))
+            ).first(),
+        )
+
+    def test_no_name_declarative_two(self, decl_base, connection):
+        """test #7752"""
+
+        class Point:
+            def __init__(self, x, y):
+                self.x = x
+                self.y = y
+
+            def __composite_values__(self):
+                return self.x, self.y
+
+            def __repr__(self):
+                return "Point(x=%r, y=%r)" % (self.x, self.y)
+
+            def __eq__(self, other):
+                return (
+                    isinstance(other, Point)
+                    and other.x == self.x
+                    and other.y == self.y
+                )
+
+            def __ne__(self, other):
+                return not self.__eq__(other)
+
+        class Vertex:
+            def __init__(self, start, end):
+                self.start = start
+                self.end = end
+
+            @classmethod
+            def _generate(self, x1, y1, x2, y2):
+                """generate a Vertex from a row"""
+                return Vertex(Point(x1, y1), Point(x2, y2))
+
+            def __composite_values__(self):
+                return (
+                    self.start.__composite_values__()
+                    + self.end.__composite_values__()
+                )
+
+        class HasVertex(decl_base):
+            __tablename__ = "has_vertex"
+            id = Column(Integer, primary_key=True)
+            x1 = Column(Integer)
+            y1 = Column(Integer)
+            x2 = Column(Integer)
+            y2 = Column(Integer)
+
+            vertex = composite(Vertex._generate, x1, y1, x2, y2)
+
+        self.assert_compile(
+            select(HasVertex),
+            "SELECT has_vertex.id, has_vertex.x1, has_vertex.y1, "
+            "has_vertex.x2, has_vertex.y2 FROM has_vertex",
+        )
+
+        decl_base.metadata.create_all(connection)
+        s = Session(connection)
+        hv = HasVertex(vertex=Vertex(Point(1, 2), Point(3, 4)))
+        s.add(hv)
+        s.commit()
+        is_(
+            hv,
+            s.scalars(
+                select(HasVertex).where(
+                    HasVertex.vertex == Vertex(Point(1, 2), Point(3, 4))
+                )
+            ).first(),
+        )
+
 
 class NestedTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
     @classmethod