]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Add support for multi-file uploads (#158)
authorSebastián Ramírez <tiangolo@gmail.com>
Sun, 14 Apr 2019 18:12:14 +0000 (22:12 +0400)
committerGitHub <noreply@github.com>
Sun, 14 Apr 2019 18:12:14 +0000 (22:12 +0400)
docs/src/request_files/tutorial002.py [new file with mode: 0644]
docs/tutorial/request-files.md
fastapi/dependencies/utils.py
fastapi/routing.py
tests/test_tutorial/test_request_files/test_tutorial002.py [new file with mode: 0644]

diff --git a/docs/src/request_files/tutorial002.py b/docs/src/request_files/tutorial002.py
new file mode 100644 (file)
index 0000000..bc665b2
--- /dev/null
@@ -0,0 +1,33 @@
+from typing import List
+
+from fastapi import FastAPI, File, UploadFile
+from starlette.responses import HTMLResponse
+
+app = FastAPI()
+
+
+@app.post("/files/")
+async def create_files(files: List[bytes] = File(...)):
+    return {"file_sizes": [len(file) for file in files]}
+
+
+@app.post("/uploadfiles/")
+async def create_upload_files(files: List[UploadFile] = File(...)):
+    return {"filenames": [file.filename for file in files]}
+
+
+@app.get("/")
+async def main():
+    content = """
+<body>
+<form action="/files/" enctype="multipart/form-data" method="post">
+<input name="files" type="file" multiple>
+<input type="submit">
+</form>
+<form action="/uploadfiles/" enctype="multipart/form-data" method="post">
+<input name="files" type="file" multiple>
+<input type="submit">
+</form>
+</body>
+    """
+    return HTMLResponse(content=content)
index bc73dd25e7d93f9539438d2d393aad2e88a8f5d1..2e0c2f802af07fb48a3d5dd23050391e68f356e4 100644 (file)
@@ -43,7 +43,7 @@ Using `UploadFile` has several advantages over `bytes`:
 
 * It uses a "spooled" file:
     * A file stored in memory up to a maximum size limit, and after passing this limit it will be stored in disk.
-* This means that it will work well for large files like images, videos, large binaries, etc. All without consuming all the memory.
+* This means that it will work well for large files like images, videos, large binaries, etc. without consuming all the memory.
 * You can get metadata from the uploaded file.
 * It has a <a href="https://docs.python.org/3/glossary.html#term-file-like-object" target="_blank">file-like</a> `async` interface.
 * It exposes an actual Python <a href="https://docs.python.org/3/library/tempfile.html#tempfile.SpooledTemporaryFile" target="_blank">`SpooledTemporaryFile`</a> object that you can pass directly to other libraries that expect a file-like object.
@@ -107,6 +107,20 @@ The way HTML forms (`<form></form>`) sends the data to the server normally uses
 
     This is not a limitation of **FastAPI**, it's part of the HTTP protocol.
 
+## Multiple file uploads
+
+It's possible to upload several files at the same time.
+
+They would be associated to the same "form field" sent using "form data".
+
+To use that, declare a `List` of `bytes` or `UploadFile`:
+
+```Python hl_lines="10 15"
+{!./src/request_files/tutorial002.py!}
+```
+
+You will receive, as declared, a `list` of `bytes` or `UploadFile`s.
+
 ## Recap
 
 Use `File` to declare files to be uploaded as input parameters (as form data).
index 4cf737d6703e63df9869710ed72afc002c739272..c9f61813240cbde879c6e01301be9eca5d2c775d 100644 (file)
@@ -31,8 +31,8 @@ from pydantic.schema import get_annotation_from_schema
 from pydantic.utils import lenient_issubclass
 from starlette.background import BackgroundTasks
 from starlette.concurrency import run_in_threadpool
-from starlette.datastructures import UploadFile
-from starlette.requests import Headers, QueryParams, Request
+from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
+from starlette.requests import Request
 
 param_supported_types = (
     str,
@@ -47,6 +47,10 @@ param_supported_types = (
     Decimal,
 )
 
+sequence_shapes = {Shape.LIST, Shape.SET, Shape.TUPLE}
+sequence_types = (list, set, tuple)
+sequence_shape_to_type = {Shape.LIST: list, Shape.SET: set, Shape.TUPLE: tuple}
+
 
 def get_sub_dependant(
     *, param: inspect.Parameter, path: str, security_scopes: List[str] = None
@@ -318,7 +322,7 @@ def request_params_to_args(
     values = {}
     errors = []
     for field in required_params:
-        if field.shape in {Shape.LIST, Shape.SET, Shape.TUPLE} and isinstance(
+        if field.shape in sequence_shapes and isinstance(
             received_params, (QueryParams, Headers)
         ):
             value = received_params.getlist(field.alias)
@@ -358,11 +362,20 @@ async def request_body_to_args(
         embed = getattr(field.schema, "embed", None)
         if len(required_params) == 1 and not embed:
             received_body = {field.alias: received_body}
-        elif received_body is None:
-            received_body = {}
         for field in required_params:
-            value = received_body.get(field.alias)
-            if value is None or (isinstance(field.schema, params.Form) and value == ""):
+            if field.shape in sequence_shapes and isinstance(received_body, FormData):
+                value = received_body.getlist(field.alias)
+            else:
+                value = received_body.get(field.alias)
+            if (
+                value is None
+                or (isinstance(field.schema, params.Form) and value == "")
+                or (
+                    isinstance(field.schema, params.Form)
+                    and field.shape in sequence_shapes
+                    and len(value) == 0
+                )
+            ):
                 if field.required:
                     errors.append(
                         ErrorWrapper(
@@ -380,6 +393,15 @@ async def request_body_to_args(
                 and isinstance(value, UploadFile)
             ):
                 value = await value.read()
+            elif (
+                field.shape in sequence_shapes
+                and isinstance(field.schema, params.File)
+                and lenient_issubclass(field.type_, bytes)
+                and isinstance(value, sequence_types)
+            ):
+                awaitables = [sub_value.read() for sub_value in value]
+                contents = await asyncio.gather(*awaitables)
+                value = sequence_shape_to_type[field.shape](contents)
             v_, errors_ = field.validate(value, values, loc=("body", field.alias))
             if isinstance(errors_, ErrorWrapper):
                 errors.append(errors_)
@@ -391,10 +413,14 @@ async def request_body_to_args(
 
 
 def get_schema_compatible_field(*, field: Field) -> Field:
+    out_field = field
     if lenient_issubclass(field.type_, UploadFile):
-        return Field(
+        use_type: type = bytes
+        if field.shape in sequence_shapes:
+            use_type = List[bytes]
+        out_field = Field(
             name=field.name,
-            type_=bytes,
+            type_=use_type,
             class_validators=field.class_validators,
             model_config=field.model_config,
             default=field.default,
@@ -402,10 +428,10 @@ def get_schema_compatible_field(*, field: Field) -> Field:
             alias=field.alias,
             schema=field.schema,
         )
-    return field
+    return out_field
 
 
-def get_body_field(*, dependant: Dependant, name: str) -> Field:
+def get_body_field(*, dependant: Dependant, name: str) -> Optional[Field]:
     flat_dependant = get_flat_dependant(dependant)
     if not flat_dependant.body_params:
         return None
index 2bdf46ddc3c7dde48b2841c522de76d4abe89700..a078662d868fc4d6272045f851fdf0417ec93b28 100644 (file)
@@ -53,12 +53,7 @@ def get_app(
             body = None
             if body_field:
                 if is_body_form:
-                    raw_body = await request.form()
-                    form_fields = {}
-                    for field, value in raw_body.items():
-                        form_fields[field] = value
-                    if form_fields:
-                        body = form_fields
+                    body = await request.form()
                 else:
                     body_bytes = await request.body()
                     if body_bytes:
diff --git a/tests/test_tutorial/test_request_files/test_tutorial002.py b/tests/test_tutorial/test_request_files/test_tutorial002.py
new file mode 100644 (file)
index 0000000..e6b7ba4
--- /dev/null
@@ -0,0 +1,219 @@
+import os
+
+from starlette.testclient import TestClient
+
+from request_files.tutorial002 import app
+
+client = TestClient(app)
+
+openapi_schema = {
+    "openapi": "3.0.2",
+    "info": {"title": "Fast API", "version": "0.1.0"},
+    "paths": {
+        "/files/": {
+            "post": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                    "422": {
+                        "description": "Validation Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "$ref": "#/components/schemas/HTTPValidationError"
+                                }
+                            }
+                        },
+                    },
+                },
+                "summary": "Create Files",
+                "operationId": "create_files_files__post",
+                "requestBody": {
+                    "content": {
+                        "multipart/form-data": {
+                            "schema": {"$ref": "#/components/schemas/Body_create_files"}
+                        }
+                    },
+                    "required": True,
+                },
+            }
+        },
+        "/uploadfiles/": {
+            "post": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    },
+                    "422": {
+                        "description": "Validation Error",
+                        "content": {
+                            "application/json": {
+                                "schema": {
+                                    "$ref": "#/components/schemas/HTTPValidationError"
+                                }
+                            }
+                        },
+                    },
+                },
+                "summary": "Create Upload Files",
+                "operationId": "create_upload_files_uploadfiles__post",
+                "requestBody": {
+                    "content": {
+                        "multipart/form-data": {
+                            "schema": {
+                                "$ref": "#/components/schemas/Body_create_upload_files"
+                            }
+                        }
+                    },
+                    "required": True,
+                },
+            }
+        },
+        "/": {
+            "get": {
+                "responses": {
+                    "200": {
+                        "description": "Successful Response",
+                        "content": {"application/json": {"schema": {}}},
+                    }
+                },
+                "summary": "Main",
+                "operationId": "main__get",
+            }
+        },
+    },
+    "components": {
+        "schemas": {
+            "Body_create_files": {
+                "title": "Body_create_files",
+                "required": ["files"],
+                "type": "object",
+                "properties": {
+                    "files": {
+                        "title": "Files",
+                        "type": "array",
+                        "items": {"type": "string", "format": "binary"},
+                    }
+                },
+            },
+            "Body_create_upload_files": {
+                "title": "Body_create_upload_files",
+                "required": ["files"],
+                "type": "object",
+                "properties": {
+                    "files": {
+                        "title": "Files",
+                        "type": "array",
+                        "items": {"type": "string", "format": "binary"},
+                    }
+                },
+            },
+            "ValidationError": {
+                "title": "ValidationError",
+                "required": ["loc", "msg", "type"],
+                "type": "object",
+                "properties": {
+                    "loc": {
+                        "title": "Location",
+                        "type": "array",
+                        "items": {"type": "string"},
+                    },
+                    "msg": {"title": "Message", "type": "string"},
+                    "type": {"title": "Error Type", "type": "string"},
+                },
+            },
+            "HTTPValidationError": {
+                "title": "HTTPValidationError",
+                "type": "object",
+                "properties": {
+                    "detail": {
+                        "title": "Detail",
+                        "type": "array",
+                        "items": {"$ref": "#/components/schemas/ValidationError"},
+                    }
+                },
+            },
+        }
+    },
+}
+
+
+def test_openapi_schema():
+    response = client.get("/openapi.json")
+    assert response.status_code == 200
+    assert response.json() == openapi_schema
+
+
+file_required = {
+    "detail": [
+        {
+            "loc": ["body", "files"],
+            "msg": "field required",
+            "type": "value_error.missing",
+        }
+    ]
+}
+
+
+def test_post_form_no_body():
+    response = client.post("/files/")
+    assert response.status_code == 422
+    assert response.json() == file_required
+
+
+def test_post_body_json():
+    response = client.post("/files/", json={"file": "Foo"})
+    print(response)
+    print(response.content)
+    assert response.status_code == 422
+    assert response.json() == file_required
+
+
+def test_post_files(tmpdir):
+    path = os.path.join(tmpdir, "test.txt")
+    with open(path, "wb") as file:
+        file.write(b"<file content>")
+    path2 = os.path.join(tmpdir, "test2.txt")
+    with open(path2, "wb") as file:
+        file.write(b"<file content2>")
+
+    client = TestClient(app)
+    response = client.post(
+        "/files/",
+        files=(
+            ("files", ("test.txt", open(path, "rb"))),
+            ("files", ("test2.txt", open(path2, "rb"))),
+        ),
+    )
+    assert response.status_code == 200
+    assert response.json() == {"file_sizes": [14, 15]}
+
+
+def test_post_upload_file(tmpdir):
+    path = os.path.join(tmpdir, "test.txt")
+    with open(path, "wb") as file:
+        file.write(b"<file content>")
+    path2 = os.path.join(tmpdir, "test2.txt")
+    with open(path2, "wb") as file:
+        file.write(b"<file content2>")
+
+    client = TestClient(app)
+    response = client.post(
+        "/uploadfiles/",
+        files=(
+            ("files", ("test.txt", open(path, "rb"))),
+            ("files", ("test2.txt", open(path2, "rb"))),
+        ),
+    )
+    assert response.status_code == 200
+    assert response.json() == {"filenames": ["test.txt", "test2.txt"]}
+
+
+def test_get_root():
+    client = TestClient(app)
+    response = client.get("/")
+    assert response.status_code == 200
+    assert b"<form" in response.content