]> git.ipfire.org Git - people/ms/libloc.git/blob - src/python/location-downloader.in
downloader: Check DNS for most recent version
[people/ms/libloc.git] / src / python / location-downloader.in
1 #!/usr/bin/python3
2 ###############################################################################
3 # #
4 # libloc - A library to determine the location of someone on the Internet #
5 # #
6 # Copyright (C) 2019 IPFire Development Team <info@ipfire.org> #
7 # #
8 # This library is free software; you can redistribute it and/or #
9 # modify it under the terms of the GNU Lesser General Public #
10 # License as published by the Free Software Foundation; either #
11 # version 2.1 of the License, or (at your option) any later version. #
12 # #
13 # This library is distributed in the hope that it will be useful, #
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of #
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
16 # Lesser General Public License for more details. #
17 # #
18 ###############################################################################
19
20 import argparse
21 import datetime
22 import gettext
23 import logging
24 import logging.handlers
25 import lzma
26 import os
27 import random
28 import shutil
29 import sys
30 import tempfile
31 import time
32 import urllib.error
33 import urllib.parse
34 import urllib.request
35
36 # Load our location module
37 import location
38
39 DATABASE_FILENAME = "test.db.xz"
40 MIRRORS = (
41 "https://location.ipfire.org/databases/",
42 "https://people.ipfire.org/~ms/location/",
43 )
44
45 def setup_logging(level=logging.INFO):
46 l = logging.getLogger("location-downloader")
47 l.setLevel(level)
48
49 # Log to console
50 h = logging.StreamHandler()
51 h.setLevel(logging.DEBUG)
52 l.addHandler(h)
53
54 # Log to syslog
55 h = logging.handlers.SysLogHandler(address="/dev/log",
56 facility=logging.handlers.SysLogHandler.LOG_DAEMON)
57 h.setLevel(logging.INFO)
58 l.addHandler(h)
59
60 # Format syslog messages
61 formatter = logging.Formatter("location-downloader[%(process)d]: %(message)s")
62 h.setFormatter(formatter)
63
64 return l
65
66 # Initialise logging
67 log = setup_logging()
68
69 # i18n
70 def _(singular, plural=None, n=None):
71 if plural:
72 return gettext.dngettext("libloc", singular, plural, n)
73
74 return gettext.dgettext("libloc", singular)
75
76
77 class Downloader(object):
78 def __init__(self, mirrors):
79 self.mirrors = list(mirrors)
80
81 # Randomize mirrors
82 random.shuffle(self.mirrors)
83
84 # Get proxies from environment
85 self.proxies = self._get_proxies()
86
87 def _get_proxies(self):
88 proxies = {}
89
90 for protocol in ("https", "http"):
91 proxy = os.environ.get("%s_proxy" % protocol, None)
92
93 if proxy:
94 proxies[protocol] = proxy
95
96 return proxies
97
98 def _make_request(self, url, baseurl=None, headers={}):
99 if baseurl:
100 url = urllib.parse.urljoin(baseurl, url)
101
102 req = urllib.request.Request(url, method="GET")
103
104 # Update headers
105 headers.update({
106 "User-Agent" : "location-downloader/%s" % location.__version__,
107 })
108
109 # Set headers
110 for header in headers:
111 req.add_header(header, headers[header])
112
113 # Set proxies
114 for protocol in self.proxies:
115 req.set_proxy(self.proxies[protocol], protocol)
116
117 return req
118
119 def _send_request(self, req, **kwargs):
120 # Log request headers
121 log.debug("HTTP %s Request to %s" % (req.method, req.host))
122 log.debug(" URL: %s" % req.full_url)
123 log.debug(" Headers:")
124 for k, v in req.header_items():
125 log.debug(" %s: %s" % (k, v))
126
127 try:
128 res = urllib.request.urlopen(req, **kwargs)
129
130 except urllib.error.HTTPError as e:
131 # Log response headers
132 log.debug("HTTP Response: %s" % e.code)
133 log.debug(" Headers:")
134 for header in e.headers:
135 log.debug(" %s: %s" % (header, e.headers[header]))
136
137 # Raise all other errors
138 raise e
139
140 # Log response headers
141 log.debug("HTTP Response: %s" % res.code)
142 log.debug(" Headers:")
143 for k, v in res.getheaders():
144 log.debug(" %s: %s" % (k, v))
145
146 return res
147
148 def download(self, url, timestamp=None, **kwargs):
149 headers = {}
150
151 if timestamp:
152 headers["If-Modified-Since"] = timestamp.strftime(
153 "%a, %d %b %Y %H:%M:%S GMT",
154 )
155
156 t = tempfile.NamedTemporaryFile(delete=False)
157 with t:
158 # Try all mirrors
159 for mirror in self.mirrors:
160 # Prepare HTTP request
161 req = self._make_request(url, baseurl=mirror, headers=headers)
162
163 try:
164 with self._send_request(req) as res:
165 decompressor = lzma.LZMADecompressor()
166
167 # Read all data
168 while True:
169 buf = res.read(1024)
170 if not buf:
171 break
172
173 # Decompress data
174 buf = decompressor.decompress(buf)
175 if buf:
176 t.write(buf)
177
178 # Write all data to disk
179 t.flush()
180
181 # Catch decompression errors
182 except lzma.LZMAError as e:
183 log.warning("Could not decompress downloaded file: %s" % e)
184 continue
185
186 except urllib.error.HTTPError as e:
187 # The file on the server was too old
188 if e.code == 304:
189 log.warning("%s is serving an outdated database. Trying next mirror..." % mirror)
190
191 # Log any other HTTP errors
192 else:
193 log.warning("%s reported: %s" % (mirror, e))
194
195 # Throw away any downloaded content and try again
196 t.truncate()
197
198 else:
199 # Check if the downloaded database is recent
200 if not self._check_database(t, timestamp):
201 log.warning("Downloaded database is outdated. Trying next mirror...")
202
203 # Throw away the data and try again
204 t.truncate()
205 continue
206
207 # Return temporary file
208 return t
209
210 raise FileNotFoundError(url)
211
212 def _check_database(self, f, timestamp=None):
213 """
214 Checks the downloaded database if it can be opened,
215 verified and if it is recent enough
216 """
217 log.debug("Opening downloaded database at %s" % f.name)
218
219 db = location.Database(f.name)
220
221 # Database is not recent
222 if timestamp and db.created_at < timestamp.timestamp():
223 return False
224
225 log.info("Downloaded new database from %s" % (time.strftime(
226 "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
227 )))
228
229 return True
230
231
232 class CLI(object):
233 def __init__(self):
234 self.downloader = Downloader(mirrors=MIRRORS)
235
236 def parse_cli(self):
237 parser = argparse.ArgumentParser(
238 description=_("Location Downloader Command Line Interface"),
239 )
240 subparsers = parser.add_subparsers()
241
242 # Global configuration flags
243 parser.add_argument("--debug", action="store_true",
244 help=_("Enable debug output"))
245
246 # version
247 parser.add_argument("--version", action="version",
248 version="%%(prog)s %s" % location.__version__)
249
250 # database
251 parser.add_argument("--database", "-d",
252 default="@databasedir@/database.db", help=_("Path to database"),
253 )
254
255 # Update
256 update = subparsers.add_parser("update", help=_("Update database"))
257 update.set_defaults(func=self.handle_update)
258
259 args = parser.parse_args()
260
261 # Enable debug logging
262 if args.debug:
263 log.setLevel(logging.DEBUG)
264
265 # Print usage if no action was given
266 if not "func" in args:
267 parser.print_usage()
268 sys.exit(2)
269
270 return args
271
272 def run(self):
273 # Parse command line arguments
274 args = self.parse_cli()
275
276 # Call function
277 ret = args.func(args)
278
279 # Return with exit code
280 if ret:
281 sys.exit(ret)
282
283 # Otherwise just exit
284 sys.exit(0)
285
286 def handle_update(self, ns):
287 # Fetch the version we need from DNS
288 t = location.discover_latest_version()
289
290 # Parse timestamp into datetime format
291 try:
292 timestamp = datetime.datetime.fromtimestamp(t)
293 except:
294 raise
295
296 # Open database
297 try:
298 db = location.Database(ns.database)
299
300 # Check if we are already on the latest version
301 if db.created_at >= timestamp.timestamp():
302 log.info("Already on the latest version")
303 return
304
305 except FileNotFoundError as e:
306 db = None
307
308 # Try downloading a new database
309 try:
310 t = self.downloader.download(DATABASE_FILENAME, timestamp=timestamp)
311
312 # If no file could be downloaded, log a message
313 except FileNotFoundError as e:
314 log.error("Could not download a new database")
315 return 1
316
317 # If we have not received a new file, there is nothing to do
318 if not t:
319 return 0
320
321 # Write temporary file to destination
322 shutil.copyfile(t.name, ns.database)
323
324 # Remove temporary file
325 os.unlink(t.name)
326
327
328 def main():
329 # Run the command line interface
330 c = CLI()
331 c.run()
332
333 main()