From: Mike Bayer Date: Mon, 26 Sep 2022 01:40:48 +0000 (-0400) Subject: add typing for sqlalchemy.orm.validates X-Git-Tag: rel_2_0_0b1~38^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=74f6e38ec717979bcde76244c94b1d8c519a5b63;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git add typing for sqlalchemy.orm.validates Fixes: #8577 Change-Id: Iede1c956078960fb866da45f1ac6aa43842516bc --- diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 98c0eba0ca..553f7b35b2 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -118,6 +118,8 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) _MP = TypeVar("_MP", bound="MapperProperty[Any]") +_Fn = TypeVar("_Fn", bound="Callable[..., Any]") + _WithPolymorphicArg = Union[ Literal["*"], @@ -3895,7 +3897,9 @@ def reconstructor(fn): return fn -def validates(*names, **kw): +def validates( + *names: str, include_removes: bool = False, include_backrefs: bool = False +) -> Callable[[_Fn], _Fn]: r"""Decorate a method as a 'validator' for one or more named properties. Designates a method as a validator, a method which receives the @@ -3930,12 +3934,10 @@ def validates(*names, **kw): :ref:`simple_validators` - usage examples for :func:`.validates` """ - include_removes = kw.pop("include_removes", False) - include_backrefs = kw.pop("include_backrefs", True) - def wrap(fn): - fn.__sa_validators__ = names - fn.__sa_validation_opts__ = { + def wrap(fn: _Fn) -> _Fn: + fn.__sa_validators__ = names # type: ignore[attr-defined] + fn.__sa_validation_opts__ = { # type: ignore[attr-defined] "include_removes": include_removes, "include_backrefs": include_backrefs, } diff --git a/test/ext/mypy/plain_files/orm_config_constructs.py b/test/ext/mypy/plain_files/orm_config_constructs.py new file mode 100644 index 0000000000..008e16f240 --- /dev/null +++ b/test/ext/mypy/plain_files/orm_config_constructs.py @@ -0,0 +1,20 @@ +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import validates + + +class Base(DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "User" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + + @validates("name", include_removes=True) + def validate_name(self, name: str) -> str: + """test #8577""" + return name + "hi"