]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Fix negative bandwidth test and add online code path test. (gh-118600)
authorRaymond Hettinger <rhettinger@users.noreply.github.com>
Sun, 5 May 2024 17:29:23 +0000 (12:29 -0500)
committerGitHub <noreply@github.com>
Sun, 5 May 2024 17:29:23 +0000 (12:29 -0500)
Lib/statistics.py
Lib/test/test_statistics.py

index f3ce2d8b6b442a5d2145882cefa6e9e43e6ff6d4..c2f4fe8e054d3ddf37c1a0fe50528f6352dc57e1 100644 (file)
@@ -1791,9 +1791,8 @@ def kde_random(data, h, kernel='normal', *, seed=None):
     if h <= 0.0:
         raise StatisticsError(f'Bandwidth h must be positive, not {h=!r}')
 
-    try:
-        kernel_invcdf = _kernel_invcdfs[kernel]
-    except KeyError:
+    kernel_invcdf = _kernel_invcdfs.get(kernel)
+    if kernel_invcdf is None:
         raise StatisticsError(f'Unknown kernel name: {kernel!r}')
 
     prng = _random.Random(seed)
index a60791e9b6e1f5bc4c55047d96584cf05c90b9b2..40680759d456ac7da3d569fe84827b6ce719214e 100644 (file)
@@ -2402,7 +2402,7 @@ class TestKDE(unittest.TestCase):
         with self.assertRaises(StatisticsError):
             kde(sample, h=0.0)                          # Zero bandwidth
         with self.assertRaises(StatisticsError):
-            kde(sample, h=0.0)                          # Negative bandwidth
+            kde(sample, h=-1.0)                         # Negative bandwidth
         with self.assertRaises(TypeError):
             kde(sample, h='str')                        # Wrong bandwidth type
         with self.assertRaises(StatisticsError):
@@ -2426,6 +2426,14 @@ class TestKDE(unittest.TestCase):
         self.assertEqual(f_hat(-1.0), 1/2)
         self.assertEqual(f_hat(1.0), 1/2)
 
+        # Test online updates to data
+
+        data = [1, 2]
+        f_hat = kde(data, 5.0, 'triangular')
+        self.assertEqual(f_hat(100), 0.0)
+        data.append(100)
+        self.assertGreater(f_hat(100), 0.0)
+
     def test_kde_kernel_invcdfs(self):
         kernel_invcdfs = statistics._kernel_invcdfs
         kde = statistics.kde
@@ -2462,7 +2470,7 @@ class TestKDE(unittest.TestCase):
         with self.assertRaises(TypeError):
             kde_random(iter(sample), 1.5)               # Data is not a sequence
         with self.assertRaises(StatisticsError):
-            kde_random(sample, h=0.0)                   # Zero bandwidth
+            kde_random(sample, h=-1.0)                  # Zero bandwidth
         with self.assertRaises(StatisticsError):
             kde_random(sample, h=0.0)                   # Negative bandwidth
         with self.assertRaises(TypeError):
@@ -2474,10 +2482,10 @@ class TestKDE(unittest.TestCase):
 
         h = 1.5
         kernel = 'cosine'
-        prng = kde_random(sample, h, kernel)
-        self.assertEqual(prng.__name__, 'rand')
-        self.assertIn(kernel, prng.__doc__)
-        self.assertIn(repr(h), prng.__doc__)
+        rand = kde_random(sample, h, kernel)
+        self.assertEqual(rand.__name__, 'rand')
+        self.assertIn(kernel, rand.__doc__)
+        self.assertIn(repr(h), rand.__doc__)
 
         # Approximate distribution test: Compare a random sample to the expected distribution
 
@@ -2507,6 +2515,14 @@ class TestKDE(unittest.TestCase):
                 for x in xarr:
                     self.assertTrue(math.isclose(p_observed(x), p_expected(x), abs_tol=0.0005))
 
+        # Test online updates to data
+
+        data = [1, 2]
+        rand = kde_random(data, 5, 'triangular')
+        self.assertLess(max([rand() for i in range(5000)]), 10)
+        data.append(100)
+        self.assertGreater(max(rand() for i in range(5000)), 10)
+
 
 class TestQuantiles(unittest.TestCase):