name=name,
type_=type_,
autoincrement=autoincrement,
- insert_default=insert_default
- if insert_default is not _NoArg.NO_ARG
- else default
- if default is not _NoArg.NO_ARG
- else None,
+ insert_default=insert_default,
attribute_options=_AttributeOptions(
init,
repr,
"_creation_order",
"foreign_keys",
"_has_nullable",
+ "_has_insert_default",
"deferred",
"_attribute_options",
"_has_dataclass_arguments",
):
self._has_dataclass_arguments = True
- kw["default"] = kw.pop("insert_default", None)
+ insert_default = kw.pop("insert_default", _NoArg.NO_ARG)
+ self._has_insert_default = insert_default is not _NoArg.NO_ARG
+
+ if self._has_insert_default:
+ kw["default"] = insert_default
+ elif attr_opts.dataclasses_default is not _NoArg.NO_ARG:
+ kw["default"] = attr_opts.dataclasses_default
self.deferred = kw.pop("deferred", False)
self.column = cast("Column[_T]", Column(*arg, **kw))
new.foreign_keys = new.column.foreign_keys
new._has_nullable = self._has_nullable
new._attribute_options = self._attribute_options
+ new._has_insert_default = self._has_insert_default
new._has_dataclass_arguments = self._has_dataclass_arguments
util.set_creation_order(new)
return new
our_type_is_pep593 = False
if use_args_from is not None:
- if use_args_from.column.primary_key:
- self.column.primary_key = True
- if use_args_from.column.default is not None:
- self.column.default = use_args_from.column.default
if (
- use_args_from.column.server_default
- and self.column.server_default is None
+ not self._has_insert_default
+ and use_args_from.column.default is not None
):
- self.column.server_default = (
- use_args_from.column.server_default
- )
-
- for const in use_args_from.column.constraints:
- if not const._type_bound:
- new_const = const._copy()
- new_const._set_parent(self.column)
-
- for fk in use_args_from.column.foreign_keys:
- if not fk.constraint:
- new_fk = fk._copy()
- new_fk._set_parent(self.column)
+ self.column.default = None
+ use_args_from.column._merge(self.column)
+ sqltype = self.column.type
if sqltype._isnull and not self.column.foreign_keys:
new_sqltype = None
server_default = self.server_default
server_onupdate = self.server_onupdate
if isinstance(server_default, (Computed, Identity)):
+ # TODO: likely should be copied in all cases
args.append(server_default._copy(**kw))
server_default = server_onupdate = None
if self._user_defined_nullable is not NULL_UNSPECIFIED:
column_kwargs["nullable"] = self._user_defined_nullable
+ # TODO: DefaultGenerator is not copied here! it's just used again
+ # with _set_parent() pointing to the old column. see the new
+ # use of _copy() in the new _merge() method
+
c = self._constructor(
name=self.name,
type_=type_,
)
return self._schema_item_copy(c)
+ def _merge(self, other: Column[Any]) -> None:
+ """merge the elements of another column into this one.
+
+ this is used by ORM pep-593 merge and will likely need a lot
+ of fixes.
+
+
+ """
+
+ if self.primary_key:
+ other.primary_key = True
+
+ type_ = self.type
+ if not type_._isnull and other.type._isnull:
+ if isinstance(type_, SchemaEventTarget):
+ type_ = type_.copy()
+
+ other.type = type_
+
+ if isinstance(type_, SchemaEventTarget):
+ type_._set_parent_with_dispatch(other)
+
+ for impl in type_._variant_mapping.values():
+ if isinstance(impl, SchemaEventTarget):
+ impl._set_parent_with_dispatch(other)
+
+ if (
+ self._user_defined_nullable is not NULL_UNSPECIFIED
+ and other._user_defined_nullable is NULL_UNSPECIFIED
+ ):
+ other.nullable = self.nullable
+
+ if self.default is not None and other.default is None:
+ new_default = self.default._copy()
+ new_default._set_parent(other)
+
+ if self.server_default and other.server_default is None:
+ new_server_default = self.server_default
+ if isinstance(new_server_default, FetchedValue):
+ new_server_default = new_server_default._copy()
+ new_server_default._set_parent(other)
+ else:
+ other.server_default = new_server_default
+
+ if self.server_onupdate and other.server_onupdate is None:
+ new_server_onupdate = self.server_onupdate
+ new_server_onupdate = new_server_onupdate._copy()
+ new_server_onupdate._set_parent(other)
+
+ if self.onupdate and other.onupdate is None:
+ new_onupdate = self.onupdate._copy()
+ new_onupdate._set_parent(other)
+
+ for const in self.constraints:
+ if not const._type_bound:
+ new_const = const._copy()
+ new_const._set_parent(other)
+
+ for fk in self.foreign_keys:
+ if not fk.constraint:
+ new_fk = fk._copy()
+ new_fk._set_parent(other)
+
def _make_proxy(
self,
selectable: FromClause,
else:
self.column.default = self
+ def _copy(self) -> DefaultGenerator:
+ raise NotImplementedError()
+
def _execute_on_connection(
self,
connection: Connection,
self.for_update = for_update
self.arg = arg
+ def _copy(self) -> ScalarElementColumnDefault:
+ return ScalarElementColumnDefault(
+ arg=self.arg, for_update=self.for_update
+ )
+
# _SQLExprDefault = Union["ColumnElement[Any]", "TextClause", "SelectBase"]
_SQLExprDefault = Union["ColumnElement[Any]", "TextClause"]
self.for_update = for_update
self.arg = arg
+ def _copy(self) -> ColumnElementColumnDefault:
+ return ColumnElementColumnDefault(
+ arg=self.arg, for_update=self.for_update
+ )
+
@util.memoized_property
@util.preload_module("sqlalchemy.sql.sqltypes")
def _arg_is_typed(self) -> bool:
self.for_update = for_update
self.arg = self._maybe_wrap_callable(arg)
+ def _copy(self) -> CallableColumnDefault:
+ return CallableColumnDefault(arg=self.arg, for_update=self.for_update)
+
def _maybe_wrap_callable(
self, fn: Union[_CallableColumnDefaultProtocol, Callable[[], Any]]
) -> _CallableColumnDefaultProtocol:
nomaxvalue: Optional[bool] = None,
cycle: Optional[bool] = None,
schema: Optional[Union[str, Literal[SchemaConst.BLANK_SCHEMA]]] = None,
- cache: Optional[bool] = None,
+ cache: Optional[int] = None,
order: Optional[bool] = None,
data_type: Optional[_TypeEngineArgument[int]] = None,
optional: bool = False,
super(Sequence, self)._set_parent(column)
column._on_table_attach(self._set_table)
+ def _copy(self) -> Sequence:
+ return Sequence(
+ name=self.name,
+ start=self.start,
+ increment=self.increment,
+ minvalue=self.minvalue,
+ maxvalue=self.maxvalue,
+ nominvalue=self.nominvalue,
+ nomaxvalue=self.nomaxvalue,
+ cycle=self.cycle,
+ schema=self.schema,
+ cache=self.cache,
+ order=self.order,
+ data_type=self.data_type,
+ optional=self.optional,
+ metadata=self.metadata,
+ for_update=self.for_update,
+ )
+
def _set_table(self, column: Column[Any], table: Table) -> None:
self._set_metadata(table.metadata)
else:
return self._clone(for_update) # type: ignore
+ def _copy(self) -> FetchedValue:
+ return FetchedValue(self.for_update)
+
def _clone(self, for_update: bool) -> Any:
n = self.__class__.__new__(self.__class__)
n.__dict__.update(self.__dict__)
self.arg = arg
self.reflected = _reflected
+ def _copy(self) -> DefaultClause:
+ return DefaultClause(
+ arg=self.arg, for_update=self.for_update, _reflected=self.reflected
+ )
+
def __repr__(self) -> str:
return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update)
import uuid
from sqlalchemy import BIGINT
+from sqlalchemy import BigInteger
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import exc as sa_exc
from sqlalchemy import ForeignKey
+from sqlalchemy import func
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import Numeric
)
)
+ @testing.combinations(
+ ("default", lambda ctx: 10),
+ ("default", func.foo()),
+ ("onupdate", lambda ctx: 10),
+ ("onupdate", func.foo()),
+ ("server_onupdate", func.foo()),
+ ("server_default", func.foo()),
+ ("nullable", True),
+ ("nullable", False),
+ ("type", BigInteger()),
+ argnames="paramname, value",
+ )
+ @testing.combinations(True, False, argnames="optional")
+ @testing.combinations(True, False, argnames="include_existing_col")
+ def test_combine_args_from_pep593(
+ self,
+ decl_base: Type[DeclarativeBase],
+ paramname,
+ value,
+ include_existing_col,
+ optional,
+ ):
+ intpk = Annotated[int, mapped_column(primary_key=True)]
+
+ args = []
+ params = {}
+ if paramname == "type":
+ args.append(value)
+ else:
+ params[paramname] = value
+
+ element_ref = Annotated[int, mapped_column(*args, **params)]
+ if optional:
+ element_ref = Optional[element_ref]
+
+ class Element(decl_base):
+ __tablename__ = "element"
+
+ id: Mapped[intpk]
+
+ if include_existing_col:
+ data: Mapped[element_ref] = mapped_column()
+ else:
+ data: Mapped[element_ref]
+
+ if paramname in (
+ "default",
+ "onupdate",
+ "server_default",
+ "server_onupdate",
+ ):
+ default = getattr(Element.__table__.c.data, paramname)
+ is_(default.arg, value)
+ is_(default.column, Element.__table__.c.data)
+ elif paramname == "type":
+ assert type(Element.__table__.c.data.type) is type(value)
+ else:
+ is_(getattr(Element.__table__.c.data, paramname), value)
+
+ if paramname != "nullable":
+ is_(Element.__table__.c.data.nullable, optional)
+ else:
+ is_(Element.__table__.c.data.nullable, value)
+
+ @testing.combinations(
+ ("default", lambda ctx: 10, lambda ctx: 15),
+ ("default", func.foo(), func.bar()),
+ ("onupdate", lambda ctx: 10, lambda ctx: 15),
+ ("onupdate", func.foo(), func.bar()),
+ ("server_onupdate", func.foo(), func.bar()),
+ ("server_default", func.foo(), func.bar()),
+ ("nullable", True, False),
+ ("nullable", False, True),
+ ("type", BigInteger(), Numeric()),
+ argnames="paramname, value, override_value",
+ )
+ def test_dont_combine_args_from_pep593(
+ self,
+ decl_base: Type[DeclarativeBase],
+ paramname,
+ value,
+ override_value,
+ ):
+ intpk = Annotated[int, mapped_column(primary_key=True)]
+
+ args = []
+ params = {}
+ override_args = []
+ override_params = {}
+ if paramname == "type":
+ args.append(value)
+ override_args.append(override_value)
+ else:
+ params[paramname] = value
+ if paramname == "default":
+ override_params["insert_default"] = override_value
+ else:
+ override_params[paramname] = override_value
+
+ element_ref = Annotated[int, mapped_column(*args, **params)]
+
+ class Element(decl_base):
+ __tablename__ = "element"
+
+ id: Mapped[intpk]
+
+ data: Mapped[element_ref] = mapped_column(
+ *override_args, **override_params
+ )
+
+ if paramname in (
+ "default",
+ "onupdate",
+ "server_default",
+ "server_onupdate",
+ ):
+ default = getattr(Element.__table__.c.data, paramname)
+ is_(default.arg, override_value)
+ is_(default.column, Element.__table__.c.data)
+ elif paramname == "type":
+ assert type(Element.__table__.c.data.type) is type(override_value)
+ else:
+ is_(getattr(Element.__table__.c.data, paramname), override_value)
+
def test_unions(self):
our_type = Numeric(10, 2)
import sqlalchemy as tsa
from sqlalchemy import ARRAY
+from sqlalchemy import BigInteger
from sqlalchemy import bindparam
from sqlalchemy import BLANK_SCHEMA
from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import column
from sqlalchemy import ColumnDefault
+from sqlalchemy import Computed
from sqlalchemy import desc
from sqlalchemy import Enum
from sqlalchemy import event
from sqlalchemy import ForeignKey
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
+from sqlalchemy import Identity
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import MetaData
+from sqlalchemy import Numeric
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import schema
from sqlalchemy import select
deregister(schema.CreateColumn)
+ @testing.combinations(
+ ("default", lambda ctx: 10),
+ ("default", func.foo()),
+ ("identity_gen", Identity()),
+ ("identity_gen", Sequence("some_seq")),
+ ("identity_gen", Computed("side * side")),
+ ("onupdate", lambda ctx: 10),
+ ("onupdate", func.foo()),
+ ("server_onupdate", func.foo()),
+ ("server_default", func.foo()),
+ ("nullable", True),
+ ("nullable", False),
+ ("type", BigInteger()),
+ ("type", Enum("one", "two", "three", create_constraint=True)),
+ argnames="paramname, value",
+ )
+ def test_merge_column(
+ self,
+ paramname,
+ value,
+ ):
+
+ args = []
+ params = {}
+ if paramname == "type" or isinstance(
+ value, (Computed, Sequence, Identity)
+ ):
+ args.append(value)
+ else:
+ params[paramname] = value
+
+ source = Column(*args, **params)
+
+ target = Column()
+
+ source._merge(target)
+
+ if isinstance(value, (Computed, Identity)):
+ default = target.server_default
+ assert isinstance(default, type(value))
+ elif isinstance(value, Sequence):
+ default = target.default
+ assert isinstance(default, type(value))
+
+ elif paramname in (
+ "default",
+ "onupdate",
+ "server_default",
+ "server_onupdate",
+ ):
+ default = getattr(target, paramname)
+ is_(default.arg, value)
+ is_(default.column, target)
+ elif paramname == "type":
+ assert type(target.type) is type(value)
+
+ if isinstance(target.type, Enum):
+ target.name = "data"
+ t = Table("t", MetaData(), target)
+ assert CheckConstraint in [type(c) for c in t.constraints]
+ else:
+ is_(getattr(target, paramname), value)
+
+ @testing.combinations(
+ ("default", lambda ctx: 10, lambda ctx: 15),
+ ("default", func.foo(), func.bar()),
+ ("identity_gen", Identity(), Identity()),
+ ("identity_gen", Sequence("some_seq"), Sequence("some_other_seq")),
+ ("identity_gen", Computed("side * side"), Computed("top / top")),
+ ("onupdate", lambda ctx: 10, lambda ctx: 15),
+ ("onupdate", func.foo(), func.bar()),
+ ("server_onupdate", func.foo(), func.bar()),
+ ("server_default", func.foo(), func.bar()),
+ ("nullable", True, False),
+ ("nullable", False, True),
+ ("type", BigInteger(), Numeric()),
+ argnames="paramname, value, override_value",
+ )
+ def test_dont_merge_column(
+ self,
+ paramname,
+ value,
+ override_value,
+ ):
+
+ args = []
+ params = {}
+ override_args = []
+ override_params = {}
+ if paramname == "type" or isinstance(
+ value, (Computed, Sequence, Identity)
+ ):
+ args.append(value)
+ override_args.append(override_value)
+ else:
+ params[paramname] = value
+ override_params[paramname] = override_value
+
+ source = Column(*args, **params)
+
+ target = Column(*override_args, **override_params)
+
+ source._merge(target)
+
+ if isinstance(value, Sequence):
+ default = target.default
+ assert default is override_value
+ elif isinstance(value, (Computed, Identity)):
+ default = target.server_default
+ assert default is override_value
+ elif paramname in (
+ "default",
+ "onupdate",
+ "server_default",
+ "server_onupdate",
+ ):
+ default = getattr(target, paramname)
+ is_(default.arg, override_value)
+ is_(default.column, target)
+ elif paramname == "type":
+ assert type(target.type) is type(override_value)
+ else:
+ is_(getattr(target, paramname), override_value)
+
class ColumnDefaultsTest(fixtures.TestBase):