]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
modernize hybrids and apply typing
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Feb 2023 14:39:07 +0000 (09:39 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Feb 2023 23:01:31 +0000 (18:01 -0500)
Improved the typing support for the :ref:`hybrids_toplevel`
extension, updated all documentation to use ORM Annotated Declarative
mappings, and added a new modifier called :attr:`.hybrid_property.inplace`.
This modifier provides a way to alter the state of a :class:`.hybrid_property`
**in place**, which is essentially what very early versions of hybrids
did, before SQLAlchemy version 1.2.0 :ticket:`3912` changed this to
remove in-place mutation.  This in-place mutation is now restored on an
**opt-in** basis to allow a single hybrid to have multiple methods
set up, without the need to name all the methods the same and without the
need to carefully "chain" differently-named methods in order to maintain
the composition.  Typing tools such as Mypy and Pyright do not allow
same-named methods on a class, so with this change a succinct method
of setting up hybrids with typing support is restored.

Change-Id: Iea88025f023428f9f006846d09fbb4be391f5ebb
References: #9321

doc/build/changelog/unreleased_20/9321.rst [new file with mode: 0644]
doc/build/changelog/whatsnew_20.rst
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/sql/_typing.py
lib/sqlalchemy/sql/selectable.py
test/ext/mypy/plain_files/hybrid_four.py [new file with mode: 0644]
test/ext/mypy/plain_files/hybrid_one.py
test/ext/mypy/plain_files/hybrid_three.py [new file with mode: 0644]
test/ext/mypy/plain_files/hybrid_two.py
test/ext/test_hybrid.py

diff --git a/doc/build/changelog/unreleased_20/9321.rst b/doc/build/changelog/unreleased_20/9321.rst
new file mode 100644 (file)
index 0000000..cfc0cb0
--- /dev/null
@@ -0,0 +1,21 @@
+.. change::
+    :tags: usecase, typing
+    :tickets: 9321
+
+    Improved the typing support for the :ref:`hybrids_toplevel`
+    extension, updated all documentation to use ORM Annotated Declarative
+    mappings, and added a new modifier called :attr:`.hybrid_property.inplace`.
+    This modifier provides a way to alter the state of a :class:`.hybrid_property`
+    **in place**, which is essentially what very early versions of hybrids
+    did, before SQLAlchemy version 1.2.0 :ticket:`3912` changed this to
+    remove in-place mutation.  This in-place mutation is now restored on an
+    **opt-in** basis to allow a single hybrid to have multiple methods
+    set up, without the need to name all the methods the same and without the
+    need to carefully "chain" differently-named methods in order to maintain
+    the composition.  Typing tools such as Mypy and Pyright do not allow
+    same-named methods on a class, so with this change a succinct method
+    of setting up hybrids with typing support is restored.
+
+    .. seealso::
+
+        :ref:`hybrid_pep484_naming`
index 60783653a8636ffac96c6b5e35f7406928a52f14..3698a64f021379c1c878f6e8b08775ef68eb73af 100644 (file)
@@ -1824,7 +1824,7 @@ operations.
 
 
 ORM Declarative Applies Column Orders Differently; Control behavior using ``__table_cls__``
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 Declarative has changed the system by which mapped columns that originate from
 mixin or abstract base classes are sorted along with the columns that are on the
index 7061ff2b4c1eda071ae38d61a63a1ca6e971e1a4..9fdf4d777b7309f1a9b0ba91a13fa7bb8766f147 100644 (file)
@@ -22,36 +22,42 @@ instance level.  Below, each function decorated with :class:`.hybrid_method` or
 :class:`.hybrid_property` may receive ``self`` as an instance of the class, or
 as the class itself::
 
-    from sqlalchemy import Column, Integer
-    from sqlalchemy.ext.declarative import declarative_base
-    from sqlalchemy.orm import Session, aliased
-    from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method
+    from __future__ import annotations
+
+    from sqlalchemy.ext.hybrid import hybrid_method
+    from sqlalchemy.ext.hybrid import hybrid_property
+    from sqlalchemy.orm import DeclarativeBase
+    from sqlalchemy.orm import Mapped
+    from sqlalchemy.orm import mapped_column
 
-    Base = declarative_base()
+
+    class Base(DeclarativeBase):
+        pass
 
     class Interval(Base):
         __tablename__ = 'interval'
 
-        id = Column(Integer, primary_key=True)
-        start = Column(Integer, nullable=False)
-        end = Column(Integer, nullable=False)
+        id: Mapped[int] = mapped_column(primary_key=True)
+        start: Mapped[int]
+        end: Mapped[int]
 
-        def __init__(self, start, end):
+        def __init__(self, start: int, end: int):
             self.start = start
             self.end = end
 
         @hybrid_property
-        def length(self):
+        def length(self) -> int:
             return self.end - self.start
 
         @hybrid_method
-        def contains(self, point):
+        def contains(self, point: int) -> bool:
             return (self.start <= point) & (point <= self.end)
 
         @hybrid_method
-        def intersects(self, other):
+        def intersects(self, other: Interval) -> bool:
             return self.contains(other.start) | self.contains(other.end)
 
+
 Above, the ``length`` property returns the difference between the
 ``end`` and ``start`` attributes.  With an instance of ``Interval``,
 this subtraction occurs in Python, using normal Python descriptor
@@ -141,18 +147,22 @@ two separate Python expressions should be defined.  The
 example we'll define the radius of the interval, which requires the
 usage of the absolute value function::
 
+    from sqlalchemy import ColumnElement
+    from sqlalchemy import Float
     from sqlalchemy import func
+    from sqlalchemy import type_coerce
 
-    class Interval:
+    class Interval(Base):
         # ...
 
         @hybrid_property
-        def radius(self):
+        def radius(self) -> float:
             return abs(self.length) / 2
 
-        @radius.expression
-        def radius(cls):
-            return func.abs(cls.length) / 2
+        @radius.inplace.expression
+        @classmethod
+        def _radius_expression(cls) -> ColumnElement[float]:
+            return type_coerce(func.abs(cls.length) / 2, Float)
 
 Above the Python function ``abs()`` is used for instance-level
 operations, the SQL function ``ABS()`` is used via the :data:`.func`
@@ -169,27 +179,72 @@ object for class-level expressions:
     FROM interval
     WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1
 
-.. note:: When defining an expression for a hybrid property or method, the
-   expression method **must** retain the name of the original hybrid, else
-   the new hybrid with the additional state will be attached to the class
-   with the non-matching name. To use the example above::
+.. _hybrid_pep484_naming:
+
+Notes on Method Names in a pep-484 World
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+In order to work with Python typing tools such as mypy, all method
+names on a class must be differently-named.   While experienced Python users
+will note that the Python ``@property`` decorator does not have this limitation
+with typing tools, as of this writing this is only because all Python typing
+tools have hardcoded rules that are specific to ``@property`` which are
+not made available to any other user-defined decorators
+(see https://github.com/python/typing/discussions/1102 .)
 
-    class Interval:
+Therefore SQLAlchemy 2.0 introduces a new modifier
+:attr:`.hybrid_property.inplace` that allows new methods to be added to an
+existing hybrid property **in place**, so that the official name of the hybrid
+can be stated once up front, and the correctly-named hybrid property can then
+be re-used to add more methods, **without** the need to name those methods the
+same way and thus avoiding naming conflicts::
+
+
+    class Interval(Base):
         # ...
 
         @hybrid_property
-        def radius(self):
+        def radius(self) -> float:
             return abs(self.length) / 2
 
-        # WRONG - the non-matching name will cause this function to be
-        # ignored
-        @radius.expression
-        def radius_expression(cls):
-            return func.abs(cls.length) / 2
+        @radius.inplace.setter
+        def _radius_setter(self, value: float) -> None:
+            # for example only
+            self.length = value * 2
+
+        @radius.inplace.expression
+        @classmethod
+        def _radius_expression(cls) -> ColumnElement[float]::
+            return type_coerce(func.abs(cls.length) / 2, Float)
+
+When not using the :attr:`.hybrid_property.inplace` modifier, all hybrid
+property modifiers return a **new** object each time.   Without
+:attr:`.hybrid_property.inplace`, the above methods need to be carefully
+chained together::
+
+    class Interval(Base):
+        # ...
+
+        # old approach not using .inplace
+
+        @hybrid_property
+        def _radius_getter(self) -> float:
+            return abs(self.length) / 2
+
+        @_radius_getter.setter
+        def _radius_setter(self, value: float) -> None:
+            # for example only
+            self.length = value * 2
+
+        @_radius_setter.expression
+        @classmethod
+        def _radius_expression(cls) -> ColumnElement[float]::
+            return type_coerce(func.abs(cls.length) / 2, Float)
+
+.. versionadded:: 2.0.4 Added :attr:`.hybrid_property.inplace` to allow
+   less verbose construction of composite :class:`.hybrid_property` objects
+   while not having to use repeated method names.
 
-   This is also true for other mutator methods, such as
-   :meth:`.hybrid_property.update_expression`. This is the same behavior
-   as that of the ``@property`` construct that is part of standard Python.
 
 Defining Setters
 ----------------
@@ -197,15 +252,15 @@ Defining Setters
 Hybrid properties can also define setter methods.  If we wanted
 ``length`` above, when set, to modify the endpoint value::
 
-    class Interval:
+    class Interval(Base):
         # ...
 
         @hybrid_property
-        def length(self):
+        def length(self) -> int:
             return self.end - self.start
 
-        @length.setter
-        def length(self, value):
+        @length.inplace.setter
+        def _length_setter(self, value: int) -> None:
             self.end = self.start + value
 
 The ``length(self, value)`` method is now called upon set::
@@ -239,33 +294,34 @@ accommodate a value passed to :meth:`_query.Query.update` which can affect
 this, using the :meth:`.hybrid_property.update_expression` decorator.
 A handler that works similarly to our setter would be::
 
-    class Interval:
+    from typing import List, Tuple, Any
+
+    class Interval(Base):
         # ...
 
         @hybrid_property
-        def length(self):
+        def length(self) -> int:
             return self.end - self.start
 
-        @length.setter
-        def length(self, value):
+        @length.inplace.setter
+        def _length_setter(self, value: int) -> None:
             self.end = self.start + value
 
-        @length.update_expression
-        def length(cls, value):
+        @length.inplace.update_expression
+        def _length_update_expression(cls, value: Any) -> List[Tuple[Any, Any]]:
             return [
                 (cls.end, cls.start + value)
             ]
 
-Above, if we use ``Interval.length`` in an UPDATE expression as::
+Above, if we use ``Interval.length`` in an UPDATE expression, we get
+a hybrid SET expression:
 
-    session.query(Interval).update(
-        {Interval.length: 25}, synchronize_session='fetch')
-
-We'll get an UPDATE statement along the lines of:
+.. sourcecode:: pycon+sql
 
-.. sourcecode:: sql
 
-    UPDATE interval SET end=start + :value
+    >>> from sqlalchemy import update
+    >>> print(update(Interval).values({Interval.length: 25}))
+    {printsql}UPDATE interval SET "end"=(interval.start + :start_1)
 
 In some cases, the default "evaluate" strategy can't perform the SET
 expression in Python; while the addition operator we're using above
@@ -273,35 +329,6 @@ is supported, for more complex SET expressions it will usually be necessary
 to use either the "fetch" or False synchronization strategy as illustrated
 above.
 
-.. note:: For ORM bulk updates to work with hybrids, the function name
-   of the hybrid must match that of how it is accessed.    Something
-   like this wouldn't work::
-
-        class Interval:
-            # ...
-
-            def _get(self):
-                return self.end - self.start
-
-            def _set(self, value):
-                self.end = self.start + value
-
-            def _update_expr(cls, value):
-                return [
-                    (cls.end, cls.start + value)
-                ]
-
-            length = hybrid_property(
-                fget=_get, fset=_set, update_expr=_update_expr
-            )
-
-    The Python descriptor protocol does not provide any reliable way for
-    a descriptor to know what attribute name it was accessed as, and
-    the UPDATE scheme currently relies upon being able to access the
-    attribute from an instance by name in order to perform the instance
-    synchronization step.
-
-.. versionadded:: 1.2 added support for bulk updates to hybrid properties.
 
 Working with Relationships
 --------------------------
@@ -317,57 +344,89 @@ Join-Dependent Relationship Hybrid
 Consider the following declarative
 mapping which relates a ``User`` to a ``SavingsAccount``::
 
-    from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
-    from sqlalchemy.orm import relationship
-    from sqlalchemy.ext.declarative import declarative_base
+    from __future__ import annotations
+
+    from decimal import Decimal
+    from typing import cast
+    from typing import List
+    from typing import Optional
+
+    from sqlalchemy import ForeignKey
+    from sqlalchemy import Numeric
+    from sqlalchemy import String
+    from sqlalchemy import SQLColumnExpression
     from sqlalchemy.ext.hybrid import hybrid_property
+    from sqlalchemy.orm import DeclarativeBase
+    from sqlalchemy.orm import Mapped
+    from sqlalchemy.orm import mapped_column
+    from sqlalchemy.orm import relationship
+
+
+    class Base(DeclarativeBase):
+        pass
 
-    Base = declarative_base()
 
     class SavingsAccount(Base):
         __tablename__ = 'account'
-        id = Column(Integer, primary_key=True)
-        user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
-        balance = Column(Numeric(15, 5))
+        id: Mapped[int] = mapped_column(primary_key=True)
+        user_id: Mapped[int] = mapped_column(ForeignKey('user.id'))
+        balance: Mapped[Decimal] = mapped_column(Numeric(15, 5))
+
+        owner: Mapped[User] = relationship(back_populates="accounts")
 
     class User(Base):
         __tablename__ = 'user'
-        id = Column(Integer, primary_key=True)
-        name = Column(String(100), nullable=False)
+        id: Mapped[int] = mapped_column(primary_key=True)
+        name: Mapped[str] = mapped_column(String(100))
 
-        accounts = relationship("SavingsAccount", backref="owner")
+        accounts: Mapped[List[SavingsAccount]] = relationship(
+            back_populates="owner", lazy="selectin"
+        )
 
         @hybrid_property
-        def balance(self):
+        def balance(self) -> Optional[Decimal]:
             if self.accounts:
                 return self.accounts[0].balance
             else:
                 return None
 
-        @balance.setter
-        def balance(self, value):
+        @balance.inplace.setter
+        def _balance_setter(self, value: Optional[Decimal]) -> None:
+            assert value is not None
+
             if not self.accounts:
-                account = Account(owner=self)
+                account = SavingsAccount(owner=self)
             else:
                 account = self.accounts[0]
             account.balance = value
 
-        @balance.expression
-        def balance(cls):
-            return SavingsAccount.balance
+        @balance.inplace.expression
+        @classmethod
+        def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]:
+            return cast("SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance)
 
 The above hybrid property ``balance`` works with the first
 ``SavingsAccount`` entry in the list of accounts for this user.   The
 in-Python getter/setter methods can treat ``accounts`` as a Python
 list available on ``self``.
 
-However, at the expression level, it's expected that the ``User`` class will
+.. tip:: The ``User.balance`` getter in the above example accesses the
+   ``self.acccounts`` collection, which will normally be loaded via the
+   :func:`.selectinload` loader strategy configured on the ``User.balance``
+   :func:`_orm.relationship`. The default loader strategy when not otherwise
+   stated on :func:`_orm.relationship` is :func:`.lazyload`, which emits SQL on
+   demand. When using asyncio, on-demand loaders such as :func:`.lazyload` are
+   not supported, so care should be taken to ensure the ``self.accounts``
+   collection is accessible to this hybrid accessor when using asyncio.
+
+At the expression level, it's expected that the ``User`` class will
 be used in an appropriate context such that an appropriate join to
 ``SavingsAccount`` will be present:
 
 .. sourcecode:: pycon+sql
 
-    >>> print(Session().query(User, User.balance).
+    >>> from sqlalchemy import select
+    >>> print(select(User, User.balance).
     ...       join(User.accounts).filter(User.balance > 5000))
     {printsql}SELECT "user".id AS user_id, "user".name AS user_name,
     account.balance AS account_balance
@@ -381,8 +440,9 @@ would use an outer join:
 
 .. sourcecode:: pycon+sql
 
+    >>> from sqlalchemy import select
     >>> from sqlalchemy import or_
-    >>> print (Session().query(User, User.balance).outerjoin(User.accounts).
+    >>> print (select(User, User.balance).outerjoin(User.accounts).
     ...         filter(or_(User.balance < 5000, User.balance == None)))
     {printsql}SELECT "user".id AS user_id, "user".name AS user_name,
     account.balance AS account_balance
@@ -400,48 +460,73 @@ illustrated at :ref:`mapper_column_property_sql_expressions`,
 we can adjust our ``SavingsAccount`` example to aggregate the balances for
 *all* accounts, and use a correlated subquery for the column expression::
 
-    from sqlalchemy import Column, Integer, ForeignKey, Numeric, String
-    from sqlalchemy.orm import relationship
-    from sqlalchemy.ext.declarative import declarative_base
+    from __future__ import annotations
+
+    from decimal import Decimal
+    from typing import List
+
+    from sqlalchemy import ForeignKey
+    from sqlalchemy import func
+    from sqlalchemy import Numeric
+    from sqlalchemy import select
+    from sqlalchemy import SQLColumnExpression
+    from sqlalchemy import String
     from sqlalchemy.ext.hybrid import hybrid_property
-    from sqlalchemy import select, func
+    from sqlalchemy.orm import DeclarativeBase
+    from sqlalchemy.orm import Mapped
+    from sqlalchemy.orm import mapped_column
+    from sqlalchemy.orm import relationship
+
+
+    class Base(DeclarativeBase):
+        pass
 
-    Base = declarative_base()
 
     class SavingsAccount(Base):
         __tablename__ = 'account'
-        id = Column(Integer, primary_key=True)
-        user_id = Column(Integer, ForeignKey('user.id'), nullable=False)
-        balance = Column(Numeric(15, 5))
+        id: Mapped[int] = mapped_column(primary_key=True)
+        user_id: Mapped[int] = mapped_column(ForeignKey('user.id'))
+        balance: Mapped[Decimal] = mapped_column(Numeric(15, 5))
+
+        owner: Mapped[User] = relationship(back_populates="accounts")
 
     class User(Base):
         __tablename__ = 'user'
-        id = Column(Integer, primary_key=True)
-        name = Column(String(100), nullable=False)
+        id: Mapped[int] = mapped_column(primary_key=True)
+        name: Mapped[str] = mapped_column(String(100))
 
-        accounts = relationship("SavingsAccount", backref="owner")
+        accounts: Mapped[List[SavingsAccount]] = relationship(
+            back_populates="owner", lazy="selectin"
+        )
 
         @hybrid_property
-        def balance(self):
-            return sum(acc.balance for acc in self.accounts)
+        def balance(self) -> Decimal:
+            return sum((acc.balance for acc in self.accounts), start=Decimal("0"))
+
+        @balance.inplace.expression
+        @classmethod
+        def _balance_expression(cls) -> SQLColumnExpression[Decimal]:
+            return (
+                select(func.sum(SavingsAccount.balance))
+                .where(SavingsAccount.user_id == cls.id)
+                .label("total_balance")
+            )
 
-        @balance.expression
-        def balance(cls):
-            return select(func.sum(SavingsAccount.balance)).\
-                    where(SavingsAccount.user_id==cls.id).\
-                    label('total_balance')
 
 The above recipe will give us the ``balance`` column which renders
 a correlated SELECT:
 
 .. sourcecode:: pycon+sql
 
-    >>> print(s.query(User).filter(User.balance > 400))
-    {printsql}SELECT "user".id AS user_id, "user".name AS user_name
+    >>> from sqlalchemy import select
+    >>> print(select(User).filter(User.balance > 400))
+    {printsql}SELECT "user".id, "user".name
     FROM "user"
-    WHERE (SELECT sum(account.balance) AS sum_1
-    FROM account
-    WHERE account.user_id = "user".id) > :param_1
+    WHERE (
+        SELECT sum(account.balance) AS sum_1 FROM account
+        WHERE account.user_id = "user".id
+    ) > :param_1
+
 
 .. _hybrid_custom_comparators:
 
@@ -462,28 +547,39 @@ idiosyncratic behavior on the SQL side.
 The example class below allows case-insensitive comparisons on the attribute
 named ``word_insensitive``::
 
-    from sqlalchemy.ext.hybrid import Comparator, hybrid_property
-    from sqlalchemy import func, Column, Integer, String
-    from sqlalchemy.orm import Session
-    from sqlalchemy.ext.declarative import declarative_base
+    from __future__ import annotations
 
-    Base = declarative_base()
+    from typing import Any
 
-    class CaseInsensitiveComparator(Comparator):
-        def __eq__(self, other):
+    from sqlalchemy import ColumnElement
+    from sqlalchemy import func
+    from sqlalchemy.ext.hybrid import Comparator
+    from sqlalchemy.ext.hybrid import hybrid_property
+    from sqlalchemy.orm import DeclarativeBase
+    from sqlalchemy.orm import Mapped
+    from sqlalchemy.orm import mapped_column
+
+    class Base(DeclarativeBase):
+        pass
+
+
+    class CaseInsensitiveComparator(Comparator[str]):
+        def __eq__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
             return func.lower(self.__clause_element__()) == func.lower(other)
 
     class SearchWord(Base):
         __tablename__ = 'searchword'
-        id = Column(Integer, primary_key=True)
-        word = Column(String(255), nullable=False)
+
+        id: Mapped[int] = mapped_column(primary_key=True)
+        word: Mapped[str]
 
         @hybrid_property
-        def word_insensitive(self):
+        def word_insensitive(self) -> str:
             return self.word.lower()
 
-        @word_insensitive.comparator
-        def word_insensitive(cls):
+        @word_insensitive.inplace.comparator
+        @classmethod
+        def _word_insensitive_comparator(cls) -> CaseInsensitiveComparator:
             return CaseInsensitiveComparator(cls.word)
 
 Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()``
@@ -491,11 +587,13 @@ SQL function to both sides:
 
 .. sourcecode:: pycon+sql
 
-    >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks"))
-    {printsql}SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
+    >>> from sqlalchemy import select
+    >>> print(select(SearchWord).filter_by(word_insensitive="Trucks"))
+    {printsql}SELECT searchword.id, searchword.word
     FROM searchword
     WHERE lower(searchword.word) = lower(:lower_1)
 
+
 The ``CaseInsensitiveComparator`` above implements part of the
 :class:`.ColumnOperators` interface.   A "coercion" operation like
 lowercasing can be applied to all comparison operations (i.e. ``eq``,
@@ -522,27 +620,29 @@ how the standard Python ``@property`` object works::
     class FirstNameOnly(Base):
         # ...
 
-        first_name = Column(String)
+        first_name: Mapped[str]
 
         @hybrid_property
-        def name(self):
+        def name(self) -> str:
             return self.first_name
 
-        @name.setter
-        def name(self, value):
+        @name.inplace.setter
+        def _name_setter(self, value: str) -> None:
             self.first_name = value
 
     class FirstNameLastName(FirstNameOnly):
         # ...
 
-        last_name = Column(String)
+        last_name: Mapped[str]
 
+        # 'inplace' is not used here; calling getter creates a copy
+        # of FirstNameOnly.name that is local to FirstNameLastName
         @FirstNameOnly.name.getter
-        def name(self):
+        def name(self) -> str:
             return self.first_name + ' ' + self.last_name
 
-        @name.setter
-        def name(self, value):
+        @name.inplace.setter
+        def _name_setter(self, value: str) -> None:
             self.first_name, self.last_name = value.split(' ', 1)
 
 Above, the ``FirstNameLastName`` class refers to the hybrid from
@@ -559,9 +659,10 @@ reference the instrumented attribute back to the hybrid object::
     class FirstNameLastName(FirstNameOnly):
         # ...
 
-        last_name = Column(String)
+        last_name: Mapped[str]
 
         @FirstNameOnly.name.overrides.expression
+        @classmethod
         def name(cls):
             return func.concat(cls.first_name, ' ', cls.last_name)
 
@@ -620,11 +721,11 @@ SQL side or Python side. Our ``SearchWord`` class can now deliver the
 
     class SearchWord(Base):
         __tablename__ = 'searchword'
-        id = Column(Integer, primary_key=True)
-        word = Column(String(255), nullable=False)
+        id: Mapped[int] = mapped_column(primary_key=True)
+        word: Mapped[str]
 
         @hybrid_property
-        def word_insensitive(self):
+        def word_insensitive(self) -> CaseInsensitiveWord:
             return CaseInsensitiveWord(self.word)
 
 The ``word_insensitive`` attribute now has case-insensitive comparison behavior
@@ -633,7 +734,7 @@ value is converted to lower case on the Python side here):
 
 .. sourcecode:: pycon+sql
 
-    >>> print(Session().query(SearchWord).filter_by(word_insensitive="Trucks"))
+    >>> print(select(SearchWord).filter_by(word_insensitive="Trucks"))
     {printsql}SELECT searchword.id AS searchword_id, searchword.word AS searchword_word
     FROM searchword
     WHERE lower(searchword.word) = :lower_1
@@ -642,14 +743,14 @@ SQL expression versus SQL expression:
 
 .. sourcecode:: pycon+sql
 
+    >>> from sqlalchemy.orm import aliased
     >>> sw1 = aliased(SearchWord)
     >>> sw2 = aliased(SearchWord)
-    >>> print(Session().query(
-    ...                    sw1.word_insensitive,
-    ...                    sw2.word_insensitive).\
-    ...                        filter(
-    ...                            sw1.word_insensitive > sw2.word_insensitive
-    ...                        ))
+    >>> print(
+    ...     select(sw1.word_insensitive, sw2.word_insensitive).filter(
+    ...         sw1.word_insensitive > sw2.word_insensitive
+    ...     )
+    ... )
     {printsql}SELECT lower(searchword_1.word) AS lower_1,
     lower(searchword_2.word) AS lower_2
     FROM searchword AS searchword_1, searchword AS searchword_2
@@ -703,6 +804,7 @@ from ..orm import attributes
 from ..orm import InspectionAttrExtensionType
 from ..orm import interfaces
 from ..orm import ORMDescriptor
+from ..orm.attributes import QueryableAttribute
 from ..sql import roles
 from ..sql._typing import is_has_clause_element
 from ..sql.elements import ColumnElement
@@ -716,6 +818,7 @@ from ..util.typing import Self
 if TYPE_CHECKING:
     from ..orm.interfaces import MapperProperty
     from ..orm.util import AliasedInsp
+    from ..sql import SQLColumnExpression
     from ..sql._typing import _ColumnExpressionArgument
     from ..sql._typing import _DMLColumnArgument
     from ..sql._typing import _HasClauseElement
@@ -725,6 +828,7 @@ if TYPE_CHECKING:
 _P = ParamSpec("_P")
 _R = TypeVar("_R")
 _T = TypeVar("_T", bound=Any)
+_TE = TypeVar("_TE", bound=Any)
 _T_co = TypeVar("_T_co", bound=Any, covariant=True)
 _T_con = TypeVar("_T_con", bound=Any, contravariant=True)
 
@@ -770,8 +874,8 @@ class _HybridSetterType(Protocol[_T_con]):
 
 class _HybridUpdaterType(Protocol[_T_con]):
     def __call__(
-        self,
-        cls: Type[Any],
+        s,
+        cls: Any,
         value: Union[_T_con, _ColumnExpressionArgument[_T_con]],
     ) -> List[Tuple[_DMLColumnArgument, Any]]:
         ...
@@ -782,13 +886,45 @@ class _HybridDeleterType(Protocol[_T_co]):
         ...
 
 
-class _HybridExprCallableType(Protocol[_T_co]):
+class _HybridExprCallableType(Protocol[_T]):
     def __call__(
         self, cls: Any
-    ) -> Union[_HasClauseElement, ColumnElement[_T_co]]:
+    ) -> Union[_HasClauseElement, SQLColumnExpression[_T]]:
+        ...
+
+
+class _HybridComparatorCallableType(Protocol[_T]):
+    def __call__(self, cls: Any) -> Comparator[_T]:
         ...
 
 
+class _HybridClassLevelAccessor(QueryableAttribute[_T]):
+    """Describe the object returned by a hybrid_property() when
+    called as a class-level descriptor.
+
+    """
+
+    if TYPE_CHECKING:
+
+        def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]:
+            ...
+
+        def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]:
+            ...
+
+        def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]:
+            ...
+
+        @property
+        def overrides(self) -> hybrid_property[_T]:
+            ...
+
+        def update_expression(
+            self, meth: _HybridUpdaterType[_T]
+        ) -> hybrid_property[_T]:
+            ...
+
+
 class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]):
     """A decorator which allows definition of a Python object method with both
     instance-level and class-level behavior.
@@ -817,8 +953,9 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]):
                     return self._value + x + y
 
                 @value.expression
-                def value(self, x, y):
-                    return func.some_function(self._value, x, y)
+                @classmethod
+                def value(cls, x, y):
+                    return func.some_function(cls._value, x, y)
 
         """
         self.func = func
@@ -827,6 +964,23 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]):
         else:
             self.expression(func)  # type: ignore
 
+    @property
+    def inplace(self) -> Self:
+        """Return the inplace mutator for this :class:`.hybrid_method`.
+
+        The :class:`.hybrid_method` class already performs "in place" mutation
+        when the :meth:`.hybrid_method.expression` decorator is called,
+        so this attribute returns Self.
+
+        .. versionadded:: 2.0.4
+
+        .. seealso::
+
+            :ref:`hybrid_pep484_naming`
+
+        """
+        return self
+
     @overload
     def __get__(
         self, instance: Literal[None], owner: Type[object]
@@ -859,6 +1013,13 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]):
         return self
 
 
+def _unwrap_classmethod(meth: _T) -> _T:
+    if isinstance(meth, classmethod):
+        return meth.__func__  # type: ignore
+    else:
+        return meth
+
+
 class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
     """A decorator which allows definition of a Python descriptor with both
     instance-level and class-level behavior.
@@ -898,9 +1059,9 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
         self.fget = fget
         self.fset = fset
         self.fdel = fdel
-        self.expr = expr
-        self.custom_comparator = custom_comparator
-        self.update_expr = update_expr
+        self.expr = _unwrap_classmethod(expr)
+        self.custom_comparator = _unwrap_classmethod(custom_comparator)
+        self.update_expr = _unwrap_classmethod(update_expr)
         util.update_wrapper(self, fget)
 
     @overload
@@ -910,7 +1071,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
     @overload
     def __get__(
         self, instance: Literal[None], owner: Type[object]
-    ) -> SQLCoreOperations[_T]:
+    ) -> _HybridClassLevelAccessor[_T]:
         ...
 
     @overload
@@ -919,7 +1080,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
 
     def __get__(
         self, instance: Optional[object], owner: Optional[Type[object]]
-    ) -> Union[hybrid_property[_T], SQLCoreOperations[_T], _T]:
+    ) -> Union[hybrid_property[_T], _HybridClassLevelAccessor[_T], _T]:
         if owner is None:
             return self
         elif instance is None:
@@ -982,6 +1143,81 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
         """
         return self
 
+    class _InPlace(Generic[_TE]):
+        """A builder helper for .hybrid_property.
+
+        .. versionadded:: 2.0.4
+
+        """
+
+        __slots__ = ("attr",)
+
+        def __init__(self, attr: hybrid_property[_TE]):
+            self.attr = attr
+
+        def _set(self, **kw: Any) -> hybrid_property[_TE]:
+            for k, v in kw.items():
+                setattr(self.attr, k, _unwrap_classmethod(v))
+            return self.attr
+
+        def getter(self, fget: _HybridGetterType[_TE]) -> hybrid_property[_TE]:
+            return self._set(fget=fget)
+
+        def setter(self, fset: _HybridSetterType[_TE]) -> hybrid_property[_TE]:
+            return self._set(fset=fset)
+
+        def deleter(
+            self, fdel: _HybridDeleterType[_TE]
+        ) -> hybrid_property[_TE]:
+            return self._set(fdel=fdel)
+
+        def expression(
+            self, expr: _HybridExprCallableType[_TE]
+        ) -> hybrid_property[_TE]:
+            return self._set(expr=expr)
+
+        def comparator(
+            self, comparator: _HybridComparatorCallableType[_TE]
+        ) -> hybrid_property[_TE]:
+            return self._set(custom_comparator=comparator)
+
+        def update_expression(
+            self, meth: _HybridUpdaterType[_TE]
+        ) -> hybrid_property[_TE]:
+            return self._set(update_expr=meth)
+
+    @property
+    def inplace(self) -> _InPlace[_T]:
+        """Return the inplace mutator for this :class:`.hybrid_property`.
+
+        This is to allow in-place mutation of the hybrid, allowing the first
+        hybrid method of a certain name to be re-used in order to add
+        more methods without having to name those methods the same, e.g.::
+
+            class Interval(Base):
+                # ...
+
+                @hybrid_property
+                def radius(self) -> float:
+                    return abs(self.length) / 2
+
+                @radius.inplace.setter
+                def _radius_setter(self, value: float) -> None:
+                    self.length = value * 2
+
+                @radius.inplace.expression
+                def _radius_expression(cls) -> ColumnElement[float]:
+                    return type_coerce(func.abs(cls.length) / 2, Float)
+
+        .. versionadded:: 2.0.4
+
+        .. seealso::
+
+            :ref:`hybrid_pep484_naming`
+
+        """
+        return hybrid_property._InPlace(self)
+
     def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]:
         """Provide a modifying decorator that defines a getter method.
 
@@ -1035,7 +1271,9 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
 
         return self._copy(expr=expr)
 
-    def comparator(self, comparator: Comparator[_T]) -> hybrid_property[_T]:
+    def comparator(
+        self, comparator: _HybridComparatorCallableType[_T]
+    ) -> hybrid_property[_T]:
         """Provide a modifying decorator that defines a custom
         comparator producing method.
 
@@ -1111,7 +1349,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
     @util.memoized_property
     def _expr_comparator(
         self,
-    ) -> Callable[[Any], interfaces.PropComparator[_T]]:
+    ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]:
         if self.custom_comparator is not None:
             return self._get_comparator(self.custom_comparator)
         elif self.expr is not None:
@@ -1121,7 +1359,7 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
 
     def _get_expr(
         self, expr: _HybridExprCallableType[_T]
-    ) -> Callable[[Any], interfaces.PropComparator[_T]]:
+    ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]:
         def _expr(cls: Any) -> ExprComparator[_T]:
             return ExprComparator(cls, expr(cls), self)
 
@@ -1131,13 +1369,13 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
 
     def _get_comparator(
         self, comparator: Any
-    ) -> Callable[[Any], interfaces.PropComparator[_T]]:
+    ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]:
 
         proxy_attr = attributes.create_proxied_attribute(self)
 
         def expr_comparator(
             owner: Type[object],
-        ) -> interfaces.PropComparator[_T]:
+        ) -> _HybridClassLevelAccessor[_T]:
             # because this is the descriptor protocol, we don't really know
             # what our attribute name is.  so search for it through the
             # MRO.
@@ -1149,12 +1387,15 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]):
             else:
                 name = attributes._UNKNOWN_ATTR_KEY  # type: ignore[assignment]
 
-            return proxy_attr(
-                owner,
-                name,
-                self,
-                comparator(owner),
-                doc=comparator.__doc__ or self.__doc__,
+            return cast(
+                "_HybridClassLevelAccessor[_T]",
+                proxy_attr(
+                    owner,
+                    name,
+                    self,
+                    comparator(owner),
+                    doc=comparator.__doc__ or self.__doc__,
+                ),
             )
 
         return expr_comparator
@@ -1166,7 +1407,7 @@ class Comparator(interfaces.PropComparator[_T]):
     classes for usage with hybrids."""
 
     def __init__(
-        self, expression: Union[_HasClauseElement, ColumnElement[_T]]
+        self, expression: Union[_HasClauseElement, SQLColumnExpression[_T]]
     ):
         self.expression = expression
 
@@ -1201,7 +1442,7 @@ class ExprComparator(Comparator[_T]):
     def __init__(
         self,
         cls: Type[Any],
-        expression: Union[_HasClauseElement, ColumnElement[_T]],
+        expression: Union[_HasClauseElement, SQLColumnExpression[_T]],
         hybrid: hybrid_property[_T],
     ):
         self.cls = cls
index e1190f7dd294c9fc760e6c5968a1e21fa39e5a13..ab124103fd229f64203a525463f432632371d966 100644 (file)
@@ -40,7 +40,6 @@ if TYPE_CHECKING:
     from .dml import UpdateBase
     from .dml import ValuesBase
     from .elements import ClauseElement
-    from .elements import ColumnClause
     from .elements import ColumnElement
     from .elements import KeyedColumnElement
     from .elements import quoted_name
@@ -224,7 +223,10 @@ _SelectStatementForCompoundArgument = Union[
 """SELECT statement acceptable by ``union()`` and other SQL set operations"""
 
 _DMLColumnArgument = Union[
-    str, "ColumnClause[Any]", _HasClauseElement, roles.DMLColumnRole
+    str,
+    _HasClauseElement,
+    roles.DMLColumnRole,
+    "SQLCoreOperations",
 ]
 """A DML column expression.  This is a "key" inside of insert().values(),
 update().values(), and related.
index 21b83d556e1b6ab68e2ceadd09be800e233eff96..75b5d09e3f56885574e7b487e0adad7dfb7e1f6c 100644 (file)
@@ -3561,7 +3561,7 @@ class SelectBase(
 
         .. seealso::
 
-            :meth:`_expression.SelectBase.as_scalar`.
+            :meth:`_expression.SelectBase.scalar_subquery`.
 
         """
         return self.scalar_subquery().label(name)
diff --git a/test/ext/mypy/plain_files/hybrid_four.py b/test/ext/mypy/plain_files/hybrid_four.py
new file mode 100644 (file)
index 0000000..a81ad96
--- /dev/null
@@ -0,0 +1,64 @@
+from __future__ import annotations
+
+from typing import Any
+
+from sqlalchemy import ColumnElement
+from sqlalchemy import func
+from sqlalchemy.ext.hybrid import Comparator
+from sqlalchemy.ext.hybrid import hybrid_property
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class CaseInsensitiveComparator(Comparator[str]):
+    def __eq__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
+        return func.lower(self.__clause_element__()) == func.lower(other)
+
+
+class SearchWord(Base):
+    __tablename__ = "searchword"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    word: Mapped[str]
+
+    @hybrid_property
+    def word_insensitive(self) -> str:
+        return self.word.lower()
+
+    @word_insensitive.inplace.comparator
+    @classmethod
+    def _word_insensitive_comparator(cls) -> CaseInsensitiveComparator:
+        return CaseInsensitiveComparator(cls.word)
+
+
+class FirstNameOnly(Base):
+    __tablename__ = "f"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    first_name: Mapped[str]
+
+    @hybrid_property
+    def name(self) -> str:
+        return self.first_name
+
+    @name.inplace.setter
+    def _name_setter(self, value: str) -> None:
+        self.first_name = value
+
+
+class FirstNameLastName(FirstNameOnly):
+
+    last_name: Mapped[str]
+
+    @FirstNameOnly.name.getter
+    def name(self) -> str:
+        return self.first_name + " " + self.last_name
+
+    @name.inplace.setter
+    def _name_setter(self, value: str) -> None:
+        self.first_name, self.last_name = value.split(" ", 1)
index b3ce365acd83e9b5e1380ef378cd68395a0fd31f..52a2a19ed0cee04be85d23e1fef6b244cf43348a 100644 (file)
@@ -69,7 +69,7 @@ if typing.TYPE_CHECKING:
     # EXPECTED_RE_TYPE: builtins.int\*?
     reveal_type(i1.length)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
+    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.int\*?\]
     reveal_type(Interval.length)
 
     # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
diff --git a/test/ext/mypy/plain_files/hybrid_three.py b/test/ext/mypy/plain_files/hybrid_three.py
new file mode 100644 (file)
index 0000000..86b0e4b
--- /dev/null
@@ -0,0 +1,87 @@
+from __future__ import annotations
+
+from decimal import Decimal
+from typing import cast
+from typing import List
+from typing import Optional
+
+from sqlalchemy import ForeignKey
+from sqlalchemy import Numeric
+from sqlalchemy import SQLColumnExpression
+from sqlalchemy import String
+from sqlalchemy.ext.hybrid import hybrid_property
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import relationship
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class SavingsAccount(Base):
+    __tablename__ = "account"
+    id: Mapped[int] = mapped_column(primary_key=True)
+    user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
+    balance: Mapped[Decimal] = mapped_column(Numeric(15, 5))
+
+
+class UserStyleOne(Base):
+    __tablename__ = "user"
+    id: Mapped[int] = mapped_column(primary_key=True)
+    name: Mapped[str] = mapped_column(String(100))
+
+    accounts: Mapped[List[SavingsAccount]] = relationship()
+
+    @hybrid_property
+    def _balance_getter(self) -> Optional[Decimal]:
+        if self.accounts:
+            return self.accounts[0].balance
+        else:
+            return None
+
+    @_balance_getter.setter
+    def _balance_setter(self, value: Optional[Decimal]) -> None:
+        assert value is not None
+        if not self.accounts:
+            account = SavingsAccount(owner=self)
+        else:
+            account = self.accounts[0]
+        account.balance = value
+
+    @_balance_setter.expression
+    def balance(cls) -> SQLColumnExpression[Optional[Decimal]]:
+        return cast(
+            "SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance
+        )
+
+
+class UserStyleTwo(Base):
+    __tablename__ = "user"
+    id: Mapped[int] = mapped_column(primary_key=True)
+    name: Mapped[str] = mapped_column(String(100))
+
+    accounts: Mapped[List[SavingsAccount]] = relationship()
+
+    @hybrid_property
+    def balance(self) -> Optional[Decimal]:
+        if self.accounts:
+            return self.accounts[0].balance
+        else:
+            return None
+
+    @balance.inplace.setter
+    def _balance_setter(self, value: Optional[Decimal]) -> None:
+        assert value is not None
+        if not self.accounts:
+            account = SavingsAccount(owner=self)
+        else:
+            account = self.accounts[0]
+        account.balance = value
+
+    @balance.inplace.expression
+    def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]:
+        return cast(
+            "SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance
+        )
index 619bbc839c19c6e8ff763f449b0bdd4d50d5572a..db50d7678e0ef42f36e66a9e85ba3e026de1158d 100644 (file)
@@ -30,17 +30,36 @@ class Interval(Base):
     def length(self) -> int:
         return self.end - self.start
 
-    # im not sure if there's a way to get typing tools to not complain about
-    # the re-defined name here, it handles it for plain @property
-    # but im not sure if that's hardcoded
-    # see https://github.com/python/typing/discussions/1102
+    # old way - chain decorators + modifiers
 
     @hybrid_property
     def _inst_radius(self) -> float:
         return abs(self.length) / 2
 
     @_inst_radius.expression
-    def radius(cls) -> ColumnElement[float]:
+    def old_radius(cls) -> ColumnElement[float]:
+        f1 = func.abs(cls.length, type_=Float())
+
+        expr = f1 / 2
+
+        # while we are here, check some Float[] / div type stuff
+        if typing.TYPE_CHECKING:
+            # EXPECTED_RE_TYPE: sqlalchemy.*Function\[builtins.float\*?\]
+            reveal_type(f1)
+
+            # EXPECTED_RE_TYPE: sqlalchemy.*ColumnElement\[builtins.float\*?\]
+            reveal_type(expr)
+        return expr
+
+    # new way - use the original decorator with inplace
+
+    @hybrid_property
+    def new_radius(self) -> float:
+        return abs(self.length) / 2
+
+    @new_radius.inplace.expression
+    @classmethod
+    def _new_radius_expr(cls) -> ColumnElement[float]:
         f1 = func.abs(cls.length, type_=Float())
 
         expr = f1 / 2
@@ -59,13 +78,17 @@ i1 = Interval(5, 10)
 i2 = Interval(7, 12)
 
 l1: int = i1.length
-rd: float = i2.radius
+rdo: float = i2.old_radius
+rdn: float = i2.new_radius
 
 expr1 = Interval.length.in_([5, 10])
 
-expr2 = Interval.radius
+expr2o = Interval.old_radius
+
+expr2n = Interval.new_radius
 
-expr3 = Interval.radius.in_([0.5, 5.2])
+expr3o = Interval.old_radius.in_([0.5, 5.2])
+expr3n = Interval.new_radius.in_([0.5, 5.2])
 
 
 if typing.TYPE_CHECKING:
@@ -73,22 +96,34 @@ if typing.TYPE_CHECKING:
     reveal_type(i1.length)
 
     # EXPECTED_RE_TYPE: builtins.float\*?
-    reveal_type(i2.radius)
+    reveal_type(i2.old_radius)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.int\*?\]
+    # EXPECTED_RE_TYPE: builtins.float\*?
+    reveal_type(i2.new_radius)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.int\*?\]
     reveal_type(Interval.length)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\]
-    reveal_type(Interval.radius)
+    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\]
+    reveal_type(Interval.old_radius)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\]
+    reveal_type(Interval.new_radius)
 
     # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
     reveal_type(expr1)
 
-    # EXPECTED_RE_TYPE: sqlalchemy.*.SQLCoreOperations\[builtins.float\*?\]
-    reveal_type(expr2)
+    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\]
+    reveal_type(expr2o)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*._HybridClassLevelAccessor\[builtins.float\*?\]
+    reveal_type(expr2n)
+
+    # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
+    reveal_type(expr3o)
 
     # EXPECTED_RE_TYPE: sqlalchemy.*.BinaryExpression\[builtins.bool\*?\]
-    reveal_type(expr3)
+    reveal_type(expr3n)
 
 # test #9268
 
index df69b36af1bd44f1fc68cfc880122ab91bde4903..c0928506679593a69fab24681d39fed4b8e4374e 100644 (file)
@@ -1,5 +1,6 @@
 from decimal import Decimal
 
+from sqlalchemy import column
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
@@ -26,6 +27,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_not
 from sqlalchemy.testing.fixtures import fixture_session
 from sqlalchemy.testing.schema import Column
 
@@ -33,7 +35,7 @@ from sqlalchemy.testing.schema import Column
 class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
-    def _fixture(self):
+    def _fixture(self, use_inplace=False, use_classmethod=False):
         Base = declarative_base()
 
         class UCComparator(hybrid.Comparator):
@@ -54,9 +56,33 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL):
                 "This is a docstring"
                 return self._value - 5
 
-            @value.comparator
-            def value(cls):
-                return UCComparator(cls._value)
+            if use_classmethod:
+                if use_inplace:
+
+                    @value.inplace.comparator
+                    @classmethod
+                    def _value_comparator(cls):
+                        return UCComparator(cls._value)
+
+                else:
+
+                    @value.comparator
+                    @classmethod
+                    def value(cls):
+                        return UCComparator(cls._value)
+
+            else:
+                if use_inplace:
+
+                    @value.inplace.comparator
+                    def _value_comparator(cls):
+                        return UCComparator(cls._value)
+
+                else:
+
+                    @value.comparator
+                    def value(cls):
+                        return UCComparator(cls._value)
 
             @value.setter
             def value(self, v):
@@ -70,31 +96,51 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL):
         eq_(a1._value, 10)
         eq_(a1.value, 5)
 
-    def test_value(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_value(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         eq_(str(A.value == 5), "upper(a.value) = upper(:upper_1)")
 
-    def test_aliased_value(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_aliased_value(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         eq_(str(aliased(A).value == 5), "upper(a_1.value) = upper(:upper_1)")
 
-    def test_query(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_query(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         sess = fixture_session()
         self.assert_compile(
             sess.query(A.value), "SELECT a.value AS a_value FROM a"
         )
 
-    def test_aliased_query(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_aliased_query(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         sess = fixture_session()
         self.assert_compile(
             sess.query(aliased(A).value),
             "SELECT a_1.value AS a_1_value FROM a AS a_1",
         )
 
-    def test_aliased_filter(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_aliased_filter(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         sess = fixture_session()
         self.assert_compile(
             sess.query(aliased(A)).filter_by(value="foo"),
@@ -189,7 +235,8 @@ class PropertyComparatorTest(fixtures.TestBase, AssertsCompiledSQL):
 class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"
 
-    def _fixture(self):
+    def _fixture(self, use_inplace=False, use_classmethod=False):
+        use_inplace, use_classmethod = bool(use_inplace), bool(use_classmethod)
         Base = declarative_base()
 
         class A(Base):
@@ -202,15 +249,42 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
                 "This is an instance-level docstring"
                 return int(self._value) - 5
 
-            @value.expression
-            def value(cls):
-                "This is a class-level docstring"
-                return func.foo(cls._value) + cls.bar_value
-
             @value.setter
             def value(self, v):
                 self._value = v + 5
 
+            if use_classmethod:
+                if use_inplace:
+
+                    @value.inplace.expression
+                    @classmethod
+                    def _value_expr(cls):
+                        "This is a class-level docstring"
+                        return func.foo(cls._value) + cls.bar_value
+
+                else:
+
+                    @value.expression
+                    @classmethod
+                    def value(cls):
+                        "This is a class-level docstring"
+                        return func.foo(cls._value) + cls.bar_value
+
+            else:
+                if use_inplace:
+
+                    @value.inplace.expression
+                    def _value_expr(cls):
+                        "This is a class-level docstring"
+                        return func.foo(cls._value) + cls.bar_value
+
+                else:
+
+                    @value.expression
+                    def value(cls):
+                        "This is a class-level docstring"
+                        return func.foo(cls._value) + cls.bar_value
+
             @hybrid.hybrid_property
             def bar_value(cls):
                 return func.bar(cls._value)
@@ -412,22 +486,34 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
             "a.lastname AS name FROM a) AS anon_1",
         )
 
-    def test_info(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_info(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         inspect(A).all_orm_descriptors.value.info["some key"] = "some value"
         eq_(
             inspect(A).all_orm_descriptors.value.info,
             {"some key": "some value"},
         )
 
-    def test_set_get(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_set_get(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         a1 = A(value=5)
         eq_(a1._value, 10)
         eq_(a1.value, 5)
 
-    def test_expression(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_expression(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         self.assert_compile(
             A.value.__clause_element__(), "foo(a.value) + bar(a.value)"
         )
@@ -455,15 +541,23 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
             "AND foo(a.value) + bar(a.value) = :param_1)",
         )
 
-    def test_aliased_expression(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_aliased_expression(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         self.assert_compile(
             aliased(A).value.__clause_element__(),
             "foo(a_1.value) + bar(a_1.value)",
         )
 
-    def test_query(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_query(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         sess = fixture_session()
         self.assert_compile(
             sess.query(A).filter_by(value="foo"),
@@ -471,8 +565,12 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
             "FROM a WHERE foo(a.value) + bar(a.value) = :param_1",
         )
 
-    def test_aliased_query(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_aliased_query(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         sess = fixture_session()
         self.assert_compile(
             sess.query(aliased(A)).filter_by(value="foo"),
@@ -480,8 +578,12 @@ class PropertyExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
             "FROM a AS a_1 WHERE foo(a_1.value) + bar(a_1.value) = :param_1",
         )
 
-    def test_docstring(self):
-        A = self._fixture()
+    @testing.variation("use_inplace", [True, False])
+    @testing.variation("use_classmethod", [True, False])
+    def test_docstring(self, use_inplace, use_classmethod):
+        A = self._fixture(
+            use_inplace=use_inplace, use_classmethod=use_classmethod
+        )
         eq_(A.value.__doc__, "This is a class-level docstring")
 
         # no docstring here since we get a literal
@@ -854,6 +956,84 @@ class SynonymOfPropertyTest(fixtures.TestBase, AssertsCompiledSQL):
         )
 
 
+class InplaceCreationTest(fixtures.TestBase, AssertsCompiledSQL):
+    """test 'inplace' definitions added for 2.0 to assist with typing
+    limitations.
+
+    """
+
+    __dialect__ = "default"
+
+    def test_property_integration(self, decl_base):
+        class Person(decl_base):
+            __tablename__ = "person"
+            id = Column(Integer, primary_key=True)
+            _name = Column(String)
+
+            @hybrid.hybrid_property
+            def name(self):
+                return self._name
+
+            @name.inplace.setter
+            def _name_setter(self, value):
+                self._name = value.title()
+
+            @name.inplace.expression
+            def _name_expression(self):
+                return func.concat("Hello", self._name)
+
+        p1 = Person(_name="name")
+        eq_(p1.name, "name")
+        p1.name = "new name"
+        eq_(p1.name, "New Name")
+
+        self.assert_compile(Person.name, "concat(:concat_1, person._name)")
+
+    def test_method_integration(self, decl_base):
+        class A(decl_base):
+            __tablename__ = "a"
+            id = Column(Integer, primary_key=True)
+            _value = Column("value", String)
+
+            @hybrid.hybrid_method
+            def value(self, x):
+                return int(self._value) + x
+
+            @value.inplace.expression
+            def _value_expression(cls, value):
+                return func.foo(cls._value, value) + value
+
+        a1 = A(_value="10")
+        eq_(a1.value(5), 15)
+
+        self.assert_compile(A.value(column("q")), "foo(a.value, q) + q")
+
+    def test_property_unit(self):
+        def one():
+            pass
+
+        def two():
+            pass
+
+        def three():
+            pass
+
+        prop = hybrid.hybrid_property(one)
+
+        prop2 = prop.inplace.expression(two)
+
+        prop3 = prop.inplace.setter(three)
+
+        is_(prop, prop2)
+        is_(prop, prop3)
+
+        def four():
+            pass
+
+        prop4 = prop.setter(four)
+        is_not(prop, prop4)
+
+
 class MethodExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
     __dialect__ = "default"