]> git.ipfire.org Git - thirdparty/samba.git/commitdiff
drs_utils: Split process_chunk() out into its own class
authorJennifer Sutton <jennifersutton@catalyst.net.nz>
Tue, 23 Jul 2024 05:24:28 +0000 (17:24 +1200)
committerJo Sutton <jsutton@samba.org>
Mon, 26 May 2025 02:41:36 +0000 (02:41 +0000)
This makes it easier to add classes with new functionality without
having to figure out how to slot them into a linear class hierarchy.

BUG: https://bugzilla.samba.org/show_bug.cgi?id=15852

Signed-off-by: Jennifer Sutton <jennifersutton@catalyst.net.nz>
Reviewed-by: Douglas Bagnall <douglas.bagnall@catalyst.net.nz>
python/samba/drs_utils.py
python/samba/join.py

index beef4f8ada618f9a96630d9cd7464276c4f78479..ab65767d1bada54f21ffb723b48462ebe12eb31f 100644 (file)
@@ -29,6 +29,7 @@ from samba.dcerpc.drsuapi import (DRSUAPI_ATTID_name,
                                   DRSUAPI_SUPPORTED_EXTENSION_GETCHGREQ_V8,
                                   DRSUAPI_SUPPORTED_EXTENSION_GETCHGREQ_V10)
 import re
+from abc import ABCMeta, abstractmethod
 
 
 class drsException(Exception):
@@ -186,19 +187,49 @@ def drs_copy_highwater_mark(hwm, new_hwm):
     hwm.highest_usn = new_hwm.highest_usn
 
 
-class drs_Replicate(object):
-    """DRS replication calls"""
+class drs_ReplicatorImplBase(metaclass=ABCMeta):
+    @abstractmethod
+    def process_chunk(
+        self, samdb, level, ctr, schema, req_level, req, first_chunk,
+    ) -> None: ...
+
+    @abstractmethod
+    def supports_ext(self, ext) -> bool: ...
+
+    @abstractmethod
+    def get_nc_changes(self, req_level, req) -> bool: ...
+
 
+class drs_ReplicatorImpl(drs_ReplicatorImplBase):
     def __init__(self, binding_string, lp, creds, samdb, invocation_id):
         self.drs = drsuapi.drsuapi(binding_string, lp, creds)
-        (self.drs_handle, self.supports_ext) = drs_DsBind(self.drs)
+        (self.drs_handle, self._supports_ext) = drs_DsBind(self.drs)
         self.net = Net(creds=creds, lp=lp)
-        self.samdb = samdb
         if not isinstance(invocation_id, misc.GUID):
             raise RuntimeError("Must supply GUID for invocation_id")
         if invocation_id == misc.GUID("00000000-0000-0000-0000-000000000000"):
             raise RuntimeError("Must not set GUID 00000000-0000-0000-0000-000000000000 as invocation_id")
-        self.replication_state = self.net.replicate_init(self.samdb, lp, self.drs, invocation_id)
+        self.replication_state = self.net.replicate_init(samdb, lp, self.drs, invocation_id)
+
+    def process_chunk(self, samdb, level, ctr, schema, req_level, req, first_chunk):
+        """Processes a single chunk of received replication data"""
+        # pass the replication into the py_net.c python bindings for processing
+        self.net.replicate_chunk(self.replication_state, level, ctr,
+                                 schema=schema, req_level=req_level, req=req)
+
+    def supports_ext(self, ext) -> bool:
+        return self._supports_ext & ext
+
+    def get_nc_changes(self, req_level, req) -> bool:
+        return self.drs.DsGetNCChanges(self.drs_handle, req_level, req)
+
+
+class drs_Replicator:
+    """DRS replication implementation"""
+
+    def __init__(self, repl, samdb):
+        self.samdb = samdb
+        self.repl = repl
         self.more_flags = 0
 
     @staticmethod
@@ -247,23 +278,17 @@ class drs_Replicate(object):
             object_to_check = object_to_check.next_object
 
 
-    def process_chunk(self, level, ctr, schema, req_level, req, first_chunk):
-        """Processes a single chunk of received replication data"""
-        # pass the replication into the py_net.c python bindings for processing
-        self.net.replicate_chunk(self.replication_state, level, ctr,
-                                 schema=schema, req_level=req_level, req=req)
-
     def replicate(self, dn, source_dsa_invocation_id, destination_dsa_guid,
                   schema=False, exop=drsuapi.DRSUAPI_EXOP_NONE, rodc=False,
                   replica_flags=None, full_sync=True, sync_forced=False):
         """replicate a single DN"""
 
         # setup for a GetNCChanges call
-        if self.supports_ext & DRSUAPI_SUPPORTED_EXTENSION_GETCHGREQ_V10:
+        if self.repl.supports_ext(DRSUAPI_SUPPORTED_EXTENSION_GETCHGREQ_V10):
             req_level = 10
             req = drsuapi.DsGetNCChangesRequest10()
             req.more_flags = self.more_flags
-        elif self.supports_ext & DRSUAPI_SUPPORTED_EXTENSION_GETCHGREQ_V8:
+        elif self.repl.supports_ext(DRSUAPI_SUPPORTED_EXTENSION_GETCHGREQ_V8):
             req_level = 8
             req = drsuapi.DsGetNCChangesRequest8()
         else:
@@ -342,12 +367,14 @@ class drs_Replicate(object):
         first_chunk = True
 
         while True:
-            (level, ctr) = self.drs.DsGetNCChanges(self.drs_handle, req_level, req)
+            (level, ctr) = self.repl.get_nc_changes(req_level, req)
             if ctr.first_object is None and ctr.object_count != 0:
                 raise RuntimeError("DsGetNCChanges: NULL first_object with object_count=%u" % (ctr.object_count))
 
             try:
-                self.process_chunk(level, ctr, schema, req_level, req, first_chunk)
+                self.repl.process_chunk(
+                    self.samdb, level, ctr, schema, req_level, req, first_chunk
+                )
             except WERRORError as e:
                 # Check if retrying with the GET_TGT flag set might resolve this error
                 if self._should_retry_with_get_tgt(e.args[0], req):
@@ -383,14 +410,20 @@ class drs_Replicate(object):
         return (num_objects, num_links)
 
 
+class drs_Replicate(drs_Replicator):
+    """DRS replication calls"""
+
+    def __init__(self, binding_string, lp, creds, samdb, invocation_id):
+        repl = drs_ReplicatorImpl(binding_string, lp, creds, samdb, invocation_id)
+        super().__init__(repl, samdb)
+
 # Handles the special case of creating a new clone of a DB, while also renaming
 # the entire DB's objects on the way through
-class drs_ReplicateRenamer(drs_Replicate):
+class drs_ReplicateRenamer(drs_ReplicatorImplBase):
     """Uses DRS replication to rename the entire DB"""
 
-    def __init__(self, binding_string, lp, creds, samdb, invocation_id,
-                 old_base_dn, new_base_dn):
-        super().__init__(binding_string, lp, creds, samdb, invocation_id)
+    def __init__(self, repl, old_base_dn, new_base_dn):
+        self.repl = repl
         self.old_base_dn = old_base_dn
         self.new_base_dn = new_base_dn
 
@@ -404,15 +437,16 @@ class drs_ReplicateRenamer(drs_Replicate):
         """Uses string substitution to replace the base DN"""
         return re.sub('%s$' % self.old_base_dn, self.new_base_dn, dn_str)
 
-    def update_name_attr(self, base_obj):
+    @staticmethod
+    def update_name_attr(base_obj, samdb):
         """Updates the 'name' attribute for the base DN object"""
         for attr in base_obj.attribute_ctr.attributes:
             if attr.attid == DRSUAPI_ATTID_name:
-                base_dn = ldb.Dn(self.samdb, base_obj.identifier.dn)
+                base_dn = ldb.Dn(samdb, base_obj.identifier.dn)
                 new_name = base_dn.get_rdn_value()
                 attr.value_ctr.values[0].blob = new_name.encode("utf-16-le")
 
-    def rename_top_level_object(self, first_obj):
+    def rename_top_level_object(self, first_obj, samdb):
         """Renames the first/top-level object in a partition"""
         old_dn = first_obj.identifier.dn
         first_obj.identifier.dn = self.rename_dn(first_obj.identifier.dn)
@@ -421,9 +455,9 @@ class drs_ReplicateRenamer(drs_Replicate):
         # we also need to fix up the 'name' attribute for the base DN,
         # otherwise the RDNs won't match
         if first_obj.identifier.dn == self.new_base_dn:
-            self.update_name_attr(first_obj)
+            self.update_name_attr(first_obj, samdb)
 
-    def process_chunk(self, level, ctr, schema, req_level, req, first_chunk):
+    def process_chunk(self, samdb, level, ctr, schema, req_level, req, first_chunk):
         """Processes a single chunk of received replication data"""
 
         # we need to rename the NC in every chunk - this gets used in searches
@@ -434,7 +468,13 @@ class drs_ReplicateRenamer(drs_Replicate):
         # rename the first object in each partition. This will cause every
         # subsequent object in the partition to be renamed as a side-effect
         if first_chunk and ctr.object_count != 0:
-            self.rename_top_level_object(ctr.first_object.object)
+            self.rename_top_level_object(ctr.first_object.object, samdb)
 
         # then do the normal repl processing to apply this chunk to our DB
-        super().process_chunk(level, ctr, schema, req_level, req, first_chunk)
+        self.repl.process_chunk(samdb, level, ctr, schema, req_level, req, first_chunk)
+
+    def supports_ext(self, ext) -> bool:
+        return self.repl.supports_ext(ext)
+
+    def get_nc_changes(self, req_level, req) -> bool:
+        return self.repl.get_nc_changes(req_level, req)
index 0b5a468e6e62010e48b5a10786a9aa5b3a6cac43..3ee3bdac690afbc2d6f93cfde7c7efb4d59e9b76 100644 (file)
@@ -954,9 +954,14 @@ class DCJoinContext(object):
 
     def create_replicator(ctx, repl_creds, binding_options):
         """Creates a new DRS object for managing replications"""
-        return drs_utils.drs_Replicate(
-                "ncacn_ip_tcp:%s[%s]" % (ctx.server, binding_options),
-                ctx.lp, repl_creds, ctx.local_samdb, ctx.invocation_id)
+        repl = drs_utils.drs_ReplicatorImpl(
+            f"ncacn_ip_tcp:{ctx.server}[{binding_options}]",
+            ctx.lp,
+            repl_creds,
+            ctx.local_samdb,
+            ctx.invocation_id,
+        )
+        return repl
 
     def join_replicate(ctx):
         """Replicate the SAM."""
@@ -989,6 +994,7 @@ class DCJoinContext(object):
                 binding_options += ",print"
 
             repl = ctx.create_replicator(repl_creds, binding_options)
+            repl = drs_utils.drs_Replicator(repl, ctx.local_samdb)
 
             repl.replicate(ctx.schema_dn, source_dsa_invocation_id,
                            destination_dsa_guid, schema=True, rodc=ctx.RODC,
@@ -1726,11 +1732,8 @@ class DCCloneAndRenameContext(DCCloneContext):
         # We want to rename all the domain objects, and the simplest way to do
         # this is during replication. This is because the base DN of the top-
         # level replicated object will flow through to all the objects below it
-        binding_str = "ncacn_ip_tcp:%s[%s]" % (ctx.server, binding_options)
-        return drs_utils.drs_ReplicateRenamer(binding_str, ctx.lp, repl_creds,
-                                              ctx.local_samdb,
-                                              ctx.invocation_id,
-                                              ctx.base_dn, ctx.new_base_dn)
+        repl = super().create_replicator(repl_creds, binding_options)
+        return drs_utils.drs_ReplicateRenamer(repl, ctx.base_dn, ctx.new_base_dn)
 
     def create_non_global_lp(ctx, global_lp):
         """Creates a non-global LoadParm based on the global LP's settings"""