]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add set_default
authorTom Christie <tom@tomchristie.com>
Wed, 11 Jul 2018 12:16:03 +0000 (13:16 +0100)
committerTom Christie <tom@tomchristie.com>
Wed, 11 Jul 2018 12:16:03 +0000 (13:16 +0100)
starlette/datastructures.py
starlette/response.py

index cfb24bf923e8468fe26994b2a20b99e8e57dca12..ba601764da1b906a84da6d610e377c91df9f0392 100644 (file)
@@ -195,3 +195,20 @@ class MutableHeaders(Headers):
             del self._list[idx]
 
         self._list.append((set_key, set_value))
+
+    def set_default(self, key: str, value: str):
+        set_key = key.lower().encode("latin-1")
+        set_value = value.encode("latin-1")
+
+        is_set = False
+        pop_indexes = []
+        for idx, (item_key, item_value) in enumerate(self._list):
+            if item_key == set_key:
+                if not is_set:
+                    is_set = True
+                    self._list[idx] = set_value
+                else:
+                    pop_indexes.append(idx)
+
+        for idx in reversed(pop_indexes):
+            del self._list[idx]
index 28d0efd8bff2f916bb779ab412d65e1c56362f14..a3bc5a4b61dfacf260064a76ad1ab89cd2bfe938 100644 (file)
@@ -2,6 +2,7 @@ from starlette.datastructures import MutableHeaders
 from starlette.types import Receive, Send
 import json
 import typing
+import os
 
 
 class Response:
@@ -17,46 +18,36 @@ class Response:
     ) -> None:
         self.body = self.render(content)
         self.status_code = status_code
-        if media_type is not None:
-            self.media_type = media_type
-        self.set_default_headers(headers)
+        self.media_type = self.media_type if media_type is None else media_type
+        self.raw_headers = [] if headers is None else [
+            (k.lower().encode("latin-1"), v.encode("latin-1"))
+            for k, v in headers.items()
+        ]
+        self.headers = MutableHeaders(self.raw_headers)
+        self.set_default_headers()
 
     def render(self, content: typing.Any) -> bytes:
         if isinstance(content, bytes):
             return content
         return content.encode(self.charset)
 
-    def set_default_headers(self, headers: dict = None):
-        if headers is None:
-            raw_headers = []
-            missing_content_length = True
-            missing_content_type = True
-        else:
-            raw_headers = [
-                (k.lower().encode("latin-1"), v.encode("latin-1"))
-                for k, v in headers.items()
-            ]
-            missing_content_length = "content-length" not in headers
-            missing_content_type = "content-type" not in headers
-
-        if missing_content_length:
-            content_length = str(len(self.body)).encode()
-            raw_headers.append((b"content-length", content_length))
-
-        if self.media_type is not None and missing_content_type:
-            content_type = self.media_type
-            if content_type.startswith("text/") and self.charset is not None:
-                content_type += "; charset=%s" % self.charset
-            content_type_value = content_type.encode("latin-1")
-            raw_headers.append((b"content-type", content_type_value))
-
-        self.raw_headers = raw_headers
+    def set_default_headers(self):
+        content_length = str(len(self.body)) if hasattr(self, 'body') else None
+        content_type = self.default_content_type
+
+        if content_length is not None:
+            self.headers.set_default("content-length", content_length)
+        if content_type is not None:
+            self.headers.set_default("content-type", content_type)
 
     @property
-    def headers(self):
-        if not hasattr(self, "_headers"):
-            self._headers = MutableHeaders(self.raw_headers)
-        return self._headers
+    def default_content_type(self):
+        if self.media_type is None:
+            return None
+
+        if self.media_type.startswith('text/') and self.charset is not None:
+            return '%s; charset=%s' % (self.media_type, self.charset)
+        return self.media_type
 
     async def __call__(self, receive: Receive, send: Send) -> None:
         await send(
@@ -100,9 +91,13 @@ class StreamingResponse(Response):
     ) -> None:
         self.body_iterator = content
         self.status_code = status_code
-        if media_type is not None:
-            self.media_type = media_type
-        self.set_default_headers(headers)
+        self.media_type = self.media_type if media_type is None else media_type
+        self.raw_headers = [] if headers is None else [
+            (k.lower().encode("latin-1"), v.encode("latin-1"))
+            for k, v in headers.items()
+        ]
+        self.headers = MutableHeaders(self.raw_headers)
+        self.set_default_headers()
 
     async def __call__(self, receive: Receive, send: Send) -> None:
         await send(
@@ -120,22 +115,22 @@ class StreamingResponse(Response):
             await send({"type": "http.response.body", "body": chunk, "more_body": True})
         await send({"type": "http.response.body", "body": b"", "more_body": False})
 
-    def set_default_headers(self, headers: dict = None):
-        if headers is None:
-            raw_headers = []
-            missing_content_type = True
-        else:
-            raw_headers = [
-                (k.lower().encode("latin-1"), v.encode("latin-1"))
-                for k, v in headers.items()
-            ]
-            missing_content_type = "content-type" not in headers
-
-        if self.media_type is not None and missing_content_type:
-            content_type = self.media_type
-            if content_type.startswith("text/") and self.charset is not None:
-                content_type += "; charset=%s" % self.charset
-            content_type_value = content_type.encode("latin-1")
-            raw_headers.append((b"content-type", content_type_value))
-
-        self.raw_headers = raw_headers
+#
+# class FileResponse:
+#     def __init__(
+#         self,
+#         path: str,
+#         headers: dict = None,
+#         media_type: str = None,
+#         filename: str = None
+#     ) -> None:
+#         self.path = path
+#         self.status_code = 200
+#         if media_type is not None:
+#             self.media_type = media_type
+#         if filename is not None:
+#             self.filename = filename
+#         else:
+#             self.filename = os.path.basename(path)
+#
+#         self.set_default_headers(headers)