]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Allow staticfiles to follow symlinks outside directory (#1377)
authorAmin Alaee <mohammadamin.alaee@gmail.com>
Sat, 28 May 2022 14:17:54 +0000 (16:17 +0200)
committerGitHub <noreply@github.com>
Sat, 28 May 2022 14:17:54 +0000 (16:17 +0200)
starlette/staticfiles.py
tests/test_staticfiles.py

index d09630f35b85bd72efc8febed9ac816c551744a6..da10a390c253de363b315fb01ab9df709f17ca17 100644 (file)
@@ -3,6 +3,7 @@ import os
 import stat
 import typing
 from email.utils import parsedate
+from pathlib import Path
 
 import anyio
 
@@ -51,7 +52,7 @@ class StaticFiles:
         self.all_directories = self.get_directories(directory, packages)
         self.html = html
         self.config_checked = False
-        if check_dir and directory is not None and not os.path.isdir(directory):
+        if check_dir and directory is not None and not Path(directory).is_dir():
             raise RuntimeError(f"Directory '{directory}' does not exist")
 
     def get_directories(
@@ -77,11 +78,9 @@ class StaticFiles:
             spec = importlib.util.find_spec(package)
             assert spec is not None, f"Package {package!r} could not be found."
             assert spec.origin is not None, f"Package {package!r} could not be found."
-            package_directory = os.path.normpath(
-                os.path.join(spec.origin, "..", statics_dir)
-            )
-            assert os.path.isdir(
-                package_directory
+            package_directory = Path(spec.origin).joinpath("..", statics_dir).resolve()
+            assert (
+                package_directory.is_dir()
             ), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
             directories.append(package_directory)
 
@@ -101,14 +100,14 @@ class StaticFiles:
         response = await self.get_response(path, scope)
         await response(scope, receive, send)
 
-    def get_path(self, scope: Scope) -> str:
+    def get_path(self, scope: Scope) -> Path:
         """
         Given the ASGI scope, return the `path` string to serve up,
         with OS specific path separators, and any '..', '.' components removed.
         """
-        return os.path.normpath(os.path.join(*scope["path"].split("/")))
+        return Path(*scope["path"].split("/"))
 
-    async def get_response(self, path: str, scope: Scope) -> Response:
+    async def get_response(self, path: Path, scope: Scope) -> Response:
         """
         Returns an HTTP response, given the incoming path, method and request headers.
         """
@@ -131,7 +130,7 @@ class StaticFiles:
         elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
             # We're in HTML mode, and have got a directory URL.
             # Check if we have 'index.html' file to serve.
-            index_path = os.path.join(path, "index.html")
+            index_path = path.joinpath("index.html")
             full_path, stat_result = await anyio.to_thread.run_sync(
                 self.lookup_path, index_path
             )
@@ -158,20 +157,25 @@ class StaticFiles:
         raise HTTPException(status_code=404)
 
     def lookup_path(
-        self, path: str
-    ) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
+        self, path: Path
+    ) -> typing.Tuple[Path, typing.Optional[os.stat_result]]:
         for directory in self.all_directories:
-            full_path = os.path.realpath(os.path.join(directory, path))
-            directory = os.path.realpath(directory)
-            if os.path.commonprefix([full_path, directory]) != directory:
-                # Don't allow misbehaving clients to break out of the static files
-                # directory.
-                continue
+            original_path = Path(directory).joinpath(path)
+            full_path = original_path.resolve()
+            directory = Path(directory).resolve()
             try:
-                return full_path, os.stat(full_path)
+                stat_result = os.lstat(original_path)
+                full_path.relative_to(directory)
+                return full_path, stat_result
+            except ValueError:
+                # Allow clients to break out of the static files directory
+                # if following symlinks.
+                if stat.S_ISLNK(stat_result.st_mode):
+                    stat_result = os.lstat(full_path)
+                    return full_path, stat_result
             except (FileNotFoundError, NotADirectoryError):
                 continue
-        return "", None
+        return Path(), None
 
     def file_response(
         self,
index 7d13a0522418c6bb46c363520d489d189643c395..53f3ea9cd6accc531bb3de682d12599b365ac5e8 100644 (file)
@@ -166,8 +166,8 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
     directory = os.path.join(tmpdir, "foo")
     os.mkdir(directory)
 
-    path = os.path.join(tmpdir, "example.txt")
-    with open(path, "w") as file:
+    file_path = os.path.join(tmpdir, "example.txt")
+    with open(file_path, "w") as file:
         file.write("outside root dir")
 
     app = StaticFiles(directory=directory)
@@ -441,3 +441,28 @@ def test_staticfiles_unhandled_os_error_returns_500(
     response = client.get("/example.txt")
     assert response.status_code == 500
     assert response.text == "Internal Server Error"
+
+
+def test_staticfiles_follows_symlinks_to_break_out_of_dir(
+    tmp_path: pathlib.Path, test_client_factory
+):
+    statics_path = tmp_path.joinpath("statics")
+    statics_path.mkdir()
+
+    symlink_path = tmp_path.joinpath("symlink")
+    symlink_path.mkdir()
+
+    symlink_file_path = symlink_path.joinpath("index.html")
+    with open(symlink_file_path, "w") as file:
+        file.write("<h1>Hello</h1>")
+
+    statics_file_path = statics_path.joinpath("index.html")
+    statics_file_path.symlink_to(symlink_file_path)
+
+    app = StaticFiles(directory=statics_path)
+    client = test_client_factory(app)
+
+    response = client.get("/index.html")
+    assert response.url == "http://testserver/index.html"
+    assert response.status_code == 200
+    assert response.text == "<h1>Hello</h1>"