]> git.ipfire.org Git - thirdparty/fastapi/fastapi.git/commitdiff
:sparkles: Add support for strings and __future__ type annotations (#451)
authordmontagu <35119617+dmontagu@users.noreply.github.com>
Sun, 29 Sep 2019 21:19:09 +0000 (14:19 -0700)
committerSebastián Ramírez <tiangolo@gmail.com>
Sun, 29 Sep 2019 21:19:09 +0000 (16:19 -0500)
* Add support for strings and __future__ annotations

* Add comments indicating reason for string annotations

* Fix ignores (including removing some unused ignores)

fastapi/dependencies/utils.py
fastapi/openapi/models.py
fastapi/utils.py
tests/test_security_oauth2.py

index 7f0f59092233ef7a890f01bed4aad8abb889bbfb..852f1e0253dee82483ab157c76c676cd201c7d99 100644 (file)
@@ -26,7 +26,7 @@ from pydantic.error_wrappers import ErrorWrapper
 from pydantic.errors import MissingError
 from pydantic.fields import Field, Required, Shape
 from pydantic.schema import get_annotation_from_schema
-from pydantic.utils import lenient_issubclass
+from pydantic.utils import ForwardRef, evaluate_forwardref, lenient_issubclass
 from starlette.background import BackgroundTasks
 from starlette.concurrency import run_in_threadpool
 from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
@@ -171,6 +171,30 @@ def is_scalar_sequence_field(field: Field) -> bool:
     return False
 
 
+def get_typed_signature(call: Callable) -> inspect.Signature:
+    signature = inspect.signature(call)
+    globalns = getattr(call, "__globals__", {})
+    typed_params = [
+        inspect.Parameter(
+            name=param.name,
+            kind=param.kind,
+            default=param.default,
+            annotation=get_typed_annotation(param, globalns),
+        )
+        for param in signature.parameters.values()
+    ]
+    typed_signature = inspect.Signature(typed_params)
+    return typed_signature
+
+
+def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
+    annotation = param.annotation
+    if isinstance(annotation, str):
+        annotation = ForwardRef(annotation)
+        annotation = evaluate_forwardref(annotation, globalns, globalns)
+    return annotation
+
+
 def get_dependant(
     *,
     path: str,
@@ -180,7 +204,7 @@ def get_dependant(
     use_cache: bool = True,
 ) -> Dependant:
     path_param_names = get_path_param_names(path)
-    endpoint_signature = inspect.signature(call)
+    endpoint_signature = get_typed_signature(call)
     signature_params = endpoint_signature.parameters
     dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
     for param_name, param in signature_params.items():
@@ -329,8 +353,12 @@ async def solve_dependencies(
 ]:
     values: Dict[str, Any] = {}
     errors: List[ErrorWrapper] = []
-    response = response or Response(  # type: ignore
-        content=None, status_code=None, headers=None, media_type=None, background=None
+    response = response or Response(
+        content=None,
+        status_code=None,  # type: ignore
+        headers=None,
+        media_type=None,
+        background=None,
     )
     dependency_cache = dependency_cache or {}
     sub_dependant: Dependant
@@ -405,7 +433,7 @@ async def solve_dependencies(
     values.update(cookie_values)
     errors += path_errors + query_errors + header_errors + cookie_errors
     if dependant.body_params:
-        body_values, body_errors = await request_body_to_args(  # type: ignore # body_params checked above
+        body_values, body_errors = await request_body_to_args(  # body_params checked above
             required_params=dependant.body_params, received_body=body
         )
         values.update(body_values)
index 3dd9f04dc83856fabc8717c935bb8a4dc2f58978..e5c50070e04b66d3e289a9719c243358dd3d0422 100644 (file)
@@ -11,7 +11,7 @@ try:
     import email_validator
 
     assert email_validator  # make autoflake ignore the unused import
-    from pydantic.types import EmailStr  # type: ignore
+    from pydantic.types import EmailStr
 except ImportError:  # pragma: no cover
     logger.warning(
         "email-validator not installed, email fields will be treated as str.\n"
index 17a16b52276872ba6e8d80e74ca65107e44575bd..8cb0ec123b3774a53985a23be36fb75768550690 100644 (file)
@@ -58,10 +58,10 @@ def create_cloned_field(field: Field) -> Field:
     use_type = original_type
     if lenient_issubclass(original_type, BaseModel):
         original_type = cast(Type[BaseModel], original_type)
-        use_type = create_model(  # type: ignore
+        use_type = create_model(
             original_type.__name__,
             __config__=original_type.__config__,
-            __validators__=original_type.__validators__,
+            __validators__=original_type.__validators__,  # type: ignore
         )
         for f in original_type.__fields__.values():
             use_type.__fields__[f.name] = f
index 890613b2909341815fd830043d2d9ce73ad2a9fe..5cf2592f314dc570a2b739b6344205f2cabb58d8 100644 (file)
@@ -21,18 +21,21 @@ class User(BaseModel):
     username: str
 
 
-def get_current_user(oauth_header: str = Security(reusable_oauth2)):
+# Here we use string annotations to test them
+def get_current_user(oauth_header: "str" = Security(reusable_oauth2)):
     user = User(username=oauth_header)
     return user
 
 
 @app.post("/login")
-def read_current_user(form_data: OAuth2PasswordRequestFormStrict = Depends()):
+# Here we use string annotations to test them
+def read_current_user(form_data: "OAuth2PasswordRequestFormStrict" = Depends()):
     return form_data
 
 
 @app.get("/users/me")
-def read_current_user(current_user: User = Depends(get_current_user)):
+# Here we use string annotations to test them
+def read_current_user(current_user: "User" = Depends(get_current_user)):
     return current_user