From: Vincent Fazio Date: Thu, 29 Dec 2022 18:46:42 +0000 (-0600) Subject: generate overload type stubs X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=1cd92bda272d2725f918f8ca86566962f784c3d4;p=thirdparty%2Fsqlalchemy%2Falembic.git generate overload type stubs Add support for generating @overload type stubs for proxied classes. This should fix typing issues with overloaded functions such as `context.get_x_argument` that narrow return types based on parameters. Fixes: #1147 Signed-off-by: Vincent Fazio --- diff --git a/tools/write_pyi.py b/tools/write_pyi.py index e5112fdb..83e77586 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -99,9 +99,30 @@ def generate_pyi_for_proxy( continue meth = getattr(cls, name, None) if callable(meth): - _generate_stub_for_meth( - cls, name, printer, env, name in context_managers - ) + # If there are overloads, generate only those + # Do not generate the base implementation to avoid mypy errors + overloads = typing.get_overloads(meth) + if overloads: + # use enumerate so we can generate docs on the last overload + for i, ovl in enumerate(overloads, 1): + _generate_stub_for_meth( + ovl, + cls, + printer, + env, + is_context_manager=name in context_managers, + is_overload=True, + base_method=meth, + gen_docs=(i == len(overloads)), + ) + else: + _generate_stub_for_meth( + meth, + cls, + printer, + env, + is_context_manager=name in context_managers, + ) else: _generate_stub_for_attr(cls, name, printer, env) @@ -133,12 +154,20 @@ def _generate_stub_for_attr(cls, name, printer, env): printer.writeline(f"{name}: {type_}") -def _generate_stub_for_meth(cls, name, printer, env, is_context_manager): - - fn = getattr(cls, name) +def _generate_stub_for_meth( + fn, + cls, + printer, + env, + is_context_manager, + is_overload=False, + base_method=None, + gen_docs=True, +): while hasattr(fn, "__wrapped__"): fn = fn.__wrapped__ + name = fn.__name__ spec = inspect_getfullargspec(fn) try: annotations = typing.get_type_hints(fn, env) @@ -168,17 +197,29 @@ def _generate_stub_for_meth(cls, name, printer, env, is_context_manager): retval = re.sub("NoneType", "None", retval) return retval + def _formatvalue(value): + return "=" + ("..." if value is Ellipsis else repr(value)) + argspec = inspect_formatargspec( *spec, formatannotation=_formatannotation, + formatvalue=_formatvalue, formatreturns=lambda val: f"-> {_formatannotation(val)}", ) + + overload = "@overload" if is_overload else "" contextmanager = "@contextmanager" if is_context_manager else "" + + fn_doc = base_method.__doc__ if base_method else fn.__doc__ + has_docs = gen_docs and fn_doc is not None + docs = '"""' + f"{fn_doc}" + '"""' if has_docs else "" + func_text = textwrap.dedent( f""" + {overload} {contextmanager} - def {name}{argspec}: - '''{fn.__doc__}''' + def {name}{argspec}: {"..." if not docs else ""} + {docs} """ )