import aiofiles
import json
import hashlib
+import os
import stat
import typing
headers: dict = None,
media_type: str = None,
filename: str = None,
+ stat_result: os.stat_result = None
) -> None:
self.path = path
self.status_code = 200
if self.filename is not None:
content_disposition = 'attachment; filename="{}"'.format(self.filename)
self.headers.setdefault("content-disposition", content_disposition)
+ self.stat_result = stat_result
+ if stat_result is not None:
+ self.set_stat_headers(stat_result)
def set_stat_headers(self, stat_result):
content_length = str(stat_result.st_size)
self.headers.setdefault("etag", etag)
async def __call__(self, receive: Receive, send: Send) -> None:
- stat_result = await aio_stat(self.path)
- self.set_stat_headers(stat_result)
+ if self.stat_result is None:
+ stat_result = await aio_stat(self.path)
+ self.set_stat_headers(stat_result)
await send(
{
"type": "http.response.start",
--- /dev/null
+from starlette import PlainTextResponse, FileResponse
+from aiofiles.os import stat as aio_stat
+import os
+import stat
+
+
+class StaticFile:
+ def __init__(self, *, path):
+ self.path = path
+
+ def __call__(self, scope):
+ if scope['method'] not in ('GET', 'HEAD'):
+ return PlainTextResponse('Method not allowed', status_code=406)
+ return _StaticFileResponder(scope, path=self.path, allow_404=False)
+
+
+class StaticFiles:
+ def __init__(self, *, directory):
+ self.directory = directory
+
+ def __call__(self, scope):
+ if scope['method'] not in ('GET', 'HEAD'):
+ return PlainTextResponse('Method not allowed', status_code=406)
+ split_path = scope['path'].split('/')
+ path = os.path.join(self.directory, *split_path)
+ return _StaticFileResponder(scope, path=path, allow_404=True)
+
+
+class _StaticFileResponder:
+ def __init__(self, scope, path, allow_404):
+ self.scope = scope
+ self.path = path
+ self.allow_404 = allow_404
+
+ async def __call__(self, receive, send):
+ try:
+ stat_result = await aio_stat(self.path)
+ except FileNotFoundError:
+ if not self.allow_404:
+ raise RuntimeError("StaticFile at path '%s' does not exist." % self.path)
+ response = PlainTextResponse('Not found', status_code=404)
+ else:
+ mode = stat_result.st_mode
+ if stat.S_ISREG(mode) or stat.S_ISLNK(mode):
+ response = FileResponse(self.path, stat_result=stat_result)
+ else:
+ if not self.allow_404:
+ raise RuntimeError("StaticFile at path '%s' is not a file." % self.path)
+ response = PlainTextResponse('Not found', status_code=404)
+ await response(receive, send)
--- /dev/null
+from starlette import TestClient
+from starlette.staticfiles import StaticFile, StaticFiles
+import os
+import pytest
+
+
+def test_staticfile(tmpdir):
+ path = os.path.join(tmpdir, "example.txt")
+ with open(path, "w") as file:
+ file.write("<file content>")
+
+ app = StaticFile(path=path)
+ client = TestClient(app)
+ response = client.get("/")
+ assert response.status_code == 200
+ assert response.text == '<file content>'
+
+
+def test_staticfile_post(tmpdir):
+ path = os.path.join(tmpdir, "example.txt")
+ with open(path, "w") as file:
+ file.write("<file content>")
+
+ app = StaticFile(path=path)
+ client = TestClient(app)
+ response = client.post("/")
+ assert response.status_code == 406
+ assert response.text == 'Method not allowed'
+
+
+def test_staticfile_with_directory_raises_error(tmpdir):
+ app = StaticFile(path=tmpdir)
+ client = TestClient(app)
+ with pytest.raises(RuntimeError) as exc:
+ response = client.get("/")
+ assert 'is not a file' in str(exc)
+
+
+def test_staticfile_with_missing_file_raises_error(tmpdir):
+ path = os.path.join(tmpdir, '404.txt')
+ app = StaticFile(path=path)
+ client = TestClient(app)
+ with pytest.raises(RuntimeError) as exc:
+ response = client.get("/")
+ assert 'does not exist' in str(exc)
+
+
+def test_staticfiles(tmpdir):
+ path = os.path.join(tmpdir, "example.txt")
+ with open(path, "w") as file:
+ file.write("<file content>")
+
+ app = StaticFiles(directory=tmpdir)
+ client = TestClient(app)
+ response = client.get("/example.txt")
+ assert response.status_code == 200
+ assert response.text == '<file content>'
+
+
+def test_staticfiles_post(tmpdir):
+ path = os.path.join(tmpdir, "example.txt")
+ with open(path, "w") as file:
+ file.write("<file content>")
+
+ app = StaticFiles(directory=tmpdir)
+ client = TestClient(app)
+ response = client.post("/example.txt")
+ assert response.status_code == 406
+ assert response.text == 'Method not allowed'
+
+
+def test_staticfiles_with_directory_returns_404(tmpdir):
+ path = os.path.join(tmpdir, "example.txt")
+ with open(path, "w") as file:
+ file.write("<file content>")
+
+ app = StaticFiles(directory=tmpdir)
+ client = TestClient(app)
+ response = client.get("/")
+ assert response.status_code == 404
+ assert response.text == 'Not found'
+
+
+def test_staticfiles_with_missing_file_returns_404(tmpdir):
+ path = os.path.join(tmpdir, "example.txt")
+ with open(path, "w") as file:
+ file.write("<file content>")
+
+ app = StaticFiles(directory=tmpdir)
+ client = TestClient(app)
+ response = client.get("/404.txt")
+ assert response.status_code == 404
+ assert response.text == 'Not found'