]> git.ipfire.org Git - location/libloc.git/blob - src/python/location-downloader.in
location-downloader: Add command to verify the downloaded database manually
[location/libloc.git] / src / python / location-downloader.in
1 #!/usr/bin/python3
2 ###############################################################################
3 # #
4 # libloc - A library to determine the location of someone on the Internet #
5 # #
6 # Copyright (C) 2019 IPFire Development Team <info@ipfire.org> #
7 # #
8 # This library is free software; you can redistribute it and/or #
9 # modify it under the terms of the GNU Lesser General Public #
10 # License as published by the Free Software Foundation; either #
11 # version 2.1 of the License, or (at your option) any later version. #
12 # #
13 # This library is distributed in the hope that it will be useful, #
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of #
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
16 # Lesser General Public License for more details. #
17 # #
18 ###############################################################################
19
20 import argparse
21 import datetime
22 import gettext
23 import logging
24 import logging.handlers
25 import lzma
26 import os
27 import random
28 import shutil
29 import sys
30 import tempfile
31 import time
32 import urllib.error
33 import urllib.parse
34 import urllib.request
35
36 # Load our location module
37 import location
38
39 DATABASE_FILENAME = "test.db.xz"
40 MIRRORS = (
41 "https://location.ipfire.org/databases/",
42 "https://people.ipfire.org/~ms/location/",
43 )
44
45 def setup_logging(level=logging.INFO):
46 l = logging.getLogger("location-downloader")
47 l.setLevel(level)
48
49 # Log to console
50 h = logging.StreamHandler()
51 h.setLevel(logging.DEBUG)
52 l.addHandler(h)
53
54 # Log to syslog
55 h = logging.handlers.SysLogHandler(address="/dev/log",
56 facility=logging.handlers.SysLogHandler.LOG_DAEMON)
57 h.setLevel(logging.INFO)
58 l.addHandler(h)
59
60 # Format syslog messages
61 formatter = logging.Formatter("location-downloader[%(process)d]: %(message)s")
62 h.setFormatter(formatter)
63
64 return l
65
66 # Initialise logging
67 log = setup_logging()
68
69 # i18n
70 def _(singular, plural=None, n=None):
71 if plural:
72 return gettext.dngettext("libloc", singular, plural, n)
73
74 return gettext.dgettext("libloc", singular)
75
76
77 class Downloader(object):
78 def __init__(self, mirrors):
79 self.mirrors = list(mirrors)
80
81 # Randomize mirrors
82 random.shuffle(self.mirrors)
83
84 # Get proxies from environment
85 self.proxies = self._get_proxies()
86
87 def _get_proxies(self):
88 proxies = {}
89
90 for protocol in ("https", "http"):
91 proxy = os.environ.get("%s_proxy" % protocol, None)
92
93 if proxy:
94 proxies[protocol] = proxy
95
96 return proxies
97
98 def _make_request(self, url, baseurl=None, headers={}):
99 if baseurl:
100 url = urllib.parse.urljoin(baseurl, url)
101
102 req = urllib.request.Request(url, method="GET")
103
104 # Update headers
105 headers.update({
106 "User-Agent" : "location-downloader/%s" % location.__version__,
107 })
108
109 # Set headers
110 for header in headers:
111 req.add_header(header, headers[header])
112
113 # Set proxies
114 for protocol in self.proxies:
115 req.set_proxy(self.proxies[protocol], protocol)
116
117 return req
118
119 def _send_request(self, req, **kwargs):
120 # Log request headers
121 log.debug("HTTP %s Request to %s" % (req.method, req.host))
122 log.debug(" URL: %s" % req.full_url)
123 log.debug(" Headers:")
124 for k, v in req.header_items():
125 log.debug(" %s: %s" % (k, v))
126
127 try:
128 res = urllib.request.urlopen(req, **kwargs)
129
130 except urllib.error.HTTPError as e:
131 # Log response headers
132 log.debug("HTTP Response: %s" % e.code)
133 log.debug(" Headers:")
134 for header in e.headers:
135 log.debug(" %s: %s" % (header, e.headers[header]))
136
137 # Raise all other errors
138 raise e
139
140 # Log response headers
141 log.debug("HTTP Response: %s" % res.code)
142 log.debug(" Headers:")
143 for k, v in res.getheaders():
144 log.debug(" %s: %s" % (k, v))
145
146 return res
147
148 def download(self, url, public_key, timestamp=None, **kwargs):
149 headers = {}
150
151 if timestamp:
152 headers["If-Modified-Since"] = timestamp.strftime(
153 "%a, %d %b %Y %H:%M:%S GMT",
154 )
155
156 t = tempfile.NamedTemporaryFile(delete=False)
157 with t:
158 # Try all mirrors
159 for mirror in self.mirrors:
160 # Prepare HTTP request
161 req = self._make_request(url, baseurl=mirror, headers=headers)
162
163 try:
164 with self._send_request(req) as res:
165 decompressor = lzma.LZMADecompressor()
166
167 # Read all data
168 while True:
169 buf = res.read(1024)
170 if not buf:
171 break
172
173 # Decompress data
174 buf = decompressor.decompress(buf)
175 if buf:
176 t.write(buf)
177
178 # Write all data to disk
179 t.flush()
180
181 # Catch decompression errors
182 except lzma.LZMAError as e:
183 log.warning("Could not decompress downloaded file: %s" % e)
184 continue
185
186 except urllib.error.HTTPError as e:
187 # The file on the server was too old
188 if e.code == 304:
189 log.warning("%s is serving an outdated database. Trying next mirror..." % mirror)
190
191 # Log any other HTTP errors
192 else:
193 log.warning("%s reported: %s" % (mirror, e))
194
195 # Throw away any downloaded content and try again
196 t.truncate()
197
198 else:
199 # Check if the downloaded database is recent
200 if not self._check_database(t, public_key, timestamp):
201 log.warning("Downloaded database is outdated. Trying next mirror...")
202
203 # Throw away the data and try again
204 t.truncate()
205 continue
206
207 # Return temporary file
208 return t
209
210 raise FileNotFoundError(url)
211
212 def _check_database(self, f, public_key, timestamp=None):
213 """
214 Checks the downloaded database if it can be opened,
215 verified and if it is recent enough
216 """
217 log.debug("Opening downloaded database at %s" % f.name)
218
219 db = location.Database(f.name)
220
221 # Database is not recent
222 if timestamp and db.created_at < timestamp.timestamp():
223 return False
224
225 log.info("Downloaded new database from %s" % (time.strftime(
226 "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
227 )))
228
229 # Verify the database
230 with open(public_key, "r") as f:
231 if not db.verify(f):
232 log.error("Could not verify database")
233 return False
234
235 return True
236
237
238 class CLI(object):
239 def __init__(self):
240 self.downloader = Downloader(mirrors=MIRRORS)
241
242 def parse_cli(self):
243 parser = argparse.ArgumentParser(
244 description=_("Location Downloader Command Line Interface"),
245 )
246 subparsers = parser.add_subparsers()
247
248 # Global configuration flags
249 parser.add_argument("--debug", action="store_true",
250 help=_("Enable debug output"))
251
252 # version
253 parser.add_argument("--version", action="version",
254 version="%%(prog)s %s" % location.__version__)
255
256 # database
257 parser.add_argument("--database", "-d",
258 default="@databasedir@/database.db", help=_("Path to database"),
259 )
260
261 # public key
262 parser.add_argument("--public-key", "-k",
263 default="@databasedir@/signing-key.pem", help=_("Public Signing Key"),
264 )
265
266 # Update
267 update = subparsers.add_parser("update", help=_("Update database"))
268 update.set_defaults(func=self.handle_update)
269
270 # Verify
271 verify = subparsers.add_parser("verify",
272 help=_("Verify the downloaded database"))
273 verify.set_defaults(func=self.handle_verify)
274
275 args = parser.parse_args()
276
277 # Enable debug logging
278 if args.debug:
279 log.setLevel(logging.DEBUG)
280
281 # Print usage if no action was given
282 if not "func" in args:
283 parser.print_usage()
284 sys.exit(2)
285
286 return args
287
288 def run(self):
289 # Parse command line arguments
290 args = self.parse_cli()
291
292 # Call function
293 ret = args.func(args)
294
295 # Return with exit code
296 if ret:
297 sys.exit(ret)
298
299 # Otherwise just exit
300 sys.exit(0)
301
302 def handle_update(self, ns):
303 # Fetch the version we need from DNS
304 t = location.discover_latest_version()
305
306 # Parse timestamp into datetime format
307 try:
308 timestamp = datetime.datetime.fromtimestamp(t)
309 except:
310 raise
311
312 # Open database
313 try:
314 db = location.Database(ns.database)
315
316 # Check if we are already on the latest version
317 if db.created_at >= timestamp.timestamp():
318 log.info("Already on the latest version")
319 return
320
321 except FileNotFoundError as e:
322 db = None
323
324 # Try downloading a new database
325 try:
326 t = self.downloader.download(DATABASE_FILENAME,
327 public_key=ns.public_key, timestamp=timestamp)
328
329 # If no file could be downloaded, log a message
330 except FileNotFoundError as e:
331 log.error("Could not download a new database")
332 return 1
333
334 # If we have not received a new file, there is nothing to do
335 if not t:
336 return 3
337
338 # Write temporary file to destination
339 shutil.copyfile(t.name, ns.database)
340
341 # Remove temporary file
342 os.unlink(t.name)
343
344 return 0
345
346 def handle_verify(self, ns):
347 try:
348 db = location.Database(ns.database)
349 except FileNotFoundError as e:
350 log.error("%s: %s" % (ns.database, e))
351 return 127
352
353 # Verify the database
354 with open(ns.public_key, "r") as f:
355 if not db.verify(f):
356 log.error("Could not verify database")
357 return 1
358
359 # Success
360 log.debug("Database successfully verified")
361 return 0
362
363
364 def main():
365 # Run the command line interface
366 c = CLI()
367 c.run()
368
369 main()