]> git.ipfire.org Git - location/libloc.git/commitdiff
importer: Merge the downloader into our main downloader
authorMichael Tremer <michael.tremer@ipfire.org>
Mon, 4 Mar 2024 12:20:10 +0000 (12:20 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Mon, 4 Mar 2024 12:20:10 +0000 (12:20 +0000)
I don't know why we ended up with duplicated code here, but there seems
to be no reason whatsoever for this.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
Makefile.am
src/python/location/downloader.py
src/python/location/importer.py [deleted file]
src/scripts/location-importer.in

index 6986afeeec95882bdf76a76903be96c24d33e1de..a7d1c4f461b0f1a6ab3f24b5af9fc6781bcbedff 100644 (file)
@@ -194,7 +194,6 @@ dist_pkgpython_PYTHON = \
        src/python/location/downloader.py \
        src/python/location/export.py \
        src/python/location/i18n.py \
-       src/python/location/importer.py \
        src/python/location/logger.py
 
 pyexec_LTLIBRARIES = \
index 3618968cfecc4ca874be8ce5393b3cad7b425cdf..4e9e1847e234b5ddec7edc40a51c084db8bf5b3a 100644 (file)
@@ -16,6 +16,7 @@
 #                                                                             #
 ###############################################################################
 
+import gzip
 import logging
 import lzma
 import os
@@ -207,3 +208,56 @@ class Downloader(object):
                                return False
 
                return True
+
+       def retrieve(self, url, **kwargs):
+               """
+                       This method will fetch the content at the given URL
+                       and will return a file-object to a temporary file.
+
+                       If the content was compressed, it will be decompressed on the fly.
+               """
+               # Open a temporary file to buffer the downloaded content
+               t = tempfile.SpooledTemporaryFile(max_size=100 * 1024 * 1024)
+
+               # Create a new request
+               req = self._make_request(url, **kwargs)
+
+               # Send request
+               res = self._send_request(req)
+
+               # Write the payload to the temporary file
+               with res as f:
+                       while True:
+                               buf = f.read(65536)
+                               if not buf:
+                                       break
+
+                               t.write(buf)
+
+               # Rewind the temporary file
+               t.seek(0)
+
+               gzip_compressed = False
+
+               # Fetch the content type
+               content_type = res.headers.get("Content-Type")
+
+               # Decompress any gzipped response on the fly
+               if content_type in ("application/x-gzip", "application/gzip"):
+                       gzip_compressed = True
+
+               # Check for the gzip magic in case web servers send a different MIME type
+               elif t.read(2) == b"\x1f\x8b":
+                       gzip_compressed = True
+
+               # Reset again
+               t.seek(0)
+
+               # Decompress the temporary file
+               if gzip_compressed:
+                       log.debug("Gzip compression detected")
+
+                       t = gzip.GzipFile(fileobj=t, mode="rb")
+
+               # Return the temporary file handle
+               return t
diff --git a/src/python/location/importer.py b/src/python/location/importer.py
deleted file mode 100644 (file)
index 58ec368..0000000
+++ /dev/null
@@ -1,101 +0,0 @@
-###############################################################################
-#                                                                             #
-# libloc - A library to determine the location of someone on the Internet     #
-#                                                                             #
-# Copyright (C) 2020 IPFire Development Team <info@ipfire.org>                #
-#                                                                             #
-# This library is free software; you can redistribute it and/or               #
-# modify it under the terms of the GNU Lesser General Public                  #
-# License as published by the Free Software Foundation; either                #
-# version 2.1 of the License, or (at your option) any later version.          #
-#                                                                             #
-# This library is distributed in the hope that it will be useful,             #
-# but WITHOUT ANY WARRANTY; without even the implied warranty of              #
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU           #
-# Lesser General Public License for more details.                             #
-#                                                                             #
-###############################################################################
-
-import gzip
-import logging
-import tempfile
-import urllib.request
-
-# Initialise logging
-log = logging.getLogger("location.importer")
-log.propagate = 1
-
-class Downloader(object):
-       def __init__(self):
-               self.proxy = None
-
-       def set_proxy(self, url):
-               """
-                       Sets a HTTP proxy that is used to perform all requests
-               """
-               log.info("Using proxy %s" % url)
-               self.proxy = url
-
-       def retrieve(self, url, **kwargs):
-               """
-                       This method will fetch the content at the given URL
-                       and will return a file-object to a temporary file.
-
-                       If the content was compressed, it will be decompressed on the fly.
-               """
-               # Open a temporary file to buffer the downloaded content
-               t = tempfile.SpooledTemporaryFile(max_size=100 * 1024 * 1024)
-
-               # Create a new request
-               req = urllib.request.Request(url, **kwargs)
-
-               # Configure proxy
-               if self.proxy:
-                       req.set_proxy(self.proxy, "http")
-
-               log.info("Retrieving %s..." % req.full_url)
-
-               # Send request
-               res = urllib.request.urlopen(req)
-
-               # Log the response headers
-               log.debug("Response Headers:")
-               for header in res.headers:
-                       log.debug("     %s: %s" % (header, res.headers[header]))
-
-               # Write the payload to the temporary file
-               with res as f:
-                       while True:
-                               buf = f.read(65536)
-                               if not buf:
-                                       break
-
-                               t.write(buf)
-
-               # Rewind the temporary file
-               t.seek(0)
-
-               gzip_compressed = False
-
-               # Fetch the content type
-               content_type = res.headers.get("Content-Type")
-
-               # Decompress any gzipped response on the fly
-               if content_type in ("application/x-gzip", "application/gzip"):
-                       gzip_compressed = True
-
-               # Check for the gzip magic in case web servers send a different MIME type
-               elif t.read(2) == b"\x1f\x8b":
-                       gzip_compressed = True
-
-               # Reset again
-               t.seek(0)
-
-               # Decompress the temporary file
-               if gzip_compressed:
-                       log.debug("Gzip compression detected")
-
-                       t = gzip.GzipFile(fileobj=t, mode="rb")
-
-               # Return the temporary file handle
-               return t
index eb142463a8dc98961b938a46364f22defb0dad3f..5b6ffad8531285077ecb15b05cc3ed9423408f4e 100644 (file)
@@ -33,7 +33,7 @@ import urllib.error
 # Load our location module
 import location
 import location.database
-import location.importer
+from location.downloader import Downloader
 from location.i18n import _
 
 # Initialise logging
@@ -162,6 +162,9 @@ class CLI(object):
                # Parse command line arguments
                args = self.parse_cli()
 
+               # Initialize the downloader
+               self.downloader = Downloader()
+
                # Initialise database
                self.db = self._setup_database(args)
 
@@ -689,8 +692,6 @@ class CLI(object):
                        writer.write(file)
 
        def handle_update_whois(self, ns):
-               downloader = location.importer.Downloader()
-
                # Did we run successfully?
                success = True
 
@@ -749,7 +750,7 @@ class CLI(object):
                                continue
 
                        try:
-                               self._process_source(downloader, name, feeds, countries)
+                               self._process_source(name, feeds, countries)
 
                        # Log an error but continue if an exception occurs
                        except Exception as e:
@@ -759,7 +760,7 @@ class CLI(object):
                # Return a non-zero exit code for errors
                return 0 if success else 1
 
-       def _process_source(self, downloader, source, feeds, countries):
+       def _process_source(self, source, feeds, countries):
                """
                        This function processes one source
                """
@@ -791,7 +792,7 @@ class CLI(object):
                        # Parse all feeds
                        for callback, url, *args in feeds:
                                # Retrieve the feed
-                               f = downloader.retrieve(url)
+                               f = self.downloader.retrieve(url)
 
                                # Call the callback
                                callback(source, countries, f, *args)
@@ -1599,9 +1600,6 @@ class CLI(object):
                                yield line
 
        def handle_update_geofeeds(self, ns):
-               # Create a downloader
-               downloader = location.importer.Downloader()
-
                # Sync geofeeds
                with self.db.transaction():
                        # Delete all geofeeds which are no longer linked
@@ -1652,7 +1650,7 @@ class CLI(object):
                # Update all geofeeds
                for geofeed in geofeeds:
                        with self.db.transaction():
-                               self._fetch_geofeed(downloader, geofeed)
+                               self._fetch_geofeed(geofeed)
 
                # Delete data from any feeds that did not update in the last two weeks
                with self.db.transaction():
@@ -1672,14 +1670,14 @@ class CLI(object):
                                        )
                        """)
 
-       def _fetch_geofeed(self, downloader, geofeed):
+       def _fetch_geofeed(self, geofeed):
                log.debug("Fetching Geofeed %s" % geofeed.url)
 
                with self.db.transaction():
                        # Open the URL
                        try:
                                # Send the request
-                               f = downloader.retrieve(geofeed.url, headers={
+                               f = self.downloader.retrieve(geofeed.url, headers={
                                        "User-Agent" : "location/%s" % location.__version__,
 
                                        # We expect some plain text file in CSV format
@@ -1897,9 +1895,6 @@ class CLI(object):
                """
                success = True
 
-               # Create a downloader
-               downloader = location.importer.Downloader()
-
                feeds = (
                        # AWS IP Ranges
                        ("AWS-IP-RANGES", self._import_aws_ip_ranges, "https://ip-ranges.amazonaws.com/ip-ranges.json"),
@@ -1928,7 +1923,7 @@ class CLI(object):
                                continue
 
                        try:
-                               self._process_feed(downloader, name, callback, url, *args)
+                               self._process_feed(name, callback, url, *args)
 
                        # Log an error but continue if an exception occurs
                        except Exception as e:
@@ -1938,12 +1933,12 @@ class CLI(object):
                # Return status
                return 0 if success else 1
 
-       def _process_feed(self, downloader, name, callback, url, *args):
+       def _process_feed(self, name, callback, url, *args):
                """
                        Processes one feed
                """
                # Open the URL
-               f = downloader.retrieve(url)
+               f = self.downloader.retrieve(url)
 
                with self.db.transaction():
                        # Drop any previous content