location-importer.in: fix typo
[location/libloc.git] / src / python / downloader.py
1 #!/usr/bin/python3
2 ###############################################################################
3 # #
4 # libloc - A library to determine the location of someone on the Internet #
5 # #
6 # Copyright (C) 2020 IPFire Development Team <info@ipfire.org> #
7 # #
8 # This library is free software; you can redistribute it and/or #
9 # modify it under the terms of the GNU Lesser General Public #
10 # License as published by the Free Software Foundation; either #
11 # version 2.1 of the License, or (at your option) any later version. #
12 # #
13 # This library is distributed in the hope that it will be useful, #
14 # but WITHOUT ANY WARRANTY; without even the implied warranty of #
15 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU #
16 # Lesser General Public License for more details. #
17 # #
18 ###############################################################################
19
20 import logging
21 import lzma
22 import os
23 import random
24 import stat
25 import tempfile
26 import time
27 import urllib.error
28 import urllib.parse
29 import urllib.request
30
31 from . import __version__
32 from _location import Database, DATABASE_VERSION_LATEST
33
34 DATABASE_FILENAME = "location.db.xz"
35 MIRRORS = (
36 "https://location.ipfire.org/databases/",
37 )
38
39 # Initialise logging
40 log = logging.getLogger("location.downloader")
41 log.propagate = 1
42
43 class Downloader(object):
44 def __init__(self, version=DATABASE_VERSION_LATEST, mirrors=None):
45 self.version = version
46
47 # Set mirrors or use defaults
48 self.mirrors = list(mirrors or MIRRORS)
49
50 # Randomize mirrors
51 random.shuffle(self.mirrors)
52
53 # Get proxies from environment
54 self.proxies = self._get_proxies()
55
56 def _get_proxies(self):
57 proxies = {}
58
59 for protocol in ("https", "http"):
60 proxy = os.environ.get("%s_proxy" % protocol, None)
61
62 if proxy:
63 proxies[protocol] = proxy
64
65 return proxies
66
67 def _make_request(self, url, baseurl=None, headers={}):
68 if baseurl:
69 url = urllib.parse.urljoin(baseurl, url)
70
71 req = urllib.request.Request(url, method="GET")
72
73 # Update headers
74 headers.update({
75 "User-Agent" : "location/%s" % __version__,
76 })
77
78 # Set headers
79 for header in headers:
80 req.add_header(header, headers[header])
81
82 # Set proxies
83 for protocol in self.proxies:
84 req.set_proxy(self.proxies[protocol], protocol)
85
86 return req
87
88 def _send_request(self, req, **kwargs):
89 # Log request headers
90 log.debug("HTTP %s Request to %s" % (req.method, req.host))
91 log.debug(" URL: %s" % req.full_url)
92 log.debug(" Headers:")
93 for k, v in req.header_items():
94 log.debug(" %s: %s" % (k, v))
95
96 try:
97 res = urllib.request.urlopen(req, **kwargs)
98
99 except urllib.error.HTTPError as e:
100 # Log response headers
101 log.debug("HTTP Response: %s" % e.code)
102 log.debug(" Headers:")
103 for header in e.headers:
104 log.debug(" %s: %s" % (header, e.headers[header]))
105
106 # Raise all other errors
107 raise e
108
109 # Log response headers
110 log.debug("HTTP Response: %s" % res.code)
111 log.debug(" Headers:")
112 for k, v in res.getheaders():
113 log.debug(" %s: %s" % (k, v))
114
115 return res
116
117 def download(self, public_key, timestamp=None, tmpdir=None, **kwargs):
118 url = "%s/%s" % (self.version, DATABASE_FILENAME)
119
120 headers = {}
121 if timestamp:
122 headers["If-Modified-Since"] = timestamp.strftime(
123 "%a, %d %b %Y %H:%M:%S GMT",
124 )
125
126 t = tempfile.NamedTemporaryFile(dir=tmpdir, delete=False)
127 with t:
128 # Try all mirrors
129 for mirror in self.mirrors:
130 # Prepare HTTP request
131 req = self._make_request(url, baseurl=mirror, headers=headers)
132
133 try:
134 with self._send_request(req) as res:
135 decompressor = lzma.LZMADecompressor()
136
137 # Read all data
138 while True:
139 buf = res.read(1024)
140 if not buf:
141 break
142
143 # Decompress data
144 buf = decompressor.decompress(buf)
145 if buf:
146 t.write(buf)
147
148 # Write all data to disk
149 t.flush()
150
151 # Catch decompression errors
152 except lzma.LZMAError as e:
153 log.warning("Could not decompress downloaded file: %s" % e)
154 continue
155
156 except urllib.error.HTTPError as e:
157 # The file on the server was too old
158 if e.code == 304:
159 log.warning("%s is serving an outdated database. Trying next mirror..." % mirror)
160
161 # Log any other HTTP errors
162 else:
163 log.warning("%s reported: %s" % (mirror, e))
164
165 # Throw away any downloaded content and try again
166 t.truncate()
167
168 else:
169 # Check if the downloaded database is recent
170 if not self._check_database(t, public_key, timestamp):
171 log.warning("Downloaded database is outdated. Trying next mirror...")
172
173 # Throw away the data and try again
174 t.truncate()
175 continue
176
177 # Make the file readable for everyone
178 os.chmod(t.name, stat.S_IRUSR|stat.S_IRGRP|stat.S_IROTH)
179
180 # Return temporary file
181 return t
182
183 # Delete the temporary file after unsuccessful downloads
184 os.unlink(t.name)
185
186 raise FileNotFoundError(url)
187
188 def _check_database(self, f, public_key, timestamp=None):
189 """
190 Checks the downloaded database if it can be opened,
191 verified and if it is recent enough
192 """
193 log.debug("Opening downloaded database at %s" % f.name)
194
195 db = Database(f.name)
196
197 # Database is not recent
198 if timestamp and db.created_at < timestamp.timestamp():
199 return False
200
201 log.info("Downloaded new database from %s" % (time.strftime(
202 "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(db.created_at),
203 )))
204
205 # Verify the database
206 with open(public_key, "r") as f:
207 if not db.verify(f):
208 log.error("Could not verify database")
209 return False
210
211 return True