]> git.ipfire.org Git - thirdparty/sqlalchemy/alembic.git/commitdiff
get write_pyi to support lowercase types with pipes
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Dec 2025 14:53:13 +0000 (09:53 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 16 Dec 2025 15:54:11 +0000 (10:54 -0500)
The specific form of `tuple[] | None` produces a `types.UnionType`
in some way that seems to not be what it has ever been previously
(an explicit Union will give you `<class 'typing._UnionGenericAlias'>`,
apparently).  so repr() this specific case so we can move to newer
typing formats.

As a test, this moves the type of partial_reordering to the newer
format.

Also, write output file using shutil.move from tempfile, so that
crashes of write_pyi dont corrupt the file.

Add version checks for black, python version

bump minimum python version to 3.12 as 3.11 seems to have problems
we dont need to fix

Change-Id: I91914c1e1b979ad84ca8b82d362ed94312645994

alembic/op.pyi
alembic/operations/base.py
alembic/script/write_hooks.py
tests/requirements.py
tools/write_pyi.py

index 6da538713cb62479e8093e57f964c4a23df665cf..96f68b82ffadb4aa23897edcc0ca5b116cb70270 100644 (file)
@@ -256,7 +256,7 @@ def batch_alter_table(
     table_name: str,
     schema: Optional[str] = None,
     recreate: Literal["auto", "always", "never"] = "auto",
-    partial_reordering: Optional[list[tuple[str, ...]]] = None,
+    partial_reordering: list[tuple[str, ...]] | None = None,
     copy_from: Optional[Table] = None,
     table_args: Tuple[Any, ...] = (),
     table_kwargs: Mapping[str, Any] = immutabledict({}),
index 5ca38b1f99c1851fadd8924fc3d7daaf6c5427d1..be3a77b2ada212b7d9a238c331f8f1ce48062e3e 100644 (file)
@@ -248,7 +248,7 @@ class AbstractOperations(util.ModuleClsProxy):
         table_name: str,
         schema: Optional[str] = None,
         recreate: Literal["auto", "always", "never"] = "auto",
-        partial_reordering: Optional[list[tuple[str, ...]]] = None,
+        partial_reordering: list[tuple[str, ...]] | None = None,
         copy_from: Optional[Table] = None,
         table_args: Tuple[Any, ...] = (),
         table_kwargs: Mapping[str, Any] = util.immutabledict(),
index 6b8161dbb2c30887fefd1d7c30aa319665c5f0b1..3dd49d9108bd9b7b11737fdc2a307be08456d008 100644 (file)
@@ -135,7 +135,10 @@ def _run_hook(
 
 @register("console_scripts")
 def console_scripts(
-    path: str, options: dict, ignore_output: bool = False
+    path: str,
+    options: dict,
+    ignore_output: bool = False,
+    verify_version: tuple[int, ...] | None = None,
 ) -> None:
     entrypoint_name = _get_required_option(options, "entrypoint")
     for entry in compat.importlib_metadata_get("console_scripts"):
@@ -147,11 +150,17 @@ def console_scripts(
             f"Could not find entrypoint console_scripts.{entrypoint_name}"
         )
 
-    command = [
-        sys.executable,
-        "-c",
-        f"import {impl.module}; {impl.module}.{impl.attr}()",
-    ]
+    if verify_version:
+        pyscript = (
+            f"import {impl.module}; "
+            f"assert tuple(int(x) for x in {impl.module}.__version__.split('.')) >= {verify_version}, "  # noqa: E501
+            f"'need exactly version {verify_version} of {impl.name}'; "
+            f"{impl.module}.{impl.attr}()"
+        )
+    else:
+        pyscript = f"import {impl.module}; {impl.module}.{impl.attr}()"
+
+    command = [sys.executable, "-c", pyscript]
     _run_hook(path, options, ignore_output, command)
 
 
index a41fc4232a6e464062459f700f654ee287a37741..9f91a743d050ceb767455e1c9eedd5c00b0a8542 100644 (file)
@@ -376,7 +376,7 @@ class DefaultRequirements(SuiteRequirements):
             requirements, "black and zimports are required for this test"
         )
         version_low = exclusions.only_if(
-            lambda _: compat.py311, "python 3.11 is required"
+            lambda _: compat.py312, "python 3.12 is required"
         )
 
         version_high = exclusions.only_if(
index cd11e2e90d3a5339698c2a69c70c2d10520962d4..ec5a11e32fb79d6ca5e1f27e959cf62e41f61e59 100644 (file)
@@ -6,6 +6,7 @@ from dataclasses import dataclass
 from dataclasses import field
 from pathlib import Path
 import re
+import shutil
 import sys
 from tempfile import NamedTemporaryFile
 import textwrap
@@ -28,6 +29,10 @@ if True:  # avoid flake/zimports messing with the order
     from alembic.operations import ops
     import sqlalchemy as sa
 
+BLACK_VERSION = (25, 9, 0)
+PYTHON_VERSIONS = (3, 12), (3, 14)
+
+
 TRIM_MODULE = [
     "alembic.autogenerate.api.",
     "alembic.operations.base.",
@@ -55,9 +60,13 @@ ADDITIONAL_ENV = {
 def generate_pyi_for_proxy(
     file_info: FileInfo, destination_path: Path, ignore_output: bool
 ):
-    if sys.version_info < (3, 11):
+    lower_python, upper_python = PYTHON_VERSIONS
+    if sys.version_info < lower_python or sys.version_info >= upper_python:
         raise RuntimeError(
-            "This script must be run with Python 3.11 or higher"
+            f"Script supports at least python "
+            f"{".".join(str(x) for x in lower_python)} "
+            f"but less than {".".join(str(x) for x in upper_python)} "
+            "right now."
         )
 
     progname = Path(sys.argv[0]).as_posix()
@@ -132,6 +141,7 @@ def generate_pyi_for_proxy(
         str(destination_path),
         {"entrypoint": "black", "options": "-l79 --target-version py39"},
         ignore_output=ignore_output,
+        verify_version=BLACK_VERSION,
     )
 
 
@@ -176,6 +186,8 @@ def _generate_stub_for_meth(
     def _formatannotation(annotation, base_module=None):
         if getattr(annotation, "__module__", None) == "typing":
             retval = repr(annotation).replace("typing.", "")
+        elif getattr(annotation, "__module__", None) == "types":
+            retval = repr(annotation).replace("types.", "")
         elif isinstance(annotation, type):
             retval = annotation.__qualname__
         elif isinstance(annotation, typing.TypeVar):
@@ -188,6 +200,8 @@ def _generate_stub_for_meth(
         else:
             retval = annotation
 
+        assert isinstance(retval, str)
+
         retval = re.sub(r"TypeEngine\b", "TypeEngine[Any]", retval)
 
         retval = retval.replace("~", "")  # typevar repr as "~T"
@@ -249,9 +263,14 @@ def _generate_stub_for_meth(
 
 def run_file(finfo: FileInfo, stdout: bool):
     if not stdout:
-        generate_pyi_for_proxy(
-            finfo, destination_path=finfo.path, ignore_output=False
-        )
+        with NamedTemporaryFile(delete=False, suffix=finfo.path.suffix) as f:
+            f.close()
+            f_path = Path(f.name)
+            generate_pyi_for_proxy(
+                finfo, destination_path=f_path, ignore_output=False
+            )
+            shutil.move(f_path, finfo.path)
+
     else:
         with NamedTemporaryFile(delete=False, suffix=finfo.path.suffix) as f:
             f.close()