from test.support import hashlib_helper
+from _operator import _compare_digest as operator_compare_digest
+
try:
from _hashlib import HMAC as C_HMAC
from _hashlib import hmac_new as c_hmac_new
+ from _hashlib import compare_digest as openssl_compare_digest
except ImportError:
C_HMAC = None
c_hmac_new = None
+ openssl_compare_digest = None
def ignore_warning(func):
class CompareDigestTestCase(unittest.TestCase):
- def test_compare_digest(self):
+ def test_hmac_compare_digest(self):
+ self._test_compare_digest(hmac.compare_digest)
+ if openssl_compare_digest is not None:
+ self.assertIs(hmac.compare_digest, openssl_compare_digest)
+ else:
+ self.assertIs(hmac.compare_digest, operator_compare_digest)
+
+ def test_operator_compare_digest(self):
+ self._test_compare_digest(operator_compare_digest)
+
+ @unittest.skipIf(openssl_compare_digest is None, "test requires _hashlib")
+ def test_openssl_compare_digest(self):
+ self._test_compare_digest(openssl_compare_digest)
+
+ def _test_compare_digest(self, compare_digest):
# Testing input type exception handling
a, b = 100, 200
- self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ self.assertRaises(TypeError, compare_digest, a, b)
a, b = 100, b"foobar"
- self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ self.assertRaises(TypeError, compare_digest, a, b)
a, b = b"foobar", 200
- self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ self.assertRaises(TypeError, compare_digest, a, b)
a, b = "foobar", b"foobar"
- self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ self.assertRaises(TypeError, compare_digest, a, b)
a, b = b"foobar", "foobar"
- self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ self.assertRaises(TypeError, compare_digest, a, b)
# Testing bytes of different lengths
a, b = b"foobar", b"foo"
- self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(compare_digest(a, b))
a, b = b"\xde\xad\xbe\xef", b"\xde\xad"
- self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(compare_digest(a, b))
# Testing bytes of same lengths, different values
a, b = b"foobar", b"foobaz"
- self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(compare_digest(a, b))
a, b = b"\xde\xad\xbe\xef", b"\xab\xad\x1d\xea"
- self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(compare_digest(a, b))
# Testing bytes of same lengths, same values
a, b = b"foobar", b"foobar"
- self.assertTrue(hmac.compare_digest(a, b))
+ self.assertTrue(compare_digest(a, b))
a, b = b"\xde\xad\xbe\xef", b"\xde\xad\xbe\xef"
- self.assertTrue(hmac.compare_digest(a, b))
+ self.assertTrue(compare_digest(a, b))
# Testing bytearrays of same lengths, same values
a, b = bytearray(b"foobar"), bytearray(b"foobar")
- self.assertTrue(hmac.compare_digest(a, b))
+ self.assertTrue(compare_digest(a, b))
# Testing bytearrays of different lengths
a, b = bytearray(b"foobar"), bytearray(b"foo")
- self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(compare_digest(a, b))
# Testing bytearrays of same lengths, different values
a, b = bytearray(b"foobar"), bytearray(b"foobaz")
- self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(compare_digest(a, b))
# Testing byte and bytearray of same lengths, same values
a, b = bytearray(b"foobar"), b"foobar"
- self.assertTrue(hmac.compare_digest(a, b))
- self.assertTrue(hmac.compare_digest(b, a))
+ self.assertTrue(compare_digest(a, b))
+ self.assertTrue(compare_digest(b, a))
# Testing byte bytearray of different lengths
a, b = bytearray(b"foobar"), b"foo"
- self.assertFalse(hmac.compare_digest(a, b))
- self.assertFalse(hmac.compare_digest(b, a))
+ self.assertFalse(compare_digest(a, b))
+ self.assertFalse(compare_digest(b, a))
# Testing byte and bytearray of same lengths, different values
a, b = bytearray(b"foobar"), b"foobaz"
- self.assertFalse(hmac.compare_digest(a, b))
- self.assertFalse(hmac.compare_digest(b, a))
+ self.assertFalse(compare_digest(a, b))
+ self.assertFalse(compare_digest(b, a))
# Testing str of same lengths
a, b = "foobar", "foobar"
- self.assertTrue(hmac.compare_digest(a, b))
+ self.assertTrue(compare_digest(a, b))
# Testing str of different lengths
a, b = "foo", "foobar"
- self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(compare_digest(a, b))
# Testing bytes of same lengths, different values
a, b = "foobar", "foobaz"
- self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(compare_digest(a, b))
# Testing error cases
a, b = "foobar", b"foobar"
- self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ self.assertRaises(TypeError, compare_digest, a, b)
a, b = b"foobar", "foobar"
- self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ self.assertRaises(TypeError, compare_digest, a, b)
a, b = b"foobar", 1
- self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ self.assertRaises(TypeError, compare_digest, a, b)
a, b = 100, 200
- self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ self.assertRaises(TypeError, compare_digest, a, b)
a, b = "fooä", "fooä"
- self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ self.assertRaises(TypeError, compare_digest, a, b)
# subclasses are supported by ignore __eq__
class mystr(str):
return False
a, b = mystr("foobar"), mystr("foobar")
- self.assertTrue(hmac.compare_digest(a, b))
+ self.assertTrue(compare_digest(a, b))
a, b = mystr("foobar"), "foobar"
- self.assertTrue(hmac.compare_digest(a, b))
+ self.assertTrue(compare_digest(a, b))
a, b = mystr("foobar"), mystr("foobaz")
- self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(compare_digest(a, b))
class mybytes(bytes):
def __eq__(self, other):
return False
a, b = mybytes(b"foobar"), mybytes(b"foobar")
- self.assertTrue(hmac.compare_digest(a, b))
+ self.assertTrue(compare_digest(a, b))
a, b = mybytes(b"foobar"), b"foobar"
- self.assertTrue(hmac.compare_digest(a, b))
+ self.assertTrue(compare_digest(a, b))
a, b = mybytes(b"foobar"), mybytes(b"foobaz")
- self.assertFalse(hmac.compare_digest(a, b))
+ self.assertFalse(compare_digest(a, b))
if __name__ == "__main__":
/* EVP is the preferred interface to hashing in OpenSSL */
#include <openssl/evp.h>
#include <openssl/hmac.h>
+#include <openssl/crypto.h>
/* We use the object interface to discover what hashes OpenSSL supports. */
#include <openssl/objects.h>
#include "openssl/err.h"
#endif // !LIBRESSL_VERSION_NUMBER
+static int
+_tscmp(const unsigned char *a, const unsigned char *b,
+ Py_ssize_t len_a, Py_ssize_t len_b)
+{
+ /* loop count depends on length of b. Might leak very little timing
+ * information if sizes are different.
+ */
+ Py_ssize_t length = len_b;
+ const void *left = a;
+ const void *right = b;
+ int result = 0;
+
+ if (len_a != length) {
+ left = b;
+ result = 1;
+ }
+
+ result |= CRYPTO_memcmp(left, right, length);
+
+ return (result == 0);
+}
+
+/* NOTE: Keep in sync with _operator.c implementation. */
+
+/*[clinic input]
+_hashlib.compare_digest
+
+ a: object
+ b: object
+ /
+
+Return 'a == b'.
+
+This function uses an approach designed to prevent
+timing analysis, making it appropriate for cryptography.
+
+a and b must both be of the same type: either str (ASCII only),
+or any bytes-like object.
+
+Note: If a and b are of different lengths, or if an error occurs,
+a timing attack could theoretically reveal information about the
+types and lengths of a and b--but not their values.
+[clinic start generated code]*/
+
+static PyObject *
+_hashlib_compare_digest_impl(PyObject *module, PyObject *a, PyObject *b)
+/*[clinic end generated code: output=6f1c13927480aed9 input=9c40c6e566ca12f5]*/
+{
+ int rc;
+
+ /* ASCII unicode string */
+ if(PyUnicode_Check(a) && PyUnicode_Check(b)) {
+ if (PyUnicode_READY(a) == -1 || PyUnicode_READY(b) == -1) {
+ return NULL;
+ }
+ if (!PyUnicode_IS_ASCII(a) || !PyUnicode_IS_ASCII(b)) {
+ PyErr_SetString(PyExc_TypeError,
+ "comparing strings with non-ASCII characters is "
+ "not supported");
+ return NULL;
+ }
+
+ rc = _tscmp(PyUnicode_DATA(a),
+ PyUnicode_DATA(b),
+ PyUnicode_GET_LENGTH(a),
+ PyUnicode_GET_LENGTH(b));
+ }
+ /* fallback to buffer interface for bytes, bytesarray and other */
+ else {
+ Py_buffer view_a;
+ Py_buffer view_b;
+
+ if (PyObject_CheckBuffer(a) == 0 && PyObject_CheckBuffer(b) == 0) {
+ PyErr_Format(PyExc_TypeError,
+ "unsupported operand types(s) or combination of types: "
+ "'%.100s' and '%.100s'",
+ Py_TYPE(a)->tp_name, Py_TYPE(b)->tp_name);
+ return NULL;
+ }
+
+ if (PyObject_GetBuffer(a, &view_a, PyBUF_SIMPLE) == -1) {
+ return NULL;
+ }
+ if (view_a.ndim > 1) {
+ PyErr_SetString(PyExc_BufferError,
+ "Buffer must be single dimension");
+ PyBuffer_Release(&view_a);
+ return NULL;
+ }
+
+ if (PyObject_GetBuffer(b, &view_b, PyBUF_SIMPLE) == -1) {
+ PyBuffer_Release(&view_a);
+ return NULL;
+ }
+ if (view_b.ndim > 1) {
+ PyErr_SetString(PyExc_BufferError,
+ "Buffer must be single dimension");
+ PyBuffer_Release(&view_a);
+ PyBuffer_Release(&view_b);
+ return NULL;
+ }
+
+ rc = _tscmp((const unsigned char*)view_a.buf,
+ (const unsigned char*)view_b.buf,
+ view_a.len,
+ view_b.len);
+
+ PyBuffer_Release(&view_a);
+ PyBuffer_Release(&view_b);
+ }
+
+ return PyBool_FromLong(rc);
+}
+
/* List of functions exported by this module */
static struct PyMethodDef EVP_functions[] = {
PBKDF2_HMAC_METHODDEF
_HASHLIB_SCRYPT_METHODDEF
_HASHLIB_GET_FIPS_MODE_METHODDEF
+ _HASHLIB_COMPARE_DIGEST_METHODDEF
_HASHLIB_HMAC_SINGLESHOT_METHODDEF
_HASHLIB_HMAC_NEW_METHODDEF
_HASHLIB_OPENSSL_MD5_METHODDEF