]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Merged revisions 68363 via svnmerge from
authorAntoine Pitrou <solipsis@pitrou.net>
Tue, 6 Jan 2009 19:08:18 +0000 (19:08 +0000)
committerAntoine Pitrou <solipsis@pitrou.net>
Tue, 6 Jan 2009 19:08:18 +0000 (19:08 +0000)
svn+ssh://pythondev@svn.python.org/python/branches/py3k

................
  r68363 | antoine.pitrou | 2009-01-06 20:02:24 +0100 (mar., 06 janv. 2009) | 17 lines

  Merged revisions 68360-68361 via svnmerge from
  svn+ssh://pythondev@svn.python.org/python/trunk

  ........
    r68360 | antoine.pitrou | 2009-01-06 19:10:47 +0100 (mar., 06 janv. 2009) | 7 lines

    Issue #1180193: When importing a module from a .pyc (or .pyo) file with
    an existing .py counterpart, override the co_filename attributes of all
    code objects if the original filename is obsolete (which can happen if the
    file has been renamed, moved, or if it is accessed through different paths).
    Patch by Ziga Seilnacht and Jean-Paul Calderone.
  ........
    r68361 | antoine.pitrou | 2009-01-06 19:34:08 +0100 (mar., 06 janv. 2009) | 3 lines

    Use shutil.rmtree rather than os.rmdir.
  ........
................

Lib/test/test_import.py
Misc/NEWS
Python/import.c

index 6598d4ee5e73ddab61ae8097e14303fef8fbc5de..145ff9a79cafa78ba373026cea1a814ea9afad3f 100644 (file)
@@ -6,6 +6,7 @@ import sys
 import py_compile
 import warnings
 import imp
+import marshal
 from test.support import unlink, TESTFN, unload, run_unittest
 
 
@@ -230,6 +231,98 @@ class ImportTest(unittest.TestCase):
         else:
             self.fail("import by path didn't raise an exception")
 
+class TestPycRewriting(unittest.TestCase):
+    # Test that the `co_filename` attribute on code objects always points
+    # to the right file, even when various things happen (e.g. both the .py
+    # and the .pyc file are renamed).
+
+    module_name = "unlikely_module_name"
+    module_source = """
+import sys
+code_filename = sys._getframe().f_code.co_filename
+module_filename = __file__
+constant = 1
+def func():
+    pass
+func_filename = func.__code__.co_filename
+"""
+    dir_name = os.path.abspath(TESTFN)
+    file_name = os.path.join(dir_name, module_name) + os.extsep + "py"
+    compiled_name = file_name + ("c" if __debug__ else "o")
+
+    def setUp(self):
+        self.sys_path = sys.path[:]
+        self.orig_module = sys.modules.pop(self.module_name, None)
+        os.mkdir(self.dir_name)
+        with open(self.file_name, "w") as f:
+            f.write(self.module_source)
+        sys.path.insert(0, self.dir_name)
+
+    def tearDown(self):
+        sys.path[:] = self.sys_path
+        if self.orig_module is not None:
+            sys.modules[self.module_name] = self.orig_module
+        else:
+            del sys.modules[self.module_name]
+        for file_name in self.file_name, self.compiled_name:
+            if os.path.exists(file_name):
+                os.remove(file_name)
+        if os.path.exists(self.dir_name):
+            shutil.rmtree(self.dir_name)
+
+    def import_module(self):
+        ns = globals()
+        __import__(self.module_name, ns, ns)
+        return sys.modules[self.module_name]
+
+    def test_basics(self):
+        mod = self.import_module()
+        self.assertEqual(mod.module_filename, self.file_name)
+        self.assertEqual(mod.code_filename, self.file_name)
+        self.assertEqual(mod.func_filename, self.file_name)
+        del sys.modules[self.module_name]
+        mod = self.import_module()
+        self.assertEqual(mod.module_filename, self.file_name)
+        self.assertEqual(mod.code_filename, self.file_name)
+        self.assertEqual(mod.func_filename, self.file_name)
+
+    def test_incorrect_code_name(self):
+        py_compile.compile(self.file_name, dfile="another_module.py")
+        mod = self.import_module()
+        self.assertEqual(mod.module_filename, self.file_name)
+        self.assertEqual(mod.code_filename, self.file_name)
+        self.assertEqual(mod.func_filename, self.file_name)
+
+    def test_module_without_source(self):
+        target = "another_module.py"
+        py_compile.compile(self.file_name, dfile=target)
+        os.remove(self.file_name)
+        mod = self.import_module()
+        self.assertEqual(mod.module_filename, self.compiled_name)
+        self.assertEqual(mod.code_filename, target)
+        self.assertEqual(mod.func_filename, target)
+
+    def test_foreign_code(self):
+        py_compile.compile(self.file_name)
+        with open(self.compiled_name, "rb") as f:
+            header = f.read(8)
+            code = marshal.load(f)
+        constants = list(code.co_consts)
+        foreign_code = test_main.__code__
+        pos = constants.index(1)
+        constants[pos] = foreign_code
+        code = type(code)(code.co_argcount, code.co_kwonlyargcount,
+                          code.co_nlocals, code.co_stacksize,
+                          code.co_flags, code.co_code, tuple(constants),
+                          code.co_names, code.co_varnames, code.co_filename,
+                          code.co_name, code.co_firstlineno, code.co_lnotab,
+                          code.co_freevars, code.co_cellvars)
+        with open(self.compiled_name, "wb") as f:
+            f.write(header)
+            marshal.dump(code, f)
+        mod = self.import_module()
+        self.assertEqual(mod.constant.co_filename, foreign_code.co_filename)
+
 class PathsTests(unittest.TestCase):
     SAMPLES = ('test', 'test\u00e4\u00f6\u00fc\u00df', 'test\u00e9\u00e8',
                'test\u00b0\u00b3\u00b2')
@@ -288,7 +381,7 @@ class RelativeImport(unittest.TestCase):
         self.assertRaises(ValueError, check_relative)
 
 def test_main(verbose=None):
-    run_unittest(ImportTest, PathsTests, RelativeImport)
+    run_unittest(ImportTest, TestPycRewriting, PathsTests, RelativeImport)
 
 if __name__ == '__main__':
     # test needs to be a package, so we can do relative import
index 15e69c568d327a7400be15303542d48c02ba7923..e9dfcbaf3817a8b9da7b761bf1e7ab1269088c73 100644 (file)
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -12,6 +12,12 @@ What's New in Python 3.0.1?
 Core and Builtins
 -----------------
 
+- Issue #1180193: When importing a module from a .pyc (or .pyo) file with
+  an existing .py counterpart, override the co_filename attributes of all
+  code objects if the original filename is obsolete (which can happen if the
+  file has been renamed, moved, or if it is accessed through different paths).
+  Patch by Ziga Seilnacht and Jean-Paul Calderone.
+
 - Issue #4580: Fix slicing of memoryviews when the item size is greater than
   one byte. Also fixes the meaning of len() so that it returns the number of
   items, rather than the size in bytes.
index 2bad2e524821fa40ed4ba844b63596be9bafb5fe..80eb04bb5146e53547818fa4da49fe9e6674c93e 100644 (file)
@@ -959,6 +959,49 @@ write_compiled_module(PyCodeObject *co, char *cpathname, struct stat *srcstat)
                PySys_WriteStderr("# wrote %s\n", cpathname);
 }
 
+static void
+update_code_filenames(PyCodeObject *co, PyObject *oldname, PyObject *newname)
+{
+       PyObject *constants, *tmp;
+       Py_ssize_t i, n;
+
+       if (PyUnicode_Compare(co->co_filename, oldname))
+               return;
+
+       tmp = co->co_filename;
+       co->co_filename = newname;
+       Py_INCREF(co->co_filename);
+       Py_DECREF(tmp);
+
+       constants = co->co_consts;
+       n = PyTuple_GET_SIZE(constants);
+       for (i = 0; i < n; i++) {
+               tmp = PyTuple_GET_ITEM(constants, i);
+               if (PyCode_Check(tmp))
+                       update_code_filenames((PyCodeObject *)tmp,
+                                             oldname, newname);
+       }
+}
+
+static int
+update_compiled_module(PyCodeObject *co, char *pathname)
+{
+       PyObject *oldname, *newname;
+
+       if (!PyUnicode_CompareWithASCIIString(co->co_filename, pathname))
+               return 0;
+
+       newname = PyUnicode_FromString(pathname);
+       if (newname == NULL)
+               return -1;
+
+       oldname = co->co_filename;
+       Py_INCREF(oldname);
+       update_code_filenames(co, oldname, newname);
+       Py_DECREF(oldname);
+       Py_DECREF(newname);
+       return 1;
+}
 
 /* Load a source module from a given file and return its module
    object WITH INCREMENTED REFERENCE COUNT.  If there's a matching
@@ -999,6 +1042,8 @@ load_source_module(char *name, char *pathname, FILE *fp)
                fclose(fpc);
                if (co == NULL)
                        return NULL;
+               if (update_compiled_module(co, pathname) < 0)
+                       return NULL;
                if (Py_VerboseFlag)
                        PySys_WriteStderr("import %s # precompiled from %s\n",
                                name, cpathname);