]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Turn directory into string on `lookup_path` on commonpath comparison (#2851)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Fri, 24 Jan 2025 11:13:42 +0000 (12:13 +0100)
committerGitHub <noreply@github.com>
Fri, 24 Jan 2025 11:13:42 +0000 (12:13 +0100)
* Turn directory into string on `lookup_path` on commonpath comparison

* remove str cast complication

starlette/staticfiles.py
tests/test_staticfiles.py

index 746e740e0e215983a69907bac68dd40bc71edc0e..34be04cdc7c486a66944e147b7c775f44402f599 100644 (file)
@@ -156,9 +156,8 @@ class StaticFiles:
             else:
                 full_path = os.path.realpath(joined_path)
                 directory = os.path.realpath(directory)
-            if os.path.commonpath([full_path, directory]) != directory:
-                # Don't allow misbehaving clients to break out of the static files
-                # directory.
+            if os.path.commonpath([full_path, directory]) != str(directory):
+                # Don't allow misbehaving clients to break out of the static files directory.
                 continue
             try:
                 return full_path, os.stat(full_path)
index b4f13171987eac9ef07e2e9f42122f16735b18a8..2c5e7e2dfa265f3e6c0d09aff871aa7fe9e06533 100644 (file)
@@ -576,16 +576,15 @@ def test_staticfiles_avoids_path_traversal(tmp_path: Path) -> None:
     assert exc_info.value.detail == "Not Found"
 
 
-def test_staticfiles_self_symlinks(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
-    statics_path = os.path.join(tmpdir, "statics")
-    os.mkdir(statics_path)
+def test_staticfiles_self_symlinks(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
+    statics_path = tmp_path / "statics"
+    statics_path.mkdir()
 
-    source_file_path = os.path.join(statics_path, "index.html")
-    with open(source_file_path, "w") as file:
-        file.write("<h1>Hello</h1>")
+    source_file_path = statics_path / "index.html"
+    source_file_path.write_text("<h1>Hello</h1>", encoding="utf-8")
 
-    statics_symlink_path = os.path.join(tmpdir, "statics_symlink")
-    os.symlink(statics_path, statics_symlink_path)
+    statics_symlink_path = tmp_path / "statics_symlink"
+    statics_symlink_path.symlink_to(statics_path)
 
     app = StaticFiles(directory=statics_symlink_path, follow_symlink=True)
     client = test_client_factory(app)