From: Vincent Fazio Date: Tue, 3 Jan 2023 17:39:13 +0000 (-0500) Subject: add overload stubs for proxied classes X-Git-Tag: rel_1_9_2~3^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=40df5ecb410660fdf96a750d626d4f2f5b2d98ea;p=thirdparty%2Fsqlalchemy%2Falembic.git add overload stubs for proxied classes ### Description Closes #1146 Closes #1147 Overloaded functions would not have type stubs generated by the stub generator for proxied classes. Now they will. ### Checklist This pull request is: - [ ] A documentation / typographical error fix - Good to go, no issue or tests are needed - [x] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. Closes: #1148 Pull-request: https://github.com/sqlalchemy/alembic/pull/1148 Pull-request-sha: ed3c28cc78e57314b7a4e533d77108efc6751949 Change-Id: I7c0ee9d333015174ee6ab754909748f745af2ff9 --- diff --git a/alembic/context.pyi b/alembic/context.pyi index 9871fadd..86345c4f 100644 --- a/alembic/context.pyi +++ b/alembic/context.pyi @@ -7,7 +7,9 @@ from typing import Callable from typing import ContextManager from typing import Dict from typing import List +from typing import Literal from typing import Optional +from typing import overload from typing import TextIO from typing import Tuple from typing import TYPE_CHECKING @@ -644,8 +646,13 @@ def get_tag_argument() -> Optional[str]: """ +@overload +def get_x_argument(as_dictionary: Literal[False]) -> List[str]: ... +@overload +def get_x_argument(as_dictionary: Literal[True]) -> Dict[str, str]: ... +@overload def get_x_argument( - as_dictionary: bool = False, + as_dictionary: bool = ..., ) -> Union[List[str], Dict[str, str]]: """Return the value(s) passed for the ``-x`` argument, if any. diff --git a/alembic/runtime/environment.py b/alembic/runtime/environment.py index 44dcd72d..a441d1fd 100644 --- a/alembic/runtime/environment.py +++ b/alembic/runtime/environment.py @@ -269,15 +269,17 @@ class EnvironmentContext(util.ModuleClsProxy): return self.context_opts.get("tag", None) @overload - def get_x_argument( # type:ignore[misc] - self, as_dictionary: Literal[False] = ... - ) -> List[str]: + def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]: ... @overload - def get_x_argument( # type:ignore[misc] - self, as_dictionary: Literal[True] = ... - ) -> Dict[str, str]: + def get_x_argument(self, as_dictionary: Literal[True]) -> Dict[str, str]: + ... + + @overload + def get_x_argument( + self, as_dictionary: bool = ... + ) -> Union[List[str], Dict[str, str]]: ... def get_x_argument( diff --git a/alembic/util/compat.py b/alembic/util/compat.py index 289aaa22..2fe49573 100644 --- a/alembic/util/compat.py +++ b/alembic/util/compat.py @@ -10,6 +10,7 @@ from sqlalchemy.util.compat import inspect_formatargspec # noqa is_posix = os.name == "posix" +py311 = sys.version_info >= (3, 11) py39 = sys.version_info >= (3, 9) py38 = sys.version_info >= (3, 8) diff --git a/docs/build/unreleased/1147.rst b/docs/build/unreleased/1147.rst new file mode 100644 index 00000000..5f6f7dec --- /dev/null +++ b/docs/build/unreleased/1147.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, typing + :tickets: 1146, 1147 + + Fixed typing definitions for :meth:`.EnvironmentContext.get_x_argument`. + + Typing stubs are now generated for overloaded proxied methods such as + :meth:`.EnvironmentContext.get_x_argument`. \ No newline at end of file diff --git a/tests/requirements.py b/tests/requirements.py index aa88f66d..c774e673 100644 --- a/tests/requirements.py +++ b/tests/requirements.py @@ -402,7 +402,7 @@ class DefaultRequirements(SuiteRequirements): requirements, "black and zimports are required for this test" ) version = exclusions.only_if( - lambda _: compat.py39, "python 3.9 is required" + lambda _: compat.py311, "python 3.11 is required" ) sqlalchemy = exclusions.only_if( diff --git a/tools/write_pyi.py b/tools/write_pyi.py index e5112fdb..e3feb363 100644 --- a/tools/write_pyi.py +++ b/tools/write_pyi.py @@ -52,8 +52,10 @@ def generate_pyi_for_proxy( ): ignore_items = IGNORE_ITEMS.get(file_key, set()) context_managers = CONTEXT_MANAGERS.get(file_key, []) - if sys.version_info < (3, 9): - raise RuntimeError("This script must be run with Python 3.9 or higher") + if sys.version_info < (3, 11): + raise RuntimeError( + "This script must be run with Python 3.11 or higher" + ) # When using an absolute path on windows, this will generate the correct # relative path that shall be written to the top comment of the pyi file. @@ -99,9 +101,30 @@ def generate_pyi_for_proxy( continue meth = getattr(cls, name, None) if callable(meth): - _generate_stub_for_meth( - cls, name, printer, env, name in context_managers - ) + # If there are overloads, generate only those + # Do not generate the base implementation to avoid mypy errors + overloads = typing.get_overloads(meth) + if overloads: + # use enumerate so we can generate docs on the last overload + for i, ovl in enumerate(overloads, 1): + _generate_stub_for_meth( + ovl, + cls, + printer, + env, + is_context_manager=name in context_managers, + is_overload=True, + base_method=meth, + gen_docs=(i == len(overloads)), + ) + else: + _generate_stub_for_meth( + meth, + cls, + printer, + env, + is_context_manager=name in context_managers, + ) else: _generate_stub_for_attr(cls, name, printer, env) @@ -133,12 +156,20 @@ def _generate_stub_for_attr(cls, name, printer, env): printer.writeline(f"{name}: {type_}") -def _generate_stub_for_meth(cls, name, printer, env, is_context_manager): - - fn = getattr(cls, name) +def _generate_stub_for_meth( + fn, + cls, + printer, + env, + is_context_manager, + is_overload=False, + base_method=None, + gen_docs=True, +): while hasattr(fn, "__wrapped__"): fn = fn.__wrapped__ + name = fn.__name__ spec = inspect_getfullargspec(fn) try: annotations = typing.get_type_hints(fn, env) @@ -168,17 +199,29 @@ def _generate_stub_for_meth(cls, name, printer, env, is_context_manager): retval = re.sub("NoneType", "None", retval) return retval + def _formatvalue(value): + return "=" + ("..." if value is Ellipsis else repr(value)) + argspec = inspect_formatargspec( *spec, formatannotation=_formatannotation, + formatvalue=_formatvalue, formatreturns=lambda val: f"-> {_formatannotation(val)}", ) + + overload = "@overload" if is_overload else "" contextmanager = "@contextmanager" if is_context_manager else "" + + fn_doc = base_method.__doc__ if base_method else fn.__doc__ + has_docs = gen_docs and fn_doc is not None + docs = '"""' + f"{fn_doc}" + '"""' if has_docs else "" + func_text = textwrap.dedent( f""" + {overload} {contextmanager} - def {name}{argspec}: - '''{fn.__doc__}''' + def {name}{argspec}: {"..." if not docs else ""} + {docs} """ )