]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-107862: Add property-based round-trip tests for base64 (#119406)
authorPetr Viktorin <encukou@gmail.com>
Wed, 17 Sep 2025 13:40:35 +0000 (15:40 +0200)
committerGitHub <noreply@github.com>
Wed, 17 Sep 2025 13:40:35 +0000 (13:40 +0000)
* Add property-based tests to test_base64

* Allow multiple positional arguments in @hypothesis.example stub

* Simplify the altchars strategy

Lib/test/support/_hypothesis_stubs/__init__.py
Lib/test/test_base64.py

index 6ba5bb814b92f7b427612ad93f0af24326b9240d..6fa013b55b2ac467e7cbe4dc829ad10b328de513 100644 (file)
@@ -24,7 +24,13 @@ def given(*_args, **_kwargs):
             @functools.wraps(f)
             def test_function(self):
                 for example_args, example_kwargs in examples:
-                    with self.subTest(*example_args, **example_kwargs):
+                    if len(example_args) < 2:
+                        subtest_args = example_args
+                    else:
+                        # subTest takes up to one positional argument.
+                        # When there are more, display them as a tuple
+                        subtest_args = [example_args]
+                    with self.subTest(*subtest_args, **example_kwargs):
                         f(self, *example_args, **example_kwargs)
 
         else:
index ce2e3e3726fcd0005d1cc3f5352825cc2d8ff962..6b5c65a56d87a044bf9512c73a959dbee7a5b433 100644 (file)
@@ -1,6 +1,7 @@
 import unittest
 import base64
 import binascii
+import string
 import os
 from array import array
 from test.support import cpython_only
@@ -14,6 +15,8 @@ class LazyImportTest(unittest.TestCase):
     def test_lazy_import(self):
         ensure_lazy_imports("base64", {"re", "getopt"})
 
+from test.support.hypothesis_helper import hypothesis
+
 
 class LegacyBase64TestCase(unittest.TestCase):
 
@@ -68,6 +71,13 @@ class LegacyBase64TestCase(unittest.TestCase):
         eq(base64.decodebytes(array('B', b'YWJj\n')), b'abc')
         self.check_type_errors(base64.decodebytes)
 
+    @hypothesis.given(payload=hypothesis.strategies.binary())
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz')
+    def test_bytes_encode_decode_round_trip(self, payload):
+        encoded = base64.encodebytes(payload)
+        decoded = base64.decodebytes(encoded)
+        self.assertEqual(payload, decoded)
+
     def test_encode(self):
         eq = self.assertEqual
         from io import BytesIO, StringIO
@@ -96,6 +106,19 @@ class LegacyBase64TestCase(unittest.TestCase):
         self.assertRaises(TypeError, base64.encode, BytesIO(b'YWJj\n'), StringIO())
         self.assertRaises(TypeError, base64.encode, StringIO('YWJj\n'), StringIO())
 
+    @hypothesis.given(payload=hypothesis.strategies.binary())
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz')
+    def test_legacy_encode_decode_round_trip(self, payload):
+        from io import BytesIO
+        payload_file_r = BytesIO(payload)
+        encoded_file_w = BytesIO()
+        base64.encode(payload_file_r, encoded_file_w)
+        encoded_file_r = BytesIO(encoded_file_w.getvalue())
+        decoded_file_w = BytesIO()
+        base64.decode(encoded_file_r, decoded_file_w)
+        decoded = decoded_file_w.getvalue()
+        self.assertEqual(payload, decoded)
+
 
 class BaseXYTestCase(unittest.TestCase):
 
@@ -276,6 +299,44 @@ class BaseXYTestCase(unittest.TestCase):
         self.assertEqual(base64.b64decode(b'++[[//]]', b'[]'), res)
         self.assertEqual(base64.urlsafe_b64decode(b'++--//__'), res)
 
+
+    def _altchars_strategy():
+        """Generate 'altchars' for base64 encoding."""
+        reserved_chars = (string.digits + string.ascii_letters + "=").encode()
+        allowed_chars = hypothesis.strategies.sampled_from(
+            [n for n in range(256) if n not in reserved_chars])
+        two_bytes_strategy = hypothesis.strategies.lists(
+            allowed_chars, min_size=2, max_size=2, unique=True).map(bytes)
+        return (hypothesis.strategies.none()
+                | hypothesis.strategies.just(b"_-")
+                | two_bytes_strategy)
+
+    @hypothesis.given(
+        payload=hypothesis.strategies.binary(),
+        altchars=_altchars_strategy(),
+        validate=hypothesis.strategies.booleans())
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', b"_-", True)
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', b"_-", False)
+    def test_b64_encode_decode_round_trip(self, payload, altchars, validate):
+        encoded = base64.b64encode(payload, altchars=altchars)
+        decoded = base64.b64decode(encoded, altchars=altchars,
+                                   validate=validate)
+        self.assertEqual(payload, decoded)
+
+    @hypothesis.given(payload=hypothesis.strategies.binary())
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz')
+    def test_standard_b64_encode_decode_round_trip(self, payload):
+        encoded = base64.standard_b64encode(payload)
+        decoded = base64.standard_b64decode(encoded)
+        self.assertEqual(payload, decoded)
+
+    @hypothesis.given(payload=hypothesis.strategies.binary())
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz')
+    def test_urlsafe_b64_encode_decode_round_trip(self, payload):
+        encoded = base64.urlsafe_b64encode(payload)
+        decoded = base64.urlsafe_b64decode(encoded)
+        self.assertEqual(payload, decoded)
+
     def test_b32encode(self):
         eq = self.assertEqual
         eq(base64.b32encode(b''), b'')
@@ -363,6 +424,19 @@ class BaseXYTestCase(unittest.TestCase):
                 with self.assertRaises(binascii.Error):
                     base64.b32decode(data.decode('ascii'))
 
+    @hypothesis.given(
+        payload=hypothesis.strategies.binary(),
+        casefold=hypothesis.strategies.booleans(),
+        map01=(
+            hypothesis.strategies.none()
+            | hypothesis.strategies.binary(min_size=1, max_size=1)))
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True, None)
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False, None)
+    def test_b32_encode_decode_round_trip(self, payload, casefold, map01):
+        encoded = base64.b32encode(payload)
+        decoded = base64.b32decode(encoded, casefold=casefold, map01=map01)
+        self.assertEqual(payload, decoded)
+
     def test_b32hexencode(self):
         test_cases = [
             # to_encode, expected
@@ -432,6 +506,15 @@ class BaseXYTestCase(unittest.TestCase):
                 with self.assertRaises(binascii.Error):
                     base64.b32hexdecode(data.decode('ascii'))
 
+    @hypothesis.given(
+        payload=hypothesis.strategies.binary(),
+        casefold=hypothesis.strategies.booleans())
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True)
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False)
+    def test_b32_hexencode_decode_round_trip(self, payload, casefold):
+        encoded = base64.b32hexencode(payload)
+        decoded = base64.b32hexdecode(encoded, casefold=casefold)
+        self.assertEqual(payload, decoded)
 
     def test_b16encode(self):
         eq = self.assertEqual
@@ -469,6 +552,16 @@ class BaseXYTestCase(unittest.TestCase):
         # Incorrect "padding"
         self.assertRaises(binascii.Error, base64.b16decode, '010')
 
+    @hypothesis.given(
+        payload=hypothesis.strategies.binary(),
+        casefold=hypothesis.strategies.booleans())
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True)
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False)
+    def test_b16_encode_decode_round_trip(self, payload, casefold):
+        endoded = base64.b16encode(payload)
+        decoded = base64.b16decode(endoded, casefold=casefold)
+        self.assertEqual(payload, decoded)
+
     def test_a85encode(self):
         eq = self.assertEqual
 
@@ -799,6 +892,61 @@ class BaseXYTestCase(unittest.TestCase):
         self.assertRaises(ValueError, base64.z85decode, b'%nSc')
         self.assertRaises(ValueError, base64.z85decode, b'%nSc1')
 
+    def add_padding(self, payload):
+        """Add the expected padding for test_?85_encode_decode_round_trip."""
+        if len(payload) % 4 != 0:
+            padding = b"\0" * ((-len(payload)) % 4)
+            payload = payload + padding
+        return payload
+
+    @hypothesis.given(
+        payload=hypothesis.strategies.binary(),
+        foldspaces=hypothesis.strategies.booleans(),
+        wrapcol=(
+            hypothesis.strategies.just(0)
+            | hypothesis.strategies.integers(1, 1000)),
+        pad=hypothesis.strategies.booleans(),
+        adobe=hypothesis.strategies.booleans(),
+    )
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False, 0, False, False)
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False, 20, True, True)
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True, 0, False, True)
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True, 20, True, False)
+    def test_a85_encode_decode_round_trip(
+        self, payload, foldspaces, wrapcol, pad, adobe
+    ):
+        encoded = base64.a85encode(
+            payload, foldspaces=foldspaces, wrapcol=wrapcol,
+            pad=pad, adobe=adobe,
+        )
+        if wrapcol:
+            if adobe and wrapcol == 1:
+                # "adobe" needs wrapcol to be at least 2.
+                # a85decode quietly uses 2 if 1 is given; it's not worth
+                # loudly deprecating this behavior.
+                wrapcol = 2
+            for line in encoded.splitlines(keepends=False):
+                self.assertLessEqual(len(line), wrapcol)
+        if adobe:
+            self.assertTrue(encoded.startswith(b'<~'))
+            self.assertTrue(encoded.endswith(b'~>'))
+        decoded = base64.a85decode(encoded, foldspaces=foldspaces, adobe=adobe)
+        if pad:
+            payload = self.add_padding(payload)
+        self.assertEqual(payload, decoded)
+
+    @hypothesis.given(
+        payload=hypothesis.strategies.binary(),
+        pad=hypothesis.strategies.booleans())
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', True)
+    @hypothesis.example(b'abcdefghijklmnopqrstuvwxyz', False)
+    def test_b85_encode_decode_round_trip(self, payload, pad):
+        encoded = base64.b85encode(payload, pad=pad)
+        if pad:
+            payload = self.add_padding(payload)
+        decoded = base64.b85decode(encoded)
+        self.assertEqual(payload, decoded)
+
     def test_decode_nonascii_str(self):
         decode_funcs = (base64.b64decode,
                         base64.standard_b64decode,