]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-111201: auto-indentation in _pyrepl (#119348)
authorArnon Yaari <wiggin15@yahoo.com>
Wed, 22 May 2024 04:21:14 +0000 (00:21 -0400)
committerGitHub <noreply@github.com>
Wed, 22 May 2024 04:21:14 +0000 (06:21 +0200)
Co-authored-by: Ɓukasz Langa <lukasz@langa.pl>
Lib/_pyrepl/readline.py
Lib/test/test_pyrepl/test_pyrepl.py

index 787dbc06e58e18661d45cf16995642e8003dfd01..054a39b7442655ec34b73f2d37232f7dd28f54a4 100644 (file)
@@ -99,6 +99,7 @@ class ReadlineAlikeReader(historical_reader.HistoricalReader, CompletingReader):
     # Instance fields
     config: ReadlineConfig
     more_lines: MoreLinesCallable | None = None
+    last_used_indentation: str | None = None
 
     def __post_init__(self) -> None:
         super().__post_init__()
@@ -157,6 +158,11 @@ class ReadlineAlikeReader(historical_reader.HistoricalReader, CompletingReader):
             cut = 0
         return self.history[cut:]
 
+    def update_last_used_indentation(self) -> None:
+        indentation = _get_first_indentation(self.buffer)
+        if indentation is not None:
+            self.last_used_indentation = indentation
+
     # --- simplified support for reading multiline Python statements ---
 
     def collect_keymap(self) -> tuple[tuple[KeySpec, CommandName], ...]:
@@ -211,6 +217,28 @@ def _get_previous_line_indent(buffer: list[str], pos: int) -> tuple[int, int | N
     return prevlinestart, indent
 
 
+def _get_first_indentation(buffer: list[str]) -> str | None:
+    indented_line_start = None
+    for i in range(len(buffer)):
+        if (i < len(buffer) - 1
+            and buffer[i] == "\n"
+            and buffer[i + 1] in " \t"
+        ):
+            indented_line_start = i + 1
+        elif indented_line_start is not None and buffer[i] not in " \t\n":
+            return ''.join(buffer[indented_line_start : i])
+    return None
+
+
+def _is_last_char_colon(buffer: list[str]) -> bool:
+    i = len(buffer)
+    while i > 0:
+        i -= 1
+        if buffer[i] not in " \t\n":  # ignore whitespaces
+            return buffer[i] == ":"
+    return False
+
+
 class maybe_accept(commands.Command):
     def do(self) -> None:
         r: ReadlineAlikeReader
@@ -227,9 +255,18 @@ class maybe_accept(commands.Command):
             # auto-indent the next line like the previous line
             prevlinestart, indent = _get_previous_line_indent(r.buffer, r.pos)
             r.insert("\n")
-            if not self.reader.paste_mode and indent:
-                for i in range(prevlinestart, prevlinestart + indent):
-                    r.insert(r.buffer[i])
+            if not self.reader.paste_mode:
+                if indent:
+                    for i in range(prevlinestart, prevlinestart + indent):
+                        r.insert(r.buffer[i])
+                r.update_last_used_indentation()
+                if _is_last_char_colon(r.buffer):
+                    if r.last_used_indentation is not None:
+                        indentation = r.last_used_indentation
+                    else:
+                        # default
+                        indentation = " " * 4
+                    r.insert(indentation)
         elif not self.reader.paste_mode:
             self.finish = True
         else:
index b643ae5895c97e831f99f801f24f1752c7a5951a..930f6759fb0b482143bddd97b3ae234d51d83143 100644 (file)
@@ -5,19 +5,31 @@ import rlcompleter
 from unittest import TestCase
 from unittest.mock import patch
 
-from .support import FakeConsole, handle_all_events, handle_events_narrow_console
-from .support import more_lines, multiline_input, code_to_events
+from .support import (
+    FakeConsole,
+    handle_all_events,
+    handle_events_narrow_console,
+    more_lines,
+    multiline_input,
+    code_to_events,
+)
 from _pyrepl.console import Event
 from _pyrepl.readline import ReadlineAlikeReader, ReadlineConfig
 from _pyrepl.readline import multiline_input as readline_multiline_input
 
 
 class TestCursorPosition(TestCase):
+    def prepare_reader(self, events):
+        console = FakeConsole(events)
+        config = ReadlineConfig(readline_completer=None)
+        reader = ReadlineAlikeReader(console=console, config=config)
+        return reader
+
     def test_up_arrow_simple(self):
         # fmt: off
         code = (
-            'def f():\n'
-            '  ...\n'
+            "def f():\n"
+            "  ...\n"
         )
         # fmt: on
         events = itertools.chain(
@@ -34,8 +46,8 @@ class TestCursorPosition(TestCase):
     def test_down_arrow_end_of_input(self):
         # fmt: off
         code = (
-            'def f():\n'
-            '  ...\n'
+            "def f():\n"
+            "  ...\n"
         )
         # fmt: on
         events = itertools.chain(
@@ -300,6 +312,79 @@ class TestCursorPosition(TestCase):
         self.assertEqual(reader.pos, 10)
         self.assertEqual(reader.cxy, (1, 1))
 
+    def test_auto_indent_default(self):
+        # fmt: off
+        input_code = (
+            "def f():\n"
+                "pass\n\n"
+        )
+
+        output_code = (
+            "def f():\n"
+            "    pass\n"
+            "    "
+        )
+        # fmt: on
+
+    def test_auto_indent_continuation(self):
+        # auto indenting according to previous user indentation
+        # fmt: off
+        events = itertools.chain(
+            code_to_events("def f():\n"),
+            # add backspace to delete default auto-indent
+            [
+                Event(evt="key", data="backspace", raw=bytearray(b"\x7f")),
+            ],
+            code_to_events(
+                "  pass\n"
+                  "pass\n\n"
+            ),
+        )
+
+        output_code = (
+            "def f():\n"
+            "  pass\n"
+            "  pass\n"
+            "  "
+        )
+        # fmt: on
+
+        reader = self.prepare_reader(events)
+        output = multiline_input(reader)
+        self.assertEqual(output, output_code)
+
+    def test_auto_indent_prev_block(self):
+        # auto indenting according to indentation in different block
+        # fmt: off
+        events = itertools.chain(
+            code_to_events("def f():\n"),
+            # add backspace to delete default auto-indent
+            [
+                Event(evt="key", data="backspace", raw=bytearray(b"\x7f")),
+            ],
+            code_to_events(
+                "  pass\n"
+                "pass\n\n"
+            ),
+            code_to_events(
+                "def g():\n"
+                  "pass\n\n"
+            ),
+        )
+
+
+        output_code = (
+            "def g():\n"
+            "  pass\n"
+            "  "
+        )
+        # fmt: on
+
+        reader = self.prepare_reader(events)
+        output1 = multiline_input(reader)
+        output2 = multiline_input(reader)
+        self.assertEqual(output2, output_code)
+
 
 class TestPyReplOutput(TestCase):
     def prepare_reader(self, events):
@@ -316,14 +401,12 @@ class TestPyReplOutput(TestCase):
 
     def test_multiline_edit(self):
         events = itertools.chain(
-            code_to_events("def f():\n  ...\n\n"),
+            code_to_events("def f():\n...\n\n"),
             [
                 Event(evt="key", data="up", raw=bytearray(b"\x1bOA")),
                 Event(evt="key", data="up", raw=bytearray(b"\x1bOA")),
                 Event(evt="key", data="up", raw=bytearray(b"\x1bOA")),
                 Event(evt="key", data="right", raw=bytearray(b"\x1bOC")),
-                Event(evt="key", data="right", raw=bytearray(b"\x1bOC")),
-                Event(evt="key", data="right", raw=bytearray(b"\x1bOC")),
                 Event(evt="key", data="backspace", raw=bytearray(b"\x7f")),
                 Event(evt="key", data="g", raw=bytearray(b"g")),
                 Event(evt="key", data="down", raw=bytearray(b"\x1bOB")),
@@ -334,9 +417,9 @@ class TestPyReplOutput(TestCase):
         reader = self.prepare_reader(events)
 
         output = multiline_input(reader)
-        self.assertEqual(output, "def f():\n  ...\n  ")
+        self.assertEqual(output, "def f():\n    ...\n    ")
         output = multiline_input(reader)
-        self.assertEqual(output, "def g():\n  ...\n  ")
+        self.assertEqual(output, "def g():\n    ...\n    ")
 
     def test_history_navigation_with_up_arrow(self):
         events = itertools.chain(
@@ -485,6 +568,7 @@ class TestPyReplCompleter(TestCase):
             @property
             def test_func(self):
                 import warnings
+
                 warnings.warn("warnings\n")
                 return None
 
@@ -508,12 +592,12 @@ class TestPasteEvent(TestCase):
     def test_paste(self):
         # fmt: off
         code = (
-            'def a():\n'
-            '  for x in range(10):\n'
-            '    if x%2:\n'
-            '      print(x)\n'
-            '    else:\n'
-            '      pass\n'
+            "def a():\n"
+            "  for x in range(10):\n"
+            "    if x%2:\n"
+            "      print(x)\n"
+            "    else:\n"
+            "      pass\n"
         )
         # fmt: on
 
@@ -534,10 +618,10 @@ class TestPasteEvent(TestCase):
     def test_paste_mid_newlines(self):
         # fmt: off
         code = (
-            'def f():\n'
-            '  x = y\n'
-            '  \n'
-            '  y = z\n'
+            "def f():\n"
+            "  x = y\n"
+            "  \n"
+            "  y = z\n"
         )
         # fmt: on
 
@@ -558,16 +642,16 @@ class TestPasteEvent(TestCase):
     def test_paste_mid_newlines_not_in_paste_mode(self):
         # fmt: off
         code = (
-            'def f():\n'
-            '  x = y\n'
-            '  \n'
-            '  y = z\n\n'
+            "def f():\n"
+                "x = y\n"
+                "\n"
+                "y = z\n\n"
         )
 
         expected = (
-            'def f():\n'
-            '  x = y\n'
-            '    '
+            "def f():\n"
+            "    x = y\n"
+            "    "
         )
         # fmt: on
 
@@ -579,20 +663,20 @@ class TestPasteEvent(TestCase):
     def test_paste_not_in_paste_mode(self):
         # fmt: off
         input_code = (
-            'def a():\n'
-            '  for x in range(10):\n'
-            '    if x%2:\n'
-            '      print(x)\n'
-            '    else:\n'
-            '      pass\n\n'
+            "def a():\n"
+                "for x in range(10):\n"
+                    "if x%2:\n"
+                        "print(x)\n"
+                    "else:\n"
+                        "pass\n\n"
         )
 
         output_code = (
-            'def a():\n'
-            '  for x in range(10):\n'
-            '      if x%2:\n'
-            '            print(x)\n'
-            '                else:'
+            "def a():\n"
+            "    for x in range(10):\n"
+            "        if x%2:\n"
+            "            print(x)\n"
+            "            else:"
         )
         # fmt: on
 
@@ -605,25 +689,25 @@ class TestPasteEvent(TestCase):
         """Test that bracketed paste using \x1b[200~ and \x1b[201~ works."""
         # fmt: off
         input_code = (
-            'def a():\n'
-            '  for x in range(10):\n'
-            '\n'
-            '    if x%2:\n'
-            '      print(x)\n'
-            '\n'
-            '    else:\n'
-            '      pass\n'
+            "def a():\n"
+            "  for x in range(10):\n"
+            "\n"
+            "    if x%2:\n"
+            "      print(x)\n"
+            "\n"
+            "    else:\n"
+            "      pass\n"
         )
 
         output_code = (
-            'def a():\n'
-            '  for x in range(10):\n'
-            '\n'
-            '    if x%2:\n'
-            '      print(x)\n'
-            '\n'
-            '    else:\n'
-            '      pass\n'
+            "def a():\n"
+            "  for x in range(10):\n"
+            "\n"
+            "    if x%2:\n"
+            "      print(x)\n"
+            "\n"
+            "    else:\n"
+            "      pass\n"
         )
         # fmt: on