]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add FileResponse
authorTom Christie <tom@tomchristie.com>
Wed, 11 Jul 2018 15:08:51 +0000 (16:08 +0100)
committerTom Christie <tom@tomchristie.com>
Wed, 11 Jul 2018 15:08:51 +0000 (16:08 +0100)
requirements.txt
setup.py
starlette/__init__.py
starlette/datastructures.py
starlette/middleware/__init__.py [new file with mode: 0644]
starlette/middleware/database.py [new file with mode: 0644]
starlette/multipart.py [new file with mode: 0644]
starlette/response.py

index 96b8130dea7b41e4223ff6660084825c1614f64f..fbec2e5117f4f1dfbe6fc23866ed7fb634e06bbf 100644 (file)
@@ -1,3 +1,4 @@
+aiofiles
 requests
 
 # Testing
index 3777b4ec9864b1785ca064312b4a54d7c73edd41..646c76139d1ab972482610d85fbbca68bf7803df 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -44,6 +44,7 @@ setup(
     author_email='tom@tomchristie.com',
     packages=get_packages('starlette'),
     install_requires=[
+        'aiofiles',
         'requests',
     ],
     classifiers=[
index 22ae9352188f8533423ceebe6dac705621579b30..33d95e7c353d5891b84de77d62455b5a52fd2e20 100644 (file)
@@ -1,5 +1,6 @@
 from starlette.decorators import asgi_application
 from starlette.response import (
+    FileResponse,
     HTMLResponse,
     JSONResponse,
     Response,
@@ -13,6 +14,7 @@ from starlette.testclient import TestClient
 
 __all__ = (
     "asgi_application",
+    "FileResponse",
     "HTMLResponse",
     "JSONResponse",
     "Path",
index ba601764da1b906a84da6d610e377c91df9f0392..71c52c04d43682e902e53f30946185eaee0ebf4c 100644 (file)
@@ -119,16 +119,15 @@ class Headers(typing.Mapping[str, str]):
     An immutable, case-insensitive multidict.
     """
 
-    def __init__(self, value: typing.Union[StrDict, StrPairs] = None) -> None:
-        if value is None:
+    def __init__(self, raw_headers=None) -> None:
+        if raw_headers is None:
             self._list = []
         else:
-            assert isinstance(value, list)
-            for header_key, header_value in value:
+            for header_key, header_value in raw_headers:
                 assert isinstance(header_key, bytes)
                 assert isinstance(header_value, bytes)
                 assert header_key == header_key.lower()
-            self._list = value
+            self._list = raw_headers
 
     def keys(self):
         return [key.decode("latin-1") for key, value in self._list]
@@ -148,7 +147,7 @@ class Headers(typing.Mapping[str, str]):
         except KeyError:
             return default
 
-    def get_list(self, key: str) -> typing.List[str]:
+    def getlist(self, key: str) -> typing.List[str]:
         get_header_key = key.lower().encode("latin-1")
         return [
             item_value.decode("latin-1")
@@ -164,7 +163,11 @@ class Headers(typing.Mapping[str, str]):
         raise KeyError(key)
 
     def __contains__(self, key: str):
-        return key.lower() in self.keys()
+        get_header_key = key.lower().encode("latin-1")
+        for header_key, header_value in self._list:
+            if header_key == get_header_key:
+                return True
+        return False
 
     def __iter__(self):
         return iter(self.items())
@@ -183,6 +186,9 @@ class Headers(typing.Mapping[str, str]):
 
 class MutableHeaders(Headers):
     def __setitem__(self, key: str, value: str):
+        """
+        Set the header `key` to `value`, removing any duplicate entries.
+        """
         set_key = key.lower().encode("latin-1")
         set_value = value.encode("latin-1")
 
@@ -196,19 +202,30 @@ class MutableHeaders(Headers):
 
         self._list.append((set_key, set_value))
 
-    def set_default(self, key: str, value: str):
-        set_key = key.lower().encode("latin-1")
-        set_value = value.encode("latin-1")
+    def __delitem__(self, key: str):
+        """
+        Remove the header `key`.
+        """
+        del_key = key.lower().encode("latin-1")
 
-        is_set = False
         pop_indexes = []
         for idx, (item_key, item_value) in enumerate(self._list):
-            if item_key == set_key:
-                if not is_set:
-                    is_set = True
-                    self._list[idx] = set_value
-                else:
-                    pop_indexes.append(idx)
+            if item_key == del_key:
+                pop_indexes.append(idx)
 
         for idx in reversed(pop_indexes):
-            del self._list[idx]
+            del (self._list[idx])
+
+    def setdefault(self, key: str, value: str):
+        """
+        If the header `key` does not exist, then set it to `value`.
+        Returns the header value.
+        """
+        set_key = key.lower().encode("latin-1")
+        set_value = value.encode("latin-1")
+
+        for idx, (item_key, item_value) in enumerate(self._list):
+            if item_key == set_key:
+                return item_value.decode("latin-1")
+        self._list.append((set_key, set_value))
+        return value
diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py
new file mode 100644 (file)
index 0000000..bdd37f2
--- /dev/null
@@ -0,0 +1,4 @@
+from starlette.middleware.database import DatabaseMiddleware
+
+
+__all__ = ["DatabaseMiddleware"]
diff --git a/starlette/middleware/database.py b/starlette/middleware/database.py
new file mode 100644 (file)
index 0000000..9fb5303
--- /dev/null
@@ -0,0 +1,27 @@
+from functools import partial
+from urllib.parse import urlparse
+import asyncio
+import asyncpg
+
+
+class DatabaseMiddleware:
+    def __init__(self, app, database_url=None, database_config=None):
+        self.app = app
+        if database_config is None:
+            parsed = urlparse(database_url)
+            database_config = {
+                "user": parsed.user,
+                "password": parsed.password,
+                "database": parsed.database,
+                "host": parsed.host,
+                "port": parsed.port,
+            }
+        loop = asyncio.get_event_loop()
+        loop.run_until_complete(self.create_pool(database_config))
+
+    async def create_pool(self, database_config):
+        self.pool = await asyncpg.create_pool(**database_config)
+
+    def __call__(self, scope):
+        scope["database"] = self.pool
+        return self.app(scope)
diff --git a/starlette/multipart.py b/starlette/multipart.py
new file mode 100644 (file)
index 0000000..c8965c1
--- /dev/null
@@ -0,0 +1,365 @@
+_begin_form = "begin_form"
+_begin_file = "begin_file"
+_cont = "cont"
+_end = "end"
+
+
+class MultiPartParser(object):
+    def __init__(
+        self,
+        stream_factory=None,
+        charset="utf-8",
+        errors="replace",
+        max_form_memory_size=None,
+        cls=None,
+        buffer_size=64 * 1024,
+    ):
+        self.charset = charset
+        self.errors = errors
+        self.max_form_memory_size = max_form_memory_size
+        self.stream_factory = (
+            default_stream_factory if stream_factory is None else stream_factory
+        )
+        self.cls = MultiDict if cls is None else cls
+
+        # make sure the buffer size is divisible by four so that we can base64
+        # decode chunk by chunk
+        assert buffer_size % 4 == 0, "buffer size has to be divisible by 4"
+        # also the buffer size has to be at least 1024 bytes long or long headers
+        # will freak out the system
+        assert buffer_size >= 1024, "buffer size has to be at least 1KB"
+
+        self.buffer_size = buffer_size
+
+    def _fix_ie_filename(self, filename):
+        """Internet Explorer 6 transmits the full file name if a file is
+        uploaded.  This function strips the full path if it thinks the
+        filename is Windows-like absolute.
+        """
+        if filename[1:3] == ":\\" or filename[:2] == "\\\\":
+            return filename.split("\\")[-1]
+        return filename
+
+    def _find_terminator(self, iterator):
+        """The terminator might have some additional newlines before it.
+        There is at least one application that sends additional newlines
+        before headers (the python setuptools package).
+        """
+        for line in iterator:
+            if not line:
+                break
+            line = line.strip()
+            if line:
+                return line
+        return b""
+
+    def fail(self, message):
+        raise ValueError(message)
+
+    def get_part_encoding(self, headers):
+        transfer_encoding = headers.get("content-transfer-encoding")
+        if (
+            transfer_encoding is not None
+            and transfer_encoding in _supported_multipart_encodings
+        ):
+            return transfer_encoding
+
+    def get_part_charset(self, headers):
+        # Figure out input charset for current part
+        content_type = headers.get("content-type")
+        if content_type:
+            mimetype, ct_params = parse_options_header(content_type)
+            return ct_params.get("charset", self.charset)
+        return self.charset
+
+    def start_file_streaming(self, filename, headers, total_content_length):
+        if isinstance(filename, bytes):
+            filename = filename.decode(self.charset, self.errors)
+        filename = self._fix_ie_filename(filename)
+        content_type = headers.get("content-type")
+        try:
+            content_length = int(headers["content-length"])
+        except (KeyError, ValueError):
+            content_length = 0
+        container = self.stream_factory(
+            total_content_length, content_type, filename, content_length
+        )
+        return filename, container
+
+    def in_memory_threshold_reached(self, bytes):
+        raise exceptions.RequestEntityTooLarge()
+
+    def validate_boundary(self, boundary):
+        if not boundary:
+            self.fail("Missing boundary")
+        if not is_valid_multipart_boundary(boundary):
+            self.fail("Invalid boundary: %s" % boundary)
+        if len(boundary) > self.buffer_size:  # pragma: no cover
+            # this should never happen because we check for a minimum size
+            # of 1024 and boundaries may not be longer than 200.  The only
+            # situation when this happens is for non debug builds where
+            # the assert is skipped.
+            self.fail("Boundary longer than buffer size")
+
+    class LineSplitter(object):
+        def __init__(self, cap=None):
+            self.buffer = b""
+            self.cap = cap
+
+        def _splitlines(self, pre, post):
+            buf = pre + post
+            rv = []
+            if not buf:
+                return rv, b""
+            lines = buf.splitlines(True)
+            iv = b""
+            for line in lines:
+                iv += line
+                while self.cap and len(iv) >= self.cap:
+                    rv.append(iv[: self.cap])
+                    iv = iv[self.cap :]
+                if line[-1:] in b"\r\n":
+                    rv.append(iv)
+                    iv = b""
+            # If this isn't the very end of the stream and what we got ends
+            # with \r, we need to hold on to it in case an \n comes next
+            if post and rv and not iv and rv[-1][-1:] == b"\r":
+                iv = rv[-1]
+                del rv[-1]
+            return rv, iv
+
+        def feed(self, data):
+            lines, self.buffer = self._splitlines(self.buffer, data)
+            if not data:
+                lines += [self.buffer]
+                if self.buffer:
+                    lines += [b""]
+            return lines
+
+    class LineParser(object):
+        def __init__(self, parent, boundary):
+            self.parent = parent
+            self.boundary = boundary
+            self._next_part = b"--" + boundary
+            self._last_part = self._next_part + b"--"
+            self._state = self._state_pre_term
+            self._output = []
+            self._headers = []
+            self._tail = b""
+            self._codec = None
+
+        def _start_content(self):
+            disposition = self._headers.get("content-disposition")
+            if disposition is None:
+                raise ValueError("Missing Content-Disposition header")
+            self.disposition, extra = parse_options_header(disposition)
+            transfer_encoding = self.parent.get_part_encoding(self._headers)
+            if transfer_encoding is not None:
+                if transfer_encoding == "base64":
+                    transfer_encoding = "base64_codec"
+                try:
+                    self._codec = codecs.lookup(transfer_encoding)
+                except Exception:
+                    raise ValueError(
+                        "Cannot decode transfer-encoding: %r" % transfer_encoding
+                    )
+            self.name = extra.get("name")
+            self.filename = extra.get("filename")
+            if self.filename is not None:
+                self._output.append(
+                    ("begin_file", (self._headers, self.name, self.filename))
+                )
+            else:
+                self._output.append(("begin_form", (self._headers, self.name)))
+            return self._state_output
+
+        def _state_done(self, line):
+            return self._state_done
+
+        def _state_output(self, line):
+            if not line:
+                raise ValueError("Unexpected end of file")
+            sline = line.rstrip()
+            if sline == self._last_part:
+                self._tail = b""
+                self._output.append(("end", None))
+                return self._state_done
+            elif sline == self._next_part:
+                self._tail = b""
+                self._output.append(("end", None))
+                self._headers = []
+                return self._state_headers
+
+            if self._codec:
+                try:
+                    line, _ = self._codec.decode(line)
+                except Exception:
+                    raise ValueError("Could not decode transfer-encoded chunk")
+
+            # We don't know yet whether we can output the final newline, so
+            # we'll save it in self._tail and output it next time.
+            tail = self._tail
+            if line[-2:] == b"\r\n":
+                self._output.append(("cont", tail + line[:-2]))
+                self._tail = line[-2:]
+            elif line[-1:] in b"\r\n":
+                self._output.append(("cont", tail + line[:-1]))
+                self._tail = line[-1:]
+            else:
+                self._output.append(("cont", tail + line))
+                self._tail = b""
+            return self._state_output
+
+        def _state_pre_term(self, line):
+            if not line:
+                raise ValueError("Unexpected end of file")
+                return self._state_pre_term
+            line = line.rstrip(b"\r\n")
+            if not line:
+                return self._state_pre_term
+            if line == self._last_part:
+                return self._state_done
+            elif line == self._next_part:
+                self._headers = []
+                return self._state_headers
+            raise ValueError("Expected boundary at start of multipart data")
+
+        def _state_headers(self, line):
+            if line is None:
+                raise ValueError("Unexpected end of file during headers")
+            line = to_native(line)
+            line, line_terminated = _line_parse(line)
+            if not line_terminated:
+                raise ValueError("Unexpected end of line in multipart header")
+            if not line:
+                self._headers = Headers(self._headers)
+                return self._start_content()
+            if line[0] in " \t" and self._headers:
+                key, value = self._headers[-1]
+                self._headers[-1] = (key, value + "\n " + line[1:])
+            else:
+                parts = line.split(":", 1)
+                if len(parts) == 2:
+                    self._headers.append((parts[0].strip(), parts[1].strip()))
+                else:
+                    raise ValueError("Malformed header")
+            return self._state_headers
+
+        def feed(self, lines):
+            self._output = []
+            s = self._state
+            for line in lines:
+                s = s(line)
+            self._state = s
+            return self._output
+
+    class PartParser(object):
+        def __init__(self, parent, content_length):
+            self.parent = parent
+            self.content_length = content_length
+            self._write = None
+            self._in_memory = 0
+            self._guard_memory = False
+
+        def _feed_one(self, event):
+            ev, data = event
+            p = self.parent
+            if ev == "begin_file":
+                self._headers, self._name, filename = data
+                self._filename, self._container = p.start_file_streaming(
+                    filename, self._headers, self.content_length
+                )
+                self._write = self._container.write
+                self._is_file = True
+                self._guard_memory = False
+            elif ev == "begin_form":
+                self._headers, self._name = data
+                self._container = []
+                self._write = self._container.append
+                self._is_file = False
+                self._guard_memory = p.max_form_memory_size is not None
+            elif ev == "cont":
+                self._write(data)
+                if self._guard_memory:
+                    self._in_memory += len(data)
+                    if self._in_memory > p.max_form_memory_size:
+                        p.in_memory_threshold_reached(self._in_memory)
+            elif ev == "end":
+                if self._is_file:
+                    self._container.seek(0)
+                    return (
+                        "file",
+                        (
+                            self._name,
+                            FileStorage(
+                                self._container,
+                                self._filename,
+                                self._name,
+                                headers=self._headers,
+                            ),
+                        ),
+                    )
+                else:
+                    part_charset = p.get_part_charset(self._headers)
+                    return (
+                        "form",
+                        (
+                            self._name,
+                            b"".join(self._container).decode(part_charset, p.errors),
+                        ),
+                    )
+
+        def feed(self, events):
+            rv = []
+            for event in events:
+                v = self._feed_one(event)
+                if v is not None:
+                    rv.append(v)
+            return rv
+
+    def parse_lines(self, file, boundary, content_length, cap_at_buffer=True):
+        """Generate parts of
+        ``('begin_form', (headers, name))``
+        ``('begin_file', (headers, name, filename))``
+        ``('cont', bytestring)``
+        ``('end', None)``
+        Always obeys the grammar
+        parts = ( begin_form cont* end |
+                  begin_file cont* end )*
+        """
+
+        line_splitter = self.LineSplitter(self.buffer_size if cap_at_buffer else None)
+        line_parser = self.LineParser(self, boundary)
+        while True:
+            buf = file.read(self.buffer_size)
+            lines = line_splitter.feed(buf)
+            parts = line_parser.feed(lines)
+            for part in parts:
+                yield part
+            if buf == b"":
+                break
+
+    def parse_parts(self, file, boundary, content_length):
+        """Generate ``('file', (name, val))`` and
+        ``('form', (name, val))`` parts.
+        """
+        line_splitter = self.LineSplitter()
+        line_parser = self.LineParser(self, boundary)
+        part_parser = self.PartParser(self, content_length)
+        while True:
+            buf = file.read(self.buffer_size)
+            lines = line_splitter.feed(buf)
+            parts = line_parser.feed(lines)
+            events = part_parser.feed(parts)
+            for event in events:
+                yield event
+            if buf == b"":
+                break
+
+    def parse(self, file, boundary, content_length):
+        formstream, filestream = tee(
+            self.parse_parts(file, boundary, content_length), 2
+        )
+        form = (p[1] for p in formstream if p[0] == "form")
+        files = (p[1] for p in filestream if p[0] == "file")
+        return self.cls(form), self.cls(files)
index a3bc5a4b61dfacf260064a76ad1ab89cd2bfe938..c3c98197b67445198b139e81c2e38f22d0248a06 100644 (file)
@@ -1,5 +1,7 @@
+from mimetypes import guess_type
 from starlette.datastructures import MutableHeaders
 from starlette.types import Receive, Send
+import aiofiles
 import json
 import typing
 import os
@@ -18,36 +20,47 @@ class Response:
     ) -> None:
         self.body = self.render(content)
         self.status_code = status_code
-        self.media_type = self.media_type if media_type is None else media_type
-        self.raw_headers = [] if headers is None else [
-            (k.lower().encode("latin-1"), v.encode("latin-1"))
-            for k, v in headers.items()
-        ]
-        self.headers = MutableHeaders(self.raw_headers)
-        self.set_default_headers()
+        if media_type is not None:
+            self.media_type = media_type
+        self.init_headers(headers)
 
     def render(self, content: typing.Any) -> bytes:
         if isinstance(content, bytes):
             return content
         return content.encode(self.charset)
 
-    def set_default_headers(self):
-        content_length = str(len(self.body)) if hasattr(self, 'body') else None
-        content_type = self.default_content_type
-
-        if content_length is not None:
-            self.headers.set_default("content-length", content_length)
-        if content_type is not None:
-            self.headers.set_default("content-type", content_type)
+    def init_headers(self, headers):
+        if headers is None:
+            raw_headers = []
+            populate_content_length = True
+            populate_content_type = True
+        else:
+            raw_headers = [
+                (k.lower().encode("latin-1"), v.encode("latin-1"))
+                for k, v in headers.items()
+            ]
+            keys = [h[0] for h in raw_headers]
+            populate_content_length = b"content-length" in keys
+            populate_content_type = b"content-type" in keys
+
+        body = getattr(self, "body", None)
+        if body is not None and populate_content_length:
+            content_length = str(len(body))
+            raw_headers.append((b"content-length", content_length.encode("latin-1")))
+
+        content_type = self.media_type
+        if content_type is not None and populate_content_type:
+            if content_type.startswith("text/"):
+                content_type += "; charset=" + self.charset
+            raw_headers.append((b"content-type", content_type.encode("latin-1")))
+
+        self.raw_headers = raw_headers
 
     @property
-    def default_content_type(self):
-        if self.media_type is None:
-            return None
-
-        if self.media_type.startswith('text/') and self.charset is not None:
-            return '%s; charset=%s' % (self.media_type, self.charset)
-        return self.media_type
+    def headers(self):
+        if not hasattr(self, "_headers"):
+            self._headers = MutableHeaders(self.raw_headers)
+        return self._headers
 
     async def __call__(self, receive: Receive, send: Send) -> None:
         await send(
@@ -92,21 +105,14 @@ class StreamingResponse(Response):
         self.body_iterator = content
         self.status_code = status_code
         self.media_type = self.media_type if media_type is None else media_type
-        self.raw_headers = [] if headers is None else [
-            (k.lower().encode("latin-1"), v.encode("latin-1"))
-            for k, v in headers.items()
-        ]
-        self.headers = MutableHeaders(self.raw_headers)
-        self.set_default_headers()
+        self.init_headers(headers)
 
     async def __call__(self, receive: Receive, send: Send) -> None:
         await send(
             {
                 "type": "http.response.start",
                 "status": self.status_code,
-                "headers": [
-                    [key.encode(), value.encode()] for key, value in self.headers
-                ],
+                "headers": self.raw_headers,
             }
         )
         async for chunk in self.body_iterator:
@@ -115,22 +121,41 @@ class StreamingResponse(Response):
             await send({"type": "http.response.body", "body": chunk, "more_body": True})
         await send({"type": "http.response.body", "body": b"", "more_body": False})
 
-#
-# class FileResponse:
-#     def __init__(
-#         self,
-#         path: str,
-#         headers: dict = None,
-#         media_type: str = None,
-#         filename: str = None
-#     ) -> None:
-#         self.path = path
-#         self.status_code = 200
-#         if media_type is not None:
-#             self.media_type = media_type
-#         if filename is not None:
-#             self.filename = filename
-#         else:
-#             self.filename = os.path.basename(path)
-#
-#         self.set_default_headers(headers)
+
+class FileResponse(Response):
+    chunk_size = 4096
+
+    def __init__(
+        self,
+        path: str,
+        headers: dict = None,
+        media_type: str = None,
+        filename: str = None,
+    ) -> None:
+        self.path = path
+        self.status_code = 200
+        self.filename = filename
+        if media_type is None:
+            media_type = guess_type(filename or path)[0] or "text/plain"
+        self.media_type = media_type
+        self.init_headers(headers)
+        if self.filename is not None:
+            content_disposition = 'attachment; filename="{}"'.format(self.filename)
+            self.headers.setdefault("content-disposition", content_disposition)
+
+    async def __call__(self, receive: Receive, send: Send) -> None:
+        await send(
+            {
+                "type": "http.response.start",
+                "status": self.status_code,
+                "headers": self.raw_headers,
+            }
+        )
+        async with aiofiles.open(self.path, mode="rb") as file:
+            more_body = True
+            while more_body:
+                chunk = await file.read(self.chunk_size)
+                more_body = len(chunk) == self.chunk_size
+                await send(
+                    {"type": "http.response.body", "body": chunk, "more_body": False}
+                )