From: Mike Bayer Date: Wed, 18 Jan 2023 17:45:42 +0000 (-0500) Subject: refactor code generation tools , include --check command X-Git-Tag: rel_2_0_0rc3~2^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cd96ffe287e26651f8dce4f688bf87af1e423f06;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git refactor code generation tools , include --check command in particular it looks like CI was not picking up on the "git diff" oriented commands, which were failing to run due to pathing issues. As we were setting cwd for black/zimports relative to sqlalchemy library, and tox installs it in the venv, black/zimports would fail to run from tox, and since these are subprocess.run we didn't pick up the failure. This overall locks down how zimports/black are run so that we are definitely from the source root, by using the location of tools/ to determine the root. Fixes: #8892 Change-Id: I7c54b747edd5a80e0c699b8456febf66d8b62375 --- diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 2b2a949bef..aafe03673f 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -877,6 +877,7 @@ class scoped_session(Generic[_S]): with_for_update: Optional[ForUpdateArg] = None, identity_token: Optional[Any] = None, execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, ) -> Optional[_O]: r"""Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -975,6 +976,13 @@ class scoped_session(Generic[_S]): :ref:`orm_queryguide_execution_options` - ORM-specific execution options + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + .. versionadded: 2.0.0rc1 + :return: The object instance, or ``None``. @@ -988,15 +996,18 @@ class scoped_session(Generic[_S]): with_for_update=with_for_update, identity_token=identity_token, execution_options=execution_options, + bind_arguments=bind_arguments, ) def get_bind( self, mapper: Optional[_EntityBindKey[_O]] = None, + *, clause: Optional[ClauseElement] = None, bind: Optional[_SessionBind] = None, _sa_skip_events: Optional[bool] = None, _sa_skip_for_implicit_returning: bool = False, + **kw: Any, ) -> Union[Engine, Connection]: r"""Return a "bind" to which this :class:`.Session` is bound. @@ -1082,6 +1093,7 @@ class scoped_session(Generic[_S]): bind=bind, _sa_skip_events=_sa_skip_events, _sa_skip_for_implicit_returning=_sa_skip_for_implicit_returning, + **kw, ) def is_modified( diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 5058b7e785..7671480452 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -19,7 +19,6 @@ import hashlib import inspect import itertools import operator -import os import re import sys import textwrap @@ -34,7 +33,6 @@ from typing import Generic from typing import Iterator from typing import List from typing import Mapping -from typing import no_type_check from typing import NoReturn from typing import Optional from typing import overload @@ -2180,45 +2178,3 @@ def has_compiled_ext(raise_=False): ) else: return False - - -@no_type_check -def console_scripts( - path: str, options: dict, ignore_output: bool = False -) -> None: - - import subprocess - import shlex - from pathlib import Path - - is_posix = os.name == "posix" - - entrypoint_name = options["entrypoint"] - - for entry in compat.importlib_metadata_get("console_scripts"): - if entry.name == entrypoint_name: - impl = entry - break - else: - raise Exception( - f"Could not find entrypoint console_scripts.{entrypoint_name}" - ) - cmdline_options_str = options.get("options", "") - cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [ - path - ] - - kw = {} - if ignore_output: - kw["stdout"] = kw["stderr"] = subprocess.DEVNULL - - subprocess.run( - [ - sys.executable, - "-c", - "import %s; %s.%s()" % (impl.module, impl.module, impl.attr), - ] - + cmdline_options_list, - cwd=Path(__file__).parent.parent, - **kw, - ) diff --git a/lib/sqlalchemy/util/tool_support.py b/lib/sqlalchemy/util/tool_support.py new file mode 100644 index 0000000000..5a2fc3ba05 --- /dev/null +++ b/lib/sqlalchemy/util/tool_support.py @@ -0,0 +1,198 @@ +# util/tool_support.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls +"""support routines for the helpers in tools/. + +These aren't imported by the enclosing util package as the are not +needed for normal library use. + +""" +from __future__ import annotations + +from argparse import ArgumentParser +from argparse import Namespace +import contextlib +import difflib +import os +from pathlib import Path +import shlex +import shutil +import subprocess +import sys +from typing import Any +from typing import Dict +from typing import Iterator +from typing import Optional + +from . import compat + + +class code_writer_cmd: + parser: ArgumentParser + args: Namespace + suppress_output: bool + diffs_detected: bool + source_root: Path + pyproject_toml_path: Path + + def __init__(self, tool_script: str): + self.source_root = Path(tool_script).parent.parent + self.pyproject_toml_path = self.source_root / Path("pyproject.toml") + assert self.pyproject_toml_path.exists() + + self.parser = ArgumentParser() + self.parser.add_argument( + "--stdout", + action="store_true", + help="Write to stdout instead of saving to file", + ) + self.parser.add_argument( + "-c", + "--check", + help="Don't write the files back, just return the " + "status. Return code 0 means nothing would change. " + "Return code 1 means some files would be reformatted", + action="store_true", + ) + + def run_zimports(self, tempfile: str) -> None: + self._run_console_script( + str(tempfile), + { + "entrypoint": "zimports", + "options": f"--toml-config {self.pyproject_toml_path}", + }, + ) + + def run_black(self, tempfile: str) -> None: + self._run_console_script( + str(tempfile), + { + "entrypoint": "black", + "options": f"--config {self.pyproject_toml_path}", + }, + ) + + def _run_console_script(self, path: str, options: Dict[str, Any]) -> None: + """Run a Python console application from within the process. + + Used for black, zimports + + """ + + is_posix = os.name == "posix" + + entrypoint_name = options["entrypoint"] + + for entry in compat.importlib_metadata_get("console_scripts"): + if entry.name == entrypoint_name: + impl = entry + break + else: + raise Exception( + f"Could not find entrypoint console_scripts.{entrypoint_name}" + ) + cmdline_options_str = options.get("options", "") + cmdline_options_list = shlex.split( + cmdline_options_str, posix=is_posix + ) + [path] + + kw: Dict[str, Any] = {} + if self.suppress_output: + kw["stdout"] = kw["stderr"] = subprocess.DEVNULL + + subprocess.run( + [ + sys.executable, + "-c", + "import %s; %s.%s()" % (impl.module, impl.module, impl.attr), + ] + + cmdline_options_list, + cwd=str(self.source_root), + **kw, + ) + + def write_status(self, *text: str) -> None: + if not self.suppress_output: + sys.stderr.write(" ".join(text)) + + def write_output_file_from_text( + self, text: str, destination_path: str + ) -> None: + if self.args.check: + self._run_diff(destination_path, source=text) + elif self.args.stdout: + print(text) + else: + self.write_status(f"Writing {destination_path}...") + Path(destination_path).write_text(text) + self.write_status("done\n") + + def write_output_file_from_tempfile( + self, tempfile: str, destination_path: str + ) -> None: + if self.args.check: + self._run_diff(destination_path, source_file=tempfile) + os.unlink(tempfile) + elif self.args.stdout: + with open(tempfile) as tf: + print(tf.read()) + os.unlink(tempfile) + else: + self.write_status(f"Writing {destination_path}...") + shutil.move(tempfile, destination_path) + self.write_status("done\n") + + def _run_diff( + self, + destination_path: str, + *, + source: Optional[str] = None, + source_file: Optional[str] = None, + ) -> None: + if source_file: + with open(source_file) as tf: + source_lines = list(tf) + elif source is not None: + source_lines = source.splitlines(keepends=True) + else: + assert False, "source or source_file is required" + + with open(destination_path) as dp: + d = difflib.unified_diff( + list(dp), + source_lines, + fromfile=destination_path, + tofile="", + n=3, + lineterm="\n", + ) + d_as_list = list(d) + if d_as_list: + self.diffs_detected = True + print("".join(d_as_list)) + + @contextlib.contextmanager + def add_arguments(self) -> Iterator[ArgumentParser]: + yield self.parser + + @contextlib.contextmanager + def run_program(self) -> Iterator[None]: + self.args = self.parser.parse_args() + if self.args.check: + self.diffs_detected = False + self.suppress_output = True + elif self.args.stdout: + self.suppress_output = True + else: + self.suppress_output = False + yield + + if self.args.check and self.diffs_detected: + sys.exit(1) + else: + sys.exit(0) diff --git a/test/orm/test_scoping.py b/test/orm/test_scoping.py index 22e1178aa7..8c6ddfa0e5 100644 --- a/test/orm/test_scoping.py +++ b/test/orm/test_scoping.py @@ -160,6 +160,7 @@ class ScopedSessionTest(fixtures.MappedTest): with_for_update=None, identity_token=None, execution_options=util.EMPTY_DICT, + bind_arguments=None, ), ], ) diff --git a/tools/format_docs_code.py b/tools/format_docs_code.py index 3b11c24a81..fcb844291c 100644 --- a/tools/format_docs_code.py +++ b/tools/format_docs_code.py @@ -4,7 +4,10 @@ this script parses the documentation files and runs black on the code blocks that it extracts from the documentation. .. versionadded:: 2.0 + """ +# mypy: ignore-errors + from argparse import ArgumentParser from argparse import RawDescriptionHelpFormatter from collections.abc import Iterator diff --git a/tools/generate_proxy_methods.py b/tools/generate_proxy_methods.py index c21db9d601..cc039d4d68 100644 --- a/tools/generate_proxy_methods.py +++ b/tools/generate_proxy_methods.py @@ -40,16 +40,16 @@ typed by hand. .. versionadded:: 2.0 """ +# mypy: ignore-errors + from __future__ import annotations -from argparse import ArgumentParser import collections import importlib import inspect import os from pathlib import Path import re -import shutil import sys from tempfile import NamedTemporaryFile import textwrap @@ -65,9 +65,9 @@ from typing import TypeVar from sqlalchemy import util from sqlalchemy.util import compat from sqlalchemy.util import langhelpers -from sqlalchemy.util.langhelpers import console_scripts from sqlalchemy.util.langhelpers import format_argspec_plus from sqlalchemy.util.langhelpers import inject_docstring_text +from sqlalchemy.util.tool_support import code_writer_cmd is_posix = os.name == "posix" @@ -340,7 +340,7 @@ def process_class( instrument(buf, prop, clslevel=True) -def process_module(modname: str, filename: str) -> str: +def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str: class_entries = classes[modname] @@ -348,7 +348,9 @@ def process_module(modname: str, filename: str) -> str: # current working directory, so that black / zimports use # local pyproject.toml with NamedTemporaryFile( - mode="w", delete=False, suffix=".py", dir=Path(filename).parent + mode="w", + delete=False, + suffix=".py", ) as buf, open(filename) as orig_py: in_block = False @@ -358,7 +360,7 @@ def process_module(modname: str, filename: str) -> str: if m: current_clsname = m.group(1) args = class_entries[current_clsname] - sys.stderr.write( + cmd.write_status( f"Generating attributes for class {current_clsname}\n" ) in_block = True @@ -379,39 +381,21 @@ def process_module(modname: str, filename: str) -> str: return buf.name -def run_module(modname, stdout): +def run_module(modname: str, cmd: code_writer_cmd) -> None: - sys.stderr.write(f"importing module {modname}\n") + cmd.write_status(f"importing module {modname}\n") mod = importlib.import_module(modname) - filename = destination_path = mod.__file__ - assert filename is not None - - tempfile = process_module(modname, filename) - - ignore_output = stdout - - console_scripts( - str(tempfile), - {"entrypoint": "zimports"}, - ignore_output=ignore_output, - ) + destination_path = mod.__file__ + assert destination_path is not None - console_scripts( - str(tempfile), - {"entrypoint": "black"}, - ignore_output=ignore_output, - ) + tempfile = process_module(modname, destination_path, cmd) - if stdout: - with open(tempfile) as tf: - print(tf.read()) - os.unlink(tempfile) - else: - sys.stderr.write(f"Writing {destination_path}...\n") - shutil.move(tempfile, destination_path) + cmd.run_zimports(tempfile) + cmd.run_black(tempfile) + cmd.write_output_file_from_tempfile(tempfile, destination_path) -def main(args): +def main(cmd: code_writer_cmd) -> None: from sqlalchemy import util from sqlalchemy.util import langhelpers @@ -420,8 +404,8 @@ def main(args): ) = create_proxy_methods for entry in entries: - if args.module in {"all", entry}: - run_module(entry, args.stdout) + if cmd.args.module in {"all", entry}: + run_module(entry, cmd) entries = [ @@ -432,17 +416,16 @@ entries = [ ] if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument( - "--module", - choices=entries + ["all"], - default="all", - help="Which file to generate. Default is to regenerate all files", - ) - parser.add_argument( - "--stdout", - action="store_true", - help="Write to stdout instead of saving to file", - ) - args = parser.parse_args() - main(args) + + cmd = code_writer_cmd(__file__) + + with cmd.add_arguments() as parser: + parser.add_argument( + "--module", + choices=entries + ["all"], + default="all", + help="Which file to generate. Default is to regenerate all files", + ) + + with cmd.run_program(): + main(cmd) diff --git a/tools/generate_tuple_map_overloads.py b/tools/generate_tuple_map_overloads.py index ff8f37840a..d455773440 100644 --- a/tools/generate_tuple_map_overloads.py +++ b/tools/generate_tuple_map_overloads.py @@ -16,19 +16,19 @@ combinatoric generated code approach. .. versionadded:: 2.0 """ +# mypy: ignore-errors + from __future__ import annotations -from argparse import ArgumentParser import importlib import os from pathlib import Path import re -import shutil import sys from tempfile import NamedTemporaryFile import textwrap -from sqlalchemy.util.langhelpers import console_scripts +from sqlalchemy.util.tool_support import code_writer_cmd is_posix = os.name == "posix" @@ -36,13 +36,15 @@ is_posix = os.name == "posix" sys.path.append(str(Path(__file__).parent.parent)) -def process_module(modname: str, filename: str) -> str: +def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str: # use tempfile in same path as the module, or at least in the # current working directory, so that black / zimports use # local pyproject.toml with NamedTemporaryFile( - mode="w", delete=False, suffix=".py", dir=Path(filename).parent + mode="w", + delete=False, + suffix=".py", ) as buf, open(filename) as orig_py: indent = "" in_block = False @@ -64,7 +66,7 @@ def process_module(modname: str, filename: str) -> str: start_index = int(m.group(4)) end_index = int(m.group(5)) - sys.stderr.write( + cmd.write_status( f"Generating {start_index}-{end_index} overloads " f"attributes for " f"class {'self.' if use_self else ''}{current_fnname} " @@ -111,42 +113,24 @@ def {current_fnname}( return buf.name -def run_module(modname, stdout): +def run_module(modname: str, cmd: code_writer_cmd) -> None: - sys.stderr.write(f"importing module {modname}\n") + cmd.write_status(f"importing module {modname}\n") mod = importlib.import_module(modname) - filename = destination_path = mod.__file__ - assert filename is not None - - tempfile = process_module(modname, filename) - - ignore_output = stdout - - console_scripts( - str(tempfile), - {"entrypoint": "zimports"}, - ignore_output=ignore_output, - ) + destination_path = mod.__file__ + assert destination_path is not None - console_scripts( - str(tempfile), - {"entrypoint": "black"}, - ignore_output=ignore_output, - ) + tempfile = process_module(modname, destination_path, cmd) - if stdout: - with open(tempfile) as tf: - print(tf.read()) - os.unlink(tempfile) - else: - sys.stderr.write(f"Writing {destination_path}...\n") - shutil.move(tempfile, destination_path) + cmd.run_zimports(tempfile) + cmd.run_black(tempfile) + cmd.write_output_file_from_tempfile(tempfile, destination_path) -def main(args): +def main(cmd: code_writer_cmd) -> None: for modname in entries: - if args.module in {"all", modname}: - run_module(modname, args.stdout) + if cmd.args.module in {"all", modname}: + run_module(modname, cmd) entries = [ @@ -158,17 +142,16 @@ entries = [ ] if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument( - "--module", - choices=entries + ["all"], - default="all", - help="Which file to generate. Default is to regenerate all files", - ) - parser.add_argument( - "--stdout", - action="store_true", - help="Write to stdout instead of saving to file", - ) - args = parser.parse_args() - main(args) + + cmd = code_writer_cmd(__file__) + + with cmd.add_arguments() as parser: + parser.add_argument( + "--module", + choices=entries + ["all"], + default="all", + help="Which file to generate. Default is to regenerate all files", + ) + + with cmd.run_program(): + main(cmd) diff --git a/tools/sync_test_files.py b/tools/sync_test_files.py index 4ef15374a6..4afa2dc8e7 100644 --- a/tools/sync_test_files.py +++ b/tools/sync_test_files.py @@ -5,8 +5,11 @@ from __future__ import annotations -from argparse import ArgumentParser from pathlib import Path +from typing import Any +from typing import Iterable + +from sqlalchemy.util.tool_support import code_writer_cmd header = '''\ """This file is automatically generated from the file @@ -22,27 +25,27 @@ from __future__ import annotations home = Path(__file__).parent.parent this_file = Path(__file__).relative_to(home).as_posix() -remove_str = '# anno only: ' +remove_str = "# anno only: " -def run_operation(name: str, source: str, dest: str): - print("Running", name, "...", end="", flush=True) - source_data = Path(source).read_text().replace(remove_str, '') - dest_data = header.format(source=source, this_file=this_file) + source_data +def run_operation( + name: str, source: str, dest: str, cmd: code_writer_cmd +) -> None: - Path(dest).write_text(dest_data) + source_data = Path(source).read_text().replace(remove_str, "") + dest_data = header.format(source=source, this_file=this_file) + source_data - print(".. done") + cmd.write_output_file_from_text(dest_data, dest) -def main(file: str): +def main(file: str, cmd: code_writer_cmd) -> None: if file == "all": - operations = files.items() + operations: Iterable[Any] = files.items() else: operations = [(file, files[file])] for name, info in operations: - run_operation(name, info["source"], info["dest"]) + run_operation(name, info["source"], info["dest"], cmd) files = { @@ -53,8 +56,11 @@ files = { } if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--file", choices=list(files) + ["all"], default="all") - - args = parser.parse_args() - main(args.file) + cmd = code_writer_cmd(__file__) + with cmd.add_arguments() as parser: + parser.add_argument( + "--file", choices=list(files) + ["all"], default="all" + ) + + with cmd.run_program(): + main(cmd.args.file, cmd) diff --git a/tools/trace_orm_adapter.py b/tools/trace_orm_adapter.py index 42a23c9f72..de8098bcb8 100644 --- a/tools/trace_orm_adapter.py +++ b/tools/trace_orm_adapter.py @@ -23,6 +23,8 @@ You can then set a breakpoint at the end of any adapt step: """ # noqa: E501 +# mypy: ignore-errors + from __future__ import annotations diff --git a/tox.ini b/tox.ini index 0260dad0dc..144de79a79 100644 --- a/tox.ini +++ b/tox.ini @@ -184,7 +184,7 @@ setenv= [testenv:lint] basepython = python3 deps= - flake8 + flake8==5.0.0 #flake8-import-order git+https://github.com/sqlalchemyorg/flake8-import-order@fix_options flake8-builtins @@ -195,7 +195,7 @@ deps= # in case it requires a version pin pydocstyle pygments - black==22.3.0 + black==22.8.0 slotscheck>=0.12,<0.13 # this is to satisfy the mypy plugin dependency @@ -214,9 +214,9 @@ commands = slotscheck -m sqlalchemy env DISABLE_SQLALCHEMY_CEXT_RUNTIME=1 slotscheck -m sqlalchemy python ./tools/format_docs_code.py --check - sh -c 'python ./tools/generate_tuple_map_overloads.py && git diff --exit-code' - sh -c 'python ./tools/generate_proxy_methods.py && git diff --exit-code' - sh -c 'python ./tools/sync_test_files.py && git diff --exit-code' + python ./tools/generate_tuple_map_overloads.py --check + python ./tools/generate_proxy_methods.py --check + python ./tools/sync_test_files.py --check # "pep8" env was renamed to "lint".