]> git.ipfire.org Git - pbs.git/blame_incremental - src/buildservice/users.py
builds: Load all builds with the group
[pbs.git] / src / buildservice / users.py
... / ...
CommitLineData
1#!/usr/bin/python3
2
3import asyncio
4import base64
5import binascii
6import cryptography.hazmat.backends
7import cryptography.hazmat.primitives.asymmetric.ec
8import cryptography.hazmat.primitives.asymmetric.utils
9import cryptography.hazmat.primitives.ciphers
10import cryptography.hazmat.primitives.ciphers.aead
11import cryptography.hazmat.primitives.hashes
12import cryptography.hazmat.primitives.kdf.hkdf
13import cryptography.hazmat.primitives.serialization
14import datetime
15import email.utils
16import functools
17import json
18import ldap
19import logging
20import os
21import pickle
22import struct
23import threading
24import time
25import urllib.parse
26
27import tornado.locale
28
29import sqlalchemy
30from sqlalchemy import BigInteger, Boolean, Column, DateTime, ForeignKey, Integer
31from sqlalchemy import Interval, LargeBinary, Text, UUID
32
33from . import base
34from . import bugtracker
35from . import builds
36from . import database
37from . import httpclient
38from . import jobs
39from . import packages
40from . import repos
41from . import uploads
42
43from .decorators import *
44
45DEFAULT_STORAGE_QUOTA = 256 * 1024 * 1024 # 256 MiB
46
47# Setup logging
48log = logging.getLogger("pbs.users")
49
50# A list of LDAP attributes that we fetch
51LDAP_ATTRS = (
52 # UID
53 "uid",
54
55 # Common Name
56 "cn",
57
58 # First & Last Name
59 "givenName", "sn",
60
61 # Email Addresses
62 "mail",
63 "mailAlternateAddress",
64)
65
66class QuotaExceededError(Exception):
67 pass
68
69class Users(base.Object):
70 def init(self):
71 # Initialize thread-local storage
72 self.local = threading.local()
73
74 @property
75 def ldap(self):
76 if not hasattr(self.local, "ldap"):
77 # Fetch the LDAP URI
78 ldap_uri = self.backend.config.get("ldap", "uri")
79
80 log.debug("Connecting to %s..." % ldap_uri)
81
82 # Establish LDAP connection
83 self.local.ldap = ldap.initialize(ldap_uri)
84
85 return self.local.ldap
86
87 async def __aiter__(self):
88 users = await self._get_users("""
89 SELECT
90 *
91 FROM
92 users
93 WHERE
94 deleted_at IS NULL
95 ORDER BY
96 name
97 """,
98 )
99
100 return aiter(users)
101
102 def _ldap_query(self, query, attrlist=None, limit=0, search_base=None):
103 search_base = self.backend.config.get("ldap", "base")
104
105 log.debug("Performing LDAP query (%s): %s" % (search_base, query))
106
107 t = time.time()
108
109 # Ask for up to 512 results being returned at a time
110 page_control = ldap.controls.SimplePagedResultsControl(True, size=512, cookie="")
111
112 results = []
113 pages = 0
114
115 # Perform the search
116 while True:
117 response = self.ldap.search_ext(search_base,
118 ldap.SCOPE_SUBTREE, query, attrlist=attrlist, sizelimit=limit,
119 serverctrls=[page_control],
120 )
121
122 # Fetch all results
123 type, data, rmsgid, serverctrls = self.ldap.result3(response)
124
125 # Append to local copy
126 results += data
127 pages += 1
128
129 controls = [c for c in serverctrls
130 if c.controlType == ldap.controls.SimplePagedResultsControl.controlType]
131
132 if not controls:
133 break
134
135 # Set the cookie for more results
136 page_control.cookie = controls[0].cookie
137
138 # There are no more results
139 if not page_control.cookie:
140 break
141
142 # Log time it took to perform the query
143 log.debug("Query took %.2fms (%s page(s))" % ((time.time() - t) * 1000.0, pages))
144
145 # Return all attributes (without the DN)
146 return [attrs for dn, attrs in results]
147
148 def _ldap_get(self, *args, **kwargs):
149 results = self._ldap_query(*args, **kwargs)
150
151 # No result
152 if not results:
153 return {}
154
155 # Too many results?
156 elif len(results) > 1:
157 raise OverflowError("Too many results returned for ldap_get()")
158
159 return results[0]
160
161 async def create(self, name, notify=False, storage_quota=None):
162 """
163 Creates a new user
164 """
165 # Set default for storage quota
166 if storage_quota is None:
167 storage_quota = DEFAULT_STORAGE_QUOTA
168
169 # Insert into database
170 user = await self.db.insert(
171 User,
172 name = name,
173 storage_quota = storage_quota,
174 )
175
176 log.debug("Created user %s" % user)
177
178 # Send a welcome email
179 if notify:
180 await user._send_welcome_email()
181
182 return user
183
184 async def get_by_name(self, name):
185 """
186 Fetch a user by its username
187 """
188 stmt = (
189 sqlalchemy
190 .select(User)
191 .where(
192 User.deleted_at == None,
193 User.name == name,
194 )
195 )
196
197 # Fetch the user from the database
198 user = await self.db.fetch_one(stmt)
199 if user:
200 return user
201
202 # Do nothing in test mode
203 if self.backend.test:
204 log.warning("Cannot use get_by_name test mode")
205 return
206
207 # Search in LDAP
208 res = self._ldap_get(
209 "(&"
210 "(objectClass=person)"
211 "(uid=%s)"
212 ")" % name,
213 attrlist=("uid",),
214 )
215 if not res:
216 return
217
218 # Fetch the UID
219 uid = res.get("uid")[0].decode()
220
221 # Create a new user
222 return await self.create(uid)
223
224 async def get_by_email(self, mail):
225 # Strip any excess stuff from the email address
226 name, mail = email.utils.parseaddr(mail)
227
228 # Do nothing in test mode
229 if self.backend.test:
230 log.warning("Cannot use get_by_email in test mode")
231 return
232
233 # Search in LDAP
234 try:
235 res = self._ldap_get(
236 "(&"
237 "(objectClass=person)"
238 "(|"
239 "(mail=%s)"
240 "(mailAlternateAddress=%s)"
241 ")"
242 ")" % (mail, mail),
243 attrlist=("uid",),
244 )
245
246 except OverflowError as e:
247 raise OverflowError("Too many results for search for %s" % mail) from e
248
249 # No results
250 if not res:
251 return
252
253 # Fetch the UID
254 uid = res.get("uid")[0].decode()
255
256 return await self.get_by_name(uid)
257
258 async def _search_by_email(self, mails, include_missing=True):
259 """
260 Takes a list of email addresses and returns all users that could be found
261 """
262 users = []
263
264 for mail in mails:
265 user = await self.get_by_email(mail)
266
267 # Include the search string if no user could be found
268 if not user and include_missing:
269 user = mail
270
271 # Skip any duplicates
272 if user in users:
273 continue
274
275 users.append(user)
276
277 return users
278
279 async def search(self, q, limit=None):
280 # Do nothing in test mode
281 if self.backend.test:
282 log.warning("Cannot search for users in test mode")
283 return []
284
285 # Search for an exact match
286 user = await self.get_by_name(q)
287 if user:
288 return [user]
289
290 res = self._ldap_query(
291 "(&"
292 "(objectClass=person)"
293 "(|"
294 "(uid=%s)"
295 "(cn=*%s*)"
296 "(mail=%s)"
297 "(mailAlternateAddress=%s)"
298 ")"
299 ")" % (q, q, q, q),
300 attrlist=("uid",),
301 limit=limit,
302 )
303
304 # Fetch users
305 stmt = (
306 sqlalchemy
307 .select(User)
308 .where(
309 User.deleted_at == None,
310 User.name in [row.get("uid")[0].decode() for row in res],
311 )
312 .order_by(
313 User.name,
314 )
315 )
316
317 # Return as list
318 return await self.db.fetch_as_list(stmt)
319
320 @functools.cached_property
321 def build_counts(self):
322 """
323 Returns a CTE that maps the user ID and the total number of builds
324 """
325 return (
326 sqlalchemy
327 .select(
328 # User ID
329 builds.Build.owner_id.label("user_id"),
330
331 # Count all builds
332 sqlalchemy.func.count(
333 builds.Build.id
334 ).label("count"),
335 )
336 .where(
337 builds.Build.owner_id != None,
338 builds.Build.test == False,
339 )
340 .group_by(
341 builds.Build.owner_id,
342 )
343 .cte("build_counts")
344 )
345
346 async def get_top(self, limit=50):
347 """
348 Returns the top users (with the most builds)
349 """
350 stmt = (
351 sqlalchemy
352 .select(User)
353 .join(
354 self.build_counts,
355 self.build_counts.c.user_id == User.id,
356 )
357 .where(
358 User.deleted_at == None,
359 )
360 .order_by(
361 self.build_counts.c.count.desc(),
362 )
363 .limit(50)
364 )
365
366 # Run the query
367 return await self.db.fetch_as_list(stmt)
368
369 @functools.cached_property
370 def build_times(self):
371 """
372 This is a CTE to easily access a user's consumed build time in the last 24 hours
373 """
374 return (
375 sqlalchemy
376
377 .select(
378 # Fetch the user by its ID
379 User.id.label("user_id"),
380
381 # Sum up the total build time
382 sqlalchemy.func.sum(
383 sqlalchemy.func.coalesce(
384 jobs.Job.finished_at,
385 sqlalchemy.func.current_timestamp()
386 )
387 - jobs.Job.started_at,
388 ).label("used_build_time"),
389 )
390
391 # Join builds & jobs
392 .join(
393 builds.Build,
394 builds.Build.owner_id == User.id,
395 )
396 .join(
397 jobs.Job,
398 jobs.Job.build_id == builds.Build.id,
399 )
400
401 # Filter out some things
402 .where(
403 User.deleted_at == None,
404 User.daily_build_quota != None,
405
406 # Jobs must have been started
407 jobs.Job.started_at != None,
408
409 sqlalchemy.or_(
410 jobs.Job.finished_at == None,
411 jobs.Job.finished_at >=
412 sqlalchemy.func.current_timestamp() - datetime.timedelta(hours=24),
413 ),
414 )
415
416 # Group by user
417 .group_by(
418 User.id,
419 )
420
421 # Make this into a CTE
422 .cte("user_build_times")
423 )
424
425 @functools.cached_property
426 def exceeded_quotas(self):
427 return (
428 sqlalchemy
429
430 .select(
431 User.id,
432 self.build_times.c.used_build_time,
433 )
434 .where(
435 #User.daily_build_quota != None,
436 self.build_times.c.used_build_time >= User.daily_build_quota,
437 )
438
439 # Make this into a CTE
440 .cte("user_exceeded_quotas")
441 )
442
443 # Push Notifications
444
445 @property
446 def vapid_public_key(self):
447 """
448 The public part of the VAPID key
449 """
450 return self.backend.config.get("vapid", "public-key")
451
452 @property
453 def vapid_private_key(self):
454 """
455 The private part of the VAPID key
456 """
457 return self.backend.config.get("vapid", "private-key")
458
459 @functools.cache
460 def get_application_server_key(self):
461 """
462 Generates the key that we are sending to the client
463 """
464 lines = []
465
466 for line in self.vapid_public_key.splitlines():
467 if line.startswith("-"):
468 continue
469
470 lines.append(line)
471
472 # Join everything together
473 key = "".join(lines)
474
475 # Decode the key
476 key = base64.b64decode(key)
477
478 # Only take the last bit
479 key = key[-65:]
480
481 # Encode the key URL-safe
482 key = base64.urlsafe_b64encode(key).strip(b"=")
483
484 # Return as string
485 return key.decode()
486
487
488class User(database.Base, database.BackendMixin, database.SoftDeleteMixin):
489 __tablename__ = "users"
490
491 def __str__(self):
492 return self.realname or self.name
493
494 def __hash__(self):
495 return hash(self.id)
496
497 def __lt__(self, other):
498 if isinstance(other, self.__class__):
499 return self.name < other.name
500
501 elif isinstance(other, str):
502 return self.name < other
503
504 return NotImplemented
505
506 def to_json(self):
507 return {
508 "name" : self.name,
509 }
510
511 # ID
512
513 id = Column(Integer, primary_key=True)
514
515 # Name
516
517 name = Column(Text, nullable=False)
518
519 # Link
520
521 @property
522 def link(self):
523 return "/users/%s" % self.name
524
525 async def delete(self):
526 await self._set_attribute("deleted", True)
527
528 # Destroy all sessions
529 for session in self.sessions:
530 session.destroy()
531
532 # Fetch any attributes from LDAP
533
534 @functools.cached_property
535 def attrs(self):
536 # Use the stored attributes (only used in the test environment)
537 #if self.data._attrs:
538 # return pickle.loads(self.data._attrs)
539 #
540 return self.backend.users._ldap_get("(uid=%s)" % self.name, attrlist=LDAP_ATTRS)
541
542 def _get_attrs(self, key):
543 return [v.decode() for v in self.attrs.get(key, [])]
544
545 def _get_attr(self, key):
546 for value in self._get_attrs(key):
547 return value
548
549 # Realname
550
551 @property
552 def realname(self):
553 return self._get_attr("cn") or ""
554
555 @property
556 def email(self):
557 """
558 The primary email address
559 """
560 return self._get_attr("mail")
561
562 @property
563 def email_to(self):
564 """
565 The name/email address of the user in MIME format
566 """
567 return email.utils.formataddr((
568 self.realname or self.name,
569 self.email or "invalid@invalid.tld",
570 ))
571
572 async def send_email(self, *args, **kwargs):
573 return await self.backend.messages.send_template(
574 *args,
575 recipient=self,
576 locale=self.locale,
577 **kwargs,
578 )
579
580 async def _send_welcome_email(self):
581 """
582 Sends a welcome email to the user
583 """
584 await self.send_email("users/messages/welcome.txt")
585
586 # Admin
587
588 admin = Column(Boolean, nullable=False, default=False)
589
590 # Admin?
591
592 def is_admin(self):
593 return self.admin is True
594
595 # Locale
596
597 @property
598 def locale(self):
599 return tornado.locale.get()
600
601 # Avatar
602
603 def avatar(self, size=512):
604 """
605 Returns a URL to the avatar the user has uploaded
606 """
607 return "https://people.ipfire.org/users/%s.jpg?size=%s" % (self.name, size)
608
609 # Permissions
610
611 def has_perm(self, user):
612 """
613 Check, if the given user has the right to perform administrative
614 operations on this user.
615 """
616 # Anonymous people have no permission
617 if user is None:
618 return False
619
620 # Admins always have permission
621 if user.is_admin():
622 return True
623
624 # Users can edit themselves
625 if user == self:
626 return True
627
628 # No permission
629 return False
630
631 # Sessions
632
633 sessions = sqlalchemy.orm.relationship("Session", back_populates="user")
634
635 # Bugzilla API Key
636
637 bugzilla_api_key = Column(Text)
638
639 # Bugzilla
640
641 async def connect_to_bugzilla(self, api_key):
642 bz = bugtracker.Bugzilla(self.backend, api_key)
643
644 # Does the API key match with this user?
645 if not self.email == await bz.whoami():
646 raise ValueError("The API key does not belong to %s" % self)
647
648 # Store the API key
649 self.bugzilla_api_key = api_key
650
651 @functools.cached_property
652 def bugzilla(self):
653 """
654 Connection to Bugzilla as this user
655 """
656 if self.bugzilla_api_key:
657 return bugtracker.Bugzilla(self.backend, self.bugzilla_api_key)
658
659 # Build Quota
660
661 daily_build_quota = Column(Interval)
662
663 # Build Times
664
665 async def get_used_daily_build_quota(self):
666 # Fetch the build time from the CTE
667 stmt = (
668 sqlalchemy
669 .select(
670 self.backend.users.build_times.c.used_build_time,
671 )
672 .where(
673 self.backend.users.build_times.c.user_id == self.id,
674 )
675 )
676
677 # Fetch the result
678 return await self.db.select_one(stmt, "used_build_time") or datetime.timedelta(0)
679
680 async def has_exceeded_build_quota(self):
681 if not self.daily_build_quota:
682 return False
683
684 return await self.get_used_daily_build_quota() >= self.daily_build_quota
685
686 # Storage Quota
687
688 storage_quota = Column(BigInteger)
689
690 async def has_exceeded_storage_quota(self, size=None):
691 """
692 Returns True if this user has exceeded their quota
693 """
694 # Skip quota check if this user has no quota
695 if not self.storage_quota:
696 return
697
698 return await self.get_disk_usage() + (size or 0) >= self.storage_quota
699
700 async def check_storage_quota(self, size=None):
701 """
702 Determines the user's disk usage
703 and raises an exception when the user is over quota.
704 """
705 # Raise QuotaExceededError if this user is over quota
706 if self.has_exceeded_storage_quota(size=size):
707 raise QuotaExceededError
708
709 async def get_disk_usage(self):
710 """
711 Returns the total disk usage of this user
712 """
713 source_packages = sqlalchemy.orm.aliased(packages.Package)
714 binary_packages = sqlalchemy.orm.aliased(packages.Package)
715
716 # Uploads
717 upload_disk_usage = (
718 sqlalchemy
719 .select(
720 uploads.Upload.size
721 )
722 .where(
723 uploads.Upload.user == self,
724 uploads.Upload.expires_at > sqlalchemy.func.current_timestamp(),
725 )
726 )
727
728 # Source Packages
729 source_package_disk_usage = (
730 sqlalchemy
731 .select(
732 source_packages.filesize
733 )
734 .select_from(
735 builds.Build,
736 )
737 .join(
738 source_packages,
739 source_packages.id == builds.Build.pkg_id,
740 )
741 .where(
742 # All objects must exist
743 source_packages.deleted_at == None,
744 builds.Build.deleted_at == None,
745
746 # Don't consider test builds
747 builds.Build.test == False,
748
749 # The build must be owned by the user
750 builds.Build.owner == self,
751 )
752 )
753
754 # Binary Packages
755 binary_package_disk_usage = (
756 sqlalchemy
757 .select(
758 binary_packages.filesize,
759 )
760 .select_from(
761 builds.Build,
762 )
763 .join(
764 jobs.Job,
765 jobs.Job.build_id == builds.Build.id,
766 )
767 .join(
768 jobs.JobPackage,
769 jobs.JobPackage.job_id == jobs.Job.id,
770 )
771 .join(
772 binary_packages,
773 binary_packages.id == jobs.JobPackage.pkg_id,
774 )
775 .where(
776 # All objects must exist
777 binary_packages.deleted_at == None,
778 builds.Build.deleted_at == None,
779 jobs.Job.deleted_at == None,
780
781 # Don't consider test builds
782 builds.Build.test == False,
783
784 # The build must be owned by the user
785 builds.Build.owner == self,
786 )
787 )
788
789 # Build Logs
790 build_log_disk_usage = (
791 sqlalchemy
792 .select(
793 jobs.Job.log_size
794 )
795 .select_from(
796 builds.Build,
797 )
798 .join(
799 jobs.Job,
800 jobs.Job.build_id == builds.Build.id,
801 )
802 .where(
803 # All objects must exist
804 builds.Build.deleted_at == None,
805 jobs.Job.deleted_at == None,
806
807 # Don't consider test builds
808 builds.Build.test == False,
809
810 # The build must be owned by the user
811 builds.Build.owner == self,
812 )
813 )
814
815 # Pull everything together
816 disk_usage = (
817 sqlalchemy
818 .union_all(
819 upload_disk_usage,
820 source_package_disk_usage,
821 binary_package_disk_usage,
822 build_log_disk_usage,
823 )
824 .cte("disk_usage")
825 )
826
827 # Add it all up
828 stmt = (
829 sqlalchemy
830 .select(
831 sqlalchemy.func.sum(
832 disk_usage.c.size
833 ).label("disk_usage"),
834 )
835 )
836
837 # Run the query
838 return await self.db.select_one(stmt, "disk_usage") or 0
839
840 # Stats
841
842 async def get_total_builds(self):
843 stmt = (
844 sqlalchemy
845 .select(
846 self.backend.users.build_counts.c.count.label("count"),
847 )
848 .select_from(self.backend.users.build_counts)
849 .where(
850 self.backend.users.build_counts.c.user_id == self.id,
851 )
852 )
853
854 # Run the query
855 return await self.db.select_one(stmt, "count") or 0
856
857 async def get_total_build_time(self):
858 """
859 Returns the total build time
860 """
861 stmt = (
862 sqlalchemy
863 .select(
864 sqlalchemy.func.sum(
865 sqlalchemy.func.coalesce(
866 jobs.Job.finished_at,
867 sqlalchemy.func.current_timestamp()
868 )
869 - jobs.Job.started_at,
870 ).label("total_build_time")
871 )
872 .join(
873 builds.Build,
874 builds.Build.id == jobs.Job.build_id,
875 )
876 .where(
877 jobs.Job.started_at != None,
878 builds.Build.owner == self,
879 )
880 )
881
882 return await self.db.select_one(stmt, "total_build_time")
883
884 # Custom repositories
885
886 async def get_repos(self, distro=None):
887 """
888 Returns all custom repositories
889 """
890 stmt = (
891 sqlalchemy
892 .select(repos.Repo)
893 .where(
894 repos.Repo.deleted_at == None,
895 repos.Repo.owner == self,
896 )
897 .order_by(
898 repos.Repo.name,
899 )
900 )
901
902 # Filter by distribution
903 if distro:
904 stmt = stmt.where(
905 repos.Repo.distro == distro,
906 )
907
908 return await self.db.fetch_as_list(stmt)
909
910 async def get_repo(self, distro, slug=None):
911 """
912 Fetches a single repository
913 """
914 # Return the "home" repository if slug is empty
915 if slug is None:
916 slug = self.name
917
918 stmt = (
919 sqlalchemy
920 .select(repos.Repo)
921 .where(
922 repos.Repo.deleted_at == None,
923 repos.Repo.owner == self,
924 repos.Repo.distro == distro,
925 repos.Repo.slug == slug,
926 )
927 )
928
929 return await self.db.fetch_one(stmt)
930
931 # Uploads
932
933 def get_uploads(self):
934 """
935 Returns all uploads that belong to this user
936 """
937 stmt = (
938 sqlalchemy
939 .select(uploads.Upload)
940 .where(
941 uploads.Upload.user == self,
942 uploads.Upload.expires_at > sqlalchemy.func.current_timestamp(),
943 )
944 .order_by(
945 uploads.Upload.created_at.desc(),
946 )
947 )
948
949 return self.db.fetch(stmt)
950
951 # Push Subscriptions
952
953 async def is_subscribed(self):
954 """
955 Returns True if the user is subscribed.
956 """
957 subscriptions = await self.get_subscriptions()
958
959 return True if subscriptions else False
960
961 async def get_subscriptions(self):
962 """
963 Fetches all current subscriptions
964 """
965 stmt = (
966 sqlalchemy
967 .select(
968 UserPushSubscription,
969 )
970 .where(
971 UserPushSubscription.user == self,
972 )
973 .order_by(
974 UserPushSubscription.created_at.asc(),
975 )
976 )
977
978 return await self.db.fetch_as_list(stmt)
979
980 async def subscribe(self, endpoint, p256dh, auth, user_agent=None):
981 """
982 Creates a new subscription for this user
983 """
984 _ = self.locale.translate
985
986 # Decode p256dh
987 if not isinstance(p256dh, bytes):
988 p256dh = base64.urlsafe_b64decode(p256dh + "==")
989
990 # Decode auth
991 if not isinstance(auth, bytes):
992 auth = base64.urlsafe_b64decode(auth + "==")
993
994 # Insert into the database
995 subscription = await self.db.insert(
996 UserPushSubscription,
997 user = self,
998 user_agent = user_agent,
999 endpoint = endpoint,
1000 p256dh = p256dh,
1001 auth = auth,
1002 )
1003
1004 # Log action
1005 log.info("%s subscribed to push notifications" % self)
1006
1007 # Send a message
1008 await subscription.send(
1009 _("Hello, %s!") % self,
1010 _("You have successfully subscribed to push notifications."),
1011 )
1012
1013 return subscription
1014
1015 async def send_push_message(self, *args, **kwargs):
1016 """
1017 Sends a message to all active subscriptions
1018 """
1019 subscriptions = await self.get_subscriptions()
1020
1021 # Return early if there are no subscriptions
1022 if not subscriptions:
1023 return False
1024
1025 # Send the message to all subscriptions
1026 for subscription in subscriptions:
1027 await subscription.send(*args, **kwargs)
1028
1029 return True
1030
1031
1032class UserPushSubscription(database.Base, database.BackendMixin):
1033 __tablename__ = "user_push_subscriptions"
1034
1035 # ID
1036
1037 id = Column(Integer, primary_key=True)
1038
1039 # User ID
1040
1041 user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
1042
1043 # User
1044
1045 user = sqlalchemy.orm.relationship("User", lazy="joined", innerjoin=True)
1046
1047 # UUID
1048
1049 uuid = Column(UUID, unique=True, nullable=False,
1050 server_default=sqlalchemy.func.gen_random_uuid())
1051
1052 # Created At
1053
1054 created_at = Column(DateTime(timezone=False), nullable=False,
1055 server_default=sqlalchemy.func.current_timestamp())
1056
1057 # User Agent
1058
1059 user_agent = Column(Text)
1060
1061 # Endpoint
1062
1063 endpoint = Column(Text, nullable=False)
1064
1065 # P256DH
1066
1067 p256dh = Column(LargeBinary, nullable=False)
1068
1069 # Auth
1070
1071 auth = Column(LargeBinary, nullable=False)
1072
1073 @property
1074 def vapid_private_key(self):
1075 return cryptography.hazmat.primitives.serialization.load_pem_private_key(
1076 self.backend.users.vapid_private_key.encode(),
1077 password=None,
1078 backend=cryptography.hazmat.backends.default_backend(),
1079 )
1080
1081 @property
1082 def vapid_public_key(self):
1083 return self.vapid_private_key.public_key()
1084
1085 async def send(self, title, body, ttl=None):
1086 """
1087 Sends a message to the user using the push service
1088 """
1089 message = {
1090 "title" : title,
1091 "body" : body,
1092 }
1093
1094 # Convert dict() to JSON
1095 message = json.dumps(message)
1096
1097 # Encrypt the message
1098 message = self._encrypt(message)
1099
1100 # Create a signature
1101 signature = self._sign()
1102
1103 # Encode the public key
1104 crypto_key = self.b64encode(
1105 self.vapid_public_key.public_bytes(
1106 cryptography.hazmat.primitives.serialization.Encoding.X962,
1107 cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint,
1108 )
1109 ).decode()
1110
1111 # Form request headers
1112 headers = {
1113 "Authorization" : "WebPush %s" % signature,
1114 "Crypto-Key" : "p256ecdsa=%s" % crypto_key,
1115
1116 "Content-Type" : "application/octet-stream",
1117 "Content-Encoding" : "aes128gcm",
1118 "TTL" : "%s" % (ttl or 0),
1119 }
1120
1121 # Send the request
1122 try:
1123 await self.backend.httpclient.fetch(self.endpoint, method="POST",
1124 headers=headers, body=message)
1125
1126 except httpclient.HTTPError as e:
1127 # 410 - Gone
1128 # The subscription is no longer valid
1129 if e.code == 410:
1130 # Let's just delete ourselves
1131 await self.delete()
1132 return
1133
1134 # Raise everything else
1135 raise e
1136
1137 async def delete(self):
1138 """
1139 Deletes this subscription
1140 """
1141 # Immediately delete it
1142 await self.db.delete(self)
1143
1144 def _sign(self):
1145 elements = []
1146
1147 for element in (self._jwt_info, self._jwt_data):
1148 # Format the dictionary
1149 element = json.dumps(element, separators=(',', ':'), sort_keys=True)
1150
1151 # Encode to bytes
1152 element = element.encode()
1153
1154 # Encode URL-safe in base64 and remove any padding
1155 element = self.b64encode(element)
1156
1157 elements.append(element)
1158
1159 # Concatenate
1160 token = b".".join(elements)
1161
1162 log.debug("String to sign: %s" % token)
1163
1164 # Create the signature
1165 signature = self.vapid_private_key.sign(
1166 token,
1167 cryptography.hazmat.primitives.asymmetric.ec.ECDSA(
1168 cryptography.hazmat.primitives.hashes.SHA256(),
1169 ),
1170 )
1171
1172 # Decode the signature
1173 r, s = cryptography.hazmat.primitives.asymmetric.utils.decode_dss_signature(signature)
1174
1175 # Encode the signature in base64
1176 signature = self.b64encode(
1177 self._num_to_bytes(r, 32) + self._num_to_bytes(s, 32),
1178 )
1179
1180 # Put everything together
1181 signature = b"%s.%s" % (token, signature)
1182 signature = signature.decode()
1183
1184 log.debug("Created signature: %s" % signature)
1185
1186 return signature
1187
1188 _jwt_info = {
1189 "typ" : "JWT",
1190 "alg" : "ES256",
1191 }
1192
1193 @property
1194 def _jwt_data(self):
1195 # Parse the URL
1196 url = urllib.parse.urlparse(self.endpoint)
1197
1198 # Let the signature expire after 12 hours
1199 expires = time.time() + (12 * 3600)
1200
1201 return {
1202 "aud" : "%s://%s" % (url.scheme, url.netloc),
1203 "exp" : int(expires),
1204 "sub" : "mailto:info@ipfire.org",
1205 }
1206
1207 @staticmethod
1208 def _num_to_bytes(n, pad_to):
1209 """
1210 Returns the byte representation of an integer, in big-endian order.
1211 """
1212 h = "%x" % n
1213
1214 r = binascii.unhexlify("0" * (len(h) % 2) + h)
1215 return b"\x00" * (pad_to - len(r)) + r
1216
1217 @staticmethod
1218 def _serialize_key(key):
1219 if isinstance(key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey):
1220 return key.private_bytes(
1221 cryptography.hazmat.primitives.serialization.Encoding.DER,
1222 cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8,
1223 cryptography.hazmat.primitives.serialization.NoEncryption(),
1224 )
1225
1226 return key.public_bytes(
1227 cryptography.hazmat.primitives.serialization.Encoding.X962,
1228 cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint,
1229 )
1230
1231 @staticmethod
1232 def b64encode(data):
1233 return base64.urlsafe_b64encode(data).strip(b"=")
1234
1235 def _encrypt(self, message):
1236 """
1237 This is an absolutely ugly monster of a function which will sign the message
1238 """
1239 headers = {}
1240
1241 # Encode everything as bytes
1242 if not isinstance(message, bytes):
1243 message = message.encode()
1244
1245 # Generate some salt
1246 salt = os.urandom(16)
1247
1248 record_size = 4096
1249 chunk_size = record_size - 17
1250
1251 # The client's public key
1252 p256dh = cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point(
1253 cryptography.hazmat.primitives.asymmetric.ec.SECP256R1(), bytes(self.p256dh),
1254 )
1255
1256 # Generate an ephemeral server key
1257 server_private_key = cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
1258 cryptography.hazmat.primitives.asymmetric.ec.SECP256R1,
1259 cryptography.hazmat.backends.default_backend(),
1260 )
1261 server_public_key = server_private_key.public_key()
1262
1263 context = b"WebPush: info\x00"
1264
1265 # Serialize the client's public key
1266 context += p256dh.public_bytes(
1267 cryptography.hazmat.primitives.serialization.Encoding.X962,
1268 cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint,
1269 )
1270
1271 # Serialize the server's public key
1272 context += server_public_key.public_bytes(
1273 cryptography.hazmat.primitives.serialization.Encoding.X962,
1274 cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint,
1275 )
1276
1277 # Perform key derivation with ECDH
1278 secret = server_private_key.exchange(
1279 cryptography.hazmat.primitives.asymmetric.ec.ECDH(), p256dh,
1280 )
1281
1282 # Derive more stuff
1283 hkdf_auth = cryptography.hazmat.primitives.kdf.hkdf.HKDF(
1284 algorithm=cryptography.hazmat.primitives.hashes.SHA256(),
1285 length=32,
1286 salt=self.auth,
1287 info=context,
1288 backend=cryptography.hazmat.backends.default_backend(),
1289 )
1290 secret = hkdf_auth.derive(secret)
1291
1292 # Derive the signing key
1293 hkdf_key = cryptography.hazmat.primitives.kdf.hkdf.HKDF(
1294 algorithm=cryptography.hazmat.primitives.hashes.SHA256(),
1295 length=16,
1296 salt=salt,
1297 info=b"Content-Encoding: aes128gcm\x00",
1298 backend=cryptography.hazmat.backends.default_backend(),
1299 )
1300 encryption_key = hkdf_key.derive(secret)
1301
1302 # Derive a nonce
1303 hkdf_nonce = cryptography.hazmat.primitives.kdf.hkdf.HKDF(
1304 algorithm=cryptography.hazmat.primitives.hashes.SHA256(),
1305 length=12,
1306 salt=salt,
1307 info=b"Content-Encoding: nonce\x00",
1308 backend=cryptography.hazmat.backends.default_backend(),
1309 )
1310 nonce = hkdf_nonce.derive(secret)
1311
1312 result = b""
1313 chunks = 0
1314
1315 while True:
1316 # Fetch a chunk
1317 chunk, message = message[:chunk_size], message[chunk_size:]
1318 if not chunk:
1319 break
1320
1321 # Is this the last chunk?
1322 last = not message
1323
1324 # Encrypt the chunk
1325 result += self._encrypt_chunk(encryption_key, nonce, chunks, chunk, last)
1326
1327 # Kepp counting...
1328 chunks += 1
1329
1330 # Fetch the public key
1331 key_id = server_public_key.public_bytes(
1332 cryptography.hazmat.primitives.serialization.Encoding.X962,
1333 cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint,
1334 )
1335
1336 # Join the entire message together
1337 message = [
1338 salt,
1339 struct.pack("!L", record_size),
1340 struct.pack("!B", len(key_id)),
1341 key_id,
1342 result,
1343 ]
1344
1345 return b"".join(message)
1346
1347 def _encrypt_chunk(self, key, nonce, counter, chunk, last=False):
1348 """
1349 Encrypts one chunk
1350 """
1351 # Make the IV
1352 iv = self._make_iv(nonce, counter)
1353
1354 log.debug("Encrypting chunk %s: length = %s" % (counter + 1, len(chunk)))
1355
1356 if last:
1357 chunk += b"\x02"
1358 else:
1359 chunk += b"\x01"
1360
1361 # Setup AES GCM
1362 cipher = cryptography.hazmat.primitives.ciphers.Cipher(
1363 cryptography.hazmat.primitives.ciphers.algorithms.AES128(key),
1364 cryptography.hazmat.primitives.ciphers.modes.GCM(iv),
1365 backend=cryptography.hazmat.backends.default_backend(),
1366 )
1367
1368 # Get the encryptor
1369 encryptor = cipher.encryptor()
1370
1371 # Encrypt the chunk
1372 chunk = encryptor.update(chunk)
1373
1374 # Finalize this round
1375 chunk += encryptor.finalize() + encryptor.tag
1376
1377 return chunk
1378
1379 @staticmethod
1380 def _make_iv(base, counter):
1381 mask, = struct.unpack("!Q", base[4:])
1382
1383 return base[:4] + struct.pack("!Q", counter ^ mask)