]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
generate overload type stubs
authorVincent Fazio <vfazio@gmail.com>
Thu, 29 Dec 2022 18:46:42 +0000 (12:46 -0600)
committerVincent Fazio <vfazio@xes-inc.com>
Fri, 30 Dec 2022 23:02:27 +0000 (17:02 -0600)
Add support for generating @overload type stubs for proxied classes.

This should fix typing issues with overloaded functions such as
`context.get_x_argument` that narrow return types based on parameters.

Fixes: #1147
Signed-off-by: Vincent Fazio <vfazio@gmail.com>
tools/write_pyi.py

index e5112fdb0be31dc64920bd8a5a63f93db178fb3d..83e7758603a66648c11f5568fdbe573a79be9e16 100644 (file)
@@ -99,9 +99,30 @@ def generate_pyi_for_proxy(
                 continue
             meth = getattr(cls, name, None)
             if callable(meth):
-                _generate_stub_for_meth(
-                    cls, name, printer, env, name in context_managers
-                )
+                # If there are overloads, generate only those
+                # Do not generate the base implementation to avoid mypy errors
+                overloads = typing.get_overloads(meth)
+                if overloads:
+                    # use enumerate so we can generate docs on the last overload
+                    for i, ovl in enumerate(overloads, 1):
+                        _generate_stub_for_meth(
+                            ovl,
+                            cls,
+                            printer,
+                            env,
+                            is_context_manager=name in context_managers,
+                            is_overload=True,
+                            base_method=meth,
+                            gen_docs=(i == len(overloads)),
+                        )
+                else:
+                    _generate_stub_for_meth(
+                        meth,
+                        cls,
+                        printer,
+                        env,
+                        is_context_manager=name in context_managers,
+                    )
             else:
                 _generate_stub_for_attr(cls, name, printer, env)
 
@@ -133,12 +154,20 @@ def _generate_stub_for_attr(cls, name, printer, env):
     printer.writeline(f"{name}: {type_}")
 
 
-def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
-
-    fn = getattr(cls, name)
+def _generate_stub_for_meth(
+    fn,
+    cls,
+    printer,
+    env,
+    is_context_manager,
+    is_overload=False,
+    base_method=None,
+    gen_docs=True,
+):
     while hasattr(fn, "__wrapped__"):
         fn = fn.__wrapped__
 
+    name = fn.__name__
     spec = inspect_getfullargspec(fn)
     try:
         annotations = typing.get_type_hints(fn, env)
@@ -168,17 +197,29 @@ def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
         retval = re.sub("NoneType", "None", retval)
         return retval
 
+    def _formatvalue(value):
+        return "=" + ("..." if value is Ellipsis else repr(value))
+
     argspec = inspect_formatargspec(
         *spec,
         formatannotation=_formatannotation,
+        formatvalue=_formatvalue,
         formatreturns=lambda val: f"-> {_formatannotation(val)}",
     )
+
+    overload = "@overload" if is_overload else ""
     contextmanager = "@contextmanager" if is_context_manager else ""
+
+    fn_doc = base_method.__doc__ if base_method else fn.__doc__
+    has_docs = gen_docs and fn_doc is not None
+    docs = '"""' + f"{fn_doc}" + '"""' if has_docs else ""
+
     func_text = textwrap.dedent(
         f"""
+    {overload}
     {contextmanager}
-    def {name}{argspec}:
-        '''{fn.__doc__}'''
+    def {name}{argspec}: {"..." if not docs else ""}
+        {docs}
     """
     )