]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
tests/krb5: Add function for creating claims
authorJoseph Sutton <josephsutton@catalyst.net.nz>
Fri, 4 Mar 2022 03:20:18 +0000 (16:20 +1300)
committerAndrew Bartlett <abartlet@samba.org>
Fri, 9 Sep 2022 00:14:38 +0000 (00:14 +0000)
Signed-off-by: Joseph Sutton <josephsutton@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
python/samba/tests/krb5/kdc_base_test.py

index 0733c5c967528a6bf4694e248552c509ad2474bd..73fa5a04baf68210937cfeb04902717837a75025 100644 (file)
@@ -130,6 +130,57 @@ class KDCBaseTest(RawKerberosTest):
 
         cls.ldb_cleanups = []
 
+        cls._claim_types_dn = None
+
+    def get_claim_types_dn(self):
+        samdb = self.get_samdb()
+
+        if self._claim_types_dn is None:
+            claim_config_dn = samdb.get_config_basedn()
+
+            self.assertTrue(claim_config_dn.add_child(
+                'CN=Claims Configuration,CN=Services'))
+            details = {
+                'dn': claim_config_dn,
+                'objectClass': 'container',
+            }
+            try:
+                samdb.add(details)
+            except ldb.LdbError as err:
+                num, _ = err.args
+                if num != ldb.ERR_ENTRY_ALREADY_EXISTS:
+                    raise
+            else:
+                self.accounts.append(str(claim_config_dn))
+
+            claim_types_dn = claim_config_dn
+            self.assertTrue(claim_types_dn.add_child('CN=Claim Types'))
+            details = {
+                'dn': claim_types_dn,
+                'objectClass': 'msDS-ClaimTypes',
+            }
+            try:
+                samdb.add(details)
+            except ldb.LdbError as err:
+                num, _ = err.args
+                if num != ldb.ERR_ENTRY_ALREADY_EXISTS:
+                    raise
+            else:
+                self.accounts.append(str(claim_types_dn))
+
+            type(self)._claim_types_dn = claim_types_dn
+
+        # Return a copy of the DN.
+        return ldb.Dn(samdb, str(self._claim_types_dn))
+
+    def tearDown(self):
+        # Clean up any accounts created for single tests.
+        if self._ldb is not None:
+            for dn in reversed(self.test_accounts):
+                delete_force(self._ldb, dn)
+
+        super().tearDown()
+
     @classmethod
     def tearDownClass(cls):
         # Clean up any accounts created by create_account. This is
@@ -155,6 +206,10 @@ class KDCBaseTest(RawKerberosTest):
         self.do_asn1_print = global_asn1_print
         self.do_hexdump = global_hexdump
 
+        # A list containing DNs of accounts that should be removed when the
+        # current test finishes.
+        self.test_accounts = []
+
     def get_lp(self):
         if self._lp is None:
             type(self)._lp = self.get_loadparm()
@@ -268,6 +323,88 @@ class KDCBaseTest(RawKerberosTest):
 
         return dn
 
+    def get_dn_from_attribute(self, attribute):
+        return self.get_dn_from_schema(attribute, 'attributeSchema')
+
+    def get_dn_from_class(self, attribute):
+        return self.get_dn_from_schema(attribute, 'classSchema')
+
+    def get_dn_from_schema(self, name, object_class):
+        samdb = self.get_samdb()
+        schema_dn = samdb.get_schema_basedn()
+
+        res = samdb.search(base=schema_dn,
+                           scope=ldb.SCOPE_ONELEVEL,
+                           expression=(f'(&(objectClass={object_class})'
+                                       f'(lDAPDisplayName={name}))'))
+        self.assertEqual(1, len(res),
+                         f'could not locate {name} in {object_class}')
+
+        return res[0].dn
+
+    def create_claim(self,
+                     claim_id,
+                     enabled=None,
+                     attribute=None,
+                     single_valued=None,
+                     source=None,
+                     source_type=None,
+                     for_classes=None,
+                     value_type=None):
+        samdb = self.get_samdb()
+
+        claim_dn = self.get_claim_types_dn()
+        self.assertTrue(claim_dn.add_child(f'CN={claim_id}'))
+
+        details = {
+            'dn': claim_dn,
+            'objectClass': 'msDS-ClaimType',
+        }
+
+        if enabled is True:
+            enabled = 'TRUE'
+        elif enabled is False:
+            enabled = 'FALSE'
+
+        if attribute is not None:
+            attribute = str(self.get_dn_from_attribute(attribute))
+
+        if single_valued is True:
+            single_valued = 'TRUE'
+        elif single_valued is False:
+            single_valued = 'FALSE'
+
+        if for_classes is not None:
+            for_classes = [str(self.get_dn_from_class(name))
+                           for name in for_classes]
+
+        if isinstance(value_type, int):
+            value_type = str(value_type)
+
+        if enabled is not None:
+            details['Enabled'] = enabled
+        if attribute is not None:
+            details['msDS-ClaimAttributeSource'] = attribute
+        if single_valued is not None:
+            details['msDS-ClaimIsSingleValued'] = single_valued
+        if source is not None:
+            details['msDS-ClaimSource'] = source
+        if source_type is not None:
+            details['msDS-ClaimSourceType'] = source_type
+        if for_classes is not None:
+            details['msDS-ClaimTypeAppliesToClass'] = for_classes
+        if value_type is not None:
+            details['msDS-ClaimValueType'] = value_type
+
+        # Save the claim DN so it can be deleted in tearDown()
+        self.test_accounts.append(str(claim_dn))
+
+        # Remove the claim if it exists; this will happen if a previous test
+        # run failed
+        delete_force(samdb, claim_dn)
+
+        samdb.add(details)
+
     def create_account(self, samdb, name, account_type=AccountType.USER,
                        spn=None, upn=None, additional_details=None,
                        ou=None, account_control=0, add_dollar=True,