]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
correct and simplify weighted_processing_order()
authorBob Halley <halley@dnspython.org>
Wed, 9 Sep 2020 12:35:36 +0000 (05:35 -0700)
committerBob Halley <halley@dnspython.org>
Wed, 9 Sep 2020 12:35:36 +0000 (05:35 -0700)
dns/rdtypes/ANY/URI.py
dns/rdtypes/IN/SRV.py
dns/rdtypes/util.py
tests/test_processing_order.py

index 4e7ad01c6f20939e01716372e106885935fa9e7f..ccbd2ce4bf8cbaedbf3dd011f1e2882100d3f6b3 100644 (file)
@@ -76,4 +76,4 @@ class URI(dns.rdata.Rdata):
 
     @classmethod
     def _processing_order(cls, iterable):
-        return dns.rdtypes.util.weighted_processing_order(iterable, False)
+        return dns.rdtypes.util.weighted_processing_order(iterable)
index 3eb227f98ec70198fdc7756573a5bdd5e27448ca..6d9b683af8ab09ad2c59eed0e8855771e057dec3 100644 (file)
@@ -72,4 +72,4 @@ class SRV(dns.rdata.Rdata):
 
     @classmethod
     def _processing_order(cls, iterable):
-        return dns.rdtypes.util.weighted_processing_order(iterable, True)
+        return dns.rdtypes.util.weighted_processing_order(iterable)
index 695754df924532f80dec8bd2e73ebf15daaf760f..7fc08cde8ff2ef7cff29da05189e92dca61491a2 100644 (file)
@@ -190,7 +190,6 @@ class Bitmap:
 def _priority_table(items):
     by_priority = collections.defaultdict(list)
     for rdata in items:
-        key = rdata._processing_priority()
         by_priority[rdata._processing_priority()].append(rdata)
     return by_priority
 
@@ -206,43 +205,27 @@ def priority_processing_order(iterable):
         ordered.extend(rdatas)
     return ordered
 
-def _processing_weight(rdata, adjust_zero_weight):
-    weight = rdata._processing_weight()
-    if weight == 0 and adjust_zero_weight:
-        return 0.1
-    else:
-        return weight
+_no_weight = 0.1
 
-def weighted_processing_order(iterable, adjust_zero_weight=False):
+def weighted_processing_order(iterable):
     items = list(iterable)
     if len(items) == 1:
         return items
     by_priority = _priority_table(items)
     ordered = []
     for k in sorted(by_priority.keys()):
-        weights_vary = False
-        weights = []
         rdatas = by_priority[k]
-        for rdata in rdatas:
-            weight = _processing_weight(rdata, adjust_zero_weight)
-            if len(weights) > 0 and weight != weights[-1]:
-                weights_vary = True
-            weights.append(weight)
-        if weights_vary:
-            while len(rdatas) > 1:
-                items = random.choices(rdatas, weights)
-                rdata = items[0]
-                ordered.append(rdata)
-                rdatas.remove(rdata)
-                weight = _processing_weight(rdata, adjust_zero_weight)
-                weights.remove(weight)
-            ordered.append(rdatas[0])
-        elif weights[0] == 0:
-            # All the weights are 0!  (This can't happen with SRV, but
-            # can with URI.  It's not clear from the URI RFC what you do here
-            # as it doesn't discuss weight.
-            return []
-        else:
-            random.shuffle(rdatas)
-            ordered.extend(rdatas)
+        total = sum(rdata._processing_weight() or _no_weight
+                    for rdata in rdatas)
+        while len(rdatas) > 1:
+            r = random.uniform(0, total)
+            for (n, rdata) in enumerate(rdatas):
+                weight = rdata._processing_weight() or _no_weight
+                if weight > r:
+                    break
+                r -= weight
+            total -= weight
+            ordered.append(rdata)
+            del rdatas[n]
+        ordered.append(rdatas[0])
     return ordered
index cc33fd4e6dc9956387e75082964a3aca1fc52e9e..2fa1b27124656e4c6d991dea997ad07f0db04161 100644 (file)
@@ -102,12 +102,3 @@ def test_processing_all_zero_weight_srv():
             assert rds[j] in po
         seen.add(tuple(po))
     assert len(seen) == 6
-
-
-def test_processing_all_zero_weight_uri():
-    rds = dns.rdataset.from_text('in', 'uri', 300,
-                                 '10 0 "ftp://ftp1.example.com/public"',
-                                 '10 0 "ftp://ftp2.example.com/public"',
-                                 '10 0 "ftp://ftp3.example.com/public"')
-    po = rds.processing_order()
-    assert len(po) == 0