]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement column._merge()
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 16 Jul 2022 20:19:15 +0000 (16:19 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 16 Jul 2022 21:41:09 +0000 (17:41 -0400)
this takes the user-defined args of one Column and merges
them into the not-user-defined args of another Column.

Implemented within the pep-593 column transfer operation
to begin to make this new feature more robust.

work may still be needed for constraints etc. but
in theory everything from the left side annotated column
should take effect for the right side if not otherwise
specified on the right.

Change-Id: I57eb37ed6ceb4b60979a35cfc4b63731d990911d

lib/sqlalchemy/orm/_orm_constructors.py
lib/sqlalchemy/orm/properties.py
lib/sqlalchemy/sql/schema.py
test/orm/declarative/test_typed_mapping.py
test/sql/test_metadata.py

index ea95f14202b6d0f14632b29030b40828a7e696f8..c7e59e3d7a2195b259649a52427bf7f42a59d5d4 100644 (file)
@@ -316,11 +316,7 @@ def mapped_column(
         name=name,
         type_=type_,
         autoincrement=autoincrement,
-        insert_default=insert_default
-        if insert_default is not _NoArg.NO_ARG
-        else default
-        if default is not _NoArg.NO_ARG
-        else None,
+        insert_default=insert_default,
         attribute_options=_AttributeOptions(
             init,
             repr,
index c5f50d7b450edd1a5eb25d049d72ac6003c1838e..caf9ff3af8230a2f6359718b2ed05c52defbe2b0 100644 (file)
@@ -485,6 +485,7 @@ class MappedColumn(
         "_creation_order",
         "foreign_keys",
         "_has_nullable",
+        "_has_insert_default",
         "deferred",
         "_attribute_options",
         "_has_dataclass_arguments",
@@ -512,7 +513,13 @@ class MappedColumn(
             ):
                 self._has_dataclass_arguments = True
 
-        kw["default"] = kw.pop("insert_default", None)
+        insert_default = kw.pop("insert_default", _NoArg.NO_ARG)
+        self._has_insert_default = insert_default is not _NoArg.NO_ARG
+
+        if self._has_insert_default:
+            kw["default"] = insert_default
+        elif attr_opts.dataclasses_default is not _NoArg.NO_ARG:
+            kw["default"] = attr_opts.dataclasses_default
 
         self.deferred = kw.pop("deferred", False)
         self.column = cast("Column[_T]", Column(*arg, **kw))
@@ -531,6 +538,7 @@ class MappedColumn(
         new.foreign_keys = new.column.foreign_keys
         new._has_nullable = self._has_nullable
         new._attribute_options = self._attribute_options
+        new._has_insert_default = self._has_insert_default
         new._has_dataclass_arguments = self._has_dataclass_arguments
         util.set_creation_order(new)
         return new
@@ -642,27 +650,13 @@ class MappedColumn(
             our_type_is_pep593 = False
 
         if use_args_from is not None:
-            if use_args_from.column.primary_key:
-                self.column.primary_key = True
-            if use_args_from.column.default is not None:
-                self.column.default = use_args_from.column.default
             if (
-                use_args_from.column.server_default
-                and self.column.server_default is None
+                not self._has_insert_default
+                and use_args_from.column.default is not None
             ):
-                self.column.server_default = (
-                    use_args_from.column.server_default
-                )
-
-            for const in use_args_from.column.constraints:
-                if not const._type_bound:
-                    new_const = const._copy()
-                    new_const._set_parent(self.column)
-
-            for fk in use_args_from.column.foreign_keys:
-                if not fk.constraint:
-                    new_fk = fk._copy()
-                    new_fk._set_parent(self.column)
+                self.column.default = None
+            use_args_from.column._merge(self.column)
+            sqltype = self.column.type
 
         if sqltype._isnull and not self.column.foreign_keys:
             new_sqltype = None
index 979b8319e106cf45655270a1a0041e1227619a76..4ed5b9e6b1af52636f238f77de0379dfbd711e8d 100644 (file)
@@ -2233,6 +2233,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
         server_default = self.server_default
         server_onupdate = self.server_onupdate
         if isinstance(server_default, (Computed, Identity)):
+            # TODO: likely should be copied in all cases
             args.append(server_default._copy(**kw))
             server_default = server_onupdate = None
 
@@ -2243,6 +2244,10 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
         if self._user_defined_nullable is not NULL_UNSPECIFIED:
             column_kwargs["nullable"] = self._user_defined_nullable
 
+        # TODO: DefaultGenerator is not copied here!  it's just used again
+        # with _set_parent() pointing to the old column.  see the new
+        # use of _copy() in the new _merge() method
+
         c = self._constructor(
             name=self.name,
             type_=type_,
@@ -2264,6 +2269,69 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
         )
         return self._schema_item_copy(c)
 
+    def _merge(self, other: Column[Any]) -> None:
+        """merge the elements of another column into this one.
+
+        this is used by ORM pep-593 merge and will likely need a lot
+        of fixes.
+
+
+        """
+
+        if self.primary_key:
+            other.primary_key = True
+
+        type_ = self.type
+        if not type_._isnull and other.type._isnull:
+            if isinstance(type_, SchemaEventTarget):
+                type_ = type_.copy()
+
+            other.type = type_
+
+            if isinstance(type_, SchemaEventTarget):
+                type_._set_parent_with_dispatch(other)
+
+            for impl in type_._variant_mapping.values():
+                if isinstance(impl, SchemaEventTarget):
+                    impl._set_parent_with_dispatch(other)
+
+        if (
+            self._user_defined_nullable is not NULL_UNSPECIFIED
+            and other._user_defined_nullable is NULL_UNSPECIFIED
+        ):
+            other.nullable = self.nullable
+
+        if self.default is not None and other.default is None:
+            new_default = self.default._copy()
+            new_default._set_parent(other)
+
+        if self.server_default and other.server_default is None:
+            new_server_default = self.server_default
+            if isinstance(new_server_default, FetchedValue):
+                new_server_default = new_server_default._copy()
+                new_server_default._set_parent(other)
+            else:
+                other.server_default = new_server_default
+
+        if self.server_onupdate and other.server_onupdate is None:
+            new_server_onupdate = self.server_onupdate
+            new_server_onupdate = new_server_onupdate._copy()
+            new_server_onupdate._set_parent(other)
+
+        if self.onupdate and other.onupdate is None:
+            new_onupdate = self.onupdate._copy()
+            new_onupdate._set_parent(other)
+
+        for const in self.constraints:
+            if not const._type_bound:
+                new_const = const._copy()
+                new_const._set_parent(other)
+
+        for fk in self.foreign_keys:
+            if not fk.constraint:
+                new_fk = fk._copy()
+                new_fk._set_parent(other)
+
     def _make_proxy(
         self,
         selectable: FromClause,
@@ -2948,6 +3016,9 @@ class DefaultGenerator(Executable, SchemaItem):
         else:
             self.column.default = self
 
+    def _copy(self) -> DefaultGenerator:
+        raise NotImplementedError()
+
     def _execute_on_connection(
         self,
         connection: Connection,
@@ -3077,6 +3148,11 @@ class ScalarElementColumnDefault(ColumnDefault):
         self.for_update = for_update
         self.arg = arg
 
+    def _copy(self) -> ScalarElementColumnDefault:
+        return ScalarElementColumnDefault(
+            arg=self.arg, for_update=self.for_update
+        )
+
 
 # _SQLExprDefault = Union["ColumnElement[Any]", "TextClause", "SelectBase"]
 _SQLExprDefault = Union["ColumnElement[Any]", "TextClause"]
@@ -3101,6 +3177,11 @@ class ColumnElementColumnDefault(ColumnDefault):
         self.for_update = for_update
         self.arg = arg
 
+    def _copy(self) -> ColumnElementColumnDefault:
+        return ColumnElementColumnDefault(
+            arg=self.arg, for_update=self.for_update
+        )
+
     @util.memoized_property
     @util.preload_module("sqlalchemy.sql.sqltypes")
     def _arg_is_typed(self) -> bool:
@@ -3132,6 +3213,9 @@ class CallableColumnDefault(ColumnDefault):
         self.for_update = for_update
         self.arg = self._maybe_wrap_callable(arg)
 
+    def _copy(self) -> CallableColumnDefault:
+        return CallableColumnDefault(arg=self.arg, for_update=self.for_update)
+
     def _maybe_wrap_callable(
         self, fn: Union[_CallableColumnDefaultProtocol, Callable[[], Any]]
     ) -> _CallableColumnDefaultProtocol:
@@ -3266,7 +3350,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
         nomaxvalue: Optional[bool] = None,
         cycle: Optional[bool] = None,
         schema: Optional[Union[str, Literal[SchemaConst.BLANK_SCHEMA]]] = None,
-        cache: Optional[bool] = None,
+        cache: Optional[int] = None,
         order: Optional[bool] = None,
         data_type: Optional[_TypeEngineArgument[int]] = None,
         optional: bool = False,
@@ -3459,6 +3543,25 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
         super(Sequence, self)._set_parent(column)
         column._on_table_attach(self._set_table)
 
+    def _copy(self) -> Sequence:
+        return Sequence(
+            name=self.name,
+            start=self.start,
+            increment=self.increment,
+            minvalue=self.minvalue,
+            maxvalue=self.maxvalue,
+            nominvalue=self.nominvalue,
+            nomaxvalue=self.nomaxvalue,
+            cycle=self.cycle,
+            schema=self.schema,
+            cache=self.cache,
+            order=self.order,
+            data_type=self.data_type,
+            optional=self.optional,
+            metadata=self.metadata,
+            for_update=self.for_update,
+        )
+
     def _set_table(self, column: Column[Any], table: Table) -> None:
         self._set_metadata(table.metadata)
 
@@ -3522,6 +3625,9 @@ class FetchedValue(SchemaEventTarget):
         else:
             return self._clone(for_update)  # type: ignore
 
+    def _copy(self) -> FetchedValue:
+        return FetchedValue(self.for_update)
+
     def _clone(self, for_update: bool) -> Any:
         n = self.__class__.__new__(self.__class__)
         n.__dict__.update(self.__dict__)
@@ -3577,6 +3683,11 @@ class DefaultClause(FetchedValue):
         self.arg = arg
         self.reflected = _reflected
 
+    def _copy(self) -> DefaultClause:
+        return DefaultClause(
+            arg=self.arg, for_update=self.for_update, _reflected=self.reflected
+        )
+
     def __repr__(self) -> str:
         return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update)
 
index c33aef9c45d83cb37403019e5edafd8e0b72bb43..f31f1f4e29a1f234396ddf78393410bb5443a679 100644 (file)
@@ -13,10 +13,12 @@ from typing import Union
 import uuid
 
 from sqlalchemy import BIGINT
+from sqlalchemy import BigInteger
 from sqlalchemy import Column
 from sqlalchemy import DateTime
 from sqlalchemy import exc as sa_exc
 from sqlalchemy import ForeignKey
+from sqlalchemy import func
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import Numeric
@@ -643,6 +645,130 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             )
         )
 
+    @testing.combinations(
+        ("default", lambda ctx: 10),
+        ("default", func.foo()),
+        ("onupdate", lambda ctx: 10),
+        ("onupdate", func.foo()),
+        ("server_onupdate", func.foo()),
+        ("server_default", func.foo()),
+        ("nullable", True),
+        ("nullable", False),
+        ("type", BigInteger()),
+        argnames="paramname, value",
+    )
+    @testing.combinations(True, False, argnames="optional")
+    @testing.combinations(True, False, argnames="include_existing_col")
+    def test_combine_args_from_pep593(
+        self,
+        decl_base: Type[DeclarativeBase],
+        paramname,
+        value,
+        include_existing_col,
+        optional,
+    ):
+        intpk = Annotated[int, mapped_column(primary_key=True)]
+
+        args = []
+        params = {}
+        if paramname == "type":
+            args.append(value)
+        else:
+            params[paramname] = value
+
+        element_ref = Annotated[int, mapped_column(*args, **params)]
+        if optional:
+            element_ref = Optional[element_ref]
+
+        class Element(decl_base):
+            __tablename__ = "element"
+
+            id: Mapped[intpk]
+
+            if include_existing_col:
+                data: Mapped[element_ref] = mapped_column()
+            else:
+                data: Mapped[element_ref]
+
+        if paramname in (
+            "default",
+            "onupdate",
+            "server_default",
+            "server_onupdate",
+        ):
+            default = getattr(Element.__table__.c.data, paramname)
+            is_(default.arg, value)
+            is_(default.column, Element.__table__.c.data)
+        elif paramname == "type":
+            assert type(Element.__table__.c.data.type) is type(value)
+        else:
+            is_(getattr(Element.__table__.c.data, paramname), value)
+
+        if paramname != "nullable":
+            is_(Element.__table__.c.data.nullable, optional)
+        else:
+            is_(Element.__table__.c.data.nullable, value)
+
+    @testing.combinations(
+        ("default", lambda ctx: 10, lambda ctx: 15),
+        ("default", func.foo(), func.bar()),
+        ("onupdate", lambda ctx: 10, lambda ctx: 15),
+        ("onupdate", func.foo(), func.bar()),
+        ("server_onupdate", func.foo(), func.bar()),
+        ("server_default", func.foo(), func.bar()),
+        ("nullable", True, False),
+        ("nullable", False, True),
+        ("type", BigInteger(), Numeric()),
+        argnames="paramname, value, override_value",
+    )
+    def test_dont_combine_args_from_pep593(
+        self,
+        decl_base: Type[DeclarativeBase],
+        paramname,
+        value,
+        override_value,
+    ):
+        intpk = Annotated[int, mapped_column(primary_key=True)]
+
+        args = []
+        params = {}
+        override_args = []
+        override_params = {}
+        if paramname == "type":
+            args.append(value)
+            override_args.append(override_value)
+        else:
+            params[paramname] = value
+            if paramname == "default":
+                override_params["insert_default"] = override_value
+            else:
+                override_params[paramname] = override_value
+
+        element_ref = Annotated[int, mapped_column(*args, **params)]
+
+        class Element(decl_base):
+            __tablename__ = "element"
+
+            id: Mapped[intpk]
+
+            data: Mapped[element_ref] = mapped_column(
+                *override_args, **override_params
+            )
+
+        if paramname in (
+            "default",
+            "onupdate",
+            "server_default",
+            "server_onupdate",
+        ):
+            default = getattr(Element.__table__.c.data, paramname)
+            is_(default.arg, override_value)
+            is_(default.column, Element.__table__.c.data)
+        elif paramname == "type":
+            assert type(Element.__table__.c.data.type) is type(override_value)
+        else:
+            is_(getattr(Element.__table__.c.data, paramname), override_value)
+
     def test_unions(self):
         our_type = Numeric(10, 2)
 
index 33b6e130f866b226220f27fdeaab6906ce3c2f16..b7913e6068da4bd16b6b8e5e6cfd9580a977f0a0 100644 (file)
@@ -3,6 +3,7 @@ import pickle
 
 import sqlalchemy as tsa
 from sqlalchemy import ARRAY
+from sqlalchemy import BigInteger
 from sqlalchemy import bindparam
 from sqlalchemy import BLANK_SCHEMA
 from sqlalchemy import Boolean
@@ -10,6 +11,7 @@ from sqlalchemy import CheckConstraint
 from sqlalchemy import Column
 from sqlalchemy import column
 from sqlalchemy import ColumnDefault
+from sqlalchemy import Computed
 from sqlalchemy import desc
 from sqlalchemy import Enum
 from sqlalchemy import event
@@ -17,9 +19,11 @@ from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import ForeignKeyConstraint
 from sqlalchemy import func
+from sqlalchemy import Identity
 from sqlalchemy import Index
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
+from sqlalchemy import Numeric
 from sqlalchemy import PrimaryKeyConstraint
 from sqlalchemy import schema
 from sqlalchemy import select
@@ -4182,6 +4186,130 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase):
 
         deregister(schema.CreateColumn)
 
+    @testing.combinations(
+        ("default", lambda ctx: 10),
+        ("default", func.foo()),
+        ("identity_gen", Identity()),
+        ("identity_gen", Sequence("some_seq")),
+        ("identity_gen", Computed("side * side")),
+        ("onupdate", lambda ctx: 10),
+        ("onupdate", func.foo()),
+        ("server_onupdate", func.foo()),
+        ("server_default", func.foo()),
+        ("nullable", True),
+        ("nullable", False),
+        ("type", BigInteger()),
+        ("type", Enum("one", "two", "three", create_constraint=True)),
+        argnames="paramname, value",
+    )
+    def test_merge_column(
+        self,
+        paramname,
+        value,
+    ):
+
+        args = []
+        params = {}
+        if paramname == "type" or isinstance(
+            value, (Computed, Sequence, Identity)
+        ):
+            args.append(value)
+        else:
+            params[paramname] = value
+
+        source = Column(*args, **params)
+
+        target = Column()
+
+        source._merge(target)
+
+        if isinstance(value, (Computed, Identity)):
+            default = target.server_default
+            assert isinstance(default, type(value))
+        elif isinstance(value, Sequence):
+            default = target.default
+            assert isinstance(default, type(value))
+
+        elif paramname in (
+            "default",
+            "onupdate",
+            "server_default",
+            "server_onupdate",
+        ):
+            default = getattr(target, paramname)
+            is_(default.arg, value)
+            is_(default.column, target)
+        elif paramname == "type":
+            assert type(target.type) is type(value)
+
+            if isinstance(target.type, Enum):
+                target.name = "data"
+                t = Table("t", MetaData(), target)
+                assert CheckConstraint in [type(c) for c in t.constraints]
+        else:
+            is_(getattr(target, paramname), value)
+
+    @testing.combinations(
+        ("default", lambda ctx: 10, lambda ctx: 15),
+        ("default", func.foo(), func.bar()),
+        ("identity_gen", Identity(), Identity()),
+        ("identity_gen", Sequence("some_seq"), Sequence("some_other_seq")),
+        ("identity_gen", Computed("side * side"), Computed("top / top")),
+        ("onupdate", lambda ctx: 10, lambda ctx: 15),
+        ("onupdate", func.foo(), func.bar()),
+        ("server_onupdate", func.foo(), func.bar()),
+        ("server_default", func.foo(), func.bar()),
+        ("nullable", True, False),
+        ("nullable", False, True),
+        ("type", BigInteger(), Numeric()),
+        argnames="paramname, value, override_value",
+    )
+    def test_dont_merge_column(
+        self,
+        paramname,
+        value,
+        override_value,
+    ):
+
+        args = []
+        params = {}
+        override_args = []
+        override_params = {}
+        if paramname == "type" or isinstance(
+            value, (Computed, Sequence, Identity)
+        ):
+            args.append(value)
+            override_args.append(override_value)
+        else:
+            params[paramname] = value
+            override_params[paramname] = override_value
+
+        source = Column(*args, **params)
+
+        target = Column(*override_args, **override_params)
+
+        source._merge(target)
+
+        if isinstance(value, Sequence):
+            default = target.default
+            assert default is override_value
+        elif isinstance(value, (Computed, Identity)):
+            default = target.server_default
+            assert default is override_value
+        elif paramname in (
+            "default",
+            "onupdate",
+            "server_default",
+            "server_onupdate",
+        ):
+            default = getattr(target, paramname)
+            is_(default.arg, override_value)
+            is_(default.column, target)
+        elif paramname == "type":
+            assert type(target.type) is type(override_value)
+        else:
+            is_(getattr(target, paramname), override_value)
+
 
 class ColumnDefaultsTest(fixtures.TestBase):