From: Mike Bayer Date: Mon, 27 Jun 2022 16:56:27 +0000 (-0400) Subject: merge column args from Annotated left side X-Git-Tag: rel_2_0_0b1~206 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5950a9d05f1cc123009223baa1915cc15f3340a7;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git merge column args from Annotated left side because we are forced by pep-681 to use the argument "default", we need a way to have client Column default separate from a dataclasses level default. Also, pep-681 does not support deriving the descriptor function from Annotated, so allow a brief right side mapped_column() to be present that will have more column-centric arguments from the left side Annotated to be merged. Change-Id: I039be1628d498486ba013b2798e1392ed1cd7f9f --- diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 63fd3b76ed..bafad09f22 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -125,6 +125,7 @@ def mapped_column( unique: Optional[bool] = None, info: Optional[_InfoType] = None, onupdate: Optional[Any] = None, + insert_default: Optional[Any] = _NoArg.NO_ARG, server_default: Optional[_ServerDefaultType] = None, server_onupdate: Optional[FetchedValue] = None, quote: Optional[bool] = None, @@ -292,6 +293,17 @@ def mapped_column( ORM declarative process, and is not part of the :class:`_schema.Column` itself; instead, it indicates that this column should be "deferred" for loading as though mapped by :func:`_orm.deferred`. + :param default: This keyword argument, if present, is passed along to the + :class:`_schema.Column` constructor as the value of the + :paramref:`_schema.Column.default` parameter. However, as + :paramref:`_orm.mapped_column.default` is also consumed as a dataclasses + directive, the :paramref:`_orm.mapped_column.insert_default` parameter + should be used instead in a dataclasses context. + :param insert_default: Passed directly to the + :paramref:`_schema.Column.default` parameter; will supersede the value + of :paramref:`_orm.mapped_column.default` when present, however + :paramref:`_orm.mapped_column.default` will always apply to the + constructor default for a dataclasses mapping. :param \**kw: All remaining keyword argments are passed through to the constructor for the :class:`_schema.Column`. @@ -304,7 +316,11 @@ def mapped_column( name=name, type_=type_, autoincrement=autoincrement, - default=default, + insert_default=insert_default + if insert_default is not _NoArg.NO_ARG + else default + if default is not _NoArg.NO_ARG + else None, attribute_options=_AttributeOptions( init, repr, diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index d1faff1d96..7308b8fb12 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -505,18 +505,14 @@ class MappedColumn( if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS: if attr_opts.dataclasses_default_factory is not _NoArg.NO_ARG: self._has_dataclass_arguments = True - kw["default"] = attr_opts.dataclasses_default_factory - elif attr_opts.dataclasses_default is not _NoArg.NO_ARG: - kw["default"] = attr_opts.dataclasses_default - if ( + elif ( attr_opts.dataclasses_init is not _NoArg.NO_ARG or attr_opts.dataclasses_repr is not _NoArg.NO_ARG ): self._has_dataclass_arguments = True - if "default" in kw and kw["default"] is _NoArg.NO_ARG: - kw.pop("default") + kw["default"] = kw.pop("insert_default", None) self.deferred = kw.pop("deferred", False) self.column = cast("Column[_T]", Column(*arg, **kw)) @@ -525,6 +521,7 @@ class MappedColumn( None, SchemaConst.NULL_UNSPECIFIED, ) + util.set_creation_order(self) def _copy(self: Self, **kw: Any) -> Self: @@ -630,14 +627,47 @@ class MappedColumn( if not self._has_nullable: self.column.nullable = nullable + our_type = de_optionalize_union_types(argument) + if is_fwd_ref(our_type): + our_type = de_stringify_annotation(cls, our_type) + + use_args_from = None + if is_pep593(our_type): + our_type_is_pep593 = True + for elem in typing_get_args(our_type): + if isinstance(elem, MappedColumn): + use_args_from = elem + break + else: + our_type_is_pep593 = False + + if use_args_from is not None: + if use_args_from.column.primary_key: + self.column.primary_key = True + if use_args_from.column.default is not None: + self.column.default = use_args_from.column.default + if ( + use_args_from.column.server_default + and self.column.server_default is None + ): + self.column.server_default = ( + use_args_from.column.server_default + ) + + for const in use_args_from.column.constraints: + if not const._type_bound: + new_const = const._copy() + new_const._set_parent(self.column) + + for fk in use_args_from.column.foreign_keys: + if not fk.constraint: + new_fk = fk._copy() + new_fk._set_parent(self.column) + if sqltype._isnull and not self.column.foreign_keys: new_sqltype = None - our_type = de_optionalize_union_types(argument) - - if is_fwd_ref(our_type): - our_type = de_stringify_annotation(cls, our_type) - if is_pep593(our_type): + if our_type_is_pep593: checks = (our_type,) + typing_get_args(our_type) else: checks = (our_type,) diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 313300f93f..569603d793 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2529,7 +2529,6 @@ class ForeignKey(DialectKWArgs, SchemaItem): by the given string schema name. """ - fk = ForeignKey( self._get_colspec(schema=schema), use_alter=self.use_alter, diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index 46cbf4759e..6a13fc9055 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -15,9 +15,9 @@ from . import exclusions from .. import event from .. import schema from .. import types as sqltypes +from ..orm import mapped_column as _orm_mapped_column from ..util import OrderedDict - __all__ = ["Table", "Column"] table_options = {} @@ -60,15 +60,31 @@ def Table(*args, **kw) -> schema.Table: return schema.Table(*args, **kw) +def mapped_column(*args, **kw): + """An orm.mapped_column wrapper/hook for dialect-specific tweaks.""" + + return _schema_column(_orm_mapped_column, args, kw) + + def Column(*args, **kw): """A schema.Column wrapper/hook for dialect-specific tweaks.""" + return _schema_column(schema.Column, args, kw) + + +def _schema_column(factory, args, kw): test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} if not config.requirements.foreign_key_ddl.enabled_for_config(config): args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)] - col = schema.Column(*args, **kw) + construct = factory(*args, **kw) + + if factory is schema.Column: + col = construct + else: + col = construct.column + if test_opts.get("test_needs_autoincrement", False) and kw.get( "primary_key", False ): @@ -94,7 +110,7 @@ def Column(*args, **kw): ) event.listen(col, "after_parent_attach", add_seq, propagate=True) - return col + return construct class eq_type_affinity: diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index 271b235966..44976b5d88 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -9,9 +9,12 @@ from typing import Set from typing import Type from unittest import mock +from typing_extensions import Annotated + from sqlalchemy import Column from sqlalchemy import exc from sqlalchemy import ForeignKey +from sqlalchemy import func from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import select @@ -227,6 +230,55 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): default="d1", default_factory=lambda: "d2" ) + def test_combine_args_from_pep593(self, decl_base: Type[DeclarativeBase]): + """test that we can set up column-level defaults separate from + dataclass defaults + + """ + intpk = Annotated[int, mapped_column(primary_key=True)] + str30 = Annotated[ + str, mapped_column(String(30), insert_default=func.foo()) + ] + s_str30 = Annotated[ + str, + mapped_column(String(30), server_default="some server default"), + ] + user_fk = Annotated[int, mapped_column(ForeignKey("user_account.id"))] + + class User(MappedAsDataclass, decl_base): + __tablename__ = "user_account" + + # we need this case for dataclasses that can't derive things + # from Annotated yet at the typing level + id: Mapped[intpk] = mapped_column(init=False) + name_none: Mapped[Optional[str30]] = mapped_column(default=None) + name: Mapped[str30] = mapped_column(default="hi") + name2: Mapped[s_str30] = mapped_column(default="there") + addresses: Mapped[List["Address"]] = relationship( # noqa: F821 + back_populates="user", default_factory=list + ) + + class Address(MappedAsDataclass, decl_base): + __tablename__ = "address" + + id: Mapped[intpk] = mapped_column(init=False) + email_address: Mapped[str] + user_id: Mapped[user_fk] = mapped_column(init=False) + user: Mapped["User"] = relationship( + back_populates="addresses", default=None + ) + + is_true(User.__table__.c.id.primary_key) + is_true(User.__table__.c.name_none.default.arg.compare(func.foo())) + is_true(User.__table__.c.name.default.arg.compare(func.foo())) + eq_(User.__table__.c.name2.server_default.arg, "some server default") + + is_true(Address.__table__.c.user_id.references(User.__table__.c.id)) + u1 = User() + eq_(u1.name_none, None) + eq_(u1.name, "hi") + eq_(u1.name2, "there") + def test_inheritance(self, dc_decl_base: Type[MappedAsDataclass]): class Person(dc_decl_base): __tablename__ = "person" diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index 36840b2d7a..eb8c7dbdf4 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -20,7 +20,6 @@ from sqlalchemy.orm import declared_attr from sqlalchemy.orm import deferred from sqlalchemy.orm import events as orm_events from sqlalchemy.orm import has_inherited_table -from sqlalchemy.orm import mapped_column from sqlalchemy.orm import registry from sqlalchemy.orm import relationship from sqlalchemy.orm import synonym @@ -35,6 +34,7 @@ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import mapped_column from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import gc_collect from sqlalchemy.util import classproperty @@ -159,11 +159,12 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(obj.name, "testing") eq_(obj.foo(), "bar1") - def test_unique_column(self): + @testing.combinations(Column, mapped_column, argnames="_column") + def test_unique_column(self, _column): class MyMixin: - id = Column(Integer, primary_key=True) - value = Column(String, unique=True) + id = _column(Integer, primary_key=True) + value = _column(String, unique=True) class MyModel(Base, MyMixin): @@ -171,10 +172,11 @@ class DeclarativeMixinTest(DeclarativeTestBase): assert MyModel.__table__.c.value.unique - def test_hierarchical_bases_wbase(self): + @testing.combinations(Column, mapped_column, argnames="_column") + def test_hierarchical_bases_wbase(self, _column): class MyMixinParent: - id = Column( + id = _column( Integer, primary_key=True, test_needs_autoincrement=True ) @@ -183,12 +185,12 @@ class DeclarativeMixinTest(DeclarativeTestBase): class MyMixin(MyMixinParent): - baz = Column(String(100), nullable=False, index=True) + baz = _column(String(100), nullable=False, index=True) class MyModel(Base, MyMixin): __tablename__ = "test" - name = Column(String(100), nullable=False, index=True) + name = _column(String(100), nullable=False, index=True) Base.metadata.create_all(testing.db) session = fixture_session() @@ -201,10 +203,11 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(obj.foo(), "bar1") eq_(obj.baz, "fu") - def test_hierarchical_bases_wdecorator(self): + @testing.combinations(Column, mapped_column, argnames="_column") + def test_hierarchical_bases_wdecorator(self, _column): class MyMixinParent: - id = Column( + id = _column( Integer, primary_key=True, test_needs_autoincrement=True ) @@ -213,7 +216,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): class MyMixin(MyMixinParent): - baz = Column(String(100), nullable=False, index=True) + baz = _column(String(100), nullable=False, index=True) @mapper_registry.mapped class MyModel(MyMixin, object): @@ -232,22 +235,23 @@ class DeclarativeMixinTest(DeclarativeTestBase): eq_(obj.foo(), "bar1") eq_(obj.baz, "fu") - def test_mixin_overrides_wbase(self): + @testing.combinations(Column, mapped_column, argnames="_column") + def test_mixin_overrides_wbase(self, _column): """test a mixin that overrides a column on a superclass.""" class MixinA: - foo = Column(String(50)) + foo = _column(String(50)) class MixinB(MixinA): - foo = Column(Integer) + foo = _column(Integer) class MyModelA(Base, MixinA): __tablename__ = "testa" - id = Column(Integer, primary_key=True) + id = _column(Integer, primary_key=True) class MyModelB(Base, MixinB): __tablename__ = "testb" - id = Column(Integer, primary_key=True) + id = _column(Integer, primary_key=True) eq_(MyModelA.__table__.c.foo.type.__class__, String) eq_(MyModelB.__table__.c.foo.type.__class__, Integer) @@ -1120,7 +1124,7 @@ class DeclarativeMixinTest(DeclarativeTestBase): return cls.__name__.lower() __table_args__ = {"mysql_engine": "InnoDB"} - timestamp = Column(Integer) + timestamp = mapped_column(Integer) id = Column(Integer, primary_key=True) class Generic(Base, CommonMixin): diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index ae2773d1c7..6f60a652ff 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -559,6 +559,44 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): is_true(table.c.data_three.nullable) is_true(table.c.data_four.nullable) + def test_extract_fk_col_from_pep593( + self, decl_base: Type[DeclarativeBase] + ): + intpk = Annotated[int, mapped_column(primary_key=True)] + element_ref = Annotated[int, mapped_column(ForeignKey("element.id"))] + + class Element(decl_base): + __tablename__ = "element" + + id: Mapped[intpk] + + class RefElementOne(decl_base): + __tablename__ = "refone" + + id: Mapped[intpk] + other_id: Mapped[element_ref] + + class RefElementTwo(decl_base): + __tablename__ = "reftwo" + + id: Mapped[intpk] + some_id: Mapped[element_ref] + + assert Element.__table__ is not None + assert RefElementOne.__table__ is not None + assert RefElementTwo.__table__ is not None + + is_true( + RefElementOne.__table__.c.other_id.references( + Element.__table__.c.id + ) + ) + is_true( + RefElementTwo.__table__.c.some_id.references( + Element.__table__.c.id + ) + ) + def test_unions(self): our_type = Numeric(10, 2)