]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
handle staticfiles OSError exceptions (#1220)
authorAmin Alaee <mohammadamin.alaee@gmail.com>
Tue, 5 Oct 2021 06:01:10 +0000 (09:31 +0330)
committerGitHub <noreply@github.com>
Tue, 5 Oct 2021 06:01:10 +0000 (09:31 +0330)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
starlette/staticfiles.py
tests/test_staticfiles.py

index 33ea0b0337602f8805ee7665c9a7ccb9a83d6dc2..39a6972609202f27b13bf7591d4d71d01d99e7a8 100644 (file)
@@ -7,12 +7,8 @@ from email.utils import parsedate
 import anyio
 
 from starlette.datastructures import URL, Headers
-from starlette.responses import (
-    FileResponse,
-    PlainTextResponse,
-    RedirectResponse,
-    Response,
-)
+from starlette.exceptions import HTTPException
+from starlette.responses import FileResponse, RedirectResponse, Response
 from starlette.types import Receive, Scope, Send
 
 PathLike = typing.Union[str, "os.PathLike[str]"]
@@ -109,9 +105,30 @@ class StaticFiles:
         Returns an HTTP response, given the incoming path, method and request headers.
         """
         if scope["method"] not in ("GET", "HEAD"):
-            return PlainTextResponse("Method Not Allowed", status_code=405)
+            raise HTTPException(status_code=405)
 
-        full_path, stat_result = await self.lookup_path(path)
+        try:
+            full_path, stat_result = await anyio.to_thread.run_sync(
+                self.lookup_path, path
+            )
+        except (FileNotFoundError, NotADirectoryError):
+            if self.html:
+                # Check for '404.html' if we're in HTML mode.
+                full_path, stat_result = await anyio.to_thread.run_sync(
+                    self.lookup_path, "404.html"
+                )
+                if stat_result and stat.S_ISREG(stat_result.st_mode):
+                    return FileResponse(
+                        full_path,
+                        stat_result=stat_result,
+                        method=scope["method"],
+                        status_code=404,
+                    )
+            raise HTTPException(status_code=404)
+        except PermissionError:
+            raise HTTPException(status_code=401)
+        except OSError:
+            raise
 
         if stat_result and stat.S_ISREG(stat_result.st_mode):
             # We have a static file to serve.
@@ -121,7 +138,9 @@ class StaticFiles:
             # We're in HTML mode, and have got a directory URL.
             # Check if we have 'index.html' file to serve.
             index_path = os.path.join(path, "index.html")
-            full_path, stat_result = await self.lookup_path(index_path)
+            full_path, stat_result = await anyio.to_thread.run_sync(
+                self.lookup_path, index_path
+            )
             if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
                 if not scope["path"].endswith("/"):
                     # Directory URLs should redirect to always end in "/".
@@ -130,20 +149,9 @@ class StaticFiles:
                     return RedirectResponse(url=url)
                 return self.file_response(full_path, stat_result, scope)
 
-        if self.html:
-            # Check for '404.html' if we're in HTML mode.
-            full_path, stat_result = await self.lookup_path("404.html")
-            if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
-                return FileResponse(
-                    full_path,
-                    stat_result=stat_result,
-                    method=scope["method"],
-                    status_code=404,
-                )
-
-        return PlainTextResponse("Not Found", status_code=404)
+        raise HTTPException(status_code=404)
 
-    async def lookup_path(
+    def lookup_path(
         self, path: str
     ) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
         for directory in self.all_directories:
@@ -153,11 +161,7 @@ class StaticFiles:
                 # Don't allow misbehaving clients to break out of the static files
                 # directory.
                 continue
-            try:
-                stat_result = await anyio.to_thread.run_sync(os.stat, full_path)
-                return full_path, stat_result
-            except FileNotFoundError:
-                pass
+            return full_path, os.stat(full_path)
         return "", None
 
     def file_response(
index d5ec1afc5e6b3f00818e70ec883c691016328469..48fdaf1a54a5a5223847be8a2353a6df0ed2a3d0 100644 (file)
@@ -1,11 +1,13 @@
 import os
 import pathlib
+import stat
 import time
 
 import anyio
 import pytest
 
 from starlette.applications import Starlette
+from starlette.exceptions import HTTPException
 from starlette.requests import Request
 from starlette.routing import Mount
 from starlette.staticfiles import StaticFiles
@@ -71,8 +73,10 @@ def test_staticfiles_post(tmpdir, test_client_factory):
     with open(path, "w") as file:
         file.write("<file content>")
 
-    app = StaticFiles(directory=tmpdir)
+    routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
+    app = Starlette(routes=routes)
     client = test_client_factory(app)
+
     response = client.post("/example.txt")
     assert response.status_code == 405
     assert response.text == "Method Not Allowed"
@@ -83,8 +87,10 @@ def test_staticfiles_with_directory_returns_404(tmpdir, test_client_factory):
     with open(path, "w") as file:
         file.write("<file content>")
 
-    app = StaticFiles(directory=tmpdir)
+    routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
+    app = Starlette(routes=routes)
     client = test_client_factory(app)
+
     response = client.get("/")
     assert response.status_code == 404
     assert response.text == "Not Found"
@@ -95,8 +101,10 @@ def test_staticfiles_with_missing_file_returns_404(tmpdir, test_client_factory):
     with open(path, "w") as file:
         file.write("<file content>")
 
-    app = StaticFiles(directory=tmpdir)
+    routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
+    app = Starlette(routes=routes)
     client = test_client_factory(app)
+
     response = client.get("/404.txt")
     assert response.status_code == 404
     assert response.text == "Not Found"
@@ -136,11 +144,15 @@ def test_staticfiles_config_check_occurs_only_once(tmpdir, test_client_factory):
     app = StaticFiles(directory=tmpdir)
     client = test_client_factory(app)
     assert not app.config_checked
-    client.get("/")
-    assert app.config_checked
-    client.get("/")
+
+    with pytest.raises(HTTPException):
+        client.get("/")
+
     assert app.config_checked
 
+    with pytest.raises(HTTPException):
+        client.get("/")
+
 
 def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
     directory = os.path.join(tmpdir, "foo")
@@ -154,9 +166,12 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
     # We can't test this with 'requests', so we test the app directly here.
     path = app.get_path({"path": "/../example.txt"})
     scope = {"method": "GET"}
-    response = anyio.run(app.get_response, path, scope)
-    assert response.status_code == 404
-    assert response.body == b"Not Found"
+
+    with pytest.raises(HTTPException) as exc_info:
+        anyio.run(app.get_response, path, scope)
+
+    assert exc_info.value.status_code == 404
+    assert exc_info.value.detail == "Not Found"
 
 
 def test_staticfiles_never_read_file_for_head_method(tmpdir, test_client_factory):
@@ -284,3 +299,70 @@ def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(
     )
     assert resp_deleted.status_code == 404
     assert resp_deleted.text == "<p>404 file</p>"
+
+
+def test_staticfiles_with_invalid_dir_permissions_returns_401(
+    tmpdir, test_client_factory
+):
+    path = os.path.join(tmpdir, "example.txt")
+    with open(path, "w") as file:
+        file.write("<file content>")
+
+    os.chmod(tmpdir, stat.S_IRWXO)
+
+    routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
+    app = Starlette(routes=routes)
+    client = test_client_factory(app)
+
+    response = client.get("/example.txt")
+    assert response.status_code == 401
+    assert response.text == "Unauthorized"
+
+
+def test_staticfiles_with_missing_dir_returns_404(tmpdir, test_client_factory):
+    path = os.path.join(tmpdir, "example.txt")
+    with open(path, "w") as file:
+        file.write("<file content>")
+
+    routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
+    app = Starlette(routes=routes)
+    client = test_client_factory(app)
+
+    response = client.get("/foo/example.txt")
+    assert response.status_code == 404
+    assert response.text == "Not Found"
+
+
+def test_staticfiles_access_file_as_dir_returns_404(tmpdir, test_client_factory):
+    path = os.path.join(tmpdir, "example.txt")
+    with open(path, "w") as file:
+        file.write("<file content>")
+
+    routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
+    app = Starlette(routes=routes)
+    client = test_client_factory(app)
+
+    response = client.get("/example.txt/foo")
+    assert response.status_code == 404
+    assert response.text == "Not Found"
+
+
+def test_staticfiles_unhandled_os_error_returns_500(
+    tmpdir, test_client_factory, monkeypatch
+):
+    def mock_timeout(*args, **kwargs):
+        raise TimeoutError
+
+    path = os.path.join(tmpdir, "example.txt")
+    with open(path, "w") as file:
+        file.write("<file content>")
+
+    routes = [Mount("/", app=StaticFiles(directory=tmpdir), name="static")]
+    app = Starlette(routes=routes)
+    client = test_client_factory(app, raise_server_exceptions=False)
+
+    monkeypatch.setattr("starlette.staticfiles.StaticFiles.lookup_path", mock_timeout)
+
+    response = client.get("/example.txt")
+    assert response.status_code == 500
+    assert response.text == "Internal Server Error"