]> git.ipfire.org Git - people/ms/libloc.git/commitdiff
importer: Simplify fetching countries
authorMichael Tremer <michael.tremer@ipfire.org>
Sat, 2 Mar 2024 10:36:00 +0000 (10:36 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Sat, 2 Mar 2024 10:36:00 +0000 (10:36 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/scripts/location-importer.in

index 35f8a6461a3a41f71d80707e7aa10ade53782b2f..5c81641d6c3fb3e37f1b3bf07b259401219c5313 100644 (file)
@@ -284,6 +284,15 @@ class CLI(object):
 
                return db
 
+       def fetch_countries(self):
+               """
+                       Returns a list of all countries on the list
+               """
+               # Fetch all valid country codes to check parsed networks aganist...
+               countries = self.db.query("SELECT country_code FROM countries ORDER BY country_code")
+
+               return [country.country_code for country in countries]
+
        def handle_write(self, ns):
                """
                        Compiles a database in libloc format out of what is in the database
@@ -668,7 +677,7 @@ class CLI(object):
                error = False
 
                # Fetch all valid country codes to check parsed networks against
-               validcountries = self.countries
+               validcountries = self.fetch_countries()
 
                # Iterate over all potential sources
                for source in sorted(location.importer.SOURCES):
@@ -1872,12 +1881,8 @@ class CLI(object):
                                "sa-east-1": "BR"
                                }
 
-               # Fetch all valid country codes to check parsed networks aganist...
-               rows = self.db.query("SELECT * FROM countries ORDER BY country_code")
-               validcountries = []
-
-               for row in rows:
-                       validcountries.append(row.country_code)
+               # Fetch all countries that we know of
+               countries = self.fetch_countries()
 
                for snetwork in aws_ip_dump["prefixes"] + aws_ip_dump["ipv6_prefixes"]:
                        try:
@@ -1913,7 +1918,7 @@ class CLI(object):
                                continue
 
                        # Skip networks with unknown country codes
-                       if not is_anycast and validcountries and cc not in validcountries:
+                       if not is_anycast and countries and cc not in countries:
                                log.warning("Skipping Amazon AWS network with bogus country '%s': %s" % \
                                        (cc, network))
                                return
@@ -2091,14 +2096,6 @@ class CLI(object):
                # Default to None
                return None
 
-       @property
-       def countries(self):
-               # Fetch all valid country codes to check parsed networks aganist
-               rows = self.db.query("SELECT * FROM countries ORDER BY country_code")
-
-               # Return all countries
-               return [row.country_code for row in rows]
-
        def handle_import_countries(self, ns):
                with self.db.transaction():
                        # Drop all data that we have