]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add support for server push 629/head
authorJeremy Lainé <jeremy.laine@m4x.org>
Sun, 1 Sep 2019 22:36:19 +0000 (00:36 +0200)
committerJeremy Lainé <jeremy.laine@m4x.org>
Mon, 2 Sep 2019 12:01:30 +0000 (14:01 +0200)
This adds support for HTTP/2 and HTTP/3 server push by adding a
Request.send_push_promise method, which signals to push-enabled
servers that a push should be sent.

docs/server-push.md [new file with mode: 0644]
mkdocs.yml
starlette/requests.py
starlette/routing.py
tests/test_requests.py

diff --git a/docs/server-push.md b/docs/server-push.md
new file mode 100644 (file)
index 0000000..ba4d31a
--- /dev/null
@@ -0,0 +1,34 @@
+
+Starlette includes support for HTTP/2 and HTTP/3 server push, making it
+possible to push resources to the client to speed up page load times.
+
+### `Request.send_push_promise`
+
+Used to initiate a server push for a resource. If server push is not available
+this method does nothing.
+
+Signature: `send_push_promise(path)`
+
+* `path` - A string denoting the path of the resource.
+
+```python
+from starlette.applications import Starlette
+from starlette.responses import HTMLResponse
+from starlette.staticfiles import StaticFiles
+
+app = Starlette()
+
+
+@app.route("/")
+async def homepage(request):
+    """
+    Homepage which uses server push to deliver the stylesheet.
+    """
+    await request.send_push_promise("/static/style.css")
+    return HTMLResponse(
+        '<html><head><link rel="stylesheet" href="/static/style.css"/></head></html>'
+    )
+
+
+app.mount("/static", StaticFiles(directory="static"))
+```
index 2ec6a97c642b6dde0ae31f437fea6c6e446a293f..d35af5d1369212afacd7e4670fedc2375a42b6f9 100644 (file)
@@ -25,6 +25,7 @@ nav:
     - API Schemas: 'schemas.md'
     - Events: 'events.md'
     - Background Tasks: 'background.md'
+    - Server Push: 'server-push.md'
     - Exceptions: 'exceptions.md'
     - Configuration: 'config.md'
     - Test Client: 'testclient.md'
index a71c90753412fac0ae53c9a5791bd56deeaca875..a831a8936827b6d7231bbaefc4b2f68011f71d63 100644 (file)
@@ -6,7 +6,7 @@ from collections.abc import Mapping
 
 from starlette.datastructures import URL, Address, FormData, Headers, QueryParams, State
 from starlette.formparsers import FormParser, MultiPartParser
-from starlette.types import Message, Receive, Scope
+from starlette.types import Message, Receive, Scope, Send
 
 try:
     from multipart.multipart import parse_options_header
@@ -14,6 +14,15 @@ except ImportError:  # pragma: nocover
     parse_options_header = None  # type: ignore
 
 
+SERVER_PUSH_HEADERS_TO_COPY = {
+    "accept",
+    "accept-encoding",
+    "accept-language",
+    "cache-control",
+    "user-agent",
+}
+
+
 class ClientDisconnect(Exception):
     pass
 
@@ -121,11 +130,18 @@ async def empty_receive() -> Message:
     raise RuntimeError("Receive channel has not been made available")
 
 
+async def empty_send(message: Message) -> None:
+    raise RuntimeError("Send channel has not been made available")
+
+
 class Request(HTTPConnection):
-    def __init__(self, scope: Scope, receive: Receive = empty_receive):
+    def __init__(
+        self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send
+    ):
         super().__init__(scope)
         assert scope["type"] == "http"
         self._receive = receive
+        self._send = send
         self._stream_consumed = False
         self._is_disconnected = False
 
@@ -206,3 +222,15 @@ class Request(HTTPConnection):
                 self._is_disconnected = True
 
         return self._is_disconnected
+
+    async def send_push_promise(self, path: str) -> None:
+        if "http.response.push" in self.scope.get("extensions", {}):
+            raw_headers = []
+            for name in SERVER_PUSH_HEADERS_TO_COPY:
+                for value in self.headers.getlist(name):
+                    raw_headers.append(
+                        (name.encode("latin-1"), value.encode("latin-1"))
+                    )
+            await self._send(
+                {"type": "http.response.push", "path": path, "headers": raw_headers}
+            )
index 2c1f815db27958bee31548b32bf8e9f41a605b3d..4f955374322185cdd6d21b3abb6043bcf6b7627b 100644 (file)
@@ -36,7 +36,7 @@ def request_response(func: typing.Callable) -> ASGIApp:
     is_coroutine = asyncio.iscoroutinefunction(func)
 
     async def app(scope: Scope, receive: Receive, send: Send) -> None:
-        request = Request(scope, receive=receive)
+        request = Request(scope, receive=receive, send=send)
         if is_coroutine:
             response = await func(request)
         else:
index 03c74c4515f3b189c0b25dd080c2a63894e779bb..defdf3a45562c4e7cb25ab2dfc531a8301c0d963 100644 (file)
@@ -300,3 +300,61 @@ def test_chunked_encoding():
 
     response = client.post("/", data=post_body())
     assert response.json() == {"body": "foobar"}
+
+
+def test_request_send_push_promise():
+    async def app(scope, receive, send):
+        # the server is push-enabled
+        scope["extensions"]["http.response.push"] = {}
+
+        request = Request(scope, receive, send)
+        await request.send_push_promise("/style.css")
+
+        response = JSONResponse({"json": "OK"})
+        await response(scope, receive, send)
+
+    client = TestClient(app)
+    response = client.get("/")
+    assert response.json() == {"json": "OK"}
+
+
+def test_request_send_push_promise_without_push_extension():
+    """
+    If server does not support the `http.response.push` extension,
+    .send_push_promise() does nothing.
+    """
+
+    async def app(scope, receive, send):
+        request = Request(scope)
+        await request.send_push_promise("/style.css")
+
+        response = JSONResponse({"json": "OK"})
+        await response(scope, receive, send)
+
+    client = TestClient(app)
+    response = client.get("/")
+    assert response.json() == {"json": "OK"}
+
+
+def test_request_send_push_promise_without_setting_send():
+    """
+    If Request is instantiated without the send channel, then
+    .send_push_promise() is not available.
+    """
+
+    async def app(scope, receive, send):
+        # the server is push-enabled
+        scope["extensions"]["http.response.push"] = {}
+
+        data = "OK"
+        request = Request(scope)
+        try:
+            await request.send_push_promise("/style.css")
+        except RuntimeError:
+            data = "Send channel not available"
+        response = JSONResponse({"json": data})
+        await response(scope, receive, send)
+
+    client = TestClient(app)
+    response = client.get("/")
+    assert response.json() == {"json": "Send channel not available"}