]>
Commit | Line | Data |
---|---|---|
857a7836 | 1 | #!/usr/bin/python3 |
9137135a | 2 | |
d48f75f7 MT |
3 | import asyncio |
4 | import base64 | |
5 | import binascii | |
6 | import cryptography.hazmat.backends | |
7 | import cryptography.hazmat.primitives.asymmetric.ec | |
8 | import cryptography.hazmat.primitives.asymmetric.utils | |
9 | import cryptography.hazmat.primitives.ciphers | |
10 | import cryptography.hazmat.primitives.ciphers.aead | |
11 | import cryptography.hazmat.primitives.hashes | |
12 | import cryptography.hazmat.primitives.kdf.hkdf | |
13 | import cryptography.hazmat.primitives.serialization | |
98b826d4 | 14 | import datetime |
b9c2a52b | 15 | import email.utils |
d48f75f7 | 16 | import json |
857a7836 | 17 | import ldap |
9137135a | 18 | import logging |
d48f75f7 | 19 | import os |
69922f6b | 20 | import pickle |
d48f75f7 | 21 | import struct |
fa598127 | 22 | import threading |
857a7836 | 23 | import time |
d48f75f7 | 24 | import urllib.parse |
9137135a MT |
25 | |
26 | import tornado.locale | |
27 | ||
2c909128 | 28 | from . import base |
c8550381 | 29 | from . import bugtracker |
d48f75f7 | 30 | from . import httpclient |
9137135a | 31 | |
b9c2a52b MT |
32 | from .decorators import * |
33 | ||
d24cf3ce MT |
34 | DEFAULT_STORAGE_QUOTA = 256 * 1024 * 1024 # 256 MiB |
35 | ||
a329017a | 36 | # Setup logging |
6acc7746 | 37 | log = logging.getLogger("pbs.users") |
a329017a | 38 | |
857a7836 MT |
39 | # A list of LDAP attributes that we fetch |
40 | LDAP_ATTRS = ( | |
41 | # UID | |
42 | "uid", | |
43 | ||
44 | # Common Name | |
45 | "cn", | |
46 | ||
47 | # First & Last Name | |
82f6d079 | 48 | "givenName", "sn", |
857a7836 MT |
49 | |
50 | # Email Addresses | |
51 | "mail", | |
52 | "mailAlternateAddress", | |
53 | ) | |
54 | ||
607a2fba MT |
55 | WITH_USED_BUILD_TIME_CTE = """ |
56 | user_build_times AS ( | |
bffb4fc5 MT |
57 | SELECT |
58 | users.id AS user_id, | |
607a2fba | 59 | SUM(jobs.finished_at - jobs.started_at) AS used |
bffb4fc5 MT |
60 | FROM |
61 | users | |
62 | LEFT JOIN | |
63 | builds ON users.id = builds.owner_id | |
64 | LEFT JOIN | |
65 | jobs ON builds.id = jobs.build_id | |
66 | WHERE | |
67 | users.deleted_at IS NULL | |
68 | AND | |
69 | users.daily_build_quota IS NOT NULL | |
bffb4fc5 MT |
70 | AND |
71 | jobs.started_at IS NOT NULL | |
72 | AND | |
73 | jobs.finished_at IS NOT NULL | |
74 | AND | |
75 | jobs.finished_at >= CURRENT_TIMESTAMP - INTERVAL '24 hours' | |
76 | GROUP BY | |
77 | users.id | |
bffb4fc5 MT |
78 | ) |
79 | """ | |
80 | ||
607a2fba MT |
81 | WITH_EXCEEDED_QUOTAS_CTE = """ |
82 | -- Include used build time | |
83 | %s, | |
84 | ||
85 | users_with_exceeded_quotas AS ( | |
86 | SELECT | |
87 | * | |
88 | FROM | |
89 | user_build_times build_times | |
90 | LEFT JOIN | |
91 | users ON build_times.user_id = users.id | |
92 | WHERE | |
93 | users.daily_build_quota IS NOT NULL | |
94 | AND | |
95 | build_times.used >= users.daily_build_quota | |
96 | ) | |
97 | """ % WITH_USED_BUILD_TIME_CTE | |
98 | ||
d48f75f7 MT |
99 | class QuotaExceededError(Exception): |
100 | pass | |
101 | ||
9137135a | 102 | class Users(base.Object): |
fa598127 MT |
103 | def init(self): |
104 | # Initialize thread-local storage | |
105 | self.local = threading.local() | |
857a7836 | 106 | |
fa598127 | 107 | @property |
857a7836 | 108 | def ldap(self): |
fa598127 MT |
109 | if not hasattr(self.local, "ldap"): |
110 | # Fetch the LDAP URI | |
111 | ldap_uri = self.backend.config.get("ldap", "uri") | |
857a7836 | 112 | |
fa598127 | 113 | log.debug("Connecting to %s..." % ldap_uri) |
857a7836 | 114 | |
fa598127 MT |
115 | # Establish LDAP connection |
116 | self.local.ldap = ldap.initialize(ldap_uri) | |
117 | ||
118 | return self.local.ldap | |
8d8d65b4 | 119 | |
b9c2a52b MT |
120 | def _get_user(self, query, *args): |
121 | res = self.db.get(query, *args) | |
9137135a | 122 | |
b9c2a52b MT |
123 | if res: |
124 | return User(self.backend, res.id, data=res) | |
9137135a | 125 | |
b9c2a52b MT |
126 | def _get_users(self, query, *args): |
127 | res = self.db.query(query, *args) | |
8d8d65b4 | 128 | |
b9c2a52b MT |
129 | for row in res: |
130 | yield User(self.backend, row.id, data=row) | |
9137135a | 131 | |
b9c2a52b | 132 | def __iter__(self): |
857a7836 MT |
133 | users = self._get_users(""" |
134 | SELECT | |
135 | * | |
136 | FROM | |
137 | users | |
138 | WHERE | |
6e0c9ad4 | 139 | deleted_at IS NULL |
857a7836 MT |
140 | ORDER BY |
141 | name | |
142 | """, | |
143 | ) | |
8d8d65b4 | 144 | |
b9c2a52b | 145 | return iter(users) |
8d8d65b4 | 146 | |
b9c2a52b | 147 | def __len__(self): |
857a7836 MT |
148 | res = self.db.get(""" |
149 | SELECT | |
150 | COUNT(*) AS count | |
151 | FROM | |
152 | users | |
153 | WHERE | |
6e0c9ad4 | 154 | deleted_at IS NULL |
857a7836 MT |
155 | """, |
156 | ) | |
8d8d65b4 | 157 | |
b9c2a52b | 158 | return res.count |
9137135a | 159 | |
857a7836 MT |
160 | def _ldap_query(self, query, attrlist=None, limit=0, search_base=None): |
161 | search_base = self.backend.config.get("ldap", "base") | |
9137135a | 162 | |
857a7836 | 163 | log.debug("Performing LDAP query (%s): %s" % (search_base, query)) |
9137135a | 164 | |
857a7836 | 165 | t = time.time() |
9137135a | 166 | |
857a7836 MT |
167 | # Ask for up to 512 results being returned at a time |
168 | page_control = ldap.controls.SimplePagedResultsControl(True, size=512, cookie="") | |
9137135a | 169 | |
857a7836 MT |
170 | results = [] |
171 | pages = 0 | |
aede21a2 | 172 | |
857a7836 MT |
173 | # Perform the search |
174 | while True: | |
175 | response = self.ldap.search_ext(search_base, | |
176 | ldap.SCOPE_SUBTREE, query, attrlist=attrlist, sizelimit=limit, | |
177 | serverctrls=[page_control], | |
178 | ) | |
18132fad | 179 | |
857a7836 MT |
180 | # Fetch all results |
181 | type, data, rmsgid, serverctrls = self.ldap.result3(response) | |
9137135a | 182 | |
857a7836 MT |
183 | # Append to local copy |
184 | results += data | |
185 | pages += 1 | |
9137135a | 186 | |
857a7836 MT |
187 | controls = [c for c in serverctrls |
188 | if c.controlType == ldap.controls.SimplePagedResultsControl.controlType] | |
9137135a | 189 | |
857a7836 MT |
190 | if not controls: |
191 | break | |
9137135a | 192 | |
857a7836 MT |
193 | # Set the cookie for more results |
194 | page_control.cookie = controls[0].cookie | |
9137135a | 195 | |
857a7836 MT |
196 | # There are no more results |
197 | if not page_control.cookie: | |
198 | break | |
9137135a | 199 | |
857a7836 MT |
200 | # Log time it took to perform the query |
201 | log.debug("Query took %.2fms (%s page(s))" % ((time.time() - t) * 1000.0, pages)) | |
f6e6ff79 | 202 | |
857a7836 MT |
203 | # Return all attributes (without the DN) |
204 | return [attrs for dn, attrs in results] | |
f6e6ff79 | 205 | |
857a7836 MT |
206 | def _ldap_get(self, *args, **kwargs): |
207 | results = self._ldap_query(*args, **kwargs) | |
f6e6ff79 | 208 | |
857a7836 MT |
209 | # No result |
210 | if not results: | |
211 | return {} | |
f6e6ff79 | 212 | |
857a7836 MT |
213 | # Too many results? |
214 | elif len(results) > 1: | |
71d98773 | 215 | raise OverflowError("Too many results returned for ldap_get()") |
abac2d48 | 216 | |
857a7836 | 217 | return results[0] |
abac2d48 | 218 | |
d24cf3ce | 219 | def create(self, name, notify=False, storage_quota=None, _attrs=None): |
857a7836 MT |
220 | """ |
221 | Creates a new user | |
222 | """ | |
69922f6b MT |
223 | # Pickle any attrs |
224 | if _attrs: | |
225 | _attrs = pickle.dumps(_attrs) | |
226 | ||
d24cf3ce MT |
227 | # Set default for storage quota |
228 | if storage_quota is None: | |
229 | storage_quota = DEFAULT_STORAGE_QUOTA | |
230 | ||
857a7836 MT |
231 | user = self._get_user(""" |
232 | INSERT INTO | |
d24cf3ce MT |
233 | users |
234 | ( | |
235 | name, | |
236 | storage_quota, | |
237 | _attrs | |
238 | ) | |
857a7836 | 239 | VALUES |
d24cf3ce MT |
240 | ( |
241 | %s, %s, %s | |
242 | ) | |
857a7836 MT |
243 | RETURNING |
244 | * | |
d24cf3ce | 245 | """, name, storage_quota, _attrs, |
857a7836 | 246 | ) |
abac2d48 | 247 | |
857a7836 | 248 | log.debug("Created user %s" % user) |
abac2d48 | 249 | |
857a7836 MT |
250 | # Send a welcome email |
251 | if notify: | |
252 | user._send_welcome_email() | |
040fc249 | 253 | |
857a7836 | 254 | return user |
040fc249 | 255 | |
857a7836 MT |
256 | def get_by_id(self, id): |
257 | return self._get_user("SELECT * FROM users \ | |
258 | WHERE id = %s", id) | |
040fc249 | 259 | |
857a7836 MT |
260 | def get_by_name(self, name): |
261 | """ | |
262 | Fetch a user by its username | |
263 | """ | |
264 | # Try to find a local user | |
265 | user = self._get_user(""" | |
266 | SELECT | |
267 | * | |
268 | FROM | |
269 | users | |
270 | WHERE | |
6e0c9ad4 | 271 | deleted_at IS NULL |
857a7836 MT |
272 | AND |
273 | name = %s | |
274 | """, name, | |
275 | ) | |
276 | if user: | |
277 | return user | |
278 | ||
ef1960e3 MT |
279 | # Do nothing in test mode |
280 | if self.backend.test: | |
281 | log.warning("Cannot use get_by_name test mode") | |
282 | return | |
283 | ||
857a7836 MT |
284 | # Search in LDAP |
285 | res = self._ldap_get( | |
286 | "(&" | |
287 | "(objectClass=person)" | |
288 | "(uid=%s)" | |
289 | ")" % name, | |
290 | attrlist=("uid",), | |
291 | ) | |
292 | if not res: | |
293 | return | |
040fc249 | 294 | |
857a7836 MT |
295 | # Fetch the UID |
296 | uid = res.get("uid")[0].decode() | |
040fc249 | 297 | |
857a7836 MT |
298 | # Create a new user |
299 | return self.create(uid) | |
f6e6ff79 | 300 | |
a1841b77 MT |
301 | def get_by_email(self, mail): |
302 | # Strip any excess stuff from the email address | |
ef1960e3 MT |
303 | name, mail = email.utils.parseaddr(mail) |
304 | ||
305 | # Do nothing in test mode | |
306 | if self.backend.test: | |
307 | log.warning("Cannot use get_by_email in test mode") | |
308 | return | |
a1841b77 | 309 | |
857a7836 | 310 | # Search in LDAP |
71d98773 MT |
311 | try: |
312 | res = self._ldap_get( | |
313 | "(&" | |
314 | "(objectClass=person)" | |
315 | "(|" | |
316 | "(mail=%s)" | |
317 | "(mailAlternateAddress=%s)" | |
318 | ")" | |
319 | ")" % (mail, mail), | |
320 | attrlist=("uid",), | |
321 | ) | |
322 | ||
323 | except OverflowError as e: | |
324 | raise OverflowError("Too many results for search for %s" % mail) from e | |
f6e6ff79 | 325 | |
857a7836 MT |
326 | # No results |
327 | if not res: | |
328 | return | |
b9c2a52b | 329 | |
857a7836 MT |
330 | # Fetch the UID |
331 | uid = res.get("uid")[0].decode() | |
332 | ||
333 | return self.get_by_name(uid) | |
334 | ||
ffd6438e MT |
335 | def _search_by_email(self, mails, include_missing=True): |
336 | """ | |
337 | Takes a list of email addresses and returns all users that could be found | |
338 | """ | |
339 | users = [] | |
340 | ||
341 | for mail in mails: | |
342 | user = self.get_by_email(mail) | |
343 | ||
344 | # Include the search string if no user could be found | |
345 | if not user and include_missing: | |
346 | user = mail | |
347 | ||
348 | # Skip any duplicates | |
349 | if user in users: | |
350 | continue | |
351 | ||
352 | users.append(user) | |
353 | ||
354 | return users | |
355 | ||
857a7836 | 356 | def search(self, q, limit=None): |
ef1960e3 MT |
357 | # Do nothing in test mode |
358 | if self.backend.test: | |
359 | log.warning("Cannot search for users in test mode") | |
360 | return [] | |
361 | ||
857a7836 MT |
362 | res = self._ldap_query( |
363 | "(&" | |
364 | "(objectClass=person)" | |
365 | "(|" | |
366 | "(uid=%s)" | |
367 | "(cn=*%s*)" | |
368 | "(mail=%s)" | |
369 | "(mailAlternateAddress=%s)" | |
370 | ")" | |
371 | ")" % (q, q, q, q), | |
372 | attrlist=("uid",), | |
373 | limit=limit, | |
374 | ) | |
b9c2a52b | 375 | |
857a7836 MT |
376 | # Fetch users |
377 | users = self._get_users(""" | |
378 | SELECT | |
379 | * | |
380 | FROM | |
381 | users | |
382 | WHERE | |
6e0c9ad4 | 383 | deleted_at IS NULL |
857a7836 MT |
384 | AND |
385 | name = ANY(%s) | |
386 | """, [row.get("uid")[0].decode() for row in res], | |
387 | ) | |
f6e6ff79 | 388 | |
857a7836 | 389 | return sorted(users) |
efbd7501 | 390 | |
1648678e MT |
391 | # Pakfire |
392 | ||
393 | @property | |
394 | def pakfire(self): | |
395 | """ | |
396 | This is a c | |
397 | """ | |
398 | user = self.get_by_email("pakfire@ipfire.org") | |
399 | if not user: | |
400 | raise RuntimeError("Missing Pakfire user") | |
401 | ||
402 | return user | |
403 | ||
282e875c MT |
404 | @property |
405 | def top(self): | |
406 | """ | |
407 | Returns the top users (with the most builds in the last year) | |
408 | """ | |
409 | users = self._get_users(""" | |
410 | SELECT | |
411 | DISTINCT users.*, | |
412 | COUNT(builds.id) AS _sort | |
413 | FROM | |
414 | users | |
415 | LEFT JOIN | |
416 | builds ON users.id = builds.owner_id | |
10dead28 MT |
417 | WHERE |
418 | builds.test IS FALSE | |
282e875c MT |
419 | GROUP BY |
420 | users.id | |
421 | ORDER BY | |
422 | _sort DESC | |
423 | LIMIT | |
424 | 30 | |
425 | """, | |
426 | ) | |
427 | ||
428 | return list(users) | |
429 | ||
d48f75f7 MT |
430 | # Push Notifications |
431 | ||
432 | @property | |
433 | def vapid_public_key(self): | |
434 | """ | |
435 | The public part of the VAPID key | |
436 | """ | |
437 | return self.settings.get("vapid-public-key") | |
438 | ||
439 | @property | |
440 | def vapid_private_key(self): | |
441 | """ | |
442 | The private part of the VAPID key | |
443 | """ | |
444 | return self.settings.get("vapid-private-key") | |
445 | ||
446 | async def generate_vapid_keys(self): | |
447 | """ | |
448 | Generates the VAPID keys | |
449 | """ | |
450 | # Do not generate a new key if one exists | |
451 | if self.vapid_public_key and self.vapid_private_key: | |
452 | return | |
453 | ||
454 | with self.db.transaction(): | |
455 | # Generate the private key | |
456 | private_key = await self.backend.command( | |
457 | "openssl", | |
458 | "ecparam", | |
459 | "-name", "prime256v1", | |
460 | "-genkey", | |
461 | "-noout", | |
462 | return_output=True, | |
463 | ) | |
464 | ||
465 | # Generate the public key | |
466 | public_key = await self.backend.command( | |
467 | "openssl", | |
468 | "ec", | |
469 | "-pubout", | |
470 | input=private_key, | |
471 | return_output=True, | |
472 | ) | |
473 | ||
474 | # Store the keys | |
475 | self.settings.set("vapid-public-key", public_key) | |
476 | self.settings.set("vapid-private-key", private_key) | |
477 | ||
478 | log.info("VAPID keys have been successfully generated") | |
479 | ||
480 | @property | |
481 | def application_server_key(self): | |
482 | lines = [] | |
483 | ||
484 | for line in self.vapid_public_key.splitlines(): | |
485 | if line.startswith("-"): | |
486 | continue | |
487 | ||
488 | lines.append(line) | |
489 | ||
490 | # Join everything together | |
491 | key = "".join(lines) | |
492 | ||
493 | # Decode the key | |
494 | key = base64.b64decode(key) | |
495 | ||
496 | # Only take the last bit | |
497 | key = key[-65:] | |
498 | ||
499 | # Encode the key URL-safe | |
500 | key = base64.urlsafe_b64encode(key).strip(b"=") | |
501 | ||
502 | return key | |
503 | ||
9137135a | 504 | |
b9c2a52b MT |
505 | class User(base.DataObject): |
506 | table = "users" | |
9137135a | 507 | |
20d7f5eb MT |
508 | def __repr__(self): |
509 | return "<%s %s>" % (self.__class__.__name__, self.realname) | |
510 | ||
367bfa3a | 511 | def __str__(self): |
dcfada36 | 512 | return self.realname or self.name |
367bfa3a | 513 | |
b0315eb4 MT |
514 | def __hash__(self): |
515 | return hash(self.id) | |
516 | ||
b9c2a52b MT |
517 | def __lt__(self, other): |
518 | if isinstance(other, self.__class__): | |
519 | return self.name < other.name | |
f6e6ff79 | 520 | |
b9c2a52b MT |
521 | elif isinstance(other, str): |
522 | return self.name < other | |
f6e6ff79 | 523 | |
cd849b46 MT |
524 | return NotImplemented |
525 | ||
4e564583 MT |
526 | def to_json(self): |
527 | return { | |
528 | "name" : self.name, | |
529 | } | |
530 | ||
857a7836 MT |
531 | @property |
532 | def name(self): | |
533 | return self.data.name | |
534 | ||
9137135a | 535 | def delete(self): |
b9c2a52b | 536 | self._set_attribute("deleted", True) |
9137135a | 537 | |
5b261ed7 MT |
538 | # Destroy all sessions |
539 | for session in self.sessions: | |
540 | session.destroy() | |
541 | ||
857a7836 | 542 | # Fetch any attributes from LDAP |
9137135a | 543 | |
857a7836 MT |
544 | @lazy_property |
545 | def attrs(self): | |
69922f6b MT |
546 | # Use the stored attributes (only used in the test environment) |
547 | if self.data._attrs: | |
548 | return pickle.loads(self.data._attrs) | |
549 | ||
857a7836 | 550 | return self.backend.users._ldap_get("(uid=%s)" % self.name, attrlist=LDAP_ATTRS) |
9137135a | 551 | |
857a7836 MT |
552 | def _get_attrs(self, key): |
553 | return [v.decode() for v in self.attrs.get(key, [])] | |
9137135a | 554 | |
857a7836 MT |
555 | def _get_attr(self, key): |
556 | for value in self._get_attrs(key): | |
557 | return value | |
9137135a | 558 | |
857a7836 | 559 | # Realname |
f6e6ff79 | 560 | |
d0bce25d | 561 | @property |
857a7836 MT |
562 | def realname(self): |
563 | return self._get_attr("cn") or "" | |
b9c2a52b | 564 | |
26fe80df JS |
565 | @property |
566 | def email(self): | |
857a7836 MT |
567 | """ |
568 | The primary email address | |
569 | """ | |
82f6d079 | 570 | return self._get_attr("mail") |
9137135a | 571 | |
23f86aae MT |
572 | @property |
573 | def email_to(self): | |
574 | """ | |
575 | The name/email address of the user in MIME format | |
576 | """ | |
82f6d079 MT |
577 | return email.utils.formataddr(( |
578 | self.realname or self.name, | |
579 | self.email or "invalid@invalid.tld", | |
580 | )) | |
23f86aae | 581 | |
04a92018 | 582 | def send_email(self, *args, **kwargs): |
7b1479a1 MT |
583 | return self.backend.messages.send_template( |
584 | *args, | |
585 | recipient=self, | |
586 | locale=self.locale, | |
587 | **kwargs, | |
588 | ) | |
68dd077d | 589 | |
18132fad MT |
590 | def _send_welcome_email(self): |
591 | """ | |
592 | Sends a welcome email to the user | |
593 | """ | |
594 | self.send_email("users/messages/welcome.txt") | |
595 | ||
b9c2a52b | 596 | def is_admin(self): |
e304790f | 597 | return self.data.admin is True |
9137135a | 598 | |
857a7836 | 599 | # Locale |
4947da2d | 600 | |
857a7836 MT |
601 | @property |
602 | def locale(self): | |
603 | return tornado.locale.get() | |
9137135a | 604 | |
ba1958a5 JS |
605 | @property |
606 | def deleted(self): | |
607 | return self.data.deleted | |
608 | ||
29256a69 | 609 | # Avatar |
9bf767c3 | 610 | |
29256a69 MT |
611 | def avatar(self, size=512): |
612 | """ | |
613 | Returns a URL to the avatar the user has uploaded | |
614 | """ | |
615 | return "https://people.ipfire.org/users/%s.jpg?size=%s" % (self.name, size) | |
9137135a | 616 | |
09d78b55 MT |
617 | # Permissions |
618 | ||
619 | def get_perms(self): | |
620 | return self.data.perms | |
621 | ||
622 | def set_perms(self, perms): | |
623 | self._set_attribute("perms", perms or []) | |
624 | ||
625 | perms = property(get_perms, set_perms) | |
f6e6ff79 | 626 | |
bc373c9c | 627 | def has_perm(self, user): |
f6e6ff79 | 628 | """ |
bc373c9c MT |
629 | Check, if the given user has the right to perform administrative |
630 | operations on this user. | |
f6e6ff79 | 631 | """ |
bc373c9c MT |
632 | # Anonymous people have no permission |
633 | if user is None: | |
634 | return False | |
635 | ||
636 | # Admins always have permission | |
637 | if user.is_admin(): | |
638 | return True | |
639 | ||
640 | # Users can edit themselves | |
641 | if user == self: | |
f6e6ff79 MT |
642 | return True |
643 | ||
bc373c9c MT |
644 | # No permission |
645 | return False | |
f6e6ff79 | 646 | |
b0315eb4 MT |
647 | @property |
648 | def sessions(self): | |
6c41a6f6 MT |
649 | sessions = self.backend.sessions._get_sessions(""" |
650 | SELECT | |
651 | * | |
652 | FROM | |
653 | sessions | |
654 | WHERE | |
655 | user_id = %s | |
656 | AND | |
657 | valid_until >= NOW() | |
658 | ORDER BY | |
659 | created_at | |
660 | """, self.id, | |
661 | ) | |
662 | ||
663 | return list(sessions) | |
b0315eb4 | 664 | |
c8550381 MT |
665 | # Bugzilla |
666 | ||
667 | async def connect_to_bugzilla(self, api_key): | |
668 | bz = bugtracker.Bugzilla(self.backend, api_key) | |
669 | ||
670 | # Does the API key match with this user? | |
671 | if not self.email == await bz.whoami(): | |
672 | raise ValueError("The API key does not belong to %s" % self) | |
673 | ||
674 | self._set_attribute("bugzilla_api_key", api_key) | |
675 | ||
676 | @lazy_property | |
677 | def bugzilla(self): | |
678 | """ | |
679 | Connection to Bugzilla as this user | |
680 | """ | |
681 | if self.data.bugzilla_api_key: | |
682 | return bugtracker.Bugzilla(self.backend, self.data.bugzilla_api_key) | |
683 | ||
bffb4fc5 MT |
684 | # Build Quota |
685 | ||
686 | def get_daily_build_quota(self): | |
607a2fba | 687 | return self.data.daily_build_quota |
bffb4fc5 MT |
688 | |
689 | def set_daily_build_quota(self, quota): | |
690 | self._set_attribute("daily_build_quota", quota) | |
691 | ||
692 | daily_build_quota = property(get_daily_build_quota, set_daily_build_quota) | |
693 | ||
607a2fba MT |
694 | @property |
695 | def _build_times(self): | |
696 | return self.db.get(""" | |
697 | WITH %s | |
698 | ||
699 | SELECT | |
700 | * | |
701 | FROM | |
702 | user_build_times | |
703 | WHERE | |
704 | user_build_times.user_id = %%s | |
705 | """ % WITH_BUILD_TIMES_CTE, self.id, | |
706 | ) | |
707 | ||
708 | @property | |
709 | def used_daily_build_quota(self): | |
bffb4fc5 MT |
710 | res = self.db.get(""" |
711 | WITH %s | |
712 | ||
713 | SELECT | |
607a2fba | 714 | user_build_times.used AS used |
bffb4fc5 | 715 | FROM |
607a2fba | 716 | user_build_times |
bffb4fc5 | 717 | WHERE |
607a2fba MT |
718 | user_build_times.user_id = %%s |
719 | """ % WITH_USED_BUILD_TIME_CTE, self.id, | |
bffb4fc5 MT |
720 | ) |
721 | ||
607a2fba MT |
722 | if res: |
723 | return res.used | |
bffb4fc5 | 724 | |
458e67b3 MT |
725 | return 0 |
726 | ||
607a2fba MT |
727 | def has_exceeded_build_quota(self): |
728 | if not self.daily_build_quota: | |
729 | return False | |
730 | ||
458e67b3 MT |
731 | if not self.used_daily_build_quota: |
732 | return False | |
733 | ||
607a2fba | 734 | return self.used_daily_build_quota >= self.daily_build_quota |
bffb4fc5 | 735 | |
2275518a | 736 | # Storage Quota |
50533a78 | 737 | |
2275518a MT |
738 | def get_storage_quota(self): |
739 | return self.data.storage_quota | |
50533a78 | 740 | |
2275518a MT |
741 | def set_storage_quota(self, quota): |
742 | self._set_attribute("storage_quota", quota) | |
50533a78 | 743 | |
2275518a | 744 | storage_quota = property(get_storage_quota, set_storage_quota) |
50533a78 | 745 | |
2275518a | 746 | def has_exceeded_storage_quota(self, size=None): |
50533a78 MT |
747 | """ |
748 | Returns True if this user has exceeded their quota | |
749 | """ | |
750 | # Skip quota check if this user has no quota | |
2275518a | 751 | if not self.storage_quota: |
50533a78 MT |
752 | return |
753 | ||
2275518a | 754 | return self.disk_usage + (size or 0) >= self.storage_quota |
50533a78 | 755 | |
2275518a | 756 | def check_storage_quota(self, size=None): |
50533a78 MT |
757 | """ |
758 | Determines the user's disk usage | |
759 | and raises an exception when the user is over quota. | |
760 | """ | |
761 | # Raise QuotaExceededError if this user is over quota | |
2275518a | 762 | if self.has_exceeded_storage_quota(size=size): |
50533a78 MT |
763 | raise QuotaExceededError |
764 | ||
765 | @lazy_property | |
766 | def disk_usage(self): | |
767 | """ | |
768 | Returns the total disk usage of this user | |
769 | """ | |
770 | res = self.db.get(""" | |
b24f64c8 MT |
771 | WITH objects AS ( |
772 | -- Uploads | |
d4ecf08d | 773 | SELECT |
b24f64c8 | 774 | uploads.size AS size |
d4ecf08d | 775 | FROM |
b24f64c8 MT |
776 | uploads |
777 | WHERE | |
778 | uploads.user_id = %s | |
779 | AND | |
780 | uploads.expires_at > CURRENT_TIMESTAMP | |
781 | ||
782 | UNION ALL | |
783 | ||
784 | -- Source Packages | |
785 | SELECT | |
786 | packages.filesize AS size | |
787 | FROM | |
788 | builds | |
789 | LEFT JOIN | |
790 | packages ON builds.pkg_id = packages.id | |
791 | WHERE | |
792 | builds.deleted_at IS NULL | |
793 | AND | |
794 | builds.owner_id = %s | |
795 | AND | |
796 | builds.test IS FALSE | |
797 | AND | |
798 | packages.deleted_at IS NULL | |
799 | ||
800 | UNION ALL | |
801 | ||
802 | -- Binary Packages | |
803 | SELECT | |
804 | packages.filesize AS size | |
805 | FROM | |
806 | builds | |
807 | LEFT JOIN | |
808 | jobs ON builds.id = jobs.build_id | |
809 | LEFT JOIN | |
810 | job_packages ON jobs.id = job_packages.job_id | |
811 | LEFT JOIN | |
812 | packages ON job_packages.pkg_id = packages.id | |
813 | WHERE | |
814 | builds.deleted_at IS NULL | |
815 | AND | |
816 | builds.owner_id = %s | |
817 | AND | |
818 | builds.test IS FALSE | |
819 | AND | |
820 | jobs.deleted_at IS NULL | |
821 | AND | |
822 | packages.deleted_at IS NULL | |
823 | ||
824 | UNION ALL | |
825 | ||
826 | -- Build Logs | |
827 | SELECT | |
828 | jobs.log_size AS size | |
829 | FROM | |
830 | jobs | |
831 | LEFT JOIN | |
832 | builds ON builds.id = jobs.build_id | |
833 | WHERE | |
834 | builds.deleted_at IS NULL | |
835 | AND | |
836 | jobs.deleted_at IS NULL | |
837 | AND | |
838 | builds.owner_id = %s | |
839 | AND | |
840 | jobs.log_size IS NOT NULL | |
d4ecf08d MT |
841 | ) |
842 | ||
50533a78 | 843 | SELECT |
b24f64c8 | 844 | SUM(size) AS disk_usage |
50533a78 | 845 | FROM |
b24f64c8 MT |
846 | objects |
847 | """, self.id, self.id, self.id, self.id, | |
50533a78 MT |
848 | ) |
849 | ||
850 | if res: | |
851 | return res.disk_usage | |
852 | ||
853 | return 0 | |
854 | ||
a3e946d8 MT |
855 | # Builds |
856 | ||
857 | def get_builds(self, name=None, limit=None, offset=None): | |
858 | """ | |
859 | Returns builds by a certain user | |
860 | """ | |
861 | if name: | |
862 | return self.get_builds_by_name(name, limit=limit, offset=offset) | |
863 | ||
864 | builds = self.backend.builds._get_builds(""" | |
865 | SELECT | |
866 | * | |
867 | FROM | |
868 | builds | |
869 | WHERE | |
870 | deleted_at IS NULL | |
871 | AND | |
872 | test IS FALSE | |
873 | AND | |
874 | owner_id = %s | |
875 | ORDER BY | |
876 | created_at DESC | |
877 | LIMIT | |
878 | %s | |
879 | OFFSET | |
880 | %s | |
881 | """, self.id, limit, offset, | |
882 | ) | |
883 | ||
884 | return list(builds) | |
885 | ||
886 | def get_builds_by_name(self, name, limit=None, offset=None): | |
887 | """ | |
888 | Fetches all builds matching name | |
889 | """ | |
890 | builds = self.backend.builds._get_builds(""" | |
891 | SELECT | |
892 | builds.* | |
893 | FROM | |
894 | builds | |
895 | JOIN | |
896 | packages ON builds.pkg_id = packages.id | |
897 | WHERE | |
898 | builds.deleted_at IS NULL | |
899 | AND | |
900 | builds.test IS FALSE | |
901 | AND | |
902 | builds.owner_id = %s | |
903 | AND | |
904 | packages.deleted_at IS NULL | |
905 | AND | |
906 | packages.name = %s | |
907 | LIMIT | |
908 | %s | |
909 | OFFSET | |
910 | %s | |
911 | """, self.id, name, limit, offset, | |
912 | ) | |
913 | ||
914 | return list(builds) | |
915 | ||
282e875c MT |
916 | # Stats |
917 | ||
918 | @lazy_property | |
919 | def total_builds(self): | |
920 | res = self.db.get(""" | |
921 | SELECT | |
922 | COUNT(*) AS builds | |
923 | FROM | |
924 | builds | |
925 | WHERE | |
10dead28 MT |
926 | test IS FALSE |
927 | AND | |
282e875c MT |
928 | owner_id = %s |
929 | """, self.id, | |
930 | ) | |
931 | ||
932 | if res: | |
933 | return res.builds | |
934 | ||
935 | return 0 | |
936 | ||
937 | @lazy_property | |
938 | def total_build_time(self): | |
939 | res = self.db.get(""" | |
940 | SELECT | |
941 | SUM(jobs.finished_at - jobs.started_at) AS total_time | |
942 | FROM | |
943 | jobs | |
944 | LEFT JOIN | |
945 | builds ON jobs.build_id = builds.id | |
946 | WHERE | |
947 | jobs.finished_at IS NOT NULL | |
948 | AND | |
949 | jobs.started_at IS NOT NULL | |
950 | AND | |
951 | builds.owner_id = %s | |
952 | """, self.id, | |
953 | ) | |
954 | ||
955 | if res: | |
956 | return res.total_time | |
957 | ||
f87d64a2 MT |
958 | # Custom repositories |
959 | ||
960 | @property | |
961 | def repos(self): | |
962 | """ | |
963 | Returns all custom repositories | |
964 | """ | |
5fc0715d | 965 | repos = self.backend.repos._get_repositories(""" |
f87d64a2 MT |
966 | SELECT |
967 | * | |
968 | FROM | |
969 | repositories | |
970 | WHERE | |
434113f1 | 971 | deleted_at IS NULL |
f87d64a2 MT |
972 | AND |
973 | owner_id = %s | |
974 | ORDER BY | |
975 | name""", | |
976 | self.id, | |
977 | ) | |
978 | ||
d1be120a MT |
979 | distros = {} |
980 | ||
981 | # Group by distro | |
982 | for repo in repos: | |
983 | try: | |
984 | distros[repo.distro].append(repo) | |
985 | except KeyError: | |
986 | distros[repo.distro] = [repo] | |
987 | ||
988 | return distros | |
f87d64a2 | 989 | |
940a8043 | 990 | def get_repo(self, distro, slug): |
ba8feae9 MT |
991 | # Return the "home" repository if slug is empty |
992 | if not slug: | |
993 | slug = self.name | |
994 | ||
258f9d20 MT |
995 | return self.backend.repos._get_repository(""" |
996 | SELECT | |
997 | * | |
998 | FROM | |
999 | repositories | |
1000 | WHERE | |
434113f1 | 1001 | deleted_at IS NULL |
258f9d20 MT |
1002 | AND |
1003 | owner_id = %s | |
940a8043 MT |
1004 | AND |
1005 | distro_id = %s | |
258f9d20 MT |
1006 | AND |
1007 | slug = %s""", | |
1008 | self.id, | |
940a8043 | 1009 | distro, |
258f9d20 MT |
1010 | slug, |
1011 | ) | |
1012 | ||
962e472b MT |
1013 | @property |
1014 | def uploads(self): | |
1015 | """ | |
1016 | Returns all uploads that belong to this user | |
1017 | """ | |
1018 | uploads = self.backend.uploads._get_uploads(""" | |
1019 | SELECT | |
1020 | * | |
1021 | FROM | |
1022 | uploads | |
1023 | WHERE | |
1024 | user_id = %s | |
1025 | AND | |
1026 | expires_at > CURRENT_TIMESTAMP | |
1027 | ORDER BY | |
1028 | created_at DESC | |
1029 | """, self.id, | |
1030 | ) | |
1031 | ||
1032 | return list(uploads) | |
1033 | ||
d48f75f7 | 1034 | # Push Subscriptions |
b9c2a52b | 1035 | |
d48f75f7 MT |
1036 | def _get_subscriptions(self, query, *args): |
1037 | res = self.db.query(query, *args) | |
1038 | ||
1039 | for row in res: | |
1040 | yield UserPushSubscription(self.backend, row.id, data=row) | |
1041 | ||
1042 | def _get_subscription(self, query, *args): | |
1043 | res = self.db.get(query, *args) | |
1044 | ||
1045 | if res: | |
1046 | return UserPushSubscription(self.backend, res.id, data=res) | |
1047 | ||
1048 | @lazy_property | |
1049 | def subscriptions(self): | |
1050 | subscriptions = self._get_subscriptions(""" | |
1051 | SELECT | |
1052 | * | |
1053 | FROM | |
1054 | user_push_subscriptions | |
1055 | WHERE | |
1056 | deleted_at IS NULL | |
1057 | AND | |
1058 | user_id = %s | |
1059 | ORDER BY | |
1060 | created_at | |
1061 | """, self.id, | |
1062 | ) | |
1063 | ||
1064 | return set(subscriptions) | |
1065 | ||
1066 | async def subscribe(self, endpoint, p256dh, auth, user_agent=None): | |
1067 | """ | |
1068 | Creates a new subscription for this user | |
1069 | """ | |
88dac623 MT |
1070 | _ = self.locale.translate |
1071 | ||
d48f75f7 MT |
1072 | # Decode p256dh |
1073 | if not isinstance(p256dh, bytes): | |
1074 | p256dh = base64.urlsafe_b64decode(p256dh + "==") | |
1075 | ||
1076 | # Decode auth | |
1077 | if not isinstance(auth, bytes): | |
1078 | auth = base64.urlsafe_b64decode(auth + "==") | |
1079 | ||
1080 | subscription = self._get_subscription(""" | |
1081 | INSERT INTO | |
1082 | user_push_subscriptions | |
1083 | ( | |
1084 | user_id, | |
1085 | user_agent, | |
1086 | endpoint, | |
1087 | p256dh, | |
1088 | auth | |
1089 | ) | |
1090 | VALUES | |
1091 | ( | |
1092 | %s, %s, %s, %s, %s | |
1093 | ) | |
1094 | RETURNING * | |
1095 | """, self.id, user_agent, endpoint, p256dh, auth, | |
1096 | ) | |
1097 | ||
1098 | # Send a message | |
88dac623 MT |
1099 | await subscription.send( |
1100 | self._make_message( | |
1101 | _("Hello, %s!") % self, | |
1102 | _("You have successfully subscribed to push notifications."), | |
1103 | ), | |
1104 | ) | |
d48f75f7 MT |
1105 | |
1106 | return subscription | |
1107 | ||
88dac623 | 1108 | async def send_push_message(self, title, body, **kwargs): |
d48f75f7 MT |
1109 | """ |
1110 | Sends a message to all active subscriptions | |
1111 | """ | |
88dac623 MT |
1112 | # Return False if the user has no subscriptions |
1113 | if not self.subscriptions: | |
1114 | return False | |
1115 | ||
1116 | # Format the message | |
1117 | message = self._make_message(title, body, **kwargs) | |
1118 | ||
1119 | # Send the message to all subscriptions | |
d48f75f7 MT |
1120 | async with asyncio.TaskGroup() as tg: |
1121 | for subscription in self.subscriptions: | |
1122 | tg.create_task(subscription.send(message)) | |
1123 | ||
88dac623 MT |
1124 | return True |
1125 | ||
1126 | def _make_message(self, title, body, **kwargs): | |
1127 | """ | |
1128 | Formats a push message with the given attributes and some useful defaults | |
1129 | """ | |
1130 | # Format the message as a dictionary | |
1131 | message = { | |
1132 | "title" : title, | |
1133 | "body" : body, | |
1134 | } | kwargs | |
1135 | ||
1136 | return message | |
1137 | ||
d48f75f7 MT |
1138 | |
1139 | class UserPushSubscription(base.DataObject): | |
1140 | table = "user_push_subscriptions" | |
1141 | ||
1142 | @property | |
1143 | def uuid(self): | |
1144 | """ | |
1145 | UUID | |
1146 | """ | |
1147 | return self.data.uuid | |
1148 | ||
1149 | @property | |
1150 | def created_at(self): | |
1151 | return self.data.created_at | |
1152 | ||
1153 | @property | |
1154 | def deleted_at(self): | |
1155 | return self.data.deleted_at | |
1156 | ||
1157 | def delete(self): | |
1158 | """ | |
1159 | Deletes this subscription | |
1160 | """ | |
1161 | self._set_attribute_now("deleted_at") | |
1162 | ||
1163 | @property | |
1164 | def endpoint(self): | |
1165 | return self.data.endpoint | |
1166 | ||
1167 | @lazy_property | |
1168 | def p256dh(self): | |
1169 | """ | |
1170 | The client's public key | |
1171 | """ | |
1172 | p = cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point( | |
1173 | cryptography.hazmat.primitives.asymmetric.ec.SECP256R1(), bytes(self.data.p256dh), | |
1174 | ) | |
1175 | ||
1176 | return p | |
1177 | ||
1178 | @property | |
1179 | def auth(self): | |
1180 | return bytes(self.data.auth) | |
1181 | ||
1182 | @property | |
1183 | def vapid_private_key(self): | |
1184 | return cryptography.hazmat.primitives.serialization.load_pem_private_key( | |
1185 | self.backend.users.vapid_private_key.encode(), | |
1186 | password=None, | |
1187 | backend=cryptography.hazmat.backends.default_backend(), | |
1188 | ) | |
1189 | ||
1190 | @property | |
1191 | def vapid_public_key(self): | |
1192 | return self.vapid_private_key.public_key() | |
1193 | ||
1194 | async def send(self, message, ttl=0): | |
1195 | """ | |
1196 | Sends a message to the user using the push service | |
1197 | """ | |
1198 | # Convert strings into a message object | |
1199 | if isinstance(message, str): | |
1200 | message = { | |
1201 | "message" : message, | |
1202 | } | |
1203 | ||
1204 | # Convert dict() to JSON | |
1205 | if isinstance(message, dict): | |
1206 | message = json.dumps(message) | |
1207 | ||
1208 | # Encode everything as bytes | |
1209 | if not isinstance(message, bytes): | |
1210 | message = message.encode() | |
1211 | ||
1212 | # Encrypt the message | |
1213 | message = self._encrypt(message) | |
1214 | ||
1215 | # Create a signature | |
1216 | signature = self._sign() | |
1217 | ||
1218 | # Encode the public key | |
1219 | crypto_key = self.b64encode( | |
1220 | self.vapid_public_key.public_bytes( | |
1221 | cryptography.hazmat.primitives.serialization.Encoding.X962, | |
1222 | cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint, | |
1223 | ) | |
1224 | ).decode() | |
1225 | ||
1226 | # Form request headers | |
1227 | headers = { | |
1228 | "Authorization" : "WebPush %s" % signature, | |
1229 | "Crypto-Key" : "p256ecdsa=%s" % crypto_key, | |
1230 | ||
1231 | "Content-Type" : "application/octet-stream", | |
1232 | "Content-Encoding" : "aes128gcm", | |
1233 | "TTL" : "%s" % (ttl or 0), | |
1234 | } | |
1235 | ||
1236 | # Send the request | |
1237 | try: | |
1238 | await self.backend.httpclient.fetch(self.endpoint, method="POST", | |
1239 | headers=headers, body=message) | |
1240 | ||
1241 | except httpclient.HTTPError as e: | |
1242 | # 410 - Gone | |
1243 | # The subscription is no longer valid | |
1244 | if e.code == 410: | |
1245 | # Let's just delete ourselves | |
1246 | self.delete() | |
1247 | return | |
1248 | ||
1249 | # Raise everything else | |
1250 | raise e | |
1251 | ||
1252 | def _sign(self): | |
1253 | elements = [] | |
1254 | ||
1255 | for element in (self._jwt_info, self._jwt_data): | |
1256 | # Format the dictionary | |
1257 | element = json.dumps(element, separators=(',', ':'), sort_keys=True) | |
1258 | ||
1259 | # Encode to bytes | |
1260 | element = element.encode() | |
1261 | ||
1262 | # Encode URL-safe in base64 and remove any padding | |
1263 | element = self.b64encode(element) | |
1264 | ||
1265 | elements.append(element) | |
1266 | ||
1267 | # Concatenate | |
1268 | token = b".".join(elements) | |
1269 | ||
1270 | log.debug("String to sign: %s" % token) | |
1271 | ||
1272 | # Create the signature | |
1273 | signature = self.vapid_private_key.sign( | |
1274 | token, | |
1275 | cryptography.hazmat.primitives.asymmetric.ec.ECDSA( | |
1276 | cryptography.hazmat.primitives.hashes.SHA256(), | |
1277 | ), | |
1278 | ) | |
1279 | ||
1280 | # Decode the signature | |
1281 | r, s = cryptography.hazmat.primitives.asymmetric.utils.decode_dss_signature(signature) | |
1282 | ||
1283 | # Encode the signature in base64 | |
1284 | signature = self.b64encode( | |
1285 | self._num_to_bytes(r, 32) + self._num_to_bytes(s, 32), | |
1286 | ) | |
1287 | ||
1288 | # Put everything together | |
1289 | signature = b"%s.%s" % (token, signature) | |
1290 | signature = signature.decode() | |
1291 | ||
1292 | log.debug("Created signature: %s" % signature) | |
1293 | ||
1294 | return signature | |
1295 | ||
1296 | _jwt_info = { | |
1297 | "typ" : "JWT", | |
1298 | "alg" : "ES256", | |
1299 | } | |
1300 | ||
1301 | @property | |
1302 | def _jwt_data(self): | |
1303 | # Parse the URL | |
1304 | url = urllib.parse.urlparse(self.endpoint) | |
1305 | ||
1306 | # Let the signature expire after 12 hours | |
1307 | expires = time.time() + (12 * 3600) | |
1308 | ||
1309 | return { | |
1310 | "aud" : "%s://%s" % (url.scheme, url.netloc), | |
1311 | "exp" : int(expires), | |
1312 | "sub" : "mailto:info@ipfire.org", | |
1313 | } | |
1314 | ||
1315 | @staticmethod | |
1316 | def _num_to_bytes(n, pad_to): | |
1317 | """ | |
1318 | Returns the byte representation of an integer, in big-endian order. | |
1319 | """ | |
1320 | h = "%x" % n | |
1321 | ||
1322 | r = binascii.unhexlify("0" * (len(h) % 2) + h) | |
1323 | return b"\x00" * (pad_to - len(r)) + r | |
1324 | ||
1325 | @staticmethod | |
1326 | def _serialize_key(key): | |
1327 | if isinstance(key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey): | |
1328 | return key.private_bytes( | |
1329 | cryptography.hazmat.primitives.serialization.Encoding.DER, | |
1330 | cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8, | |
1331 | cryptography.hazmat.primitives.serialization.NoEncryption(), | |
1332 | ) | |
1333 | ||
1334 | return key.public_bytes( | |
1335 | cryptography.hazmat.primitives.serialization.Encoding.X962, | |
1336 | cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint, | |
1337 | ) | |
1338 | ||
1339 | @staticmethod | |
1340 | def b64encode(data): | |
1341 | return base64.urlsafe_b64encode(data).strip(b"=") | |
1342 | ||
1343 | def _encrypt(self, message): | |
1344 | """ | |
1345 | This is an absolutely ugly monster of a function which will sign the message | |
1346 | """ | |
1347 | headers = {} | |
1348 | ||
1349 | # Generate some salt | |
1350 | salt = os.urandom(16) | |
1351 | ||
1352 | record_size = 4096 | |
1353 | chunk_size = record_size - 17 | |
1354 | ||
1355 | # Generate an ephemeral server key | |
1356 | server_private_key = cryptography.hazmat.primitives.asymmetric.ec.generate_private_key( | |
1357 | cryptography.hazmat.primitives.asymmetric.ec.SECP256R1, | |
1358 | cryptography.hazmat.backends.default_backend(), | |
1359 | ) | |
1360 | server_public_key = server_private_key.public_key() | |
1361 | ||
1362 | context = b"WebPush: info\x00" | |
1363 | ||
1364 | # Serialize the client's public key | |
1365 | context += self.p256dh.public_bytes( | |
1366 | cryptography.hazmat.primitives.serialization.Encoding.X962, | |
1367 | cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint, | |
1368 | ) | |
1369 | ||
1370 | # Serialize the server's public key | |
1371 | context += server_public_key.public_bytes( | |
1372 | cryptography.hazmat.primitives.serialization.Encoding.X962, | |
1373 | cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint, | |
1374 | ) | |
1375 | ||
1376 | # Perform key derivation with ECDH | |
1377 | secret = server_private_key.exchange( | |
1378 | cryptography.hazmat.primitives.asymmetric.ec.ECDH(), self.p256dh, | |
1379 | ) | |
1380 | ||
1381 | # Derive more stuff | |
1382 | hkdf_auth = cryptography.hazmat.primitives.kdf.hkdf.HKDF( | |
1383 | algorithm=cryptography.hazmat.primitives.hashes.SHA256(), | |
1384 | length=32, | |
1385 | salt=self.auth, | |
1386 | info=context, | |
1387 | backend=cryptography.hazmat.backends.default_backend(), | |
1388 | ) | |
1389 | secret = hkdf_auth.derive(secret) | |
1390 | ||
1391 | # Derive the signing key | |
1392 | hkdf_key = cryptography.hazmat.primitives.kdf.hkdf.HKDF( | |
1393 | algorithm=cryptography.hazmat.primitives.hashes.SHA256(), | |
1394 | length=16, | |
1395 | salt=salt, | |
1396 | info=b"Content-Encoding: aes128gcm\x00", | |
1397 | backend=cryptography.hazmat.backends.default_backend(), | |
1398 | ) | |
1399 | encryption_key = hkdf_key.derive(secret) | |
1400 | ||
1401 | # Derive a nonce | |
1402 | hkdf_nonce = cryptography.hazmat.primitives.kdf.hkdf.HKDF( | |
1403 | algorithm=cryptography.hazmat.primitives.hashes.SHA256(), | |
1404 | length=12, | |
1405 | salt=salt, | |
1406 | info=b"Content-Encoding: nonce\x00", | |
1407 | backend=cryptography.hazmat.backends.default_backend(), | |
1408 | ) | |
1409 | nonce = hkdf_nonce.derive(secret) | |
1410 | ||
1411 | result = b"" | |
1412 | chunks = 0 | |
1413 | ||
1414 | while True: | |
1415 | # Fetch a chunk | |
1416 | chunk, message = message[:chunk_size], message[chunk_size:] | |
1417 | if not chunk: | |
1418 | break | |
1419 | ||
1420 | # Is this the last chunk? | |
1421 | last = not message | |
1422 | ||
1423 | # Encrypt the chunk | |
1424 | result += self._encrypt_chunk(encryption_key, nonce, chunks, chunk, last) | |
1425 | ||
1426 | # Kepp counting... | |
1427 | chunks += 1 | |
1428 | ||
1429 | # Fetch the public key | |
1430 | key_id = server_public_key.public_bytes( | |
1431 | cryptography.hazmat.primitives.serialization.Encoding.X962, | |
1432 | cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint, | |
1433 | ) | |
1434 | ||
1435 | # Join the entire message together | |
1436 | message = [ | |
1437 | salt, | |
1438 | struct.pack("!L", record_size), | |
1439 | struct.pack("!B", len(key_id)), | |
1440 | key_id, | |
1441 | result, | |
1442 | ] | |
1443 | ||
1444 | return b"".join(message) | |
1445 | ||
1446 | def _encrypt_chunk(self, key, nonce, counter, chunk, last=False): | |
1447 | """ | |
1448 | Encrypts one chunk | |
1449 | """ | |
1450 | # Make the IV | |
1451 | iv = self._make_iv(nonce, counter) | |
1452 | ||
1453 | log.debug("Encrypting chunk %s: length = %s" % (counter + 1, len(chunk))) | |
1454 | ||
1455 | if last: | |
1456 | chunk += b"\x02" | |
1457 | else: | |
1458 | chunk += b"\x01" | |
1459 | ||
1460 | # Setup AES GCM | |
1461 | cipher = cryptography.hazmat.primitives.ciphers.Cipher( | |
1462 | cryptography.hazmat.primitives.ciphers.algorithms.AES128(key), | |
1463 | cryptography.hazmat.primitives.ciphers.modes.GCM(iv), | |
1464 | backend=cryptography.hazmat.backends.default_backend(), | |
1465 | ) | |
1466 | ||
1467 | # Get the encryptor | |
1468 | encryptor = cipher.encryptor() | |
1469 | ||
1470 | # Encrypt the chunk | |
1471 | chunk = encryptor.update(chunk) | |
1472 | ||
1473 | # Finalize this round | |
1474 | chunk += encryptor.finalize() + encryptor.tag | |
1475 | ||
1476 | return chunk | |
1477 | ||
1478 | @staticmethod | |
1479 | def _make_iv(base, counter): | |
1480 | mask, = struct.unpack("!Q", base[4:]) | |
1481 | ||
1482 | return base[:4] + struct.pack("!Q", counter ^ mask) |