]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
python:lsa_utils: Fix fallback to OpenPolicy2
authorStefan Metzmacher <metze@samba.org>
Wed, 17 Jul 2024 16:12:31 +0000 (18:12 +0200)
committerJule Anger <janger@samba.org>
Thu, 20 Feb 2025 11:22:18 +0000 (11:22 +0000)
BUG: https://bugzilla.samba.org/show_bug.cgi?id=15680

Pair-Programmed-With: Andreas Schneider <asn@samba.org>
Signed-off-by: Andreas Schneider <asn@samba.org>
Signed-off-by: Stefan Metzmacher <metze@samba.org>
Autobuild-User(master): Andreas Schneider <asn@cryptomilk.org>
Autobuild-Date(master): Mon Feb 17 18:33:15 UTC 2025 on atb-devel-224

(cherry picked from commit a814f5d90a3fb85a94c9516dba224037e8fd76f1)

Autobuild-User(v4-22-test): Jule Anger <janger@samba.org>
Autobuild-Date(v4-22-test): Thu Feb 20 11:22:18 UTC 2025 on atb-devel-224

python/samba/lsa_utils.py
python/samba/netcmd/domain/trust.py
python/samba/tests/dcerpc/lsa_utils.py
python/samba/tests/krb5/kdc_base_test.py

index 571beb46c851dd97a4f1fd3953799018942501bd..506dc399c93aeea09ff6cf27a0d019d69638a040 100644 (file)
@@ -20,24 +20,27 @@ from samba.dcerpc import lsa, drsblobs, misc
 from samba.ndr import ndr_pack
 from samba import (
     NTSTATUSError,
+    ntstatus,
     aead_aes_256_cbc_hmac_sha512,
     arcfour_encrypt,
 )
-from samba.ntstatus import (
-    NT_STATUS_RPC_PROCNUM_OUT_OF_RANGE
-)
 from samba import crypto
 from secrets import token_bytes
+# FIXME from collections.abc import Callable
 
 
 def OpenPolicyFallback(
-    conn: lsa.lsarpc,
+    # new_lsa_conn: Callable[[], lsa.lsarpc], - FIXME the type doesn't work
+    # with python version 3.6 (CentOS8, SLES15).
+    new_lsa_conn,
     system_name: str,
     in_version: int,
     in_revision_info: lsa.revision_info1,
     sec_qos: bool,
     access_mask: int,
 ):
+    conn = new_lsa_conn()
+
     attr = lsa.ObjectAttribute()
     if sec_qos:
         qos = lsa.QosInfo()
@@ -48,26 +51,38 @@ def OpenPolicyFallback(
 
         attr.sec_qos = qos
 
-    try:
-        out_version, out_rev_info, policy = conn.OpenPolicy3(
-            system_name,
-            attr,
-            access_mask,
-            in_version,
-            in_revision_info
-        )
-    except NTSTATUSError as e:
-        if e.args[0] == NT_STATUS_RPC_PROCNUM_OUT_OF_RANGE:
-            out_version = 1
-            out_rev_info = lsa.revision_info1()
-            out_rev_info.revision = 1
-            out_rev_info.supported_features = 0
-
-            policy = conn.OpenPolicy2(system_name, attr, access_mask)
-        else:
-            raise
-
-    return out_version, out_rev_info, policy
+    open_policy2 = False
+    if in_revision_info is not None:
+        try:
+            out_version, out_rev_info, policy = conn.OpenPolicy3(
+                system_name,
+                attr,
+                access_mask,
+                in_version,
+                in_revision_info
+            )
+        except NTSTATUSError as e:
+            if e.args[0] == ntstatus.NT_STATUS_RPC_PROCNUM_OUT_OF_RANGE:
+                open_policy2 = True
+            if e.args[0] == ntstatus.NT_STATUS_ACCESS_DENIED:
+                # We need a new connection
+                conn = new_lsa_conn(basis_connection=conn)
+
+                open_policy2 = True
+            else:
+                raise
+    else:
+        open_policy2 = True
+
+    if open_policy2:
+        out_version = 1
+        out_rev_info = lsa.revision_info1()
+        out_rev_info.revision = 1
+        out_rev_info.supported_features = 0
+
+        policy = conn.OpenPolicy2(system_name, attr, access_mask)
+
+    return conn, out_version, out_rev_info, policy
 
 
 def CreateTrustedDomainRelax(
index f39d4814a111e4d868f63fded6a8628c1bc69dab..f3d75f841377f0f28eb7a216208a1caf56aced6c 100644 (file)
@@ -125,8 +125,13 @@ class DomainTrustCommand(Command):
         self.local_creds = local_creds
         return self.local_server
 
-    def new_local_lsa_connection(self):
-        return lsa.lsarpc(self.local_binding_string, self.local_lp, self.local_creds)
+    def new_local_lsa_connection(self, basis_connection=None):
+        return lsa.lsarpc(
+            self.local_binding_string,
+            self.local_lp,
+            self.local_creds,
+            basis_connection=basis_connection
+        )
 
     def new_local_netlogon_connection(self):
         return netlogon.netlogon(self.local_binding_string, self.local_lp, self.local_creds)
@@ -203,13 +208,23 @@ class DomainTrustCommand(Command):
         self.remote_creds = remote_creds
         return self.remote_server
 
-    def new_remote_lsa_connection(self):
-        return lsa.lsarpc(self.remote_binding_string, self.local_lp, self.remote_creds)
+    def new_remote_lsa_connection(self, basis_connection=None):
+        return lsa.lsarpc(
+            self.remote_binding_string,
+            self.local_lp,
+            self.remote_creds,
+            basis_connection=basis_connection
+        )
 
-    def new_remote_netlogon_connection(self):
-        return netlogon.netlogon(self.remote_binding_string, self.local_lp, self.remote_creds)
+    def new_remote_netlogon_connection(self, basis_connection=None):
+        return netlogon.netlogon(
+            self.remote_binding_string,
+            self.local_lp,
+            self.remote_creds,
+            basis_connection=basis_connection
+        )
 
-    def get_lsa_info(self, conn, policy_access):
+    def get_lsa_info(self, conn_fn, policy_access):
         in_version = 1
         in_revision_info1 = lsa.revision_info1()
         in_revision_info1.revision = 1
@@ -217,9 +232,9 @@ class DomainTrustCommand(Command):
             lsa.LSA_FEATURE_TDO_AUTH_INFO_AES_CIPHER
         )
 
-        out_version, out_revision_info1, policy = OpenPolicyFallback(
-            conn,
-            b''.decode('utf-8'),
+        conn, out_version, out_revision_info1, policy = OpenPolicyFallback(
+            conn_fn,
+            '',
             in_version,
             in_revision_info1,
             False,
@@ -228,7 +243,7 @@ class DomainTrustCommand(Command):
 
         info = conn.QueryInfoPolicy2(policy, lsa.LSA_POLICY_INFO_DNS)
 
-        return (policy, out_version, out_revision_info1, info)
+        return (conn, policy, out_version, out_revision_info1, info)
 
     def get_netlogon_dc_unc(self, conn, server, domain):
         try:
@@ -508,19 +523,15 @@ class cmd_domain_trust_show(DomainTrustCommand):
     def run(self, domain, sambaopts=None, versionopts=None, localdcopts=None):
 
         self.setup_local_server(sambaopts, localdcopts)
-        try:
-            local_lsa = self.new_local_lsa_connection()
-        except RuntimeError as error:
-            raise self.LocalRuntimeError(self, error, "failed to connect lsa server")
-
         try:
             local_policy_access = lsa.LSA_POLICY_VIEW_LOCAL_INFORMATION
             (
+                local_lsa,
                 local_policy,
                 local_version,
                 local_revision_info1,
                 local_lsa_info
-            ) = self.get_lsa_info(local_lsa, local_policy_access)
+            ) = self.get_lsa_info(self.new_local_lsa_connection, local_policy_access)
         except RuntimeError as error:
             raise self.LocalRuntimeError(self, error, "failed to query LSA_POLICY_INFO_DNS")
 
@@ -649,19 +660,16 @@ class cmd_domain_trust_modify(DomainTrustCommand):
             raise CommandError("modification arguments are required, try --help")
 
         self.setup_local_server(sambaopts, localdcopts)
-        try:
-            local_lsa = self.new_local_lsa_connection()
-        except RuntimeError as error:
-            raise self.LocalRuntimeError(self, error, "failed to connect to lsa server")
 
         try:
             local_policy_access = lsa.LSA_POLICY_VIEW_LOCAL_INFORMATION
             (
+                local_lsa,
                 local_policy,
                 local_version,
                 local_revision_info1,
                 local_lsa_info
-            ) = self.get_lsa_info(local_lsa, local_policy_access)
+            ) = self.get_lsa_info(self.new_local_lsa_connection, local_policy_access)
         except RuntimeError as error:
             raise self.LocalRuntimeError(self, error, "failed to query LSA_POLICY_INFO_DNS")
 
@@ -908,18 +916,15 @@ class cmd_domain_trust_create(DomainTrustCommand):
                 remote_trust_info.trust_attributes |= lsa.LSA_TRUST_ATTRIBUTE_TREAT_AS_EXTERNAL
 
         local_server = self.setup_local_server(sambaopts, localdcopts)
-        try:
-            local_lsa = self.new_local_lsa_connection()
-        except RuntimeError as error:
-            raise self.LocalRuntimeError(self, error, "failed to connect lsa server")
 
         try:
             (
+                local_lsa,
                 local_policy,
                 local_version,
                 local_revision_info1,
                 local_lsa_info
-            ) = self.get_lsa_info(local_lsa, local_policy_access)
+            ) = self.get_lsa_info(self.new_local_lsa_connection, local_policy_access)
         except RuntimeError as error:
             raise self.LocalRuntimeError(self, error, "failed to query LSA_POLICY_INFO_DNS")
 
@@ -933,18 +938,14 @@ class cmd_domain_trust_create(DomainTrustCommand):
         except RuntimeError as error:
             raise self.RemoteRuntimeError(self, error, "failed to locate remote server")
 
-        try:
-            remote_lsa = self.new_remote_lsa_connection()
-        except RuntimeError as error:
-            raise self.RemoteRuntimeError(self, error, "failed to connect lsa server")
-
         try:
             (
+                remote_lsa,
                 remote_policy,
                 remote_version,
                 remote_revision_info1,
                 remote_lsa_info
-            ) = self.get_lsa_info(remote_lsa, remote_policy_access)
+            ) = self.get_lsa_info(self.new_remote_lsa_connection, remote_policy_access)
         except RuntimeError as error:
             raise self.RemoteRuntimeError(self, error, "failed to query LSA_POLICY_INFO_DNS")
 
@@ -1297,18 +1298,15 @@ class cmd_domain_trust_delete(DomainTrustCommand):
             remote_policy_access |= lsa.LSA_POLICY_CREATE_SECRET
 
         self.setup_local_server(sambaopts, localdcopts)
-        try:
-            local_lsa = self.new_local_lsa_connection()
-        except RuntimeError as error:
-            raise self.LocalRuntimeError(self, error, "failed to connect lsa server")
 
         try:
             (
+                local_lsa,
                 local_policy,
                 local_version,
                 local_revision_info1,
                 local_lsa_info
-            ) = self.get_lsa_info(local_lsa, local_policy_access)
+            ) = self.get_lsa_info(self.new_local_lsa_connection, local_policy_access)
         except RuntimeError as error:
             raise self.LocalRuntimeError(self, error, "failed to query LSA_POLICY_INFO_DNS")
 
@@ -1338,18 +1336,14 @@ class cmd_domain_trust_delete(DomainTrustCommand):
             except RuntimeError as error:
                 raise self.RemoteRuntimeError(self, error, "failed to locate remote server")
 
-            try:
-                remote_lsa = self.new_remote_lsa_connection()
-            except RuntimeError as error:
-                raise self.RemoteRuntimeError(self, error, "failed to connect lsa server")
-
             try:
                 (
+                    remote_lsa,
                     remote_policy,
                     remote_version,
                     remote_revision_info1,
                     remote_lsa_info
-                ) = self.get_lsa_info(remote_lsa, remote_policy_access)
+                ) = self.get_lsa_info(self.new_remote_lsa_connection, remote_policy_access)
             except RuntimeError as error:
                 raise self.RemoteRuntimeError(self, error, "failed to query LSA_POLICY_INFO_DNS")
 
@@ -1450,18 +1444,15 @@ class cmd_domain_trust_validate(DomainTrustCommand):
         local_policy_access = lsa.LSA_POLICY_VIEW_LOCAL_INFORMATION
 
         local_server = self.setup_local_server(sambaopts, localdcopts)
-        try:
-            local_lsa = self.new_local_lsa_connection()
-        except RuntimeError as error:
-            raise self.LocalRuntimeError(self, error, "failed to connect lsa server")
 
         try:
             (
+                local_lsa,
                 local_policy,
                 local_version,
                 local_revision_info1,
                 local_lsa_info
-            ) = self.get_lsa_info(local_lsa, local_policy_access)
+            ) = self.get_lsa_info(self.new_local_lsa_connection, local_policy_access)
         except RuntimeError as error:
             raise self.LocalRuntimeError(self, error, "failed to query LSA_POLICY_INFO_DNS")
 
@@ -1897,11 +1888,12 @@ class cmd_domain_trust_namespaces(DomainTrustCommand):
 
         try:
             (
+                local_lsa,
                 local_policy,
                 local_version,
                 local_revision_info1,
                 local_lsa_info
-            ) = self.get_lsa_info(local_lsa, local_policy_access)
+            ) = self.get_lsa_info(self.new_local_lsa_connection, local_policy_access)
         except RuntimeError as error:
             raise self.LocalRuntimeError(self, error, "failed to query LSA_POLICY_INFO_DNS")
 
index fee9a45419bda9ac9f2ce5e11d1b33de2d76b929..8a3e7d242767e0e8d44ff2392a0f4813722da8c2 100644 (file)
@@ -35,6 +35,7 @@ from samba.lsa_utils import (
 
 
 class CreateTrustedDomain(TestCase):
+    smbencrypt = True
 
     def get_user_creds(self):
         c = Credentials()
@@ -47,26 +48,35 @@ class CreateTrustedDomain(TestCase):
         c.set_password(password)
         return c
 
-    def _create_trust_relax(self, smbencrypt=True):
+    def new_lsa_conn(self, basis_connection=None):
         creds = self.get_user_creds()
-
-        if smbencrypt:
+        if self.smbencrypt:
             creds.set_smb_encryption(SMB_ENCRYPTION_REQUIRED)
         else:
             creds.set_smb_encryption(SMB_ENCRYPTION_OFF)
 
         lp = self.get_loadparm()
-
         binding_string = (
             "ncacn_np:%s" % (samba.tests.env_get_var_value('SERVER'))
         )
-        lsa_conn = lsa.lsarpc(binding_string, lp, creds)
 
-        if smbencrypt:
+        lsa_conn = lsa.lsarpc(
+            binding_string,
+            lp,
+            creds,
+            basis_connection=basis_connection
+        )
+
+        if self.smbencrypt:
             self.assertTrue(lsa_conn.transport_encrypted())
         else:
             self.assertFalse(lsa_conn.transport_encrypted())
 
+        return lsa_conn
+
+    def _create_trust_relax(self, smbencrypt=True):
+        self.smbencrypt = smbencrypt
+
         in_version = 1
         in_revision_info1 = lsa.revision_info1()
         in_revision_info1.revision = 1
@@ -74,8 +84,13 @@ class CreateTrustedDomain(TestCase):
             lsa.LSA_FEATURE_TDO_AUTH_INFO_AES_CIPHER
         )
 
-        out_version, out_revision_info1, pol_handle = OpenPolicyFallback(
+        (
             lsa_conn,
+            out_version,
+            out_revision_info1,
+            pol_handle
+        ) = OpenPolicyFallback(
+            self.new_lsa_conn,
             '',
             in_version,
             in_revision_info1,
@@ -148,14 +163,7 @@ class CreateTrustedDomain(TestCase):
             self.assertIsNone(trustdom_handle)
 
     def _create_trust_fallback(self):
-        creds = self.get_user_creds()
-
-        lp = self.get_loadparm()
-
-        binding_string = (
-            "ncacn_np:%s" % (samba.tests.env_get_var_value('SERVER'))
-        )
-        lsa_conn = lsa.lsarpc(binding_string, lp, creds)
+        self.smbencrypt = True
 
         in_version = 1
         in_revision_info1 = lsa.revision_info1()
@@ -164,8 +172,13 @@ class CreateTrustedDomain(TestCase):
             lsa.LSA_FEATURE_TDO_AUTH_INFO_AES_CIPHER
         )
 
-        out_version, out_revision_info1, pol_handle = OpenPolicyFallback(
+        (
             lsa_conn,
+            out_version,
+            out_revision_info1,
+            pol_handle
+        ) = OpenPolicyFallback(
+            self.new_lsa_conn,
             '',
             in_version,
             in_revision_info1,
index 1da770e4fe8c88c07d451197c6fa7ee0ee66a49b..dee6ef830718cbe9f0700495fd029499d9cb7ceb 100644 (file)
@@ -57,7 +57,6 @@ from samba.crypto import des_crypt_blob_16, md4_hash_blob
 from samba.lsa_utils import OpenPolicyFallback, CreateTrustedDomainFallback
 from samba.dcerpc import (
     claims,
-    dcerpc,
     drsblobs,
     drsuapi,
     krb5ccache,
@@ -441,7 +440,7 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
         return self._drsuapi_connection
 
     def get_lsarpc_connection(self):
-        def get_lsa_info(conn, policy_access):
+        def get_lsa_info(conn_fn, policy_access):
             in_version = 1
             in_revision_info1 = lsa.revision_info1()
             in_revision_info1.revision = 1
@@ -449,9 +448,9 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
                 lsa.LSA_FEATURE_TDO_AUTH_INFO_AES_CIPHER
             )
 
-            out_version, out_revision_info1, policy = OpenPolicyFallback(
-                conn,
-                b''.decode('utf-8'),
+            conn, out_version, out_revision_info1, policy = OpenPolicyFallback(
+                conn_fn,
+                '',
                 in_version,
                 in_revision_info1,
                 False,
@@ -460,7 +459,18 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
 
             info = conn.QueryInfoPolicy2(policy, lsa.LSA_POLICY_INFO_DNS)
 
-            return (policy, out_version, out_revision_info1, info)
+            return (conn, policy, out_version, out_revision_info1, info)
+
+        def new_lsa_conn(basis_connection=None):
+            lp = self.get_lp()
+            admin_creds = self.get_admin_creds()
+
+            return lsa.lsarpc(
+                self._binding_string,
+                lp,
+                admin_creds,
+                basis_connection=basis_connection
+            )
 
         def lsarpc_connect(server, lp, creds, ip=None):
             binding_options = ""
@@ -474,13 +484,14 @@ class KDCBaseTest(TestCaseInTempDir, RawKerberosTest):
             else:
                 binding_string = "ncacn_np:%s[%s]" % (server, binding_options)
 
+            self._binding_string = binding_string
+
             try:
-                conn = lsa.lsarpc(binding_string, lp, creds)
                 policy_access = lsa.LSA_POLICY_VIEW_LOCAL_INFORMATION
                 policy_access |= lsa.LSA_POLICY_TRUST_ADMIN
                 policy_access |= lsa.LSA_POLICY_CREATE_SECRET
-                (policy, out_version, out_revision_info1, info) = \
-                    get_lsa_info(conn, policy_access)
+                (conn, policy, out_version, out_revision_info1, info) = \
+                    get_lsa_info(new_lsa_conn, policy_access)
             except Exception as e:
                 raise RuntimeError("LSARPC connection to %s failed: %s" % (server, e))