From: Miss Islington (bot) <31488909+miss-islington@users.noreply.github.com> Date: Mon, 18 Jan 2021 18:36:07 +0000 (-0800) Subject: bpo-42944 Fix Random.sample when counts is not None (GH-24235) (GH-24243) X-Git-Tag: v3.9.2rc1~49 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=a90539f5723a4c34430761be8cba97daa8474abf;p=thirdparty%2FPython%2Fcpython.git bpo-42944 Fix Random.sample when counts is not None (GH-24235) (GH-24243) --- diff --git a/Lib/random.py b/Lib/random.py index 190df6a8c3a5..36e16a9063b5 100644 --- a/Lib/random.py +++ b/Lib/random.py @@ -442,7 +442,7 @@ class Random(_random.Random): raise TypeError('Counts must be integers') if total <= 0: raise ValueError('Total of counts must be greater than zero') - selections = sample(range(total), k=k) + selections = self.sample(range(total), k=k) bisect = _bisect return [population[bisect(cum_counts, s)] for s in selections] randbelow = self._randbelow diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py index a80e71e67e4c..15a68418bdd8 100644 --- a/Lib/test/test_random.py +++ b/Lib/test/test_random.py @@ -207,33 +207,6 @@ class TestBasicOps: with self.assertRaises(ValueError): sample(['red', 'green', 'blue'], counts=[1, 2, 3, 4], k=2) # too many counts - def test_sample_counts_equivalence(self): - # Test the documented strong equivalence to a sample with repeated elements. - # We run this test on random.Random() which makes deterministic selections - # for a given seed value. - sample = random.sample - seed = random.seed - - colors = ['red', 'green', 'blue', 'orange', 'black', 'amber'] - counts = [500, 200, 20, 10, 5, 1 ] - k = 700 - seed(8675309) - s1 = sample(colors, counts=counts, k=k) - seed(8675309) - expanded = [color for (color, count) in zip(colors, counts) for i in range(count)] - self.assertEqual(len(expanded), sum(counts)) - s2 = sample(expanded, k=k) - self.assertEqual(s1, s2) - - pop = 'abcdefghi' - counts = [10, 9, 8, 7, 6, 5, 4, 3, 2] - seed(8675309) - s1 = ''.join(sample(pop, counts=counts, k=30)) - expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)]) - seed(8675309) - s2 = ''.join(sample(expanded, k=30)) - self.assertEqual(s1, s2) - def test_choices(self): choices = self.gen.choices data = ['red', 'green', 'blue', 'yellow'] @@ -888,6 +861,33 @@ class MersenneTwister_TestBasicOps(TestBasicOps, unittest.TestCase): self.assertEqual(self.gen.randbytes(n), gen2.getrandbits(n * 8).to_bytes(n, 'little')) + def test_sample_counts_equivalence(self): + # Test the documented strong equivalence to a sample with repeated elements. + # We run this test on random.Random() which makes deterministic selections + # for a given seed value. + sample = self.gen.sample + seed = self.gen.seed + + colors = ['red', 'green', 'blue', 'orange', 'black', 'amber'] + counts = [500, 200, 20, 10, 5, 1 ] + k = 700 + seed(8675309) + s1 = sample(colors, counts=counts, k=k) + seed(8675309) + expanded = [color for (color, count) in zip(colors, counts) for i in range(count)] + self.assertEqual(len(expanded), sum(counts)) + s2 = sample(expanded, k=k) + self.assertEqual(s1, s2) + + pop = 'abcdefghi' + counts = [10, 9, 8, 7, 6, 5, 4, 3, 2] + seed(8675309) + s1 = ''.join(sample(pop, counts=counts, k=30)) + expanded = ''.join([letter for (letter, count) in zip(pop, counts) for i in range(count)]) + seed(8675309) + s2 = ''.join(sample(expanded, k=30)) + self.assertEqual(s1, s2) + def gamma(z, sqrt2pi=(2.0*pi)**0.5): # Reflection to right half of complex plane diff --git a/Misc/NEWS.d/next/Library/2021-01-18-10-41-44.bpo-42944.RrONvy.rst b/Misc/NEWS.d/next/Library/2021-01-18-10-41-44.bpo-42944.RrONvy.rst new file mode 100644 index 000000000000..b78d10aa2554 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-01-18-10-41-44.bpo-42944.RrONvy.rst @@ -0,0 +1 @@ +Fix ``random.Random.sample`` when ``counts`` argument is not ``None``.