]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
merge column args from Annotated left side
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 27 Jun 2022 16:56:27 +0000 (12:56 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 28 Jun 2022 17:36:06 +0000 (13:36 -0400)
because we are forced by pep-681 to use the argument
"default", we need a way to have client Column default
separate from a dataclasses level default.  Also, pep-681
does not support deriving the descriptor function from
Annotated, so allow a brief right side mapped_column() to
be present that will have more column-centric arguments
from the left side Annotated to be merged.

Change-Id: I039be1628d498486ba013b2798e1392ed1cd7f9f

lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/testing/schema.py
test/orm/declarative/test_dc_transforms.py
test/orm/declarative/test_mixin.py
test/orm/declarative/test_typed_mapping.py

index 63fd3b76ed4332fe5e322624fa41f62bf8a8c743..bafad09f221341bbca67eb44fa348064abf0572e 100644 (file)
@@ -125,6 +125,7 @@ def mapped_column(
     unique: Optional[bool] = None,
     info: Optional[_InfoType] = None,
     onupdate: Optional[Any] = None,
+    insert_default: Optional[Any] = _NoArg.NO_ARG,
     server_default: Optional[_ServerDefaultType] = None,
     server_onupdate: Optional[FetchedValue] = None,
     quote: Optional[bool] = None,
@@ -292,6 +293,17 @@ def mapped_column(
      ORM declarative process, and is not part of the :class:`_schema.Column`
      itself; instead, it indicates that this column should be "deferred" for
      loading as though mapped by :func:`_orm.deferred`.
+    :param default: This keyword argument, if present, is passed along to the
+     :class:`_schema.Column` constructor as the value of the
+     :paramref:`_schema.Column.default` parameter.  However, as
+     :paramref:`_orm.mapped_column.default` is also consumed as a dataclasses
+     directive, the :paramref:`_orm.mapped_column.insert_default` parameter
+     should be used instead in a dataclasses context.
+    :param insert_default: Passed directly to the
+     :paramref:`_schema.Column.default` parameter; will supersede the value
+     of :paramref:`_orm.mapped_column.default` when present, however
+     :paramref:`_orm.mapped_column.default` will always apply to the
+     constructor default for a dataclasses mapping.
     :param \**kw: All remaining keyword argments are passed through to the
      constructor for the :class:`_schema.Column`.
 
@@ -304,7 +316,11 @@ def mapped_column(
         name=name,
         type_=type_,
         autoincrement=autoincrement,
-        default=default,
+        insert_default=insert_default
+        if insert_default is not _NoArg.NO_ARG
+        else default
+        if default is not _NoArg.NO_ARG
+        else None,
         attribute_options=_AttributeOptions(
             init,
             repr,
index d1faff1d964209fc6ab443f376d2e5d7bb084274..7308b8fb1250c58f1e8115e797ca7135be986f42 100644 (file)
@@ -505,18 +505,14 @@ class MappedColumn(
         if attr_opts is not None and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS:
             if attr_opts.dataclasses_default_factory is not _NoArg.NO_ARG:
                 self._has_dataclass_arguments = True
-                kw["default"] = attr_opts.dataclasses_default_factory
-            elif attr_opts.dataclasses_default is not _NoArg.NO_ARG:
-                kw["default"] = attr_opts.dataclasses_default
 
-            if (
+            elif (
                 attr_opts.dataclasses_init is not _NoArg.NO_ARG
                 or attr_opts.dataclasses_repr is not _NoArg.NO_ARG
             ):
                 self._has_dataclass_arguments = True
 
-        if "default" in kw and kw["default"] is _NoArg.NO_ARG:
-            kw.pop("default")
+        kw["default"] = kw.pop("insert_default", None)
 
         self.deferred = kw.pop("deferred", False)
         self.column = cast("Column[_T]", Column(*arg, **kw))
@@ -525,6 +521,7 @@ class MappedColumn(
             None,
             SchemaConst.NULL_UNSPECIFIED,
         )
+
         util.set_creation_order(self)
 
     def _copy(self: Self, **kw: Any) -> Self:
@@ -630,14 +627,47 @@ class MappedColumn(
         if not self._has_nullable:
             self.column.nullable = nullable
 
+        our_type = de_optionalize_union_types(argument)
+        if is_fwd_ref(our_type):
+            our_type = de_stringify_annotation(cls, our_type)
+
+        use_args_from = None
+        if is_pep593(our_type):
+            our_type_is_pep593 = True
+            for elem in typing_get_args(our_type):
+                if isinstance(elem, MappedColumn):
+                    use_args_from = elem
+                    break
+        else:
+            our_type_is_pep593 = False
+
+        if use_args_from is not None:
+            if use_args_from.column.primary_key:
+                self.column.primary_key = True
+            if use_args_from.column.default is not None:
+                self.column.default = use_args_from.column.default
+            if (
+                use_args_from.column.server_default
+                and self.column.server_default is None
+            ):
+                self.column.server_default = (
+                    use_args_from.column.server_default
+                )
+
+            for const in use_args_from.column.constraints:
+                if not const._type_bound:
+                    new_const = const._copy()
+                    new_const._set_parent(self.column)
+
+            for fk in use_args_from.column.foreign_keys:
+                if not fk.constraint:
+                    new_fk = fk._copy()
+                    new_fk._set_parent(self.column)
+
         if sqltype._isnull and not self.column.foreign_keys:
             new_sqltype = None
-            our_type = de_optionalize_union_types(argument)
-
-            if is_fwd_ref(our_type):
-                our_type = de_stringify_annotation(cls, our_type)
 
-            if is_pep593(our_type):
+            if our_type_is_pep593:
                 checks = (our_type,) + typing_get_args(our_type)
             else:
                 checks = (our_type,)
index 313300f93fd88387341ba33adf24d3371adec9f7..569603d79341f7f42b86761dd061a36d9fdec5f0 100644 (file)
@@ -2529,7 +2529,6 @@ class ForeignKey(DialectKWArgs, SchemaItem):
           by the given string schema name.
 
         """
-
         fk = ForeignKey(
             self._get_colspec(schema=schema),
             use_alter=self.use_alter,
index 46cbf4759ed31e099ebaca486eb6485c89f8c0e5..6a13fc905597189a02619aafb4cf1d95b88431de 100644 (file)
@@ -15,9 +15,9 @@ from . import exclusions
 from .. import event
 from .. import schema
 from .. import types as sqltypes
+from ..orm import mapped_column as _orm_mapped_column
 from ..util import OrderedDict
 
-
 __all__ = ["Table", "Column"]
 
 table_options = {}
@@ -60,15 +60,31 @@ def Table(*args, **kw) -> schema.Table:
     return schema.Table(*args, **kw)
 
 
+def mapped_column(*args, **kw):
+    """An orm.mapped_column wrapper/hook for dialect-specific tweaks."""
+
+    return _schema_column(_orm_mapped_column, args, kw)
+
+
 def Column(*args, **kw):
     """A schema.Column wrapper/hook for dialect-specific tweaks."""
 
+    return _schema_column(schema.Column, args, kw)
+
+
+def _schema_column(factory, args, kw):
     test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
 
     if not config.requirements.foreign_key_ddl.enabled_for_config(config):
         args = [arg for arg in args if not isinstance(arg, schema.ForeignKey)]
 
-    col = schema.Column(*args, **kw)
+    construct = factory(*args, **kw)
+
+    if factory is schema.Column:
+        col = construct
+    else:
+        col = construct.column
+
     if test_opts.get("test_needs_autoincrement", False) and kw.get(
         "primary_key", False
     ):
@@ -94,7 +110,7 @@ def Column(*args, **kw):
                 )
 
             event.listen(col, "after_parent_attach", add_seq, propagate=True)
-    return col
+    return construct
 
 
 class eq_type_affinity:
index 271b235966bf7c65aab3e5b72a6c207c2be48f3d..44976b5d88175526ecec50cc2ca15548ad055b5d 100644 (file)
@@ -9,9 +9,12 @@ from typing import Set
 from typing import Type
 from unittest import mock
 
+from typing_extensions import Annotated
+
 from sqlalchemy import Column
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
+from sqlalchemy import func
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import select
@@ -227,6 +230,55 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase):
                     default="d1", default_factory=lambda: "d2"
                 )
 
+    def test_combine_args_from_pep593(self, decl_base: Type[DeclarativeBase]):
+        """test that we can set up column-level defaults separate from
+        dataclass defaults
+
+        """
+        intpk = Annotated[int, mapped_column(primary_key=True)]
+        str30 = Annotated[
+            str, mapped_column(String(30), insert_default=func.foo())
+        ]
+        s_str30 = Annotated[
+            str,
+            mapped_column(String(30), server_default="some server default"),
+        ]
+        user_fk = Annotated[int, mapped_column(ForeignKey("user_account.id"))]
+
+        class User(MappedAsDataclass, decl_base):
+            __tablename__ = "user_account"
+
+            # we need this case for dataclasses that can't derive things
+            # from Annotated yet at the typing level
+            id: Mapped[intpk] = mapped_column(init=False)
+            name_none: Mapped[Optional[str30]] = mapped_column(default=None)
+            name: Mapped[str30] = mapped_column(default="hi")
+            name2: Mapped[s_str30] = mapped_column(default="there")
+            addresses: Mapped[List["Address"]] = relationship(  # noqa: F821
+                back_populates="user", default_factory=list
+            )
+
+        class Address(MappedAsDataclass, decl_base):
+            __tablename__ = "address"
+
+            id: Mapped[intpk] = mapped_column(init=False)
+            email_address: Mapped[str]
+            user_id: Mapped[user_fk] = mapped_column(init=False)
+            user: Mapped["User"] = relationship(
+                back_populates="addresses", default=None
+            )
+
+        is_true(User.__table__.c.id.primary_key)
+        is_true(User.__table__.c.name_none.default.arg.compare(func.foo()))
+        is_true(User.__table__.c.name.default.arg.compare(func.foo()))
+        eq_(User.__table__.c.name2.server_default.arg, "some server default")
+
+        is_true(Address.__table__.c.user_id.references(User.__table__.c.id))
+        u1 = User()
+        eq_(u1.name_none, None)
+        eq_(u1.name, "hi")
+        eq_(u1.name2, "there")
+
     def test_inheritance(self, dc_decl_base: Type[MappedAsDataclass]):
         class Person(dc_decl_base):
             __tablename__ = "person"
index 36840b2d7a788aa1987c096a497b07f73b75207e..eb8c7dbdf4d79e2470fa5c4c0a2156f699ef3c93 100644 (file)
@@ -20,7 +20,6 @@ from sqlalchemy.orm import declared_attr
 from sqlalchemy.orm import deferred
 from sqlalchemy.orm import events as orm_events
 from sqlalchemy.orm import has_inherited_table
-from sqlalchemy.orm import mapped_column
 from sqlalchemy.orm import registry
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import synonym
@@ -35,6 +34,7 @@ from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import mapped_column
 from sqlalchemy.testing.schema import Table
 from sqlalchemy.testing.util import gc_collect
 from sqlalchemy.util import classproperty
@@ -159,11 +159,12 @@ class DeclarativeMixinTest(DeclarativeTestBase):
         eq_(obj.name, "testing")
         eq_(obj.foo(), "bar1")
 
-    def test_unique_column(self):
+    @testing.combinations(Column, mapped_column, argnames="_column")
+    def test_unique_column(self, _column):
         class MyMixin:
 
-            id = Column(Integer, primary_key=True)
-            value = Column(String, unique=True)
+            id = _column(Integer, primary_key=True)
+            value = _column(String, unique=True)
 
         class MyModel(Base, MyMixin):
 
@@ -171,10 +172,11 @@ class DeclarativeMixinTest(DeclarativeTestBase):
 
         assert MyModel.__table__.c.value.unique
 
-    def test_hierarchical_bases_wbase(self):
+    @testing.combinations(Column, mapped_column, argnames="_column")
+    def test_hierarchical_bases_wbase(self, _column):
         class MyMixinParent:
 
-            id = Column(
+            id = _column(
                 Integer, primary_key=True, test_needs_autoincrement=True
             )
 
@@ -183,12 +185,12 @@ class DeclarativeMixinTest(DeclarativeTestBase):
 
         class MyMixin(MyMixinParent):
 
-            baz = Column(String(100), nullable=False, index=True)
+            baz = _column(String(100), nullable=False, index=True)
 
         class MyModel(Base, MyMixin):
 
             __tablename__ = "test"
-            name = Column(String(100), nullable=False, index=True)
+            name = _column(String(100), nullable=False, index=True)
 
         Base.metadata.create_all(testing.db)
         session = fixture_session()
@@ -201,10 +203,11 @@ class DeclarativeMixinTest(DeclarativeTestBase):
         eq_(obj.foo(), "bar1")
         eq_(obj.baz, "fu")
 
-    def test_hierarchical_bases_wdecorator(self):
+    @testing.combinations(Column, mapped_column, argnames="_column")
+    def test_hierarchical_bases_wdecorator(self, _column):
         class MyMixinParent:
 
-            id = Column(
+            id = _column(
                 Integer, primary_key=True, test_needs_autoincrement=True
             )
 
@@ -213,7 +216,7 @@ class DeclarativeMixinTest(DeclarativeTestBase):
 
         class MyMixin(MyMixinParent):
 
-            baz = Column(String(100), nullable=False, index=True)
+            baz = _column(String(100), nullable=False, index=True)
 
         @mapper_registry.mapped
         class MyModel(MyMixin, object):
@@ -232,22 +235,23 @@ class DeclarativeMixinTest(DeclarativeTestBase):
         eq_(obj.foo(), "bar1")
         eq_(obj.baz, "fu")
 
-    def test_mixin_overrides_wbase(self):
+    @testing.combinations(Column, mapped_column, argnames="_column")
+    def test_mixin_overrides_wbase(self, _column):
         """test a mixin that overrides a column on a superclass."""
 
         class MixinA:
-            foo = Column(String(50))
+            foo = _column(String(50))
 
         class MixinB(MixinA):
-            foo = Column(Integer)
+            foo = _column(Integer)
 
         class MyModelA(Base, MixinA):
             __tablename__ = "testa"
-            id = Column(Integer, primary_key=True)
+            id = _column(Integer, primary_key=True)
 
         class MyModelB(Base, MixinB):
             __tablename__ = "testb"
-            id = Column(Integer, primary_key=True)
+            id = _column(Integer, primary_key=True)
 
         eq_(MyModelA.__table__.c.foo.type.__class__, String)
         eq_(MyModelB.__table__.c.foo.type.__class__, Integer)
@@ -1120,7 +1124,7 @@ class DeclarativeMixinTest(DeclarativeTestBase):
                 return cls.__name__.lower()
 
             __table_args__ = {"mysql_engine": "InnoDB"}
-            timestamp = Column(Integer)
+            timestamp = mapped_column(Integer)
             id = Column(Integer, primary_key=True)
 
         class Generic(Base, CommonMixin):
index ae2773d1c7f9bcda26fafd60d1769da107c37d08..6f60a652ff1bc2be72af023c49250f47e6063510 100644 (file)
@@ -559,6 +559,44 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
         is_true(table.c.data_three.nullable)
         is_true(table.c.data_four.nullable)
 
+    def test_extract_fk_col_from_pep593(
+        self, decl_base: Type[DeclarativeBase]
+    ):
+        intpk = Annotated[int, mapped_column(primary_key=True)]
+        element_ref = Annotated[int, mapped_column(ForeignKey("element.id"))]
+
+        class Element(decl_base):
+            __tablename__ = "element"
+
+            id: Mapped[intpk]
+
+        class RefElementOne(decl_base):
+            __tablename__ = "refone"
+
+            id: Mapped[intpk]
+            other_id: Mapped[element_ref]
+
+        class RefElementTwo(decl_base):
+            __tablename__ = "reftwo"
+
+            id: Mapped[intpk]
+            some_id: Mapped[element_ref]
+
+        assert Element.__table__ is not None
+        assert RefElementOne.__table__ is not None
+        assert RefElementTwo.__table__ is not None
+
+        is_true(
+            RefElementOne.__table__.c.other_id.references(
+                Element.__table__.c.id
+            )
+        )
+        is_true(
+            RefElementTwo.__table__.c.some_id.references(
+                Element.__table__.c.id
+            )
+        )
+
     def test_unions(self):
         our_type = Numeric(10, 2)