]> git.ipfire.org Git - people/ms/libloc.git/blobdiff - src/scripts/location-importer.in
importer: Fix incorrect variable name
[people/ms/libloc.git] / src / scripts / location-importer.in
index 55a293bee83314135dfddc0fcd48b2cac4a0e587..b9d8eccb824edf61c296de982cb7fed562c742bf 100644 (file)
@@ -18,6 +18,7 @@
 ###############################################################################
 
 import argparse
+import asyncio
 import csv
 import functools
 import http.client
@@ -158,7 +159,7 @@ class CLI(object):
 
                return args
 
-       def run(self):
+       async def run(self):
                # Parse command line arguments
                args = self.parse_cli()
 
@@ -169,7 +170,7 @@ class CLI(object):
                self.db = self._setup_database(args)
 
                # Call function
-               ret = args.func(args)
+               ret = await args.func(args)
 
                # Return with exit code
                if ret:
@@ -320,7 +321,7 @@ class CLI(object):
 
                return set((country.country_code for country in countries))
 
-       def handle_write(self, ns):
+       async def handle_write(self, ns):
                """
                        Compiles a database in libloc format out of what is in the database
                """
@@ -697,7 +698,7 @@ class CLI(object):
                for file in ns.file:
                        writer.write(file)
 
-       def handle_update_whois(self, ns):
+       async def handle_update_whois(self, ns):
                # Did we run successfully?
                success = True
 
@@ -756,7 +757,7 @@ class CLI(object):
                                continue
 
                        try:
-                               self._process_source(name, feeds, countries)
+                               await self._process_source(name, feeds, countries)
 
                        # Log an error but continue if an exception occurs
                        except Exception as e:
@@ -766,7 +767,7 @@ class CLI(object):
                # Return a non-zero exit code for errors
                return 0 if success else 1
 
-       def _process_source(self, source, feeds, countries):
+       async def _process_source(self, source, feeds, countries):
                """
                        This function processes one source
                """
@@ -801,7 +802,8 @@ class CLI(object):
                                f = self.downloader.retrieve(url)
 
                                # Call the callback
-                               callback(source, countries, f, *args)
+                               with self.db.pipeline():
+                                       await callback(source, countries, f, *args)
 
                        # Process all parsed networks from every RIR we happen to have access to,
                        # insert the largest network chunks into the networks table immediately...
@@ -948,7 +950,7 @@ class CLI(object):
                                """,
                        )
 
-       def _import_standard_format(self, source, countries, f, *args):
+       async def _import_standard_format(self, source, countries, f, *args):
                """
                        Imports a single standard format source feed
                """
@@ -956,12 +958,12 @@ class CLI(object):
                for block in iterate_over_blocks(f):
                        self._parse_block(block, source, countries)
 
-       def _import_extended_format(self, source, countries, f, *args):
+       async def _import_extended_format(self, source, countries, f, *args):
                # Iterate over all lines
                for line in iterate_over_lines(f):
-                       self._parse_line(block, source, countries)
+                       self._parse_line(line, source, countries)
 
-       def _import_arin_as_names(self, source, countries, f, *args):
+       async def _import_arin_as_names(self, source, countries, f, *args):
                # Walk through the file
                for line in csv.DictReader(feed, dialect="arin"):
                        log.debug("Processing object: %s" % line)
@@ -1423,12 +1425,12 @@ class CLI(object):
                        """, "%s" % network, country_code, [country], source_key,
                )
 
-       def handle_update_announcements(self, ns):
+       async def handle_update_announcements(self, ns):
                server = ns.server[0]
 
                with self.db.transaction():
                        if server.startswith("/"):
-                               self._handle_update_announcements_from_bird(server)
+                               await self._handle_update_announcements_from_bird(server)
 
                        # Purge anything we never want here
                        self.db.execute("""
@@ -1487,7 +1489,7 @@ class CLI(object):
                                DELETE FROM announcements WHERE last_seen_at <= CURRENT_TIMESTAMP - INTERVAL '14 days';
                        """)
 
-       def _handle_update_announcements_from_bird(self, server):
+       async def _handle_update_announcements_from_bird(self, server):
                # Pre-compile the regular expression for faster searching
                route = re.compile(b"^\s(.+?)\s+.+?\[(?:AS(.*?))?.\]$")
 
@@ -1605,7 +1607,7 @@ class CLI(object):
                                # Otherwise return the line
                                yield line
 
-       def handle_update_geofeeds(self, ns):
+       async def handle_update_geofeeds(self, ns):
                # Sync geofeeds
                with self.db.transaction():
                        # Delete all geofeeds which are no longer linked
@@ -1673,10 +1675,14 @@ class CLI(object):
                                id
                """)
 
+               ratelimiter = asyncio.Semaphore(32)
+
                # Update all geofeeds
-               for geofeed in geofeeds:
-                       with self.db.transaction():
-                               self._fetch_geofeed(geofeed)
+               async with asyncio.TaskGroup() as tasks:
+                       for geofeed in geofeeds:
+                               task = tasks.create_task(
+                                       self._fetch_geofeed(ratelimiter, geofeed),
+                               )
 
                # Delete data from any feeds that did not update in the last two weeks
                with self.db.transaction():
@@ -1696,135 +1702,152 @@ class CLI(object):
                                        )
                        """)
 
-       def _fetch_geofeed(self, geofeed):
-               log.debug("Fetching Geofeed %s" % geofeed.url)
+       async def _fetch_geofeed(self, ratelimiter, geofeed):
+               async with ratelimiter:
+                       log.debug("Fetching Geofeed %s" % geofeed.url)
 
-               with self.db.transaction():
-                       # Open the URL
-                       try:
-                               # Send the request
-                               f = self.downloader.retrieve(geofeed.url, headers={
-                                       "User-Agent" : "location/%s" % location.__version__,
+                       with self.db.transaction():
+                               # Open the URL
+                               try:
+                                       # Send the request
+                                       f = await asyncio.to_thread(
+                                               self.downloader.retrieve,
 
-                                       # We expect some plain text file in CSV format
-                                       "Accept"     : "text/csv, text/plain",
-                               })
+                                               # Fetch the feed by its URL
+                                               geofeed.url,
 
-                               # Remove any previous data
-                               self.db.execute("DELETE FROM geofeed_networks \
-                                       WHERE geofeed_id = %s", geofeed.id)
+                                               # Send some extra headers
+                                               headers={
+                                                       "User-Agent" : "location/%s" % location.__version__,
 
-                               lineno = 0
+                                                       # We expect some plain text file in CSV format
+                                                       "Accept"     : "text/csv, text/plain",
+                                               },
 
-                               # Read the output line by line
-                               for line in f:
-                                       lineno += 1
+                                               # Don't wait longer than 10 seconds for a response
+                                               timeout=10,
+                                       )
 
-                                       try:
-                                               line = line.decode()
+                                       # Remove any previous data
+                                       self.db.execute("DELETE FROM geofeed_networks \
+                                               WHERE geofeed_id = %s", geofeed.id)
 
-                                       # Ignore any lines we cannot decode
-                                       except UnicodeDecodeError:
-                                               log.debug("Could not decode line %s in %s" \
-                                                       % (lineno, geofeed.url))
-                                               continue
+                                       lineno = 0
 
-                                       # Strip any newline
-                                       line = line.rstrip()
+                                       # Read the output line by line
+                                       with self.db.pipeline():
+                                               for line in f:
+                                                       lineno += 1
 
-                                       # Skip empty lines
-                                       if not line:
-                                               continue
+                                                       try:
+                                                               line = line.decode()
 
-                                       # Try to parse the line
-                                       try:
-                                               fields = line.split(",", 5)
-                                       except ValueError:
-                                               log.debug("Could not parse line: %s" % line)
-                                               continue
+                                                       # Ignore any lines we cannot decode
+                                                       except UnicodeDecodeError:
+                                                               log.debug("Could not decode line %s in %s" \
+                                                                       % (lineno, geofeed.url))
+                                                               continue
 
-                                       # Check if we have enough fields
-                                       if len(fields) < 4:
-                                               log.debug("Not enough fields in line: %s" % line)
-                                               continue
+                                                       # Strip any newline
+                                                       line = line.rstrip()
 
-                                       # Fetch all fields
-                                       network, country, region, city, = fields[:4]
+                                                       # Skip empty lines
+                                                       if not line:
+                                                               continue
 
-                                       # Try to parse the network
-                                       try:
-                                               network = ipaddress.ip_network(network, strict=False)
-                                       except ValueError:
-                                               log.debug("Could not parse network: %s" % network)
-                                               continue
+                                                       # Skip comments
+                                                       elif line.startswith("#"):
+                                                               continue
 
-                                       # Strip any excess whitespace from country codes
-                                       country = country.strip()
+                                                       # Try to parse the line
+                                                       try:
+                                                               fields = line.split(",", 5)
+                                                       except ValueError:
+                                                               log.debug("Could not parse line: %s" % line)
+                                                               continue
 
-                                       # Make the country code uppercase
-                                       country = country.upper()
+                                                       # Check if we have enough fields
+                                                       if len(fields) < 4:
+                                                               log.debug("Not enough fields in line: %s" % line)
+                                                               continue
 
-                                       # Check the country code
-                                       if not country:
-                                               log.debug("Empty country code in Geofeed %s line %s" \
-                                                       % (geofeed.url, lineno))
-                                               continue
+                                                       # Fetch all fields
+                                                       network, country, region, city, = fields[:4]
 
-                                       elif not location.country_code_is_valid(country):
-                                               log.debug("Invalid country code in Geofeed %s:%s: %s" \
-                                                       % (geofeed.url, lineno, country))
-                                               continue
+                                                       # Try to parse the network
+                                                       try:
+                                                               network = ipaddress.ip_network(network, strict=False)
+                                                       except ValueError:
+                                                               log.debug("Could not parse network: %s" % network)
+                                                               continue
 
-                                       # Write this into the database
-                                       self.db.execute("""
-                                               INSERT INTO
-                                                       geofeed_networks (
-                                                               geofeed_id,
-                                                               network,
+                                                       # Strip any excess whitespace from country codes
+                                                       country = country.strip()
+
+                                                       # Make the country code uppercase
+                                                       country = country.upper()
+
+                                                       # Check the country code
+                                                       if not country:
+                                                               log.debug("Empty country code in Geofeed %s line %s" \
+                                                                       % (geofeed.url, lineno))
+                                                               continue
+
+                                                       elif not location.country_code_is_valid(country):
+                                                               log.debug("Invalid country code in Geofeed %s:%s: %s" \
+                                                                       % (geofeed.url, lineno, country))
+                                                               continue
+
+                                                       # Write this into the database
+                                                       self.db.execute("""
+                                                               INSERT INTO
+                                                                       geofeed_networks (
+                                                                               geofeed_id,
+                                                                               network,
+                                                                               country,
+                                                                               region,
+                                                                               city
+                                                                       )
+                                                               VALUES (%s, %s, %s, %s, %s)""",
+                                                               geofeed.id,
+                                                               "%s" % network,
                                                                country,
                                                                region,
-                                                               city
+                                                               city,
                                                        )
-                                               VALUES (%s, %s, %s, %s, %s)""",
-                                               geofeed.id,
-                                               "%s" % network,
-                                               country,
-                                               region,
-                                               city,
-                                       )
 
-                       # Catch any HTTP errors
-                       except urllib.request.HTTPError as e:
-                               self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
-                                       WHERE id = %s", e.code, "%s" % e, geofeed.id)
+                               # Catch any HTTP errors
+                               except urllib.request.HTTPError as e:
+                                       self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
+                                               WHERE id = %s", e.code, "%s" % e, geofeed.id)
 
-                               # Remove any previous data when the feed has been deleted
-                               if e.code == 404:
-                                       self.db.execute("DELETE FROM geofeed_networks \
-                                               WHERE geofeed_id = %s", geofeed.id)
+                                       # Remove any previous data when the feed has been deleted
+                                       if e.code == 404:
+                                               self.db.execute("DELETE FROM geofeed_networks \
+                                                       WHERE geofeed_id = %s", geofeed.id)
 
-                       # Catch any other errors and connection timeouts
-                       except (http.client.InvalidURL, urllib.request.URLError, TimeoutError) as e:
-                               log.debug("Could not fetch URL %s: %s" % (geofeed.url, e))
+                               # Catch any other errors and connection timeouts
+                               except (http.client.InvalidURL, urllib.request.URLError, TimeoutError) as e:
+                                       log.debug("Could not fetch URL %s: %s" % (geofeed.url, e))
 
-                               self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
-                                       WHERE id = %s", 599, "%s" % e, geofeed.id)
+                                       self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
+                                               WHERE id = %s", 599, "%s" % e, geofeed.id)
 
-                       # Mark the geofeed as updated
-                       else:
-                               self.db.execute("""
-                                       UPDATE
-                                               geofeeds
-                                       SET
-                                               updated_at = CURRENT_TIMESTAMP,
-                                               status = NULL,
-                                               error = NULL
-                                       WHERE
-                                               id = %s""",
-                                       geofeed.id,
-                               )
+                               # Mark the geofeed as updated
+                               else:
+                                       self.db.execute("""
+                                               UPDATE
+                                                       geofeeds
+                                               SET
+                                                       updated_at = CURRENT_TIMESTAMP,
+                                                       status = NULL,
+                                                       error = NULL
+                                               WHERE
+                                                       id = %s""",
+                                               geofeed.id,
+                                       )
 
-       def handle_update_overrides(self, ns):
+       async def handle_update_overrides(self, ns):
                with self.db.transaction():
                        # Drop any previous content
                        self.db.execute("TRUNCATE TABLE autnum_overrides")
@@ -1917,7 +1940,21 @@ class CLI(object):
                                                elif type == "geofeed":
                                                        url = block.get("geofeed")
 
-                                                       # XXX Check the URL
+                                                       # Parse the URL
+                                                       try:
+                                                               url = urllib.parse.urlparse(url)
+                                                       except ValueError as e:
+                                                               log.warning("Skipping invalid URL %s: %s" % (url, e))
+                                                               continue
+
+                                                       # Make sure that this is a HTTPS URL
+                                                       if not url.scheme == "https":
+                                                               log.warning("Skipping Geofeed URL that is not using HTTPS: %s" \
+                                                                       % url.geturl())
+                                                               continue
+
+                                                       # Normalize the URL and convert it back
+                                                       url = url.geturl()
 
                                                        self.db.execute("""
                                                                INSERT INTO
@@ -1936,7 +1973,7 @@ class CLI(object):
                                                else:
                                                        log.warning("Unsupported type: %s" % type)
 
-       def handle_update_feeds(self, ns):
+       async def handle_update_feeds(self, ns):
                """
                        Update any third-party feeds
                """
@@ -1970,7 +2007,7 @@ class CLI(object):
                                continue
 
                        try:
-                               self._process_feed(name, callback, url, *args)
+                               await self._process_feed(name, callback, url, *args)
 
                        # Log an error but continue if an exception occurs
                        except Exception as e:
@@ -1980,7 +2017,7 @@ class CLI(object):
                # Return status
                return 0 if success else 1
 
-       def _process_feed(self, name, callback, url, *args):
+       async def _process_feed(self, name, callback, url, *args):
                """
                        Processes one feed
                """
@@ -1993,9 +2030,10 @@ class CLI(object):
                        self.db.execute("DELETE FROM network_feeds WHERE source = %s", name)
 
                        # Call the callback to process the feed
-                       return callback(name, f, *args)
+                       with self.db.pipeline():
+                               return await callback(name, f, *args)
 
-       def _import_aws_ip_ranges(self, name, f):
+       async def _import_aws_ip_ranges(self, name, f):
                # Parse the feed
                feed = json.load(f)
 
@@ -2117,7 +2155,7 @@ class CLI(object):
                                """, "%s" % network, name, cc, is_anycast,
                        )
 
-       def _import_spamhaus_drop(self, name, f):
+       async def _import_spamhaus_drop(self, name, f):
                """
                        Import Spamhaus DROP IP feeds
                """
@@ -2173,7 +2211,7 @@ class CLI(object):
                if not lines:
                        raise RuntimeError("Received bogus feed %s with no data" % name)
 
-       def _import_spamhaus_asndrop(self, name, f):
+       async def _import_spamhaus_asndrop(self, name, f):
                """
                        Import Spamhaus ASNDROP feed
                """
@@ -2244,7 +2282,7 @@ class CLI(object):
                # Default to None
                return None
 
-       def handle_import_countries(self, ns):
+       async def handle_import_countries(self, ns):
                with self.db.transaction():
                        # Drop all data that we have
                        self.db.execute("TRUNCATE TABLE countries")
@@ -2344,9 +2382,10 @@ def iterate_over_lines(f):
                # Strip the ending
                yield line.rstrip()
 
-def main():
+async def main():
        # Run the command line interface
        c = CLI()
-       c.run()
 
-main()
+       await c.run()
+
+asyncio.run(main())