From 5f194f73bf9d789cb2b1eed3c48243d13f9b0a54 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 11 Jul 2018 16:08:51 +0100 Subject: [PATCH] Add FileResponse --- requirements.txt | 1 + setup.py | 1 + starlette/__init__.py | 2 + starlette/datastructures.py | 53 +++-- starlette/middleware/__init__.py | 4 + starlette/middleware/database.py | 27 +++ starlette/multipart.py | 365 +++++++++++++++++++++++++++++++ starlette/response.py | 125 ++++++----- 8 files changed, 510 insertions(+), 68 deletions(-) create mode 100644 starlette/middleware/__init__.py create mode 100644 starlette/middleware/database.py create mode 100644 starlette/multipart.py diff --git a/requirements.txt b/requirements.txt index 96b8130d..fbec2e51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiofiles requests # Testing diff --git a/setup.py b/setup.py index 3777b4ec..646c7613 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ setup( author_email='tom@tomchristie.com', packages=get_packages('starlette'), install_requires=[ + 'aiofiles', 'requests', ], classifiers=[ diff --git a/starlette/__init__.py b/starlette/__init__.py index 22ae9352..33d95e7c 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1,5 +1,6 @@ from starlette.decorators import asgi_application from starlette.response import ( + FileResponse, HTMLResponse, JSONResponse, Response, @@ -13,6 +14,7 @@ from starlette.testclient import TestClient __all__ = ( "asgi_application", + "FileResponse", "HTMLResponse", "JSONResponse", "Path", diff --git a/starlette/datastructures.py b/starlette/datastructures.py index ba601764..71c52c04 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -119,16 +119,15 @@ class Headers(typing.Mapping[str, str]): An immutable, case-insensitive multidict. """ - def __init__(self, value: typing.Union[StrDict, StrPairs] = None) -> None: - if value is None: + def __init__(self, raw_headers=None) -> None: + if raw_headers is None: self._list = [] else: - assert isinstance(value, list) - for header_key, header_value in value: + for header_key, header_value in raw_headers: assert isinstance(header_key, bytes) assert isinstance(header_value, bytes) assert header_key == header_key.lower() - self._list = value + self._list = raw_headers def keys(self): return [key.decode("latin-1") for key, value in self._list] @@ -148,7 +147,7 @@ class Headers(typing.Mapping[str, str]): except KeyError: return default - def get_list(self, key: str) -> typing.List[str]: + def getlist(self, key: str) -> typing.List[str]: get_header_key = key.lower().encode("latin-1") return [ item_value.decode("latin-1") @@ -164,7 +163,11 @@ class Headers(typing.Mapping[str, str]): raise KeyError(key) def __contains__(self, key: str): - return key.lower() in self.keys() + get_header_key = key.lower().encode("latin-1") + for header_key, header_value in self._list: + if header_key == get_header_key: + return True + return False def __iter__(self): return iter(self.items()) @@ -183,6 +186,9 @@ class Headers(typing.Mapping[str, str]): class MutableHeaders(Headers): def __setitem__(self, key: str, value: str): + """ + Set the header `key` to `value`, removing any duplicate entries. + """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") @@ -196,19 +202,30 @@ class MutableHeaders(Headers): self._list.append((set_key, set_value)) - def set_default(self, key: str, value: str): - set_key = key.lower().encode("latin-1") - set_value = value.encode("latin-1") + def __delitem__(self, key: str): + """ + Remove the header `key`. + """ + del_key = key.lower().encode("latin-1") - is_set = False pop_indexes = [] for idx, (item_key, item_value) in enumerate(self._list): - if item_key == set_key: - if not is_set: - is_set = True - self._list[idx] = set_value - else: - pop_indexes.append(idx) + if item_key == del_key: + pop_indexes.append(idx) for idx in reversed(pop_indexes): - del self._list[idx] + del (self._list[idx]) + + def setdefault(self, key: str, value: str): + """ + If the header `key` does not exist, then set it to `value`. + Returns the header value. + """ + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") + + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == set_key: + return item_value.decode("latin-1") + self._list.append((set_key, set_value)) + return value diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py new file mode 100644 index 00000000..bdd37f2d --- /dev/null +++ b/starlette/middleware/__init__.py @@ -0,0 +1,4 @@ +from starlette.middleware.database import DatabaseMiddleware + + +__all__ = ["DatabaseMiddleware"] diff --git a/starlette/middleware/database.py b/starlette/middleware/database.py new file mode 100644 index 00000000..9fb5303b --- /dev/null +++ b/starlette/middleware/database.py @@ -0,0 +1,27 @@ +from functools import partial +from urllib.parse import urlparse +import asyncio +import asyncpg + + +class DatabaseMiddleware: + def __init__(self, app, database_url=None, database_config=None): + self.app = app + if database_config is None: + parsed = urlparse(database_url) + database_config = { + "user": parsed.user, + "password": parsed.password, + "database": parsed.database, + "host": parsed.host, + "port": parsed.port, + } + loop = asyncio.get_event_loop() + loop.run_until_complete(self.create_pool(database_config)) + + async def create_pool(self, database_config): + self.pool = await asyncpg.create_pool(**database_config) + + def __call__(self, scope): + scope["database"] = self.pool + return self.app(scope) diff --git a/starlette/multipart.py b/starlette/multipart.py new file mode 100644 index 00000000..c8965c10 --- /dev/null +++ b/starlette/multipart.py @@ -0,0 +1,365 @@ +_begin_form = "begin_form" +_begin_file = "begin_file" +_cont = "cont" +_end = "end" + + +class MultiPartParser(object): + def __init__( + self, + stream_factory=None, + charset="utf-8", + errors="replace", + max_form_memory_size=None, + cls=None, + buffer_size=64 * 1024, + ): + self.charset = charset + self.errors = errors + self.max_form_memory_size = max_form_memory_size + self.stream_factory = ( + default_stream_factory if stream_factory is None else stream_factory + ) + self.cls = MultiDict if cls is None else cls + + # make sure the buffer size is divisible by four so that we can base64 + # decode chunk by chunk + assert buffer_size % 4 == 0, "buffer size has to be divisible by 4" + # also the buffer size has to be at least 1024 bytes long or long headers + # will freak out the system + assert buffer_size >= 1024, "buffer size has to be at least 1KB" + + self.buffer_size = buffer_size + + def _fix_ie_filename(self, filename): + """Internet Explorer 6 transmits the full file name if a file is + uploaded. This function strips the full path if it thinks the + filename is Windows-like absolute. + """ + if filename[1:3] == ":\\" or filename[:2] == "\\\\": + return filename.split("\\")[-1] + return filename + + def _find_terminator(self, iterator): + """The terminator might have some additional newlines before it. + There is at least one application that sends additional newlines + before headers (the python setuptools package). + """ + for line in iterator: + if not line: + break + line = line.strip() + if line: + return line + return b"" + + def fail(self, message): + raise ValueError(message) + + def get_part_encoding(self, headers): + transfer_encoding = headers.get("content-transfer-encoding") + if ( + transfer_encoding is not None + and transfer_encoding in _supported_multipart_encodings + ): + return transfer_encoding + + def get_part_charset(self, headers): + # Figure out input charset for current part + content_type = headers.get("content-type") + if content_type: + mimetype, ct_params = parse_options_header(content_type) + return ct_params.get("charset", self.charset) + return self.charset + + def start_file_streaming(self, filename, headers, total_content_length): + if isinstance(filename, bytes): + filename = filename.decode(self.charset, self.errors) + filename = self._fix_ie_filename(filename) + content_type = headers.get("content-type") + try: + content_length = int(headers["content-length"]) + except (KeyError, ValueError): + content_length = 0 + container = self.stream_factory( + total_content_length, content_type, filename, content_length + ) + return filename, container + + def in_memory_threshold_reached(self, bytes): + raise exceptions.RequestEntityTooLarge() + + def validate_boundary(self, boundary): + if not boundary: + self.fail("Missing boundary") + if not is_valid_multipart_boundary(boundary): + self.fail("Invalid boundary: %s" % boundary) + if len(boundary) > self.buffer_size: # pragma: no cover + # this should never happen because we check for a minimum size + # of 1024 and boundaries may not be longer than 200. The only + # situation when this happens is for non debug builds where + # the assert is skipped. + self.fail("Boundary longer than buffer size") + + class LineSplitter(object): + def __init__(self, cap=None): + self.buffer = b"" + self.cap = cap + + def _splitlines(self, pre, post): + buf = pre + post + rv = [] + if not buf: + return rv, b"" + lines = buf.splitlines(True) + iv = b"" + for line in lines: + iv += line + while self.cap and len(iv) >= self.cap: + rv.append(iv[: self.cap]) + iv = iv[self.cap :] + if line[-1:] in b"\r\n": + rv.append(iv) + iv = b"" + # If this isn't the very end of the stream and what we got ends + # with \r, we need to hold on to it in case an \n comes next + if post and rv and not iv and rv[-1][-1:] == b"\r": + iv = rv[-1] + del rv[-1] + return rv, iv + + def feed(self, data): + lines, self.buffer = self._splitlines(self.buffer, data) + if not data: + lines += [self.buffer] + if self.buffer: + lines += [b""] + return lines + + class LineParser(object): + def __init__(self, parent, boundary): + self.parent = parent + self.boundary = boundary + self._next_part = b"--" + boundary + self._last_part = self._next_part + b"--" + self._state = self._state_pre_term + self._output = [] + self._headers = [] + self._tail = b"" + self._codec = None + + def _start_content(self): + disposition = self._headers.get("content-disposition") + if disposition is None: + raise ValueError("Missing Content-Disposition header") + self.disposition, extra = parse_options_header(disposition) + transfer_encoding = self.parent.get_part_encoding(self._headers) + if transfer_encoding is not None: + if transfer_encoding == "base64": + transfer_encoding = "base64_codec" + try: + self._codec = codecs.lookup(transfer_encoding) + except Exception: + raise ValueError( + "Cannot decode transfer-encoding: %r" % transfer_encoding + ) + self.name = extra.get("name") + self.filename = extra.get("filename") + if self.filename is not None: + self._output.append( + ("begin_file", (self._headers, self.name, self.filename)) + ) + else: + self._output.append(("begin_form", (self._headers, self.name))) + return self._state_output + + def _state_done(self, line): + return self._state_done + + def _state_output(self, line): + if not line: + raise ValueError("Unexpected end of file") + sline = line.rstrip() + if sline == self._last_part: + self._tail = b"" + self._output.append(("end", None)) + return self._state_done + elif sline == self._next_part: + self._tail = b"" + self._output.append(("end", None)) + self._headers = [] + return self._state_headers + + if self._codec: + try: + line, _ = self._codec.decode(line) + except Exception: + raise ValueError("Could not decode transfer-encoded chunk") + + # We don't know yet whether we can output the final newline, so + # we'll save it in self._tail and output it next time. + tail = self._tail + if line[-2:] == b"\r\n": + self._output.append(("cont", tail + line[:-2])) + self._tail = line[-2:] + elif line[-1:] in b"\r\n": + self._output.append(("cont", tail + line[:-1])) + self._tail = line[-1:] + else: + self._output.append(("cont", tail + line)) + self._tail = b"" + return self._state_output + + def _state_pre_term(self, line): + if not line: + raise ValueError("Unexpected end of file") + return self._state_pre_term + line = line.rstrip(b"\r\n") + if not line: + return self._state_pre_term + if line == self._last_part: + return self._state_done + elif line == self._next_part: + self._headers = [] + return self._state_headers + raise ValueError("Expected boundary at start of multipart data") + + def _state_headers(self, line): + if line is None: + raise ValueError("Unexpected end of file during headers") + line = to_native(line) + line, line_terminated = _line_parse(line) + if not line_terminated: + raise ValueError("Unexpected end of line in multipart header") + if not line: + self._headers = Headers(self._headers) + return self._start_content() + if line[0] in " \t" and self._headers: + key, value = self._headers[-1] + self._headers[-1] = (key, value + "\n " + line[1:]) + else: + parts = line.split(":", 1) + if len(parts) == 2: + self._headers.append((parts[0].strip(), parts[1].strip())) + else: + raise ValueError("Malformed header") + return self._state_headers + + def feed(self, lines): + self._output = [] + s = self._state + for line in lines: + s = s(line) + self._state = s + return self._output + + class PartParser(object): + def __init__(self, parent, content_length): + self.parent = parent + self.content_length = content_length + self._write = None + self._in_memory = 0 + self._guard_memory = False + + def _feed_one(self, event): + ev, data = event + p = self.parent + if ev == "begin_file": + self._headers, self._name, filename = data + self._filename, self._container = p.start_file_streaming( + filename, self._headers, self.content_length + ) + self._write = self._container.write + self._is_file = True + self._guard_memory = False + elif ev == "begin_form": + self._headers, self._name = data + self._container = [] + self._write = self._container.append + self._is_file = False + self._guard_memory = p.max_form_memory_size is not None + elif ev == "cont": + self._write(data) + if self._guard_memory: + self._in_memory += len(data) + if self._in_memory > p.max_form_memory_size: + p.in_memory_threshold_reached(self._in_memory) + elif ev == "end": + if self._is_file: + self._container.seek(0) + return ( + "file", + ( + self._name, + FileStorage( + self._container, + self._filename, + self._name, + headers=self._headers, + ), + ), + ) + else: + part_charset = p.get_part_charset(self._headers) + return ( + "form", + ( + self._name, + b"".join(self._container).decode(part_charset, p.errors), + ), + ) + + def feed(self, events): + rv = [] + for event in events: + v = self._feed_one(event) + if v is not None: + rv.append(v) + return rv + + def parse_lines(self, file, boundary, content_length, cap_at_buffer=True): + """Generate parts of + ``('begin_form', (headers, name))`` + ``('begin_file', (headers, name, filename))`` + ``('cont', bytestring)`` + ``('end', None)`` + Always obeys the grammar + parts = ( begin_form cont* end | + begin_file cont* end )* + """ + + line_splitter = self.LineSplitter(self.buffer_size if cap_at_buffer else None) + line_parser = self.LineParser(self, boundary) + while True: + buf = file.read(self.buffer_size) + lines = line_splitter.feed(buf) + parts = line_parser.feed(lines) + for part in parts: + yield part + if buf == b"": + break + + def parse_parts(self, file, boundary, content_length): + """Generate ``('file', (name, val))`` and + ``('form', (name, val))`` parts. + """ + line_splitter = self.LineSplitter() + line_parser = self.LineParser(self, boundary) + part_parser = self.PartParser(self, content_length) + while True: + buf = file.read(self.buffer_size) + lines = line_splitter.feed(buf) + parts = line_parser.feed(lines) + events = part_parser.feed(parts) + for event in events: + yield event + if buf == b"": + break + + def parse(self, file, boundary, content_length): + formstream, filestream = tee( + self.parse_parts(file, boundary, content_length), 2 + ) + form = (p[1] for p in formstream if p[0] == "form") + files = (p[1] for p in filestream if p[0] == "file") + return self.cls(form), self.cls(files) diff --git a/starlette/response.py b/starlette/response.py index a3bc5a4b..c3c98197 100644 --- a/starlette/response.py +++ b/starlette/response.py @@ -1,5 +1,7 @@ +from mimetypes import guess_type from starlette.datastructures import MutableHeaders from starlette.types import Receive, Send +import aiofiles import json import typing import os @@ -18,36 +20,47 @@ class Response: ) -> None: self.body = self.render(content) self.status_code = status_code - self.media_type = self.media_type if media_type is None else media_type - self.raw_headers = [] if headers is None else [ - (k.lower().encode("latin-1"), v.encode("latin-1")) - for k, v in headers.items() - ] - self.headers = MutableHeaders(self.raw_headers) - self.set_default_headers() + if media_type is not None: + self.media_type = media_type + self.init_headers(headers) def render(self, content: typing.Any) -> bytes: if isinstance(content, bytes): return content return content.encode(self.charset) - def set_default_headers(self): - content_length = str(len(self.body)) if hasattr(self, 'body') else None - content_type = self.default_content_type - - if content_length is not None: - self.headers.set_default("content-length", content_length) - if content_type is not None: - self.headers.set_default("content-type", content_type) + def init_headers(self, headers): + if headers is None: + raw_headers = [] + populate_content_length = True + populate_content_type = True + else: + raw_headers = [ + (k.lower().encode("latin-1"), v.encode("latin-1")) + for k, v in headers.items() + ] + keys = [h[0] for h in raw_headers] + populate_content_length = b"content-length" in keys + populate_content_type = b"content-type" in keys + + body = getattr(self, "body", None) + if body is not None and populate_content_length: + content_length = str(len(body)) + raw_headers.append((b"content-length", content_length.encode("latin-1"))) + + content_type = self.media_type + if content_type is not None and populate_content_type: + if content_type.startswith("text/"): + content_type += "; charset=" + self.charset + raw_headers.append((b"content-type", content_type.encode("latin-1"))) + + self.raw_headers = raw_headers @property - def default_content_type(self): - if self.media_type is None: - return None - - if self.media_type.startswith('text/') and self.charset is not None: - return '%s; charset=%s' % (self.media_type, self.charset) - return self.media_type + def headers(self): + if not hasattr(self, "_headers"): + self._headers = MutableHeaders(self.raw_headers) + return self._headers async def __call__(self, receive: Receive, send: Send) -> None: await send( @@ -92,21 +105,14 @@ class StreamingResponse(Response): self.body_iterator = content self.status_code = status_code self.media_type = self.media_type if media_type is None else media_type - self.raw_headers = [] if headers is None else [ - (k.lower().encode("latin-1"), v.encode("latin-1")) - for k, v in headers.items() - ] - self.headers = MutableHeaders(self.raw_headers) - self.set_default_headers() + self.init_headers(headers) async def __call__(self, receive: Receive, send: Send) -> None: await send( { "type": "http.response.start", "status": self.status_code, - "headers": [ - [key.encode(), value.encode()] for key, value in self.headers - ], + "headers": self.raw_headers, } ) async for chunk in self.body_iterator: @@ -115,22 +121,41 @@ class StreamingResponse(Response): await send({"type": "http.response.body", "body": chunk, "more_body": True}) await send({"type": "http.response.body", "body": b"", "more_body": False}) -# -# class FileResponse: -# def __init__( -# self, -# path: str, -# headers: dict = None, -# media_type: str = None, -# filename: str = None -# ) -> None: -# self.path = path -# self.status_code = 200 -# if media_type is not None: -# self.media_type = media_type -# if filename is not None: -# self.filename = filename -# else: -# self.filename = os.path.basename(path) -# -# self.set_default_headers(headers) + +class FileResponse(Response): + chunk_size = 4096 + + def __init__( + self, + path: str, + headers: dict = None, + media_type: str = None, + filename: str = None, + ) -> None: + self.path = path + self.status_code = 200 + self.filename = filename + if media_type is None: + media_type = guess_type(filename or path)[0] or "text/plain" + self.media_type = media_type + self.init_headers(headers) + if self.filename is not None: + content_disposition = 'attachment; filename="{}"'.format(self.filename) + self.headers.setdefault("content-disposition", content_disposition) + + async def __call__(self, receive: Receive, send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + async with aiofiles.open(self.path, mode="rb") as file: + more_body = True + while more_body: + chunk = await file.read(self.chunk_size) + more_body = len(chunk) == self.chunk_size + await send( + {"type": "http.response.body", "body": chunk, "more_body": False} + ) -- 2.47.3