]> git.ipfire.org Git - location/libloc.git/blob - src/python/location-downloader.in
location-downloader: Verify the database after download
[location/libloc.git] / src / python / location-downloader.in
1 #!/usr/bin/python3
2 ###############################################################################
3 # #
4 # libloc - A library to determine the location of someone on the Internet #
5 # #
6 # Copyright (C) 2019 IPFire Development Team <info@ipfire.org> #
7 # #
8 # This library is free software; you can redistribute it and/or #
9 # modify it under the terms of the GNU Lesser General Public #
10 # License as published by the Free Software Foundation; either #
11 # version 2.1 of the License, or (at your option) any later version. #
12 # #
13 # This library is distributed in the hope that it will be useful, #
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of #
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
16 # Lesser General Public License for more details. #
17 # #
18 ###############################################################################
19
20 import argparse
21 import datetime
22 import gettext
23 import logging
24 import logging.handlers
25 import lzma
26 import os
27 import random
28 import shutil
29 import sys
30 import tempfile
31 import time
32 import urllib.error
33 import urllib.parse
34 import urllib.request
35
36 # Load our location module
37 import location
38
39 DATABASE_FILENAME = "test.db.xz"
40 MIRRORS = (
41 "https://location.ipfire.org/databases/",
42 "https://people.ipfire.org/~ms/location/",
43 )
44
45 def setup_logging(level=logging.INFO):
46 l = logging.getLogger("location-downloader")
47 l.setLevel(level)
48
49 # Log to console
50 h = logging.StreamHandler()
51 h.setLevel(logging.DEBUG)
52 l.addHandler(h)
53
54 # Log to syslog
55 h = logging.handlers.SysLogHandler(address="/dev/log",
56 facility=logging.handlers.SysLogHandler.LOG_DAEMON)
57 h.setLevel(logging.INFO)
58 l.addHandler(h)
59
60 # Format syslog messages
61 formatter = logging.Formatter("location-downloader[%(process)d]: %(message)s")
62 h.setFormatter(formatter)
63
64 return l
65
66 # Initialise logging
67 log = setup_logging()
68
69 # i18n
70 def _(singular, plural=None, n=None):
71 if plural:
72 return gettext.dngettext("libloc", singular, plural, n)
73
74 return gettext.dgettext("libloc", singular)
75
76
77 class Downloader(object):
78 def __init__(self, mirrors):
79 self.mirrors = list(mirrors)
80
81 # Randomize mirrors
82 random.shuffle(self.mirrors)
83
84 # Get proxies from environment
85 self.proxies = self._get_proxies()
86
87 def _get_proxies(self):
88 proxies = {}
89
90 for protocol in ("https", "http"):
91 proxy = os.environ.get("%s_proxy" % protocol, None)
92
93 if proxy:
94 proxies[protocol] = proxy
95
96 return proxies
97
98 def _make_request(self, url, baseurl=None, headers={}):
99 if baseurl:
100 url = urllib.parse.urljoin(baseurl, url)
101
102 req = urllib.request.Request(url, method="GET")
103
104 # Update headers
105 headers.update({
106 "User-Agent" : "location-downloader/%s" % location.__version__,
107 })
108
109 # Set headers
110 for header in headers:
111 req.add_header(header, headers[header])
112
113 # Set proxies
114 for protocol in self.proxies:
115 req.set_proxy(self.proxies[protocol], protocol)
116
117 return req
118
119 def _send_request(self, req, **kwargs):
120 # Log request headers
121 log.debug("HTTP %s Request to %s" % (req.method, req.host))
122 log.debug(" URL: %s" % req.full_url)
123 log.debug(" Headers:")
124 for k, v in req.header_items():
125 log.debug(" %s: %s" % (k, v))
126
127 try:
128 res = urllib.request.urlopen(req, **kwargs)
129
130 except urllib.error.HTTPError as e:
131 # Log response headers
132 log.debug("HTTP Response: %s" % e.code)
133 log.debug(" Headers:")
134 for header in e.headers:
135 log.debug(" %s: %s" % (header, e.headers[header]))
136
137 # Raise all other errors
138 raise e
139
140 # Log response headers
141 log.debug("HTTP Response: %s" % res.code)
142 log.debug(" Headers:")
143 for k, v in res.getheaders():
144 log.debug(" %s: %s" % (k, v))
145
146 return res
147
148 def download(self, url, public_key, timestamp=None, **kwargs):
149 headers = {}
150
151 if timestamp:
152 headers["If-Modified-Since"] = timestamp.strftime(
153 "%a, %d %b %Y %H:%M:%S GMT",
154 )
155
156 t = tempfile.NamedTemporaryFile(delete=False)
157 with t:
158 # Try all mirrors
159 for mirror in self.mirrors:
160 # Prepare HTTP request
161 req = self._make_request(url, baseurl=mirror, headers=headers)
162
163 try:
164 with self._send_request(req) as res:
165 decompressor = lzma.LZMADecompressor()
166
167 # Read all data
168 while True:
169 buf = res.read(1024)
170 if not buf:
171 break
172
173 # Decompress data
174 buf = decompressor.decompress(buf)
175 if buf:
176 t.write(buf)
177
178 # Write all data to disk
179 t.flush()
180
181 # Catch decompression errors
182 except lzma.LZMAError as e:
183 log.warning("Could not decompress downloaded file: %s" % e)
184 continue
185
186 except urllib.error.HTTPError as e:
187 # The file on the server was too old
188 if e.code == 304:
189 log.warning("%s is serving an outdated database. Trying next mirror..." % mirror)
190
191 # Log any other HTTP errors
192 else:
193 log.warning("%s reported: %s" % (mirror, e))
194
195 # Throw away any downloaded content and try again
196 t.truncate()
197
198 else:
199 # Check if the downloaded database is recent
200 if not self._check_database(t, public_key, timestamp):
201 log.warning("Downloaded database is outdated. Trying next mirror...")
202
203 # Throw away the data and try again
204 t.truncate()
205 continue
206
207 # Return temporary file
208 return t
209
210 raise FileNotFoundError(url)
211
212 def _check_database(self, f, public_key, timestamp=None):
213 """
214 Checks the downloaded database if it can be opened,
215 verified and if it is recent enough
216 """
217 log.debug("Opening downloaded database at %s" % f.name)
218
219 db = location.Database(f.name)
220
221 # Database is not recent
222 if timestamp and db.created_at < timestamp.timestamp():
223 return False
224
225 log.info("Downloaded new database from %s" % (time.strftime(
226 "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
227 )))
228
229 # Verify the database
230 with open(public_key, "r") as f:
231 if not db.verify(f):
232 log.error("Could not verify database")
233 return False
234
235 return True
236
237
238 class CLI(object):
239 def __init__(self):
240 self.downloader = Downloader(mirrors=MIRRORS)
241
242 def parse_cli(self):
243 parser = argparse.ArgumentParser(
244 description=_("Location Downloader Command Line Interface"),
245 )
246 subparsers = parser.add_subparsers()
247
248 # Global configuration flags
249 parser.add_argument("--debug", action="store_true",
250 help=_("Enable debug output"))
251
252 # version
253 parser.add_argument("--version", action="version",
254 version="%%(prog)s %s" % location.__version__)
255
256 # database
257 parser.add_argument("--database", "-d",
258 default="@databasedir@/database.db", help=_("Path to database"),
259 )
260
261 # public key
262 parser.add_argument("--public-key", "-k",
263 default="@databasedir@/signing-key.pem", help=_("Public Signing Key"),
264 )
265
266 # Update
267 update = subparsers.add_parser("update", help=_("Update database"))
268 update.set_defaults(func=self.handle_update)
269
270 args = parser.parse_args()
271
272 # Enable debug logging
273 if args.debug:
274 log.setLevel(logging.DEBUG)
275
276 # Print usage if no action was given
277 if not "func" in args:
278 parser.print_usage()
279 sys.exit(2)
280
281 return args
282
283 def run(self):
284 # Parse command line arguments
285 args = self.parse_cli()
286
287 # Call function
288 ret = args.func(args)
289
290 # Return with exit code
291 if ret:
292 sys.exit(ret)
293
294 # Otherwise just exit
295 sys.exit(0)
296
297 def handle_update(self, ns):
298 # Fetch the version we need from DNS
299 t = location.discover_latest_version()
300
301 # Parse timestamp into datetime format
302 try:
303 timestamp = datetime.datetime.fromtimestamp(t)
304 except:
305 raise
306
307 # Open database
308 try:
309 db = location.Database(ns.database)
310
311 # Check if we are already on the latest version
312 if db.created_at >= timestamp.timestamp():
313 log.info("Already on the latest version")
314 return
315
316 except FileNotFoundError as e:
317 db = None
318
319 # Try downloading a new database
320 try:
321 t = self.downloader.download(DATABASE_FILENAME,
322 public_key=ns.public_key, timestamp=timestamp)
323
324 # If no file could be downloaded, log a message
325 except FileNotFoundError as e:
326 log.error("Could not download a new database")
327 return 1
328
329 # If we have not received a new file, there is nothing to do
330 if not t:
331 return 0
332
333 # Write temporary file to destination
334 shutil.copyfile(t.name, ns.database)
335
336 # Remove temporary file
337 os.unlink(t.name)
338
339
340 def main():
341 # Run the command line interface
342 c = CLI()
343 c.run()
344
345 main()