]> git.ipfire.org Git - thirdparty/tornado.git/commitdiff
Thread safety for shared data structures (templates and static file hashes)
authorBen Darnell <ben@bendarnell.com>
Mon, 14 Nov 2011 04:15:44 +0000 (20:15 -0800)
committerBen Darnell <ben@bendarnell.com>
Mon, 14 Nov 2011 04:15:44 +0000 (20:15 -0800)
These data structures were basically safe before thanks to the GIL, but could
lead to wasted work in multithreaded environments (such as the new python2.7
app engine)

Also fixed a bug that made debug mode not invalidate static files.

tornado/template.py
tornado/web.py

index f28f40355da164493e7b3341528360da1d496136..fc9563016ac4e55a30a957ddbe2e7b1cccaf7f8a 100644 (file)
@@ -177,6 +177,7 @@ import logging
 import os.path
 import posixpath
 import re
+import threading
 
 from tornado import escape
 from tornado.util import bytes_type
@@ -282,10 +283,17 @@ class BaseLoader(object):
         self.autoescape = autoescape
         self.namespace = namespace or {}
         self.templates = {}
+        # self.lock protects self.templates.  It's a reentrant lock
+        # because templates may load other templates via `include` or
+        # `extends`.  Note that thanks to the GIL this code would be safe
+        # even without the lock, but could lead to wasted work as multiple
+        # threads tried to compile the same template simultaneously.
+        self.lock = threading.RLock()
 
     def reset(self):
         """Resets the cache of compiled templates."""
-        self.templates = {}
+        with self.lock:
+            self.templates = {}
 
     def resolve_path(self, name, parent_path=None):
         """Converts a possibly-relative path to absolute (used internally)."""
@@ -294,9 +302,10 @@ class BaseLoader(object):
     def load(self, name, parent_path=None):
         """Loads a template."""
         name = self.resolve_path(name, parent_path=parent_path)
-        if name not in self.templates:
-            self.templates[name] = self._create_template(name)
-        return self.templates[name]
+        with self.lock:
+            if name not in self.templates:
+                self.templates[name] = self._create_template(name)
+            return self.templates[name]
 
     def _create_template(self, name):
         raise NotImplementedError()
index e87f330545042f05a98ccd43037aa234290e2819..d9a77ccf66eb154aa013b3223c89dc56dfbd5587 100644 (file)
@@ -69,6 +69,7 @@ import os.path
 import re
 import stat
 import sys
+import threading
 import time
 import tornado
 import traceback
@@ -98,6 +99,9 @@ class RequestHandler(object):
     """
     SUPPORTED_METHODS = ("GET", "HEAD", "POST", "DELETE", "PUT", "OPTIONS")
 
+    _template_loaders = {}  # {path: template.BaseLoader}
+    _template_loader_lock = threading.Lock()
+
     def __init__(self, application, request, **kwargs):
         self.application = application
         self.request = request
@@ -537,12 +541,13 @@ class RequestHandler(object):
             while frame.f_code.co_filename == web_file:
                 frame = frame.f_back
             template_path = os.path.dirname(frame.f_code.co_filename)
-        if not getattr(RequestHandler, "_templates", None):
-            RequestHandler._templates = {}
-        if template_path not in RequestHandler._templates:
-            loader = self.create_template_loader(template_path)
-            RequestHandler._templates[template_path] = loader
-        t = RequestHandler._templates[template_path].load(template_name)
+        with RequestHandler._template_loader_lock:
+            if template_path not in RequestHandler._template_loaders:
+                loader = self.create_template_loader(template_path)
+                RequestHandler._template_loaders[template_path] = loader
+            else:
+                loader = RequestHandler._template_loaders[template_path]
+        t = loader.load(template_name)
         args = dict(
             handler=self,
             request=self.request,
@@ -1320,10 +1325,10 @@ class Application(object):
         # In debug mode, re-compile templates and reload static files on every
         # request so you don't need to restart to see changes
         if self.settings.get("debug"):
-            if getattr(RequestHandler, "_templates", None):
-                for loader in RequestHandler._templates.values():
+            with RequestHandler._template_loader_lock:
+                for loader in RequestHandler._template_loaders.values():
                     loader.reset()
-            RequestHandler._static_hashes = {}
+            StaticFileHandler.reset()
 
         handler._execute(transforms, *args, **kwargs)
         return handler
@@ -1424,11 +1429,17 @@ class StaticFileHandler(RequestHandler):
     CACHE_MAX_AGE = 86400*365*10 #10 years
 
     _static_hashes = {}
+    _lock = threading.Lock()  # protects _static_hashes
 
     def initialize(self, path, default_filename=None):
         self.root = os.path.abspath(path) + os.path.sep
         self.default_filename = default_filename
 
+    @classmethod
+    def reset(cls):
+        with cls._lock:
+            cls._static_hashes = {}
+
     def head(self, path):
         self.get(path, include_body=False)
 
@@ -1517,19 +1528,21 @@ class StaticFileHandler(RequestHandler):
         is the static path being requested.  The url returned should be
         relative to the current host.
         """
-        hashes = cls._static_hashes
         abs_path = os.path.join(settings["static_path"], path)
-        if abs_path not in hashes:
-            try:
-                f = open(abs_path, "rb")
-                hashes[abs_path] = hashlib.md5(f.read()).hexdigest()
-                f.close()
-            except Exception:
-                logging.error("Could not open static file %r", path)
-                hashes[abs_path] = None
+        with cls._lock:
+            hashes = cls._static_hashes
+            if abs_path not in hashes:
+                try:
+                    f = open(abs_path, "rb")
+                    hashes[abs_path] = hashlib.md5(f.read()).hexdigest()
+                    f.close()
+                except Exception:
+                    logging.error("Could not open static file %r", path)
+                    hashes[abs_path] = None
+            hsh = hashes.get(abs_path)
         static_url_prefix = settings.get('static_url_prefix', '/static/')
-        if hashes.get(abs_path):
-            return static_url_prefix + path + "?v=" + hashes[abs_path][:5]
+        if hsh:
+            return static_url_prefix + path + "?v=" + hsh[:5]
         else:
             return static_url_prefix + path