]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
netcmd: tests: tests tidyup and make use of setUpTestData
authorRob van der Linde <rob@catalyst.net.nz>
Tue, 26 Sep 2023 11:20:49 +0000 (00:20 +1300)
committerAndrew Bartlett <abartlet@samba.org>
Fri, 29 Sep 2023 02:18:34 +0000 (02:18 +0000)
Still only load the test data once per test class, but much easier to read.

Made several methods static for creating/deleting claims, policies and silos.

Signed-off-by: Rob van der Linde <rob@catalyst.net.nz>
Reviewed-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
python/samba/tests/samba_tool/domain_auth_base.py
python/samba/tests/samba_tool/domain_auth_policy.py
python/samba/tests/samba_tool/domain_auth_silo.py
python/samba/tests/samba_tool/domain_claim.py
python/samba/tests/samba_tool/domain_models.py

index d949c516be1810c37c173f5ddacc0569e94546a0..a0f423767c6654adf33c4f297840051d64ad040e 100644 (file)
@@ -1,6 +1,6 @@
 # Unix SMB/CIFS implementation.
 #
-# Base class for samba-tool domain auth policy and silo commands
+# Base test class for samba-tool domain auth policy and silo commands.
 #
 # Copyright (C) Catalyst.Net Ltd. 2023
 #
@@ -26,59 +26,36 @@ from ldb import SCOPE_ONELEVEL
 
 from .base import SambaToolCmdTest
 
+HOST = "ldap://{DC_SERVER}".format(**os.environ)
+CREDS = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
+
 
 class BaseAuthCmdTest(SambaToolCmdTest):
-    def setUp(self):
-        super().setUp()
-
-        if self._first_self is None:
-            cls = type(self)
-            cls.host = "ldap://{DC_SERVER}".format(**os.environ)
-            cls.creds = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
-            cls.samdb = self.getSamDB("-H", self.host, self.creds)
-
-            # Generate some test data.
-            self.create_authentication_policy(name="Single Policy")
-            self.create_authentication_policy(name="User Policy")
-            self.create_authentication_policy(name="Service Policy")
-            self.create_authentication_policy(name="Computer Policy")
-
-            self.create_authentication_silo(name="Developers",
-                                            description="Developers, Developers",
-                                            policy="Single Policy")
-            self.create_authentication_silo(name="Managers",
-                                            description="Managers",
-                                            policy="Single Policy")
-            self.create_authentication_silo(name="QA",
-                                            description="Quality Assurance",
-                                            user_policy="User Policy",
-                                            service_policy="Service Policy",
-                                            computer_policy="Computer Policy")
-
-            cls._first_self = self
+    """Base test class for samba-tool domain auth policy and silo commands."""
 
     @classmethod
     def setUpClass(cls):
+        cls.samdb = cls.getSamDB("-H", HOST, CREDS)
         super().setUpClass()
-        cls._first_self = None
-        cls.policies = []
-        cls.silos = []
 
     @classmethod
-    def tearDownClass(cls):
-        """Remove data created by setUp, and kept for the lifetime of the
-        class."""
-        first_self = cls._first_self
-        if first_self is not None:
-            for policy in first_self.policies:
-                first_self.delete_authentication_policy(policy, force=True)
-
-            for silo in first_self.silos:
-                first_self.delete_authentication_silo(silo, force=True)
-
-            cls._first_self = None
-
-        super().tearDownClass()
+    def setUpTestData(cls):
+        cls.create_authentication_policy(name="Single Policy")
+        cls.create_authentication_policy(name="User Policy")
+        cls.create_authentication_policy(name="Service Policy")
+        cls.create_authentication_policy(name="Computer Policy")
+
+        cls.create_authentication_silo(name="Developers",
+                                       description="Developers, Developers",
+                                       policy="Single Policy")
+        cls.create_authentication_silo(name="Managers",
+                                       description="Managers",
+                                       policy="Single Policy")
+        cls.create_authentication_silo(name="QA",
+                                       description="Quality Assurance",
+                                       user_policy="User Policy",
+                                       service_policy="Service Policy",
+                                       computer_policy="Computer Policy")
 
     def get_services_dn(self):
         """Returns Services DN."""
@@ -125,13 +102,14 @@ class BaseAuthCmdTest(SambaToolCmdTest):
     def _run(cls, *argv):
         """Override _run, so we don't always have to pass host and creds."""
         args = list(argv)
-        args.extend(["-H", cls.host, cls.creds])
+        args.extend(["-H", HOST, CREDS])
         return super()._run(*args)
 
     runcmd = _run
     runsubcmd = _run
 
-    def create_authentication_policy(self, name, description=None, audit=False,
+    @classmethod
+    def create_authentication_policy(cls, name, description=None, audit=False,
                                      protect=False):
         """Create an authentication policy."""
 
@@ -148,13 +126,15 @@ class BaseAuthCmdTest(SambaToolCmdTest):
 
         # Run command and store name in self.silos for tearDownClass to clean
         # up.
-        result, out, err = self.runcmd(*cmd)
-        self.assertIsNone(result, msg=err)
-        self.assertTrue(out.startswith("Created authentication policy"))
-        self.policies.append(name)
+        result, out, err = cls.runcmd(*cmd)
+        assert result is None
+        assert out.startswith("Created authentication policy")
+        cls.addClassCleanup(cls.delete_authentication_policy,
+                            name=name, force=True)
         return name
 
-    def delete_authentication_policy(self, name, force=False):
+    @classmethod
+    def delete_authentication_policy(cls, name, force=False):
         """Delete authentication policy by name."""
         cmd = ["domain", "auth", "policy", "delete", "--name", name]
 
@@ -162,11 +142,12 @@ class BaseAuthCmdTest(SambaToolCmdTest):
         if force:
             cmd.append("--force")
 
-        result, out, err = self.runcmd(*cmd)
-        self.assertIsNone(result, msg=err)
-        self.assertIn("Deleted authentication policy", out)
+        result, out, err = cls.runcmd(*cmd)
+        assert result is None
+        assert "Deleted authentication policy" in out
 
-    def create_authentication_silo(self, name, description=None, policy=None,
+    @classmethod
+    def create_authentication_silo(cls, name, description=None, policy=None,
                                    user_policy=None, service_policy=None,
                                    computer_policy=None, audit=False,
                                    protect=False):
@@ -194,13 +175,15 @@ class BaseAuthCmdTest(SambaToolCmdTest):
 
         # Run command and store name in self.silos for tearDownClass to clean
         # up.
-        result, out, err = self.runcmd(*cmd)
-        self.assertIsNone(result, msg=err)
-        self.assertTrue(out.startswith("Created authentication silo"))
-        self.silos.append(name)
+        result, out, err = cls.runcmd(*cmd)
+        assert result is None
+        assert out.startswith("Created authentication silo")
+        cls.addClassCleanup(cls.delete_authentication_silo,
+                            name=name, force=True)
         return name
 
-    def delete_authentication_silo(self, name, force=False):
+    @classmethod
+    def delete_authentication_silo(cls, name, force=False):
         """Delete authentication silo by name."""
         cmd = ["domain", "auth", "silo", "delete", "--name", name]
 
@@ -208,9 +191,9 @@ class BaseAuthCmdTest(SambaToolCmdTest):
         if force:
             cmd.append("--force")
 
-        result, out, err = self.runcmd(*cmd)
-        self.assertIsNone(result, msg=err)
-        self.assertIn("Deleted authentication silo", out)
+        result, out, err = cls.runcmd(*cmd)
+        assert result is None
+        assert "Deleted authentication silo" in out
 
     def get_authentication_silo(self, name):
         """Get authentication silo by name."""
index acd62804cf1a0b1393f5d098b53f260c09f429f8..50e12fbbf08eb2d38e1e808a32b2dec6c46fc056 100644 (file)
@@ -39,8 +39,10 @@ class AuthPolicyCmdTestCase(BaseAuthCmdTest):
         result, out, err = self.runcmd("domain", "auth", "policy", "list")
         self.assertIsNone(result, msg=err)
 
-        # Check each authentication policy we created is there.
-        for policy in self.policies:
+        expected_policies = [
+            "Single Policy", "User Policy", "Service Policy", "Computer Policy"]
+
+        for policy in expected_policies:
             self.assertIn(policy, out)
 
     def test_authentication_policy_list_json(self):
@@ -52,8 +54,10 @@ class AuthPolicyCmdTestCase(BaseAuthCmdTest):
         # we should get valid json
         policies = json.loads(out)
 
-        # each policy in self.policies must be present
-        for name in self.policies:
+        expected_policies = [
+            "Single Policy", "User Policy", "Service Policy", "Computer Policy"]
+
+        for name in expected_policies:
             policy = policies[name]
             self.assertIn("name", policy)
             self.assertIn("msDS-AuthNPolicy", list(policy["objectClass"]))
index 2b18098ba0f81fe6a573a0052c88d26fe7c5f1d8..f7cd7859f037f952098b60ccf98b4207dfd2e5b1 100644 (file)
@@ -38,8 +38,9 @@ class AuthSiloCmdTestCase(BaseAuthCmdTest):
         result, out, err = self.runcmd("domain", "auth", "silo", "list")
         self.assertIsNone(result, msg=err)
 
-        # Check each silo we created is there.
-        for silo in self.silos:
+        expected_silos = ["Developers", "Managers", "QA"]
+
+        for silo in expected_silos:
             self.assertIn(silo, out)
 
     def test_authentication_silo_list_json(self):
@@ -51,8 +52,9 @@ class AuthSiloCmdTestCase(BaseAuthCmdTest):
         # we should get valid json
         silos = json.loads(out)
 
-        # each silo in self.silos must be present
-        for name in self.silos:
+        expected_silos = ["Developers", "Managers", "QA"]
+
+        for name in expected_silos:
             silo = silos[name]
             self.assertIn("msDS-AuthNPolicySilo", list(silo["objectClass"]))
             self.assertIn("description", silo)
index 0ae61f34ae780ababc5cfa6c077450d36acc0b95..675b63ad6c37a7ea265f6ab5c043ce707c4afa84 100644 (file)
@@ -63,60 +63,30 @@ VALUE_TYPES = [
     "Yes/No"
 ]
 
+HOST = "ldap://{DC_SERVER}".format(**os.environ)
+CREDS = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
+
 
 class ClaimCmdTestCase(SambaToolCmdTest):
-    def setUp(self):
-        super().setUp()
-        self.this_test_claim_types = set()
-
-        if self._first_self is None:
-            cls = type(self)
-            cls.host = "ldap://{DC_SERVER}".format(**os.environ)
-            cls.creds = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
-            cls.samdb = self.getSamDB("-H", self.host, self.creds)
-
-            # Generate some known claim types used by tests.
-            for attribute in ATTRIBUTES:
-                self.create_claim_type(attribute, classes=["user"], preserve=True)
-
-            # Generate some more with unique names not in the ATTRIBUTES list.
-            self.create_claim_type("accountExpires", name="expires",
-                                   classes=["user"], preserve=True)
-            self.create_claim_type("department", name="dept", classes=["user"],
-                                   protect=True, preserve=True)
-            self.create_claim_type("carLicense", name="plate", classes=["user"],
-                                   disable=True, preserve=True)
-
-            cls._first_self = self
-
-    def tearDown(self):
-        # Remove claim types created by a single test.
-        first_self = self._first_self
-        if first_self is not None:
-            for claim_type in first_self.this_test_claim_types:
-                first_self.delete_claim_type(claim_type, force=True)
-                first_self.claim_types.remove(claim_type)
-
-        super().tearDown()
 
     @classmethod
     def setUpClass(cls):
+        cls.samdb = cls.getSamDB("-H", HOST, CREDS)
         super().setUpClass()
-        cls._first_self = None
-        cls.claim_types = set()
 
     @classmethod
-    def tearDownClass(cls):
-        # Remove claim types created by setUp, and kept for the lifetime of the
-        # class.
-        first_self = cls._first_self
-        if first_self is not None:
-            for claim_type in first_self.claim_types:
-                first_self.delete_claim_type(claim_type, force=True)
-
-            cls._first_self = None
+    def setUpTestData(cls):
+        # Generate some known claim types used by tests.
+        for attribute in ATTRIBUTES:
+            cls.create_claim_type(attribute, classes=["user"])
 
-        super().tearDownClass()
+        # Generate some more with unique names not in the ATTRIBUTES list.
+        cls.create_claim_type("accountExpires", name="expires",
+                              classes=["user"])
+        cls.create_claim_type("department", name="dept", classes=["user"],
+                              protect=True)
+        cls.create_claim_type("carLicense", name="plate", classes=["user"],
+                              disable=True)
 
     def get_services_dn(self):
         """Returns Services DN."""
@@ -134,15 +104,15 @@ class ClaimCmdTestCase(SambaToolCmdTest):
     def _run(cls, *argv):
         """Override _run, so we don't always have to pass host and creds."""
         args = list(argv)
-        args.extend(["-H", cls.host, cls.creds])
+        args.extend(["-H", HOST, CREDS])
         return super()._run(*args)
 
     runcmd = _run
     runsubcmd = _run
 
-    def create_claim_type(self, attribute, name=None, description=None,
-                          classes=None, disable=False, protect=False,
-                          preserve=False):
+    @classmethod
+    def create_claim_type(cls, attribute, name=None, description=None,
+                          classes=None, disable=False, protect=False):
         """Create a claim type using the samba-tool command."""
 
         # if name is specified it will override the attribute name
@@ -166,16 +136,14 @@ class ClaimCmdTestCase(SambaToolCmdTest):
         if protect:
             cmd.append("--protect")
 
-        result, out, err = self.runcmd(*cmd)
-        self.assertIsNone(result, msg=err)
-        self.assertTrue(out.startswith("Created claim type"))
-        if preserve:
-            self.claim_types.add(display_name)
-        else:
-            self.this_test_claim_types.add(display_name)
+        result, out, err = cls.runcmd(*cmd)
+        assert result is None
+        assert out.startswith("Created claim type")
+        cls.addClassCleanup(cls.delete_claim_type, name=display_name, force=True)
         return display_name
 
-    def delete_claim_type(self, name, force=False):
+    @classmethod
+    def delete_claim_type(cls, name, force=False):
         """Delete claim type by display name."""
         cmd = ["domain", "claim", "claim-type", "delete", "--name", name]
 
@@ -183,9 +151,9 @@ class ClaimCmdTestCase(SambaToolCmdTest):
         if force:
             cmd.append("--force")
 
-        result, out, err = self.runcmd(*cmd)
-        self.assertIsNone(result, msg=err)
-        self.assertIn("Deleted claim type", out)
+        result, out, err = cls.runcmd(*cmd)
+        assert result is None
+        assert "Deleted claim type" in out
 
     def get_claim_type(self, name):
         """Get claim type by display name."""
@@ -203,8 +171,7 @@ class ClaimCmdTestCase(SambaToolCmdTest):
         result, out, err = self.runcmd("domain", "claim", "claim-type", "list")
         self.assertIsNone(result, msg=err)
 
-        # check each claim type we created is there
-        for claim_type in self.claim_types:
+        for claim_type in ATTRIBUTES:
             self.assertIn(claim_type, out)
 
     def test_claim_type_list_json(self):
@@ -217,8 +184,7 @@ class ClaimCmdTestCase(SambaToolCmdTest):
         json_result = json.loads(out)
         claim_types = list(json_result.keys())
 
-        # check each claim type we created is there
-        for claim_type in self.claim_types:
+        for claim_type in ATTRIBUTES:
             self.assertIn(claim_type, claim_types)
 
     def test_claim_type_view(self):
index cc1572d2d6902cdf605d7f2b42d5d4ab6310008a..4b0b10ead2ebb4699e0e9981982b3c90f3128f3b 100644 (file)
@@ -32,6 +32,9 @@ from samba.ndr import ndr_unpack
 
 from .base import SambaToolCmdTest
 
+HOST = "ldap://{DC_SERVER}".format(**os.environ)
+CREDS = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
+
 
 class FieldTestMixin:
     """Tests a model field to ensure it behaves correctly in both directions.
@@ -39,11 +42,10 @@ class FieldTestMixin:
     Use a mixin since TestCase can't be marked as abstract.
     """
 
-    def setUp(self):
-        super().setUp()
-        self.host = "ldap://{DC_SERVER}".format(**os.environ)
-        self.creds = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
-        self.samdb = self.getSamDB("-H", self.host, self.creds)
+    @classmethod
+    def setUpClass(cls):
+        cls.samdb = cls.getSamDB("-H", HOST, CREDS)
+        super().setUpClass()
 
     def get_users_dn(self):
         """Returns Users DN."""