From: Mike Bayer Date: Sat, 16 Jul 2022 20:19:15 +0000 (-0400) Subject: implement column._merge() X-Git-Tag: rel_2_0_0b1~174^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=26c0e8e1846a4e6ac05c15a1ad188a5655b72edb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement column._merge() this takes the user-defined args of one Column and merges them into the not-user-defined args of another Column. Implemented within the pep-593 column transfer operation to begin to make this new feature more robust. work may still be needed for constraints etc. but in theory everything from the left side annotated column should take effect for the right side if not otherwise specified on the right. Change-Id: I57eb37ed6ceb4b60979a35cfc4b63731d990911d --- diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index ea95f14202..c7e59e3d7a 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -316,11 +316,7 @@ def mapped_column( name=name, type_=type_, autoincrement=autoincrement, - insert_default=insert_default - if insert_default is not _NoArg.NO_ARG - else default - if default is not _NoArg.NO_ARG - else None, + insert_default=insert_default, attribute_options=_AttributeOptions( init, repr, diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index c5f50d7b45..caf9ff3af8 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -485,6 +485,7 @@ class MappedColumn( "_creation_order", "foreign_keys", "_has_nullable", + "_has_insert_default", "deferred", "_attribute_options", "_has_dataclass_arguments", @@ -512,7 +513,13 @@ class MappedColumn( ): self._has_dataclass_arguments = True - kw["default"] = kw.pop("insert_default", None) + insert_default = kw.pop("insert_default", _NoArg.NO_ARG) + self._has_insert_default = insert_default is not _NoArg.NO_ARG + + if self._has_insert_default: + kw["default"] = insert_default + elif attr_opts.dataclasses_default is not _NoArg.NO_ARG: + kw["default"] = attr_opts.dataclasses_default self.deferred = kw.pop("deferred", False) self.column = cast("Column[_T]", Column(*arg, **kw)) @@ -531,6 +538,7 @@ class MappedColumn( new.foreign_keys = new.column.foreign_keys new._has_nullable = self._has_nullable new._attribute_options = self._attribute_options + new._has_insert_default = self._has_insert_default new._has_dataclass_arguments = self._has_dataclass_arguments util.set_creation_order(new) return new @@ -642,27 +650,13 @@ class MappedColumn( our_type_is_pep593 = False if use_args_from is not None: - if use_args_from.column.primary_key: - self.column.primary_key = True - if use_args_from.column.default is not None: - self.column.default = use_args_from.column.default if ( - use_args_from.column.server_default - and self.column.server_default is None + not self._has_insert_default + and use_args_from.column.default is not None ): - self.column.server_default = ( - use_args_from.column.server_default - ) - - for const in use_args_from.column.constraints: - if not const._type_bound: - new_const = const._copy() - new_const._set_parent(self.column) - - for fk in use_args_from.column.foreign_keys: - if not fk.constraint: - new_fk = fk._copy() - new_fk._set_parent(self.column) + self.column.default = None + use_args_from.column._merge(self.column) + sqltype = self.column.type if sqltype._isnull and not self.column.foreign_keys: new_sqltype = None diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 979b8319e1..4ed5b9e6b1 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2233,6 +2233,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): server_default = self.server_default server_onupdate = self.server_onupdate if isinstance(server_default, (Computed, Identity)): + # TODO: likely should be copied in all cases args.append(server_default._copy(**kw)) server_default = server_onupdate = None @@ -2243,6 +2244,10 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): if self._user_defined_nullable is not NULL_UNSPECIFIED: column_kwargs["nullable"] = self._user_defined_nullable + # TODO: DefaultGenerator is not copied here! it's just used again + # with _set_parent() pointing to the old column. see the new + # use of _copy() in the new _merge() method + c = self._constructor( name=self.name, type_=type_, @@ -2264,6 +2269,69 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): ) return self._schema_item_copy(c) + def _merge(self, other: Column[Any]) -> None: + """merge the elements of another column into this one. + + this is used by ORM pep-593 merge and will likely need a lot + of fixes. + + + """ + + if self.primary_key: + other.primary_key = True + + type_ = self.type + if not type_._isnull and other.type._isnull: + if isinstance(type_, SchemaEventTarget): + type_ = type_.copy() + + other.type = type_ + + if isinstance(type_, SchemaEventTarget): + type_._set_parent_with_dispatch(other) + + for impl in type_._variant_mapping.values(): + if isinstance(impl, SchemaEventTarget): + impl._set_parent_with_dispatch(other) + + if ( + self._user_defined_nullable is not NULL_UNSPECIFIED + and other._user_defined_nullable is NULL_UNSPECIFIED + ): + other.nullable = self.nullable + + if self.default is not None and other.default is None: + new_default = self.default._copy() + new_default._set_parent(other) + + if self.server_default and other.server_default is None: + new_server_default = self.server_default + if isinstance(new_server_default, FetchedValue): + new_server_default = new_server_default._copy() + new_server_default._set_parent(other) + else: + other.server_default = new_server_default + + if self.server_onupdate and other.server_onupdate is None: + new_server_onupdate = self.server_onupdate + new_server_onupdate = new_server_onupdate._copy() + new_server_onupdate._set_parent(other) + + if self.onupdate and other.onupdate is None: + new_onupdate = self.onupdate._copy() + new_onupdate._set_parent(other) + + for const in self.constraints: + if not const._type_bound: + new_const = const._copy() + new_const._set_parent(other) + + for fk in self.foreign_keys: + if not fk.constraint: + new_fk = fk._copy() + new_fk._set_parent(other) + def _make_proxy( self, selectable: FromClause, @@ -2948,6 +3016,9 @@ class DefaultGenerator(Executable, SchemaItem): else: self.column.default = self + def _copy(self) -> DefaultGenerator: + raise NotImplementedError() + def _execute_on_connection( self, connection: Connection, @@ -3077,6 +3148,11 @@ class ScalarElementColumnDefault(ColumnDefault): self.for_update = for_update self.arg = arg + def _copy(self) -> ScalarElementColumnDefault: + return ScalarElementColumnDefault( + arg=self.arg, for_update=self.for_update + ) + # _SQLExprDefault = Union["ColumnElement[Any]", "TextClause", "SelectBase"] _SQLExprDefault = Union["ColumnElement[Any]", "TextClause"] @@ -3101,6 +3177,11 @@ class ColumnElementColumnDefault(ColumnDefault): self.for_update = for_update self.arg = arg + def _copy(self) -> ColumnElementColumnDefault: + return ColumnElementColumnDefault( + arg=self.arg, for_update=self.for_update + ) + @util.memoized_property @util.preload_module("sqlalchemy.sql.sqltypes") def _arg_is_typed(self) -> bool: @@ -3132,6 +3213,9 @@ class CallableColumnDefault(ColumnDefault): self.for_update = for_update self.arg = self._maybe_wrap_callable(arg) + def _copy(self) -> CallableColumnDefault: + return CallableColumnDefault(arg=self.arg, for_update=self.for_update) + def _maybe_wrap_callable( self, fn: Union[_CallableColumnDefaultProtocol, Callable[[], Any]] ) -> _CallableColumnDefaultProtocol: @@ -3266,7 +3350,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): nomaxvalue: Optional[bool] = None, cycle: Optional[bool] = None, schema: Optional[Union[str, Literal[SchemaConst.BLANK_SCHEMA]]] = None, - cache: Optional[bool] = None, + cache: Optional[int] = None, order: Optional[bool] = None, data_type: Optional[_TypeEngineArgument[int]] = None, optional: bool = False, @@ -3459,6 +3543,25 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): super(Sequence, self)._set_parent(column) column._on_table_attach(self._set_table) + def _copy(self) -> Sequence: + return Sequence( + name=self.name, + start=self.start, + increment=self.increment, + minvalue=self.minvalue, + maxvalue=self.maxvalue, + nominvalue=self.nominvalue, + nomaxvalue=self.nomaxvalue, + cycle=self.cycle, + schema=self.schema, + cache=self.cache, + order=self.order, + data_type=self.data_type, + optional=self.optional, + metadata=self.metadata, + for_update=self.for_update, + ) + def _set_table(self, column: Column[Any], table: Table) -> None: self._set_metadata(table.metadata) @@ -3522,6 +3625,9 @@ class FetchedValue(SchemaEventTarget): else: return self._clone(for_update) # type: ignore + def _copy(self) -> FetchedValue: + return FetchedValue(self.for_update) + def _clone(self, for_update: bool) -> Any: n = self.__class__.__new__(self.__class__) n.__dict__.update(self.__dict__) @@ -3577,6 +3683,11 @@ class DefaultClause(FetchedValue): self.arg = arg self.reflected = _reflected + def _copy(self) -> DefaultClause: + return DefaultClause( + arg=self.arg, for_update=self.for_update, _reflected=self.reflected + ) + def __repr__(self) -> str: return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update) diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index c33aef9c45..f31f1f4e29 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -13,10 +13,12 @@ from typing import Union import uuid from sqlalchemy import BIGINT +from sqlalchemy import BigInteger from sqlalchemy import Column from sqlalchemy import DateTime from sqlalchemy import exc as sa_exc from sqlalchemy import ForeignKey +from sqlalchemy import func from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import Numeric @@ -643,6 +645,130 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) ) + @testing.combinations( + ("default", lambda ctx: 10), + ("default", func.foo()), + ("onupdate", lambda ctx: 10), + ("onupdate", func.foo()), + ("server_onupdate", func.foo()), + ("server_default", func.foo()), + ("nullable", True), + ("nullable", False), + ("type", BigInteger()), + argnames="paramname, value", + ) + @testing.combinations(True, False, argnames="optional") + @testing.combinations(True, False, argnames="include_existing_col") + def test_combine_args_from_pep593( + self, + decl_base: Type[DeclarativeBase], + paramname, + value, + include_existing_col, + optional, + ): + intpk = Annotated[int, mapped_column(primary_key=True)] + + args = [] + params = {} + if paramname == "type": + args.append(value) + else: + params[paramname] = value + + element_ref = Annotated[int, mapped_column(*args, **params)] + if optional: + element_ref = Optional[element_ref] + + class Element(decl_base): + __tablename__ = "element" + + id: Mapped[intpk] + + if include_existing_col: + data: Mapped[element_ref] = mapped_column() + else: + data: Mapped[element_ref] + + if paramname in ( + "default", + "onupdate", + "server_default", + "server_onupdate", + ): + default = getattr(Element.__table__.c.data, paramname) + is_(default.arg, value) + is_(default.column, Element.__table__.c.data) + elif paramname == "type": + assert type(Element.__table__.c.data.type) is type(value) + else: + is_(getattr(Element.__table__.c.data, paramname), value) + + if paramname != "nullable": + is_(Element.__table__.c.data.nullable, optional) + else: + is_(Element.__table__.c.data.nullable, value) + + @testing.combinations( + ("default", lambda ctx: 10, lambda ctx: 15), + ("default", func.foo(), func.bar()), + ("onupdate", lambda ctx: 10, lambda ctx: 15), + ("onupdate", func.foo(), func.bar()), + ("server_onupdate", func.foo(), func.bar()), + ("server_default", func.foo(), func.bar()), + ("nullable", True, False), + ("nullable", False, True), + ("type", BigInteger(), Numeric()), + argnames="paramname, value, override_value", + ) + def test_dont_combine_args_from_pep593( + self, + decl_base: Type[DeclarativeBase], + paramname, + value, + override_value, + ): + intpk = Annotated[int, mapped_column(primary_key=True)] + + args = [] + params = {} + override_args = [] + override_params = {} + if paramname == "type": + args.append(value) + override_args.append(override_value) + else: + params[paramname] = value + if paramname == "default": + override_params["insert_default"] = override_value + else: + override_params[paramname] = override_value + + element_ref = Annotated[int, mapped_column(*args, **params)] + + class Element(decl_base): + __tablename__ = "element" + + id: Mapped[intpk] + + data: Mapped[element_ref] = mapped_column( + *override_args, **override_params + ) + + if paramname in ( + "default", + "onupdate", + "server_default", + "server_onupdate", + ): + default = getattr(Element.__table__.c.data, paramname) + is_(default.arg, override_value) + is_(default.column, Element.__table__.c.data) + elif paramname == "type": + assert type(Element.__table__.c.data.type) is type(override_value) + else: + is_(getattr(Element.__table__.c.data, paramname), override_value) + def test_unions(self): our_type = Numeric(10, 2) diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 33b6e130f8..b7913e6068 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -3,6 +3,7 @@ import pickle import sqlalchemy as tsa from sqlalchemy import ARRAY +from sqlalchemy import BigInteger from sqlalchemy import bindparam from sqlalchemy import BLANK_SCHEMA from sqlalchemy import Boolean @@ -10,6 +11,7 @@ from sqlalchemy import CheckConstraint from sqlalchemy import Column from sqlalchemy import column from sqlalchemy import ColumnDefault +from sqlalchemy import Computed from sqlalchemy import desc from sqlalchemy import Enum from sqlalchemy import event @@ -17,9 +19,11 @@ from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func +from sqlalchemy import Identity from sqlalchemy import Index from sqlalchemy import Integer from sqlalchemy import MetaData +from sqlalchemy import Numeric from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema from sqlalchemy import select @@ -4182,6 +4186,130 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): deregister(schema.CreateColumn) + @testing.combinations( + ("default", lambda ctx: 10), + ("default", func.foo()), + ("identity_gen", Identity()), + ("identity_gen", Sequence("some_seq")), + ("identity_gen", Computed("side * side")), + ("onupdate", lambda ctx: 10), + ("onupdate", func.foo()), + ("server_onupdate", func.foo()), + ("server_default", func.foo()), + ("nullable", True), + ("nullable", False), + ("type", BigInteger()), + ("type", Enum("one", "two", "three", create_constraint=True)), + argnames="paramname, value", + ) + def test_merge_column( + self, + paramname, + value, + ): + + args = [] + params = {} + if paramname == "type" or isinstance( + value, (Computed, Sequence, Identity) + ): + args.append(value) + else: + params[paramname] = value + + source = Column(*args, **params) + + target = Column() + + source._merge(target) + + if isinstance(value, (Computed, Identity)): + default = target.server_default + assert isinstance(default, type(value)) + elif isinstance(value, Sequence): + default = target.default + assert isinstance(default, type(value)) + + elif paramname in ( + "default", + "onupdate", + "server_default", + "server_onupdate", + ): + default = getattr(target, paramname) + is_(default.arg, value) + is_(default.column, target) + elif paramname == "type": + assert type(target.type) is type(value) + + if isinstance(target.type, Enum): + target.name = "data" + t = Table("t", MetaData(), target) + assert CheckConstraint in [type(c) for c in t.constraints] + else: + is_(getattr(target, paramname), value) + + @testing.combinations( + ("default", lambda ctx: 10, lambda ctx: 15), + ("default", func.foo(), func.bar()), + ("identity_gen", Identity(), Identity()), + ("identity_gen", Sequence("some_seq"), Sequence("some_other_seq")), + ("identity_gen", Computed("side * side"), Computed("top / top")), + ("onupdate", lambda ctx: 10, lambda ctx: 15), + ("onupdate", func.foo(), func.bar()), + ("server_onupdate", func.foo(), func.bar()), + ("server_default", func.foo(), func.bar()), + ("nullable", True, False), + ("nullable", False, True), + ("type", BigInteger(), Numeric()), + argnames="paramname, value, override_value", + ) + def test_dont_merge_column( + self, + paramname, + value, + override_value, + ): + + args = [] + params = {} + override_args = [] + override_params = {} + if paramname == "type" or isinstance( + value, (Computed, Sequence, Identity) + ): + args.append(value) + override_args.append(override_value) + else: + params[paramname] = value + override_params[paramname] = override_value + + source = Column(*args, **params) + + target = Column(*override_args, **override_params) + + source._merge(target) + + if isinstance(value, Sequence): + default = target.default + assert default is override_value + elif isinstance(value, (Computed, Identity)): + default = target.server_default + assert default is override_value + elif paramname in ( + "default", + "onupdate", + "server_default", + "server_onupdate", + ): + default = getattr(target, paramname) + is_(default.arg, override_value) + is_(default.column, target) + elif paramname == "type": + assert type(target.type) is type(override_value) + else: + is_(getattr(target, paramname), override_value) + class ColumnDefaultsTest(fixtures.TestBase):