From: Mike Bayer Date: Tue, 22 Feb 2022 02:26:42 +0000 (-0500) Subject: limit new constructor scan thing for composites to dataclasses only X-Git-Tag: rel_2_0_0b1~474 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=90ecbc817b68fa8916bd06b8ebf076b1ec4d9232;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git limit new constructor scan thing for composites to dataclasses only Fixes: #7753 Change-Id: Ibf92fa34097a7d6b39dc71c72253034e314bd6a1 --- diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 4616e40945..d4d010cbe6 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -12,6 +12,7 @@ as actively in the load/persist ORM loop. """ from __future__ import annotations +from dataclasses import is_dataclass import inspect import itertools import operator @@ -248,12 +249,10 @@ class Composite( self.descriptor = property(fget, fset, fdel) @util.preload_module("sqlalchemy.orm.properties") - @util.preload_module("sqlalchemy.orm.decl_base") def declarative_scan( self, registry, cls, key, annotation, is_dataclass_field ): MappedColumn = util.preloaded.orm_properties.MappedColumn - decl_base = util.preloaded.orm_decl_base argument = _extract_mapped_subtype( annotation, @@ -273,6 +272,17 @@ class Composite( f"class argument" ) self.composite_class = argument + + if is_dataclass(self.composite_class): + self._setup_for_dataclass(registry, cls, key) + + @util.preload_module("sqlalchemy.orm.properties") + @util.preload_module("sqlalchemy.orm.decl_base") + def _setup_for_dataclass(self, registry, cls, key): + MappedColumn = util.preloaded.orm_properties.MappedColumn + + decl_base = util.preloaded.orm_decl_base + insp = inspect.signature(self.composite_class) for param, attr in itertools.zip_longest( insp.parameters.values(), self.attrs @@ -289,7 +299,7 @@ class Composite( elif isinstance(attr, schema.Column): decl_base._undefer_column_name(param.name, attr) - if not hasattr(cls, "__composite_values__"): + if not hasattr(self.composite_class, "__composite_values__"): getter = operator.attrgetter( *[p.name for p in insp.parameters.values()] ) diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index afa3daf135..b8d9d90080 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -14,6 +14,7 @@ from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -425,7 +426,7 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): e = Edge() eq_(e.start, None) - def test_no_name_declarative(self, decl_base): + def test_no_name_declarative(self, decl_base, connection): """test #7751""" class Point: @@ -467,6 +468,89 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): "vertices.y2 FROM vertices", ) + decl_base.metadata.create_all(connection) + s = Session(connection) + hv = Vertex(start=Point(1, 2), end=Point(3, 4)) + s.add(hv) + s.commit() + + is_( + hv, + s.scalars( + select(Vertex).where(Vertex.start == Point(1, 2)) + ).first(), + ) + + def test_no_name_declarative_two(self, decl_base, connection): + """test #7752""" + + class Point: + def __init__(self, x, y): + self.x = x + self.y = y + + def __composite_values__(self): + return self.x, self.y + + def __repr__(self): + return "Point(x=%r, y=%r)" % (self.x, self.y) + + def __eq__(self, other): + return ( + isinstance(other, Point) + and other.x == self.x + and other.y == self.y + ) + + def __ne__(self, other): + return not self.__eq__(other) + + class Vertex: + def __init__(self, start, end): + self.start = start + self.end = end + + @classmethod + def _generate(self, x1, y1, x2, y2): + """generate a Vertex from a row""" + return Vertex(Point(x1, y1), Point(x2, y2)) + + def __composite_values__(self): + return ( + self.start.__composite_values__() + + self.end.__composite_values__() + ) + + class HasVertex(decl_base): + __tablename__ = "has_vertex" + id = Column(Integer, primary_key=True) + x1 = Column(Integer) + y1 = Column(Integer) + x2 = Column(Integer) + y2 = Column(Integer) + + vertex = composite(Vertex._generate, x1, y1, x2, y2) + + self.assert_compile( + select(HasVertex), + "SELECT has_vertex.id, has_vertex.x1, has_vertex.y1, " + "has_vertex.x2, has_vertex.y2 FROM has_vertex", + ) + + decl_base.metadata.create_all(connection) + s = Session(connection) + hv = HasVertex(vertex=Vertex(Point(1, 2), Point(3, 4))) + s.add(hv) + s.commit() + is_( + hv, + s.scalars( + select(HasVertex).where( + HasVertex.vertex == Vertex(Point(1, 2), Point(3, 4)) + ) + ).first(), + ) + class NestedTest(fixtures.MappedTest, testing.AssertsCompiledSQL): @classmethod