]> git.ipfire.org Git - thirdparty/starlette.git/commitdiff
fix IndexError in TemplateResponse (#2909)
authorAlex Oleshkevich <alex.oleshkevich@gmail.com>
Sun, 16 Mar 2025 12:18:00 +0000 (13:18 +0100)
committerGitHub <noreply@github.com>
Sun, 16 Mar 2025 12:18:00 +0000 (13:18 +0100)
starlette/templating.py
tests/test_templates.py

index 6b01aac9209fdcc893f4b24fd09d8a3f14149ec9..f764858b88365c0485828cbf2d3efda62b9d3fb3 100644 (file)
@@ -168,9 +168,9 @@ class Jinja2Templates:
                 name = args[0]
                 context = args[1] if len(args) > 1 else kwargs.get("context", {})
                 status_code = args[2] if len(args) > 2 else kwargs.get("status_code", 200)
-                headers = args[2] if len(args) > 2 else kwargs.get("headers")
-                media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
-                background = args[4] if len(args) > 4 else kwargs.get("background")
+                headers = args[3] if len(args) > 3 else kwargs.get("headers")
+                media_type = args[4] if len(args) > 4 else kwargs.get("media_type")
+                background = args[5] if len(args) > 5 else kwargs.get("background")
 
                 if "request" not in context:
                     raise ValueError('context must include a "request" key')
index 6b2080c17793877d7b0cbe5e8c9f5ef83abb0086..e182cb82b2b2ade1fb3533c25aca49617490aaef 100644 (file)
@@ -279,6 +279,127 @@ def test_templates_warns_when_first_argument_isnot_request(
     spy.assert_called()
 
 
+class TestTemplatesArgsOnly:
+    # MAINTAINERS: remove after 1.0
+    def test_name_and_context(self, tmpdir: Path, test_client_factory: TestClientFactory) -> None:
+        path = os.path.join(tmpdir, "index.html")
+        with open(path, "w") as file:
+            file.write("value: {{ a }}")
+        templates = Jinja2Templates(directory=str(tmpdir))
+
+        def page(request: Request) -> Response:
+            return templates.TemplateResponse(
+                "index.html",
+                {"a": "b", "request": request},
+            )
+
+        app = Starlette(routes=[Route("/", page)])
+        client = test_client_factory(app)
+        with pytest.warns(DeprecationWarning):
+            response = client.get("/")
+
+        assert response.text == "value: b"  # context was rendered
+        assert response.status_code == 200
+
+    def test_status_code(self, tmpdir: Path, test_client_factory: TestClientFactory) -> None:
+        path = os.path.join(tmpdir, "index.html")
+        with open(path, "w") as file:
+            file.write("value: {{ a }}")
+        templates = Jinja2Templates(directory=str(tmpdir))
+
+        def page(request: Request) -> Response:
+            return templates.TemplateResponse(
+                "index.html",
+                {"a": "b", "request": request},
+                201,
+            )
+
+        app = Starlette(routes=[Route("/", page)])
+        client = test_client_factory(app)
+        with pytest.warns(DeprecationWarning):
+            response = client.get("/")
+
+        assert response.text == "value: b"  # context was rendered
+        assert response.status_code == 201
+
+    def test_headers(self, tmpdir: Path, test_client_factory: TestClientFactory) -> None:
+        path = os.path.join(tmpdir, "index.html")
+        with open(path, "w") as file:
+            file.write("value: {{ a }}")
+        templates = Jinja2Templates(directory=str(tmpdir))
+
+        def page(request: Request) -> Response:
+            return templates.TemplateResponse(
+                "index.html",
+                {"a": "b", "request": request},
+                201,
+                {"x-key": "value"},
+            )
+
+        app = Starlette(routes=[Route("/", page)])
+        client = test_client_factory(app)
+        with pytest.warns(DeprecationWarning):
+            response = client.get("/")
+
+        assert response.text == "value: b"  # context was rendered
+        assert response.status_code == 201
+        assert response.headers["x-key"] == "value"
+
+    def test_media_type(self, tmpdir: Path, test_client_factory: TestClientFactory) -> None:
+        path = os.path.join(tmpdir, "index.html")
+        with open(path, "w") as file:
+            file.write("value: {{ a }}")
+        templates = Jinja2Templates(directory=str(tmpdir))
+
+        def page(request: Request) -> Response:
+            return templates.TemplateResponse(
+                "index.html",
+                {"a": "b", "request": request},
+                201,
+                {"x-key": "value"},
+                "text/plain",
+            )
+
+        app = Starlette(routes=[Route("/", page)])
+        client = test_client_factory(app)
+        with pytest.warns(DeprecationWarning):
+            response = client.get("/")
+
+        assert response.text == "value: b"  # context was rendered
+        assert response.status_code == 201
+        assert response.headers["x-key"] == "value"
+        assert response.headers["content-type"] == "text/plain; charset=utf-8"
+
+    def test_all_args(self, tmpdir: Path, test_client_factory: TestClientFactory) -> None:
+        path = os.path.join(tmpdir, "index.html")
+        with open(path, "w") as file:
+            file.write("value: {{ a }}")
+        templates = Jinja2Templates(directory=str(tmpdir))
+
+        spy = mock.MagicMock()
+
+        def page(request: Request) -> Response:
+            return templates.TemplateResponse(
+                "index.html",
+                {"a": "b", "request": request},
+                201,
+                {"x-key": "value"},
+                "text/plain",
+                BackgroundTask(func=spy),
+            )
+
+        app = Starlette(routes=[Route("/", page)])
+        client = test_client_factory(app)
+        with pytest.warns(DeprecationWarning):
+            response = client.get("/")
+
+        assert response.text == "value: b"  # context was rendered
+        assert response.status_code == 201
+        assert response.headers["x-key"] == "value"
+        assert response.headers["content-type"] == "text/plain; charset=utf-8"
+        spy.assert_called()
+
+
 def test_templates_when_first_argument_is_request(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
     # MAINTAINERS: remove after 1.0
     path = os.path.join(tmpdir, "index.html")