--- /dev/null
+from __future__ import annotations
+
+import pkgutil
+import sys
+import tokenize
+from io import StringIO
+from contextlib import contextmanager
+from dataclasses import dataclass
+from itertools import chain
+from tokenize import TokenInfo
+
+TYPE_CHECKING = False
+
+if TYPE_CHECKING:
+ from typing import Any, Iterable, Iterator, Mapping
+
+
+def make_default_module_completer() -> ModuleCompleter:
+ # Inside pyrepl, __package__ is set to '_pyrepl'
+ return ModuleCompleter(namespace={'__package__': '_pyrepl'})
+
+
+class ModuleCompleter:
+ """A completer for Python import statements.
+
+ Examples:
+ - import <tab>
+ - import foo<tab>
+ - import foo.<tab>
+ - import foo as bar, baz<tab>
+
+ - from <tab>
+ - from foo<tab>
+ - from foo import <tab>
+ - from foo import bar<tab>
+ - from foo import (bar as baz, qux<tab>
+ """
+
+ def __init__(self, namespace: Mapping[str, Any] | None = None) -> None:
+ self.namespace = namespace or {}
+ self._global_cache: list[pkgutil.ModuleInfo] = []
+ self._curr_sys_path: list[str] = sys.path[:]
+
+ def get_completions(self, line: str) -> list[str]:
+ """Return the next possible import completions for 'line'."""
+ result = ImportParser(line).parse()
+ if not result:
+ return []
+ try:
+ return self.complete(*result)
+ except Exception:
+ # Some unexpected error occurred, make it look like
+ # no completions are available
+ return []
+
+ def complete(self, from_name: str | None, name: str | None) -> list[str]:
+ if from_name is None:
+ # import x.y.z<tab>
+ assert name is not None
+ path, prefix = self.get_path_and_prefix(name)
+ modules = self.find_modules(path, prefix)
+ return [self.format_completion(path, module) for module in modules]
+
+ if name is None:
+ # from x.y.z<tab>
+ path, prefix = self.get_path_and_prefix(from_name)
+ modules = self.find_modules(path, prefix)
+ return [self.format_completion(path, module) for module in modules]
+
+ # from x.y import z<tab>
+ return self.find_modules(from_name, name)
+
+ def find_modules(self, path: str, prefix: str) -> list[str]:
+ """Find all modules under 'path' that start with 'prefix'."""
+ modules = self._find_modules(path, prefix)
+ # Filter out invalid module names
+ # (for example those containing dashes that cannot be imported with 'import')
+ return [mod for mod in modules if mod.isidentifier()]
+
+ def _find_modules(self, path: str, prefix: str) -> list[str]:
+ if not path:
+ # Top-level import (e.g. `import foo<tab>`` or `from foo<tab>`)`
+ return [name for _, name, _ in self.global_cache
+ if name.startswith(prefix)]
+
+ if path.startswith('.'):
+ # Convert relative path to absolute path
+ package = self.namespace.get('__package__', '')
+ path = self.resolve_relative_name(path, package) # type: ignore[assignment]
+ if path is None:
+ return []
+
+ modules: Iterable[pkgutil.ModuleInfo] = self.global_cache
+ for segment in path.split('.'):
+ modules = [mod_info for mod_info in modules
+ if mod_info.ispkg and mod_info.name == segment]
+ modules = self.iter_submodules(modules)
+ return [module.name for module in modules
+ if module.name.startswith(prefix)]
+
+ def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]:
+ """Iterate over all submodules of the given parent modules."""
+ specs = [info.module_finder.find_spec(info.name, None)
+ for info in parent_modules if info.ispkg]
+ search_locations = set(chain.from_iterable(
+ getattr(spec, 'submodule_search_locations', [])
+ for spec in specs if spec
+ ))
+ return pkgutil.iter_modules(search_locations)
+
+ def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]:
+ """
+ Split a dotted name into an import path and a
+ final prefix that is to be completed.
+
+ Examples:
+ 'foo.bar' -> 'foo', 'bar'
+ 'foo.' -> 'foo', ''
+ '.foo' -> '.', 'foo'
+ """
+ if '.' not in dotted_name:
+ return '', dotted_name
+ if dotted_name.startswith('.'):
+ stripped = dotted_name.lstrip('.')
+ dots = '.' * (len(dotted_name) - len(stripped))
+ if '.' not in stripped:
+ return dots, stripped
+ path, prefix = stripped.rsplit('.', 1)
+ return dots + path, prefix
+ path, prefix = dotted_name.rsplit('.', 1)
+ return path, prefix
+
+ def format_completion(self, path: str, module: str) -> str:
+ if path == '' or path.endswith('.'):
+ return f'{path}{module}'
+ return f'{path}.{module}'
+
+ def resolve_relative_name(self, name: str, package: str) -> str | None:
+ """Resolve a relative module name to an absolute name.
+
+ Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo'
+ """
+ # taken from importlib._bootstrap
+ level = 0
+ for character in name:
+ if character != '.':
+ break
+ level += 1
+ bits = package.rsplit('.', level - 1)
+ if len(bits) < level:
+ return None
+ base = bits[0]
+ name = name[level:]
+ return f'{base}.{name}' if name else base
+
+ @property
+ def global_cache(self) -> list[pkgutil.ModuleInfo]:
+ """Global module cache"""
+ if not self._global_cache or self._curr_sys_path != sys.path:
+ self._curr_sys_path = sys.path[:]
+ # print('getting packages')
+ self._global_cache = list(pkgutil.iter_modules())
+ return self._global_cache
+
+
+class ImportParser:
+ """
+ Parses incomplete import statements that are
+ suitable for autocomplete suggestions.
+
+ Examples:
+ - import foo -> Result(from_name=None, name='foo')
+ - import foo. -> Result(from_name=None, name='foo.')
+ - from foo -> Result(from_name='foo', name=None)
+ - from foo import bar -> Result(from_name='foo', name='bar')
+ - from .foo import ( -> Result(from_name='.foo', name='')
+
+ Note that the parser works in reverse order, starting from the
+ last token in the input string. This makes the parser more robust
+ when parsing multiple statements.
+ """
+ _ignored_tokens = {
+ tokenize.INDENT, tokenize.DEDENT, tokenize.COMMENT,
+ tokenize.NL, tokenize.NEWLINE, tokenize.ENDMARKER
+ }
+ _keywords = {'import', 'from', 'as'}
+
+ def __init__(self, code: str) -> None:
+ self.code = code
+ tokens = []
+ try:
+ for t in tokenize.generate_tokens(StringIO(code).readline):
+ if t.type not in self._ignored_tokens:
+ tokens.append(t)
+ except tokenize.TokenError as e:
+ if 'unexpected EOF' not in str(e):
+ # unexpected EOF is fine, since we're parsing an
+ # incomplete statement, but other errors are not
+ # because we may not have all the tokens so it's
+ # safer to bail out
+ tokens = []
+ except SyntaxError:
+ tokens = []
+ self.tokens = TokenQueue(tokens[::-1])
+
+ def parse(self) -> tuple[str | None, str | None] | None:
+ if not (res := self._parse()):
+ return None
+ return res.from_name, res.name
+
+ def _parse(self) -> Result | None:
+ with self.tokens.save_state():
+ return self.parse_from_import()
+ with self.tokens.save_state():
+ return self.parse_import()
+
+ def parse_import(self) -> Result:
+ if self.code.rstrip().endswith('import') and self.code.endswith(' '):
+ return Result(name='')
+ if self.tokens.peek_string(','):
+ name = ''
+ else:
+ if self.code.endswith(' '):
+ raise ParseError('parse_import')
+ name = self.parse_dotted_name()
+ if name.startswith('.'):
+ raise ParseError('parse_import')
+ while self.tokens.peek_string(','):
+ self.tokens.pop()
+ self.parse_dotted_as_name()
+ if self.tokens.peek_string('import'):
+ return Result(name=name)
+ raise ParseError('parse_import')
+
+ def parse_from_import(self) -> Result:
+ stripped = self.code.rstrip()
+ if stripped.endswith('import') and self.code.endswith(' '):
+ return Result(from_name=self.parse_empty_from_import(), name='')
+ if stripped.endswith('from') and self.code.endswith(' '):
+ return Result(from_name='')
+ if self.tokens.peek_string('(') or self.tokens.peek_string(','):
+ return Result(from_name=self.parse_empty_from_import(), name='')
+ if self.code.endswith(' '):
+ raise ParseError('parse_from_import')
+ name = self.parse_dotted_name()
+ if '.' in name:
+ self.tokens.pop_string('from')
+ return Result(from_name=name)
+ if self.tokens.peek_string('from'):
+ return Result(from_name=name)
+ from_name = self.parse_empty_from_import()
+ return Result(from_name=from_name, name=name)
+
+ def parse_empty_from_import(self) -> str:
+ if self.tokens.peek_string(','):
+ self.tokens.pop()
+ self.parse_as_names()
+ if self.tokens.peek_string('('):
+ self.tokens.pop()
+ self.tokens.pop_string('import')
+ return self.parse_from()
+
+ def parse_from(self) -> str:
+ from_name = self.parse_dotted_name()
+ self.tokens.pop_string('from')
+ return from_name
+
+ def parse_dotted_as_name(self) -> str:
+ self.tokens.pop_name()
+ if self.tokens.peek_string('as'):
+ self.tokens.pop()
+ with self.tokens.save_state():
+ return self.parse_dotted_name()
+
+ def parse_dotted_name(self) -> str:
+ name = []
+ if self.tokens.peek_string('.'):
+ name.append('.')
+ self.tokens.pop()
+ if (self.tokens.peek_name()
+ and (tok := self.tokens.peek())
+ and tok.string not in self._keywords):
+ name.append(self.tokens.pop_name())
+ if not name:
+ raise ParseError('parse_dotted_name')
+ while self.tokens.peek_string('.'):
+ name.append('.')
+ self.tokens.pop()
+ if (self.tokens.peek_name()
+ and (tok := self.tokens.peek())
+ and tok.string not in self._keywords):
+ name.append(self.tokens.pop_name())
+ else:
+ break
+
+ while self.tokens.peek_string('.'):
+ name.append('.')
+ self.tokens.pop()
+ return ''.join(name[::-1])
+
+ def parse_as_names(self) -> None:
+ self.parse_as_name()
+ while self.tokens.peek_string(','):
+ self.tokens.pop()
+ self.parse_as_name()
+
+ def parse_as_name(self) -> None:
+ self.tokens.pop_name()
+ if self.tokens.peek_string('as'):
+ self.tokens.pop()
+ self.tokens.pop_name()
+
+
+class ParseError(Exception):
+ pass
+
+
+@dataclass(frozen=True)
+class Result:
+ from_name: str | None = None
+ name: str | None = None
+
+
+class TokenQueue:
+ """Provides helper functions for working with a sequence of tokens."""
+
+ def __init__(self, tokens: list[TokenInfo]) -> None:
+ self.tokens: list[TokenInfo] = tokens
+ self.index: int = 0
+ self.stack: list[int] = []
+
+ @contextmanager
+ def save_state(self) -> Any:
+ try:
+ self.stack.append(self.index)
+ yield
+ except ParseError:
+ self.index = self.stack.pop()
+ else:
+ self.stack.pop()
+
+ def __bool__(self) -> bool:
+ return self.index < len(self.tokens)
+
+ def peek(self) -> TokenInfo | None:
+ if not self:
+ return None
+ return self.tokens[self.index]
+
+ def peek_name(self) -> bool:
+ if not (tok := self.peek()):
+ return False
+ return tok.type == tokenize.NAME
+
+ def pop_name(self) -> str:
+ tok = self.pop()
+ if tok.type != tokenize.NAME:
+ raise ParseError('pop_name')
+ return tok.string
+
+ def peek_string(self, string: str) -> bool:
+ if not (tok := self.peek()):
+ return False
+ return tok.string == string
+
+ def pop_string(self, string: str) -> str:
+ tok = self.pop()
+ if tok.string != string:
+ raise ParseError('pop_string')
+ return tok.string
+
+ def pop(self) -> TokenInfo:
+ if not self:
+ raise ParseError('pop')
+ tok = self.tokens[self.index]
+ self.index += 1
+ return tok
def get_completions(self, stem: str) -> list[str]:
return []
+
+ def get_line(self) -> str:
+ """Return the current line until the cursor position."""
+ return ''.join(self.buffer[:self.pos])
from . import commands, historical_reader
from .completing_reader import CompletingReader
from .console import Console as ConsoleType
+from ._module_completer import ModuleCompleter, make_default_module_completer
Console: type[ConsoleType]
_error: tuple[type[Exception], ...] | type[Exception]
class ReadlineConfig:
readline_completer: Completer | None = None
completer_delims: frozenset[str] = frozenset(" \t\n`~!@#$%^&*()-=+[{]}\\|;:'\",<>/?")
-
+ module_completer: ModuleCompleter = field(default_factory=make_default_module_completer)
@dataclass(kw_only=True)
class ReadlineAlikeReader(historical_reader.HistoricalReader, CompletingReader):
return "".join(b[p + 1 : self.pos])
def get_completions(self, stem: str) -> list[str]:
+ if module_completions := self.get_module_completions():
+ return module_completions
if len(stem) == 0 and self.more_lines is not None:
b = self.buffer
p = self.pos
result.sort()
return result
+ def get_module_completions(self) -> list[str]:
+ line = self.get_line()
+ return self.config.module_completer.get_completions(line)
+
def get_trimmed_history(self, maxlength: int) -> list[str]:
if maxlength >= 0:
cut = len(self.history) - maxlength
code_to_events,
)
from _pyrepl.console import Event
+from _pyrepl._module_completer import ImportParser, ModuleCompleter
from _pyrepl.readline import (ReadlineAlikeReader, ReadlineConfig,
_ReadlineWrapper)
from _pyrepl.readline import multiline_input as readline_multiline_input
self.assertEqual(mock_stderr.getvalue(), "")
+class TestPyReplModuleCompleter(TestCase):
+ def setUp(self):
+ self._saved_sys_path = sys.path
+
+ def tearDown(self):
+ sys.path = self._saved_sys_path
+
+ def prepare_reader(self, events, namespace):
+ console = FakeConsole(events)
+ config = ReadlineConfig()
+ config.readline_completer = rlcompleter.Completer(namespace).complete
+ reader = ReadlineAlikeReader(console=console, config=config)
+ return reader
+
+ def test_import_completions(self):
+ import importlib
+ # Make iter_modules() search only the standard library.
+ # This makes the test more reliable in case there are
+ # other user packages/scripts on PYTHONPATH which can
+ # intefere with the completions.
+ lib_path = os.path.dirname(importlib.__path__[0])
+ sys.path = [lib_path]
+
+ cases = (
+ ("import path\t\n", "import pathlib"),
+ ("import importlib.\t\tres\t\n", "import importlib.resources"),
+ ("import importlib.resources.\t\ta\t\n", "import importlib.resources.abc"),
+ ("import foo, impo\t\n", "import foo, importlib"),
+ ("import foo as bar, impo\t\n", "import foo as bar, importlib"),
+ ("from impo\t\n", "from importlib"),
+ ("from importlib.res\t\n", "from importlib.resources"),
+ ("from importlib.\t\tres\t\n", "from importlib.resources"),
+ ("from importlib.resources.ab\t\n", "from importlib.resources.abc"),
+ ("from importlib import mac\t\n", "from importlib import machinery"),
+ ("from importlib import res\t\n", "from importlib import resources"),
+ ("from importlib.res\t import a\t\n", "from importlib.resources import abc"),
+ )
+ for code, expected in cases:
+ with self.subTest(code=code):
+ events = code_to_events(code)
+ reader = self.prepare_reader(events, namespace={})
+ output = reader.readline()
+ self.assertEqual(output, expected)
+
+ def test_relative_import_completions(self):
+ cases = (
+ ("from .readl\t\n", "from .readline"),
+ ("from . import readl\t\n", "from . import readline"),
+ )
+ for code, expected in cases:
+ with self.subTest(code=code):
+ events = code_to_events(code)
+ reader = self.prepare_reader(events, namespace={})
+ output = reader.readline()
+ self.assertEqual(output, expected)
+
+ @patch("pkgutil.iter_modules", lambda: [(None, 'valid_name', None),
+ (None, 'invalid-name', None)])
+ def test_invalid_identifiers(self):
+ # Make sure modules which are not valid identifiers
+ # are not suggested as those cannot be imported via 'import'.
+ cases = (
+ ("import valid\t\n", "import valid_name"),
+ # 'invalid-name' contains a dash and should not be completed
+ ("import invalid\t\n", "import invalid"),
+ )
+ for code, expected in cases:
+ with self.subTest(code=code):
+ events = code_to_events(code)
+ reader = self.prepare_reader(events, namespace={})
+ output = reader.readline()
+ self.assertEqual(output, expected)
+
+ def test_get_path_and_prefix(self):
+ cases = (
+ ('', ('', '')),
+ ('.', ('.', '')),
+ ('..', ('..', '')),
+ ('.foo', ('.', 'foo')),
+ ('..foo', ('..', 'foo')),
+ ('..foo.', ('..foo', '')),
+ ('..foo.bar', ('..foo', 'bar')),
+ ('.foo.bar.', ('.foo.bar', '')),
+ ('..foo.bar.', ('..foo.bar', '')),
+ ('foo', ('', 'foo')),
+ ('foo.', ('foo', '')),
+ ('foo.bar', ('foo', 'bar')),
+ ('foo.bar.', ('foo.bar', '')),
+ ('foo.bar.baz', ('foo.bar', 'baz')),
+ )
+ completer = ModuleCompleter()
+ for name, expected in cases:
+ with self.subTest(name=name):
+ self.assertEqual(completer.get_path_and_prefix(name), expected)
+
+ def test_parse(self):
+ cases = (
+ ('import ', (None, '')),
+ ('import foo', (None, 'foo')),
+ ('import foo,', (None, '')),
+ ('import foo, ', (None, '')),
+ ('import foo, bar', (None, 'bar')),
+ ('import foo, bar, baz', (None, 'baz')),
+ ('import foo as bar,', (None, '')),
+ ('import foo as bar, ', (None, '')),
+ ('import foo as bar, baz', (None, 'baz')),
+ ('import a.', (None, 'a.')),
+ ('import a.b', (None, 'a.b')),
+ ('import a.b.', (None, 'a.b.')),
+ ('import a.b.c', (None, 'a.b.c')),
+ ('import a.b.c, foo', (None, 'foo')),
+ ('import a.b.c, foo.bar', (None, 'foo.bar')),
+ ('import a.b.c, foo.bar,', (None, '')),
+ ('import a.b.c, foo.bar, ', (None, '')),
+ ('from foo', ('foo', None)),
+ ('from a.', ('a.', None)),
+ ('from a.b', ('a.b', None)),
+ ('from a.b.', ('a.b.', None)),
+ ('from a.b.c', ('a.b.c', None)),
+ ('from foo import ', ('foo', '')),
+ ('from foo import a', ('foo', 'a')),
+ ('from ', ('', None)),
+ ('from . import a', ('.', 'a')),
+ ('from .foo import a', ('.foo', 'a')),
+ ('from ..foo import a', ('..foo', 'a')),
+ ('from foo import (', ('foo', '')),
+ ('from foo import ( ', ('foo', '')),
+ ('from foo import (a', ('foo', 'a')),
+ ('from foo import (a,', ('foo', '')),
+ ('from foo import (a, ', ('foo', '')),
+ ('from foo import (a, c', ('foo', 'c')),
+ ('from foo import (a as b, c', ('foo', 'c')),
+ )
+ for code, parsed in cases:
+ parser = ImportParser(code)
+ actual = parser.parse()
+ with self.subTest(code=code):
+ self.assertEqual(actual, parsed)
+ # The parser should not get tripped up by any
+ # other preceding statements
+ code = f'import xyz\n{code}'
+ with self.subTest(code=code):
+ self.assertEqual(actual, parsed)
+ code = f'import xyz;{code}'
+ with self.subTest(code=code):
+ self.assertEqual(actual, parsed)
+
+ def test_parse_error(self):
+ cases = (
+ '',
+ 'import foo ',
+ 'from foo ',
+ 'import foo. ',
+ 'import foo.bar ',
+ 'from foo ',
+ 'from foo. ',
+ 'from foo.bar ',
+ 'from foo import bar ',
+ 'from foo import (bar ',
+ 'from foo import bar, baz ',
+ 'import foo as',
+ 'import a. as',
+ 'import a.b as',
+ 'import a.b. as',
+ 'import a.b.c as',
+ 'import (foo',
+ 'import (',
+ 'import .foo',
+ 'import ..foo',
+ 'import .foo.bar',
+ 'import foo; x = 1',
+ 'import a.; x = 1',
+ 'import a.b; x = 1',
+ 'import a.b.; x = 1',
+ 'import a.b.c; x = 1',
+ 'from foo import a as',
+ 'from foo import a. as',
+ 'from foo import a.b as',
+ 'from foo import a.b. as',
+ 'from foo import a.b.c as',
+ 'from foo impo',
+ 'import import',
+ 'import from',
+ 'import as',
+ 'from import',
+ 'from from',
+ 'from as',
+ 'from foo import import',
+ 'from foo import from',
+ 'from foo import as',
+ )
+ for code in cases:
+ parser = ImportParser(code)
+ actual = parser.parse()
+ with self.subTest(code=code):
+ self.assertEqual(actual, None)
+
class TestPasteEvent(TestCase):
def prepare_reader(self, events):
console = FakeConsole(events)
--- /dev/null
+Add module autocomplete to PyREPL.