]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
add overload stubs for proxied classes
authorVincent Fazio <vfazio@gmail.com>
Tue, 3 Jan 2023 17:39:13 +0000 (12:39 -0500)
committersqla-tester <sqla-tester@sqlalchemy.org>
Tue, 3 Jan 2023 17:39:13 +0000 (12:39 -0500)
<!-- Provide a general summary of your proposed changes in the Title field above -->

### Description
Closes #1146
Closes #1147

<!-- Describe your changes in detail -->

Overloaded functions would not have type stubs generated by the stub generator for proxied classes. Now they will.

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [x] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

Closes: #1148
Pull-request: https://github.com/sqlalchemy/alembic/pull/1148
Pull-request-sha: ed3c28cc78e57314b7a4e533d77108efc6751949

Change-Id: I7c0ee9d333015174ee6ab754909748f745af2ff9

alembic/context.pyi
alembic/runtime/environment.py
alembic/util/compat.py
docs/build/unreleased/1147.rst [new file with mode: 0644]
tests/requirements.py
tools/write_pyi.py

index 9871fadddb9a85b27adef371fe6c6c3c35360fdd..86345c4f6076239f58b1ba6e526917b2d159ab13 100644 (file)
@@ -7,7 +7,9 @@ from typing import Callable
 from typing import ContextManager
 from typing import Dict
 from typing import List
+from typing import Literal
 from typing import Optional
+from typing import overload
 from typing import TextIO
 from typing import Tuple
 from typing import TYPE_CHECKING
@@ -644,8 +646,13 @@ def get_tag_argument() -> Optional[str]:
 
     """
 
+@overload
+def get_x_argument(as_dictionary: Literal[False]) -> List[str]: ...
+@overload
+def get_x_argument(as_dictionary: Literal[True]) -> Dict[str, str]: ...
+@overload
 def get_x_argument(
-    as_dictionary: bool = False,
+    as_dictionary: bool = ...,
 ) -> Union[List[str], Dict[str, str]]:
     """Return the value(s) passed for the ``-x`` argument, if any.
 
index 44dcd72db1e3fcdfdbedf2b9c52a06295ab39938..a441d1fd53739df3e1375d084defb01d377334aa 100644 (file)
@@ -269,15 +269,17 @@ class EnvironmentContext(util.ModuleClsProxy):
         return self.context_opts.get("tag", None)
 
     @overload
-    def get_x_argument(  # type:ignore[misc]
-        self, as_dictionary: Literal[False] = ...
-    ) -> List[str]:
+    def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]:
         ...
 
     @overload
-    def get_x_argument(  # type:ignore[misc]
-        self, as_dictionary: Literal[True] = ...
-    ) -> Dict[str, str]:
+    def get_x_argument(self, as_dictionary: Literal[True]) -> Dict[str, str]:
+        ...
+
+    @overload
+    def get_x_argument(
+        self, as_dictionary: bool = ...
+    ) -> Union[List[str], Dict[str, str]]:
         ...
 
     def get_x_argument(
index 289aaa228a8027efd4b66ea98448dd2a18953fb5..2fe49573d61c22ae9c811ccf64c308a4707d205d 100644 (file)
@@ -10,6 +10,7 @@ from sqlalchemy.util.compat import inspect_formatargspec  # noqa
 
 is_posix = os.name == "posix"
 
+py311 = sys.version_info >= (3, 11)
 py39 = sys.version_info >= (3, 9)
 py38 = sys.version_info >= (3, 8)
 
diff --git a/docs/build/unreleased/1147.rst b/docs/build/unreleased/1147.rst
new file mode 100644 (file)
index 0000000..5f6f7de
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, typing
+    :tickets: 1146, 1147
+
+    Fixed typing definitions for :meth:`.EnvironmentContext.get_x_argument`.
+
+    Typing stubs are now generated for overloaded proxied methods such as
+    :meth:`.EnvironmentContext.get_x_argument`.
\ No newline at end of file
index aa88f66d22a194133c522382a64ca77ba1a70839..c774e673a4e8c71d6825673fff6ff137065ffb41 100644 (file)
@@ -402,7 +402,7 @@ class DefaultRequirements(SuiteRequirements):
             requirements, "black and zimports are required for this test"
         )
         version = exclusions.only_if(
-            lambda _: compat.py39, "python 3.9 is required"
+            lambda _: compat.py311, "python 3.11 is required"
         )
 
         sqlalchemy = exclusions.only_if(
index e5112fdb0be31dc64920bd8a5a63f93db178fb3d..e3feb363096d1ee9a2f3dc8fee0cf040dde4bc97 100644 (file)
@@ -52,8 +52,10 @@ def generate_pyi_for_proxy(
 ):
     ignore_items = IGNORE_ITEMS.get(file_key, set())
     context_managers = CONTEXT_MANAGERS.get(file_key, [])
-    if sys.version_info < (3, 9):
-        raise RuntimeError("This script must be run with Python 3.9 or higher")
+    if sys.version_info < (3, 11):
+        raise RuntimeError(
+            "This script must be run with Python 3.11 or higher"
+        )
 
     # When using an absolute path on windows, this will generate the correct
     # relative path that shall be written to the top comment of the pyi file.
@@ -99,9 +101,30 @@ def generate_pyi_for_proxy(
                 continue
             meth = getattr(cls, name, None)
             if callable(meth):
-                _generate_stub_for_meth(
-                    cls, name, printer, env, name in context_managers
-                )
+                # If there are overloads, generate only those
+                # Do not generate the base implementation to avoid mypy errors
+                overloads = typing.get_overloads(meth)
+                if overloads:
+                    # use enumerate so we can generate docs on the last overload
+                    for i, ovl in enumerate(overloads, 1):
+                        _generate_stub_for_meth(
+                            ovl,
+                            cls,
+                            printer,
+                            env,
+                            is_context_manager=name in context_managers,
+                            is_overload=True,
+                            base_method=meth,
+                            gen_docs=(i == len(overloads)),
+                        )
+                else:
+                    _generate_stub_for_meth(
+                        meth,
+                        cls,
+                        printer,
+                        env,
+                        is_context_manager=name in context_managers,
+                    )
             else:
                 _generate_stub_for_attr(cls, name, printer, env)
 
@@ -133,12 +156,20 @@ def _generate_stub_for_attr(cls, name, printer, env):
     printer.writeline(f"{name}: {type_}")
 
 
-def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
-
-    fn = getattr(cls, name)
+def _generate_stub_for_meth(
+    fn,
+    cls,
+    printer,
+    env,
+    is_context_manager,
+    is_overload=False,
+    base_method=None,
+    gen_docs=True,
+):
     while hasattr(fn, "__wrapped__"):
         fn = fn.__wrapped__
 
+    name = fn.__name__
     spec = inspect_getfullargspec(fn)
     try:
         annotations = typing.get_type_hints(fn, env)
@@ -168,17 +199,29 @@ def _generate_stub_for_meth(cls, name, printer, env, is_context_manager):
         retval = re.sub("NoneType", "None", retval)
         return retval
 
+    def _formatvalue(value):
+        return "=" + ("..." if value is Ellipsis else repr(value))
+
     argspec = inspect_formatargspec(
         *spec,
         formatannotation=_formatannotation,
+        formatvalue=_formatvalue,
         formatreturns=lambda val: f"-> {_formatannotation(val)}",
     )
+
+    overload = "@overload" if is_overload else ""
     contextmanager = "@contextmanager" if is_context_manager else ""
+
+    fn_doc = base_method.__doc__ if base_method else fn.__doc__
+    has_docs = gen_docs and fn_doc is not None
+    docs = '"""' + f"{fn_doc}" + '"""' if has_docs else ""
+
     func_text = textwrap.dedent(
         f"""
+    {overload}
     {contextmanager}
-    def {name}{argspec}:
-        '''{fn.__doc__}'''
+    def {name}{argspec}: {"..." if not docs else ""}
+        {docs}
     """
     )