###############################################################################
import argparse
-import gettext
+import datetime
+import logging
import lzma
import os
import random
# Load our location module
import location
-
-import logging
-logging.basicConfig(level=logging.INFO)
+from location.i18n import _
DATABASE_FILENAME = "test.db.xz"
MIRRORS = (
"https://people.ipfire.org/~ms/location/",
)
-# i18n
-def _(singular, plural=None, n=None):
- if plural:
- return gettext.dngettext("libloc", singular, plural, n)
-
- return gettext.dgettext("libloc", singular)
-
-class NotModifiedError(Exception):
- """
- Raised when the file has not been modified on the server
- """
- pass
-
+# Initialise logging
+log = logging.getLogger("location.downloader")
+log.propagate = 1
class Downloader(object):
def __init__(self, mirrors):
# Update headers
headers.update({
- "User-Agent" : "location-downloader/%s" % location.__version__,
+ "User-Agent" : "location-downloader/@VERSION@",
})
# Set headers
def _send_request(self, req, **kwargs):
# Log request headers
- logging.debug("HTTP %s Request to %s" % (req.method, req.host))
- logging.debug(" URL: %s" % req.full_url)
- logging.debug(" Headers:")
+ log.debug("HTTP %s Request to %s" % (req.method, req.host))
+ log.debug(" URL: %s" % req.full_url)
+ log.debug(" Headers:")
for k, v in req.header_items():
- logging.debug(" %s: %s" % (k, v))
+ log.debug(" %s: %s" % (k, v))
try:
res = urllib.request.urlopen(req, **kwargs)
except urllib.error.HTTPError as e:
# Log response headers
- logging.debug("HTTP Response: %s" % e.code)
- logging.debug(" Headers:")
+ log.debug("HTTP Response: %s" % e.code)
+ log.debug(" Headers:")
for header in e.headers:
- logging.debug(" %s: %s" % (header, e.headers[header]))
-
- # Handle 304
- if e.code == 304:
- raise NotModifiedError() from e
+ log.debug(" %s: %s" % (header, e.headers[header]))
# Raise all other errors
raise e
# Log response headers
- logging.debug("HTTP Response: %s" % res.code)
- logging.debug(" Headers:")
+ log.debug("HTTP Response: %s" % res.code)
+ log.debug(" Headers:")
for k, v in res.getheaders():
- logging.debug(" %s: %s" % (k, v))
+ log.debug(" %s: %s" % (k, v))
return res
- def download(self, url, mtime=None, **kwargs):
+ def download(self, url, public_key, timestamp=None, **kwargs):
headers = {}
- if mtime:
- headers["If-Modified-Since"] = time.strftime(
- "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(mtime),
+ if timestamp:
+ headers["If-Modified-Since"] = timestamp.strftime(
+ "%a, %d %b %Y %H:%M:%S GMT",
)
t = tempfile.NamedTemporaryFile(delete=False)
if buf:
t.write(buf)
- # Write all data to disk
- t.flush()
-
- # Nothing to do when the database on the server is up to date
- except NotModifiedError:
- logging.info("Local database is up to date")
- return
+ # Write all data to disk
+ t.flush()
# Catch decompression errors
except lzma.LZMAError as e:
- logging.warning("Could not decompress downloaded file: %s" % e)
+ log.warning("Could not decompress downloaded file: %s" % e)
continue
- # XXX what do we catch here?
except urllib.error.HTTPError as e:
- if e.code == 404:
- continue
+ # The file on the server was too old
+ if e.code == 304:
+ log.warning("%s is serving an outdated database. Trying next mirror..." % mirror)
- # Truncate the target file and drop downloaded content
- try:
- t.truncate()
- except OSError:
- pass
+ # Log any other HTTP errors
+ else:
+ log.warning("%s reported: %s" % (mirror, e))
- raise e
+ # Throw away any downloaded content and try again
+ t.truncate()
- # Return temporary file
- return t
+ else:
+ # Check if the downloaded database is recent
+ if not self._check_database(t, public_key, timestamp):
+ log.warning("Downloaded database is outdated. Trying next mirror...")
+
+ # Throw away the data and try again
+ t.truncate()
+ continue
+
+ # Return temporary file
+ return t
raise FileNotFoundError(url)
+ def _check_database(self, f, public_key, timestamp=None):
+ """
+ Checks the downloaded database if it can be opened,
+ verified and if it is recent enough
+ """
+ log.debug("Opening downloaded database at %s" % f.name)
+
+ db = location.Database(f.name)
+
+ # Database is not recent
+ if timestamp and db.created_at < timestamp.timestamp():
+ return False
+
+ log.info("Downloaded new database from %s" % (time.strftime(
+ "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
+ )))
+
+ # Verify the database
+ with open(public_key, "r") as f:
+ if not db.verify(f):
+ log.error("Could not verify database")
+ return False
+
+ return True
+
class CLI(object):
def __init__(self):
# version
parser.add_argument("--version", action="version",
- version="%%(prog)s %s" % location.__version__)
+ version="%(prog)s @VERSION@")
# database
parser.add_argument("--database", "-d",
default="@databasedir@/database.db", help=_("Path to database"),
)
+ # public key
+ parser.add_argument("--public-key", "-k",
+ default="@databasedir@/signing-key.pem", help=_("Public Signing Key"),
+ )
+
# Update
update = subparsers.add_parser("update", help=_("Update database"))
update.set_defaults(func=self.handle_update)
+ # Verify
+ verify = subparsers.add_parser("verify",
+ help=_("Verify the downloaded database"))
+ verify.set_defaults(func=self.handle_verify)
+
args = parser.parse_args()
# Enable debug logging
if args.debug:
- logging.basicConfig(level=logging.DEBUG)
+ location.logger.set_level(logging.DEBUG)
# Print usage if no action was given
if not "func" in args:
sys.exit(0)
def handle_update(self, ns):
- mtime = None
+ # Fetch the version we need from DNS
+ t = location.discover_latest_version()
+
+ # Parse timestamp into datetime format
+ try:
+ timestamp = datetime.datetime.fromtimestamp(t)
+ except:
+ raise
# Open database
try:
db = location.Database(ns.database)
- # Get mtime of the old file
- mtime = os.path.getmtime(ns.database)
+ # Check if we are already on the latest version
+ if db.created_at >= timestamp.timestamp():
+ log.info("Already on the latest version")
+ return
+
except FileNotFoundError as e:
db = None
# Try downloading a new database
try:
- t = self.downloader.download(DATABASE_FILENAME, mtime=mtime)
+ t = self.downloader.download(DATABASE_FILENAME,
+ public_key=ns.public_key, timestamp=timestamp)
# If no file could be downloaded, log a message
except FileNotFoundError as e:
- logging.error("Could not download a new database")
+ log.error("Could not download a new database")
return 1
# If we have not received a new file, there is nothing to do
if not t:
- return 0
-
- # Save old database creation time
- created_at = db.created_at if db else 0
-
- # Try opening the downloaded file
- try:
- db = location.Database(t.name)
- except Exception as e:
- raise e
-
- # Check if the downloaded file is newer
- if db.created_at <= created_at:
- logging.warning("Downloaded database is older than the current version")
- return 1
-
- logging.info("Downloaded new database from %s" % (time.strftime(
- "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
- )))
+ return 3
# Write temporary file to destination
shutil.copyfile(t.name, ns.database)
# Remove temporary file
os.unlink(t.name)
+ return 0
+
+ def handle_verify(self, ns):
+ try:
+ db = location.Database(ns.database)
+ except FileNotFoundError as e:
+ log.error("%s: %s" % (ns.database, e))
+ return 127
+
+ # Verify the database
+ with open(ns.public_key, "r") as f:
+ if not db.verify(f):
+ log.error("Could not verify database")
+ return 1
+
+ # Success
+ log.debug("Database successfully verified")
+ return 0
+
def main():
# Run the command line interface