]> git.ipfire.org Git - people/ms/libloc.git/commitdiff
importer: Wrap everything into asyncio
authorMichael Tremer <michael.tremer@ipfire.org>
Thu, 7 Mar 2024 12:18:14 +0000 (12:18 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Thu, 7 Mar 2024 12:18:14 +0000 (12:18 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/scripts/location-importer.in

index cc2e5e2ffb1dfb98e4372f33db80812ca6112fc3..39bfcfc2210b5702fd49cee7b939e9c62fe9fcdb 100644 (file)
@@ -18,6 +18,7 @@
 ###############################################################################
 
 import argparse
+import asyncio
 import csv
 import functools
 import http.client
@@ -158,7 +159,7 @@ class CLI(object):
 
                return args
 
-       def run(self):
+       async def run(self):
                # Parse command line arguments
                args = self.parse_cli()
 
@@ -169,7 +170,7 @@ class CLI(object):
                self.db = self._setup_database(args)
 
                # Call function
-               ret = args.func(args)
+               ret = await args.func(args)
 
                # Return with exit code
                if ret:
@@ -320,7 +321,7 @@ class CLI(object):
 
                return set((country.country_code for country in countries))
 
-       def handle_write(self, ns):
+       async def handle_write(self, ns):
                """
                        Compiles a database in libloc format out of what is in the database
                """
@@ -697,7 +698,7 @@ class CLI(object):
                for file in ns.file:
                        writer.write(file)
 
-       def handle_update_whois(self, ns):
+       async def handle_update_whois(self, ns):
                # Did we run successfully?
                success = True
 
@@ -756,7 +757,7 @@ class CLI(object):
                                continue
 
                        try:
-                               self._process_source(name, feeds, countries)
+                               await self._process_source(name, feeds, countries)
 
                        # Log an error but continue if an exception occurs
                        except Exception as e:
@@ -766,7 +767,7 @@ class CLI(object):
                # Return a non-zero exit code for errors
                return 0 if success else 1
 
-       def _process_source(self, source, feeds, countries):
+       async def _process_source(self, source, feeds, countries):
                """
                        This function processes one source
                """
@@ -801,7 +802,7 @@ class CLI(object):
                                f = self.downloader.retrieve(url)
 
                                # Call the callback
-                               callback(source, countries, f, *args)
+                               await callback(source, countries, f, *args)
 
                        # Process all parsed networks from every RIR we happen to have access to,
                        # insert the largest network chunks into the networks table immediately...
@@ -948,7 +949,7 @@ class CLI(object):
                                """,
                        )
 
-       def _import_standard_format(self, source, countries, f, *args):
+       async def _import_standard_format(self, source, countries, f, *args):
                """
                        Imports a single standard format source feed
                """
@@ -956,12 +957,12 @@ class CLI(object):
                for block in iterate_over_blocks(f):
                        self._parse_block(block, source, countries)
 
-       def _import_extended_format(self, source, countries, f, *args):
+       async def _import_extended_format(self, source, countries, f, *args):
                # Iterate over all lines
                for line in iterate_over_lines(f):
                        self._parse_line(block, source, countries)
 
-       def _import_arin_as_names(self, source, countries, f, *args):
+       async def _import_arin_as_names(self, source, countries, f, *args):
                # Walk through the file
                for line in csv.DictReader(feed, dialect="arin"):
                        log.debug("Processing object: %s" % line)
@@ -1423,12 +1424,12 @@ class CLI(object):
                        """, "%s" % network, country_code, [country], source_key,
                )
 
-       def handle_update_announcements(self, ns):
+       async def handle_update_announcements(self, ns):
                server = ns.server[0]
 
                with self.db.transaction():
                        if server.startswith("/"):
-                               self._handle_update_announcements_from_bird(server)
+                               await self._handle_update_announcements_from_bird(server)
 
                        # Purge anything we never want here
                        self.db.execute("""
@@ -1487,7 +1488,7 @@ class CLI(object):
                                DELETE FROM announcements WHERE last_seen_at <= CURRENT_TIMESTAMP - INTERVAL '14 days';
                        """)
 
-       def _handle_update_announcements_from_bird(self, server):
+       async def _handle_update_announcements_from_bird(self, server):
                # Pre-compile the regular expression for faster searching
                route = re.compile(b"^\s(.+?)\s+.+?\[(?:AS(.*?))?.\]$")
 
@@ -1605,7 +1606,7 @@ class CLI(object):
                                # Otherwise return the line
                                yield line
 
-       def handle_update_geofeeds(self, ns):
+       async def handle_update_geofeeds(self, ns):
                # Sync geofeeds
                with self.db.transaction():
                        # Delete all geofeeds which are no longer linked
@@ -1673,10 +1674,14 @@ class CLI(object):
                                id
                """)
 
+               ratelimiter = asyncio.Semaphore(32)
+
                # Update all geofeeds
-               for geofeed in geofeeds:
-                       with self.db.transaction():
-                               self._fetch_geofeed(geofeed)
+               async with asyncio.TaskGroup() as tasks:
+                       for geofeed in geofeeds:
+                               task = tasks.create_task(
+                                       self._fetch_geofeed(ratelimiter, geofeed),
+                               )
 
                # Delete data from any feeds that did not update in the last two weeks
                with self.db.transaction():
@@ -1696,139 +1701,146 @@ class CLI(object):
                                        )
                        """)
 
-       def _fetch_geofeed(self, geofeed):
-               log.debug("Fetching Geofeed %s" % geofeed.url)
+       async def _fetch_geofeed(self, ratelimiter, geofeed):
+               async with ratelimiter:
+                       log.debug("Fetching Geofeed %s" % geofeed.url)
 
-               with self.db.transaction():
-                       # Open the URL
-                       try:
-                               # Send the request
-                               f = self.downloader.retrieve(geofeed.url, headers={
-                                       "User-Agent" : "location/%s" % location.__version__,
+                       with self.db.transaction():
+                               # Open the URL
+                               try:
+                                       # Send the request
+                                       f = await asyncio.to_thread(
+                                               self.downloader.retrieve, geofeed.url,
+                                               headers={
+                                                       "User-Agent" : "location/%s" % location.__version__,
+
+                                                       # We expect some plain text file in CSV format
+                                                       "Accept"     : "text/csv, text/plain",
+                                               },
+
+                                               # Don't wait longer than 10 seconds for a response
+                                               #timeout=10,
+                                       )
 
-                                       # We expect some plain text file in CSV format
-                                       "Accept"     : "text/csv, text/plain",
-                               })
+                                       # Remove any previous data
+                                       self.db.execute("DELETE FROM geofeed_networks \
+                                               WHERE geofeed_id = %s", geofeed.id)
 
-                               # Remove any previous data
-                               self.db.execute("DELETE FROM geofeed_networks \
-                                       WHERE geofeed_id = %s", geofeed.id)
+                                       lineno = 0
 
-                               lineno = 0
+                                       # Read the output line by line
+                                       for line in f:
+                                               lineno += 1
 
-                               # Read the output line by line
-                               for line in f:
-                                       lineno += 1
+                                               try:
+                                                       line = line.decode()
 
-                                       try:
-                                               line = line.decode()
+                                               # Ignore any lines we cannot decode
+                                               except UnicodeDecodeError:
+                                                       log.debug("Could not decode line %s in %s" \
+                                                               % (lineno, geofeed.url))
+                                                       continue
 
-                                       # Ignore any lines we cannot decode
-                                       except UnicodeDecodeError:
-                                               log.debug("Could not decode line %s in %s" \
-                                                       % (lineno, geofeed.url))
-                                               continue
+                                               # Strip any newline
+                                               line = line.rstrip()
 
-                                       # Strip any newline
-                                       line = line.rstrip()
+                                               # Skip empty lines
+                                               if not line:
+                                                       continue
 
-                                       # Skip empty lines
-                                       if not line:
-                                               continue
-
-                                       # Skip comments
-                                       elif line.startswith("#"):
-                                               continue
+                                               # Skip comments
+                                               elif line.startswith("#"):
+                                                       continue
 
-                                       # Try to parse the line
-                                       try:
-                                               fields = line.split(",", 5)
-                                       except ValueError:
-                                               log.debug("Could not parse line: %s" % line)
-                                               continue
+                                               # Try to parse the line
+                                               try:
+                                                       fields = line.split(",", 5)
+                                               except ValueError:
+                                                       log.debug("Could not parse line: %s" % line)
+                                                       continue
 
-                                       # Check if we have enough fields
-                                       if len(fields) < 4:
-                                               log.debug("Not enough fields in line: %s" % line)
-                                               continue
+                                               # Check if we have enough fields
+                                               if len(fields) < 4:
+                                                       log.debug("Not enough fields in line: %s" % line)
+                                                       continue
 
-                                       # Fetch all fields
-                                       network, country, region, city, = fields[:4]
+                                               # Fetch all fields
+                                               network, country, region, city, = fields[:4]
 
-                                       # Try to parse the network
-                                       try:
-                                               network = ipaddress.ip_network(network, strict=False)
-                                       except ValueError:
-                                               log.debug("Could not parse network: %s" % network)
-                                               continue
+                                               # Try to parse the network
+                                               try:
+                                                       network = ipaddress.ip_network(network, strict=False)
+                                               except ValueError:
+                                                       log.debug("Could not parse network: %s" % network)
+                                                       continue
+
+                                               # Strip any excess whitespace from country codes
+                                               country = country.strip()
+
+                                               # Make the country code uppercase
+                                               country = country.upper()
+
+                                               # Check the country code
+                                               if not country:
+                                                       log.debug("Empty country code in Geofeed %s line %s" \
+                                                               % (geofeed.url, lineno))
+                                                       continue
+
+                                               elif not location.country_code_is_valid(country):
+                                                       log.debug("Invalid country code in Geofeed %s:%s: %s" \
+                                                               % (geofeed.url, lineno, country))
+                                                       continue
+
+                                               # Write this into the database
+                                               self.db.execute("""
+                                                       INSERT INTO
+                                                               geofeed_networks (
+                                                                       geofeed_id,
+                                                                       network,
+                                                                       country,
+                                                                       region,
+                                                                       city
+                                                               )
+                                                       VALUES (%s, %s, %s, %s, %s)""",
+                                                       geofeed.id,
+                                                       "%s" % network,
+                                                       country,
+                                                       region,
+                                                       city,
+                                               )
 
-                                       # Strip any excess whitespace from country codes
-                                       country = country.strip()
+                               # Catch any HTTP errors
+                               except urllib.request.HTTPError as e:
+                                       self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
+                                               WHERE id = %s", e.code, "%s" % e, geofeed.id)
 
-                                       # Make the country code uppercase
-                                       country = country.upper()
+                                       # Remove any previous data when the feed has been deleted
+                                       if e.code == 404:
+                                               self.db.execute("DELETE FROM geofeed_networks \
+                                                       WHERE geofeed_id = %s", geofeed.id)
 
-                                       # Check the country code
-                                       if not country:
-                                               log.debug("Empty country code in Geofeed %s line %s" \
-                                                       % (geofeed.url, lineno))
-                                               continue
+                               # Catch any other errors and connection timeouts
+                               except (http.client.InvalidURL, urllib.request.URLError, TimeoutError) as e:
+                                       log.debug("Could not fetch URL %s: %s" % (geofeed.url, e))
 
-                                       elif not location.country_code_is_valid(country):
-                                               log.debug("Invalid country code in Geofeed %s:%s: %s" \
-                                                       % (geofeed.url, lineno, country))
-                                               continue
+                                       self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
+                                               WHERE id = %s", 599, "%s" % e, geofeed.id)
 
-                                       # Write this into the database
+                               # Mark the geofeed as updated
+                               else:
                                        self.db.execute("""
-                                               INSERT INTO
-                                                       geofeed_networks (
-                                                               geofeed_id,
-                                                               network,
-                                                               country,
-                                                               region,
-                                                               city
-                                                       )
-                                               VALUES (%s, %s, %s, %s, %s)""",
+                                               UPDATE
+                                                       geofeeds
+                                               SET
+                                                       updated_at = CURRENT_TIMESTAMP,
+                                                       status = NULL,
+                                                       error = NULL
+                                               WHERE
+                                                       id = %s""",
                                                geofeed.id,
-                                               "%s" % network,
-                                               country,
-                                               region,
-                                               city,
                                        )
 
-                       # Catch any HTTP errors
-                       except urllib.request.HTTPError as e:
-                               self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
-                                       WHERE id = %s", e.code, "%s" % e, geofeed.id)
-
-                               # Remove any previous data when the feed has been deleted
-                               if e.code == 404:
-                                       self.db.execute("DELETE FROM geofeed_networks \
-                                               WHERE geofeed_id = %s", geofeed.id)
-
-                       # Catch any other errors and connection timeouts
-                       except (http.client.InvalidURL, urllib.request.URLError, TimeoutError) as e:
-                               log.debug("Could not fetch URL %s: %s" % (geofeed.url, e))
-
-                               self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
-                                       WHERE id = %s", 599, "%s" % e, geofeed.id)
-
-                       # Mark the geofeed as updated
-                       else:
-                               self.db.execute("""
-                                       UPDATE
-                                               geofeeds
-                                       SET
-                                               updated_at = CURRENT_TIMESTAMP,
-                                               status = NULL,
-                                               error = NULL
-                                       WHERE
-                                               id = %s""",
-                                       geofeed.id,
-                               )
-
-       def handle_update_overrides(self, ns):
+       async def handle_update_overrides(self, ns):
                with self.db.transaction():
                        # Drop any previous content
                        self.db.execute("TRUNCATE TABLE autnum_overrides")
@@ -1954,7 +1966,7 @@ class CLI(object):
                                                else:
                                                        log.warning("Unsupported type: %s" % type)
 
-       def handle_update_feeds(self, ns):
+       async def handle_update_feeds(self, ns):
                """
                        Update any third-party feeds
                """
@@ -1988,7 +2000,7 @@ class CLI(object):
                                continue
 
                        try:
-                               self._process_feed(name, callback, url, *args)
+                               await self._process_feed(name, callback, url, *args)
 
                        # Log an error but continue if an exception occurs
                        except Exception as e:
@@ -1998,7 +2010,7 @@ class CLI(object):
                # Return status
                return 0 if success else 1
 
-       def _process_feed(self, name, callback, url, *args):
+       async def _process_feed(self, name, callback, url, *args):
                """
                        Processes one feed
                """
@@ -2011,9 +2023,9 @@ class CLI(object):
                        self.db.execute("DELETE FROM network_feeds WHERE source = %s", name)
 
                        # Call the callback to process the feed
-                       return callback(name, f, *args)
+                       return await callback(name, f, *args)
 
-       def _import_aws_ip_ranges(self, name, f):
+       async def _import_aws_ip_ranges(self, name, f):
                # Parse the feed
                feed = json.load(f)
 
@@ -2135,7 +2147,7 @@ class CLI(object):
                                """, "%s" % network, name, cc, is_anycast,
                        )
 
-       def _import_spamhaus_drop(self, name, f):
+       async def _import_spamhaus_drop(self, name, f):
                """
                        Import Spamhaus DROP IP feeds
                """
@@ -2191,7 +2203,7 @@ class CLI(object):
                if not lines:
                        raise RuntimeError("Received bogus feed %s with no data" % name)
 
-       def _import_spamhaus_asndrop(self, name, f):
+       async def _import_spamhaus_asndrop(self, name, f):
                """
                        Import Spamhaus ASNDROP feed
                """
@@ -2262,7 +2274,7 @@ class CLI(object):
                # Default to None
                return None
 
-       def handle_import_countries(self, ns):
+       async def handle_import_countries(self, ns):
                with self.db.transaction():
                        # Drop all data that we have
                        self.db.execute("TRUNCATE TABLE countries")
@@ -2362,9 +2374,10 @@ def iterate_over_lines(f):
                # Strip the ending
                yield line.rstrip()
 
-def main():
+async def main():
        # Run the command line interface
        c = CLI()
-       c.run()
 
-main()
+       await c.run()
+
+asyncio.run(main())