###############################################################################
import argparse
+import asyncio
import csv
import functools
import http.client
return args
- def run(self):
+ async def run(self):
# Parse command line arguments
args = self.parse_cli()
self.db = self._setup_database(args)
# Call function
- ret = args.func(args)
+ ret = await args.func(args)
# Return with exit code
if ret:
return set((country.country_code for country in countries))
- def handle_write(self, ns):
+ async def handle_write(self, ns):
"""
Compiles a database in libloc format out of what is in the database
"""
for file in ns.file:
writer.write(file)
- def handle_update_whois(self, ns):
+ async def handle_update_whois(self, ns):
# Did we run successfully?
success = True
continue
try:
- self._process_source(name, feeds, countries)
+ await self._process_source(name, feeds, countries)
# Log an error but continue if an exception occurs
except Exception as e:
# Return a non-zero exit code for errors
return 0 if success else 1
- def _process_source(self, source, feeds, countries):
+ async def _process_source(self, source, feeds, countries):
"""
This function processes one source
"""
f = self.downloader.retrieve(url)
# Call the callback
- callback(source, countries, f, *args)
+ await callback(source, countries, f, *args)
# Process all parsed networks from every RIR we happen to have access to,
# insert the largest network chunks into the networks table immediately...
""",
)
- def _import_standard_format(self, source, countries, f, *args):
+ async def _import_standard_format(self, source, countries, f, *args):
"""
Imports a single standard format source feed
"""
for block in iterate_over_blocks(f):
self._parse_block(block, source, countries)
- def _import_extended_format(self, source, countries, f, *args):
+ async def _import_extended_format(self, source, countries, f, *args):
# Iterate over all lines
for line in iterate_over_lines(f):
self._parse_line(block, source, countries)
- def _import_arin_as_names(self, source, countries, f, *args):
+ async def _import_arin_as_names(self, source, countries, f, *args):
# Walk through the file
for line in csv.DictReader(feed, dialect="arin"):
log.debug("Processing object: %s" % line)
""", "%s" % network, country_code, [country], source_key,
)
- def handle_update_announcements(self, ns):
+ async def handle_update_announcements(self, ns):
server = ns.server[0]
with self.db.transaction():
if server.startswith("/"):
- self._handle_update_announcements_from_bird(server)
+ await self._handle_update_announcements_from_bird(server)
# Purge anything we never want here
self.db.execute("""
DELETE FROM announcements WHERE last_seen_at <= CURRENT_TIMESTAMP - INTERVAL '14 days';
""")
- def _handle_update_announcements_from_bird(self, server):
+ async def _handle_update_announcements_from_bird(self, server):
# Pre-compile the regular expression for faster searching
route = re.compile(b"^\s(.+?)\s+.+?\[(?:AS(.*?))?.\]$")
# Otherwise return the line
yield line
- def handle_update_geofeeds(self, ns):
+ async def handle_update_geofeeds(self, ns):
# Sync geofeeds
with self.db.transaction():
# Delete all geofeeds which are no longer linked
id
""")
+ ratelimiter = asyncio.Semaphore(32)
+
# Update all geofeeds
- for geofeed in geofeeds:
- with self.db.transaction():
- self._fetch_geofeed(geofeed)
+ async with asyncio.TaskGroup() as tasks:
+ for geofeed in geofeeds:
+ task = tasks.create_task(
+ self._fetch_geofeed(ratelimiter, geofeed),
+ )
# Delete data from any feeds that did not update in the last two weeks
with self.db.transaction():
)
""")
- def _fetch_geofeed(self, geofeed):
- log.debug("Fetching Geofeed %s" % geofeed.url)
+ async def _fetch_geofeed(self, ratelimiter, geofeed):
+ async with ratelimiter:
+ log.debug("Fetching Geofeed %s" % geofeed.url)
- with self.db.transaction():
- # Open the URL
- try:
- # Send the request
- f = self.downloader.retrieve(geofeed.url, headers={
- "User-Agent" : "location/%s" % location.__version__,
+ with self.db.transaction():
+ # Open the URL
+ try:
+ # Send the request
+ f = await asyncio.to_thread(
+ self.downloader.retrieve, geofeed.url,
+ headers={
+ "User-Agent" : "location/%s" % location.__version__,
+
+ # We expect some plain text file in CSV format
+ "Accept" : "text/csv, text/plain",
+ },
+
+ # Don't wait longer than 10 seconds for a response
+ #timeout=10,
+ )
- # We expect some plain text file in CSV format
- "Accept" : "text/csv, text/plain",
- })
+ # Remove any previous data
+ self.db.execute("DELETE FROM geofeed_networks \
+ WHERE geofeed_id = %s", geofeed.id)
- # Remove any previous data
- self.db.execute("DELETE FROM geofeed_networks \
- WHERE geofeed_id = %s", geofeed.id)
+ lineno = 0
- lineno = 0
+ # Read the output line by line
+ for line in f:
+ lineno += 1
- # Read the output line by line
- for line in f:
- lineno += 1
+ try:
+ line = line.decode()
- try:
- line = line.decode()
+ # Ignore any lines we cannot decode
+ except UnicodeDecodeError:
+ log.debug("Could not decode line %s in %s" \
+ % (lineno, geofeed.url))
+ continue
- # Ignore any lines we cannot decode
- except UnicodeDecodeError:
- log.debug("Could not decode line %s in %s" \
- % (lineno, geofeed.url))
- continue
+ # Strip any newline
+ line = line.rstrip()
- # Strip any newline
- line = line.rstrip()
+ # Skip empty lines
+ if not line:
+ continue
- # Skip empty lines
- if not line:
- continue
-
- # Skip comments
- elif line.startswith("#"):
- continue
+ # Skip comments
+ elif line.startswith("#"):
+ continue
- # Try to parse the line
- try:
- fields = line.split(",", 5)
- except ValueError:
- log.debug("Could not parse line: %s" % line)
- continue
+ # Try to parse the line
+ try:
+ fields = line.split(",", 5)
+ except ValueError:
+ log.debug("Could not parse line: %s" % line)
+ continue
- # Check if we have enough fields
- if len(fields) < 4:
- log.debug("Not enough fields in line: %s" % line)
- continue
+ # Check if we have enough fields
+ if len(fields) < 4:
+ log.debug("Not enough fields in line: %s" % line)
+ continue
- # Fetch all fields
- network, country, region, city, = fields[:4]
+ # Fetch all fields
+ network, country, region, city, = fields[:4]
- # Try to parse the network
- try:
- network = ipaddress.ip_network(network, strict=False)
- except ValueError:
- log.debug("Could not parse network: %s" % network)
- continue
+ # Try to parse the network
+ try:
+ network = ipaddress.ip_network(network, strict=False)
+ except ValueError:
+ log.debug("Could not parse network: %s" % network)
+ continue
+
+ # Strip any excess whitespace from country codes
+ country = country.strip()
+
+ # Make the country code uppercase
+ country = country.upper()
+
+ # Check the country code
+ if not country:
+ log.debug("Empty country code in Geofeed %s line %s" \
+ % (geofeed.url, lineno))
+ continue
+
+ elif not location.country_code_is_valid(country):
+ log.debug("Invalid country code in Geofeed %s:%s: %s" \
+ % (geofeed.url, lineno, country))
+ continue
+
+ # Write this into the database
+ self.db.execute("""
+ INSERT INTO
+ geofeed_networks (
+ geofeed_id,
+ network,
+ country,
+ region,
+ city
+ )
+ VALUES (%s, %s, %s, %s, %s)""",
+ geofeed.id,
+ "%s" % network,
+ country,
+ region,
+ city,
+ )
- # Strip any excess whitespace from country codes
- country = country.strip()
+ # Catch any HTTP errors
+ except urllib.request.HTTPError as e:
+ self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
+ WHERE id = %s", e.code, "%s" % e, geofeed.id)
- # Make the country code uppercase
- country = country.upper()
+ # Remove any previous data when the feed has been deleted
+ if e.code == 404:
+ self.db.execute("DELETE FROM geofeed_networks \
+ WHERE geofeed_id = %s", geofeed.id)
- # Check the country code
- if not country:
- log.debug("Empty country code in Geofeed %s line %s" \
- % (geofeed.url, lineno))
- continue
+ # Catch any other errors and connection timeouts
+ except (http.client.InvalidURL, urllib.request.URLError, TimeoutError) as e:
+ log.debug("Could not fetch URL %s: %s" % (geofeed.url, e))
- elif not location.country_code_is_valid(country):
- log.debug("Invalid country code in Geofeed %s:%s: %s" \
- % (geofeed.url, lineno, country))
- continue
+ self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
+ WHERE id = %s", 599, "%s" % e, geofeed.id)
- # Write this into the database
+ # Mark the geofeed as updated
+ else:
self.db.execute("""
- INSERT INTO
- geofeed_networks (
- geofeed_id,
- network,
- country,
- region,
- city
- )
- VALUES (%s, %s, %s, %s, %s)""",
+ UPDATE
+ geofeeds
+ SET
+ updated_at = CURRENT_TIMESTAMP,
+ status = NULL,
+ error = NULL
+ WHERE
+ id = %s""",
geofeed.id,
- "%s" % network,
- country,
- region,
- city,
)
- # Catch any HTTP errors
- except urllib.request.HTTPError as e:
- self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
- WHERE id = %s", e.code, "%s" % e, geofeed.id)
-
- # Remove any previous data when the feed has been deleted
- if e.code == 404:
- self.db.execute("DELETE FROM geofeed_networks \
- WHERE geofeed_id = %s", geofeed.id)
-
- # Catch any other errors and connection timeouts
- except (http.client.InvalidURL, urllib.request.URLError, TimeoutError) as e:
- log.debug("Could not fetch URL %s: %s" % (geofeed.url, e))
-
- self.db.execute("UPDATE geofeeds SET status = %s, error = %s \
- WHERE id = %s", 599, "%s" % e, geofeed.id)
-
- # Mark the geofeed as updated
- else:
- self.db.execute("""
- UPDATE
- geofeeds
- SET
- updated_at = CURRENT_TIMESTAMP,
- status = NULL,
- error = NULL
- WHERE
- id = %s""",
- geofeed.id,
- )
-
- def handle_update_overrides(self, ns):
+ async def handle_update_overrides(self, ns):
with self.db.transaction():
# Drop any previous content
self.db.execute("TRUNCATE TABLE autnum_overrides")
else:
log.warning("Unsupported type: %s" % type)
- def handle_update_feeds(self, ns):
+ async def handle_update_feeds(self, ns):
"""
Update any third-party feeds
"""
continue
try:
- self._process_feed(name, callback, url, *args)
+ await self._process_feed(name, callback, url, *args)
# Log an error but continue if an exception occurs
except Exception as e:
# Return status
return 0 if success else 1
- def _process_feed(self, name, callback, url, *args):
+ async def _process_feed(self, name, callback, url, *args):
"""
Processes one feed
"""
self.db.execute("DELETE FROM network_feeds WHERE source = %s", name)
# Call the callback to process the feed
- return callback(name, f, *args)
+ return await callback(name, f, *args)
- def _import_aws_ip_ranges(self, name, f):
+ async def _import_aws_ip_ranges(self, name, f):
# Parse the feed
feed = json.load(f)
""", "%s" % network, name, cc, is_anycast,
)
- def _import_spamhaus_drop(self, name, f):
+ async def _import_spamhaus_drop(self, name, f):
"""
Import Spamhaus DROP IP feeds
"""
if not lines:
raise RuntimeError("Received bogus feed %s with no data" % name)
- def _import_spamhaus_asndrop(self, name, f):
+ async def _import_spamhaus_asndrop(self, name, f):
"""
Import Spamhaus ASNDROP feed
"""
# Default to None
return None
- def handle_import_countries(self, ns):
+ async def handle_import_countries(self, ns):
with self.db.transaction():
# Drop all data that we have
self.db.execute("TRUNCATE TABLE countries")
# Strip the ending
yield line.rstrip()
-def main():
+async def main():
# Run the command line interface
c = CLI()
- c.run()
-main()
+ await c.run()
+
+asyncio.run(main())