From: Mike Bayer Date: Thu, 16 Feb 2023 14:39:07 +0000 (-0500) Subject: modernize hybrids and apply typing X-Git-Tag: rel_2_0_4~4^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=81993801dd39dd4a5973f8500e849f35ac07f2f3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git modernize hybrids and apply typing Improved the typing support for the :ref:`hybrids_toplevel` extension, updated all documentation to use ORM Annotated Declarative mappings, and added a new modifier called :attr:`.hybrid_property.inplace`. This modifier provides a way to alter the state of a :class:`.hybrid_property` **in place**, which is essentially what very early versions of hybrids did, before SQLAlchemy version 1.2.0 :ticket:`3912` changed this to remove in-place mutation. This in-place mutation is now restored on an **opt-in** basis to allow a single hybrid to have multiple methods set up, without the need to name all the methods the same and without the need to carefully "chain" differently-named methods in order to maintain the composition. Typing tools such as Mypy and Pyright do not allow same-named methods on a class, so with this change a succinct method of setting up hybrids with typing support is restored. Change-Id: Iea88025f023428f9f006846d09fbb4be391f5ebb References: #9321 --- diff --git a/doc/build/changelog/unreleased_20/9321.rst b/doc/build/changelog/unreleased_20/9321.rst new file mode 100644 index 0000000000..cfc0cb0d75 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9321.rst @@ -0,0 +1,21 @@ +.. change:: + :tags: usecase, typing + :tickets: 9321 + + Improved the typing support for the :ref:`hybrids_toplevel` + extension, updated all documentation to use ORM Annotated Declarative + mappings, and added a new modifier called :attr:`.hybrid_property.inplace`. + This modifier provides a way to alter the state of a :class:`.hybrid_property` + **in place**, which is essentially what very early versions of hybrids + did, before SQLAlchemy version 1.2.0 :ticket:`3912` changed this to + remove in-place mutation. This in-place mutation is now restored on an + **opt-in** basis to allow a single hybrid to have multiple methods + set up, without the need to name all the methods the same and without the + need to carefully "chain" differently-named methods in order to maintain + the composition. Typing tools such as Mypy and Pyright do not allow + same-named methods on a class, so with this change a succinct method + of setting up hybrids with typing support is restored. + + .. seealso:: + + :ref:`hybrid_pep484_naming` diff --git a/doc/build/changelog/whatsnew_20.rst b/doc/build/changelog/whatsnew_20.rst index 60783653a8..3698a64f02 100644 --- a/doc/build/changelog/whatsnew_20.rst +++ b/doc/build/changelog/whatsnew_20.rst @@ -1824,7 +1824,7 @@ operations. ORM Declarative Applies Column Orders Differently; Control behavior using ``__table_cls__`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Declarative has changed the system by which mapped columns that originate from mixin or abstract base classes are sorted along with the columns that are on the diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 7061ff2b4c..9fdf4d777b 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -22,36 +22,42 @@ instance level. Below, each function decorated with :class:`.hybrid_method` or :class:`.hybrid_property` may receive ``self`` as an instance of the class, or as the class itself:: - from sqlalchemy import Column, Integer - from sqlalchemy.ext.declarative import declarative_base - from sqlalchemy.orm import Session, aliased - from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method + from __future__ import annotations + + from sqlalchemy.ext.hybrid import hybrid_method + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column - Base = declarative_base() + + class Base(DeclarativeBase): + pass class Interval(Base): __tablename__ = 'interval' - id = Column(Integer, primary_key=True) - start = Column(Integer, nullable=False) - end = Column(Integer, nullable=False) + id: Mapped[int] = mapped_column(primary_key=True) + start: Mapped[int] + end: Mapped[int] - def __init__(self, start, end): + def __init__(self, start: int, end: int): self.start = start self.end = end @hybrid_property - def length(self): + def length(self) -> int: return self.end - self.start @hybrid_method - def contains(self, point): + def contains(self, point: int) -> bool: return (self.start <= point) & (point <= self.end) @hybrid_method - def intersects(self, other): + def intersects(self, other: Interval) -> bool: return self.contains(other.start) | self.contains(other.end) + Above, the ``length`` property returns the difference between the ``end`` and ``start`` attributes. With an instance of ``Interval``, this subtraction occurs in Python, using normal Python descriptor @@ -141,18 +147,22 @@ two separate Python expressions should be defined. The example we'll define the radius of the interval, which requires the usage of the absolute value function:: + from sqlalchemy import ColumnElement + from sqlalchemy import Float from sqlalchemy import func + from sqlalchemy import type_coerce - class Interval: + class Interval(Base): # ... @hybrid_property - def radius(self): + def radius(self) -> float: return abs(self.length) / 2 - @radius.expression - def radius(cls): - return func.abs(cls.length) / 2 + @radius.inplace.expression + @classmethod + def _radius_expression(cls) -> ColumnElement[float]: + return type_coerce(func.abs(cls.length) / 2, Float) Above the Python function ``abs()`` is used for instance-level operations, the SQL function ``ABS()`` is used via the :data:`.func` @@ -169,27 +179,72 @@ object for class-level expressions: FROM interval WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1 -.. note:: When defining an expression for a hybrid property or method, the - expression method **must** retain the name of the original hybrid, else - the new hybrid with the additional state will be attached to the class - with the non-matching name. To use the example above:: +.. _hybrid_pep484_naming: + +Notes on Method Names in a pep-484 World +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to work with Python typing tools such as mypy, all method +names on a class must be differently-named. While experienced Python users +will note that the Python ``@property`` decorator does not have this limitation +with typing tools, as of this writing this is only because all Python typing +tools have hardcoded rules that are specific to ``@property`` which are +not made available to any other user-defined decorators +(see https://github.com/python/typing/discussions/1102 .) - class Interval: +Therefore SQLAlchemy 2.0 introduces a new modifier +:attr:`.hybrid_property.inplace` that allows new methods to be added to an +existing hybrid property **in place**, so that the official name of the hybrid +can be stated once up front, and the correctly-named hybrid property can then +be re-used to add more methods, **without** the need to name those methods the +same way and thus avoiding naming conflicts:: + + + class Interval(Base): # ... @hybrid_property - def radius(self): + def radius(self) -> float: return abs(self.length) / 2 - # WRONG - the non-matching name will cause this function to be - # ignored - @radius.expression - def radius_expression(cls): - return func.abs(cls.length) / 2 + @radius.inplace.setter + def _radius_setter(self, value: float) -> None: + # for example only + self.length = value * 2 + + @radius.inplace.expression + @classmethod + def _radius_expression(cls) -> ColumnElement[float]:: + return type_coerce(func.abs(cls.length) / 2, Float) + +When not using the :attr:`.hybrid_property.inplace` modifier, all hybrid +property modifiers return a **new** object each time. Without +:attr:`.hybrid_property.inplace`, the above methods need to be carefully +chained together:: + + class Interval(Base): + # ... + + # old approach not using .inplace + + @hybrid_property + def _radius_getter(self) -> float: + return abs(self.length) / 2 + + @_radius_getter.setter + def _radius_setter(self, value: float) -> None: + # for example only + self.length = value * 2 + + @_radius_setter.expression + @classmethod + def _radius_expression(cls) -> ColumnElement[float]:: + return type_coerce(func.abs(cls.length) / 2, Float) + +.. versionadded:: 2.0.4 Added :attr:`.hybrid_property.inplace` to allow + less verbose construction of composite :class:`.hybrid_property` objects + while not having to use repeated method names. - This is also true for other mutator methods, such as - :meth:`.hybrid_property.update_expression`. This is the same behavior - as that of the ``@property`` construct that is part of standard Python. Defining Setters ---------------- @@ -197,15 +252,15 @@ Defining Setters Hybrid properties can also define setter methods. If we wanted ``length`` above, when set, to modify the endpoint value:: - class Interval: + class Interval(Base): # ... @hybrid_property - def length(self): + def length(self) -> int: return self.end - self.start - @length.setter - def length(self, value): + @length.inplace.setter + def _length_setter(self, value: int) -> None: self.end = self.start + value The ``length(self, value)`` method is now called upon set:: @@ -239,33 +294,34 @@ accommodate a value passed to :meth:`_query.Query.update` which can affect this, using the :meth:`.hybrid_property.update_expression` decorator. A handler that works similarly to our setter would be:: - class Interval: + from typing import List, Tuple, Any + + class Interval(Base): # ... @hybrid_property - def length(self): + def length(self) -> int: return self.end - self.start - @length.setter - def length(self, value): + @length.inplace.setter + def _length_setter(self, value: int) -> None: self.end = self.start + value - @length.update_expression - def length(cls, value): + @length.inplace.update_expression + def _length_update_expression(cls, value: Any) -> List[Tuple[Any, Any]]: return [ (cls.end, cls.start + value) ] -Above, if we use ``Interval.length`` in an UPDATE expression as:: +Above, if we use ``Interval.length`` in an UPDATE expression, we get +a hybrid SET expression: - session.query(Interval).update( - {Interval.length: 25}, synchronize_session='fetch') - -We'll get an UPDATE statement along the lines of: +.. sourcecode:: pycon+sql -.. sourcecode:: sql - UPDATE interval SET end=start + :value + >>> from sqlalchemy import update + >>> print(update(Interval).values({Interval.length: 25})) + {printsql}UPDATE interval SET "end"=(interval.start + :start_1) In some cases, the default "evaluate" strategy can't perform the SET expression in Python; while the addition operator we're using above @@ -273,35 +329,6 @@ is supported, for more complex SET expressions it will usually be necessary to use either the "fetch" or False synchronization strategy as illustrated above. -.. note:: For ORM bulk updates to work with hybrids, the function name - of the hybrid must match that of how it is accessed. Something - like this wouldn't work:: - - class Interval: - # ... - - def _get(self): - return self.end - self.start - - def _set(self, value): - self.end = self.start + value - - def _update_expr(cls, value): - return [ - (cls.end, cls.start + value) - ] - - length = hybrid_property( - fget=_get, fset=_set, update_expr=_update_expr - ) - - The Python descriptor protocol does not provide any reliable way for - a descriptor to know what attribute name it was accessed as, and - the UPDATE scheme currently relies upon being able to access the - attribute from an instance by name in order to perform the instance - synchronization step. - -.. versionadded:: 1.2 added support for bulk updates to hybrid properties. Working with Relationships -------------------------- @@ -317,57 +344,89 @@ Join-Dependent Relationship Hybrid Consider the following declarative mapping which relates a ``User`` to a ``SavingsAccount``:: - from sqlalchemy import Column, Integer, ForeignKey, Numeric, String - from sqlalchemy.orm import relationship - from sqlalchemy.ext.declarative import declarative_base + from __future__ import annotations + + from decimal import Decimal + from typing import cast + from typing import List + from typing import Optional + + from sqlalchemy import ForeignKey + from sqlalchemy import Numeric + from sqlalchemy import String + from sqlalchemy import SQLColumnExpression from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import relationship + + + class Base(DeclarativeBase): + pass - Base = declarative_base() class SavingsAccount(Base): __tablename__ = 'account' - id = Column(Integer, primary_key=True) - user_id = Column(Integer, ForeignKey('user.id'), nullable=False) - balance = Column(Numeric(15, 5)) + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) + balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) + + owner: Mapped[User] = relationship(back_populates="accounts") class User(Base): __tablename__ = 'user' - id = Column(Integer, primary_key=True) - name = Column(String(100), nullable=False) + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) - accounts = relationship("SavingsAccount", backref="owner") + accounts: Mapped[List[SavingsAccount]] = relationship( + back_populates="owner", lazy="selectin" + ) @hybrid_property - def balance(self): + def balance(self) -> Optional[Decimal]: if self.accounts: return self.accounts[0].balance else: return None - @balance.setter - def balance(self, value): + @balance.inplace.setter + def _balance_setter(self, value: Optional[Decimal]) -> None: + assert value is not None + if not self.accounts: - account = Account(owner=self) + account = SavingsAccount(owner=self) else: account = self.accounts[0] account.balance = value - @balance.expression - def balance(cls): - return SavingsAccount.balance + @balance.inplace.expression + @classmethod + def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]: + return cast("SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance) The above hybrid property ``balance`` works with the first ``SavingsAccount`` entry in the list of accounts for this user. The in-Python getter/setter methods can treat ``accounts`` as a Python list available on ``self``. -However, at the expression level, it's expected that the ``User`` class will +.. tip:: The ``User.balance`` getter in the above example accesses the + ``self.acccounts`` collection, which will normally be loaded via the + :func:`.selectinload` loader strategy configured on the ``User.balance`` + :func:`_orm.relationship`. The default loader strategy when not otherwise + stated on :func:`_orm.relationship` is :func:`.lazyload`, which emits SQL on + demand. When using asyncio, on-demand loaders such as :func:`.lazyload` are + not supported, so care should be taken to ensure the ``self.accounts`` + collection is accessible to this hybrid accessor when using asyncio. + +At the expression level, it's expected that the ``User`` class will be used in an appropriate context such that an appropriate join to ``SavingsAccount`` will be present: .. sourcecode:: pycon+sql - >>> print(Session().query(User, User.balance). + >>> from sqlalchemy import select + >>> print(select(User, User.balance). ... join(User.accounts).filter(User.balance > 5000)) {printsql}SELECT "user".id AS user_id, "user".name AS user_name, account.balance AS account_balance @@ -381,8 +440,9 @@ would use an outer join: .. sourcecode:: pycon+sql + >>> from sqlalchemy import select >>> from sqlalchemy import or_ - >>> print (Session().query(User, User.balance).outerjoin(User.accounts). + >>> print (select(User, User.balance).outerjoin(User.accounts). ... filter(or_(User.balance < 5000, User.balance == None))) {printsql}SELECT "user".id AS user_id, "user".name AS user_name, account.balance AS account_balance @@ -400,48 +460,73 @@ illustrated at :ref:`mapper_column_property_sql_expressions`, we can adjust our ``SavingsAccount`` example to aggregate the balances for *all* accounts, and use a correlated subquery for the column expression:: - from sqlalchemy import Column, Integer, ForeignKey, Numeric, String - from sqlalchemy.orm import relationship - from sqlalchemy.ext.declarative import declarative_base + from __future__ import annotations + + from decimal import Decimal + from typing import List + + from sqlalchemy import ForeignKey + from sqlalchemy import func + from sqlalchemy import Numeric + from sqlalchemy import select + from sqlalchemy import SQLColumnExpression + from sqlalchemy import String from sqlalchemy.ext.hybrid import hybrid_property - from sqlalchemy import select, func + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import relationship + + + class Base(DeclarativeBase): + pass - Base = declarative_base() class SavingsAccount(Base): __tablename__ = 'account' - id = Column(Integer, primary_key=True) - user_id = Column(Integer, ForeignKey('user.id'), nullable=False) - balance = Column(Numeric(15, 5)) + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) + balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) + + owner: Mapped[User] = relationship(back_populates="accounts") class User(Base): __tablename__ = 'user' - id = Column(Integer, primary_key=True) - name = Column(String(100), nullable=False) + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) - accounts = relationship("SavingsAccount", backref="owner") + accounts: Mapped[List[SavingsAccount]] = relationship( + back_populates="owner", lazy="selectin" + ) @hybrid_property - def balance(self): - return sum(acc.balance for acc in self.accounts) + def balance(self) -> Decimal: + return sum((acc.balance for acc in self.accounts), start=Decimal("0")) + + @balance.inplace.expression + @classmethod + def _balance_expression(cls) -> SQLColumnExpression[Decimal]: + return ( + select(func.sum(SavingsAccount.balance)) + .where(SavingsAccount.user_id == cls.id) + .label("total_balance") + ) - @balance.expression - def balance(cls): - return select(func.sum(SavingsAccount.balance)).\ - where(SavingsAccount.user_id==cls.id).\ - label('total_balance') The above recipe will give us the ``balance`` column which renders a correlated SELECT: .. sourcecode:: pycon+sql - >>> print(s.query(User).filter(User.balance > 400)) - {printsql}SELECT "user".id AS user_id, "user".name AS user_name + >>> from sqlalchemy import select + >>> print(select(User).filter(User.balance > 400)) + {printsql}SELECT "user".id, "user".name FROM "user" - WHERE (SELECT sum(account.balance) AS sum_1 - FROM account - WHERE account.user_id = "user".id) > :param_1 + WHERE ( + SELECT sum(account.balance) AS sum_1 FROM account + WHERE account.user_id = "user".id + ) > :param_1 + .. _hybrid_custom_comparators: @@ -462,28 +547,39 @@ idiosyncratic behavior on the SQL side. The example class below allows case-insensitive comparisons on the attribute named ``word_insensitive``:: - from sqlalchemy.ext.hybrid import Comparator, hybrid_property - from sqlalchemy import func, Column, Integer, String - from sqlalchemy.orm import Session - from sqlalchemy.ext.declarative import declarative_base + from __future__ import annotations - Base = declarative_base() + from typing import Any - class CaseInsensitiveComparator(Comparator): - def __eq__(self, other): + from sqlalchemy import ColumnElement + from sqlalchemy import func + from sqlalchemy.ext.hybrid import Comparator + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class Base(DeclarativeBase): + pass + + + class CaseInsensitiveComparator(Comparator[str]): + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 return func.lower(self.__clause_element__()) == func.lower(other) class SearchWord(Base): __tablename__ = 'searchword' - id = Column(Integer, primary_key=True) - word = Column(String(255), nullable=False) + + id: Mapped[int] = mapped_column(primary_key=True) + word: Mapped[str] @hybrid_property - def word_insensitive(self): + def word_insensitive(self) -> str: return self.word.lower() - @word_insensitive.comparator - def word_insensitive(cls): + @word_insensitive.inplace.comparator + @classmethod + def _word_insensitive_comparator(cls) -> CaseInsensitiveComparator: return CaseInsensitiveComparator(cls.word) Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()`` @@ -491,11 +587,13 @@ SQL function to both sides: .. sourcecode:: pycon+sql - >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks")) - {printsql}SELECT searchword.id AS searchword_id, searchword.word AS searchword_word + >>> from sqlalchemy import select + >>> print(select(SearchWord).filter_by(word_insensitive="Trucks")) + {printsql}SELECT searchword.id, searchword.word FROM searchword WHERE lower(searchword.word) = lower(:lower_1) + The ``CaseInsensitiveComparator`` above implements part of the :class:`.ColumnOperators` interface. A "coercion" operation like lowercasing can be applied to all comparison operations (i.e. ``eq``, @@ -522,27 +620,29 @@ how the standard Python ``@property`` object works:: class FirstNameOnly(Base): # ... - first_name = Column(String) + first_name: Mapped[str] @hybrid_property - def name(self): + def name(self) -> str: return self.first_name - @name.setter - def name(self, value): + @name.inplace.setter + def _name_setter(self, value: str) -> None: self.first_name = value class FirstNameLastName(FirstNameOnly): # ... - last_name = Column(String) + last_name: Mapped[str] + # 'inplace' is not used here; calling getter creates a copy + # of FirstNameOnly.name that is local to FirstNameLastName @FirstNameOnly.name.getter - def name(self): + def name(self) -> str: return self.first_name + ' ' + self.last_name - @name.setter - def name(self, value): + @name.inplace.setter + def _name_setter(self, value: str) -> None: self.first_name, self.last_name = value.split(' ', 1) Above, the ``FirstNameLastName`` class refers to the hybrid from @@ -559,9 +659,10 @@ reference the instrumented attribute back to the hybrid object:: class FirstNameLastName(FirstNameOnly): # ... - last_name = Column(String) + last_name: Mapped[str] @FirstNameOnly.name.overrides.expression + @classmethod def name(cls): return func.concat(cls.first_name, ' ', cls.last_name) @@ -620,11 +721,11 @@ SQL side or Python side. Our ``SearchWord`` class can now deliver the class SearchWord(Base): __tablename__ = 'searchword' - id = Column(Integer, primary_key=True) - word = Column(String(255), nullable=False) + id: Mapped[int] = mapped_column(primary_key=True) + word: Mapped[str] @hybrid_property - def word_insensitive(self): + def word_insensitive(self) -> CaseInsensitiveWord: return CaseInsensitiveWord(self.word) The ``word_insensitive`` attribute now has case-insensitive comparison behavior @@ -633,7 +734,7 @@ value is converted to lower case on the Python side here): .. sourcecode:: pycon+sql - >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks")) + >>> print(select(SearchWord).filter_by(word_insensitive="Trucks")) {printsql}SELECT searchword.id AS searchword_id, searchword.word AS searchword_word FROM searchword WHERE lower(searchword.word) = :lower_1 @@ -642,14 +743,14 @@ SQL expression versus SQL expression: .. sourcecode:: pycon+sql + >>> from sqlalchemy.orm import aliased >>> sw1 = aliased(SearchWord) >>> sw2 = aliased(SearchWord) - >>> print(Session().query( - ... sw1.word_insensitive, - ... sw2.word_insensitive).\ - ... filter( - ... sw1.word_insensitive > sw2.word_insensitive - ... )) + >>> print( + ... select(sw1.word_insensitive, sw2.word_insensitive).filter( + ... sw1.word_insensitive > sw2.word_insensitive + ... ) + ... ) {printsql}SELECT lower(searchword_1.word) AS lower_1, lower(searchword_2.word) AS lower_2 FROM searchword AS searchword_1, searchword AS searchword_2 @@ -703,6 +804,7 @@ from ..orm import attributes from ..orm import InspectionAttrExtensionType from ..orm import interfaces from ..orm import ORMDescriptor +from ..orm.attributes import QueryableAttribute from ..sql import roles from ..sql._typing import is_has_clause_element from ..sql.elements import ColumnElement @@ -716,6 +818,7 @@ from ..util.typing import Self if TYPE_CHECKING: from ..orm.interfaces import MapperProperty from ..orm.util import AliasedInsp + from ..sql import SQLColumnExpression from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _DMLColumnArgument from ..sql._typing import _HasClauseElement @@ -725,6 +828,7 @@ if TYPE_CHECKING: _P = ParamSpec("_P") _R = TypeVar("_R") _T = TypeVar("_T", bound=Any) +_TE = TypeVar("_TE", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) _T_con = TypeVar("_T_con", bound=Any, contravariant=True) @@ -770,8 +874,8 @@ class _HybridSetterType(Protocol[_T_con]): class _HybridUpdaterType(Protocol[_T_con]): def __call__( - self, - cls: Type[Any], + s, + cls: Any, value: Union[_T_con, _ColumnExpressionArgument[_T_con]], ) -> List[Tuple[_DMLColumnArgument, Any]]: ... @@ -782,13 +886,45 @@ class _HybridDeleterType(Protocol[_T_co]): ... -class _HybridExprCallableType(Protocol[_T_co]): +class _HybridExprCallableType(Protocol[_T]): def __call__( self, cls: Any - ) -> Union[_HasClauseElement, ColumnElement[_T_co]]: + ) -> Union[_HasClauseElement, SQLColumnExpression[_T]]: + ... + + +class _HybridComparatorCallableType(Protocol[_T]): + def __call__(self, cls: Any) -> Comparator[_T]: ... +class _HybridClassLevelAccessor(QueryableAttribute[_T]): + """Describe the object returned by a hybrid_property() when + called as a class-level descriptor. + + """ + + if TYPE_CHECKING: + + def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]: + ... + + def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]: + ... + + def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]: + ... + + @property + def overrides(self) -> hybrid_property[_T]: + ... + + def update_expression( + self, meth: _HybridUpdaterType[_T] + ) -> hybrid_property[_T]: + ... + + class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): """A decorator which allows definition of a Python object method with both instance-level and class-level behavior. @@ -817,8 +953,9 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): return self._value + x + y @value.expression - def value(self, x, y): - return func.some_function(self._value, x, y) + @classmethod + def value(cls, x, y): + return func.some_function(cls._value, x, y) """ self.func = func @@ -827,6 +964,23 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): else: self.expression(func) # type: ignore + @property + def inplace(self) -> Self: + """Return the inplace mutator for this :class:`.hybrid_method`. + + The :class:`.hybrid_method` class already performs "in place" mutation + when the :meth:`.hybrid_method.expression` decorator is called, + so this attribute returns Self. + + .. versionadded:: 2.0.4 + + .. seealso:: + + :ref:`hybrid_pep484_naming` + + """ + return self + @overload def __get__( self, instance: Literal[None], owner: Type[object] @@ -859,6 +1013,13 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): return self +def _unwrap_classmethod(meth: _T) -> _T: + if isinstance(meth, classmethod): + return meth.__func__ # type: ignore + else: + return meth + + class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): """A decorator which allows definition of a Python descriptor with both instance-level and class-level behavior. @@ -898,9 +1059,9 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): self.fget = fget self.fset = fset self.fdel = fdel - self.expr = expr - self.custom_comparator = custom_comparator - self.update_expr = update_expr + self.expr = _unwrap_classmethod(expr) + self.custom_comparator = _unwrap_classmethod(custom_comparator) + self.update_expr = _unwrap_classmethod(update_expr) util.update_wrapper(self, fget) @overload @@ -910,7 +1071,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): @overload def __get__( self, instance: Literal[None], owner: Type[object] - ) -> SQLCoreOperations[_T]: + ) -> _HybridClassLevelAccessor[_T]: ... @overload @@ -919,7 +1080,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): def __get__( self, instance: Optional[object], owner: Optional[Type[object]] - ) -> Union[hybrid_property[_T], SQLCoreOperations[_T], _T]: + ) -> Union[hybrid_property[_T], _HybridClassLevelAccessor[_T], _T]: if owner is None: return self elif instance is None: @@ -982,6 +1143,81 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): """ return self + class _InPlace(Generic[_TE]): + """A builder helper for .hybrid_property. + + .. versionadded:: 2.0.4 + + """ + + __slots__ = ("attr",) + + def __init__(self, attr: hybrid_property[_TE]): + self.attr = attr + + def _set(self, **kw: Any) -> hybrid_property[_TE]: + for k, v in kw.items(): + setattr(self.attr, k, _unwrap_classmethod(v)) + return self.attr + + def getter(self, fget: _HybridGetterType[_TE]) -> hybrid_property[_TE]: + return self._set(fget=fget) + + def setter(self, fset: _HybridSetterType[_TE]) -> hybrid_property[_TE]: + return self._set(fset=fset) + + def deleter( + self, fdel: _HybridDeleterType[_TE] + ) -> hybrid_property[_TE]: + return self._set(fdel=fdel) + + def expression( + self, expr: _HybridExprCallableType[_TE] + ) -> hybrid_property[_TE]: + return self._set(expr=expr) + + def comparator( + self, comparator: _HybridComparatorCallableType[_TE] + ) -> hybrid_property[_TE]: + return self._set(custom_comparator=comparator) + + def update_expression( + self, meth: _HybridUpdaterType[_TE] + ) -> hybrid_property[_TE]: + return self._set(update_expr=meth) + + @property + def inplace(self) -> _InPlace[_T]: + """Return the inplace mutator for this :class:`.hybrid_property`. + + This is to allow in-place mutation of the hybrid, allowing the first + hybrid method of a certain name to be re-used in order to add + more methods without having to name those methods the same, e.g.:: + + class Interval(Base): + # ... + + @hybrid_property + def radius(self) -> float: + return abs(self.length) / 2 + + @radius.inplace.setter + def _radius_setter(self, value: float) -> None: + self.length = value * 2 + + @radius.inplace.expression + def _radius_expression(cls) -> ColumnElement[float]: + return type_coerce(func.abs(cls.length) / 2, Float) + + .. versionadded:: 2.0.4 + + .. seealso:: + + :ref:`hybrid_pep484_naming` + + """ + return hybrid_property._InPlace(self) + def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]: """Provide a modifying decorator that defines a getter method. @@ -1035,7 +1271,9 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): return self._copy(expr=expr) - def comparator(self, comparator: Comparator[_T]) -> hybrid_property[_T]: + def comparator( + self, comparator: _HybridComparatorCallableType[_T] + ) -> hybrid_property[_T]: """Provide a modifying decorator that defines a custom comparator producing method. @@ -1111,7 +1349,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): @util.memoized_property def _expr_comparator( self, - ) -> Callable[[Any], interfaces.PropComparator[_T]]: + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: if self.custom_comparator is not None: return self._get_comparator(self.custom_comparator) elif self.expr is not None: @@ -1121,7 +1359,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): def _get_expr( self, expr: _HybridExprCallableType[_T] - ) -> Callable[[Any], interfaces.PropComparator[_T]]: + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: def _expr(cls: Any) -> ExprComparator[_T]: return ExprComparator(cls, expr(cls), self) @@ -1131,13 +1369,13 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): def _get_comparator( self, comparator: Any - ) -> Callable[[Any], interfaces.PropComparator[_T]]: + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: proxy_attr = attributes.create_proxied_attribute(self) def expr_comparator( owner: Type[object], - ) -> interfaces.PropComparator[_T]: + ) -> _HybridClassLevelAccessor[_T]: # because this is the descriptor protocol, we don't really know # what our attribute name is. so search for it through the # MRO. @@ -1149,12 +1387,15 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): else: name = attributes._UNKNOWN_ATTR_KEY # type: ignore[assignment] - return proxy_attr( - owner, - name, - self, - comparator(owner), - doc=comparator.__doc__ or self.__doc__, + return cast( + "_HybridClassLevelAccessor[_T]", + proxy_attr( + owner, + name, + self, + comparator(owner), + doc=comparator.__doc__ or self.__doc__, + ), ) return expr_comparator @@ -1166,7 +1407,7 @@ class Comparator(interfaces.PropComparator[_T]): classes for usage with hybrids.""" def __init__( - self, expression: Union[_HasClauseElement, ColumnElement[_T]] + self, expression: Union[_HasClauseElement, SQLColumnExpression[_T]] ): self.expression = expression @@ -1201,7 +1442,7 @@ class ExprComparator(Comparator[_T]): def __init__( self, cls: Type[Any], - expression: Union[_HasClauseElement, ColumnElement[_T]], + expression: Union[_HasClauseElement, SQLColumnExpression[_T]], hybrid: hybrid_property[_T], ): self.cls = cls diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index e1190f7dd2..ab124103fd 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -40,7 +40,6 @@ if TYPE_CHECKING: from .dml import UpdateBase from .dml import ValuesBase from .elements import ClauseElement - from .elements import ColumnClause from .elements import ColumnElement from .elements import KeyedColumnElement from .elements import quoted_name @@ -224,7 +223,10 @@ _SelectStatementForCompoundArgument = Union[ """SELECT statement acceptable by ``union()`` and other SQL set operations""" _DMLColumnArgument = Union[ - str, "ColumnClause[Any]", _HasClauseElement, roles.DMLColumnRole + str, + _HasClauseElement, + roles.DMLColumnRole, + "SQLCoreOperations", ] """A DML column expression. This is a "key" inside of insert().values(), update().values(), and related. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 21b83d556e..75b5d09e3f 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3561,7 +3561,7 @@ class SelectBase( .. seealso:: - :meth:`_expression.SelectBase.as_scalar`. + :meth:`_expression.SelectBase.scalar_subquery`. """ return self.scalar_subquery().label(name) diff --git a/test/ext/mypy/plain_files/hybrid_four.py b/test/ext/mypy/plain_files/hybrid_four.py new file mode 100644 index 0000000000..a81ad96c41 --- /dev/null +++ b/test/ext/mypy/plain_files/hybrid_four.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import Any + +from sqlalchemy import ColumnElement +from sqlalchemy import func +from sqlalchemy.ext.hybrid import Comparator +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class CaseInsensitiveComparator(Comparator[str]): + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return func.lower(self.__clause_element__()) == func.lower(other) + + +class SearchWord(Base): + __tablename__ = "searchword" + + id: Mapped[int] = mapped_column(primary_key=True) + word: Mapped[str] + + @hybrid_property + def word_insensitive(self) -> str: + return self.word.lower() + + @word_insensitive.inplace.comparator + @classmethod + def _word_insensitive_comparator(cls) -> CaseInsensitiveComparator: + return CaseInsensitiveComparator(cls.word) + + +class FirstNameOnly(Base): + __tablename__ = "f" + + id: Mapped[int] = mapped_column(primary_key=True) + first_name: Mapped[str] + + @hybrid_property + def name(self) -> str: + return self.first_name + + @name.inplace.setter + def _name_setter(self, value: str) -> None: + self.first_name = value + + +class FirstNameLastName(FirstNameOnly): + + last_name: Mapped[str] + + @FirstNameOnly.name.getter + def name(self) -> str: + return self.first_name + " " + self.last_name + + @name.inplace.setter + def _name_setter(self, value: str) -> None: + self.first_name, self.last_name = value.split(" ", 1) diff --git a/test/ext/mypy/plain_files/hybrid_one.py b/test/ext/mypy/plain_files/hybrid_one.py index b3ce365acd..52a2a19ed0 100644 --- a/test/ext/mypy/plain_files/hybrid_one.py +++ b/test/ext/mypy/plain_files/hybrid_one.py @@ -69,7 +69,7 @@ if typing.TYPE_CHECKING: # EXPECTED_RE_TYPE: builtins.int\*? reveal_type(i1.length) - # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] + # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.int\*?\] reveal_type(Interval.length) # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] diff --git a/test/ext/mypy/plain_files/hybrid_three.py b/test/ext/mypy/plain_files/hybrid_three.py new file mode 100644 index 0000000000..86b0e4b262 --- /dev/null +++ b/test/ext/mypy/plain_files/hybrid_three.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from decimal import Decimal +from typing import cast +from typing import List +from typing import Optional + +from sqlalchemy import ForeignKey +from sqlalchemy import Numeric +from sqlalchemy import SQLColumnExpression +from sqlalchemy import String +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import relationship + + +class Base(DeclarativeBase): + pass + + +class SavingsAccount(Base): + __tablename__ = "account" + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) + balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) + + +class UserStyleOne(Base): + __tablename__ = "user" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + + accounts: Mapped[List[SavingsAccount]] = relationship() + + @hybrid_property + def _balance_getter(self) -> Optional[Decimal]: + if self.accounts: + return self.accounts[0].balance + else: + return None + + @_balance_getter.setter + def _balance_setter(self, value: Optional[Decimal]) -> None: + assert value is not None + if not self.accounts: + account = SavingsAccount(owner=self) + else: + account = self.accounts[0] + account.balance = value + + @_balance_setter.expression + def balance(cls) -> SQLColumnExpression[Optional[Decimal]]: + return cast( + "SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance + ) + + +class UserStyleTwo(Base): + __tablename__ = "user" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + + accounts: Mapped[List[SavingsAccount]] = relationship() + + @hybrid_property + def balance(self) -> Optional[Decimal]: + if self.accounts: + return self.accounts[0].balance + else: + return None + + @balance.inplace.setter + def _balance_setter(self, value: Optional[Decimal]) -> None: + assert value is not None + if not self.accounts: + account = SavingsAccount(owner=self) + else: + account = self.accounts[0] + account.balance = value + + @balance.inplace.expression + def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]: + return cast( + "SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance + ) diff --git a/test/ext/mypy/plain_files/hybrid_two.py b/test/ext/mypy/plain_files/hybrid_two.py index 619bbc839c..db50d7678e 100644 --- a/test/ext/mypy/plain_files/hybrid_two.py +++ b/test/ext/mypy/plain_files/hybrid_two.py @@ -30,17 +30,36 @@ class Interval(Base): def length(self) -> int: return self.end - self.start - # im not sure if there's a way to get typing tools to not complain about - # the re-defined name here, it handles it for plain @property - # but im not sure if that's hardcoded - # see https://github.com/python/typing/discussions/1102 + # old way - chain decorators + modifiers @hybrid_property def _inst_radius(self) -> float: return abs(self.length) / 2 @_inst_radius.expression - def radius(cls) -> ColumnElement[float]: + def old_radius(cls) -> ColumnElement[float]: + f1 = func.abs(cls.length, type_=Float()) + + expr = f1 / 2 + + # while we are here, check some Float[] / div type stuff + if typing.TYPE_CHECKING: + # EXPECTED_RE_TYPE: sqlalchemy.*Function\[builtins.float\*?\] + reveal_type(f1) + + # EXPECTED_RE_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\] + reveal_type(expr) + return expr + + # new way - use the original decorator with inplace + + @hybrid_property + def new_radius(self) -> float: + return abs(self.length) / 2 + + @new_radius.inplace.expression + @classmethod + def _new_radius_expr(cls) -> ColumnElement[float]: f1 = func.abs(cls.length, type_=Float()) expr = f1 / 2 @@ -59,13 +78,17 @@ i1 = Interval(5, 10) i2 = Interval(7, 12) l1: int = i1.length -rd: float = i2.radius +rdo: float = i2.old_radius +rdn: float = i2.new_radius expr1 = Interval.length.in_([5, 10]) -expr2 = Interval.radius +expr2o = Interval.old_radius + +expr2n = Interval.new_radius -expr3 = Interval.radius.in_([0.5, 5.2]) +expr3o = Interval.old_radius.in_([0.5, 5.2]) +expr3n = Interval.new_radius.in_([0.5, 5.2]) if typing.TYPE_CHECKING: @@ -73,22 +96,34 @@ if typing.TYPE_CHECKING: reveal_type(i1.length) # EXPECTED_RE_TYPE: builtins.float\*? - reveal_type(i2.radius) + reveal_type(i2.old_radius) - # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\] + # EXPECTED_RE_TYPE: builtins.float\*? + reveal_type(i2.new_radius) + + # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.int\*?\] reveal_type(Interval.length) - # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\] - reveal_type(Interval.radius) + # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\] + reveal_type(Interval.old_radius) + + # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\] + reveal_type(Interval.new_radius) # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] reveal_type(expr1) - # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\] - reveal_type(expr2) + # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\] + reveal_type(expr2o) + + # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\] + reveal_type(expr2n) + + # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] + reveal_type(expr3o) # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\] - reveal_type(expr3) + reveal_type(expr3n) # test #9268 diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index df69b36af1..c092850667 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -1,5 +1,6 @@ from decimal import Decimal +from sqlalchemy import column from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func @@ -26,6 +27,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_not from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column @@ -33,7 +35,7 @@ from sqlalchemy.testing.schema import Column class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" - def _fixture(self): + def _fixture(self, use_inplace=False, use_classmethod=False): Base = declarative_base() class UCComparator(hybrid.Comparator): @@ -54,9 +56,33 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): "This is a docstring" return self._value - 5 - @value.comparator - def value(cls): - return UCComparator(cls._value) + if use_classmethod: + if use_inplace: + + @value.inplace.comparator + @classmethod + def _value_comparator(cls): + return UCComparator(cls._value) + + else: + + @value.comparator + @classmethod + def value(cls): + return UCComparator(cls._value) + + else: + if use_inplace: + + @value.inplace.comparator + def _value_comparator(cls): + return UCComparator(cls._value) + + else: + + @value.comparator + def value(cls): + return UCComparator(cls._value) @value.setter def value(self, v): @@ -70,31 +96,51 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): eq_(a1._value, 10) eq_(a1.value, 5) - def test_value(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_value(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) eq_(str(A.value == 5), "upper(a.value) = upper(:upper_1)") - def test_aliased_value(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_aliased_value(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) eq_(str(aliased(A).value == 5), "upper(a_1.value) = upper(:upper_1)") - def test_query(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_query(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) sess = fixture_session() self.assert_compile( sess.query(A.value), "SELECT a.value AS a_value FROM a" ) - def test_aliased_query(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_aliased_query(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) sess = fixture_session() self.assert_compile( sess.query(aliased(A).value), "SELECT a_1.value AS a_1_value FROM a AS a_1", ) - def test_aliased_filter(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_aliased_filter(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) sess = fixture_session() self.assert_compile( sess.query(aliased(A)).filter_by(value="foo"), @@ -189,7 +235,8 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL): class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default" - def _fixture(self): + def _fixture(self, use_inplace=False, use_classmethod=False): + use_inplace, use_classmethod = bool(use_inplace), bool(use_classmethod) Base = declarative_base() class A(Base): @@ -202,15 +249,42 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): "This is an instance-level docstring" return int(self._value) - 5 - @value.expression - def value(cls): - "This is a class-level docstring" - return func.foo(cls._value) + cls.bar_value - @value.setter def value(self, v): self._value = v + 5 + if use_classmethod: + if use_inplace: + + @value.inplace.expression + @classmethod + def _value_expr(cls): + "This is a class-level docstring" + return func.foo(cls._value) + cls.bar_value + + else: + + @value.expression + @classmethod + def value(cls): + "This is a class-level docstring" + return func.foo(cls._value) + cls.bar_value + + else: + if use_inplace: + + @value.inplace.expression + def _value_expr(cls): + "This is a class-level docstring" + return func.foo(cls._value) + cls.bar_value + + else: + + @value.expression + def value(cls): + "This is a class-level docstring" + return func.foo(cls._value) + cls.bar_value + @hybrid.hybrid_property def bar_value(cls): return func.bar(cls._value) @@ -412,22 +486,34 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): "a.lastname AS name FROM a) AS anon_1", ) - def test_info(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_info(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) inspect(A).all_orm_descriptors.value.info["some key"] = "some value" eq_( inspect(A).all_orm_descriptors.value.info, {"some key": "some value"}, ) - def test_set_get(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_set_get(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) a1 = A(value=5) eq_(a1._value, 10) eq_(a1.value, 5) - def test_expression(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_expression(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) self.assert_compile( A.value.__clause_element__(), "foo(a.value) + bar(a.value)" ) @@ -455,15 +541,23 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): "AND foo(a.value) + bar(a.value) = :param_1)", ) - def test_aliased_expression(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_aliased_expression(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) self.assert_compile( aliased(A).value.__clause_element__(), "foo(a_1.value) + bar(a_1.value)", ) - def test_query(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_query(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) sess = fixture_session() self.assert_compile( sess.query(A).filter_by(value="foo"), @@ -471,8 +565,12 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): "FROM a WHERE foo(a.value) + bar(a.value) = :param_1", ) - def test_aliased_query(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_aliased_query(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) sess = fixture_session() self.assert_compile( sess.query(aliased(A)).filter_by(value="foo"), @@ -480,8 +578,12 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL): "FROM a AS a_1 WHERE foo(a_1.value) + bar(a_1.value) = :param_1", ) - def test_docstring(self): - A = self._fixture() + @testing.variation("use_inplace", [True, False]) + @testing.variation("use_classmethod", [True, False]) + def test_docstring(self, use_inplace, use_classmethod): + A = self._fixture( + use_inplace=use_inplace, use_classmethod=use_classmethod + ) eq_(A.value.__doc__, "This is a class-level docstring") # no docstring here since we get a literal @@ -854,6 +956,84 @@ class SynonymOfPropertyTest(fixtures.TestBase, AssertsCompiledSQL): ) +class InplaceCreationTest(fixtures.TestBase, AssertsCompiledSQL): + """test 'inplace' definitions added for 2.0 to assist with typing + limitations. + + """ + + __dialect__ = "default" + + def test_property_integration(self, decl_base): + class Person(decl_base): + __tablename__ = "person" + id = Column(Integer, primary_key=True) + _name = Column(String) + + @hybrid.hybrid_property + def name(self): + return self._name + + @name.inplace.setter + def _name_setter(self, value): + self._name = value.title() + + @name.inplace.expression + def _name_expression(self): + return func.concat("Hello", self._name) + + p1 = Person(_name="name") + eq_(p1.name, "name") + p1.name = "new name" + eq_(p1.name, "New Name") + + self.assert_compile(Person.name, "concat(:concat_1, person._name)") + + def test_method_integration(self, decl_base): + class A(decl_base): + __tablename__ = "a" + id = Column(Integer, primary_key=True) + _value = Column("value", String) + + @hybrid.hybrid_method + def value(self, x): + return int(self._value) + x + + @value.inplace.expression + def _value_expression(cls, value): + return func.foo(cls._value, value) + value + + a1 = A(_value="10") + eq_(a1.value(5), 15) + + self.assert_compile(A.value(column("q")), "foo(a.value, q) + q") + + def test_property_unit(self): + def one(): + pass + + def two(): + pass + + def three(): + pass + + prop = hybrid.hybrid_property(one) + + prop2 = prop.inplace.expression(two) + + prop3 = prop.inplace.setter(three) + + is_(prop, prop2) + is_(prop, prop3) + + def four(): + pass + + prop4 = prop.setter(four) + is_not(prop, prop4) + + class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default"