]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Fix Staticfiles `404.html` in HTML mode (#1314)
authorEugene Mayer <white@pfel.ru>
Tue, 19 Oct 2021 09:27:35 +0000 (12:27 +0300)
committerGitHub <noreply@github.com>
Tue, 19 Oct 2021 09:27:35 +0000 (11:27 +0200)
Co-authored-by: Amin Alaee <mohammadamin.alaee@gmail.com>
starlette/staticfiles.py
tests/test_staticfiles.py

index 39a6972609202f27b13bf7591d4d71d01d99e7a8..f7057539fa1021d2f984df54ff3175f6e077ac81 100644 (file)
@@ -111,20 +111,6 @@ class StaticFiles:
             full_path, stat_result = await anyio.to_thread.run_sync(
                 self.lookup_path, path
             )
-        except (FileNotFoundError, NotADirectoryError):
-            if self.html:
-                # Check for '404.html' if we're in HTML mode.
-                full_path, stat_result = await anyio.to_thread.run_sync(
-                    self.lookup_path, "404.html"
-                )
-                if stat_result and stat.S_ISREG(stat_result.st_mode):
-                    return FileResponse(
-                        full_path,
-                        stat_result=stat_result,
-                        method=scope["method"],
-                        status_code=404,
-                    )
-            raise HTTPException(status_code=404)
         except PermissionError:
             raise HTTPException(status_code=401)
         except OSError:
@@ -149,6 +135,18 @@ class StaticFiles:
                     return RedirectResponse(url=url)
                 return self.file_response(full_path, stat_result, scope)
 
+        if self.html:
+            # Check for '404.html' if we're in HTML mode.
+            full_path, stat_result = await anyio.to_thread.run_sync(
+                self.lookup_path, "404.html"
+            )
+            if stat_result and stat.S_ISREG(stat_result.st_mode):
+                return FileResponse(
+                    full_path,
+                    stat_result=stat_result,
+                    method=scope["method"],
+                    status_code=404,
+                )
         raise HTTPException(status_code=404)
 
     def lookup_path(
@@ -161,7 +159,10 @@ class StaticFiles:
                 # Don't allow misbehaving clients to break out of the static files
                 # directory.
                 continue
-            return full_path, os.stat(full_path)
+            try:
+                return full_path, os.stat(full_path)
+            except (FileNotFoundError, NotADirectoryError):
+                continue
         return "", None
 
     def file_response(
index 48fdaf1a54a5a5223847be8a2353a6df0ed2a3d0..8057af68921d08bf31625a52dcf339857396e4f0 100644 (file)
@@ -229,7 +229,7 @@ def test_staticfiles_304_with_last_modified_compare_last_req(
     assert response.content == b"<file content>"
 
 
-def test_staticfiles_html(tmpdir, test_client_factory):
+def test_staticfiles_html_normal(tmpdir, test_client_factory):
     path = os.path.join(tmpdir, "404.html")
     with open(path, "w") as file:
         file.write("<h1>Custom not found page</h1>")
@@ -262,6 +262,73 @@ def test_staticfiles_html(tmpdir, test_client_factory):
     assert response.text == "<h1>Custom not found page</h1>"
 
 
+def test_staticfiles_html_without_index(tmpdir, test_client_factory):
+    path = os.path.join(tmpdir, "404.html")
+    with open(path, "w") as file:
+        file.write("<h1>Custom not found page</h1>")
+    path = os.path.join(tmpdir, "dir")
+    os.mkdir(path)
+
+    app = StaticFiles(directory=tmpdir, html=True)
+    client = test_client_factory(app)
+
+    response = client.get("/dir/")
+    assert response.url == "http://testserver/dir/"
+    assert response.status_code == 404
+    assert response.text == "<h1>Custom not found page</h1>"
+
+    response = client.get("/dir")
+    assert response.url == "http://testserver/dir"
+    assert response.status_code == 404
+    assert response.text == "<h1>Custom not found page</h1>"
+
+    response = client.get("/missing")
+    assert response.status_code == 404
+    assert response.text == "<h1>Custom not found page</h1>"
+
+
+def test_staticfiles_html_without_404(tmpdir, test_client_factory):
+    path = os.path.join(tmpdir, "dir")
+    os.mkdir(path)
+    path = os.path.join(path, "index.html")
+    with open(path, "w") as file:
+        file.write("<h1>Hello</h1>")
+
+    app = StaticFiles(directory=tmpdir, html=True)
+    client = test_client_factory(app)
+
+    response = client.get("/dir/")
+    assert response.url == "http://testserver/dir/"
+    assert response.status_code == 200
+    assert response.text == "<h1>Hello</h1>"
+
+    response = client.get("/dir")
+    assert response.url == "http://testserver/dir/"
+    assert response.status_code == 200
+    assert response.text == "<h1>Hello</h1>"
+
+    with pytest.raises(HTTPException) as exc_info:
+        response = client.get("/missing")
+    assert exc_info.value.status_code == 404
+
+
+def test_staticfiles_html_only_files(tmpdir, test_client_factory):
+    path = os.path.join(tmpdir, "hello.html")
+    with open(path, "w") as file:
+        file.write("<h1>Hello</h1>")
+
+    app = StaticFiles(directory=tmpdir, html=True)
+    client = test_client_factory(app)
+
+    with pytest.raises(HTTPException) as exc_info:
+        response = client.get("/")
+    assert exc_info.value.status_code == 404
+
+    response = client.get("/hello.html")
+    assert response.status_code == 200
+    assert response.text == "<h1>Hello</h1>"
+
+
 def test_staticfiles_cache_invalidation_for_deleted_file_html_mode(
     tmpdir, test_client_factory
 ):