From f5706b26b37a6ecd70a5e07fada3dd84c5d08b25 Mon Sep 17 00:00:00 2001 From: Michael Tremer Date: Sun, 26 Jan 2025 10:24:46 +0000 Subject: [PATCH] users: Fix subscribing to push notifications Signed-off-by: Michael Tremer --- src/buildservice/users.py | 103 +++++++++++++++++++------------------- src/web/users.py | 10 ++-- 2 files changed, 59 insertions(+), 54 deletions(-) diff --git a/src/buildservice/users.py b/src/buildservice/users.py index 8a8154fa..0c1c476a 100644 --- a/src/buildservice/users.py +++ b/src/buildservice/users.py @@ -446,17 +446,20 @@ class Users(base.Object): """ The public part of the VAPID key """ - return self.config.get("vapid", "public-key") + return self.backend.config.get("vapid", "public-key") @property def vapid_private_key(self): """ The private part of the VAPID key """ - return self.config.get("vapid", "private-key") + return self.backend.config.get("vapid", "private-key") - @functools.cached_property - def application_server_key(self): + @functools.cache + def get_application_server_key(self): + """ + Generates the key that we are sending to the client + """ lines = [] for line in self.vapid_public_key.splitlines(): @@ -477,7 +480,8 @@ class Users(base.Object): # Encode the key URL-safe key = base64.urlsafe_b64encode(key).strip(b"=") - return key + # Return as string + return key.decode() class User(database.Base, database.BackendMixin, database.SoftDeleteMixin): @@ -912,23 +916,24 @@ class User(database.Base, database.BackendMixin, database.SoftDeleteMixin): # Push Subscriptions - @lazy_property - def subscriptions(self): - subscriptions = self._get_subscriptions(""" - SELECT - * - FROM - user_push_subscriptions - WHERE - deleted_at IS NULL - AND - user_id = %s - ORDER BY - created_at - """, self.id, + async def get_subscriptions(self): + """ + Fetches all current subscriptions + """ + stmt = ( + sqlalchemy + .select( + UserPushSubscription, + ).where( + UserPushSubscription.deleted_at == None, + UserPushSubscription.user == self, + ) + .order_by( + UserPushSubscription.created_at.asc(), + ) ) - return set(subscriptions) + return await self.db.fetch_as_list(stmt) async def subscribe(self, endpoint, p256dh, auth, user_agent=None): """ @@ -944,24 +949,19 @@ class User(database.Base, database.BackendMixin, database.SoftDeleteMixin): if not isinstance(auth, bytes): auth = base64.urlsafe_b64decode(auth + "==") - subscription = self._get_subscription(""" - INSERT INTO - user_push_subscriptions - ( - user_id, - user_agent, - endpoint, - p256dh, - auth - ) - VALUES - ( - %s, %s, %s, %s, %s - ) - RETURNING * - """, self.id, user_agent, endpoint, p256dh, auth, + # Insert into the database + subscription = await self.db.insert( + UserPushSubscription, + user = self, + user_agent = user_agent, + endpoint = endpoint, + p256dh = p256dh, + auth = auth, ) + # Log action + log.info("%s subscribed to push notifications" % self) + # Send a message await subscription.send( self._make_message( @@ -976,8 +976,10 @@ class User(database.Base, database.BackendMixin, database.SoftDeleteMixin): """ Sends a message to all active subscriptions """ - # Return False if the user has no subscriptions - if not self.subscriptions: + subscriptions = await self.get_subscriptions() + + # Return early if there are no subscriptions + if not subscriptions: return False # Format the message @@ -1003,7 +1005,7 @@ class User(database.Base, database.BackendMixin, database.SoftDeleteMixin): return message -class UserPushSubscription(database.Base): +class UserPushSubscription(database.Base, database.BackendMixin): __tablename__ = "user_push_subscriptions" # ID @@ -1020,7 +1022,8 @@ class UserPushSubscription(database.Base): # UUID - uuid = Column(UUID, unique=True, nullable=False) + uuid = Column(UUID, unique=True, nullable=False, + server_default=sqlalchemy.func.gen_random_uuid()) # Created At @@ -1035,16 +1038,9 @@ class UserPushSubscription(database.Base): endpoint = Column(Text, nullable=False) - @lazy_property - def p256dh(self): - """ - The client's public key - """ - p = cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point( - cryptography.hazmat.primitives.asymmetric.ec.SECP256R1(), bytes(self.data.p256dh), - ) + # P256DH - return p + p256dh = Column(LargeBinary, nullable=False) # Auth @@ -1223,6 +1219,11 @@ class UserPushSubscription(database.Base): record_size = 4096 chunk_size = record_size - 17 + # The client's public key + p256dh = cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point( + cryptography.hazmat.primitives.asymmetric.ec.SECP256R1(), bytes(self.p256dh), + ) + # Generate an ephemeral server key server_private_key = cryptography.hazmat.primitives.asymmetric.ec.generate_private_key( cryptography.hazmat.primitives.asymmetric.ec.SECP256R1, @@ -1233,7 +1234,7 @@ class UserPushSubscription(database.Base): context = b"WebPush: info\x00" # Serialize the client's public key - context += self.p256dh.public_bytes( + context += p256dh.public_bytes( cryptography.hazmat.primitives.serialization.Encoding.X962, cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint, ) @@ -1246,7 +1247,7 @@ class UserPushSubscription(database.Base): # Perform key derivation with ECDH secret = server_private_key.exchange( - cryptography.hazmat.primitives.asymmetric.ec.ECDH(), self.p256dh, + cryptography.hazmat.primitives.asymmetric.ec.ECDH(), p256dh, ) # Derive more stuff diff --git a/src/web/users.py b/src/web/users.py index 45986354..d42d3c5d 100644 --- a/src/web/users.py +++ b/src/web/users.py @@ -95,8 +95,8 @@ class BuildsHandler(base.BaseHandler): class PushSubscribeHandler(base.BaseHandler): @base.authenticated - def get(self): - self.render("users/subscribe.html") + async def get(self): + await self.render("users/subscribe.html") @base.authenticated async def post(self): @@ -120,9 +120,13 @@ class PushSubscribeHandler(base.BaseHandler): "user_agent" : self.user_agent, } - with self.db.transaction(): + async with await self.db.transaction(): await self.current_user.subscribe(**args) + # Send empty response + self.set_status(204) + self.finish() + #class PushSubscribeButton(ui_modules.UIModule): # def render(self): -- 2.47.2