]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-45017: move opcode-related logic from modulefinder to dis (GH-28246)
authorIrit Katriel <1055913+iritkatriel@users.noreply.github.com>
Thu, 9 Sep 2021 13:04:12 +0000 (14:04 +0100)
committerGitHub <noreply@github.com>
Thu, 9 Sep 2021 13:04:12 +0000 (14:04 +0100)
Lib/dis.py
Lib/modulefinder.py
Lib/test/test_dis.py

index 66487dce0eefc0032b1ae27cb5d4cc949f722acd..a073572e59e66deeae488ac13aa73b87b9fcec03 100644 (file)
@@ -535,6 +535,42 @@ def findlinestarts(code):
             yield start, line
     return
 
+def _find_imports(co):
+    """Find import statements in the code
+
+    Generate triplets (name, level, fromlist) where
+    name is the imported module and level, fromlist are
+    the corresponding args to __import__.
+    """
+    IMPORT_NAME = opmap['IMPORT_NAME']
+    LOAD_CONST = opmap['LOAD_CONST']
+
+    consts = co.co_consts
+    names = co.co_names
+    opargs = [(op, arg) for _, op, arg in _unpack_opargs(co.co_code)
+                  if op != EXTENDED_ARG]
+    for i, (op, oparg) in enumerate(opargs):
+        if (op == IMPORT_NAME and i >= 2
+                and opargs[i-1][0] == opargs[i-2][0] == LOAD_CONST):
+            level = consts[opargs[i-2][1]]
+            fromlist = consts[opargs[i-1][1]]
+            yield (names[oparg], level, fromlist)
+
+def _find_store_names(co):
+    """Find names of variables which are written in the code
+
+    Generate sequence of strings
+    """
+    STORE_OPS = {
+        opmap['STORE_NAME'],
+        opmap['STORE_GLOBAL']
+    }
+
+    names = co.co_names
+    for _, op, arg in _unpack_opargs(co.co_code):
+        if op in STORE_OPS:
+            yield names[arg]
+
 
 class Bytecode:
     """The bytecode operations of a piece of code
index cb455f40c4d7894ef73ab25bed6659e917565394..a0a020f9eeb9b415b7ca39566c779e13b0051928 100644 (file)
@@ -8,14 +8,6 @@ import os
 import io
 import sys
 
-
-LOAD_CONST = dis.opmap['LOAD_CONST']
-IMPORT_NAME = dis.opmap['IMPORT_NAME']
-STORE_NAME = dis.opmap['STORE_NAME']
-STORE_GLOBAL = dis.opmap['STORE_GLOBAL']
-STORE_OPS = STORE_NAME, STORE_GLOBAL
-EXTENDED_ARG = dis.EXTENDED_ARG
-
 # Old imp constants:
 
 _SEARCH_ERROR = 0
@@ -394,24 +386,13 @@ class ModuleFinder:
 
     def scan_opcodes(self, co):
         # Scan the code, and yield 'interesting' opcode combinations
-        code = co.co_code
-        names = co.co_names
-        consts = co.co_consts
-        opargs = [(op, arg) for _, op, arg in dis._unpack_opargs(code)
-                  if op != EXTENDED_ARG]
-        for i, (op, oparg) in enumerate(opargs):
-            if op in STORE_OPS:
-                yield "store", (names[oparg],)
-                continue
-            if (op == IMPORT_NAME and i >= 2
-                    and opargs[i-1][0] == opargs[i-2][0] == LOAD_CONST):
-                level = consts[opargs[i-2][1]]
-                fromlist = consts[opargs[i-1][1]]
-                if level == 0: # absolute import
-                    yield "absolute_import", (fromlist, names[oparg])
-                else: # relative import
-                    yield "relative_import", (level, fromlist, names[oparg])
-                continue
+        for name in dis._find_store_names(co):
+            yield "store", (name,)
+        for name, level, fromlist in dis._find_imports(co):
+            if level == 0:  # absolute import
+                yield "absolute_import", (fromlist, name)
+            else:  # relative import
+                yield "relative_import", (level, fromlist, name)
 
     def scan_code(self, co, m):
         code = co.co_code
index b97e41cdfab5ec66e9e82c7adb858c028ce59532..a140a89f0e7e8ddd9bd1e46e88bb11125f22e538 100644 (file)
@@ -1326,5 +1326,38 @@ class TestBytecodeTestCase(BytecodeTestCase):
         with self.assertRaises(AssertionError):
             self.assertNotInBytecode(code, "LOAD_CONST", 1)
 
+class TestFinderMethods(unittest.TestCase):
+    def test__find_imports(self):
+        cases = [
+            ("import a.b.c", ('a.b.c', 0, None)),
+            ("from a.b import c", ('a.b', 0, ('c',))),
+            ("from a.b import c as d", ('a.b', 0, ('c',))),
+            ("from a.b import *", ('a.b', 0, ('*',))),
+            ("from ...a.b import c as d", ('a.b', 3, ('c',))),
+            ("from ..a.b import c as d, e as f", ('a.b', 2, ('c', 'e'))),
+            ("from ..a.b import *", ('a.b', 2, ('*',))),
+        ]
+        for src, expected in cases:
+            with self.subTest(src=src):
+                code = compile(src, "<string>", "exec")
+                res = tuple(dis._find_imports(code))
+                self.assertEqual(len(res), 1)
+                self.assertEqual(res[0], expected)
+
+    def test__find_store_names(self):
+        cases = [
+            ("x+y", ()),
+            ("x=y=1", ('x', 'y')),
+            ("x+=y", ('x',)),
+            ("global x\nx=y=1", ('x', 'y')),
+            ("global x\nz=x", ('z',)),
+        ]
+        for src, expected in cases:
+            with self.subTest(src=src):
+                code = compile(src, "<string>", "exec")
+                res = tuple(dis._find_store_names(code))
+                self.assertEqual(res, expected)
+
+
 if __name__ == "__main__":
     unittest.main()