From a7e9f8f1b31c8dad319937e439908f3ed66e7144 Mon Sep 17 00:00:00 2001 From: Michael Tremer Date: Wed, 21 Feb 2024 16:19:57 +0000 Subject: [PATCH] tests: Build out more dedup tests Signed-off-by: Michael Tremer --- tests/python/networks-dedup.py | 63 +++++++++++++++++++++++++++------- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/tests/python/networks-dedup.py b/tests/python/networks-dedup.py index bcb8f9c..abceaae 100644 --- a/tests/python/networks-dedup.py +++ b/tests/python/networks-dedup.py @@ -23,21 +23,25 @@ import tempfile import unittest class Test(unittest.TestCase): - def test_dudup_simple(self): + def __test(self, inputs, outputs): """ - Creates a couple of redundant networks and expects fewer being written + Takes a list of networks that are written to the database and + compares the result with the second argument. """ with tempfile.NamedTemporaryFile() as f: w = location.Writer() - # Add 10.0.0.0/8 - n = w.add_network("10.0.0.0/8") + # Add all inputs + for network, cc, asn in inputs: + n = w.add_network(network) - # Add 10.0.0.0/16 - w.add_network("10.0.0.0/16") + # Add CC + if cc: + n.country_code = cc - # Add 10.0.0.0/24 - w.add_network("10.0.0.0/24") + # Add ASN + if asn: + n.asn = asn # Write file w.write(f.name) @@ -45,11 +49,46 @@ class Test(unittest.TestCase): # Re-open the database db = location.Database(f.name) - for i, network in enumerate(db.networks): - # The only network we should see is 10.0.0.0/8 - self.assertEqual(network, n) + # Check if the output matches what we expect + self.assertCountEqual( + outputs, ["%s" % network for network in db.networks], + ) + + def test_dudup_simple(self): + """ + Creates a couple of redundant networks and expects fewer being written + """ + self.__test( + ( + ("10.0.0.0/8", None, None), + ("10.0.0.0/16", None, None), + ("10.0.0.0/24", None, None), + ), + + # Everything should be put into the /8 subnet + ("10.0.0.0/8",), + ) + + def test_dedup_noop(self): + """ + Nothing should be changed here + """ + self.maxDiff = None + + networks = ( + ("10.0.0.0/8", None, None), + ("20.0.0.0/8", None, None), + ("30.0.0.0/8", None, None), + ("40.0.0.0/8", None, None), + ("50.0.0.0/8", None, None), + ("60.0.0.0/8", None, None), + ("70.0.0.0/8", None, None), + ("80.0.0.0/8", None, None), + ("90.0.0.0/8", None, None), + ) - self.assertTrue(i == 0) + # The input should match the output + self.__test(networks, [network for network, cc, asn in networks]) if __name__ == "__main__": -- 2.39.2