From: Michael Tremer Date: Thu, 7 Mar 2024 12:18:14 +0000 (+0000) Subject: importer: Wrap everything into asyncio X-Git-Tag: 0.9.18~95 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=98d9079b7623ecce6fb27ef6b6898ff49c7c1901;p=location%2Flibloc.git importer: Wrap everything into asyncio Signed-off-by: Michael Tremer --- diff --git a/src/scripts/location-importer.in b/src/scripts/location-importer.in index cc2e5e2..39bfcfc 100644 --- a/src/scripts/location-importer.in +++ b/src/scripts/location-importer.in @@ -18,6 +18,7 @@ ############################################################################### import argparse +import asyncio import csv import functools import http.client @@ -158,7 +159,7 @@ class CLI(object): return args - def run(self): + async def run(self): # Parse command line arguments args = self.parse_cli() @@ -169,7 +170,7 @@ class CLI(object): self.db = self._setup_database(args) # Call function - ret = args.func(args) + ret = await args.func(args) # Return with exit code if ret: @@ -320,7 +321,7 @@ class CLI(object): return set((country.country_code for country in countries)) - def handle_write(self, ns): + async def handle_write(self, ns): """ Compiles a database in libloc format out of what is in the database """ @@ -697,7 +698,7 @@ class CLI(object): for file in ns.file: writer.write(file) - def handle_update_whois(self, ns): + async def handle_update_whois(self, ns): # Did we run successfully? success = True @@ -756,7 +757,7 @@ class CLI(object): continue try: - self._process_source(name, feeds, countries) + await self._process_source(name, feeds, countries) # Log an error but continue if an exception occurs except Exception as e: @@ -766,7 +767,7 @@ class CLI(object): # Return a non-zero exit code for errors return 0 if success else 1 - def _process_source(self, source, feeds, countries): + async def _process_source(self, source, feeds, countries): """ This function processes one source """ @@ -801,7 +802,7 @@ class CLI(object): f = self.downloader.retrieve(url) # Call the callback - callback(source, countries, f, *args) + await callback(source, countries, f, *args) # Process all parsed networks from every RIR we happen to have access to, # insert the largest network chunks into the networks table immediately... @@ -948,7 +949,7 @@ class CLI(object): """, ) - def _import_standard_format(self, source, countries, f, *args): + async def _import_standard_format(self, source, countries, f, *args): """ Imports a single standard format source feed """ @@ -956,12 +957,12 @@ class CLI(object): for block in iterate_over_blocks(f): self._parse_block(block, source, countries) - def _import_extended_format(self, source, countries, f, *args): + async def _import_extended_format(self, source, countries, f, *args): # Iterate over all lines for line in iterate_over_lines(f): self._parse_line(block, source, countries) - def _import_arin_as_names(self, source, countries, f, *args): + async def _import_arin_as_names(self, source, countries, f, *args): # Walk through the file for line in csv.DictReader(feed, dialect="arin"): log.debug("Processing object: %s" % line) @@ -1423,12 +1424,12 @@ class CLI(object): """, "%s" % network, country_code, [country], source_key, ) - def handle_update_announcements(self, ns): + async def handle_update_announcements(self, ns): server = ns.server[0] with self.db.transaction(): if server.startswith("/"): - self._handle_update_announcements_from_bird(server) + await self._handle_update_announcements_from_bird(server) # Purge anything we never want here self.db.execute(""" @@ -1487,7 +1488,7 @@ class CLI(object): DELETE FROM announcements WHERE last_seen_at <= CURRENT_TIMESTAMP - INTERVAL '14 days'; """) - def _handle_update_announcements_from_bird(self, server): + async def _handle_update_announcements_from_bird(self, server): # Pre-compile the regular expression for faster searching route = re.compile(b"^\s(.+?)\s+.+?\[(?:AS(.*?))?.\]$") @@ -1605,7 +1606,7 @@ class CLI(object): # Otherwise return the line yield line - def handle_update_geofeeds(self, ns): + async def handle_update_geofeeds(self, ns): # Sync geofeeds with self.db.transaction(): # Delete all geofeeds which are no longer linked @@ -1673,10 +1674,14 @@ class CLI(object): id """) + ratelimiter = asyncio.Semaphore(32) + # Update all geofeeds - for geofeed in geofeeds: - with self.db.transaction(): - self._fetch_geofeed(geofeed) + async with asyncio.TaskGroup() as tasks: + for geofeed in geofeeds: + task = tasks.create_task( + self._fetch_geofeed(ratelimiter, geofeed), + ) # Delete data from any feeds that did not update in the last two weeks with self.db.transaction(): @@ -1696,139 +1701,146 @@ class CLI(object): ) """) - def _fetch_geofeed(self, geofeed): - log.debug("Fetching Geofeed %s" % geofeed.url) + async def _fetch_geofeed(self, ratelimiter, geofeed): + async with ratelimiter: + log.debug("Fetching Geofeed %s" % geofeed.url) - with self.db.transaction(): - # Open the URL - try: - # Send the request - f = self.downloader.retrieve(geofeed.url, headers={ - "User-Agent" : "location/%s" % location.__version__, + with self.db.transaction(): + # Open the URL + try: + # Send the request + f = await asyncio.to_thread( + self.downloader.retrieve, geofeed.url, + headers={ + "User-Agent" : "location/%s" % location.__version__, + + # We expect some plain text file in CSV format + "Accept" : "text/csv, text/plain", + }, + + # Don't wait longer than 10 seconds for a response + #timeout=10, + ) - # We expect some plain text file in CSV format - "Accept" : "text/csv, text/plain", - }) + # Remove any previous data + self.db.execute("DELETE FROM geofeed_networks \ + WHERE geofeed_id = %s", geofeed.id) - # Remove any previous data - self.db.execute("DELETE FROM geofeed_networks \ - WHERE geofeed_id = %s", geofeed.id) + lineno = 0 - lineno = 0 + # Read the output line by line + for line in f: + lineno += 1 - # Read the output line by line - for line in f: - lineno += 1 + try: + line = line.decode() - try: - line = line.decode() + # Ignore any lines we cannot decode + except UnicodeDecodeError: + log.debug("Could not decode line %s in %s" \ + % (lineno, geofeed.url)) + continue - # Ignore any lines we cannot decode - except UnicodeDecodeError: - log.debug("Could not decode line %s in %s" \ - % (lineno, geofeed.url)) - continue + # Strip any newline + line = line.rstrip() - # Strip any newline - line = line.rstrip() + # Skip empty lines + if not line: + continue - # Skip empty lines - if not line: - continue - - # Skip comments - elif line.startswith("#"): - continue + # Skip comments + elif line.startswith("#"): + continue - # Try to parse the line - try: - fields = line.split(",", 5) - except ValueError: - log.debug("Could not parse line: %s" % line) - continue + # Try to parse the line + try: + fields = line.split(",", 5) + except ValueError: + log.debug("Could not parse line: %s" % line) + continue - # Check if we have enough fields - if len(fields) < 4: - log.debug("Not enough fields in line: %s" % line) - continue + # Check if we have enough fields + if len(fields) < 4: + log.debug("Not enough fields in line: %s" % line) + continue - # Fetch all fields - network, country, region, city, = fields[:4] + # Fetch all fields + network, country, region, city, = fields[:4] - # Try to parse the network - try: - network = ipaddress.ip_network(network, strict=False) - except ValueError: - log.debug("Could not parse network: %s" % network) - continue + # Try to parse the network + try: + network = ipaddress.ip_network(network, strict=False) + except ValueError: + log.debug("Could not parse network: %s" % network) + continue + + # Strip any excess whitespace from country codes + country = country.strip() + + # Make the country code uppercase + country = country.upper() + + # Check the country code + if not country: + log.debug("Empty country code in Geofeed %s line %s" \ + % (geofeed.url, lineno)) + continue + + elif not location.country_code_is_valid(country): + log.debug("Invalid country code in Geofeed %s:%s: %s" \ + % (geofeed.url, lineno, country)) + continue + + # Write this into the database + self.db.execute(""" + INSERT INTO + geofeed_networks ( + geofeed_id, + network, + country, + region, + city + ) + VALUES (%s, %s, %s, %s, %s)""", + geofeed.id, + "%s" % network, + country, + region, + city, + ) - # Strip any excess whitespace from country codes - country = country.strip() + # Catch any HTTP errors + except urllib.request.HTTPError as e: + self.db.execute("UPDATE geofeeds SET status = %s, error = %s \ + WHERE id = %s", e.code, "%s" % e, geofeed.id) - # Make the country code uppercase - country = country.upper() + # Remove any previous data when the feed has been deleted + if e.code == 404: + self.db.execute("DELETE FROM geofeed_networks \ + WHERE geofeed_id = %s", geofeed.id) - # Check the country code - if not country: - log.debug("Empty country code in Geofeed %s line %s" \ - % (geofeed.url, lineno)) - continue + # Catch any other errors and connection timeouts + except (http.client.InvalidURL, urllib.request.URLError, TimeoutError) as e: + log.debug("Could not fetch URL %s: %s" % (geofeed.url, e)) - elif not location.country_code_is_valid(country): - log.debug("Invalid country code in Geofeed %s:%s: %s" \ - % (geofeed.url, lineno, country)) - continue + self.db.execute("UPDATE geofeeds SET status = %s, error = %s \ + WHERE id = %s", 599, "%s" % e, geofeed.id) - # Write this into the database + # Mark the geofeed as updated + else: self.db.execute(""" - INSERT INTO - geofeed_networks ( - geofeed_id, - network, - country, - region, - city - ) - VALUES (%s, %s, %s, %s, %s)""", + UPDATE + geofeeds + SET + updated_at = CURRENT_TIMESTAMP, + status = NULL, + error = NULL + WHERE + id = %s""", geofeed.id, - "%s" % network, - country, - region, - city, ) - # Catch any HTTP errors - except urllib.request.HTTPError as e: - self.db.execute("UPDATE geofeeds SET status = %s, error = %s \ - WHERE id = %s", e.code, "%s" % e, geofeed.id) - - # Remove any previous data when the feed has been deleted - if e.code == 404: - self.db.execute("DELETE FROM geofeed_networks \ - WHERE geofeed_id = %s", geofeed.id) - - # Catch any other errors and connection timeouts - except (http.client.InvalidURL, urllib.request.URLError, TimeoutError) as e: - log.debug("Could not fetch URL %s: %s" % (geofeed.url, e)) - - self.db.execute("UPDATE geofeeds SET status = %s, error = %s \ - WHERE id = %s", 599, "%s" % e, geofeed.id) - - # Mark the geofeed as updated - else: - self.db.execute(""" - UPDATE - geofeeds - SET - updated_at = CURRENT_TIMESTAMP, - status = NULL, - error = NULL - WHERE - id = %s""", - geofeed.id, - ) - - def handle_update_overrides(self, ns): + async def handle_update_overrides(self, ns): with self.db.transaction(): # Drop any previous content self.db.execute("TRUNCATE TABLE autnum_overrides") @@ -1954,7 +1966,7 @@ class CLI(object): else: log.warning("Unsupported type: %s" % type) - def handle_update_feeds(self, ns): + async def handle_update_feeds(self, ns): """ Update any third-party feeds """ @@ -1988,7 +2000,7 @@ class CLI(object): continue try: - self._process_feed(name, callback, url, *args) + await self._process_feed(name, callback, url, *args) # Log an error but continue if an exception occurs except Exception as e: @@ -1998,7 +2010,7 @@ class CLI(object): # Return status return 0 if success else 1 - def _process_feed(self, name, callback, url, *args): + async def _process_feed(self, name, callback, url, *args): """ Processes one feed """ @@ -2011,9 +2023,9 @@ class CLI(object): self.db.execute("DELETE FROM network_feeds WHERE source = %s", name) # Call the callback to process the feed - return callback(name, f, *args) + return await callback(name, f, *args) - def _import_aws_ip_ranges(self, name, f): + async def _import_aws_ip_ranges(self, name, f): # Parse the feed feed = json.load(f) @@ -2135,7 +2147,7 @@ class CLI(object): """, "%s" % network, name, cc, is_anycast, ) - def _import_spamhaus_drop(self, name, f): + async def _import_spamhaus_drop(self, name, f): """ Import Spamhaus DROP IP feeds """ @@ -2191,7 +2203,7 @@ class CLI(object): if not lines: raise RuntimeError("Received bogus feed %s with no data" % name) - def _import_spamhaus_asndrop(self, name, f): + async def _import_spamhaus_asndrop(self, name, f): """ Import Spamhaus ASNDROP feed """ @@ -2262,7 +2274,7 @@ class CLI(object): # Default to None return None - def handle_import_countries(self, ns): + async def handle_import_countries(self, ns): with self.db.transaction(): # Drop all data that we have self.db.execute("TRUNCATE TABLE countries") @@ -2362,9 +2374,10 @@ def iterate_over_lines(f): # Strip the ending yield line.rstrip() -def main(): +async def main(): # Run the command line interface c = CLI() - c.run() -main() + await c.run() + +asyncio.run(main())