]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
python: Factor out asn.1 methods into their own module
authorJennifer Sutton <jennifersutton@catalyst.net.nz>
Sun, 2 Nov 2025 21:45:44 +0000 (10:45 +1300)
committerDouglas Bagnall <dbagnall@samba.org>
Wed, 5 Nov 2025 04:08:40 +0000 (04:08 +0000)
Signed-off-by: Jennifer Sutton <jennifersutton@catalyst.net.nz>
Reviewed-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
Reviewed-by: Gary Lockyer <gary@catalyst.net.nz>
python/samba/asn1.py [new file with mode: 0644]
python/samba/tests/krb5/pkinit_certificate_mapping_tests.py
python/samba/tests/krb5/pkinit_tests.py
python/samba/tests/krb5/raw_testcase.py

diff --git a/python/samba/asn1.py b/python/samba/asn1.py
new file mode 100644 (file)
index 0000000..c15fa07
--- /dev/null
@@ -0,0 +1,103 @@
+# Unix SMB/CIFS implementation.
+# Copyright (C) Catalyst.Net Ltd 2025
+#
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <https://www.gnu.org/licenses/>.
+#
+
+"""ASN.1 module"""
+
+import math
+from typing import Optional
+
+
+class Asn1Error(Exception):
+    pass
+
+
+def length_in_bytes(value: int) -> int:
+    """Return the length in bytes of an integer once it is encoded as
+    bytes."""
+
+    if value < 0:
+        raise Asn1Error("value must be positive")
+    if not isinstance(value, int):
+        raise Asn1Error("value must be an integer")
+
+    length_in_bits = max(1, math.log2(value + 1))
+    length_in_bytes = math.ceil(length_in_bits / 8)
+    return length_in_bytes
+
+
+def bytes_from_int(value: int, *, length: Optional[int] = None) -> bytes:
+    """Return an integer encoded big-endian into bytes of an optionally
+    specified length.
+    """
+    if length is None:
+        length = length_in_bytes(value)
+    return value.to_bytes(length, "big")
+
+
+def int_from_bytes(data: bytes) -> int:
+    """Return an integer decoded from bytes in big-endian format."""
+    return int.from_bytes(data, "big")
+
+
+def int_from_bit_string(string: str) -> int:
+    """Return an integer decoded from a bitstring."""
+    return int(string, base=2)
+
+
+def bit_string_from_int(value: int) -> str:
+    """Return a bitstring encoding of an integer."""
+
+    string = f"{value:b}"
+
+    # The bitstring must be padded to a multiple of 8 bits in length, or
+    # pyasn1 will interpret it incorrectly (as if the padding bits were
+    # present, but on the wrong end).
+    length = len(string)
+    padding_len = math.ceil(length / 8) * 8 - length
+    return "0" * padding_len + string
+
+
+def bit_string_from_bytes(data: bytes) -> str:
+    """Return a bitstring encoding of bytes in big-endian format."""
+    value = int_from_bytes(data)
+    return bit_string_from_int(value)
+
+
+def bytes_from_bit_string(string: str) -> bytes:
+    """Return big-endian format bytes encoded from a bitstring."""
+    value = int_from_bit_string(string)
+    length = math.ceil(len(string) / 8)
+    return value.to_bytes(length, "big")
+
+
+def asn1_length(data: bytes) -> bytes:
+    """Return the ASN.1 encoding of the length of some data."""
+
+    length = len(data)
+
+    if length <= 0:
+        raise Asn1Error("length must be greater than zero")
+    if length < 0x80:
+        return bytes([length])
+
+    encoding_len = length_in_bytes(length)
+    if encoding_len >= 0x80:
+        raise Asn1Error("item is too long to be ASN.1 encoded")
+
+    data = bytes_from_int(length, length=encoding_len)
+    return bytes([0x80 | encoding_len]) + data
index 99cc23b66ec5b03a03bf51c3a1bc57bfaf97758d..e6486cc54d2a5463f23e195607b4db4dc5d64b53 100755 (executable)
@@ -42,6 +42,7 @@ from cryptography.hazmat.primitives import hashes, serialization
 from cryptography.hazmat.primitives.asymmetric import dh, padding
 from cryptography.x509.oid import NameOID
 
+from samba import asn1
 from samba.domain.models import User
 import samba.tests
 from samba.dcerpc import security
@@ -708,12 +709,12 @@ class PkInitCertificateMappingTests(KDCBaseTest):
             encoded_sid = object_sid.encode("utf-8")
 
             # The OCTET STRING tag, followed by length and encoded SID…
-            security_ext = bytes([0x04]) + self.asn1_length(encoded_sid) + (encoded_sid)
+            security_ext = bytes([0x04]) + asn1.asn1_length(encoded_sid) + (encoded_sid)
 
             # …enclosed in a construct tagged with the application-specific value
             # 0…
             security_ext = (
-                bytes([0xA0]) + self.asn1_length(security_ext) + (security_ext)
+                bytes([0xA0]) + asn1.asn1_length(security_ext) + (security_ext)
             )
 
             # …preceded by the extension OID…
@@ -728,12 +729,12 @@ class PkInitCertificateMappingTests(KDCBaseTest):
             # nesting going on.  So far I haven’t been able to replicate this with
             # pyasn1.)
             security_ext = (
-                bytes([0xA0]) + self.asn1_length(security_ext) + (security_ext)
+                bytes([0xA0]) + asn1.asn1_length(security_ext) + (security_ext)
             )
 
             # …all enclosed in a structure with a SEQUENCE tag.
             security_ext = (
-                bytes([0x30]) + self.asn1_length(security_ext) + (security_ext)
+                bytes([0x30]) + asn1.asn1_length(security_ext) + (security_ext)
             )
 
             # Add the security extension to the certificate.
index 5278d4945cf090420ab689b1e4224696bf47c270..4928f1ce46e5883086c226586f6fefb5c6a4944d 100755 (executable)
@@ -37,7 +37,7 @@ from cryptography.x509.oid import NameOID
 
 import ldb
 import samba.tests
-from samba import credentials, generate_random_password, ntstatus
+from samba import asn1, credentials, generate_random_password, ntstatus
 from samba.nt_time import (nt_time_delta_from_timedelta,
                            nt_now, NtTime, string_from_nt_time)
 from samba.dcerpc import security, netlogon
@@ -1579,12 +1579,12 @@ class PkInitTests(KDCBaseTest):
         encoded_sid = creds.get_sid().encode('utf-8')
 
         # The OCTET STRING tag, followed by length and encoded SID…
-        security_ext = bytes([0x04]) + self.asn1_length(encoded_sid) + (
+        security_ext = bytes([0x04]) + asn1.asn1_length(encoded_sid) + (
             encoded_sid)
 
         # …enclosed in a construct tagged with the application-specific value
         # 0…
-        security_ext = bytes([0xa0]) + self.asn1_length(security_ext) + (
+        security_ext = bytes([0xa0]) + asn1.asn1_length(security_ext) + (
             security_ext)
 
         # …preceded by the extension OID…
@@ -1597,11 +1597,11 @@ class PkInitTests(KDCBaseTest):
         # the OID, but of the entire structure so far, as if there’s some
         # nesting going on.  So far I haven’t been able to replicate this with
         # pyasn1.)
-        security_ext = bytes([0xa0]) + self.asn1_length(security_ext) + (
+        security_ext = bytes([0xa0]) + asn1.asn1_length(security_ext) + (
             security_ext)
 
         # …all enclosed in a structure with a SEQUENCE tag.
-        security_ext = bytes([0x30]) + self.asn1_length(security_ext) + (
+        security_ext = bytes([0x30]) + asn1.asn1_length(security_ext) + (
             security_ext)
 
         # Add the security extension to the certificate.
index e5f155fd55c900bb59cce90ccad94426c15929ca..357345a8d8c573b3571a4aeb691f256ee34f3bd0 100644 (file)
@@ -46,7 +46,7 @@ import pyasn1.type.univ
 
 from pyasn1.error import PyAsn1Error
 
-from samba import unix2nttime
+from samba import asn1, unix2nttime
 from samba.common import get_string
 from samba.credentials import Credentials
 from samba.dcerpc import claims, krb5pac, netlogon, samr, security, krb5ccache
@@ -2594,72 +2594,6 @@ class RawKerberosTest(TestCase):
 
         return domain_params_obj
 
-    def length_in_bytes(self, value):
-        """Return the length in bytes of an integer once it is encoded as
-        bytes."""
-
-        self.assertGreaterEqual(value, 0, 'value must be positive')
-        self.assertIsInstance(value, int)
-
-        length_in_bits = max(1, math.log2(value + 1))
-        length_in_bytes = math.ceil(length_in_bits / 8)
-        return length_in_bytes
-
-    def bytes_from_int(self, value, *, length=None):
-        """Return an integer encoded big-endian into bytes of an optionally
-        specified length.
-        """
-        if length is None:
-            length = self.length_in_bytes(value)
-        return value.to_bytes(length, 'big')
-
-    def int_from_bytes(self, data):
-        """Return an integer decoded from bytes in big-endian format."""
-        return int.from_bytes(data, 'big')
-
-    def int_from_bit_string(self, string):
-        """Return an integer decoded from a bitstring."""
-        return int(string, base=2)
-
-    def bit_string_from_int(self, value):
-        """Return a bitstring encoding of an integer."""
-
-        string = f'{value:b}'
-
-        # The bitstring must be padded to a multiple of 8 bits in length, or
-        # pyasn1 will interpret it incorrectly (as if the padding bits were
-        # present, but on the wrong end).
-        length = len(string)
-        padding_len = math.ceil(length / 8) * 8 - length
-        return '0' * padding_len + string
-
-    def bit_string_from_bytes(self, data):
-        """Return a bitstring encoding of bytes in big-endian format."""
-        value = self.int_from_bytes(data)
-        return self.bit_string_from_int(value)
-
-    def bytes_from_bit_string(self, string):
-        """Return big-endian format bytes encoded from a bitstring."""
-        value = self.int_from_bit_string(string)
-        length = math.ceil(len(string) / 8)
-        return value.to_bytes(length, 'big')
-
-    def asn1_length(self, data):
-        """Return the ASN.1 encoding of the length of some data."""
-
-        length = len(data)
-
-        self.assertGreater(length, 0)
-        if length < 0x80:
-            return bytes([length])
-
-        encoding_len = self.length_in_bytes(length)
-        self.assertLess(encoding_len, 0x80,
-                        'item is too long to be ASN.1 encoded')
-
-        data = self.bytes_from_int(length, length=encoding_len)
-        return bytes([0x80 | encoding_len]) + data
-
     @staticmethod
     def octetstring2key(x, enctype):
         """This implements the function defined in RFC4556 3.2.3.1 “Using
@@ -3783,7 +3717,7 @@ class RawKerberosTest(TestCase):
                     # Windows encodes the ASN.1 incorrectly, neglecting to add
                     # the SEQUENCE tag. We’ll have to prepend it ourselves in
                     # order for the decoding to work.
-                    encoded_len = self.asn1_length(decrypted_content)
+                    encoded_len = asn1.asn1_length(decrypted_content)
                     decrypted_content = bytes([0x30]) + encoded_len + (
                         decrypted_content)
 
@@ -3891,7 +3825,7 @@ class RawKerberosTest(TestCase):
                 self.assertElementEqual(dh_key_info, 'nonce',
                                         kdc_exchange_dict['pk_nonce'])
 
-                dh_public_key_data = self.bytes_from_bit_string(
+                dh_public_key_data = asn1.bytes_from_bit_string(
                     dh_key_info['subjectPublicKey'])
                 dh_public_key_decoded = self.der_decode(
                     dh_public_key_data, asn1Spec=krb5_asn1.DHPublicKey())
@@ -3906,7 +3840,7 @@ class RawKerberosTest(TestCase):
                 shared_secret = dh_private_key.exchange(dh_public_key)
 
                 # Pad the shared secret out to the length of ‘p’.
-                p_len = self.length_in_bytes(dh_numbers.p)
+                p_len = asn1.length_in_bytes(dh_numbers.p)
                 padding_len = p_len - len(shared_secret)
                 self.assertGreaterEqual(padding_len, 0)
                 padded_shared_secret = bytes(padding_len) + shared_secret