]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
add typing for sqlalchemy.orm.validates
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Sep 2022 01:40:48 +0000 (21:40 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Sep 2022 01:41:14 +0000 (21:41 -0400)
Fixes: #8577
Change-Id: Iede1c956078960fb866da45f1ac6aa43842516bc

lib/sqlalchemy/orm/mapper.py
test/ext/mypy/plain_files/orm_config_constructs.py [new file with mode: 0644]

index 98c0eba0cab6b3f873df1b7308cd20a79fa5e658..553f7b35b25bc0a58b214439329a67594fa2c130 100644 (file)
@@ -118,6 +118,8 @@ if TYPE_CHECKING:
 
 _T = TypeVar("_T", bound=Any)
 _MP = TypeVar("_MP", bound="MapperProperty[Any]")
+_Fn = TypeVar("_Fn", bound="Callable[..., Any]")
+
 
 _WithPolymorphicArg = Union[
     Literal["*"],
@@ -3895,7 +3897,9 @@ def reconstructor(fn):
     return fn
 
 
-def validates(*names, **kw):
+def validates(
+    *names: str, include_removes: bool = False, include_backrefs: bool = False
+) -> Callable[[_Fn], _Fn]:
     r"""Decorate a method as a 'validator' for one or more named properties.
 
     Designates a method as a validator, a method which receives the
@@ -3930,12 +3934,10 @@ def validates(*names, **kw):
       :ref:`simple_validators` - usage examples for :func:`.validates`
 
     """
-    include_removes = kw.pop("include_removes", False)
-    include_backrefs = kw.pop("include_backrefs", True)
 
-    def wrap(fn):
-        fn.__sa_validators__ = names
-        fn.__sa_validation_opts__ = {
+    def wrap(fn: _Fn) -> _Fn:
+        fn.__sa_validators__ = names  # type: ignore[attr-defined]
+        fn.__sa_validation_opts__ = {  # type: ignore[attr-defined]
             "include_removes": include_removes,
             "include_backrefs": include_backrefs,
         }
diff --git a/test/ext/mypy/plain_files/orm_config_constructs.py b/test/ext/mypy/plain_files/orm_config_constructs.py
new file mode 100644 (file)
index 0000000..008e16f
--- /dev/null
@@ -0,0 +1,20 @@
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+from sqlalchemy.orm import validates
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class User(Base):
+    __tablename__ = "User"
+
+    id: Mapped[int] = mapped_column(primary_key=True)
+    name: Mapped[str]
+
+    @validates("name", include_removes=True)
+    def validate_name(self, name: str) -> str:
+        """test #8577"""
+        return name + "hi"