import stat
import typing
from email.utils import parsedate
+from pathlib import Path
import anyio
self.all_directories = self.get_directories(directory, packages)
self.html = html
self.config_checked = False
- if check_dir and directory is not None and not os.path.isdir(directory):
+ if check_dir and directory is not None and not Path(directory).is_dir():
raise RuntimeError(f"Directory '{directory}' does not exist")
def get_directories(
spec = importlib.util.find_spec(package)
assert spec is not None, f"Package {package!r} could not be found."
assert spec.origin is not None, f"Package {package!r} could not be found."
- package_directory = os.path.normpath(
- os.path.join(spec.origin, "..", statics_dir)
- )
- assert os.path.isdir(
- package_directory
+ package_directory = Path(spec.origin).joinpath("..", statics_dir).resolve()
+ assert (
+ package_directory.is_dir()
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
directories.append(package_directory)
response = await self.get_response(path, scope)
await response(scope, receive, send)
- def get_path(self, scope: Scope) -> str:
+ def get_path(self, scope: Scope) -> Path:
"""
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
- return os.path.normpath(os.path.join(*scope["path"].split("/")))
+ return Path(*scope["path"].split("/"))
- async def get_response(self, path: str, scope: Scope) -> Response:
+ async def get_response(self, path: Path, scope: Scope) -> Response:
"""
Returns an HTTP response, given the incoming path, method and request headers.
"""
elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html:
# We're in HTML mode, and have got a directory URL.
# Check if we have 'index.html' file to serve.
- index_path = os.path.join(path, "index.html")
+ index_path = path.joinpath("index.html")
full_path, stat_result = await anyio.to_thread.run_sync(
self.lookup_path, index_path
)
raise HTTPException(status_code=404)
def lookup_path(
- self, path: str
- ) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
+ self, path: Path
+ ) -> typing.Tuple[Path, typing.Optional[os.stat_result]]:
for directory in self.all_directories:
- full_path = os.path.realpath(os.path.join(directory, path))
- directory = os.path.realpath(directory)
- if os.path.commonprefix([full_path, directory]) != directory:
- # Don't allow misbehaving clients to break out of the static files
- # directory.
- continue
+ original_path = Path(directory).joinpath(path)
+ full_path = original_path.resolve()
+ directory = Path(directory).resolve()
try:
- return full_path, os.stat(full_path)
+ stat_result = os.lstat(original_path)
+ full_path.relative_to(directory)
+ return full_path, stat_result
+ except ValueError:
+ # Allow clients to break out of the static files directory
+ # if following symlinks.
+ if stat.S_ISLNK(stat_result.st_mode):
+ stat_result = os.lstat(full_path)
+ return full_path, stat_result
except (FileNotFoundError, NotADirectoryError):
continue
- return "", None
+ return Path(), None
def file_response(
self,
directory = os.path.join(tmpdir, "foo")
os.mkdir(directory)
- path = os.path.join(tmpdir, "example.txt")
- with open(path, "w") as file:
+ file_path = os.path.join(tmpdir, "example.txt")
+ with open(file_path, "w") as file:
file.write("outside root dir")
app = StaticFiles(directory=directory)
response = client.get("/example.txt")
assert response.status_code == 500
assert response.text == "Internal Server Error"
+
+
+def test_staticfiles_follows_symlinks_to_break_out_of_dir(
+ tmp_path: pathlib.Path, test_client_factory
+):
+ statics_path = tmp_path.joinpath("statics")
+ statics_path.mkdir()
+
+ symlink_path = tmp_path.joinpath("symlink")
+ symlink_path.mkdir()
+
+ symlink_file_path = symlink_path.joinpath("index.html")
+ with open(symlink_file_path, "w") as file:
+ file.write("<h1>Hello</h1>")
+
+ statics_file_path = statics_path.joinpath("index.html")
+ statics_file_path.symlink_to(symlink_file_path)
+
+ app = StaticFiles(directory=statics_path)
+ client = test_client_factory(app)
+
+ response = client.get("/index.html")
+ assert response.url == "http://testserver/index.html"
+ assert response.status_code == 200
+ assert response.text == "<h1>Hello</h1>"