]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Add StaticFile and StaticFiles ASGI applications
authorTom Christie <tom@tomchristie.com>
Thu, 12 Jul 2018 12:13:44 +0000 (13:13 +0100)
committerTom Christie <tom@tomchristie.com>
Thu, 12 Jul 2018 12:13:44 +0000 (13:13 +0100)
starlette/response.py
starlette/staticfiles.py [new file with mode: 0644]
tests/test_response.py
tests/test_staticfiles.py [new file with mode: 0644]

index d3294713d305e0953edff5b9f31f93577cb3eaa7..a1f4010ed7a64f7bf536a4b73b38663334bc3ae8 100644 (file)
@@ -6,6 +6,7 @@ from starlette.types import Receive, Send
 import aiofiles
 import json
 import hashlib
+import os
 import stat
 import typing
 
@@ -134,6 +135,7 @@ class FileResponse(Response):
         headers: dict = None,
         media_type: str = None,
         filename: str = None,
+        stat_result: os.stat_result = None
     ) -> None:
         self.path = path
         self.status_code = 200
@@ -145,6 +147,9 @@ class FileResponse(Response):
         if self.filename is not None:
             content_disposition = 'attachment; filename="{}"'.format(self.filename)
             self.headers.setdefault("content-disposition", content_disposition)
+        self.stat_result = stat_result
+        if stat_result is not None:
+            self.set_stat_headers(stat_result)
 
     def set_stat_headers(self, stat_result):
         content_length = str(stat_result.st_size)
@@ -156,8 +161,9 @@ class FileResponse(Response):
         self.headers.setdefault("etag", etag)
 
     async def __call__(self, receive: Receive, send: Send) -> None:
-        stat_result = await aio_stat(self.path)
-        self.set_stat_headers(stat_result)
+        if self.stat_result is None:
+            stat_result = await aio_stat(self.path)
+            self.set_stat_headers(stat_result)
         await send(
             {
                 "type": "http.response.start",
diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py
new file mode 100644 (file)
index 0000000..d596415
--- /dev/null
@@ -0,0 +1,50 @@
+from starlette import PlainTextResponse, FileResponse
+from aiofiles.os import stat as aio_stat
+import os
+import stat
+
+
+class StaticFile:
+    def __init__(self, *, path):
+        self.path = path
+
+    def __call__(self, scope):
+        if scope['method'] not in ('GET', 'HEAD'):
+            return PlainTextResponse('Method not allowed', status_code=406)
+        return _StaticFileResponder(scope, path=self.path, allow_404=False)
+
+
+class StaticFiles:
+    def __init__(self, *, directory):
+        self.directory = directory
+
+    def __call__(self, scope):
+        if scope['method'] not in ('GET', 'HEAD'):
+            return PlainTextResponse('Method not allowed', status_code=406)
+        split_path = scope['path'].split('/')
+        path = os.path.join(self.directory, *split_path)
+        return _StaticFileResponder(scope, path=path, allow_404=True)
+
+
+class _StaticFileResponder:
+    def __init__(self, scope, path, allow_404):
+        self.scope = scope
+        self.path = path
+        self.allow_404 = allow_404
+
+    async def __call__(self, receive, send):
+        try:
+            stat_result = await aio_stat(self.path)
+        except FileNotFoundError:
+            if not self.allow_404:
+                raise RuntimeError("StaticFile at path '%s' does not exist." % self.path)
+            response = PlainTextResponse('Not found', status_code=404)
+        else:
+            mode = stat_result.st_mode
+            if stat.S_ISREG(mode) or stat.S_ISLNK(mode):
+                response = FileResponse(self.path, stat_result=stat_result)
+            else:
+                if not self.allow_404:
+                    raise RuntimeError("StaticFile at path '%s' is not a file." % self.path)
+                response = PlainTextResponse('Not found', status_code=404)
+        await response(receive, send)
index c22602b0aa1ecd1846fe9d034b0942b24c827d80..670aa6ca178ef0321267de5d9ee9f6da74283f65 100644 (file)
@@ -1,5 +1,6 @@
 from starlette import FileResponse, Response, StreamingResponse, TestClient
 import asyncio
+import os
 
 
 def test_text_response():
@@ -68,11 +69,12 @@ def test_response_headers():
 
 
 def test_file_response(tmpdir):
-    with open("xyz", "wb") as file:
+    path = os.path.join(tmpdir, "xyz")
+    with open(path, "wb") as file:
         file.write(b"<file content>")
 
     def app(scope):
-        return FileResponse(path="xyz", filename="example.png")
+        return FileResponse(path=path, filename="example.png")
 
     client = TestClient(app)
     response = client.get("/")
diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py
new file mode 100644 (file)
index 0000000..5565a43
--- /dev/null
@@ -0,0 +1,93 @@
+from starlette import TestClient
+from starlette.staticfiles import StaticFile, StaticFiles
+import os
+import pytest
+
+
+def test_staticfile(tmpdir):
+    path = os.path.join(tmpdir, "example.txt")
+    with open(path, "w") as file:
+        file.write("<file content>")
+
+    app = StaticFile(path=path)
+    client = TestClient(app)
+    response = client.get("/")
+    assert response.status_code == 200
+    assert response.text == '<file content>'
+
+
+def test_staticfile_post(tmpdir):
+    path = os.path.join(tmpdir, "example.txt")
+    with open(path, "w") as file:
+        file.write("<file content>")
+
+    app = StaticFile(path=path)
+    client = TestClient(app)
+    response = client.post("/")
+    assert response.status_code == 406
+    assert response.text == 'Method not allowed'
+
+
+def test_staticfile_with_directory_raises_error(tmpdir):
+    app = StaticFile(path=tmpdir)
+    client = TestClient(app)
+    with pytest.raises(RuntimeError) as exc:
+        response = client.get("/")
+    assert 'is not a file' in str(exc)
+
+
+def test_staticfile_with_missing_file_raises_error(tmpdir):
+    path = os.path.join(tmpdir, '404.txt')
+    app = StaticFile(path=path)
+    client = TestClient(app)
+    with pytest.raises(RuntimeError) as exc:
+        response = client.get("/")
+    assert 'does not exist' in str(exc)
+
+
+def test_staticfiles(tmpdir):
+    path = os.path.join(tmpdir, "example.txt")
+    with open(path, "w") as file:
+        file.write("<file content>")
+
+    app = StaticFiles(directory=tmpdir)
+    client = TestClient(app)
+    response = client.get("/example.txt")
+    assert response.status_code == 200
+    assert response.text == '<file content>'
+
+
+def test_staticfiles_post(tmpdir):
+    path = os.path.join(tmpdir, "example.txt")
+    with open(path, "w") as file:
+        file.write("<file content>")
+
+    app = StaticFiles(directory=tmpdir)
+    client = TestClient(app)
+    response = client.post("/example.txt")
+    assert response.status_code == 406
+    assert response.text == 'Method not allowed'
+
+
+def test_staticfiles_with_directory_returns_404(tmpdir):
+    path = os.path.join(tmpdir, "example.txt")
+    with open(path, "w") as file:
+        file.write("<file content>")
+
+    app = StaticFiles(directory=tmpdir)
+    client = TestClient(app)
+    response = client.get("/")
+    assert response.status_code == 404
+    assert response.text == 'Not found'
+
+
+def test_staticfiles_with_missing_file_returns_404(tmpdir):
+    path = os.path.join(tmpdir, "example.txt")
+    with open(path, "w") as file:
+        file.write("<file content>")
+
+    app = StaticFiles(directory=tmpdir)
+    client = TestClient(app)
+    response = client.get("/404.txt")
+    assert response.status_code == 404
+    assert response.text == 'Not found'