import re
import warnings
from collections.abc import Sequence
-from copy import copy, deepcopy
+from copy import copy
from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import lru_cache
values: dict[str, Any] = {}, # noqa: B006
*,
loc: tuple[Union[int, str], ...] = (),
- ) -> tuple[Any, Union[list[dict[str, Any]], None]]:
+ ) -> tuple[Any, list[dict[str, Any]]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
- None,
+ [],
)
except ValidationError as exc:
return None, _regenerate_error_with_loc(
if "description" in item_def:
item_description = cast(str, item_def["description"]).split("\f")[0]
item_def["description"] = item_description
- new_mapping, new_definitions = _remap_definitions_and_field_mappings(
- model_name_map=model_name_map,
- definitions=definitions, # type: ignore[arg-type]
- field_mapping=field_mapping,
- )
- return new_mapping, new_definitions
-
-
-def _replace_refs(
- *,
- schema: dict[str, Any],
- old_name_to_new_name_map: dict[str, str],
-) -> dict[str, Any]:
- new_schema = deepcopy(schema)
- for key, value in new_schema.items():
- if key == "$ref":
- value = schema["$ref"]
- if isinstance(value, str):
- ref_name = schema["$ref"].split("/")[-1]
- if ref_name in old_name_to_new_name_map:
- new_name = old_name_to_new_name_map[ref_name]
- new_schema["$ref"] = REF_TEMPLATE.format(model=new_name)
- continue
- if isinstance(value, dict):
- new_schema[key] = _replace_refs(
- schema=value,
- old_name_to_new_name_map=old_name_to_new_name_map,
- )
- elif isinstance(value, list):
- new_value = []
- for item in value:
- if isinstance(item, dict):
- new_item = _replace_refs(
- schema=item,
- old_name_to_new_name_map=old_name_to_new_name_map,
- )
- new_value.append(new_item)
-
- else:
- new_value.append(item)
- new_schema[key] = new_value
- return new_schema
-
-
-def _remap_definitions_and_field_mappings(
- *,
- model_name_map: ModelNameMap,
- definitions: dict[str, Any],
- field_mapping: dict[
- tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
- ],
-) -> tuple[
- dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
- dict[str, Any],
-]:
- old_name_to_new_name_map = {}
- for field_key, schema in field_mapping.items():
- model = field_key[0].type_
- if model not in model_name_map or "$ref" not in schema:
- continue
- new_name = model_name_map[model]
- old_name = schema["$ref"].split("/")[-1]
- if old_name in {f"{new_name}-Input", f"{new_name}-Output"}:
- continue
- old_name_to_new_name_map[old_name] = new_name
-
- new_field_mapping: dict[
- tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
- ] = {}
- for field_key, schema in field_mapping.items():
- new_schema = _replace_refs(
- schema=schema,
- old_name_to_new_name_map=old_name_to_new_name_map,
- )
- new_field_mapping[field_key] = new_schema
-
- new_definitions = {}
- for key, value in definitions.items():
- if key in old_name_to_new_name_map:
- new_key = old_name_to_new_name_map[key]
- else:
- new_key = key
- new_value = _replace_refs(
- schema=value,
- old_name_to_new_name_map=old_name_to_new_name_map,
- )
- new_definitions[new_key] = new_value
- return new_field_mapping, new_definitions
+ # definitions: dict[DefsRef, dict[str, Any]]
+ # but mypy complains about general str in other places that are not declared as
+ # DefsRef, although DefsRef is just str:
+ # DefsRef = NewType('DefsRef', str)
+ # So, a cast to simplify the types here
+ return field_mapping, cast(dict[str, dict[str, Any]], definitions)
def is_scalar_field(field: ModelField) -> bool:
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return,index]
-def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
+def get_missing_field_error(loc: tuple[Union[int, str], ...]) -> dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
return {v: k for k, v in name_model_map.items()}
-def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
- flat_models = get_flat_models_from_fields(fields, known_models=set())
- return get_model_name_map(flat_models)
-
-
def get_flat_models_from_model(
model: type["BaseModel"], known_models: Union[TypeModelSet, None] = None
) -> TypeModelSet: