]> git.ipfire.org Git - people/ms/libloc.git/commitdiff
importer: Use database pipelining when parsing feeds
authorMichael Tremer <michael.tremer@ipfire.org>
Thu, 7 Mar 2024 12:39:09 +0000 (12:39 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Thu, 7 Mar 2024 12:45:58 +0000 (12:45 +0000)
Pipelining should allow us to parse feeds faster since we no longer wait
for a response from the database for each row that we are inserting.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/python/location/database.py
src/scripts/location-importer.in

index 28ce20d385314e0a73422d0688d9f0ee5071ce06..c31379c110c4b5f24cf2c948f607b4d90d2da426 100644 (file)
@@ -151,6 +151,14 @@ class Connection(object):
 
                return conn.transaction()
 
+       def pipeline(self):
+               """
+                       Sets the connection into pipeline mode.
+               """
+               conn = self.connection()
+
+               return conn.pipeline()
+
 
 class Row(dict):
        """A dict that allows for object-like property access syntax."""
index 2de49fbc2ac28c3ca19c2fc2ff1760d0525619c2..708f793e27916f1ab49f74bbeaf4a3e96143d369 100644 (file)
@@ -802,7 +802,8 @@ class CLI(object):
                                f = self.downloader.retrieve(url)
 
                                # Call the callback
-                               await callback(source, countries, f, *args)
+                               with self.db.pipeline():
+                                       await callback(source, countries, f, *args)
 
                        # Process all parsed networks from every RIR we happen to have access to,
                        # insert the largest network chunks into the networks table immediately...
@@ -1734,85 +1735,86 @@ class CLI(object):
                                        lineno = 0
 
                                        # Read the output line by line
-                                       for line in f:
-                                               lineno += 1
+                                       with self.db.pipeline():
+                                               for line in f:
+                                                       lineno += 1
 
-                                               try:
-                                                       line = line.decode()
+                                                       try:
+                                                               line = line.decode()
 
-                                               # Ignore any lines we cannot decode
-                                               except UnicodeDecodeError:
-                                                       log.debug("Could not decode line %s in %s" \
-                                                               % (lineno, geofeed.url))
-                                                       continue
+                                                       # Ignore any lines we cannot decode
+                                                       except UnicodeDecodeError:
+                                                               log.debug("Could not decode line %s in %s" \
+                                                                       % (lineno, geofeed.url))
+                                                               continue
 
-                                               # Strip any newline
-                                               line = line.rstrip()
+                                                       # Strip any newline
+                                                       line = line.rstrip()
 
-                                               # Skip empty lines
-                                               if not line:
-                                                       continue
+                                                       # Skip empty lines
+                                                       if not line:
+                                                               continue
 
-                                               # Skip comments
-                                               elif line.startswith("#"):
-                                                       continue
+                                                       # Skip comments
+                                                       elif line.startswith("#"):
+                                                               continue
 
-                                               # Try to parse the line
-                                               try:
-                                                       fields = line.split(",", 5)
-                                               except ValueError:
-                                                       log.debug("Could not parse line: %s" % line)
-                                                       continue
+                                                       # Try to parse the line
+                                                       try:
+                                                               fields = line.split(",", 5)
+                                                       except ValueError:
+                                                               log.debug("Could not parse line: %s" % line)
+                                                               continue
 
-                                               # Check if we have enough fields
-                                               if len(fields) < 4:
-                                                       log.debug("Not enough fields in line: %s" % line)
-                                                       continue
+                                                       # Check if we have enough fields
+                                                       if len(fields) < 4:
+                                                               log.debug("Not enough fields in line: %s" % line)
+                                                               continue
 
-                                               # Fetch all fields
-                                               network, country, region, city, = fields[:4]
+                                                       # Fetch all fields
+                                                       network, country, region, city, = fields[:4]
 
-                                               # Try to parse the network
-                                               try:
-                                                       network = ipaddress.ip_network(network, strict=False)
-                                               except ValueError:
-                                                       log.debug("Could not parse network: %s" % network)
-                                                       continue
-
-                                               # Strip any excess whitespace from country codes
-                                               country = country.strip()
-
-                                               # Make the country code uppercase
-                                               country = country.upper()
-
-                                               # Check the country code
-                                               if not country:
-                                                       log.debug("Empty country code in Geofeed %s line %s" \
-                                                               % (geofeed.url, lineno))
-                                                       continue
-
-                                               elif not location.country_code_is_valid(country):
-                                                       log.debug("Invalid country code in Geofeed %s:%s: %s" \
-                                                               % (geofeed.url, lineno, country))
-                                                       continue
-
-                                               # Write this into the database
-                                               self.db.execute("""
-                                                       INSERT INTO
-                                                               geofeed_networks (
-                                                                       geofeed_id,
-                                                                       network,
-                                                                       country,
-                                                                       region,
-                                                                       city
-                                                               )
-                                                       VALUES (%s, %s, %s, %s, %s)""",
-                                                       geofeed.id,
-                                                       "%s" % network,
-                                                       country,
-                                                       region,
-                                                       city,
-                                               )
+                                                       # Try to parse the network
+                                                       try:
+                                                               network = ipaddress.ip_network(network, strict=False)
+                                                       except ValueError:
+                                                               log.debug("Could not parse network: %s" % network)
+                                                               continue
+
+                                                       # Strip any excess whitespace from country codes
+                                                       country = country.strip()
+
+                                                       # Make the country code uppercase
+                                                       country = country.upper()
+
+                                                       # Check the country code
+                                                       if not country:
+                                                               log.debug("Empty country code in Geofeed %s line %s" \
+                                                                       % (geofeed.url, lineno))
+                                                               continue
+
+                                                       elif not location.country_code_is_valid(country):
+                                                               log.debug("Invalid country code in Geofeed %s:%s: %s" \
+                                                                       % (geofeed.url, lineno, country))
+                                                               continue
+
+                                                       # Write this into the database
+                                                       self.db.execute("""
+                                                               INSERT INTO
+                                                                       geofeed_networks (
+                                                                               geofeed_id,
+                                                                               network,
+                                                                               country,
+                                                                               region,
+                                                                               city
+                                                                       )
+                                                               VALUES (%s, %s, %s, %s, %s)""",
+                                                               geofeed.id,
+                                                               "%s" % network,
+                                                               country,
+                                                               region,
+                                                               city,
+                                                       )
 
                                # Catch any HTTP errors
                                except urllib.request.HTTPError as e:
@@ -2028,7 +2030,8 @@ class CLI(object):
                        self.db.execute("DELETE FROM network_feeds WHERE source = %s", name)
 
                        # Call the callback to process the feed
-                       return await callback(name, f, *args)
+                       with self.db.pipeline():
+                               return await callback(name, f, *args)
 
        async def _import_aws_ip_ranges(self, name, f):
                # Parse the feed