From: Daniele Varrazzo Date: Tue, 8 Aug 2023 22:18:23 +0000 (+0100) Subject: fix(async-to-sync): fold long strings as multiline X-Git-Tag: pool-3.2.0~12^2~58 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=e65fd4ca73165af668fa02ee22a9c905cf80443c;p=thirdparty%2Fpsycopg.git fix(async-to-sync): fold long strings as multiline --- diff --git a/tests/test_connection.py b/tests/test_connection.py index 9516281b8..73ec2fba8 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -236,14 +236,14 @@ def test_commit(conn): @pytest.mark.crdb_skip("deferrable") def test_commit_error(conn): - sql = [ - "drop table if exists selfref;", - "create table selfref (", - "x serial primary key,", - "y int references selfref (x) deferrable initially deferred)", - ] - - conn.execute("".join(sql)) + conn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) conn.commit() conn.execute("insert into selfref (y) values (-1)") diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 98edb2d21..d336c19de 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -234,14 +234,14 @@ async def test_commit(aconn): @pytest.mark.crdb_skip("deferrable") async def test_commit_error(aconn): - sql = [ - "drop table if exists selfref;", - "create table selfref (", - "x serial primary key,", - "y int references selfref (x) deferrable initially deferred)", - ] - - await aconn.execute("".join(sql)) + await aconn.execute( + """ + drop table if exists selfref; + create table selfref ( + x serial primary key, + y int references selfref (x) deferrable initially deferred) + """ + ) await aconn.commit() await aconn.execute("insert into selfref (y) values (-1)") diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 622d9d1f9..5377d0d1e 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -495,12 +495,13 @@ def test_rownumber_none(conn, query): def test_rownumber_mixed(conn): cur = conn.cursor() - queries = [ - "select x from generate_series(1, 3) x", - "set timezone to utc", - "select x from generate_series(4, 6) x", - ] - cur.execute(";\n".join(queries)) + cur.execute( + """ +select x from generate_series(1, 3) x; +set timezone to utc; +select x from generate_series(4, 6) x; +""" + ) assert cur.rownumber == 0 assert cur.fetchone() == (1,) assert cur.rownumber == 1 diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index c0c49e68a..ad3673f45 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -500,12 +500,13 @@ async def test_rownumber_none(aconn, query): async def test_rownumber_mixed(aconn): cur = aconn.cursor() - queries = [ - "select x from generate_series(1, 3) x", - "set timezone to utc", - "select x from generate_series(4, 6) x", - ] - await cur.execute(";\n".join(queries)) + await cur.execute( + """ +select x from generate_series(1, 3) x; +set timezone to utc; +select x from generate_series(4, 6) x; +""" + ) assert cur.rownumber == 0 assert await cur.fetchone() == (1,) assert cur.rownumber == 1 diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index e34c7ae43..02bc04539 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -6,6 +6,7 @@ from __future__ import annotations import os import sys +from typing import Any from argparse import ArgumentParser, Namespace import ast_comments as ast @@ -43,7 +44,7 @@ def tree_to_str(tree: ast.AST, filename: str) -> str: # from the original file '{os.path.basename(filename)}' # DO NOT CHANGE! Change the original file instead. """ - rv += ast.unparse(tree) + rv += unparse(tree) return rv @@ -208,6 +209,27 @@ class BlanksInserter(ast.NodeTransformer): return new_body +def unparse(tree: ast.AST) -> str: + rv: str = Unparser().visit(tree) + return rv + + +class Unparser(ast._Unparser): + """ + Try to emit long strings as multiline. + + The normal class only tries to emit docstrings as multiline, + but the resulting source doesn't pass flake8. + """ + + # Beware: private method. Tested with in Python 3.10. + def _write_constant(self, value: Any) -> None: + if isinstance(value, str) and len(value) > 50: + self._write_str_avoiding_backslashes(value) + else: + super()._write_constant(value) + + def parse_cmdline() -> Namespace: parser = ArgumentParser(description=__doc__) parser.add_argument("filename", metavar="FILE", help="the file to process")