"""
from __future__ import annotations
+from dataclasses import is_dataclass
import inspect
import itertools
import operator
self.descriptor = property(fget, fset, fdel)
@util.preload_module("sqlalchemy.orm.properties")
- @util.preload_module("sqlalchemy.orm.decl_base")
def declarative_scan(
self, registry, cls, key, annotation, is_dataclass_field
):
MappedColumn = util.preloaded.orm_properties.MappedColumn
- decl_base = util.preloaded.orm_decl_base
argument = _extract_mapped_subtype(
annotation,
f"class argument"
)
self.composite_class = argument
+
+ if is_dataclass(self.composite_class):
+ self._setup_for_dataclass(registry, cls, key)
+
+ @util.preload_module("sqlalchemy.orm.properties")
+ @util.preload_module("sqlalchemy.orm.decl_base")
+ def _setup_for_dataclass(self, registry, cls, key):
+ MappedColumn = util.preloaded.orm_properties.MappedColumn
+
+ decl_base = util.preloaded.orm_decl_base
+
insp = inspect.signature(self.composite_class)
for param, attr in itertools.zip_longest(
insp.parameters.values(), self.attrs
elif isinstance(attr, schema.Column):
decl_base._undefer_column_name(param.name, attr)
- if not hasattr(cls, "__composite_values__"):
+ if not hasattr(self.composite_class, "__composite_values__"):
getter = operator.attrgetter(
*[p.name for p in insp.parameters.values()]
)
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
e = Edge()
eq_(e.start, None)
- def test_no_name_declarative(self, decl_base):
+ def test_no_name_declarative(self, decl_base, connection):
"""test #7751"""
class Point:
"vertices.y2 FROM vertices",
)
+ decl_base.metadata.create_all(connection)
+ s = Session(connection)
+ hv = Vertex(start=Point(1, 2), end=Point(3, 4))
+ s.add(hv)
+ s.commit()
+
+ is_(
+ hv,
+ s.scalars(
+ select(Vertex).where(Vertex.start == Point(1, 2))
+ ).first(),
+ )
+
+ def test_no_name_declarative_two(self, decl_base, connection):
+ """test #7752"""
+
+ class Point:
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __composite_values__(self):
+ return self.x, self.y
+
+ def __repr__(self):
+ return "Point(x=%r, y=%r)" % (self.x, self.y)
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, Point)
+ and other.x == self.x
+ and other.y == self.y
+ )
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ class Vertex:
+ def __init__(self, start, end):
+ self.start = start
+ self.end = end
+
+ @classmethod
+ def _generate(self, x1, y1, x2, y2):
+ """generate a Vertex from a row"""
+ return Vertex(Point(x1, y1), Point(x2, y2))
+
+ def __composite_values__(self):
+ return (
+ self.start.__composite_values__()
+ + self.end.__composite_values__()
+ )
+
+ class HasVertex(decl_base):
+ __tablename__ = "has_vertex"
+ id = Column(Integer, primary_key=True)
+ x1 = Column(Integer)
+ y1 = Column(Integer)
+ x2 = Column(Integer)
+ y2 = Column(Integer)
+
+ vertex = composite(Vertex._generate, x1, y1, x2, y2)
+
+ self.assert_compile(
+ select(HasVertex),
+ "SELECT has_vertex.id, has_vertex.x1, has_vertex.y1, "
+ "has_vertex.x2, has_vertex.y2 FROM has_vertex",
+ )
+
+ decl_base.metadata.create_all(connection)
+ s = Session(connection)
+ hv = HasVertex(vertex=Vertex(Point(1, 2), Point(3, 4)))
+ s.add(hv)
+ s.commit()
+ is_(
+ hv,
+ s.scalars(
+ select(HasVertex).where(
+ HasVertex.vertex == Vertex(Point(1, 2), Point(3, 4))
+ )
+ ).first(),
+ )
+
class NestedTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
@classmethod