From 364321653935686a71097a3995acd7cf20c74d7a Mon Sep 17 00:00:00 2001 From: Jon Parise Date: Sun, 18 Nov 2012 11:45:38 -0800 Subject: [PATCH] _get_host_handlers() now returns all host matches. This approach has more clearly defined precedence rules than the previous insertion-time strategy implemented in add_handlers(). It also correctly leaves pattern matching in the hands of the regular expression evaluator as opposed to directly comparing pattern strings. --- tornado/test/web_test.py | 35 ++++++++++++++++++++++++++++++++ tornado/web.py | 44 ++++++++++++++++------------------------ 2 files changed, 52 insertions(+), 27 deletions(-) diff --git a/tornado/test/web_test.py b/tornado/test/web_test.py index 46b4ec1e7..0991b509f 100644 --- a/tornado/test/web_test.py +++ b/tornado/test/web_test.py @@ -847,6 +847,41 @@ class CustomStaticFileTest(WebTestCase): wsgi_safe.append(CustomStaticFileTest) +class HostMatchingTest(WebTestCase): + class Handler(RequestHandler): + def initialize(self, reply): + self.reply = reply + + def get(self): + self.write(self.reply) + + def get_handlers(self): + return [("/foo", HostMatchingTest.Handler, {"reply": "wildcard"})] + + def test_host_matching(self): + self.app.add_handlers("www.example.com", + [("/foo", HostMatchingTest.Handler, {"reply": "[0]"})]) + self.app.add_handlers(r"www\.example\.com", + [("/bar", HostMatchingTest.Handler, {"reply": "[1]"})]) + self.app.add_handlers("www.example.com", + [("/baz", HostMatchingTest.Handler, {"reply": "[2]"})]) + + response = self.fetch("/foo") + self.assertEqual(response.body, b("wildcard")) + response = self.fetch("/bar") + self.assertEqual(response.code, 404) + response = self.fetch("/baz") + self.assertEqual(response.code, 404) + + response = self.fetch("/foo", headers={'Host': 'www.example.com'}) + self.assertEqual(response.body, b("[0]")) + response = self.fetch("/bar", headers={'Host': 'www.example.com'}) + self.assertEqual(response.body, b("[1]")) + response = self.fetch("/baz", headers={'Host': 'www.example.com'}) + self.assertEqual(response.body, b("[2]")) +wsgi_safe.append(HostMatchingTest) + + class NamedURLSpecGroupsTest(WebTestCase): def get_handlers(self): class EchoHandler(RequestHandler): diff --git a/tornado/web.py b/tornado/web.py index f252a76ed..00f3600ad 100644 --- a/tornado/web.py +++ b/tornado/web.py @@ -1314,32 +1314,21 @@ class Application(object): def add_handlers(self, host_pattern, host_handlers): """Appends the given handlers to our handler list. - Note that host patterns are processed sequentially in the - order they were added, and only the first matching pattern is - used. + Host patterns are processed sequentially in the order they were + added. All matching patterns will be considered. """ if not host_pattern.endswith("$"): host_pattern += "$" - - # Search for an existing handlers entry for this host pattern. - handlers = None - for entry in self.handlers: - if entry[0].pattern == host_pattern: - handlers = entry[1] - break - - # Otherwise, add a new handlers entry for this host pattern. - if handlers is None: - handlers = [] - # The handlers with the wildcard host_pattern are a special - # case - they're added in the constructor but should have lower - # precedence than the more-precise handlers added later. - # If a wildcard handler group exists, it should always be last - # in the list, so insert new groups just before it. - if self.handlers and self.handlers[-1][0].pattern == '.*$': - self.handlers.insert(-1, (re.compile(host_pattern), handlers)) - else: - self.handlers.append((re.compile(host_pattern), handlers)) + handlers = [] + # The handlers with the wildcard host_pattern are a special + # case - they're added in the constructor but should have lower + # precedence than the more-precise handlers added later. + # If a wildcard handler group exists, it should always be last + # in the list, so insert new groups just before it. + if self.handlers and self.handlers[-1][0].pattern == '.*$': + self.handlers.insert(-1, (re.compile(host_pattern), handlers)) + else: + self.handlers.append((re.compile(host_pattern), handlers)) for spec in host_handlers: if type(spec) is type(()): @@ -1371,15 +1360,16 @@ class Application(object): def _get_host_handlers(self, request): host = request.host.lower().split(':')[0] + matches = [] for pattern, handlers in self.handlers: if pattern.match(host): - return handlers + matches.extend(handlers) # Look for default host if not behind load balancer (for debugging) - if "X-Real-Ip" not in request.headers: + if not matches and "X-Real-Ip" not in request.headers: for pattern, handlers in self.handlers: if pattern.match(self.default_host): - return handlers - return None + matches.extend(handlers) + return matches or None def _load_ui_methods(self, methods): if type(methods) is types.ModuleType: -- 2.47.2