]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
python:tests/krb5: add a create_trust() helper function to test trusted domains
authorStefan Metzmacher <metze@samba.org>
Mon, 2 Dec 2024 07:48:32 +0000 (08:48 +0100)
committerStefan Metzmacher <metze@samba.org>
Wed, 8 Jan 2025 09:13:31 +0000 (09:13 +0000)
Signed-off-by: Stefan Metzmacher <metze@samba.org>
Reviewed-by: Jennifer Sutton <jennifersutton@catalyst.net.nz>
python/samba/tests/krb5/kdc_base_test.py

index 2dc696af35db03f7e77b25b20253c64271a0a0f5..167a2d70a54e31c67af7db461e0555049457f174 100644 (file)
@@ -43,6 +43,8 @@ from samba import (
     generate_random_password,
     net,
     ntstatus,
+    current_unix_time,
+    unix2nttime,
 )
 from samba.auth import system_session
 from samba.credentials import (
@@ -52,6 +54,7 @@ from samba.credentials import (
     Credentials,
 )
 from samba.crypto import des_crypt_blob_16, md4_hash_blob
+from samba.lsa_utils import OpenPolicyFallback, CreateTrustedDomainFallback
 from samba.dcerpc import (
     claims,
     dcerpc,
@@ -66,7 +69,13 @@ from samba.dcerpc import (
     samr,
     security,
 )
-from samba.dcerpc.misc import SEC_CHAN_BDC, SEC_CHAN_NULL, SEC_CHAN_WKSTA, SEC_CHAN_RODC
+from samba.dcerpc.misc import (
+    SEC_CHAN_NULL,
+    SEC_CHAN_BDC,
+    SEC_CHAN_DNS_DOMAIN,
+    SEC_CHAN_DOMAIN,
+    SEC_CHAN_WKSTA,
+)
 from samba.domain.models import AuthenticationPolicy, AuthenticationSilo
 from samba.drs_utils import drs_Replicate, drsuapi_connect
 from samba.dsdb import (
@@ -88,7 +97,8 @@ from samba.dsdb import (
     UF_SERVER_TRUST_ACCOUNT,
     UF_TRUSTED_TO_AUTHENTICATE_FOR_DELEGATION,
     UF_WORKSTATION_TRUST_ACCOUNT,
-    UF_SMARTCARD_REQUIRED
+    UF_SMARTCARD_REQUIRED,
+    UF_INTERDOMAIN_TRUST_ACCOUNT,
 )
 from samba.join import DCJoinContext
 from samba.ndr import ndr_pack, ndr_unpack
@@ -160,6 +170,7 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
         RODC = object()
         MANAGED_SERVICE = object()
         GROUP_MANAGED_SERVICE = object()
+        TRUST = object()
 
     @classmethod
     def setUpClass(cls):
@@ -170,6 +181,7 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
         cls._rodc_ldb = None
 
         cls._drsuapi_connection = None
+        cls._lsarpc_connection = None
 
         cls._functional_level = None
 
@@ -182,6 +194,9 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
         # A list containing DNs of accounts created as part of testing.
         cls.accounts = []
 
+        # A list of tdo_handles of trusts created as part of testing.
+        cls.trusts = []
+
         cls.account_cache = {}
         cls.policy_cache = {}
         cls.tkt_cache = {}
@@ -330,6 +345,12 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
         if self._test_rodc_ctx is not None:
             self._test_rodc_ctx.cleanup_old_join(force=True)
 
+        # Clean up any trusts created for single tests.
+        if self._lsarpc_connection is not None:
+            lsa_conn, _, _, _, _ = self._lsarpc_connection
+            for tdo_handle in reversed(self.test_trusts):
+                lsa_conn.DeleteObject(tdo_handle)
+
         super().tearDown()
 
     @classmethod
@@ -347,6 +368,14 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
             for dn in reversed(cls.accounts):
                 delete_force(cls._ldb, dn)
 
+        # Clean up any trusts created by create_trust. This is
+        # done in tearDownClass() rather than tearDown(), so that
+        # trust accounts need only be created once for permutation tests.
+        if cls._lsarpc_connection is not None:
+            lsa_conn, _, _, _, _ = cls._lsarpc_connection
+            for tdo_handle in reversed(cls.trusts):
+                lsa_conn.DeleteObject(tdo_handle)
+
         if cls._rodc_ctx is not None:
             cls._rodc_ctx.cleanup_old_join(force=True)
 
@@ -362,6 +391,10 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
         self.test_accounts = []
         self._test_rodc_ctx = None
 
+        # A list containing tdo_handles of trusts that should be removed when the
+        # current test finishes.
+        self.test_trusts = []
+
     def get_lp(self) -> LoadParm:
         if self._lp is None:
             type(self)._lp = self.get_loadparm()
@@ -407,6 +440,62 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
 
         return self._drsuapi_connection
 
+    def get_lsarpc_connection(self):
+        def get_lsa_info(conn, policy_access):
+            in_version = 1
+            in_revision_info1 = lsa.revision_info1()
+            in_revision_info1.revision = 1
+            in_revision_info1.supported_features = (
+                lsa.LSA_FEATURE_TDO_AUTH_INFO_AES_CIPHER
+            )
+
+            out_version, out_revision_info1, policy = OpenPolicyFallback(
+                conn,
+                b''.decode('utf-8'),
+                in_version,
+                in_revision_info1,
+                access_mask=policy_access
+            )
+
+            info = conn.QueryInfoPolicy2(policy, lsa.LSA_POLICY_INFO_DNS)
+
+            return (policy, out_version, out_revision_info1, info)
+
+        def lsarpc_connect(server, lp, creds, ip=None):
+            binding_options = ""
+            if lp.log_level() >= 9:
+                binding_options += ",print"
+
+            # Allow forcing the IP
+            if ip is not None:
+                binding_options += f",target_hostname={server}"
+                binding_string = f"ncacn_np:{ip}[{binding_options}]"
+            else:
+                binding_string = "ncacn_np:%s[%s]" % (server, binding_options)
+
+            try:
+                conn = lsa.lsarpc(binding_string, lp, creds)
+                policy_access = lsa.LSA_POLICY_VIEW_LOCAL_INFORMATION
+                policy_access |= lsa.LSA_POLICY_TRUST_ADMIN
+                policy_access |= lsa.LSA_POLICY_CREATE_SECRET
+                (policy, out_version, out_revision_info1, info) = \
+                    get_lsa_info(conn, policy_access)
+            except Exception as e:
+                raise RuntimeError("LSARPC connection to %s failed: %s" % (server, e))
+
+            return (conn, policy, out_version, out_revision_info1, info)
+
+        if self._lsarpc_connection is None:
+            admin_creds = self.get_admin_creds()
+            samdb = self.get_samdb()
+            dns_hostname = samdb.host_dns_name()
+            type(self)._lsarpc_connection = lsarpc_connect(dns_hostname,
+                                                           self.get_lp(),
+                                                           admin_creds,
+                                                           ip=self.dc_host)
+
+        return self._lsarpc_connection
+
     def get_server_dn(self, samdb):
         server = samdb.get_serverName()
 
@@ -805,6 +894,207 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
             # Save the claim DN so it can be deleted in tearDownClass()
             self.accounts.append(str(claim_dn))
 
+    def create_trust(self, trust_info,
+                     trust_enc_types=None,
+                     trust_incoming_password=None,
+                     trust_outgoing_password=None,
+                     expect_error=None,
+                     preserve=True):
+        """Create an trust account for testing.
+           The handle of the created trust is added to cls.trusts,
+           which is used by tearDownClass to clean up the created trusts.
+           With preserve=False the handle is added to self.test_trusts,
+           which is used by tearDown to clean up the created trusts.
+        """
+
+        if trust_incoming_password is None:
+            trust_incoming_password = generate_random_password(120, 120)
+        trust_incoming_secret = list(trust_incoming_password.encode('utf-16-le'))
+        if trust_outgoing_password is None:
+            trust_outgoing_password = generate_random_password(120, 120)
+        trust_outgoing_secret = list(trust_outgoing_password.encode('utf-16-le'))
+
+        def generate_AuthInOutBlob(secret, update_time):
+            if secret is None:
+                blob = drsblobs.trustAuthInOutBlob()
+                blob.count = 0
+
+                return blob
+
+            clear = drsblobs.AuthInfoClear()
+            clear.size = len(secret)
+            clear.password = secret
+
+            info = drsblobs.AuthenticationInformation()
+            info.LastUpdateTime = unix2nttime(update_time)
+            info.AuthType = lsa.TRUST_AUTH_TYPE_CLEAR
+            info.AuthInfo = clear
+
+            array = drsblobs.AuthenticationInformationArray()
+            array.count = 1
+            array.array = [info]
+
+            blob = drsblobs.trustAuthInOutBlob()
+            blob.count = 1
+            blob.current = array
+
+            return blob
+
+        update_time = current_unix_time()
+        trust_incoming_blob = generate_AuthInOutBlob(trust_incoming_secret,
+                                                     update_time)
+        trust_outgoing_blob = generate_AuthInOutBlob(trust_outgoing_secret,
+                                                     update_time)
+
+        lsa_conn, lsa_policy, lsa_version, lsa_revision_info1, local_info = \
+                self.get_lsarpc_connection()
+
+        try:
+            tdo_handle = CreateTrustedDomainFallback(lsa_conn,
+                                                     lsa_policy,
+                                                     trust_info,
+                                                     lsa.LSA_TRUSTED_DOMAIN_ALL_ACCESS |
+                                                     security.SEC_STD_DELETE,
+                                                     lsa_version,
+                                                     lsa_revision_info1,
+                                                     trust_incoming_blob,
+                                                     trust_outgoing_blob)
+        except NTSTATUSError as err:
+            status, _ = err.args
+            self.assertIsNotNone(expect_error,
+                                 f'unexpectedly failed with {status:08X}')
+            self.assertEqual(expect_error, status, 'got wrong status code')
+            return (None, None, None, None)
+        self.assertIsNone(expect_error, 'expected error')
+        if preserve:
+            # Mark this trust for deletion in tearDownClass() after all the
+            # tests in this class finish.
+            self.trusts.append(tdo_handle)
+        else:
+            # Mark this trust for deletion in tearDown() after the current
+            # test finishes.
+            self.test_trusts.append(tdo_handle)
+        if trust_enc_types:
+            lsa_conn.SetInformationTrustedDomain(tdo_handle,
+                                                 lsa.LSA_TRUSTED_DOMAIN_SUPPORTED_ENCRYPTION_TYPES,
+                                                 trust_enc_types)
+
+        samdb = self.get_samdb()
+
+        incoming_account_name = trust_info.netbios_name.string
+        incoming_account_name += '$'
+        incoming_nbt_domain = local_info.name.string
+        incoming_dns_domain = local_info.dns_domain.string
+
+        outgoing_account_name = local_info.name.string
+        outgoing_account_name += '$'
+        outgoing_nbt_domain = trust_info.netbios_name.string
+        outgoing_dns_domain = trust_info.domain_name.string
+
+        tdo_search_filter = "(&(objectClass=trustedDomain)(name=%s))" % (
+                            outgoing_dns_domain)
+        tdo_res = samdb.search(scope=ldb.SCOPE_SUBTREE,
+                               expression=tdo_search_filter,
+                               attrs=['msDS-TrustForestTrustInfo'])
+        self.assertEqual(len(tdo_res), 1)
+        tdo_dn = tdo_res[0].dn
+
+        acct_search_filter = "(&(objectClass=user)(sAMAccountName=%s))" % (
+                             incoming_account_name)
+        acct_res = samdb.search(scope=ldb.SCOPE_SUBTREE,
+                                expression=acct_search_filter,
+                                attrs=['msDS-KeyVersionNumber',
+                                       'objectSid',
+                                       'objectGUID'])
+        self.assertEqual(len(acct_res), 1)
+        acct_dn = acct_res[0].dn
+        acct_kvno = int(acct_res[0]['msDS-KeyVersionNumber'][0])
+        acct_sid = acct_res[0].get('objectSid', idx=0)
+        acct_sid = samdb.schema_format_value('objectSID', acct_sid)
+        acct_sid = acct_sid.decode('utf-8')
+        acct_guid = acct_res[0].get('objectGUID', idx=0)
+        acct_guid = samdb.schema_format_value('objectGUID', acct_guid)
+        acct_guid = acct_guid.decode('utf-8')
+
+        trust_incoming_salt = "%skrbtgt%s" % (
+                incoming_dns_domain.upper(),
+                outgoing_dns_domain.upper())
+        trust_outgoing_salt = "%skrbtgt%s" % (
+                outgoing_dns_domain.upper(),
+                incoming_dns_domain.upper())
+        trust_account_salt = "%skrbtgt%s" % (
+                incoming_dns_domain.upper(),
+                outgoing_nbt_domain.upper())
+
+        if trust_info.trust_type != lsa.LSA_TRUST_TYPE_DOWNLEVEL:
+            secure_channel_type = SEC_CHAN_DNS_DOMAIN
+        else:
+            secure_channel_type = SEC_CHAN_DOMAIN
+
+        incoming_creds = KerberosCredentials()
+        incoming_creds.guess(self.get_lp())
+        incoming_creds.set_realm(incoming_dns_domain.upper())
+        incoming_creds.set_domain(incoming_nbt_domain.upper())
+        incoming_creds.set_forced_salt(trust_incoming_salt.encode('utf-8'))
+        incoming_creds.set_password(trust_incoming_password)
+        incoming_creds.set_username(incoming_account_name)
+        incoming_creds.set_workstation('')
+        incoming_creds.set_secure_channel_type(secure_channel_type)
+        incoming_creds.set_dn(tdo_dn)
+        incoming_creds.set_type(self.AccountType.TRUST)
+        incoming_creds.set_user_account_control(UF_INTERDOMAIN_TRUST_ACCOUNT)
+        self.creds_set_enctypes(incoming_creds)
+
+        outgoing_creds = KerberosCredentials()
+        outgoing_creds.guess(self.get_lp())
+        outgoing_creds.set_realm(outgoing_dns_domain.upper())
+        outgoing_creds.set_domain(outgoing_nbt_domain.upper())
+        outgoing_creds.set_forced_salt(trust_outgoing_salt.encode('utf-8'))
+        outgoing_creds.set_password(trust_outgoing_password)
+        outgoing_creds.set_username(outgoing_account_name)
+        outgoing_creds.set_workstation('')
+        outgoing_creds.set_secure_channel_type(secure_channel_type)
+        outgoing_creds.set_dn(tdo_dn)
+        outgoing_creds.set_type(self.AccountType.TRUST)
+        outgoing_creds.set_user_account_control(UF_INTERDOMAIN_TRUST_ACCOUNT)
+        self.creds_set_enctypes(outgoing_creds)
+
+        account_creds = KerberosCredentials()
+        account_creds.guess(self.get_lp())
+        account_creds.set_realm(incoming_dns_domain.upper())
+        account_creds.set_domain(incoming_nbt_domain.upper())
+        account_creds.set_forced_salt(trust_account_salt.encode('utf-8'))
+        account_creds.set_password(trust_incoming_password)
+        account_creds.set_username(incoming_account_name)
+        account_creds.set_workstation('TEST-TRUST-DC')
+        account_creds.set_secure_channel_type(secure_channel_type)
+        account_creds.set_dn(acct_dn)
+        account_creds.set_type(self.AccountType.TRUST)
+        account_creds.set_user_account_control(UF_INTERDOMAIN_TRUST_ACCOUNT)
+        account_creds.set_kvno(acct_kvno)
+        account_creds.set_sid(str(acct_sid))
+        account_creds.set_guid(acct_guid)
+        if trust_enc_types is not None:
+            self.creds_set_enctypes(account_creds,
+                                    extra_bits=trust_enc_types.enc_types)
+        else:
+            self.creds_set_enctypes(account_creds)
+
+        incoming_creds.set_trust_outgoing_creds(outgoing_creds)
+        incoming_creds.set_trust_account_creds(account_creds)
+
+        outgoing_creds.set_trust_incoming_creds(incoming_creds)
+        outgoing_creds.set_trust_account_creds(account_creds)
+
+        account_creds.set_trust_incoming_creds(incoming_creds)
+        account_creds.set_trust_outgoing_creds(outgoing_creds)
+
+        self.remember_creds_for_keytab_export(incoming_creds)
+        self.remember_creds_for_keytab_export(outgoing_creds)
+        self.remember_creds_for_keytab_export(account_creds)
+
+        return (tdo_handle, incoming_creds, outgoing_creds, account_creds)
+
     def create_account(self, samdb, name, account_type=AccountType.USER,
                        spn=None, upn=None, additional_details=None,
                        ou=None, account_control=0, add_dollar=None,