]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-137371: refactor and fortify `test.support.hashlib_helper` (#137375)
authorBénédikt Tran <10796600+picnixz@users.noreply.github.com>
Wed, 10 Sep 2025 07:14:38 +0000 (09:14 +0200)
committerGitHub <noreply@github.com>
Wed, 10 Sep 2025 07:14:38 +0000 (09:14 +0200)
Lib/test/support/hashlib_helper.py
Lib/test/test_hmac.py
Lib/test/test_support.py

index 96be74e4105c182b569802d0fb92d5a4797241ad..49077d7cb4d75796d86b9707f2392610df91a01d 100644 (file)
@@ -9,15 +9,68 @@ from test.support import import_helper
 from types import MappingProxyType
 
 
-def try_import_module(module_name):
-    """Try to import a module and return None on failure."""
+def _parse_fullname(fullname, *, strict=False):
+    """Parse a fully-qualified name ``<module_name>.<member_name>``.
+
+    The ``module_name`` component contains one or more dots.
+    The ``member_name`` component does not contain any dot.
+
+    If *strict* is true, *fullname* must be a string. Otherwise,
+    it can be None, and, ``module_name`` and ``member_name`` will
+    also be None.
+    """
+    if fullname is None:
+        assert not strict
+        return None, None
+    assert isinstance(fullname, str), fullname
+    assert fullname.count(".") >= 1, fullname
+    module_name, member_name = fullname.rsplit(".", maxsplit=1)
+    return module_name, member_name
+
+
+def _import_module(module_name, *, strict=False):
+    """Import a module from its fully-qualified name.
+
+    If *strict* is false, import failures are suppressed and None is returned.
+    """
+    if module_name is None:
+        # To prevent a TypeError in importlib.import_module
+        if strict:
+            raise ImportError("no module to import")
+        return None
     try:
         return importlib.import_module(module_name)
-    except ImportError:
+    except ImportError as exc:
+        if strict:
+            raise exc
+        return None
+
+
+def _import_member(module_name, member_name, *, strict=False):
+    """Import a member from a module.
+
+    If *strict* is false, import failures are suppressed and None is returned.
+    """
+    if member_name is None:
+        if strict:
+            raise ImportError(f"no member to import from {module_name}")
         return None
+    module = _import_module(module_name, strict=strict)
+    if strict:
+        return getattr(module, member_name)
+    return getattr(module, member_name, None)
+
+
+class Implementation(enum.StrEnum):
+    # Indicate that the hash function is implemented by a built-in module.
+    builtin = enum.auto()
+    # Indicate that the hash function is implemented by OpenSSL.
+    openssl = enum.auto()
+    # Indicate that the hash function is provided through the public API.
+    hashlib = enum.auto()
 
 
-class HID(enum.StrEnum):
+class _HashId(enum.StrEnum):
     """Enumeration containing the canonical digest names.
 
     Those names should only be used by hashlib.new() or hmac.new().
@@ -57,62 +110,162 @@ class HID(enum.StrEnum):
         return self.startswith("blake2")
 
 
-CANONICAL_DIGEST_NAMES = frozenset(map(str, HID.__members__))
+CANONICAL_DIGEST_NAMES = frozenset(map(str, _HashId.__members__))
 NON_HMAC_DIGEST_NAMES = frozenset((
-    HID.shake_128, HID.shake_256,
-    HID.blake2s, HID.blake2b,
+    _HashId.shake_128, _HashId.shake_256,
+    _HashId.blake2s, _HashId.blake2b,
 ))
 
 
-class HashInfo:
-    """Dataclass storing explicit hash constructor names.
+class _HashInfoItem:
+    """Interface for interacting with a named object.
 
-    - *builtin* is the fully-qualified name for the explicit HACL*
-      hash constructor function, e.g., "_md5.md5".
-
-    - *openssl* is the name of the "_hashlib" module method for the explicit
-      OpenSSL hash constructor function, e.g., "openssl_md5".
+    The object is entirely described by its fully-qualified *fullname*.
 
-    - *hashlib* is the name of the "hashlib" module method for the explicit
-      hash constructor function, e.g., "md5".
+    *fullname* must be None or a string "<module_name>.<member_name>".
     """
 
-    def __init__(self, builtin, openssl=None, hashlib=None):
-        assert isinstance(builtin, str), builtin
-        assert len(builtin.split(".")) == 2, builtin
+    def __init__(self, fullname=None, *, strict=False):
+        module_name, member_name = _parse_fullname(fullname, strict=strict)
+        self.fullname = fullname
+        self.module_name = module_name
+        self.member_name = member_name
+
+    def import_module(self, *, strict=False):
+        """Import the described module.
+
+        If *strict* is true, an ImportError may be raised if importing fails,
+        otherwise, None is returned on error.
+        """
+        return _import_module(self.module_name, strict=strict)
 
-        self.builtin = builtin
-        self.builtin_module_name, self.builtin_method_name = (
-            self.builtin.split(".", maxsplit=1)
+    def import_member(self, *, strict=False):
+        """Import the described member.
+
+        If *strict* is true, an AttributeError or an ImportError may be
+        raised if importing fails; otherwise, None is returned on error.
+        """
+        return _import_member(
+            self.module_name, self.member_name, strict=strict
         )
 
-        assert openssl is None or openssl.startswith("openssl_")
-        self.openssl = self.openssl_method_name = openssl
-        self.openssl_module_name = "_hashlib" if openssl else None
 
-        assert hashlib is None or isinstance(hashlib, str)
-        self.hashlib = self.hashlib_method_name = hashlib
-        self.hashlib_module_name = "hashlib" if hashlib else None
+class _HashInfoBase:
+    """Base dataclass containing "backend" information.
+
+    Subclasses may define an attribute named after one of the known
+    implementations ("builtin", "openssl" or "hashlib") which stores
+    an _HashInfoItem object.
+
+    Those attributes can be retrieved through __getitem__(), e.g.,
+    ``info["builtin"]`` returns the _HashInfoItem corresponding to
+    the builtin implementation.
+    """
+
+    def __init__(self, canonical_name):
+        assert isinstance(canonical_name, _HashId), canonical_name
+        self.canonical_name = canonical_name
+
+    def __getitem__(self, implementation):
+        try:
+            attrname = Implementation(implementation)
+        except ValueError:
+            raise self.invalid_implementation_error(implementation) from None
+
+        try:
+            provider = getattr(self, attrname)
+        except AttributeError:
+            raise self.invalid_implementation_error(implementation) from None
+
+        if not isinstance(provider, _HashInfoItem):
+            raise KeyError(implementation)
+        return provider
+
+    def invalid_implementation_error(self, implementation):
+        msg = f"no implementation {implementation} for {self.canonical_name}"
+        return AssertionError(msg)
+
+
+class _HashTypeInfo(_HashInfoBase):
+    """Dataclass containing information for hash functions types.
+
+    - *canonical_name* must be a _HashId.
+
+    - *builtin* is the fully-qualified name for the builtin HACL* type,
+      e.g., "_md5.MD5Type".
+
+    - *openssl* is the fully-qualified name for the OpenSSL wrapper type,
+      e.g., "_hashlib.HASH".
+    """
+
+    def __init__(self, canonical_name, builtin, openssl):
+        super().__init__(canonical_name)
+        self.builtin = _HashInfoItem(builtin, strict=True)
+        self.openssl = _HashInfoItem(openssl, strict=True)
+
+    def fullname(self, implementation):
+        """Get the fully qualified name of a given implementation.
+
+        This returns a string of the form "MODULE_NAME.OBJECT_NAME" or None
+        if the hash function does not have a corresponding implementation.
+
+        *implementation* must be "builtin" or "openssl".
+        """
+        return self[implementation].fullname
 
     def module_name(self, implementation):
-        match implementation:
-            case "builtin":
-                return self.builtin_module_name
-            case "openssl":
-                return self.openssl_module_name
-            case "hashlib":
-                return self.hashlib_module_name
-        raise AssertionError(f"invalid implementation {implementation}")
+        """Get the name of the module containing the hash object type."""
+        return self[implementation].module_name
+
+    def object_type_name(self, implementation):
+        """Get the name of the hash object class name."""
+        return self[implementation].member_name
+
+    def import_module(self, implementation, *, allow_skip=False):
+        """Import the module containing the hash object type.
+
+        On error, return None if *allow_skip* is false, or raise SkipNoHash.
+        """
+        target = self[implementation]
+        module = target.import_module()
+        if allow_skip and module is None:
+            reason = f"cannot import module {target.module_name}"
+            raise SkipNoHash(self.canonical_name, implementation, reason)
+        return module
+
+    def import_object_type(self, implementation, *, allow_skip=False):
+        """Get the runtime hash object type.
+
+        On error, return None if *allow_skip* is false, or raise SkipNoHash.
+        """
+        target = self[implementation]
+        member = target.import_member()
+        if allow_skip and member is None:
+            reason = f"cannot import class {target.fullname}"
+            raise SkipNoHash(self.canonical_name, implementation, reason)
+        return member
 
-    def method_name(self, implementation):
-        match implementation:
-            case "builtin":
-                return self.builtin_method_name
-            case "openssl":
-                return self.openssl_method_name
-            case "hashlib":
-                return self.hashlib_method_name
-        raise AssertionError(f"invalid implementation {implementation}")
+
+class _HashFuncInfo(_HashInfoBase):
+    """Dataclass containing information for hash functions constructors.
+
+    - *canonical_name* must be a _HashId.
+
+    - *builtin* is the fully-qualified name of the HACL*
+      hash constructor function, e.g., "_md5.md5".
+
+    - *openssl* is the fully-qualified name of the "_hashlib" method
+      for the OpenSSL named constructor, e.g., "_hashlib.openssl_md5".
+
+    - *hashlib* is the fully-qualified name of the "hashlib" method
+      for the explicit named hash constructor, e.g., "hashlib.md5".
+    """
+
+    def __init__(self, canonical_name, builtin, openssl=None, hashlib=None):
+        super().__init__(canonical_name)
+        self.builtin = _HashInfoItem(builtin, strict=True)
+        self.openssl = _HashInfoItem(openssl, strict=False)
+        self.hashlib = _HashInfoItem(hashlib, strict=False)
 
     def fullname(self, implementation):
         """Get the fully qualified name of a given implementation.
@@ -122,63 +275,239 @@ class HashInfo:
 
         *implementation* must be "builtin", "openssl" or "hashlib".
         """
-        module_name = self.module_name(implementation)
-        method_name = self.method_name(implementation)
-        if module_name is None or method_name is None:
-            return None
-        return f"{module_name}.{method_name}"
-
-
-# Mapping from a "canonical" name to a pair (HACL*, _hashlib.*, hashlib.*)
-# constructors. If the constructor name is None, then this means that the
-# algorithm can only be used by the "agile" new() interfaces.
-_EXPLICIT_CONSTRUCTORS = MappingProxyType({  # fmt: skip
-    HID.md5: HashInfo("_md5.md5", "openssl_md5", "md5"),
-    HID.sha1: HashInfo("_sha1.sha1", "openssl_sha1", "sha1"),
-    HID.sha224: HashInfo("_sha2.sha224", "openssl_sha224", "sha224"),
-    HID.sha256: HashInfo("_sha2.sha256", "openssl_sha256", "sha256"),
-    HID.sha384: HashInfo("_sha2.sha384", "openssl_sha384", "sha384"),
-    HID.sha512: HashInfo("_sha2.sha512", "openssl_sha512", "sha512"),
-    HID.sha3_224: HashInfo(
-        "_sha3.sha3_224", "openssl_sha3_224", "sha3_224"
+        return self[implementation].fullname
+
+    def module_name(self, implementation):
+        """Get the name of the constructor function module.
+
+        The *implementation* must be "builtin", "openssl" or "hashlib".
+        """
+        return self[implementation].module_name
+
+    def method_name(self, implementation):
+        """Get the name of the constructor function module method.
+
+        Use fullname() to get the constructor function fully-qualified name.
+
+        The *implementation* must be "builtin", "openssl" or "hashlib".
+        """
+        return self[implementation].member_name
+
+
+class _HashInfo:
+    """Dataclass containing information for supported hash functions.
+
+    Attributes
+    ----------
+    canonical_name : _HashId
+        The hash function canonical name.
+    type : _HashTypeInfo
+        The hash object types information.
+    func : _HashTypeInfo
+        The hash object constructors information.
+    """
+
+    def __init__(
+        self,
+        canonical_name,
+        builtin_object_type_fullname,
+        openssl_object_type_fullname,
+        builtin_method_fullname,
+        openssl_method_fullname=None,
+        hashlib_method_fullname=None,
+    ):
+        """
+        - *canonical_name* must be a _HashId.
+
+        - *builtin_object_type_fullname* is the fully-qualified name
+          for the builtin HACL* type, e.g., "_md5.MD5Type".
+
+        - *openssl_object_type_fullname* is the fully-qualified name
+          for the OpenSSL wrapper type, e.g., "_hashlib.HASH".
+
+        - *builtin_method_fullname* is the fully-qualified name
+          of the HACL* hash constructor function, e.g., "_md5.md5".
+
+        - *openssl_method_fullname* is the fully-qualified name
+          of the "_hashlib" module method for the explicit OpenSSL
+          hash constructor function, e.g., "_hashlib.openssl_md5".
+
+        - *hashlib_method_fullname* is the fully-qualified name
+          of the "hashlib"  module method for the explicit hash
+          constructor function, e.g., "hashlib.md5".
+        """
+        assert isinstance(canonical_name, _HashId), canonical_name
+        self.canonical_name = canonical_name
+        self.type = _HashTypeInfo(
+            canonical_name,
+            builtin_object_type_fullname,
+            openssl_object_type_fullname,
+        )
+        self.func = _HashFuncInfo(
+            canonical_name,
+            builtin_method_fullname,
+            openssl_method_fullname,
+            hashlib_method_fullname,
+        )
+
+
+_HASHINFO_DATABASE = MappingProxyType({
+    _HashId.md5: _HashInfo(
+        _HashId.md5,
+        "_md5.MD5Type",
+        "_hashlib.HASH",
+        "_md5.md5",
+        "_hashlib.openssl_md5",
+        "hashlib.md5",
+    ),
+    _HashId.sha1: _HashInfo(
+        _HashId.sha1,
+        "_sha1.SHA1Type",
+        "_hashlib.HASH",
+        "_sha1.sha1",
+        "_hashlib.openssl_sha1",
+        "hashlib.sha1",
+    ),
+    _HashId.sha224: _HashInfo(
+        _HashId.sha224,
+        "_sha2.SHA224Type",
+        "_hashlib.HASH",
+        "_sha2.sha224",
+        "_hashlib.openssl_sha224",
+        "hashlib.sha224",
+    ),
+    _HashId.sha256: _HashInfo(
+        _HashId.sha256,
+        "_sha2.SHA256Type",
+        "_hashlib.HASH",
+        "_sha2.sha256",
+        "_hashlib.openssl_sha256",
+        "hashlib.sha256",
+    ),
+    _HashId.sha384: _HashInfo(
+        _HashId.sha384,
+        "_sha2.SHA384Type",
+        "_hashlib.HASH",
+        "_sha2.sha384",
+        "_hashlib.openssl_sha384",
+        "hashlib.sha384",
     ),
-    HID.sha3_256: HashInfo(
-        "_sha3.sha3_256", "openssl_sha3_256", "sha3_256"
+    _HashId.sha512: _HashInfo(
+        _HashId.sha512,
+        "_sha2.SHA512Type",
+        "_hashlib.HASH",
+        "_sha2.sha512",
+        "_hashlib.openssl_sha512",
+        "hashlib.sha512",
     ),
-    HID.sha3_384: HashInfo(
-        "_sha3.sha3_384", "openssl_sha3_384", "sha3_384"
+    _HashId.sha3_224: _HashInfo(
+        _HashId.sha3_224,
+        "_sha3.sha3_224",
+        "_hashlib.HASH",
+        "_sha3.sha3_224",
+        "_hashlib.openssl_sha3_224",
+        "hashlib.sha3_224",
     ),
-    HID.sha3_512: HashInfo(
-        "_sha3.sha3_512", "openssl_sha3_512", "sha3_512"
+    _HashId.sha3_256: _HashInfo(
+        _HashId.sha3_256,
+        "_sha3.sha3_256",
+        "_hashlib.HASH",
+        "_sha3.sha3_256",
+        "_hashlib.openssl_sha3_256",
+        "hashlib.sha3_256",
     ),
-    HID.shake_128: HashInfo(
-        "_sha3.shake_128", "openssl_shake_128", "shake_128"
+    _HashId.sha3_384: _HashInfo(
+        _HashId.sha3_384,
+        "_sha3.sha3_384",
+        "_hashlib.HASH",
+        "_sha3.sha3_384",
+        "_hashlib.openssl_sha3_384",
+        "hashlib.sha3_384",
     ),
-    HID.shake_256: HashInfo(
-        "_sha3.shake_256", "openssl_shake_256", "shake_256"
+    _HashId.sha3_512: _HashInfo(
+        _HashId.sha3_512,
+        "_sha3.sha3_512",
+        "_hashlib.HASH",
+        "_sha3.sha3_512",
+        "_hashlib.openssl_sha3_512",
+        "hashlib.sha3_512",
+    ),
+    _HashId.shake_128: _HashInfo(
+        _HashId.shake_128,
+        "_sha3.shake_128",
+        "_hashlib.HASHXOF",
+        "_sha3.shake_128",
+        "_hashlib.openssl_shake_128",
+        "hashlib.shake_128",
+    ),
+    _HashId.shake_256: _HashInfo(
+        _HashId.shake_256,
+        "_sha3.shake_256",
+        "_hashlib.HASHXOF",
+        "_sha3.shake_256",
+        "_hashlib.openssl_shake_256",
+        "hashlib.shake_256",
+    ),
+    _HashId.blake2s: _HashInfo(
+        _HashId.blake2s,
+        "_blake2.blake2s",
+        "_hashlib.HASH",
+        "_blake2.blake2s",
+        None,
+        "hashlib.blake2s",
+    ),
+    _HashId.blake2b: _HashInfo(
+        _HashId.blake2b,
+        "_blake2.blake2b",
+        "_hashlib.HASH",
+        "_blake2.blake2b",
+        None,
+        "hashlib.blake2b",
     ),
-    HID.blake2s: HashInfo("_blake2.blake2s", None, "blake2s"),
-    HID.blake2b: HashInfo("_blake2.blake2b", None, "blake2b"),
 })
-assert _EXPLICIT_CONSTRUCTORS.keys() == CANONICAL_DIGEST_NAMES
-get_hash_info = _EXPLICIT_CONSTRUCTORS.__getitem__
+assert _HASHINFO_DATABASE.keys() == CANONICAL_DIGEST_NAMES
+
+
+def get_hash_type_info(name):
+    info = _HASHINFO_DATABASE[name]
+    assert isinstance(info, _HashInfo), info
+    return info.type
+
+
+def get_hash_func_info(name):
+    info = _HASHINFO_DATABASE[name]
+    assert isinstance(info, _HashInfo), info
+    return info.func
+
+
+def _iter_hash_func_info(excluded):
+    for name, info in _HASHINFO_DATABASE.items():
+        if name not in excluded:
+            yield info.func
+
 
 # Mapping from canonical hash names to their explicit HACL* HMAC constructor.
 # There is currently no OpenSSL one-shot named function and there will likely
 # be none in the future.
-_EXPLICIT_HMAC_CONSTRUCTORS = {
-    HID(name): f"_hmac.compute_{name}"
-    for name in CANONICAL_DIGEST_NAMES
+_HMACINFO_DATABASE = {
+    _HashId(canonical_name): _HashInfoItem(f"_hmac.compute_{canonical_name}")
+    for canonical_name in CANONICAL_DIGEST_NAMES
 }
 # Neither HACL* nor OpenSSL supports HMAC over XOFs.
-_EXPLICIT_HMAC_CONSTRUCTORS[HID.shake_128] = None
-_EXPLICIT_HMAC_CONSTRUCTORS[HID.shake_256] = None
+_HMACINFO_DATABASE[_HashId.shake_128] = _HashInfoItem()
+_HMACINFO_DATABASE[_HashId.shake_256] = _HashInfoItem()
 # Strictly speaking, HMAC-BLAKE is meaningless as BLAKE2 is already a
 # keyed hash function. However, as it's exposed by HACL*, we test it.
-_EXPLICIT_HMAC_CONSTRUCTORS[HID.blake2s] = '_hmac.compute_blake2s_32'
-_EXPLICIT_HMAC_CONSTRUCTORS[HID.blake2b] = '_hmac.compute_blake2b_32'
-_EXPLICIT_HMAC_CONSTRUCTORS = MappingProxyType(_EXPLICIT_HMAC_CONSTRUCTORS)
-assert _EXPLICIT_HMAC_CONSTRUCTORS.keys() == CANONICAL_DIGEST_NAMES
+_HMACINFO_DATABASE[_HashId.blake2s] = _HashInfoItem('_hmac.compute_blake2s_32')
+_HMACINFO_DATABASE[_HashId.blake2b] = _HashInfoItem('_hmac.compute_blake2b_32')
+_HMACINFO_DATABASE = MappingProxyType(_HMACINFO_DATABASE)
+assert _HMACINFO_DATABASE.keys() == CANONICAL_DIGEST_NAMES
+
+
+def get_hmac_item_info(name):
+    info = _HMACINFO_DATABASE[name]
+    assert isinstance(info, _HashInfoItem), info
+    return info
 
 
 def _decorate_func_or_class(decorator_func, func_or_class):
@@ -230,26 +559,42 @@ def _ensure_wrapper_signature(wrapper, wrapped):
         )
 
 
-def requires_hashlib():
-    _hashlib = try_import_module("_hashlib")
+def _make_conditional_decorator(test, /, *test_args, **test_kwargs):
+    def decorator_func(func):
+        @functools.wraps(func)
+        def wrapper(*args, **kwargs):
+            test(*test_args, **test_kwargs)
+            return func(*args, **kwargs)
+        return wrapper
+    return functools.partial(_decorate_func_or_class, decorator_func)
+
+
+def requires_openssl_hashlib():
+    _hashlib = _import_module("_hashlib")
     return unittest.skipIf(_hashlib is None, "requires _hashlib")
 
 
 def requires_builtin_hmac():
-    _hmac = try_import_module("_hmac")
+    _hmac = _import_module("_hmac")
     return unittest.skipIf(_hmac is None, "requires _hmac")
 
 
 class SkipNoHash(unittest.SkipTest):
     """A SkipTest exception raised when a hash is not available."""
 
-    def __init__(self, digestname, implementation=None, interface=None):
+    def __init__(self, digestname, implementation=None, reason=None):
         parts = ["missing", implementation, f"hash algorithm {digestname!r}"]
-        if interface is not None:
-            parts.append(f"for {interface}")
+        if reason is not None:
+            parts.insert(0, f"{reason}: ")
         super().__init__(" ".join(filter(None, parts)))
 
 
+class SkipNoHashInCall(SkipNoHash):
+
+    def __init__(self, func, digestname, implementation=None):
+        super().__init__(digestname, implementation, f"cannot use {func}")
+
+
 def _hashlib_new(digestname, openssl, /, **kwargs):
     """Check availability of [hashlib|_hashlib].new(digestname, **kwargs).
 
@@ -264,13 +609,12 @@ def _hashlib_new(digestname, openssl, /, **kwargs):
     # exceptions as it should be unconditionally available.
     hashlib = importlib.import_module("hashlib")
     # re-import '_hashlib' in case it was mocked
-    _hashlib = try_import_module("_hashlib")
+    _hashlib = _import_module("_hashlib")
     module = _hashlib if openssl and _hashlib is not None else hashlib
     try:
         module.new(digestname, **kwargs)
     except ValueError as exc:
-        interface = f"{module.__name__}.new"
-        raise SkipNoHash(digestname, interface=interface) from exc
+        raise SkipNoHashInCall(f"{module.__name__}.new", digestname) from exc
     return functools.partial(module.new, digestname)
 
 
@@ -315,7 +659,7 @@ def _openssl_new(digestname, /, **kwargs):
     try:
         _hashlib.new(digestname, **kwargs)
     except ValueError as exc:
-        raise SkipNoHash(digestname, interface="_hashlib.new") from exc
+        raise SkipNoHashInCall("_hashlib.new", digestname) from exc
     return functools.partial(_hashlib.new, digestname)
 
 
@@ -326,14 +670,15 @@ def _openssl_hash(digestname, /, **kwargs):
     or SkipTest is raised if none exists.
     """
     assert isinstance(digestname, str), digestname
-    fullname = f"_hashlib.openssl_{digestname}"
+    method_name = f"openssl_{digestname}"
+    fullname = f"_hashlib.{method_name}"
     try:
         # re-import '_hashlib' in case it was mocked
         _hashlib = importlib.import_module("_hashlib")
     except ImportError as exc:
         raise SkipNoHash(fullname, "openssl") from exc
     try:
-        constructor = getattr(_hashlib, f"openssl_{digestname}", None)
+        constructor = getattr(_hashlib, method_name)
     except AttributeError as exc:
         raise SkipNoHash(fullname, "openssl") from exc
     try:
@@ -343,16 +688,6 @@ def _openssl_hash(digestname, /, **kwargs):
     return constructor
 
 
-def _make_requires_hashdigest_decorator(test, /, *test_args, **test_kwargs):
-    def decorator_func(func):
-        @functools.wraps(func)
-        def wrapper(*args, **kwargs):
-            test(*test_args, **test_kwargs)
-            return func(*args, **kwargs)
-        return wrapper
-    return functools.partial(_decorate_func_or_class, decorator_func)
-
-
 def requires_hashdigest(digestname, openssl=None, *, usedforsecurity=True):
     """Decorator raising SkipTest if a hashing algorithm is not available.
 
@@ -370,7 +705,7 @@ def requires_hashdigest(digestname, openssl=None, *, usedforsecurity=True):
     ValueError: [digital envelope routines: EVP_DigestInit_ex] disabled for FIPS
     ValueError: unsupported hash type md4
     """
-    return _make_requires_hashdigest_decorator(
+    return _make_conditional_decorator(
         _hashlib_new, digestname, openssl, usedforsecurity=usedforsecurity
     )
 
@@ -380,34 +715,35 @@ def requires_openssl_hashdigest(digestname, *, usedforsecurity=True):
 
     The hashing algorithm may be missing or blocked by a strict crypto policy.
     """
-    return _make_requires_hashdigest_decorator(
+    return _make_conditional_decorator(
         _openssl_new, digestname, usedforsecurity=usedforsecurity
     )
 
 
-def requires_builtin_hashdigest(
-    module_name, digestname, *, usedforsecurity=True
-):
-    """Decorator raising SkipTest if a HACL* hashing algorithm is missing.
+def _make_requires_builtin_hashdigest_decorator(item, *, usedforsecurity=True):
+    assert isinstance(item, _HashInfoItem), item
+    return _make_conditional_decorator(
+        _builtin_hash,
+        item.module_name,
+        item.member_name,
+        usedforsecurity=usedforsecurity,
+    )
 
-    - The *module_name* is the C extension module name based on HACL*.
-    - The *digestname* is one of its member, e.g., 'md5'.
-    """
-    return _make_requires_hashdigest_decorator(
-        _builtin_hash, module_name, digestname, usedforsecurity=usedforsecurity
+
+def requires_builtin_hashdigest(canonical_name, *, usedforsecurity=True):
+    """Decorator raising SkipTest if a HACL* hashing algorithm is missing."""
+    info = get_hash_func_info(canonical_name)
+    return _make_requires_builtin_hashdigest_decorator(
+        info.builtin, usedforsecurity=usedforsecurity
     )
 
 
-def requires_builtin_hashes(*ignored, usedforsecurity=True):
+def requires_builtin_hashes(*, exclude=(), usedforsecurity=True):
     """Decorator raising SkipTest if one HACL* hashing algorithm is missing."""
     return _chain_decorators((
-        requires_builtin_hashdigest(
-            api.builtin_module_name,
-            api.builtin_method_name,
-            usedforsecurity=usedforsecurity,
-        )
-        for name, api in _EXPLICIT_CONSTRUCTORS.items()
-        if name not in ignored
+        _make_requires_builtin_hashdigest_decorator(
+            info.builtin, usedforsecurity=usedforsecurity
+        ) for info in _iter_hash_func_info(exclude)
     ))
 
 
@@ -424,69 +760,31 @@ class HashFunctionsTrait:
     implementation of HMAC).
     """
 
-    DIGEST_NAMES = [
-        'md5', 'sha1',
-        'sha224', 'sha256', 'sha384', 'sha512',
-        'sha3_224', 'sha3_256', 'sha3_384', 'sha3_512',
-    ]
-
     # Default 'usedforsecurity' to use when checking a hash function.
     # When the trait properties are callables (e.g., _md5.md5) and
     # not strings, they must be called with the same 'usedforsecurity'.
     usedforsecurity = True
 
-    @classmethod
-    def setUpClass(cls):
-        super().setUpClass()
-        assert CANONICAL_DIGEST_NAMES.issuperset(cls.DIGEST_NAMES)
-
     def is_valid_digest_name(self, digestname):
-        self.assertIn(digestname, self.DIGEST_NAMES)
+        self.assertIn(digestname, _HashId)
 
     def _find_constructor(self, digestname):
         # By default, a missing algorithm skips the test that uses it.
         self.is_valid_digest_name(digestname)
         self.skipTest(f"missing hash function: {digestname}")
 
-    @property
-    def md5(self):
-        return self._find_constructor("md5")
+    md5 = property(lambda self: self._find_constructor("md5"))
+    sha1 = property(lambda self: self._find_constructor("sha1"))
 
-    @property
-    def sha1(self):
-        return self._find_constructor("sha1")
-
-    @property
-    def sha224(self):
-        return self._find_constructor("sha224")
-
-    @property
-    def sha256(self):
-        return self._find_constructor("sha256")
-
-    @property
-    def sha384(self):
-        return self._find_constructor("sha384")
-
-    @property
-    def sha512(self):
-        return self._find_constructor("sha512")
-
-    @property
-    def sha3_224(self):
-        return self._find_constructor("sha3_224")
-
-    @property
-    def sha3_256(self):
-        return self._find_constructor("sha3_256")
-
-    @property
-    def sha3_384(self):
-        return self._find_constructor("sha3_384")
+    sha224 = property(lambda self: self._find_constructor("sha224"))
+    sha256 = property(lambda self: self._find_constructor("sha256"))
+    sha384 = property(lambda self: self._find_constructor("sha384"))
+    sha512 = property(lambda self: self._find_constructor("sha512"))
 
-    @property
-    def sha3_512(self):
-        return self._find_constructor("sha3_512")
+    sha3_224 = property(lambda self: self._find_constructor("sha3_224"))
+    sha3_256 = property(lambda self: self._find_constructor("sha3_256"))
+    sha3_384 = property(lambda self: self._find_constructor("sha3_384"))
+    sha3_512 = property(lambda self: self._find_constructor("sha3_512"))
 
 
 class NamedHashFunctionsTrait(HashFunctionsTrait):
@@ -497,7 +795,7 @@ class NamedHashFunctionsTrait(HashFunctionsTrait):
 
     def _find_constructor(self, digestname):
         self.is_valid_digest_name(digestname)
-        return digestname
+        return str(digestname)  # ensure that we are an exact string
 
 
 class OpenSSLHashFunctionsTrait(HashFunctionsTrait):
@@ -523,10 +821,10 @@ class BuiltinHashFunctionsTrait(HashFunctionsTrait):
 
     def _find_constructor(self, digestname):
         self.is_valid_digest_name(digestname)
-        info = _EXPLICIT_CONSTRUCTORS[digestname]
+        info = get_hash_func_info(digestname)
         return _builtin_hash(
-            info.builtin_module_name,
-            info.builtin_method_name,
+            info.builtin.module_name,
+            info.builtin.member_name,
             usedforsecurity=self.usedforsecurity,
         )
 
@@ -542,7 +840,7 @@ def find_gil_minsize(modules_names, default=2048):
     """
     sizes = []
     for module_name in modules_names:
-        module = try_import_module(module_name)
+        module = _import_module(module_name)
         if module is not None:
             sizes.append(getattr(module, '_GIL_MINSIZE', default))
     return max(sizes, default=default)
@@ -553,7 +851,7 @@ def _block_openssl_hash_new(blocked_name):
     assert isinstance(blocked_name, str), blocked_name
 
     # re-import '_hashlib' in case it was mocked
-    if (_hashlib := try_import_module("_hashlib")) is None:
+    if (_hashlib := _import_module("_hashlib")) is None:
         return contextlib.nullcontext()
 
     @functools.wraps(wrapped := _hashlib.new)
@@ -572,7 +870,7 @@ def _block_openssl_hmac_new(blocked_name):
     assert isinstance(blocked_name, str), blocked_name
 
     # re-import '_hashlib' in case it was mocked
-    if (_hashlib := try_import_module("_hashlib")) is None:
+    if (_hashlib := _import_module("_hashlib")) is None:
         return contextlib.nullcontext()
 
     @functools.wraps(wrapped := _hashlib.hmac_new)
@@ -590,7 +888,7 @@ def _block_openssl_hmac_digest(blocked_name):
     assert isinstance(blocked_name, str), blocked_name
 
     # re-import '_hashlib' in case it was mocked
-    if (_hashlib := try_import_module("_hashlib")) is None:
+    if (_hashlib := _import_module("_hashlib")) is None:
         return contextlib.nullcontext()
 
     @functools.wraps(wrapped := _hashlib.hmac_digest)
@@ -607,7 +905,7 @@ def _block_builtin_hash_new(name):
     """Block a buitin-in hash name from the hashlib.new() interface."""
     assert isinstance(name, str), name
     assert name.lower() == name, f"invalid name: {name}"
-    assert name in HID, f"invalid hash: {name}"
+    assert name in _HashId, f"invalid hash: {name}"
 
     # Re-import 'hashlib' in case it was mocked
     hashlib = importlib.import_module('hashlib')
@@ -620,7 +918,7 @@ def _block_builtin_hash_new(name):
     # so we need to block the possibility of importing it, but only
     # during the call to __get_builtin_constructor().
     get_builtin_constructor = getattr(hashlib, '__get_builtin_constructor')
-    builtin_module_name = _EXPLICIT_CONSTRUCTORS[name].builtin_module_name
+    builtin_module_name = get_hash_func_info(name).builtin.module_name
 
     @functools.wraps(get_builtin_constructor)
     def get_builtin_constructor_mock(name):
@@ -632,7 +930,7 @@ def _block_builtin_hash_new(name):
     return unittest.mock.patch.multiple(
         hashlib,
         __get_builtin_constructor=get_builtin_constructor_mock,
-        __builtin_constructor_cache=builtin_constructor_cache_mock
+        __builtin_constructor_cache=builtin_constructor_cache_mock,
     )
 
 
@@ -640,7 +938,7 @@ def _block_builtin_hmac_new(blocked_name):
     assert isinstance(blocked_name, str), blocked_name
 
     # re-import '_hmac' in case it was mocked
-    if (_hmac := try_import_module("_hmac")) is None:
+    if (_hmac := _import_module("_hmac")) is None:
         return contextlib.nullcontext()
 
     @functools.wraps(wrapped := _hmac.new)
@@ -657,7 +955,7 @@ def _block_builtin_hmac_digest(blocked_name):
     assert isinstance(blocked_name, str), blocked_name
 
     # re-import '_hmac' in case it was mocked
-    if (_hmac := try_import_module("_hmac")) is None:
+    if (_hmac := _import_module("_hmac")) is None:
         return contextlib.nullcontext()
 
     @functools.wraps(wrapped := _hmac.compute_digest)
@@ -671,30 +969,19 @@ def _block_builtin_hmac_digest(blocked_name):
 
 
 def _make_hash_constructor_blocker(name, dummy, implementation):
-    info = _EXPLICIT_CONSTRUCTORS[name]
-    module_name = info.module_name(implementation)
-    method_name = info.method_name(implementation)
-    if module_name is None or method_name is None:
+    info = get_hash_func_info(name)[implementation]
+    if (wrapped := info.import_member()) is None:
         # function shouldn't exist for this implementation
         return contextlib.nullcontext()
-
-    try:
-        module = importlib.import_module(module_name)
-    except ImportError:
-        # module is already disabled
-        return contextlib.nullcontext()
-
-    wrapped = getattr(module, method_name)
     wrapper = functools.wraps(wrapped)(dummy)
     _ensure_wrapper_signature(wrapper, wrapped)
-    return unittest.mock.patch(info.fullname(implementation), wrapper)
+    return unittest.mock.patch(info.fullname, wrapper)
 
 
 def _block_hashlib_hash_constructor(name):
     """Block explicit public constructors."""
     def dummy(data=b'', *, usedforsecurity=True, string=None):
         raise ValueError(f"blocked explicit public hash name: {name}")
-
     return _make_hash_constructor_blocker(name, dummy, 'hashlib')
 
 
@@ -714,23 +1001,18 @@ def _block_builtin_hash_constructor(name):
 
 def _block_builtin_hmac_constructor(name):
     """Block explicit HACL* HMAC constructors."""
-    fullname = _EXPLICIT_HMAC_CONSTRUCTORS[name]
-    if fullname is None:
+    info = get_hmac_item_info(name)
+    assert info.module_name is None or info.module_name == "_hmac", info
+    if (wrapped := info.import_member()) is None:
         # function shouldn't exist for this implementation
         return contextlib.nullcontext()
-    assert fullname.count('.') == 1, fullname
-    module_name, method = fullname.split('.', maxsplit=1)
-    assert module_name == '_hmac', module_name
-    try:
-        module = importlib.import_module(module_name)
-    except ImportError:
-        # module is already disabled
-        return contextlib.nullcontext()
-    @functools.wraps(wrapped := getattr(module, method))
+
+    @functools.wraps(wrapped)
     def wrapper(key, obj):
         raise ValueError(f"blocked hash name: {name}")
+
     _ensure_wrapper_signature(wrapper, wrapped)
-    return unittest.mock.patch(fullname, wrapper)
+    return unittest.mock.patch(info.fullname, wrapper)
 
 
 @contextlib.contextmanager
@@ -760,14 +1042,14 @@ def block_algorithm(name, *, allow_openssl=False, allow_builtin=False):
             # the OpenSSL implementation, except with usedforsecurity=False.
             # However, blocking such functions also means blocking them
             # so we again need to block them if we want to.
-            (_hashlib := try_import_module("_hashlib"))
+            (_hashlib := _import_module("_hashlib"))
             and _hashlib.get_fips_mode()
             and not allow_openssl
         ) or (
             # Without OpenSSL, hashlib.<name>() functions are aliases
             # to built-in functions, so both of them must be blocked
             # as the module may have been imported before the HACL ones.
-            not (_hashlib := try_import_module("_hashlib"))
+            not (_hashlib := _import_module("_hashlib"))
             and not allow_builtin
         ):
             stack.enter_context(_block_hashlib_hash_constructor(name))
@@ -794,3 +1076,21 @@ def block_algorithm(name, *, allow_openssl=False, allow_builtin=False):
             # _hmac.compute_digest(..., name)
             stack.enter_context(_block_builtin_hmac_digest(name))
         yield
+
+
+@contextlib.contextmanager
+def block_openssl_algorithms(*, exclude=()):
+    """Block OpenSSL implementations, except those given in *exclude*."""
+    with contextlib.ExitStack() as stack:
+        for name in CANONICAL_DIGEST_NAMES.difference(exclude):
+            stack.enter_context(block_algorithm(name, allow_builtin=True))
+        yield
+
+
+@contextlib.contextmanager
+def block_builtin_algorithms(*, exclude=()):
+    """Block HACL* implementations, except those given in *exclude*."""
+    with contextlib.ExitStack() as stack:
+        for name in CANONICAL_DIGEST_NAMES.difference(exclude):
+            stack.enter_context(block_algorithm(name, allow_openssl=True))
+        yield
index 5c29369d10b1432fe61177b712e41a6fa3d4b3b9..7634deeb1d8eb9be1207189259af9c8653e272ef 100644 (file)
@@ -161,7 +161,7 @@ class ThroughModuleAPIMixin(ModuleMixin, CreatorMixin, DigestMixin):
         return _call_digest_func(self.hmac.digest, key, msg, digestmod)
 
 
-@hashlib_helper.requires_hashlib()
+@hashlib_helper.requires_openssl_hashlib()
 class ThroughOpenSSLAPIMixin(CreatorMixin, DigestMixin):
     """Mixin delegating to _hashlib.hmac_new() and _hashlib.hmac_digest()."""
 
@@ -1431,7 +1431,7 @@ class HMACCompareDigestTestCase(CompareDigestMixin, unittest.TestCase):
             self.assertIs(self.compare_digest, operator_compare_digest)
 
 
-@hashlib_helper.requires_hashlib()
+@hashlib_helper.requires_openssl_hashlib()
 class OpenSSLCompareDigestTestCase(CompareDigestMixin, unittest.TestCase):
     compare_digest = openssl_compare_digest
 
@@ -1509,7 +1509,7 @@ class PyMiscellaneousTests(unittest.TestCase):
         hmac = import_fresh_module("hmac", blocked=["_hmac"])
         self.do_test_hmac_digest_overflow_error_switch_to_slow(hmac, size)
 
-    @hashlib_helper.requires_builtin_hashdigest("_md5", "md5")
+    @hashlib_helper.requires_builtin_hashdigest("md5")
     @bigmemtest(size=_4G + 5, memuse=2, dry_run=False)
     def test_hmac_digest_overflow_error_builtin_only(self, size):
         hmac = import_fresh_module("hmac", blocked=["_hashlib"])
index 12361aa4e518a605617db178946ae45d4a03d1eb..711691348ffe7c7394571569a1d635b81d9bafd6 100644 (file)
@@ -866,22 +866,22 @@ class TestHashlibSupport(unittest.TestCase):
             return default
 
     def fetch_hash_function(self, name, implementation):
-        info = hashlib_helper.get_hash_info(name)
-        match implementation:
-            case "hashlib":
-                assert info.hashlib is not None, info
-                return getattr(self.hashlib, info.hashlib)
-            case "openssl":
-                try:
-                    return getattr(self._hashlib, info.openssl, None)
-                except TypeError:
-                    return None
-        fullname = info.fullname(implementation)
+        info = hashlib_helper.get_hash_func_info(name)
+        match hashlib_helper.Implementation(implementation):
+            case hashlib_helper.Implementation.hashlib:
+                method_name = info.hashlib.member_name
+                assert isinstance(method_name, str), method_name
+                return getattr(self.hashlib, method_name)
+            case hashlib_helper.Implementation.openssl:
+                method_name = info.openssl.member_name
+                assert isinstance(method_name, str | None), method_name
+                return getattr(self._hashlib, method_name or "", None)
+        fullname = info[implementation].fullname
         return self.try_import_attribute(fullname)
 
     def fetch_hmac_function(self, name):
-        fullname = hashlib_helper._EXPLICIT_HMAC_CONSTRUCTORS[name]
-        return self.try_import_attribute(fullname)
+        target = hashlib_helper.get_hmac_item_info(name)
+        return target.import_member()
 
     def check_openssl_hash(self, name, *, disabled=True):
         """Check that OpenSSL HASH interface is enabled/disabled."""