]> git.ipfire.org Git - people/ms/libloc.git/blob - src/python/location-downloader.in
961c5dffdaf631344588ebf04df8ebf1fcb51c9c
[people/ms/libloc.git] / src / python / location-downloader.in
1 #!/usr/bin/python3
2 ###############################################################################
3 # #
4 # libloc - A library to determine the location of someone on the Internet #
5 # #
6 # Copyright (C) 2019 IPFire Development Team <info@ipfire.org> #
7 # #
8 # This library is free software; you can redistribute it and/or #
9 # modify it under the terms of the GNU Lesser General Public #
10 # License as published by the Free Software Foundation; either #
11 # version 2.1 of the License, or (at your option) any later version. #
12 # #
13 # This library is distributed in the hope that it will be useful, #
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of #
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
16 # Lesser General Public License for more details. #
17 # #
18 ###############################################################################
19
20 import argparse
21 import gettext
22 import logging
23 import logging.handlers
24 import lzma
25 import os
26 import random
27 import shutil
28 import sys
29 import tempfile
30 import time
31 import urllib.error
32 import urllib.parse
33 import urllib.request
34
35 # Load our location module
36 import location
37
38 DATABASE_FILENAME = "test.db.xz"
39 MIRRORS = (
40 "https://location.ipfire.org/databases/",
41 "https://people.ipfire.org/~ms/location/",
42 )
43
44 def setup_logging(level=logging.INFO):
45 l = logging.getLogger("location-downloader")
46 l.setLevel(level)
47
48 # Log to console
49 h = logging.StreamHandler()
50 h.setLevel(logging.DEBUG)
51 l.addHandler(h)
52
53 # Log to syslog
54 h = logging.handlers.SysLogHandler(address="/dev/log",
55 facility=logging.handlers.SysLogHandler.LOG_DAEMON)
56 h.setLevel(logging.INFO)
57 l.addHandler(h)
58
59 # Format syslog messages
60 formatter = logging.Formatter("location-downloader[%(process)d]: %(message)s")
61 h.setFormatter(formatter)
62
63 return l
64
65 # Initialise logging
66 log = setup_logging()
67
68 # i18n
69 def _(singular, plural=None, n=None):
70 if plural:
71 return gettext.dngettext("libloc", singular, plural, n)
72
73 return gettext.dgettext("libloc", singular)
74
75 class NotModifiedError(Exception):
76 """
77 Raised when the file has not been modified on the server
78 """
79 pass
80
81
82 class Downloader(object):
83 def __init__(self, mirrors):
84 self.mirrors = list(mirrors)
85
86 # Randomize mirrors
87 random.shuffle(self.mirrors)
88
89 # Get proxies from environment
90 self.proxies = self._get_proxies()
91
92 def _get_proxies(self):
93 proxies = {}
94
95 for protocol in ("https", "http"):
96 proxy = os.environ.get("%s_proxy" % protocol, None)
97
98 if proxy:
99 proxies[protocol] = proxy
100
101 return proxies
102
103 def _make_request(self, url, baseurl=None, headers={}):
104 if baseurl:
105 url = urllib.parse.urljoin(baseurl, url)
106
107 req = urllib.request.Request(url, method="GET")
108
109 # Update headers
110 headers.update({
111 "User-Agent" : "location-downloader/%s" % location.__version__,
112 })
113
114 # Set headers
115 for header in headers:
116 req.add_header(header, headers[header])
117
118 # Set proxies
119 for protocol in self.proxies:
120 req.set_proxy(self.proxies[protocol], protocol)
121
122 return req
123
124 def _send_request(self, req, **kwargs):
125 # Log request headers
126 log.debug("HTTP %s Request to %s" % (req.method, req.host))
127 log.debug(" URL: %s" % req.full_url)
128 log.debug(" Headers:")
129 for k, v in req.header_items():
130 log.debug(" %s: %s" % (k, v))
131
132 try:
133 res = urllib.request.urlopen(req, **kwargs)
134
135 except urllib.error.HTTPError as e:
136 # Log response headers
137 log.debug("HTTP Response: %s" % e.code)
138 log.debug(" Headers:")
139 for header in e.headers:
140 log.debug(" %s: %s" % (header, e.headers[header]))
141
142 # Handle 304
143 if e.code == 304:
144 raise NotModifiedError() from e
145
146 # Raise all other errors
147 raise e
148
149 # Log response headers
150 log.debug("HTTP Response: %s" % res.code)
151 log.debug(" Headers:")
152 for k, v in res.getheaders():
153 log.debug(" %s: %s" % (k, v))
154
155 return res
156
157 def download(self, url, mtime=None, **kwargs):
158 headers = {}
159
160 if mtime:
161 headers["If-Modified-Since"] = time.strftime(
162 "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(mtime),
163 )
164
165 t = tempfile.NamedTemporaryFile(delete=False)
166 with t:
167 # Try all mirrors
168 for mirror in self.mirrors:
169 # Prepare HTTP request
170 req = self._make_request(url, baseurl=mirror, headers=headers)
171
172 try:
173 with self._send_request(req) as res:
174 decompressor = lzma.LZMADecompressor()
175
176 # Read all data
177 while True:
178 buf = res.read(1024)
179 if not buf:
180 break
181
182 # Decompress data
183 buf = decompressor.decompress(buf)
184 if buf:
185 t.write(buf)
186
187 # Write all data to disk
188 t.flush()
189
190 # Nothing to do when the database on the server is up to date
191 except NotModifiedError:
192 log.info("Local database is up to date")
193 return
194
195 # Catch decompression errors
196 except lzma.LZMAError as e:
197 log.warning("Could not decompress downloaded file: %s" % e)
198 continue
199
200 # XXX what do we catch here?
201 except urllib.error.HTTPError as e:
202 if e.code == 404:
203 continue
204
205 # Truncate the target file and drop downloaded content
206 try:
207 t.truncate()
208 except OSError:
209 pass
210
211 raise e
212
213 # Return temporary file
214 return t
215
216 raise FileNotFoundError(url)
217
218
219 class CLI(object):
220 def __init__(self):
221 self.downloader = Downloader(mirrors=MIRRORS)
222
223 def parse_cli(self):
224 parser = argparse.ArgumentParser(
225 description=_("Location Downloader Command Line Interface"),
226 )
227 subparsers = parser.add_subparsers()
228
229 # Global configuration flags
230 parser.add_argument("--debug", action="store_true",
231 help=_("Enable debug output"))
232
233 # version
234 parser.add_argument("--version", action="version",
235 version="%%(prog)s %s" % location.__version__)
236
237 # database
238 parser.add_argument("--database", "-d",
239 default="@databasedir@/database.db", help=_("Path to database"),
240 )
241
242 # Update
243 update = subparsers.add_parser("update", help=_("Update database"))
244 update.set_defaults(func=self.handle_update)
245
246 args = parser.parse_args()
247
248 # Enable debug logging
249 if args.debug:
250 log.setLevel(logging.DEBUG)
251
252 # Print usage if no action was given
253 if not "func" in args:
254 parser.print_usage()
255 sys.exit(2)
256
257 return args
258
259 def run(self):
260 # Parse command line arguments
261 args = self.parse_cli()
262
263 # Call function
264 ret = args.func(args)
265
266 # Return with exit code
267 if ret:
268 sys.exit(ret)
269
270 # Otherwise just exit
271 sys.exit(0)
272
273 def handle_update(self, ns):
274 mtime = None
275
276 # Open database
277 try:
278 db = location.Database(ns.database)
279
280 # Get mtime of the old file
281 mtime = os.path.getmtime(ns.database)
282 except FileNotFoundError as e:
283 db = None
284
285 # Try downloading a new database
286 try:
287 t = self.downloader.download(DATABASE_FILENAME, mtime=mtime)
288
289 # If no file could be downloaded, log a message
290 except FileNotFoundError as e:
291 log.error("Could not download a new database")
292 return 1
293
294 # If we have not received a new file, there is nothing to do
295 if not t:
296 return 0
297
298 # Save old database creation time
299 created_at = db.created_at if db else 0
300
301 # Try opening the downloaded file
302 try:
303 db = location.Database(t.name)
304 except Exception as e:
305 raise e
306
307 # Check if the downloaded file is newer
308 if db.created_at <= created_at:
309 log.warning("Downloaded database is older than the current version")
310 return 1
311
312 log.info("Downloaded new database from %s" % (time.strftime(
313 "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
314 )))
315
316 # Write temporary file to destination
317 shutil.copyfile(t.name, ns.database)
318
319 # Remove temporary file
320 os.unlink(t.name)
321
322
323 def main():
324 # Run the command line interface
325 c = CLI()
326 c.run()
327
328 main()