from pydantic.errors import MissingError
from pydantic.fields import Field, Required, Shape
from pydantic.schema import get_annotation_from_schema
-from pydantic.utils import lenient_issubclass
+from pydantic.utils import ForwardRef, evaluate_forwardref, lenient_issubclass
from starlette.background import BackgroundTasks
from starlette.concurrency import run_in_threadpool
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
return False
+def get_typed_signature(call: Callable) -> inspect.Signature:
+ signature = inspect.signature(call)
+ globalns = getattr(call, "__globals__", {})
+ typed_params = [
+ inspect.Parameter(
+ name=param.name,
+ kind=param.kind,
+ default=param.default,
+ annotation=get_typed_annotation(param, globalns),
+ )
+ for param in signature.parameters.values()
+ ]
+ typed_signature = inspect.Signature(typed_params)
+ return typed_signature
+
+
+def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
+ annotation = param.annotation
+ if isinstance(annotation, str):
+ annotation = ForwardRef(annotation)
+ annotation = evaluate_forwardref(annotation, globalns, globalns)
+ return annotation
+
+
def get_dependant(
*,
path: str,
use_cache: bool = True,
) -> Dependant:
path_param_names = get_path_param_names(path)
- endpoint_signature = inspect.signature(call)
+ endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
dependant = Dependant(call=call, name=name, path=path, use_cache=use_cache)
for param_name, param in signature_params.items():
]:
values: Dict[str, Any] = {}
errors: List[ErrorWrapper] = []
- response = response or Response( # type: ignore
- content=None, status_code=None, headers=None, media_type=None, background=None
+ response = response or Response(
+ content=None,
+ status_code=None, # type: ignore
+ headers=None,
+ media_type=None,
+ background=None,
)
dependency_cache = dependency_cache or {}
sub_dependant: Dependant
values.update(cookie_values)
errors += path_errors + query_errors + header_errors + cookie_errors
if dependant.body_params:
- body_values, body_errors = await request_body_to_args( # type: ignore # body_params checked above
+ body_values, body_errors = await request_body_to_args( # body_params checked above
required_params=dependant.body_params, received_body=body
)
values.update(body_values)
use_type = original_type
if lenient_issubclass(original_type, BaseModel):
original_type = cast(Type[BaseModel], original_type)
- use_type = create_model( # type: ignore
+ use_type = create_model(
original_type.__name__,
__config__=original_type.__config__,
- __validators__=original_type.__validators__,
+ __validators__=original_type.__validators__, # type: ignore
)
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = f
username: str
-def get_current_user(oauth_header: str = Security(reusable_oauth2)):
+# Here we use string annotations to test them
+def get_current_user(oauth_header: "str" = Security(reusable_oauth2)):
user = User(username=oauth_header)
return user
@app.post("/login")
-def read_current_user(form_data: OAuth2PasswordRequestFormStrict = Depends()):
+# Here we use string annotations to test them
+def read_current_user(form_data: "OAuth2PasswordRequestFormStrict" = Depends()):
return form_data
@app.get("/users/me")
-def read_current_user(current_user: User = Depends(get_current_user)):
+# Here we use string annotations to test them
+def read_current_user(current_user: "User" = Depends(get_current_user)):
return current_user