]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
Revert "Allow staticfiles to follow symlinks outside directory (#1377)" (#1681)
authorMarcelo Trylesinski <marcelotryle@gmail.com>
Fri, 10 Jun 2022 05:12:12 +0000 (07:12 +0200)
committerGitHub <noreply@github.com>
Fri, 10 Jun 2022 05:12:12 +0000 (07:12 +0200)
This reverts commit d3dccdc477652b6de5a7b6b14a2bf3fa2f94be2c.

starlette/staticfiles.py
tests/test_staticfiles.py

index da10a390c253de363b315fb01ab9df709f17ca17..d09630f35b85bd72efc8febed9ac816c551744a6 100644 (file)
@@ -3,7 +3,6 @@ import os
 import stat
 import typing
 from email.utils import parsedate
-from pathlib import Path
 
 import anyio
 
@@ -52,7 +51,7 @@ class StaticFiles:
         self.all_directories = self.get_directories(directory, packages)
         self.html = html
         self.config_checked = False
-        if check_dir and directory is not None and not Path(directory).is_dir():
+        if check_dir and directory is not None and not os.path.isdir(directory):
             raise RuntimeError(f"Directory '{directory}' does not exist")
 
     def get_directories(
@@ -78,9 +77,11 @@ class StaticFiles:
             spec = importlib.util.find_spec(package)
             assert spec is not None, f"Package {package!r} could not be found."
             assert spec.origin is not None, f"Package {package!r} could not be found."
-            package_directory = Path(spec.origin).joinpath("..", statics_dir).resolve()
-            assert (
-                package_directory.is_dir()
+            package_directory = os.path.normpath(
+                os.path.join(spec.origin, "..", statics_dir)
+            )
+            assert os.path.isdir(
+                package_directory
             ), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
             directories.append(package_directory)
 
@@ -100,14 +101,14 @@ class StaticFiles:
         response = await self.get_response(path, scope)
         await response(scope, receive, send)
 
-    def get_path(self, scope: Scope) -> Path:
+    def get_path(self, scope: Scope) -> str:
         """
         Given the ASGI scope, return the `path` string to serve up,
         with OS specific path separators, and any '..', '.' components removed.
         """
-        return Path(*scope["path"].split("/"))
+        return os.path.normpath(os.path.join(*scope["path"].split("/")))
 
-    async def get_response(self, path: Path, scope: Scope) -> Response:
+    async def get_response(self, path: str, scope: Scope) -> Response:
         """
         Returns an HTTP response, given the incoming path, method and request headers.
         """
@@ -130,7 +131,7 @@ class StaticFiles:
         elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
             # We're in HTML mode, and have got a directory URL.
             # Check if we have 'index.html' file to serve.
-            index_path = path.joinpath("index.html")
+            index_path = os.path.join(path, "index.html")
             full_path, stat_result = await anyio.to_thread.run_sync(
                 self.lookup_path, index_path
             )
@@ -157,25 +158,20 @@ class StaticFiles:
         raise HTTPException(status_code=404)
 
     def lookup_path(
-        self, path: Path
-    ) -> typing.Tuple[Path, typing.Optional[os.stat_result]]:
+        self, path: str
+    ) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
         for directory in self.all_directories:
-            original_path = Path(directory).joinpath(path)
-            full_path = original_path.resolve()
-            directory = Path(directory).resolve()
+            full_path = os.path.realpath(os.path.join(directory, path))
+            directory = os.path.realpath(directory)
+            if os.path.commonprefix([full_path, directory]) != directory:
+                # Don't allow misbehaving clients to break out of the static files
+                # directory.
+                continue
             try:
-                stat_result = os.lstat(original_path)
-                full_path.relative_to(directory)
-                return full_path, stat_result
-            except ValueError:
-                # Allow clients to break out of the static files directory
-                # if following symlinks.
-                if stat.S_ISLNK(stat_result.st_mode):
-                    stat_result = os.lstat(full_path)
-                    return full_path, stat_result
+                return full_path, os.stat(full_path)
             except (FileNotFoundError, NotADirectoryError):
                 continue
-        return Path(), None
+        return "", None
 
     def file_response(
         self,
index 53f3ea9cd6accc531bb3de682d12599b365ac5e8..7d13a0522418c6bb46c363520d489d189643c395 100644 (file)
@@ -166,8 +166,8 @@ def test_staticfiles_prevents_breaking_out_of_directory(tmpdir):
     directory = os.path.join(tmpdir, "foo")
     os.mkdir(directory)
 
-    file_path = os.path.join(tmpdir, "example.txt")
-    with open(file_path, "w") as file:
+    path = os.path.join(tmpdir, "example.txt")
+    with open(path, "w") as file:
         file.write("outside root dir")
 
     app = StaticFiles(directory=directory)
@@ -441,28 +441,3 @@ def test_staticfiles_unhandled_os_error_returns_500(
     response = client.get("/example.txt")
     assert response.status_code == 500
     assert response.text == "Internal Server Error"
-
-
-def test_staticfiles_follows_symlinks_to_break_out_of_dir(
-    tmp_path: pathlib.Path, test_client_factory
-):
-    statics_path = tmp_path.joinpath("statics")
-    statics_path.mkdir()
-
-    symlink_path = tmp_path.joinpath("symlink")
-    symlink_path.mkdir()
-
-    symlink_file_path = symlink_path.joinpath("index.html")
-    with open(symlink_file_path, "w") as file:
-        file.write("<h1>Hello</h1>")
-
-    statics_file_path = statics_path.joinpath("index.html")
-    statics_file_path.symlink_to(symlink_file_path)
-
-    app = StaticFiles(directory=statics_path)
-    client = test_client_factory(app)
-
-    response = client.get("/index.html")
-    assert response.url == "http://testserver/index.html"
-    assert response.status_code == 200
-    assert response.text == "<h1>Hello</h1>"