]> git.ipfire.org Git - location/libloc.git/commitdiff
importer: Refactor feed parsing
authorMichael Tremer <michael.tremer@ipfire.org>
Sat, 2 Mar 2024 10:28:02 +0000 (10:28 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Sat, 2 Mar 2024 10:28:02 +0000 (10:28 +0000)
This adds a bit of common code across all feeds, but is only implemented
for AWS at the moment.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/scripts/location-importer.in

index f4a93ee93aef14a45c613d7772d7f4e612987e6c..35f8a6461a3a41f71d80707e7aa10ade53782b2f 100644 (file)
@@ -1791,30 +1791,50 @@ class CLI(object):
                """
                        Update any third-party feeds
                """
-               # AWS
-               self._update_feed_for_aws()
+               success = True
+
+               # Create a downloader
+               downloader = location.importer.Downloader()
+
+               feeds = (
+                       # AWS IP Ranges
+                       ("AWS-IP-RANGES", self._import_aws_ip_ranges, "https://ip-ranges.amazonaws.com/ip-ranges.json"),
+               )
+
+               # Walk through all feeds
+               for name, callback, url, *args in feeds:
+                       try:
+                               self._process_feed(downloader, name, callback, url, *args)
+
+                       # Log an error but continue if an exception occurs
+                       except Exception as e:
+                               log.error("Error processing feed '%s': %s" % (name, e))
+                               success = False
 
                # Spamhaus
                self._update_feed_for_spamhaus_drop()
 
-       def _update_feed_for_aws(self):
-               # Download Amazon AWS IP allocation file to create overrides...
-               downloader = location.importer.Downloader()
+               # Return status
+               return 0 if success else 1
 
-               try:
-                       # Fetch IP ranges
-                       f = downloader.retrieve("https://ip-ranges.amazonaws.com/ip-ranges.json")
+       def _process_feed(self, downloader, name, callback, url, *args):
+               """
+                       Processes one feed
+               """
+               # Open the URL
+               f = downloader.retrieve(url)
 
-                       # Parse downloaded file
-                       aws_ip_dump = json.load(f)
-               except Exception as e:
-                       log.error("unable to preprocess Amazon AWS IP ranges: %s" % e)
-                       return
+               with self.db.transaction():
+                       # Drop any previous content
+                       self.db.execute("DELETE FROM autnum_feeds  WHERE source = %s", name)
+                       self.db.execute("DELETE FROM network_feeds WHERE source = %s", name)
 
-               # At this point, we can assume the downloaded file to be valid
-               self.db.execute("""
-                       DELETE FROM network_feeds WHERE source = 'Amazon AWS IP feed'
-               """)
+                       # Call the callback to process the feed
+                       return callback(name, f, *args)
+
+       def _import_aws_ip_ranges(self, name, f):
+               # Parse the feed
+               aws_ip_dump = json.load(f)
 
                # XXX: Set up a dictionary for mapping a region name to a country. Unfortunately,
                # there seems to be no machine-readable version available of this other than
@@ -1859,63 +1879,62 @@ class CLI(object):
                for row in rows:
                        validcountries.append(row.country_code)
 
-               with self.db.transaction():
-                       for snetwork in aws_ip_dump["prefixes"] + aws_ip_dump["ipv6_prefixes"]:
-                               try:
-                                       network = ipaddress.ip_network(snetwork.get("ip_prefix") or snetwork.get("ipv6_prefix"), strict=False)
-                               except ValueError:
-                                       log.warning("Unable to parse line: %s" % snetwork)
-                                       continue
+               for snetwork in aws_ip_dump["prefixes"] + aws_ip_dump["ipv6_prefixes"]:
+                       try:
+                               network = ipaddress.ip_network(snetwork.get("ip_prefix") or snetwork.get("ipv6_prefix"), strict=False)
+                       except ValueError:
+                               log.warning("Unable to parse line: %s" % snetwork)
+                               continue
 
-                               # Sanitize parsed networks...
-                               if not self._check_parsed_network(network):
-                                       continue
+                       # Sanitize parsed networks...
+                       if not self._check_parsed_network(network):
+                               continue
 
-                               # Determine region of this network...
-                               region = snetwork["region"]
-                               cc = None
-                               is_anycast = False
-
-                               # Any region name starting with "us-" will get "US" country code assigned straight away...
-                               if region.startswith("us-"):
-                                       cc = "US"
-                               elif region.startswith("cn-"):
-                                       # ... same goes for China ...
-                                       cc = "CN"
-                               elif region == "GLOBAL":
-                                       # ... funny region name for anycast-like networks ...
-                                       is_anycast = True
-                               elif region in aws_region_country_map:
-                                       # ... assign looked up country code otherwise ...
-                                       cc = aws_region_country_map[region]
-                               else:
-                                       # ... and bail out if we are missing something here
-                                       log.warning("Unable to determine country code for line: %s" % snetwork)
-                                       continue
+                       # Determine region of this network...
+                       region = snetwork["region"]
+                       cc = None
+                       is_anycast = False
+
+                       # Any region name starting with "us-" will get "US" country code assigned straight away...
+                       if region.startswith("us-"):
+                               cc = "US"
+                       elif region.startswith("cn-"):
+                               # ... same goes for China ...
+                               cc = "CN"
+                       elif region == "GLOBAL":
+                               # ... funny region name for anycast-like networks ...
+                               is_anycast = True
+                       elif region in aws_region_country_map:
+                               # ... assign looked up country code otherwise ...
+                               cc = aws_region_country_map[region]
+                       else:
+                               # ... and bail out if we are missing something here
+                               log.warning("Unable to determine country code for line: %s" % snetwork)
+                               continue
 
-                               # Skip networks with unknown country codes
-                               if not is_anycast and validcountries and cc not in validcountries:
-                                       log.warning("Skipping Amazon AWS network with bogus country '%s': %s" % \
-                                               (cc, network))
-                                       return
+                       # Skip networks with unknown country codes
+                       if not is_anycast and validcountries and cc not in validcountries:
+                               log.warning("Skipping Amazon AWS network with bogus country '%s': %s" % \
+                                       (cc, network))
+                               return
 
-                               # Conduct SQL statement...
-                               self.db.execute("""
-                                       INSERT INTO
-                                               network_feeds
-                                       (
-                                               network,
-                                               source,
-                                               country,
-                                               is_anycast
-                                       )
-                                       VALUES
-                                       (
-                                               %s, %s, %s, %s
-                                       )
-                                       ON CONFLICT (network, source) DO NOTHING
-                                       """, "%s" % network, "Amazon AWS IP feed", cc, is_anycast,
+                       # Conduct SQL statement...
+                       self.db.execute("""
+                               INSERT INTO
+                                       network_feeds
+                               (
+                                       network,
+                                       source,
+                                       country,
+                                       is_anycast
                                )
+                               VALUES
+                               (
+                                       %s, %s, %s, %s
+                               )
+                               ON CONFLICT (network, source) DO NOTHING
+                               """, "%s" % network, "Amazon AWS IP feed", cc, is_anycast,
+                       )
 
        def _update_feed_for_spamhaus_drop(self):
                downloader = location.importer.Downloader()