Still only load the test data once per test class, but much easier to read.
Made several methods static for creating/deleting claims, policies and silos.
Signed-off-by: Rob van der Linde <rob@catalyst.net.nz>
Reviewed-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
Reviewed-by: Andrew Bartlett <abartlet@samba.org>
# Unix SMB/CIFS implementation.
#
-# Base class for samba-tool domain auth policy and silo commands
+# Base test class for samba-tool domain auth policy and silo commands.
#
# Copyright (C) Catalyst.Net Ltd. 2023
#
from .base import SambaToolCmdTest
+HOST = "ldap://{DC_SERVER}".format(**os.environ)
+CREDS = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
+
class BaseAuthCmdTest(SambaToolCmdTest):
- def setUp(self):
- super().setUp()
-
- if self._first_self is None:
- cls = type(self)
- cls.host = "ldap://{DC_SERVER}".format(**os.environ)
- cls.creds = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
- cls.samdb = self.getSamDB("-H", self.host, self.creds)
-
- # Generate some test data.
- self.create_authentication_policy(name="Single Policy")
- self.create_authentication_policy(name="User Policy")
- self.create_authentication_policy(name="Service Policy")
- self.create_authentication_policy(name="Computer Policy")
-
- self.create_authentication_silo(name="Developers",
- description="Developers, Developers",
- policy="Single Policy")
- self.create_authentication_silo(name="Managers",
- description="Managers",
- policy="Single Policy")
- self.create_authentication_silo(name="QA",
- description="Quality Assurance",
- user_policy="User Policy",
- service_policy="Service Policy",
- computer_policy="Computer Policy")
-
- cls._first_self = self
+ """Base test class for samba-tool domain auth policy and silo commands."""
@classmethod
def setUpClass(cls):
+ cls.samdb = cls.getSamDB("-H", HOST, CREDS)
super().setUpClass()
- cls._first_self = None
- cls.policies = []
- cls.silos = []
@classmethod
- def tearDownClass(cls):
- """Remove data created by setUp, and kept for the lifetime of the
- class."""
- first_self = cls._first_self
- if first_self is not None:
- for policy in first_self.policies:
- first_self.delete_authentication_policy(policy, force=True)
-
- for silo in first_self.silos:
- first_self.delete_authentication_silo(silo, force=True)
-
- cls._first_self = None
-
- super().tearDownClass()
+ def setUpTestData(cls):
+ cls.create_authentication_policy(name="Single Policy")
+ cls.create_authentication_policy(name="User Policy")
+ cls.create_authentication_policy(name="Service Policy")
+ cls.create_authentication_policy(name="Computer Policy")
+
+ cls.create_authentication_silo(name="Developers",
+ description="Developers, Developers",
+ policy="Single Policy")
+ cls.create_authentication_silo(name="Managers",
+ description="Managers",
+ policy="Single Policy")
+ cls.create_authentication_silo(name="QA",
+ description="Quality Assurance",
+ user_policy="User Policy",
+ service_policy="Service Policy",
+ computer_policy="Computer Policy")
def get_services_dn(self):
"""Returns Services DN."""
def _run(cls, *argv):
"""Override _run, so we don't always have to pass host and creds."""
args = list(argv)
- args.extend(["-H", cls.host, cls.creds])
+ args.extend(["-H", HOST, CREDS])
return super()._run(*args)
runcmd = _run
runsubcmd = _run
- def create_authentication_policy(self, name, description=None, audit=False,
+ @classmethod
+ def create_authentication_policy(cls, name, description=None, audit=False,
protect=False):
"""Create an authentication policy."""
# Run command and store name in self.silos for tearDownClass to clean
# up.
- result, out, err = self.runcmd(*cmd)
- self.assertIsNone(result, msg=err)
- self.assertTrue(out.startswith("Created authentication policy"))
- self.policies.append(name)
+ result, out, err = cls.runcmd(*cmd)
+ assert result is None
+ assert out.startswith("Created authentication policy")
+ cls.addClassCleanup(cls.delete_authentication_policy,
+ name=name, force=True)
return name
- def delete_authentication_policy(self, name, force=False):
+ @classmethod
+ def delete_authentication_policy(cls, name, force=False):
"""Delete authentication policy by name."""
cmd = ["domain", "auth", "policy", "delete", "--name", name]
if force:
cmd.append("--force")
- result, out, err = self.runcmd(*cmd)
- self.assertIsNone(result, msg=err)
- self.assertIn("Deleted authentication policy", out)
+ result, out, err = cls.runcmd(*cmd)
+ assert result is None
+ assert "Deleted authentication policy" in out
- def create_authentication_silo(self, name, description=None, policy=None,
+ @classmethod
+ def create_authentication_silo(cls, name, description=None, policy=None,
user_policy=None, service_policy=None,
computer_policy=None, audit=False,
protect=False):
# Run command and store name in self.silos for tearDownClass to clean
# up.
- result, out, err = self.runcmd(*cmd)
- self.assertIsNone(result, msg=err)
- self.assertTrue(out.startswith("Created authentication silo"))
- self.silos.append(name)
+ result, out, err = cls.runcmd(*cmd)
+ assert result is None
+ assert out.startswith("Created authentication silo")
+ cls.addClassCleanup(cls.delete_authentication_silo,
+ name=name, force=True)
return name
- def delete_authentication_silo(self, name, force=False):
+ @classmethod
+ def delete_authentication_silo(cls, name, force=False):
"""Delete authentication silo by name."""
cmd = ["domain", "auth", "silo", "delete", "--name", name]
if force:
cmd.append("--force")
- result, out, err = self.runcmd(*cmd)
- self.assertIsNone(result, msg=err)
- self.assertIn("Deleted authentication silo", out)
+ result, out, err = cls.runcmd(*cmd)
+ assert result is None
+ assert "Deleted authentication silo" in out
def get_authentication_silo(self, name):
"""Get authentication silo by name."""
result, out, err = self.runcmd("domain", "auth", "policy", "list")
self.assertIsNone(result, msg=err)
- # Check each authentication policy we created is there.
- for policy in self.policies:
+ expected_policies = [
+ "Single Policy", "User Policy", "Service Policy", "Computer Policy"]
+
+ for policy in expected_policies:
self.assertIn(policy, out)
def test_authentication_policy_list_json(self):
# we should get valid json
policies = json.loads(out)
- # each policy in self.policies must be present
- for name in self.policies:
+ expected_policies = [
+ "Single Policy", "User Policy", "Service Policy", "Computer Policy"]
+
+ for name in expected_policies:
policy = policies[name]
self.assertIn("name", policy)
self.assertIn("msDS-AuthNPolicy", list(policy["objectClass"]))
result, out, err = self.runcmd("domain", "auth", "silo", "list")
self.assertIsNone(result, msg=err)
- # Check each silo we created is there.
- for silo in self.silos:
+ expected_silos = ["Developers", "Managers", "QA"]
+
+ for silo in expected_silos:
self.assertIn(silo, out)
def test_authentication_silo_list_json(self):
# we should get valid json
silos = json.loads(out)
- # each silo in self.silos must be present
- for name in self.silos:
+ expected_silos = ["Developers", "Managers", "QA"]
+
+ for name in expected_silos:
silo = silos[name]
self.assertIn("msDS-AuthNPolicySilo", list(silo["objectClass"]))
self.assertIn("description", silo)
"Yes/No"
]
+HOST = "ldap://{DC_SERVER}".format(**os.environ)
+CREDS = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
+
class ClaimCmdTestCase(SambaToolCmdTest):
- def setUp(self):
- super().setUp()
- self.this_test_claim_types = set()
-
- if self._first_self is None:
- cls = type(self)
- cls.host = "ldap://{DC_SERVER}".format(**os.environ)
- cls.creds = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
- cls.samdb = self.getSamDB("-H", self.host, self.creds)
-
- # Generate some known claim types used by tests.
- for attribute in ATTRIBUTES:
- self.create_claim_type(attribute, classes=["user"], preserve=True)
-
- # Generate some more with unique names not in the ATTRIBUTES list.
- self.create_claim_type("accountExpires", name="expires",
- classes=["user"], preserve=True)
- self.create_claim_type("department", name="dept", classes=["user"],
- protect=True, preserve=True)
- self.create_claim_type("carLicense", name="plate", classes=["user"],
- disable=True, preserve=True)
-
- cls._first_self = self
-
- def tearDown(self):
- # Remove claim types created by a single test.
- first_self = self._first_self
- if first_self is not None:
- for claim_type in first_self.this_test_claim_types:
- first_self.delete_claim_type(claim_type, force=True)
- first_self.claim_types.remove(claim_type)
-
- super().tearDown()
@classmethod
def setUpClass(cls):
+ cls.samdb = cls.getSamDB("-H", HOST, CREDS)
super().setUpClass()
- cls._first_self = None
- cls.claim_types = set()
@classmethod
- def tearDownClass(cls):
- # Remove claim types created by setUp, and kept for the lifetime of the
- # class.
- first_self = cls._first_self
- if first_self is not None:
- for claim_type in first_self.claim_types:
- first_self.delete_claim_type(claim_type, force=True)
-
- cls._first_self = None
+ def setUpTestData(cls):
+ # Generate some known claim types used by tests.
+ for attribute in ATTRIBUTES:
+ cls.create_claim_type(attribute, classes=["user"])
- super().tearDownClass()
+ # Generate some more with unique names not in the ATTRIBUTES list.
+ cls.create_claim_type("accountExpires", name="expires",
+ classes=["user"])
+ cls.create_claim_type("department", name="dept", classes=["user"],
+ protect=True)
+ cls.create_claim_type("carLicense", name="plate", classes=["user"],
+ disable=True)
def get_services_dn(self):
"""Returns Services DN."""
def _run(cls, *argv):
"""Override _run, so we don't always have to pass host and creds."""
args = list(argv)
- args.extend(["-H", cls.host, cls.creds])
+ args.extend(["-H", HOST, CREDS])
return super()._run(*args)
runcmd = _run
runsubcmd = _run
- def create_claim_type(self, attribute, name=None, description=None,
- classes=None, disable=False, protect=False,
- preserve=False):
+ @classmethod
+ def create_claim_type(cls, attribute, name=None, description=None,
+ classes=None, disable=False, protect=False):
"""Create a claim type using the samba-tool command."""
# if name is specified it will override the attribute name
if protect:
cmd.append("--protect")
- result, out, err = self.runcmd(*cmd)
- self.assertIsNone(result, msg=err)
- self.assertTrue(out.startswith("Created claim type"))
- if preserve:
- self.claim_types.add(display_name)
- else:
- self.this_test_claim_types.add(display_name)
+ result, out, err = cls.runcmd(*cmd)
+ assert result is None
+ assert out.startswith("Created claim type")
+ cls.addClassCleanup(cls.delete_claim_type, name=display_name, force=True)
return display_name
- def delete_claim_type(self, name, force=False):
+ @classmethod
+ def delete_claim_type(cls, name, force=False):
"""Delete claim type by display name."""
cmd = ["domain", "claim", "claim-type", "delete", "--name", name]
if force:
cmd.append("--force")
- result, out, err = self.runcmd(*cmd)
- self.assertIsNone(result, msg=err)
- self.assertIn("Deleted claim type", out)
+ result, out, err = cls.runcmd(*cmd)
+ assert result is None
+ assert "Deleted claim type" in out
def get_claim_type(self, name):
"""Get claim type by display name."""
result, out, err = self.runcmd("domain", "claim", "claim-type", "list")
self.assertIsNone(result, msg=err)
- # check each claim type we created is there
- for claim_type in self.claim_types:
+ for claim_type in ATTRIBUTES:
self.assertIn(claim_type, out)
def test_claim_type_list_json(self):
json_result = json.loads(out)
claim_types = list(json_result.keys())
- # check each claim type we created is there
- for claim_type in self.claim_types:
+ for claim_type in ATTRIBUTES:
self.assertIn(claim_type, claim_types)
def test_claim_type_view(self):
from .base import SambaToolCmdTest
+HOST = "ldap://{DC_SERVER}".format(**os.environ)
+CREDS = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
+
class FieldTestMixin:
"""Tests a model field to ensure it behaves correctly in both directions.
Use a mixin since TestCase can't be marked as abstract.
"""
- def setUp(self):
- super().setUp()
- self.host = "ldap://{DC_SERVER}".format(**os.environ)
- self.creds = "-U{DC_USERNAME}%{DC_PASSWORD}".format(**os.environ)
- self.samdb = self.getSamDB("-H", self.host, self.creds)
+ @classmethod
+ def setUpClass(cls):
+ cls.samdb = cls.getSamDB("-H", HOST, CREDS)
+ super().setUpClass()
def get_users_dn(self):
"""Returns Users DN."""