]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
🐛 Check Content-Type request header before assuming JSON (#2118)
authorPatrick Wang <1263870+patrickkwang@users.noreply.github.com>
Mon, 7 Jun 2021 10:46:18 +0000 (06:46 -0400)
committerGitHub <noreply@github.com>
Mon, 7 Jun 2021 10:46:18 +0000 (12:46 +0200)
Co-authored-by: Patrick Wang <patrickkwang@users.noreply.github.com>
Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
fastapi/routing.py
tests/test_tutorial/test_body/test_tutorial001.py
tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py

index ac5e19d99835a7b7e07db0a8d87ac02e9794fb51..9b51f03cac5622c03e48cee18b7e7a11be3ad6a4 100644 (file)
@@ -1,4 +1,5 @@
 import asyncio
+import email.message
 import enum
 import inspect
 import json
@@ -36,7 +37,7 @@ from fastapi.utils import (
 )
 from pydantic import BaseModel
 from pydantic.error_wrappers import ErrorWrapper, ValidationError
-from pydantic.fields import ModelField
+from pydantic.fields import ModelField, Undefined
 from starlette import routing
 from starlette.concurrency import run_in_threadpool
 from starlette.exceptions import HTTPException
@@ -174,14 +175,26 @@ def get_request_handler(
 
     async def app(request: Request) -> Response:
         try:
-            body = None
+            body: Any = None
             if body_field:
                 if is_body_form:
                     body = await request.form()
                 else:
                     body_bytes = await request.body()
                     if body_bytes:
-                        body = await request.json()
+                        json_body: Any = Undefined
+                        content_type_value = request.headers.get("content-type")
+                        if content_type_value:
+                            message = email.message.Message()
+                            message["content-type"] = content_type_value
+                            if message.get_content_maintype() == "application":
+                                subtype = message.get_content_subtype()
+                                if subtype == "json" or subtype.endswith("+json"):
+                                    json_body = await request.json()
+                        if json_body != Undefined:
+                            body = json_body
+                        else:
+                            body = body_bytes
         except json.JSONDecodeError as e:
             raise RequestValidationError([ErrorWrapper(e, ("body", e.pos))], body=e.doc)
         except Exception as e:
index 38c6dbe876b26d082751ffaffcbedcf4f4baa666..c90240ae4c34948725e6a07c96f9948853e5d3f7 100644 (file)
@@ -173,25 +173,91 @@ def test_post_body(path, body, expected_status, expected_response):
 
 
 def test_post_broken_body():
-    response = client.post("/items/", data={"name": "Foo", "price": 50.5})
+    response = client.post(
+        "/items/",
+        headers={"content-type": "application/json"},
+        data="{some broken json}",
+    )
     assert response.status_code == 422, response.text
     assert response.json() == {
         "detail": [
             {
+                "loc": ["body", 1],
+                "msg": "Expecting property name enclosed in double quotes: line 1 column 2 (char 1)",
+                "type": "value_error.jsondecode",
                 "ctx": {
-                    "colno": 1,
-                    "doc": "name=Foo&price=50.5",
+                    "msg": "Expecting property name enclosed in double quotes",
+                    "doc": "{some broken json}",
+                    "pos": 1,
                     "lineno": 1,
-                    "msg": "Expecting value",
-                    "pos": 0,
+                    "colno": 2,
                 },
-                "loc": ["body", 0],
-                "msg": "Expecting value: line 1 column 1 (char 0)",
-                "type": "value_error.jsondecode",
             }
         ]
     }
+
+
+def test_post_form_for_json():
+    response = client.post("/items/", data={"name": "Foo", "price": 50.5})
+    assert response.status_code == 422, response.text
+    assert response.json() == {
+        "detail": [
+            {
+                "loc": ["body"],
+                "msg": "value is not a valid dict",
+                "type": "type_error.dict",
+            }
+        ]
+    }
+
+
+def test_explicit_content_type():
+    response = client.post(
+        "/items/",
+        data='{"name": "Foo", "price": 50.5}',
+        headers={"Content-Type": "application/json"},
+    )
+    assert response.status_code == 200, response.text
+
+
+def test_geo_json():
+    response = client.post(
+        "/items/",
+        data='{"name": "Foo", "price": 50.5}',
+        headers={"Content-Type": "application/geo+json"},
+    )
+    assert response.status_code == 200, response.text
+
+
+def test_wrong_headers():
+    data = '{"name": "Foo", "price": 50.5}'
+    invalid_dict = {
+        "detail": [
+            {
+                "loc": ["body"],
+                "msg": "value is not a valid dict",
+                "type": "type_error.dict",
+            }
+        ]
+    }
+
+    response = client.post("/items/", data=data, headers={"Content-Type": "text/plain"})
+    assert response.status_code == 422, response.text
+    assert response.json() == invalid_dict
+
+    response = client.post(
+        "/items/", data=data, headers={"Content-Type": "application/geo+json-seq"}
+    )
+    assert response.status_code == 422, response.text
+    assert response.json() == invalid_dict
+    response = client.post(
+        "/items/", data=data, headers={"Content-Type": "application/not-really-json"}
+    )
+    assert response.status_code == 422, response.text
+    assert response.json() == invalid_dict
+
+
+def test_other_exceptions():
     with patch("json.loads", side_effect=Exception):
         response = client.post("/items/", json={"test": "test2"})
         assert response.status_code == 400, response.text
-    assert response.json() == {"detail": "There was an error parsing the body"}
index cc85a8a82a5acb49d1ce784623bb341119759a94..3eb5822e288168bb7b3a88c42f0f9481a0daa25b 100644 (file)
@@ -25,6 +25,7 @@ def test_gzip_request(compress):
     if compress:
         data = gzip.compress(data)
         headers["Content-Encoding"] = "gzip"
+    headers["Content-Type"] = "application/json"
     response = client.post("/sum", data=data, headers=headers)
     assert response.json() == {"sum": n}