from __future__ import absolute_import, division, print_function, with_statement
+import socket
+
from tornado.concurrent import Future
from tornado.escape import native_str, utf8
from tornado import gen
# and its socket attribute replaced with None.
self.address_family = stream.socket.family
self.no_keep_alive = no_keep_alive
+ # In HTTPServerRequest we want an IP, not a full socket address.
+ if (self.address_family in (socket.AF_INET, socket.AF_INET6) and
+ address is not None):
+ self.remote_ip = address[0]
+ else:
+ # Unix (or other) socket; fake the remote address.
+ self.remote_ip = '0.0.0.0'
if protocol:
self.protocol = protocol
elif isinstance(stream, iostream.SSLIOStream):
def start_request(self, connection):
return _ServerRequestAdapter(self, connection)
+
class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
- """Adapts the `HTTPMessageDelegate` interface to the `HTTPServerRequest`
- interface expected by our clients.
+ """Adapts the `HTTPMessageDelegate` interface to the interface expected
+ by our clients.
"""
def __init__(self, server, connection):
self.server = server
self.connection = connection
- self._chunks = []
-
- def headers_received(self, start_line, headers):
- # HTTPRequest wants an IP, not a full socket address
- if self.connection.address_family in (socket.AF_INET, socket.AF_INET6):
- remote_ip = self.connection.address[0]
+ self.request = None
+ if isinstance(server.request_callback,
+ httputil.HTTPServerConnectionDelegate):
+ self.delegate = server.request_callback.start_request(connection)
+ self._chunks = None
else:
- # Unix (or other) socket; fake the remote address
- remote_ip = '0.0.0.0'
+ self.delegate = None
+ self._chunks = []
+
+ def _apply_xheaders(self, headers):
+ """Rewrite the connection.remote_ip and connection.protocol fields.
+
+ This is hacky, but the movement of logic between `HTTPServer`
+ and `.Application` leaves us without a clean place to do this.
+ """
+ self._orig_remote_ip = self.connection.remote_ip
+ self._orig_protocol = self.connection.protocol
+ # Squid uses X-Forwarded-For, others use X-Real-Ip
+ ip = headers.get("X-Forwarded-For", self.connection.remote_ip)
+ ip = ip.split(',')[-1].strip()
+ ip = headers.get("X-Real-Ip", ip)
+ if netutil.is_valid_ip(ip):
+ self.connection.remote_ip = ip
+ # AWS uses X-Forwarded-Proto
+ proto_header = headers.get(
+ "X-Scheme", headers.get("X-Forwarded-Proto",
+ self.connection.protocol))
+ if proto_header in ("http", "https"):
+ self.connection.protocol = proto_header
+
+ def _unapply_xheaders(self):
+ """Undo changes from `_apply_xheaders`.
+
+ Xheaders are per-request so they should not leak to the next
+ request on the same connection.
+ """
+ self.connection.remote_ip = self._orig_remote_ip
+ self.connection.protocol = self._orig_protocol
- protocol = self.connection.protocol
-
- # xheaders can override the defaults
+ def headers_received(self, start_line, headers):
if self.server.xheaders:
- # Squid uses X-Forwarded-For, others use X-Real-Ip
- ip = headers.get("X-Forwarded-For", remote_ip)
- ip = ip.split(',')[-1].strip()
- ip = headers.get("X-Real-Ip", ip)
- if netutil.is_valid_ip(ip):
- remote_ip = ip
- # AWS uses X-Forwarded-Proto
- proto_header = headers.get(
- "X-Scheme", headers.get("X-Forwarded-Proto", protocol))
- if proto_header in ("http", "https"):
- protocol = proto_header
-
- self.request = httputil.HTTPServerRequest(
- connection=self.connection, method=start_line.method,
- uri=start_line.path, version=start_line.version,
- headers=headers, remote_ip=remote_ip, protocol=protocol)
+ self._apply_xheaders(headers)
+ if self.delegate is None:
+ self.request = httputil.HTTPServerRequest(
+ connection=self.connection, start_line=start_line,
+ headers=headers)
+ else:
+ self.delegate.headers_received(start_line, headers)
def data_received(self, chunk):
- self._chunks.append(chunk)
+ if self.delegate is None:
+ self._chunks.append(chunk)
+ else:
+ self.delegate.data_received(chunk)
def finish(self):
- self.request.body = b''.join(self._chunks)
- self.request._parse_body()
- self.server.request_callback(self.request)
+ if self.delegate is None:
+ self.request.body = b''.join(self._chunks)
+ self.request._parse_body()
+ self.server.request_callback(self.request)
+ else:
+ self.delegate.finish()
+ if self.server.xheaders:
+ self._unapply_xheaders()
HTTPRequest = httputil.HTTPServerRequest
are typically kept open in HTTP/1.1, multiple requests can be handled
sequentially on a single connection.
"""
- def __init__(self, method, uri, version="HTTP/1.0", headers=None,
+ def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
body=None, remote_ip=None, protocol=None, host=None,
- files=None, connection=None):
+ files=None, connection=None, start_line=None):
+ if start_line is not None:
+ method, uri, version = start_line
self.method = method
self.uri = uri
self.version = version
self.body = body or ""
# set remote IP and protocol
- self.remote_ip = remote_ip
- self.protocol = protocol or "http"
+ self.remote_ip = remote_ip or getattr(connection, 'remote_ip')
+ self.protocol = protocol or getattr(connection, 'protocol', "http")
self.host = host or self.headers.get("Host") or "127.0.0.1"
self.files = files or {}
return wrapper
-class Application(object):
+class Application(httputil.HTTPServerConnectionDelegate):
"""A collection of request handlers that make up a web application.
Instances of this class are callable and can be passed directly to
except TypeError:
pass
- def __call__(self, request):
- """Called by HTTPServer to execute the request."""
- transforms = [t(request) for t in self.transforms]
- handler = None
- args = []
- kwargs = {}
- handlers = self._get_host_handlers(request)
- if not handlers:
- handler = RedirectHandler(
- self, request, url="http://" + self.default_host + "/")
- else:
- for spec in handlers:
- match = spec.regex.match(request.path)
- if match:
- handler = spec.handler_class(self, request, **spec.kwargs)
- if spec.regex.groups:
- # None-safe wrapper around url_unescape to handle
- # unmatched optional groups correctly
- def unquote(s):
- if s is None:
- return s
- return escape.url_unescape(s, encoding=None,
- plus=False)
- # Pass matched groups to the handler. Since
- # match.groups() includes both named and unnamed groups,
- # we want to use either groups or groupdict but not both.
- # Note that args are passed as bytes so the handler can
- # decide what encoding to use.
-
- if spec.regex.groupindex:
- kwargs = dict(
- (str(k), unquote(v))
- for (k, v) in match.groupdict().items())
- else:
- args = [unquote(s) for s in match.groups()]
- break
- if not handler:
- if self.settings.get('default_handler_class'):
- handler_class = self.settings['default_handler_class']
- handler_args = self.settings.get(
- 'default_handler_args', {})
- else:
- handler_class = ErrorHandler
- handler_args = dict(status_code=404)
- handler = handler_class(self, request, **handler_args)
+ def start_request(self, connection):
+ # Modern HTTPServer interface
+ return _RequestDispatcher(self, connection)
- # If template cache is disabled (usually in the debug mode),
- # re-compile templates and reload static files on every
- # request so you don't need to restart to see changes
- if not self.settings.get("compiled_template_cache", True):
- with RequestHandler._template_loader_lock:
- for loader in RequestHandler._template_loaders.values():
- loader.reset()
- if not self.settings.get('static_hash_cache', True):
- StaticFileHandler.reset()
-
- handler._execute(transforms, *args, **kwargs)
- return handler
+ def __call__(self, request):
+ # Legacy HTTPServer interface
+ dispatcher = _RequestDispatcher(self, None)
+ dispatcher.set_request(request)
+ return dispatcher.execute()
def reverse_url(self, name, *args):
"""Returns a URL path for handler named ``name``
handler._request_summary(), request_time)
+class _RequestDispatcher(httputil.HTTPMessageDelegate):
+ def __init__(self, application, connection):
+ self.application = application
+ self.connection = connection
+ self.request = None
+ self.chunks = []
+ self.handler_class = None
+ self.handler_kwargs = None
+ self.path_args = []
+ self.path_kwargs = {}
+
+ def headers_received(self, start_line, headers):
+ self.set_request(httputil.HTTPServerRequest(
+ connection=self.connection, start_line=start_line, headers=headers))
+
+ def set_request(self, request):
+ self.request = request
+ self._find_handler()
+
+ def _find_handler(self):
+ # Identify the handler to use as soon as we have the request.
+ # Save url path arguments for later.
+ app = self.application
+ handlers = app._get_host_handlers(self.request)
+ if not handlers:
+ self.handler_class = RedirectHandler
+ self.handler_kwargs = dict(url="http://" + app.default_host + "/")
+ return
+ for spec in handlers:
+ match = spec.regex.match(self.request.path)
+ if match:
+ self.handler_class = spec.handler_class
+ self.handler_kwargs = spec.kwargs
+ if spec.regex.groups:
+ # Pass matched groups to the handler. Since
+ # match.groups() includes both named and
+ # unnamed groups, we want to use either groups
+ # or groupdict but not both.
+ if spec.regex.groupindex:
+ self.path_kwargs = dict(
+ (str(k), _unquote_or_none(v))
+ for (k, v) in match.groupdict().items())
+ else:
+ self.path_args = [_unquote_or_none(s)
+ for s in match.groups()]
+ return
+ if app.settings.get('default_handler_class'):
+ self.handler_class = app.settings['default_handler_class']
+ self.handler_kwargs = app.settings.get(
+ 'default_handler_args', {})
+ else:
+ self.handler_class = ErrorHandler
+ self.handler_kwargs = dict(status_code=404)
+
+ def data_received(self, data):
+ self.chunks.append(data)
+
+ def finish(self):
+ self.request.body = b''.join(self.chunks)
+ self.request._parse_body()
+ self.execute()
+
+ def execute(self):
+ # If template cache is disabled (usually in the debug mode),
+ # re-compile templates and reload static files on every
+ # request so you don't need to restart to see changes
+ if not self.application.settings.get("compiled_template_cache", True):
+ with RequestHandler._template_loader_lock:
+ for loader in RequestHandler._template_loaders.values():
+ loader.reset()
+ if not self.application.settings.get('static_hash_cache', True):
+ StaticFileHandler.reset()
+
+ handler = self.handler_class(self.application, self.request,
+ **self.handler_kwargs)
+ transforms = [t(self.request) for t in self.application.transforms]
+ handler._execute(transforms, *self.path_args, **self.path_kwargs)
+ return handler
+
+
+
class HTTPError(Exception):
"""An exception that will turn into an HTTP error response.
for part in parts:
hash.update(utf8(part))
return utf8(hash.hexdigest())
+
+def _unquote_or_none(s):
+ """None-safe wrapper around url_unescape to handle unamteched optional
+ groups correctly.
+
+ Note that args are passed as bytes so the handler can decide what
+ encoding to use.
+ """
+ if s is None:
+ return s
+ return escape.url_unescape(s, encoding=None, plus=False)