]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
refactor code generation tools , include --check command
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 18 Jan 2023 17:45:42 +0000 (12:45 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 18 Jan 2023 20:07:55 +0000 (15:07 -0500)
in particular it looks like CI was not picking up on the
"git diff" oriented commands, which were failing to run due
to pathing issues.  As we were setting cwd for black/zimports
relative to sqlalchemy library, and tox installs it in
the venv, black/zimports would fail to run from tox, and
since these are subprocess.run we didn't pick up the
failure.

This overall locks down how zimports/black are run
so that we are definitely from the source root, by using
the location of tools/ to determine the root.

Fixes: #8892
Change-Id: I7c54b747edd5a80e0c699b8456febf66d8b62375

lib/sqlalchemy/orm/scoping.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/tool_support.py [new file with mode: 0644]
test/orm/test_scoping.py
tools/format_docs_code.py
tools/generate_proxy_methods.py
tools/generate_tuple_map_overloads.py
tools/sync_test_files.py
tools/trace_orm_adapter.py
tox.ini

index 2b2a949befba0f291fae3de62fba2ac6ab750066..aafe03673f81da9e7e8cc19f4871ea46282c41a1 100644 (file)
@@ -877,6 +877,7 @@ class scoped_session(Generic[_S]):
         with_for_update: Optional[ForUpdateArg] = None,
         identity_token: Optional[Any] = None,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
     ) -> Optional[_O]:
         r"""Return an instance based on the given primary key identifier,
         or ``None`` if not found.
@@ -975,6 +976,13 @@ class scoped_session(Generic[_S]):
             :ref:`orm_queryguide_execution_options` - ORM-specific execution
             options
 
+        :param bind_arguments: dictionary of additional arguments to determine
+         the bind.  May include "mapper", "bind", or other custom arguments.
+         Contents of this dictionary are passed to the
+         :meth:`.Session.get_bind` method.
+
+         .. versionadded: 2.0.0rc1
+
         :return: The object instance, or ``None``.
 
 
@@ -988,15 +996,18 @@ class scoped_session(Generic[_S]):
             with_for_update=with_for_update,
             identity_token=identity_token,
             execution_options=execution_options,
+            bind_arguments=bind_arguments,
         )
 
     def get_bind(
         self,
         mapper: Optional[_EntityBindKey[_O]] = None,
+        *,
         clause: Optional[ClauseElement] = None,
         bind: Optional[_SessionBind] = None,
         _sa_skip_events: Optional[bool] = None,
         _sa_skip_for_implicit_returning: bool = False,
+        **kw: Any,
     ) -> Union[Engine, Connection]:
         r"""Return a "bind" to which this :class:`.Session` is bound.
 
@@ -1082,6 +1093,7 @@ class scoped_session(Generic[_S]):
             bind=bind,
             _sa_skip_events=_sa_skip_events,
             _sa_skip_for_implicit_returning=_sa_skip_for_implicit_returning,
+            **kw,
         )
 
     def is_modified(
index 5058b7e7852639aeefce1806bcc38745ec40d701..7671480452052267e1f80ac79ee93f28b5bc3983 100644 (file)
@@ -19,7 +19,6 @@ import hashlib
 import inspect
 import itertools
 import operator
-import os
 import re
 import sys
 import textwrap
@@ -34,7 +33,6 @@ from typing import Generic
 from typing import Iterator
 from typing import List
 from typing import Mapping
-from typing import no_type_check
 from typing import NoReturn
 from typing import Optional
 from typing import overload
@@ -2180,45 +2178,3 @@ def has_compiled_ext(raise_=False):
         )
     else:
         return False
-
-
-@no_type_check
-def console_scripts(
-    path: str, options: dict, ignore_output: bool = False
-) -> None:
-
-    import subprocess
-    import shlex
-    from pathlib import Path
-
-    is_posix = os.name == "posix"
-
-    entrypoint_name = options["entrypoint"]
-
-    for entry in compat.importlib_metadata_get("console_scripts"):
-        if entry.name == entrypoint_name:
-            impl = entry
-            break
-    else:
-        raise Exception(
-            f"Could not find entrypoint console_scripts.{entrypoint_name}"
-        )
-    cmdline_options_str = options.get("options", "")
-    cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [
-        path
-    ]
-
-    kw = {}
-    if ignore_output:
-        kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
-
-    subprocess.run(
-        [
-            sys.executable,
-            "-c",
-            "import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
-        ]
-        + cmdline_options_list,
-        cwd=Path(__file__).parent.parent,
-        **kw,
-    )
diff --git a/lib/sqlalchemy/util/tool_support.py b/lib/sqlalchemy/util/tool_support.py
new file mode 100644 (file)
index 0000000..5a2fc3b
--- /dev/null
@@ -0,0 +1,198 @@
+# util/tool_support.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: allow-untyped-defs, allow-untyped-calls
+"""support routines for the helpers in tools/.
+
+These aren't imported by the enclosing util package as the are not
+needed for normal library use.
+
+"""
+from __future__ import annotations
+
+from argparse import ArgumentParser
+from argparse import Namespace
+import contextlib
+import difflib
+import os
+from pathlib import Path
+import shlex
+import shutil
+import subprocess
+import sys
+from typing import Any
+from typing import Dict
+from typing import Iterator
+from typing import Optional
+
+from . import compat
+
+
+class code_writer_cmd:
+    parser: ArgumentParser
+    args: Namespace
+    suppress_output: bool
+    diffs_detected: bool
+    source_root: Path
+    pyproject_toml_path: Path
+
+    def __init__(self, tool_script: str):
+        self.source_root = Path(tool_script).parent.parent
+        self.pyproject_toml_path = self.source_root / Path("pyproject.toml")
+        assert self.pyproject_toml_path.exists()
+
+        self.parser = ArgumentParser()
+        self.parser.add_argument(
+            "--stdout",
+            action="store_true",
+            help="Write to stdout instead of saving to file",
+        )
+        self.parser.add_argument(
+            "-c",
+            "--check",
+            help="Don't write the files back, just return the "
+            "status. Return code 0 means nothing would change. "
+            "Return code 1 means some files would be reformatted",
+            action="store_true",
+        )
+
+    def run_zimports(self, tempfile: str) -> None:
+        self._run_console_script(
+            str(tempfile),
+            {
+                "entrypoint": "zimports",
+                "options": f"--toml-config {self.pyproject_toml_path}",
+            },
+        )
+
+    def run_black(self, tempfile: str) -> None:
+        self._run_console_script(
+            str(tempfile),
+            {
+                "entrypoint": "black",
+                "options": f"--config {self.pyproject_toml_path}",
+            },
+        )
+
+    def _run_console_script(self, path: str, options: Dict[str, Any]) -> None:
+        """Run a Python console application from within the process.
+
+        Used for black, zimports
+
+        """
+
+        is_posix = os.name == "posix"
+
+        entrypoint_name = options["entrypoint"]
+
+        for entry in compat.importlib_metadata_get("console_scripts"):
+            if entry.name == entrypoint_name:
+                impl = entry
+                break
+        else:
+            raise Exception(
+                f"Could not find entrypoint console_scripts.{entrypoint_name}"
+            )
+        cmdline_options_str = options.get("options", "")
+        cmdline_options_list = shlex.split(
+            cmdline_options_str, posix=is_posix
+        ) + [path]
+
+        kw: Dict[str, Any] = {}
+        if self.suppress_output:
+            kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
+
+        subprocess.run(
+            [
+                sys.executable,
+                "-c",
+                "import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
+            ]
+            + cmdline_options_list,
+            cwd=str(self.source_root),
+            **kw,
+        )
+
+    def write_status(self, *text: str) -> None:
+        if not self.suppress_output:
+            sys.stderr.write(" ".join(text))
+
+    def write_output_file_from_text(
+        self, text: str, destination_path: str
+    ) -> None:
+        if self.args.check:
+            self._run_diff(destination_path, source=text)
+        elif self.args.stdout:
+            print(text)
+        else:
+            self.write_status(f"Writing {destination_path}...")
+            Path(destination_path).write_text(text)
+            self.write_status("done\n")
+
+    def write_output_file_from_tempfile(
+        self, tempfile: str, destination_path: str
+    ) -> None:
+        if self.args.check:
+            self._run_diff(destination_path, source_file=tempfile)
+            os.unlink(tempfile)
+        elif self.args.stdout:
+            with open(tempfile) as tf:
+                print(tf.read())
+            os.unlink(tempfile)
+        else:
+            self.write_status(f"Writing {destination_path}...")
+            shutil.move(tempfile, destination_path)
+            self.write_status("done\n")
+
+    def _run_diff(
+        self,
+        destination_path: str,
+        *,
+        source: Optional[str] = None,
+        source_file: Optional[str] = None,
+    ) -> None:
+        if source_file:
+            with open(source_file) as tf:
+                source_lines = list(tf)
+        elif source is not None:
+            source_lines = source.splitlines(keepends=True)
+        else:
+            assert False, "source or source_file is required"
+
+        with open(destination_path) as dp:
+            d = difflib.unified_diff(
+                list(dp),
+                source_lines,
+                fromfile=destination_path,
+                tofile="<proposed changes>",
+                n=3,
+                lineterm="\n",
+            )
+            d_as_list = list(d)
+            if d_as_list:
+                self.diffs_detected = True
+                print("".join(d_as_list))
+
+    @contextlib.contextmanager
+    def add_arguments(self) -> Iterator[ArgumentParser]:
+        yield self.parser
+
+    @contextlib.contextmanager
+    def run_program(self) -> Iterator[None]:
+        self.args = self.parser.parse_args()
+        if self.args.check:
+            self.diffs_detected = False
+            self.suppress_output = True
+        elif self.args.stdout:
+            self.suppress_output = True
+        else:
+            self.suppress_output = False
+        yield
+
+        if self.args.check and self.diffs_detected:
+            sys.exit(1)
+        else:
+            sys.exit(0)
index 22e1178aa75d1412097a00472dc7586ec06a034e..8c6ddfa0e58ce5bb8158c8f0d758ad19ef5ef878 100644 (file)
@@ -160,6 +160,7 @@ class ScopedSessionTest(fixtures.MappedTest):
                     with_for_update=None,
                     identity_token=None,
                     execution_options=util.EMPTY_DICT,
+                    bind_arguments=None,
                 ),
             ],
         )
index 3b11c24a81110d9daffb25ad6a95f0a3e3cebdfa..fcb844291cf331ed09194cf92b35d2acd82d1ea4 100644 (file)
@@ -4,7 +4,10 @@ this script parses the documentation files and runs black on the code blocks
 that it extracts from the documentation.
 
 .. versionadded:: 2.0
+
 """
+# mypy: ignore-errors
+
 from argparse import ArgumentParser
 from argparse import RawDescriptionHelpFormatter
 from collections.abc import Iterator
index c21db9d60121f9be2ac5361e1cb87979dd054a16..cc039d4d68597aad82a7ebbb2d6c8d2d40c9644a 100644 (file)
@@ -40,16 +40,16 @@ typed by hand.
 .. versionadded:: 2.0
 
 """
+# mypy: ignore-errors
+
 from __future__ import annotations
 
-from argparse import ArgumentParser
 import collections
 import importlib
 import inspect
 import os
 from pathlib import Path
 import re
-import shutil
 import sys
 from tempfile import NamedTemporaryFile
 import textwrap
@@ -65,9 +65,9 @@ from typing import TypeVar
 from sqlalchemy import util
 from sqlalchemy.util import compat
 from sqlalchemy.util import langhelpers
-from sqlalchemy.util.langhelpers import console_scripts
 from sqlalchemy.util.langhelpers import format_argspec_plus
 from sqlalchemy.util.langhelpers import inject_docstring_text
+from sqlalchemy.util.tool_support import code_writer_cmd
 
 is_posix = os.name == "posix"
 
@@ -340,7 +340,7 @@ def process_class(
         instrument(buf, prop, clslevel=True)
 
 
-def process_module(modname: str, filename: str) -> str:
+def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str:
 
     class_entries = classes[modname]
 
@@ -348,7 +348,9 @@ def process_module(modname: str, filename: str) -> str:
     # current working directory, so that black / zimports use
     # local pyproject.toml
     with NamedTemporaryFile(
-        mode="w", delete=False, suffix=".py", dir=Path(filename).parent
+        mode="w",
+        delete=False,
+        suffix=".py",
     ) as buf, open(filename) as orig_py:
 
         in_block = False
@@ -358,7 +360,7 @@ def process_module(modname: str, filename: str) -> str:
             if m:
                 current_clsname = m.group(1)
                 args = class_entries[current_clsname]
-                sys.stderr.write(
+                cmd.write_status(
                     f"Generating attributes for class {current_clsname}\n"
                 )
                 in_block = True
@@ -379,39 +381,21 @@ def process_module(modname: str, filename: str) -> str:
     return buf.name
 
 
-def run_module(modname, stdout):
+def run_module(modname: str, cmd: code_writer_cmd) -> None:
 
-    sys.stderr.write(f"importing module {modname}\n")
+    cmd.write_status(f"importing module {modname}\n")
     mod = importlib.import_module(modname)
-    filename = destination_path = mod.__file__
-    assert filename is not None
-
-    tempfile = process_module(modname, filename)
-
-    ignore_output = stdout
-
-    console_scripts(
-        str(tempfile),
-        {"entrypoint": "zimports"},
-        ignore_output=ignore_output,
-    )
+    destination_path = mod.__file__
+    assert destination_path is not None
 
-    console_scripts(
-        str(tempfile),
-        {"entrypoint": "black"},
-        ignore_output=ignore_output,
-    )
+    tempfile = process_module(modname, destination_path, cmd)
 
-    if stdout:
-        with open(tempfile) as tf:
-            print(tf.read())
-        os.unlink(tempfile)
-    else:
-        sys.stderr.write(f"Writing {destination_path}...\n")
-        shutil.move(tempfile, destination_path)
+    cmd.run_zimports(tempfile)
+    cmd.run_black(tempfile)
+    cmd.write_output_file_from_tempfile(tempfile, destination_path)
 
 
-def main(args):
+def main(cmd: code_writer_cmd) -> None:
     from sqlalchemy import util
     from sqlalchemy.util import langhelpers
 
@@ -420,8 +404,8 @@ def main(args):
     ) = create_proxy_methods
 
     for entry in entries:
-        if args.module in {"all", entry}:
-            run_module(entry, args.stdout)
+        if cmd.args.module in {"all", entry}:
+            run_module(entry, cmd)
 
 
 entries = [
@@ -432,17 +416,16 @@ entries = [
 ]
 
 if __name__ == "__main__":
-    parser = ArgumentParser()
-    parser.add_argument(
-        "--module",
-        choices=entries + ["all"],
-        default="all",
-        help="Which file to generate. Default is to regenerate all files",
-    )
-    parser.add_argument(
-        "--stdout",
-        action="store_true",
-        help="Write to stdout instead of saving to file",
-    )
-    args = parser.parse_args()
-    main(args)
+
+    cmd = code_writer_cmd(__file__)
+
+    with cmd.add_arguments() as parser:
+        parser.add_argument(
+            "--module",
+            choices=entries + ["all"],
+            default="all",
+            help="Which file to generate. Default is to regenerate all files",
+        )
+
+    with cmd.run_program():
+        main(cmd)
index ff8f37840a5c1754de8bb7e836b07dcde53ea89d..d4557734407950e8996bc11c1b42560ac3eb619f 100644 (file)
@@ -16,19 +16,19 @@ combinatoric generated code approach.
 .. versionadded:: 2.0
 
 """
+# mypy: ignore-errors
+
 from __future__ import annotations
 
-from argparse import ArgumentParser
 import importlib
 import os
 from pathlib import Path
 import re
-import shutil
 import sys
 from tempfile import NamedTemporaryFile
 import textwrap
 
-from sqlalchemy.util.langhelpers import console_scripts
+from sqlalchemy.util.tool_support import code_writer_cmd
 
 is_posix = os.name == "posix"
 
@@ -36,13 +36,15 @@ is_posix = os.name == "posix"
 sys.path.append(str(Path(__file__).parent.parent))
 
 
-def process_module(modname: str, filename: str) -> str:
+def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str:
 
     # use tempfile in same path as the module, or at least in the
     # current working directory, so that black / zimports use
     # local pyproject.toml
     with NamedTemporaryFile(
-        mode="w", delete=False, suffix=".py", dir=Path(filename).parent
+        mode="w",
+        delete=False,
+        suffix=".py",
     ) as buf, open(filename) as orig_py:
         indent = ""
         in_block = False
@@ -64,7 +66,7 @@ def process_module(modname: str, filename: str) -> str:
                 start_index = int(m.group(4))
                 end_index = int(m.group(5))
 
-                sys.stderr.write(
+                cmd.write_status(
                     f"Generating {start_index}-{end_index} overloads "
                     f"attributes for "
                     f"class {'self.' if use_self else ''}{current_fnname} "
@@ -111,42 +113,24 @@ def {current_fnname}(
     return buf.name
 
 
-def run_module(modname, stdout):
+def run_module(modname: str, cmd: code_writer_cmd) -> None:
 
-    sys.stderr.write(f"importing module {modname}\n")
+    cmd.write_status(f"importing module {modname}\n")
     mod = importlib.import_module(modname)
-    filename = destination_path = mod.__file__
-    assert filename is not None
-
-    tempfile = process_module(modname, filename)
-
-    ignore_output = stdout
-
-    console_scripts(
-        str(tempfile),
-        {"entrypoint": "zimports"},
-        ignore_output=ignore_output,
-    )
+    destination_path = mod.__file__
+    assert destination_path is not None
 
-    console_scripts(
-        str(tempfile),
-        {"entrypoint": "black"},
-        ignore_output=ignore_output,
-    )
+    tempfile = process_module(modname, destination_path, cmd)
 
-    if stdout:
-        with open(tempfile) as tf:
-            print(tf.read())
-        os.unlink(tempfile)
-    else:
-        sys.stderr.write(f"Writing {destination_path}...\n")
-        shutil.move(tempfile, destination_path)
+    cmd.run_zimports(tempfile)
+    cmd.run_black(tempfile)
+    cmd.write_output_file_from_tempfile(tempfile, destination_path)
 
 
-def main(args):
+def main(cmd: code_writer_cmd) -> None:
     for modname in entries:
-        if args.module in {"all", modname}:
-            run_module(modname, args.stdout)
+        if cmd.args.module in {"all", modname}:
+            run_module(modname, cmd)
 
 
 entries = [
@@ -158,17 +142,16 @@ entries = [
 ]
 
 if __name__ == "__main__":
-    parser = ArgumentParser()
-    parser.add_argument(
-        "--module",
-        choices=entries + ["all"],
-        default="all",
-        help="Which file to generate. Default is to regenerate all files",
-    )
-    parser.add_argument(
-        "--stdout",
-        action="store_true",
-        help="Write to stdout instead of saving to file",
-    )
-    args = parser.parse_args()
-    main(args)
+
+    cmd = code_writer_cmd(__file__)
+
+    with cmd.add_arguments() as parser:
+        parser.add_argument(
+            "--module",
+            choices=entries + ["all"],
+            default="all",
+            help="Which file to generate. Default is to regenerate all files",
+        )
+
+    with cmd.run_program():
+        main(cmd)
index 4ef15374a61d886830bb49b86c9ac7851fc817f8..4afa2dc8e74a4891491e25335972b3738245aeea 100644 (file)
@@ -5,8 +5,11 @@
 
 from __future__ import annotations
 
-from argparse import ArgumentParser
 from pathlib import Path
+from typing import Any
+from typing import Iterable
+
+from sqlalchemy.util.tool_support import code_writer_cmd
 
 header = '''\
 """This file is automatically generated from the file
@@ -22,27 +25,27 @@ from __future__ import annotations
 
 home = Path(__file__).parent.parent
 this_file = Path(__file__).relative_to(home).as_posix()
-remove_str = '# anno only: '
+remove_str = "# anno only: "
 
-def run_operation(name: str, source: str, dest: str):
-    print("Running", name, "...", end="", flush=True)
 
-    source_data = Path(source).read_text().replace(remove_str, '')
-    dest_data = header.format(source=source, this_file=this_file) + source_data
+def run_operation(
+    name: str, source: str, dest: str, cmd: code_writer_cmd
+) -> None:
 
-    Path(dest).write_text(dest_data)
+    source_data = Path(source).read_text().replace(remove_str, "")
+    dest_data = header.format(source=source, this_file=this_file) + source_data
 
-    print(".. done")
+    cmd.write_output_file_from_text(dest_data, dest)
 
 
-def main(file: str):
+def main(file: str, cmd: code_writer_cmd) -> None:
     if file == "all":
-        operations = files.items()
+        operations: Iterable[Any] = files.items()
     else:
         operations = [(file, files[file])]
 
     for name, info in operations:
-        run_operation(name, info["source"], info["dest"])
+        run_operation(name, info["source"], info["dest"], cmd)
 
 
 files = {
@@ -53,8 +56,11 @@ files = {
 }
 
 if __name__ == "__main__":
-    parser = ArgumentParser()
-    parser.add_argument("--file", choices=list(files) + ["all"], default="all")
-
-    args = parser.parse_args()
-    main(args.file)
+    cmd = code_writer_cmd(__file__)
+    with cmd.add_arguments() as parser:
+        parser.add_argument(
+            "--file", choices=list(files) + ["all"], default="all"
+        )
+
+    with cmd.run_program():
+        main(cmd.args.file, cmd)
index 42a23c9f72bb4974d722d7050c7fedd97a51041c..de8098bcb8f67f027cfb6132efe7bc1dd7f29c10 100644 (file)
@@ -23,6 +23,8 @@ You can then set a breakpoint at the end of any adapt step:
 
 
 """  # noqa: E501
+# mypy: ignore-errors
+
 
 from __future__ import annotations
 
diff --git a/tox.ini b/tox.ini
index 0260dad0dc733c6839a57b05da8ac27f861afb87..144de79a7995ce12d893c36211cd4faa595677c7 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -184,7 +184,7 @@ setenv=
 [testenv:lint]
 basepython = python3
 deps=
-      flake8
+      flake8==5.0.0
       #flake8-import-order
       git+https://github.com/sqlalchemyorg/flake8-import-order@fix_options
       flake8-builtins
@@ -195,7 +195,7 @@ deps=
       # in case it requires a version pin
       pydocstyle
       pygments
-      black==22.3.0
+      black==22.8.0
       slotscheck>=0.12,<0.13
 
       # this is to satisfy the mypy plugin dependency
@@ -214,9 +214,9 @@ commands =
      slotscheck -m sqlalchemy
      env DISABLE_SQLALCHEMY_CEXT_RUNTIME=1 slotscheck -m sqlalchemy
      python ./tools/format_docs_code.py --check
-     sh -c 'python ./tools/generate_tuple_map_overloads.py && git diff --exit-code'
-     sh -c 'python ./tools/generate_proxy_methods.py && git diff --exit-code'
-     sh -c 'python ./tools/sync_test_files.py && git diff --exit-code'
+     python ./tools/generate_tuple_map_overloads.py --check
+     python ./tools/generate_proxy_methods.py --check
+     python ./tools/sync_test_files.py --check
 
 
 # "pep8" env was renamed to "lint".