import enum
import inspect
+import sys
from abc import ABC, abstractmethod # pylint: disable=[no-name-in-module]
from typing import Any, Callable, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar, Union, cast
raise DataValidationError(f"expected bool, found {type(obj)}", object_path)
def _create_literal(self, tp: Type[Any], obj: Any, object_path: str) -> Any:
- expected = get_generic_type_arguments(tp)
+ args = get_generic_type_arguments(tp)
+
+ expected = []
+ if sys.version_info < (3, 9):
+ for arg in args:
+ if is_literal(arg):
+ expected += get_generic_type_arguments(arg)
+ else:
+ expected.append(arg)
+ else:
+ expected = args
+
if obj in expected:
return obj
raise DataValidationError(f"'{obj}' does not match any of the expected values {expected}", object_path)
v: str
+class _TestLiteral(ConfigSchema):
+ v: Literal[Literal["lit1"], Literal["lit2"]]
+
+
+@pytest.mark.parametrize("val", ["lit1", "lit2"])
+def test_parsing_literal_valid(val: str):
+ assert _TestLiteral(parse_yaml(f"v: {val}")).v == val
+
+
+@pytest.mark.parametrize("val", ["invalid", "false", 1, "null"])
+def test_parsing_literal_invalid(val: str):
+ with raises(DataValidationError):
+ _TestLiteral(parse_yaml(f"v: {val}"))
+
+
@pytest.mark.parametrize("val,exp", [("false", False), ("true", True), ("False", False), ("True", True)])
def test_parsing_bool_valid(val: str, exp: bool):
assert _TestBool(parse_yaml(f"v: {val}")).v == exp