From: Rob van der Linde Date: Thu, 5 Oct 2023 01:03:14 +0000 (+1300) Subject: netcmd: don't turn exception into CommandError in run_validators X-Git-Tag: talloc-2.4.2~1142 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=bdad257a31210e7d6195212bd051f1136854ce6f;p=thirdparty%2Fsamba.git netcmd: don't turn exception into CommandError in run_validators It's the wrong place to do it. Instead, let it raise the original exception, capture it in _run, and call existing show_command_error method. Signed-off-by: Rob van der Linde Reviewed-by: Douglas Bagnall Reviewed-by: Andrew Bartlett --- diff --git a/python/samba/netcmd/__init__.py b/python/samba/netcmd/__init__.py index 81f8d68b9e0..edd254145f5 100644 --- a/python/samba/netcmd/__init__.py +++ b/python/samba/netcmd/__init__.py @@ -31,7 +31,6 @@ from samba.logger import get_samba_logger from samba.samdb import SamDB from .encoders import JSONEncoder -from .validators import ValidationError class Option(SambaOption): @@ -39,18 +38,10 @@ class Option(SambaOption): SUPPRESS_HELP = optparse.SUPPRESS_HELP def run_validators(self, opt, value): - """Runs the list of validators on the current option. - - If the validator raises ValidationError, turn that into CommandError - which gives nicer output. - """ + """Runs the list of validators on the current option.""" validators = getattr(self, "validators") or [] - for validator in validators: - try: - validator(opt, value) - except ValidationError as e: - raise CommandError(e) + validator(opt, value) def convert_value(self, opt, value): """Override convert_value to run validators just after. @@ -242,7 +233,14 @@ class Command(object): def _run(self, *argv): parser, optiongroups = self._create_parser(self.command_name) - opts, args = parser.parse_args(list(argv)) + + # Handle possible validation errors raised by parser + try: + opts, args = parser.parse_args(list(argv)) + except Exception as e: + self.show_command_error(e) + return -1 + # Filter out options from option groups kwargs = dict(opts.__dict__) for option_group in parser.option_groups: diff --git a/python/samba/tests/samba_tool/domain_auth_policy.py b/python/samba/tests/samba_tool/domain_auth_policy.py index 674c30fc2f7..12d17519f2f 100644 --- a/python/samba/tests/samba_tool/domain_auth_policy.py +++ b/python/samba/tests/samba_tool/domain_auth_policy.py @@ -26,7 +26,6 @@ from unittest.mock import patch from samba.dcerpc import security from samba.ndr import ndr_unpack -from samba.netcmd import CommandError from samba.netcmd.domain.models.exceptions import ModelError from samba.samdb import SamDB from samba.sd_utils import SDUtils @@ -141,22 +140,20 @@ class AuthPolicyCmdTestCase(BaseAuthCmdTest): self.assertEqual(str(policy["msDS-UserTGTLifetime"]), "60") # check lower bounds (45) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "create", - "--name", "userTGTLifetimeLower", - "--user-tgt-lifetime", "44") - + result, out, err = self.runcmd("domain", "auth", "policy", "create", + "--name", "userTGTLifetimeLower", + "--user-tgt-lifetime", "44") + self.assertEqual(result, -1) self.assertIn("--user-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) # check upper bounds (2147483647) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "create", - "--name", "userTGTLifetimeUpper", - "--user-tgt-lifetime", "2147483648") - + result, out, err = self.runcmd("domain", "auth", "policy", "create", + "--name", "userTGTLifetimeUpper", + "--user-tgt-lifetime", "2147483648") + self.assertEqual(result, -1) self.assertIn("--user-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) def test_authentication_policy_create_service_tgt_lifetime(self): """Test create a new authentication policy with --service-tgt-lifetime. @@ -177,22 +174,20 @@ class AuthPolicyCmdTestCase(BaseAuthCmdTest): self.assertEqual(str(policy["msDS-ServiceTGTLifetime"]), "60") # check lower bounds (45) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "create", - "--name", "serviceTGTLifetimeLower", - "--service-tgt-lifetime", "44") - + result, out, err = self.runcmd("domain", "auth", "policy", "create", + "--name", "serviceTGTLifetimeLower", + "--service-tgt-lifetime", "44") + self.assertEqual(result, -1) self.assertIn("--service-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) # check upper bounds (2147483647) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "create", - "--name", "serviceTGTLifetimeUpper", - "--service-tgt-lifetime", "2147483648") - + result, out, err = self.runcmd("domain", "auth", "policy", "create", + "--name", "serviceTGTLifetimeUpper", + "--service-tgt-lifetime", "2147483648") + self.assertEqual(result, -1) self.assertIn("--service-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) def test_authentication_policy_create_computer_tgt_lifetime(self): """Test create a new authentication policy with --computer-tgt-lifetime. @@ -213,22 +208,20 @@ class AuthPolicyCmdTestCase(BaseAuthCmdTest): self.assertEqual(str(policy["msDS-ComputerTGTLifetime"]), "60") # check lower bounds (45) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "create", - "--name", "computerTGTLifetimeLower", - "--computer-tgt-lifetime", "44") - + result, out, err = self.runcmd("domain", "auth", "policy", "create", + "--name", "computerTGTLifetimeLower", + "--computer-tgt-lifetime", "44") + self.assertEqual(result, -1) self.assertIn("--computer-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) # check upper bounds (2147483647) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "create", - "--name", "computerTGTLifetimeUpper", - "--computer-tgt-lifetime", "2147483648") - + result, out, err = self.runcmd("domain", "auth", "policy", "create", + "--name", "computerTGTLifetimeUpper", + "--computer-tgt-lifetime", "2147483648") + self.assertEqual(result, -1) self.assertIn("--computer-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) def test_authentication_policy_create_valid_sddl(self): """Test creating a new authentication policy with valid SDDL in a field.""" @@ -387,22 +380,20 @@ class AuthPolicyCmdTestCase(BaseAuthCmdTest): self.assertEqual(str(policy["msDS-UserTGTLifetime"]), "120") # check lower bounds (45) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "modify", - "--name", name, - "--user-tgt-lifetime", "44") - + result, out, err = self.runcmd("domain", "auth", "policy", "modify", + "--name", name, + "--user-tgt-lifetime", "44") + self.assertEqual(result, -1) self.assertIn("--user-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) # check upper bounds (2147483647) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "modify", - "--name", name, - "--user-tgt-lifetime", "2147483648") - + result, out, err = self.runcmd("domain", "auth", "policy", "modify", + "--name", name, + "--user-tgt-lifetime", "2147483648") + self.assertEqual(result, -1) self.assertIn("-user-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) def test_authentication_policy_modify_service_tgt_lifetime(self): """Test modifying an authentication policy --service-tgt-lifetime. @@ -425,22 +416,20 @@ class AuthPolicyCmdTestCase(BaseAuthCmdTest): self.assertEqual(str(policy["msDS-ServiceTGTLifetime"]), "120") # check lower bounds (45) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "modify", - "--name", name, - "--service-tgt-lifetime", "44") - + result, out, err = self.runcmd("domain", "auth", "policy", "modify", + "--name", name, + "--service-tgt-lifetime", "44") + self.assertEqual(result, -1) self.assertIn("--service-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) # check upper bounds (2147483647) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "modify", - "--name", name, - "--service-tgt-lifetime", "2147483648") - + result, out, err = self.runcmd("domain", "auth", "policy", "modify", + "--name", name, + "--service-tgt-lifetime", "2147483648") + self.assertEqual(result, -1) self.assertIn("--service-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) def test_authentication_policy_modify_computer_tgt_lifetime(self): """Test modifying an authentication policy --computer-tgt-lifetime. @@ -463,22 +452,20 @@ class AuthPolicyCmdTestCase(BaseAuthCmdTest): self.assertEqual(str(policy["msDS-ComputerTGTLifetime"]), "120") # check lower bounds (45) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "modify", - "--name", name, - "--computer-tgt-lifetime", "44") - + result, out, err = self.runcmd("domain", "auth", "policy", "modify", + "--name", name, + "--computer-tgt-lifetime", "44") + self.assertEqual(result, -1) self.assertIn("--computer-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) # check upper bounds (2147483647) - with self.assertRaises(CommandError) as e: - self.runcmd("domain", "auth", "policy", "modify", - "--name", name, - "--computer-tgt-lifetime", "2147483648") - + result, out, err = self.runcmd("domain", "auth", "policy", "modify", + "--name", name, + "--computer-tgt-lifetime", "2147483648") + self.assertEqual(result, -1) self.assertIn("--computer-tgt-lifetime must be between 45 and 2147483647", - str(e.exception)) + err) def test_authentication_policy_modify_name_missing(self): """Test modify authentication but the --name argument is missing."""