From 333fa0ec15187bb7a726262e5630fe79323c46a1 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 3 Apr 2025 10:36:28 -0400 Subject: [PATCH] generalize composite bulk insert to hybrids Added new hybrid method :meth:`.hybrid_property.bulk_insert_setter` which works in a similar way as :meth:`.hybrid_property.update_expression` for bulk ORM operations. A user-defined class method can now populate a bulk insert mapping dictionary using the desired hybrid mechanics. New documentation is added showing how both of these methods can be used including in combination with the new :func:`_sql.from_dml_column` construct. Fixes: #12496 Change-Id: I39f6793538f14314e0147765fa2d780b7c99493e --- doc/build/changelog/migration_21.rst | 90 ++++ doc/build/changelog/unreleased_21/12496.rst | 15 + lib/sqlalchemy/ext/hybrid.py | 450 ++++++++++++++++--- lib/sqlalchemy/orm/attributes.py | 5 + lib/sqlalchemy/orm/bulk_persistence.py | 41 +- lib/sqlalchemy/orm/descriptor_props.py | 3 + lib/sqlalchemy/orm/interfaces.py | 5 + lib/sqlalchemy/orm/mapper.py | 7 +- lib/sqlalchemy/sql/_elements_constructors.py | 2 +- lib/sqlalchemy/sql/base.py | 27 +- lib/sqlalchemy/sql/expression.py | 1 + test/ext/test_hybrid.py | 415 +++++++++++++++++ 12 files changed, 992 insertions(+), 69 deletions(-) diff --git a/doc/build/changelog/migration_21.rst b/doc/build/changelog/migration_21.rst index 5634cdda64..cff0f3b52e 100644 --- a/doc/build/changelog/migration_21.rst +++ b/doc/build/changelog/migration_21.rst @@ -418,6 +418,96 @@ required if using the connection string directly with ``pyodbc.connect()``). :ticket:`11250` +.. _change_12496: + +New Hybrid DML hook features +---------------------------- + +To complement the existing :meth:`.hybrid_property.update_expression` decorator, +a new decorator :meth:`.hybrid_property.bulk_dml` is added, which works +specifically with parameter dictionaries passed to :meth:`_orm.Session.execute` +when dealing with ORM-enabled :func:`_dml.insert` or :func:`_dml.update`:: + + from typing import MutableMapping + from dataclasses import dataclass + + + @dataclass + class Point: + x: int + y: int + + + class Location(Base): + __tablename__ = "location" + + id: Mapped[int] = mapped_column(primary_key=True) + x: Mapped[int] + y: Mapped[int] + + @hybrid_property + def coordinates(self) -> Point: + return Point(self.x, self.y) + + @coordinates.inplace.bulk_dml + @classmethod + def _coordinates_bulk_dml( + cls, mapping: MutableMapping[str, Any], value: Point + ) -> None: + mapping["x"] = value.x + mapping["y"] = value.y + +Additionally, a new helper :func:`_sql.from_dml_column` is added, which may be +used with the :meth:`.hybrid_property.update_expression` hook to indicate +re-use of a column expression from elsewhere in the UPDATE statement's SET +clause:: + + from sqlalchemy import from_dml_column + + + class Product(Base): + __tablename__ = "product" + + id: Mapped[int] = mapped_column(primary_key=True) + price: Mapped[float] + tax_rate: Mapped[float] + + @hybrid_property + def total_price(self) -> float: + return self.price * (1 + self.tax_rate) + + @total_price.inplace.update_expression + @classmethod + def _total_price_update_expression(cls, value: Any) -> List[Tuple[Any, Any]]: + return [(cls.price, value / (1 + from_dml_column(cls.tax_rate)))] + +In the above example, if the ``tax_rate`` column is also indicated in the +SET clause of the UPDATE, that expression will be used for the ``total_price`` +expression rather than making use of the previous value of the ``tax_rate`` +column: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import update + >>> print(update(Product).values({Product.tax_rate: 0.08, Product.total_price: 125.00})) + {printsql}UPDATE product SET tax_rate=:tax_rate, price=(:param_1 / (:tax_rate + :param_2)) + +When the target column is omitted, :func:`_sql.from_dml_column` falls back to +using the original column expression: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import update + >>> print(update(Product).values({Product.total_price: 125.00})) + {printsql}UPDATE product SET price=(:param_1 / (tax_rate + :param_2)) + + +.. seealso:: + + :ref:`hybrid_bulk_update` + +:ticket:`12496` + .. _change_10556: Addition of ``BitString`` subclass for handling postgresql ``BIT`` columns diff --git a/doc/build/changelog/unreleased_21/12496.rst b/doc/build/changelog/unreleased_21/12496.rst index 77d8ffb7d3..78bc102443 100644 --- a/doc/build/changelog/unreleased_21/12496.rst +++ b/doc/build/changelog/unreleased_21/12496.rst @@ -9,3 +9,18 @@ is mostly intended to be a helper with ORM :class:`.hybrid_property` within DML hooks. +.. change:: + :tags: feature, orm + :tickets: 12496 + + Added new hybrid method :meth:`.hybrid_property.bulk_dml` which + works in a similar way as :meth:`.hybrid_property.update_expression` for + bulk ORM operations. A user-defined class method can now populate a bulk + insert mapping dictionary using the desired hybrid mechanics. New + documentation is added showing how both of these methods can be used + including in combination with the new :func:`_sql.from_dml_column` + construct. + + .. seealso:: + + :ref:`change_12496` diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index cbf5e591c1..fe1f336852 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -320,59 +320,140 @@ The ``length(self, value)`` method is now called upon set:: .. _hybrid_bulk_update: -Allowing Bulk ORM Update ------------------------- +Supporting ORM Bulk INSERT and UPDATE +------------------------------------- -A hybrid can define a custom "UPDATE" handler for when using -ORM-enabled updates, allowing the hybrid to be used in the -SET clause of the update. +Hybrids have support for use in ORM Bulk INSERT/UPDATE operations described +at :ref:`orm_expression_update_delete`. There are two distinct hooks +that may be used supply a hybrid value within a DML operation: -Normally, when using a hybrid with :func:`_sql.update`, the SQL -expression is used as the column that's the target of the SET. If our -``Interval`` class had a hybrid ``start_point`` that linked to -``Interval.start``, this could be substituted directly:: +1. The :meth:`.hybrid_property.update_expression` hook indicates a method that + can provide one or more expressions to render in the SET clause of an + UPDATE or INSERT statement, in response to when a hybrid attribute is referenced + directly in the :meth:`.UpdateBase.values` method; i.e. the use shown + in :ref:`orm_queryguide_update_delete_where` and :ref:`orm_queryguide_insert_values` - from sqlalchemy import update +2. The :meth:`.hybrid_property.bulk_dml` hook indicates a method that + can intercept individual parameter dictionaries sent to :meth:`_orm.Session.execute`, + i.e. the use shown at :ref:`orm_queryguide_bulk_insert` as well + as :ref:`orm_queryguide_bulk_update`. - stmt = update(Interval).values({Interval.start_point: 10}) +Using update_expression with update.values() and insert.values() +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -However, when using a composite hybrid like ``Interval.length``, this -hybrid represents more than one column. We can set up a handler that will -accommodate a value passed in the VALUES expression which can affect -this, using the :meth:`.hybrid_property.update_expression` decorator. -A handler that works similarly to our setter would be:: +The :meth:`.hybrid_property.update_expression` decorator indicates a method +that is invoked when a hybrid is used in the :meth:`.ValuesBase.values` clause +of an :func:`_sql.update` or :func:`_sql.insert` statement. It returns a list +of tuple pairs ``[(x1, y1), (x2, y2), ...]`` which will expand into the SET +clause of an UPDATE statement as ``SET x1=y1, x2=y2, ...``. - from typing import List, Tuple, Any +The :func:`_sql.from_dml_column` construct is often useful as it can create a +SQL expression that refers to another column that may also present in the same +INSERT or UPDATE statement, alternatively falling back to referring to the +original column if such an expression is not present. +In the example below, the ``total_price`` hybrid will derive the ``price`` +column, by taking the given "total price" value and dividing it by a +``tax_rate`` value that is also present in the :meth:`.ValuesBase.values` call:: - class Interval(Base): - # ... + from sqlalchemy import from_dml_column - @hybrid_property - def length(self) -> int: - return self.end - self.start - @length.inplace.setter - def _length_setter(self, value: int) -> None: - self.end = self.start + value + class Product(Base): + __tablename__ = "product" - @length.inplace.update_expression - def _length_update_expression( + id: Mapped[int] = mapped_column(primary_key=True) + price: Mapped[float] + tax_rate: Mapped[float] + + @hybrid_property + def total_price(self) -> float: + return self.price * (1 + self.tax_rate) + + @total_price.inplace.update_expression + @classmethod + def _total_price_update_expression( cls, value: Any ) -> List[Tuple[Any, Any]]: - return [(cls.end, cls.start + value)] + return [(cls.price, value / (1 + from_dml_column(cls.tax_rate)))] -Above, if we use ``Interval.length`` in an UPDATE expression, we get -a hybrid SET expression: +When used in an UPDATE statement, :func:`_sql.from_dml_column` creates a +reference to the ``tax_rate`` column that will use the value passed to +the :meth:`.ValuesBase.values` method, rather than the existing value on the column +in the database. This allows the hybrid to access other values being +updated in the same statement: .. sourcecode:: pycon+sql + >>> from sqlalchemy import update + >>> print( + ... update(Product).values( + ... {Product.tax_rate: 0.08, Product.total_price: 125.00} + ... ) + ... ) + {printsql}UPDATE product SET tax_rate=:tax_rate, price=(:total_price / (:tax_rate + :param_1)) + +When the column referenced by :func:`_sql.from_dml_column` (in this case ``product.tax_rate``) +is omitted from :meth:`.ValuesBase.values`, the rendered expression falls back to +using the original column: + +.. sourcecode:: pycon+sql >>> from sqlalchemy import update - >>> print(update(Interval).values({Interval.length: 25})) - {printsql}UPDATE interval SET "end"=(interval.start + :start_1) + >>> print(update(Product).values({Product.total_price: 125.00})) + {printsql}UPDATE product SET price=(:total_price / (tax_rate + :param_1)) + + + +Using bulk_dml to intercept bulk parameter dictionaries +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. versionadded:: 2.1 -This SET expression is accommodated by the ORM automatically. +For bulk operations that pass a list of parameter dictionaries to +methods like :meth:`.Session.execute`, the +:meth:`.hybrid_property.bulk_dml` decorator provides a hook that can +receive each dictionary and populate it with new values. + +The implementation for the :meth:`.hybrid_property.bulk_dml` hook can retrieve +other column values from the parameter dictionary:: + + from typing import MutableMapping + + + class Product(Base): + __tablename__ = "product" + + id: Mapped[int] = mapped_column(primary_key=True) + price: Mapped[float] + tax_rate: Mapped[float] + + @hybrid_property + def total_price(self) -> float: + return self.price * (1 + self.tax_rate) + + @total_price.inplace.bulk_dml + @classmethod + def _total_price_bulk_dml( + cls, mapping: MutableMapping[str, Any], value: float + ) -> None: + mapping["price"] = value / (1 + mapping["tax_rate"]) + +This allows for bulk INSERT/UPDATE with derived values:: + + # Bulk INSERT + session.execute( + insert(Product), + [ + {"tax_rate": 0.08, "total_price": 125.00}, + {"tax_rate": 0.05, "total_price": 110.00}, + ], + ) + +Note that the method decorated by :meth:`.hybrid_property.bulk_dml` is invoked +only with parameter dictionaries and does not have the ability to use +SQL expressions in the given dictionaries, only literal Python values that will +be passed to parameters in the INSERT or UPDATE statement. .. seealso:: @@ -731,31 +812,36 @@ reference the instrumented attribute back to the hybrid object:: def name(cls): return func.concat(cls.first_name, " ", cls.last_name) +.. _hybrid_value_objects: + Hybrid Value Objects -------------------- -Note in our previous example, if we were to compare the ``word_insensitive`` +In the example shown previously at :ref:`hybrid_custom_comparators`, +if we were to compare the ``word_insensitive`` attribute of a ``SearchWord`` instance to a plain Python string, the plain Python string would not be coerced to lower case - the ``CaseInsensitiveComparator`` we built, being returned by ``@word_insensitive.comparator``, only applies to the SQL side. -A more comprehensive form of the custom comparator is to construct a *Hybrid -Value Object*. This technique applies the target value or expression to a value +A more comprehensive form of the custom comparator is to construct a **Hybrid +Value Object**. This technique applies the target value or expression to a value object which is then returned by the accessor in all cases. The value object allows control of all operations upon the value as well as how compared values are treated, both on the SQL expression side as well as the Python value side. Replacing the previous ``CaseInsensitiveComparator`` class with a new ``CaseInsensitiveWord`` class:: + from sqlalchemy import func + from sqlalchemy.ext.hybrid import Comparator + + class CaseInsensitiveWord(Comparator): "Hybrid value representing a lower case representation of a word." def __init__(self, word): - if isinstance(word, basestring): + if isinstance(word, str): self.word = word.lower() - elif isinstance(word, CaseInsensitiveWord): - self.word = word.word else: self.word = func.lower(word) @@ -774,11 +860,50 @@ Replacing the previous ``CaseInsensitiveComparator`` class with a new "Label to apply to Query tuple results" Above, the ``CaseInsensitiveWord`` object represents ``self.word``, which may -be a SQL function, or may be a Python native. By overriding ``operate()`` and -``__clause_element__()`` to work in terms of ``self.word``, all comparison -operations will work against the "converted" form of ``word``, whether it be -SQL side or Python side. Our ``SearchWord`` class can now deliver the -``CaseInsensitiveWord`` object unconditionally from a single hybrid call:: +be a SQL function, or may be a Python native string. The hybrid value object should +implement ``__clause_element__()``, which allows the object to be coerced into +a SQL-capable value when used in SQL expression constructs, as well as Python +comparison methods such as ``__eq__()``, which is accomplished in the above +example by subclassing :class:`.hybrid.Comparator` and overriding the +``operate()`` method. + +.. topic:: Building the Value object with dataclasses + + Hybrid value objects may also be implemented as Python dataclasses. If + modification to values upon construction is needed, use the + ``__post_init__()`` dataclasses method. Instance variables that work in + a "hybrid" fashion may be instance of a plain Python value, or an instance + of :class:`.SQLColumnExpression` genericized against that type. Also make sure to disable + dataclass comparison features, as the :class:`.hybrid.Comparator` class + provides these:: + + from sqlalchemy import func + from sqlalchemy.ext.hybrid import Comparator + from dataclasses import dataclass + + + @dataclass(eq=False) + class CaseInsensitiveWord(Comparator): + word: str | SQLColumnExpression[str] + + def __post_init__(self): + if isinstance(self.word, str): + self.word = self.word.lower() + else: + self.word = func.lower(self.word) + + def operate(self, op, other, **kwargs): + if not isinstance(other, CaseInsensitiveWord): + other = CaseInsensitiveWord(other) + return op(self.word, other.word, **kwargs) + + def __clause_element__(self): + return self.word + +With ``__clause_element__()`` provided, our ``SearchWord`` class +can now deliver the ``CaseInsensitiveWord`` object unconditionally from a +single hybrid method, returning an object that behaves appropriately +in both value-based and SQL contexts:: class SearchWord(Base): __tablename__ = "searchword" @@ -789,18 +914,20 @@ SQL side or Python side. Our ``SearchWord`` class can now deliver the def word_insensitive(self) -> CaseInsensitiveWord: return CaseInsensitiveWord(self.word) -The ``word_insensitive`` attribute now has case-insensitive comparison behavior -universally, including SQL expression vs. Python expression (note the Python -value is converted to lower case on the Python side here): +The class-level version of ``CaseInsensitiveWord`` will work in SQL +constructs: .. sourcecode:: pycon+sql - >>> print(select(SearchWord).filter_by(word_insensitive="Trucks")) + >>> print(select(SearchWord).filter(SearchWord.word_insensitive == "Trucks")) {printsql}SELECT searchword.id AS searchword_id, searchword.word AS searchword_word FROM searchword WHERE lower(searchword.word) = :lower_1 -SQL expression versus SQL expression: +By also subclassing :class:`.hybrid.Comparator` and providing an implementation +for ``operate()``, the ``word_insensitive`` attribute also has case-insensitive +comparison behavior universally, including SQL expression and Python expression +(note the Python value is converted to lower case on the Python side here): .. sourcecode:: pycon+sql @@ -841,6 +968,176 @@ measurement, currencies and encrypted passwords. `_ - on the techspot.zzzeek.org blog +.. _composite_hybrid_value_objects: + +Composite Hybrid Value Objects +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The functionality of :ref:`hybrid_value_objects` may also be expanded to +support "composite" forms; in this pattern, SQLAlchemy hybrids begin to +approximate most (though not all) the same functionality that is available from +the ORM natively via the :ref:`mapper_composite` feature. We can imitate the +example of ``Point`` and ``Vertex`` from that section using hybrids, where +``Point`` is modified to become a Hybrid Value Object:: + + from dataclasses import dataclass + + from sqlalchemy import tuple_ + from sqlalchemy.ext.hybrid import Comparator + from sqlalchemy import SQLColumnExpression + + + @dataclass(eq=False) + class Point(Comparator): + x: int | SQLColumnExpression[int] + y: int | SQLColumnExpression[int] + + def operate(self, op, other, **kwargs): + return op(self.x, other.x) & op(self.y, other.y) + + def __clause_element__(self): + return tuple_(self.x, self.y) + +Above, the ``operate()`` method is where the most "hybrid" behavior takes +place, making use of ``op()`` (the Python operator function in use) along +with the the bitwise ``&`` operator provides us with the SQL AND operator +in a SQL context, and boolean "and" in a Python boolean context. + +Following from there, the owning ``Vertex`` class now uses hybrids to +represent ``start`` and ``end``:: + + from sqlalchemy.orm import DeclarativeBase, Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.ext.hybrid import hybrid_property + + + class Base(DeclarativeBase): + pass + + + class Vertex(Base): + __tablename__ = "vertices" + + id: Mapped[int] = mapped_column(primary_key=True) + + x1: Mapped[int] + y1: Mapped[int] + x2: Mapped[int] + y2: Mapped[int] + + @hybrid_property + def start(self) -> Point: + return Point(self.x1, self.y1) + + @start.inplace.setter + def _set_start(self, value: Point) -> None: + self.x1 = value.x + self.y1 = value.y + + @hybrid_property + def end(self) -> Point: + return Point(self.x2, self.y2) + + @end.inplace.setter + def _set_end(self, value: Point) -> None: + self.x2 = value.x + self.y2 = value.y + + def __repr__(self) -> str: + return f"Vertex(start={self.start}, end={self.end})" + +Using the above mapping, we can use expressions at the Python or SQL level +using ``Vertex.start`` and ``Vertex.end``:: + + >>> v1 = Vertex(start=Point(3, 4), end=Point(15, 10)) + >>> v1.end == Point(15, 10) + True + >>> stmt = ( + ... select(Vertex) + ... .where(Vertex.start == Point(3, 4)) + ... .where(Vertex.end < Point(7, 8)) + ... ) + >>> print(stmt) + SELECT vertices.id, vertices.x1, vertices.y1, vertices.x2, vertices.y2 + FROM vertices + WHERE vertices.x1 = :x1_1 AND vertices.y1 = :y1_1 AND vertices.x2 < :x2_1 AND vertices.y2 < :y2_1 + +DML Support for Composite Value Objects +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Composite value objects like ``Point`` can also be used with the ORM's +DML features. The :meth:`.hybrid_property.update_expression` decorator allows +the hybrid to expand a composite value into multiple column assignments +in UPDATE and INSERT statements:: + + class Location(Base): + __tablename__ = "location" + + id: Mapped[int] = mapped_column(primary_key=True) + x: Mapped[int] + y: Mapped[int] + + @hybrid_property + def coordinates(self) -> Point: + return Point(self.x, self.y) + + @coordinates.inplace.update_expression + @classmethod + def _coordinates_update_expression( + cls, value: Any + ) -> List[Tuple[Any, Any]]: + assert isinstance(value, Point) + return [(cls.x, value.x), (cls.y, value.y)] + +This allows UPDATE statements to work with the composite value: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import update + >>> print( + ... update(Location) + ... .where(Location.id == 5) + ... .values({Location.coordinates: Point(25, 17)}) + ... ) + {printsql}UPDATE location SET x=:x, y=:y WHERE location.id = :id_1 + +For bulk operations that use parameter dictionaries, the +:meth:`.hybrid_property.bulk_dml` decorator provides a hook to +convert composite values into individual column values:: + + from typing import MutableMapping + + + class Location(Base): + # ... (same as above) + + @coordinates.inplace.bulk_dml + @classmethod + def _coordinates_bulk_dml( + cls, mapping: MutableMapping[str, Any], value: Point + ) -> None: + mapping["x"] = value.x + mapping["y"] = value.y + +This enables bulk operations with composite values:: + + # Bulk INSERT + session.execute( + insert(Location), + [ + {"id": 1, "coordinates": Point(10, 20)}, + {"id": 2, "coordinates": Point(30, 40)}, + ], + ) + + # Bulk UPDATE + session.execute( + update(Location), + [ + {"id": 1, "coordinates": Point(15, 25)}, + {"id": 2, "coordinates": Point(35, 45)}, + ], + ) """ # noqa @@ -851,6 +1148,7 @@ from typing import Callable from typing import cast from typing import Generic from typing import List +from typing import MutableMapping from typing import Optional from typing import overload from typing import Protocol @@ -861,6 +1159,7 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from .. import exc from .. import util from ..orm import attributes from ..orm import InspectionAttrExtensionType @@ -938,6 +1237,15 @@ class _HybridUpdaterType(Protocol[_T_con]): ) -> List[Tuple[_DMLColumnArgument, Any]]: ... +class _HybridBulkDMLType(Protocol[_T_co]): + def __call__( + s, + cls: Any, + mapping: MutableMapping[str, Any], + value: Any, + ) -> Any: ... + + class _HybridDeleterType(Protocol[_T_co]): def __call__(s, self: Any) -> None: ... @@ -979,6 +1287,10 @@ class _HybridClassLevelAccessor(QueryableAttribute[_T]): self, meth: _HybridUpdaterType[_T] ) -> hybrid_property[_T]: ... + def bulk_dml( + self, meth: _HybridBulkDMLType[_T] + ) -> hybrid_property[_T]: ... + class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): """A decorator which allows definition of a Python object method with both @@ -1093,6 +1405,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): expr: Optional[_HybridExprCallableType[_T]] = None, custom_comparator: Optional[Comparator[_T]] = None, update_expr: Optional[_HybridUpdaterType[_T]] = None, + bulk_dml_setter: Optional[_HybridBulkDMLType[_T]] = None, ): """Create a new :class:`.hybrid_property`. @@ -1117,6 +1430,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): self.expr = _unwrap_classmethod(expr) self.custom_comparator = _unwrap_classmethod(custom_comparator) self.update_expr = _unwrap_classmethod(update_expr) + self.bulk_dml_setter = _unwrap_classmethod(bulk_dml_setter) util.update_wrapper(self, fget) # type: ignore[arg-type] @overload @@ -1237,6 +1551,11 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): ) -> hybrid_property[_TE]: return self._set(update_expr=meth) + def bulk_dml( + self, meth: _HybridBulkDMLType[_TE] + ) -> hybrid_property[_TE]: + return self._set(bulk_dml_setter=meth) + @property def inplace(self) -> _InPlace[_T]: """Return the inplace mutator for this :class:`.hybrid_property`. @@ -1388,6 +1707,14 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): """ return self._copy(update_expr=meth) + def bulk_dml(self, meth: _HybridBulkDMLType[_T]) -> hybrid_property[_T]: + """Define a setter for bulk dml. + + .. versionadded:: 2.1 + + """ + return self._copy(bulk_dml=meth) + @util.memoized_property def _expr_comparator( self, @@ -1498,7 +1825,8 @@ class ExprComparator(Comparator[_T]): return self.hybrid.info def _bulk_update_tuples( - self, value: Any + self, + value: Any, ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: if isinstance(self.expression, attributes.QueryableAttribute): return self.expression._bulk_update_tuples(value) @@ -1507,6 +1835,28 @@ class ExprComparator(Comparator[_T]): else: return [(self.expression, value)] + def _bulk_dml_setter(self, key: str) -> Optional[Callable[..., Any]]: + """return a callable that will process a bulk INSERT value""" + + meth = None + + def prop(mapping: MutableMapping[str, Any]) -> None: + nonlocal meth + value = mapping[key] + + if meth is None: + if self.hybrid.bulk_dml_setter is None: + raise exc.InvalidRequestError( + "Can't evaluate bulk DML statement; please " + "supply a bulk_dml decorated function" + ) + + meth = self.hybrid.bulk_dml_setter + + meth(self.cls, mapping, value) + + return prop + @util.non_memoized_property def property(self) -> MapperProperty[_T]: # this accessor is not normally used, however is accessed by things diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index bd229a271d..46462049cc 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -392,6 +392,11 @@ class QueryableAttribute( return self.comparator._bulk_update_tuples(value) + def _bulk_dml_setter(self, key: str) -> Optional[Callable[..., Any]]: + """return a callable that will process a bulk INSERT value""" + + return self.comparator._bulk_dml_setter(key) + def adapt_to_entity(self, adapt_to_entity: AliasedInsp[Any]) -> Self: assert not self._of_type return self.__class__( diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 7918c3ba84..8e813d667a 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -35,6 +35,7 @@ from .context import _AbstractORMCompileState from .context import _ORMFromStatementCompileState from .context import FromStatement from .context import QueryContext +from .interfaces import PropComparator from .. import exc as sa_exc from .. import util from ..engine import Dialect @@ -150,7 +151,7 @@ def _bulk_insert( # for all other cases we need to establish a local dictionary # so that the incoming dictionaries aren't mutated mappings = [dict(m) for m in mappings] - _expand_composites(mapper, mappings) + _expand_other_attrs(mapper, mappings) connection = session_transaction.connection(base_mapper) @@ -309,7 +310,7 @@ def _bulk_update( mappings = [state.dict for state in mappings] else: mappings = [dict(m) for m in mappings] - _expand_composites(mapper, mappings) + _expand_other_attrs(mapper, mappings) if session_transaction.session.connection_callable: raise NotImplementedError( @@ -371,19 +372,32 @@ def _bulk_update( return _result.null_result() -def _expand_composites(mapper, mappings): - composite_attrs = mapper.composites - if not composite_attrs: - return +def _expand_other_attrs( + mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]] +) -> None: + all_attrs = mapper.all_orm_descriptors + + attr_keys = set(all_attrs.keys()) - composite_keys = set(composite_attrs.keys()) - populators = { - key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() - for key in composite_keys + bulk_dml_setters = { + key: setter + for key, setter in ( + (key, attr._bulk_dml_setter(key)) + for key, attr in ( + (key, _entity_namespace_key(mapper, key, default=NO_VALUE)) + for key in attr_keys + ) + if attr is not NO_VALUE and isinstance(attr, PropComparator) + ) + if setter is not None } + setters_todo = set(bulk_dml_setters) + if not setters_todo: + return + for mapping in mappings: - for key in composite_keys.intersection(mapping): - populators[key](mapping) + for key in setters_todo.intersection(mapping): + bulk_dml_setters[key](mapping) class _ORMDMLState(_AbstractORMCompileState): @@ -401,7 +415,7 @@ class _ORMDMLState(_AbstractORMCompileState): if isinstance(k, str): desc = _entity_namespace_key(mapper, k, default=NO_VALUE) - if desc is NO_VALUE: + if not isinstance(desc, PropComparator): yield ( coercions.expect(roles.DMLColumnRole, k), ( @@ -426,6 +440,7 @@ class _ORMDMLState(_AbstractORMCompileState): attr = _entity_namespace_key( k_anno["entity_namespace"], k_anno["proxy_key"] ) + assert isinstance(attr, PropComparator) yield from core_get_crud_kv_pairs( statement, attr._bulk_update_tuples(v), diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index d5f7bcc876..287d065b0b 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -810,6 +810,9 @@ class CompositeProperty( return list(zip(self._comparable_elements, values)) + def _bulk_dml_setter(self, key: str) -> Optional[Callable[..., Any]]: + return self.prop._populate_composite_bulk_save_mappings_fn() + @util.memoized_property def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: if self._adapt_to_entity: diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 266acbc472..71bcdcdf00 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -891,6 +891,11 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): return [(cast("_DMLColumnArgument", self.__clause_element__()), value)] + def _bulk_dml_setter(self, key: str) -> Optional[Callable[..., Any]]: + """return a callable that will process a bulk INSERT value""" + + return None + def adapt_to_entity( self, adapt_to_entity: AliasedInsp[Any] ) -> PropComparator[_T_co]: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 2f8bebee51..9bc5cc055d 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -112,6 +112,7 @@ if TYPE_CHECKING: from ..engine import RowMapping from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _EquivalentColumnMap + from ..sql.base import _EntityNamespace from ..sql.base import ReadOnlyColumnCollection from ..sql.elements import ColumnClause from ..sql.elements import ColumnElement @@ -3096,9 +3097,9 @@ class Mapper( return self._filter_properties(descriptor_props.SynonymProperty) - @property - def entity_namespace(self): - return self.class_ + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: + return self.class_ # type: ignore[return-value] @HasMemoized.memoized_attribute def column_attrs(self) -> util.ReadOnlyProperties[ColumnProperty[Any]]: diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index abb5b14b4c..7fe4abb545 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -1750,7 +1750,7 @@ def true() -> True_: def tuple_( - *clauses: _ColumnExpressionArgument[Any], + *clauses: _ColumnExpressionOrLiteralArgument[Any], types: Optional[Sequence[_TypeEngineArgument[Any]]] = None, ) -> Tuple: """Return a :class:`.Tuple`. diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index fe6cdf6a07..9381954fa6 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -2384,11 +2384,34 @@ def _entity_namespace( raise +@overload def _entity_namespace_key( entity: Union[_HasEntityNamespace, ExternallyTraversible], key: str, - default: Union[SQLCoreOperations[Any], _NoArg] = NO_ARG, -) -> SQLCoreOperations[Any]: +) -> SQLCoreOperations[Any]: ... + + +@overload +def _entity_namespace_key( + entity: Union[_HasEntityNamespace, ExternallyTraversible], + key: str, + default: _NoArg, +) -> SQLCoreOperations[Any]: ... + + +@overload +def _entity_namespace_key( + entity: Union[_HasEntityNamespace, ExternallyTraversible], + key: str, + default: _T, +) -> Union[SQLCoreOperations[Any], _T]: ... + + +def _entity_namespace_key( + entity: Union[_HasEntityNamespace, ExternallyTraversible], + key: str, + default: Union[SQLCoreOperations[Any], _T, _NoArg] = NO_ARG, +) -> Union[SQLCoreOperations[Any], _T]: """Return an entry from an entity_namespace. diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 5abb4e3ec5..2b6df2e7cf 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -82,6 +82,7 @@ from .elements import ClauseList as ClauseList from .elements import CollectionAggregate as CollectionAggregate from .elements import ColumnClause as ColumnClause from .elements import ColumnElement as ColumnElement +from .elements import DMLTargetCopy as DMLTargetCopy from .elements import ExpressionClauseList as ExpressionClauseList from .elements import Extract as Extract from .elements import False_ as False_ diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index 09da020743..ac4274dd67 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -1,8 +1,13 @@ +from __future__ import annotations + +import dataclasses from decimal import Decimal +from typing import TYPE_CHECKING from sqlalchemy import column from sqlalchemy import exc from sqlalchemy import ForeignKey +from sqlalchemy import from_dml_column from sqlalchemy import func from sqlalchemy import insert from sqlalchemy import inspect @@ -14,11 +19,14 @@ from sqlalchemy import Numeric from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import tuple_ from sqlalchemy.ext import hybrid from sqlalchemy.orm import aliased from sqlalchemy.orm import column_property from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.orm import synonym @@ -30,10 +38,13 @@ from sqlalchemy.sql import update from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not +from sqlalchemy.testing.assertsql import CompiledSQL +from sqlalchemy.testing.assertsql import Conditional from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -1247,6 +1258,11 @@ class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): + """Original DML test suite when we first added the ability for ORM + UPDATE to handle hybrid values. + + """ + __dialect__ = "default" @classmethod @@ -1534,6 +1550,405 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): eq_(s.scalar(select(Person.first_name).where(Person.id == 3)), "first") +if TYPE_CHECKING: + from sqlalchemy.sql import SQLColumnExpression + + +@dataclasses.dataclass(eq=False) +class Point(hybrid.Comparator): + x: int | SQLColumnExpression[int] + y: int | SQLColumnExpression[int] + + def operate(self, op, other, **kwargs): + return op(self.x, other.x) & op(self.y, other.y) + + def __clause_element__(self): + return tuple_(self.x, self.y) + + +class DMLTest( + fixtures.TestBase, AssertsCompiledSQL, testing.AssertsExecutionResults +): + """updated DML test suite when #12496 was done, where we created the use + cases of "expansive" and "derived" hybrids and how their use cases + differ, and also added the bulk_dml hook as well as the from_dml_column + construct. + + + """ + + __dialect__ = "default" + + @testing.fixture + def single_plain(self, decl_base): + """fixture with a single-col hybrid""" + + class A(decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + x: Mapped[int] + + @hybrid.hybrid_property + def x_plain(self): + return self.x + + return A + + @testing.fixture + def expand_plain(self, decl_base): + """fixture with an expand hybrid (deals w/ a value object that spans + multiple columns)""" + + class A(decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + x: Mapped[int] + y: Mapped[int] + + @hybrid.hybrid_property + def xy(self): + return Point(self.x, self.y) + + return A + + @testing.fixture + def expand_update(self, decl_base): + """fixture with an expand hybrid (deals w/ a value object that spans + multiple columns)""" + + class A(decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + x: Mapped[int] + y: Mapped[int] + + @hybrid.hybrid_property + def xy(self): + return Point(self.x, self.y) + + @xy.inplace.update_expression + @classmethod + def _xy(cls, value): + return [(cls.x, value.x), (cls.y, value.y)] + + return A + + @testing.fixture + def expand_dml(self, decl_base): + """fixture with an expand hybrid (deals w/ a value object that spans + multiple columns)""" + + class A(decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + x: Mapped[int] + y: Mapped[int] + + @hybrid.hybrid_property + def xy(self): + return Point(self.x, self.y) + + @xy.inplace.bulk_dml + @classmethod + def _xy(cls, mapping, value): + mapping["x"] = value.x + mapping["y"] = value.y + + return A + + @testing.fixture + def derived_update(self, decl_base): + """fixture with a derive hybrid (value is derived from other columns + with data that's not in the value object itself) + """ + + class A(decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + amount: Mapped[int] + rate: Mapped[float] + + @hybrid.hybrid_property + def adjusted_amount(self): + return self.amount * self.rate + + @adjusted_amount.inplace.update_expression + @classmethod + def _adjusted_amount(cls, value): + return [(cls.amount, value / from_dml_column(cls.rate))] + + return A + + @testing.fixture + def derived_dml(self, decl_base): + """fixture with a derive hybrid (value is derived from other columns + with data that's not in the value object itself) + """ + + class A(decl_base): + __tablename__ = "a" + id: Mapped[int] = mapped_column(primary_key=True) + + amount: Mapped[int] + rate: Mapped[float] + + @hybrid.hybrid_property + def adjusted_amount(self): + return self.amount * self.rate + + @adjusted_amount.inplace.bulk_dml + @classmethod + def _adjusted_amount(cls, mapping, value): + mapping["amount"] = value / mapping["rate"] + + return A + + def test_single_plain_update_values(self, single_plain): + A = single_plain + self.assert_compile( + update(A).values({A.x_plain: 10}), + "UPDATE a SET x=:x", + checkparams={"x": 10}, + ) + + def test_single_plain_insert_values(self, single_plain): + A = single_plain + self.assert_compile( + insert(A).values({A.x_plain: 10}), + "INSERT INTO a (x) VALUES (:x)", + checkparams={"x": 10}, + ) + + @testing.variation("crud", ["insert", "update"]) + def test_single_plain_bulk(self, crud, decl_base, single_plain): + A = single_plain + + decl_base.metadata.create_all(testing.db) + + with expect_raises_message( + exc.InvalidRequestError, + "Can't evaluate bulk DML statement; " + "please supply a bulk_dml decorated function", + ): + with Session(testing.db) as session: + session.execute( + insert(A) if crud.insert else update(A), + [ + {"x_plain": 10}, + {"x_plain": 11}, + ], + ) + + @testing.variation("keytype", ["attr", "string"]) + def test_expand_plain_update_values(self, expand_plain, keytype): + A = expand_plain + + # SQL tuple_ update happens instead due to __clause_element__ + self.assert_compile( + update(A) + .where(A.xy == Point(10, 12)) + .values({"xy" if keytype.string else A.xy: Point(5, 6)}), + "UPDATE a SET (x, y)=(:param_1, :param_2) " + "WHERE a.x = :x_1 AND a.y = :y_1", + {"param_1": 5, "param_2": 6, "x_1": 10, "y_1": 12}, + ) + + @testing.variation("crud", ["insert", "update"]) + def test_expand_update_bulk(self, crud, expand_update, decl_base): + A = expand_update + decl_base.metadata.create_all(testing.db) + + with expect_raises_message( + exc.InvalidRequestError, + "Can't evaluate bulk DML statement; " + "please supply a bulk_dml decorated function", + ): + with Session(testing.db) as session: + session.execute( + insert(A) if crud.insert else update(A), + [ + {"xy": Point(3, 4)}, + {"xy": Point(5, 6)}, + ], + ) + + @testing.variation("crud", ["insert", "update"]) + def test_expand_dml_bulk(self, crud, expand_dml, decl_base, connection): + A = expand_dml + decl_base.metadata.create_all(connection) + + with self.sql_execution_asserter(connection) as asserter: + with Session(connection) as session: + session.execute( + insert(A), + [ + {"id": 1, "xy": Point(3, 4)}, + {"id": 2, "xy": Point(5, 6)}, + ], + ) + + if crud.update: + session.execute( + update(A), + [ + {"id": 1, "xy": Point(10, 9)}, + {"id": 2, "xy": Point(7, 8)}, + ], + ) + asserter.assert_( + CompiledSQL( + "INSERT INTO a (id, x, y) VALUES (:id, :x, :y)", + [{"id": 1, "x": 3, "y": 4}, {"id": 2, "x": 5, "y": 6}], + ), + Conditional( + crud.update, + [ + CompiledSQL( + "UPDATE a SET x=:x, y=:y WHERE a.id = :a_id", + [ + {"x": 10, "y": 9, "a_id": 1}, + {"x": 7, "y": 8, "a_id": 2}, + ], + ) + ], + [], + ), + ) + + @testing.variation("keytype", ["attr", "string"]) + def test_expand_update_insert_values(self, expand_update, keytype): + A = expand_update + self.assert_compile( + insert(A).values({"xy" if keytype.string else A.xy: Point(5, 6)}), + "INSERT INTO a (x, y) VALUES (:x, :y)", + checkparams={"x": 5, "y": 6}, + ) + + @testing.variation("keytype", ["attr", "string"]) + def test_expand_update_update_values(self, expand_update, keytype): + A = expand_update + self.assert_compile( + update(A).values({"xy" if keytype.string else A.xy: Point(5, 6)}), + "UPDATE a SET x=:x, y=:y", + checkparams={"x": 5, "y": 6}, + ) + + ##################################################### + + @testing.variation("keytype", ["attr", "string"]) + def test_derived_update_insert_values(self, derived_update, keytype): + A = derived_update + self.assert_compile( + insert(A).values( + { + "rate" if keytype.string else A.rate: 1.5, + ( + "adjusted_amount" + if keytype.string + else A.adjusted_amount + ): 25, + } + ), + "INSERT INTO a (amount, rate) VALUES " + "((:param_1 / CAST(:rate AS FLOAT)), :rate)", + checkparams={"param_1": 25, "rate": 1.5}, + ) + + @testing.variation("keytype", ["attr", "string"]) + @testing.variation("rate_present", [True, False]) + def test_derived_update_update_values( + self, derived_update, rate_present, keytype + ): + A = derived_update + + if rate_present: + # when column is present in UPDATE SET, from_dml_column + # uses that expression + self.assert_compile( + update(A).values( + { + "rate" if keytype.string else A.rate: 1.5, + ( + "adjusted_amount" + if keytype.string + else A.adjusted_amount + ): 25, + } + ), + "UPDATE a SET amount=(:param_1 / CAST(:rate AS FLOAT)), " + "rate=:rate", + checkparams={"param_1": 25, "rate": 1.5}, + ) + else: + # when column is not present in UPDATE SET, from_dml_column + # renders the column, which will work in an UPDATE, but not INSERT + self.assert_compile( + update(A).values( + { + ( + "adjusted_amount" + if keytype.string + else A.adjusted_amount + ): 25 + } + ), + "UPDATE a SET amount=(:param_1 / CAST(a.rate AS FLOAT))", + checkparams={"param_1": 25}, + ) + + @testing.variation("crud", ["insert", "update"]) + def test_derived_dml_bulk(self, crud, derived_dml, decl_base, connection): + A = derived_dml + decl_base.metadata.create_all(connection) + + with self.sql_execution_asserter(connection) as asserter: + with Session(connection) as session: + session.execute( + insert(A), + [ + {"rate": 1.5, "adjusted_amount": 25}, + {"rate": 2.5, "adjusted_amount": 25}, + ], + ) + + if crud.update: + session.execute( + update(A), + [ + {"id": 1, "rate": 1.8, "adjusted_amount": 30}, + {"id": 2, "rate": 2.8, "adjusted_amount": 40}, + ], + ) + asserter.assert_( + CompiledSQL( + "INSERT INTO a (amount, rate) VALUES (:amount, :rate)", + [ + {"amount": 25 / 1.5, "rate": 1.5}, + {"amount": 25 / 2.5, "rate": 2.5}, + ], + ), + Conditional( + crud.update, + [ + CompiledSQL( + "UPDATE a SET amount=:amount, rate=:rate " + "WHERE a.id = :a_id", + [ + {"amount": 30 / 1.8, "rate": 1.8, "a_id": 1}, + {"amount": 40 / 2.8, "rate": 2.8, "a_id": 2}, + ], + ) + ], + [], + ), + ) + + class SpecialObjectTest(fixtures.TestBase, AssertsCompiledSQL): """tests against hybrids that return a non-ClauseElement. -- 2.47.3