return values, errors
+def is_union_of_base_models(field_type: Any) -> bool:
+ """Check if field type is a Union where all members are BaseModel subclasses."""
+ from fastapi.types import UnionType
+
+ origin = get_origin(field_type)
+
+ # Check if it's a Union type (covers both typing.Union and types.UnionType in Python 3.10+)
+ if origin is not Union and origin is not UnionType:
+ return False
+
+ union_args = get_args(field_type)
+
+ for arg in union_args:
+ if not lenient_issubclass(arg, BaseModel):
+ return False
+
+ return True
+
+
def _should_embed_body_fields(fields: List[ModelField]) -> bool:
if not fields:
return False
# If it explicitly specifies it is embedded, it has to be embedded
if getattr(first_field.field_info, "embed", None):
return True
- # If it's a Form (or File) field, it has to be a BaseModel to be top level
+ # If it's a Form (or File) field, it has to be a BaseModel (or a union of BaseModels) to be top level
# otherwise it has to be embedded, so that the key value pair can be extracted
- if isinstance(first_field.field_info, params.Form) and not lenient_issubclass(
- first_field.type_, BaseModel
+ if (
+ isinstance(first_field.field_info, params.Form)
+ and not lenient_issubclass(first_field.type_, BaseModel)
+ and not is_union_of_base_models(first_field.type_)
):
return True
return False
--- /dev/null
+from typing import Union
+
+from fastapi import FastAPI, Form
+from fastapi.testclient import TestClient
+from pydantic import BaseModel
+from typing_extensions import Annotated
+
+app = FastAPI()
+
+
+class UserForm(BaseModel):
+ name: str
+ email: str
+
+
+class CompanyForm(BaseModel):
+ company_name: str
+ industry: str
+
+
+@app.post("/form-union/")
+def post_union_form(data: Annotated[Union[UserForm, CompanyForm], Form()]):
+ return {"received": data}
+
+
+client = TestClient(app)
+
+
+def test_post_user_form():
+ response = client.post(
+ "/form-union/", data={"name": "John Doe", "email": "john@example.com"}
+ )
+ assert response.status_code == 200, response.text
+ assert response.json() == {
+ "received": {"name": "John Doe", "email": "john@example.com"}
+ }
+
+
+def test_post_company_form():
+ response = client.post(
+ "/form-union/", data={"company_name": "Tech Corp", "industry": "Technology"}
+ )
+ assert response.status_code == 200, response.text
+ assert response.json() == {
+ "received": {"company_name": "Tech Corp", "industry": "Technology"}
+ }
+
+
+def test_invalid_form_data():
+ response = client.post(
+ "/form-union/",
+ data={"name": "John", "company_name": "Tech Corp"},
+ )
+ assert response.status_code == 422, response.text
+
+
+def test_empty_form():
+ response = client.post("/form-union/")
+ assert response.status_code == 422, response.text
+
+
+def test_openapi_schema():
+ response = client.get("/openapi.json")
+ assert response.status_code == 200, response.text
+
+ assert response.json() == {
+ "openapi": "3.1.0",
+ "info": {"title": "FastAPI", "version": "0.1.0"},
+ "paths": {
+ "/form-union/": {
+ "post": {
+ "summary": "Post Union Form",
+ "operationId": "post_union_form_form_union__post",
+ "requestBody": {
+ "content": {
+ "application/x-www-form-urlencoded": {
+ "schema": {
+ "anyOf": [
+ {"$ref": "#/components/schemas/UserForm"},
+ {"$ref": "#/components/schemas/CompanyForm"},
+ ],
+ "title": "Data",
+ }
+ }
+ },
+ "required": True,
+ },
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {"application/json": {"schema": {}}},
+ },
+ "422": {
+ "description": "Validation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/HTTPValidationError"
+ }
+ }
+ },
+ },
+ },
+ }
+ }
+ },
+ "components": {
+ "schemas": {
+ "CompanyForm": {
+ "properties": {
+ "company_name": {"type": "string", "title": "Company Name"},
+ "industry": {"type": "string", "title": "Industry"},
+ },
+ "type": "object",
+ "required": ["company_name", "industry"],
+ "title": "CompanyForm",
+ },
+ "HTTPValidationError": {
+ "properties": {
+ "detail": {
+ "items": {"$ref": "#/components/schemas/ValidationError"},
+ "type": "array",
+ "title": "Detail",
+ }
+ },
+ "type": "object",
+ "title": "HTTPValidationError",
+ },
+ "UserForm": {
+ "properties": {
+ "name": {"type": "string", "title": "Name"},
+ "email": {"type": "string", "title": "Email"},
+ },
+ "type": "object",
+ "required": ["name", "email"],
+ "title": "UserForm",
+ },
+ "ValidationError": {
+ "properties": {
+ "loc": {
+ "items": {
+ "anyOf": [{"type": "string"}, {"type": "integer"}]
+ },
+ "type": "array",
+ "title": "Location",
+ },
+ "msg": {"type": "string", "title": "Message"},
+ "type": {"type": "string", "title": "Error Type"},
+ },
+ "type": "object",
+ "required": ["loc", "msg", "type"],
+ "title": "ValidationError",
+ },
+ }
+ },
+ }