full_path = os.path.abspath(joined_path)
else:
full_path = os.path.realpath(joined_path)
- directory = os.path.realpath(directory)
+ directory = os.path.realpath(directory)
if os.path.commonpath([full_path, directory]) != directory:
# Don't allow misbehaving clients to break out of the static files
# directory.
assert exc_info.value.status_code == 404
assert exc_info.value.detail == "Not Found"
+
+
+def test_staticfiles_self_symlinks(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
+ statics_path = os.path.join(tmpdir, "statics")
+ os.mkdir(statics_path)
+
+ source_file_path = os.path.join(statics_path, "index.html")
+ with open(source_file_path, "w") as file:
+ file.write("<h1>Hello</h1>")
+
+ statics_symlink_path = os.path.join(tmpdir, "statics_symlink")
+ os.symlink(statics_path, statics_symlink_path)
+
+ app = StaticFiles(directory=statics_symlink_path, follow_symlink=True)
+ client = test_client_factory(app)
+
+ response = client.get("/index.html")
+ assert response.url == "http://testserver/index.html"
+ assert response.status_code == 200
+ assert response.text == "<h1>Hello</h1>"