]> git.ipfire.org Git - pbs.git/commitdiff
users: Fix subscribing to push notifications
authorMichael Tremer <michael.tremer@ipfire.org>
Sun, 26 Jan 2025 10:24:46 +0000 (10:24 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Sun, 26 Jan 2025 10:24:46 +0000 (10:24 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/buildservice/users.py
src/web/users.py

index 8a8154faf06e0c7b66a5ffb5a49c79869cc2d04d..0c1c476a3bfe07f3285900807de4822c2abd64fc 100644 (file)
@@ -446,17 +446,20 @@ class Users(base.Object):
                """
                        The public part of the VAPID key
                """
-               return self.config.get("vapid", "public-key")
+               return self.backend.config.get("vapid", "public-key")
 
        @property
        def vapid_private_key(self):
                """
                        The private part of the VAPID key
                """
-               return self.config.get("vapid", "private-key")
+               return self.backend.config.get("vapid", "private-key")
 
-       @functools.cached_property
-       def application_server_key(self):
+       @functools.cache
+       def get_application_server_key(self):
+               """
+                       Generates the key that we are sending to the client
+               """
                lines = []
 
                for line in self.vapid_public_key.splitlines():
@@ -477,7 +480,8 @@ class Users(base.Object):
                # Encode the key URL-safe
                key = base64.urlsafe_b64encode(key).strip(b"=")
 
-               return key
+               # Return as string
+               return key.decode()
 
 
 class User(database.Base, database.BackendMixin, database.SoftDeleteMixin):
@@ -912,23 +916,24 @@ class User(database.Base, database.BackendMixin, database.SoftDeleteMixin):
 
        # Push Subscriptions
 
-       @lazy_property
-       def subscriptions(self):
-               subscriptions = self._get_subscriptions("""
-                       SELECT
-                               *
-                       FROM
-                               user_push_subscriptions
-                       WHERE
-                               deleted_at IS NULL
-                       AND
-                               user_id = %s
-                       ORDER BY
-                               created_at
-                       """, self.id,
+       async def get_subscriptions(self):
+               """
+                       Fetches all current subscriptions
+               """
+               stmt = (
+                       sqlalchemy
+                       .select(
+                               UserPushSubscription,
+                       ).where(
+                               UserPushSubscription.deleted_at == None,
+                               UserPushSubscription.user == self,
+                       )
+                       .order_by(
+                               UserPushSubscription.created_at.asc(),
+                       )
                )
 
-               return set(subscriptions)
+               return await self.db.fetch_as_list(stmt)
 
        async def subscribe(self, endpoint, p256dh, auth, user_agent=None):
                """
@@ -944,24 +949,19 @@ class User(database.Base, database.BackendMixin, database.SoftDeleteMixin):
                if not isinstance(auth, bytes):
                        auth = base64.urlsafe_b64decode(auth + "==")
 
-               subscription = self._get_subscription("""
-                       INSERT INTO
-                               user_push_subscriptions
-                       (
-                               user_id,
-                               user_agent,
-                               endpoint,
-                               p256dh,
-                               auth
-                       )
-                       VALUES
-                       (
-                               %s, %s, %s, %s, %s
-                       )
-                       RETURNING *
-                       """, self.id, user_agent, endpoint, p256dh, auth,
+               # Insert into the database
+               subscription = await self.db.insert(
+                       UserPushSubscription,
+                       user       = self,
+                       user_agent = user_agent,
+                       endpoint   = endpoint,
+                       p256dh     = p256dh,
+                       auth       = auth,
                )
 
+               # Log action
+               log.info("%s subscribed to push notifications" % self)
+
                # Send a message
                await subscription.send(
                        self._make_message(
@@ -976,8 +976,10 @@ class User(database.Base, database.BackendMixin, database.SoftDeleteMixin):
                """
                        Sends a message to all active subscriptions
                """
-               # Return False if the user has no subscriptions
-               if not self.subscriptions:
+               subscriptions = await self.get_subscriptions()
+
+               # Return early if there are no subscriptions
+               if not subscriptions:
                        return False
 
                # Format the message
@@ -1003,7 +1005,7 @@ class User(database.Base, database.BackendMixin, database.SoftDeleteMixin):
                return message
 
 
-class UserPushSubscription(database.Base):
+class UserPushSubscription(database.Base, database.BackendMixin):
        __tablename__ = "user_push_subscriptions"
 
        # ID
@@ -1020,7 +1022,8 @@ class UserPushSubscription(database.Base):
 
        # UUID
 
-       uuid = Column(UUID, unique=True, nullable=False)
+       uuid = Column(UUID, unique=True, nullable=False,
+               server_default=sqlalchemy.func.gen_random_uuid())
 
        # Created At
 
@@ -1035,16 +1038,9 @@ class UserPushSubscription(database.Base):
 
        endpoint = Column(Text, nullable=False)
 
-       @lazy_property
-       def p256dh(self):
-               """
-                       The client's public key
-               """
-               p = cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point(
-                       cryptography.hazmat.primitives.asymmetric.ec.SECP256R1(), bytes(self.data.p256dh),
-               )
+       # P256DH
 
-               return p
+       p256dh = Column(LargeBinary, nullable=False)
 
        # Auth
 
@@ -1223,6 +1219,11 @@ class UserPushSubscription(database.Base):
                record_size = 4096
                chunk_size = record_size - 17
 
+               # The client's public key
+               p256dh = cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point(
+                       cryptography.hazmat.primitives.asymmetric.ec.SECP256R1(), bytes(self.p256dh),
+               )
+
                # Generate an ephemeral server key
                server_private_key = cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
                        cryptography.hazmat.primitives.asymmetric.ec.SECP256R1,
@@ -1233,7 +1234,7 @@ class UserPushSubscription(database.Base):
                context = b"WebPush: info\x00"
 
                # Serialize the client's public key
-               context += self.p256dh.public_bytes(
+               context += p256dh.public_bytes(
                        cryptography.hazmat.primitives.serialization.Encoding.X962,
                        cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint,
                )
@@ -1246,7 +1247,7 @@ class UserPushSubscription(database.Base):
 
                # Perform key derivation with ECDH
                secret = server_private_key.exchange(
-                       cryptography.hazmat.primitives.asymmetric.ec.ECDH(), self.p256dh,
+                       cryptography.hazmat.primitives.asymmetric.ec.ECDH(), p256dh,
                )
 
                # Derive more stuff
index 45986354a0d83c9742d47bd352795fad30af5298..d42d3c5dd5da84168094c09a6ec3ff2d73e2f96b 100644 (file)
@@ -95,8 +95,8 @@ class BuildsHandler(base.BaseHandler):
 
 class PushSubscribeHandler(base.BaseHandler):
        @base.authenticated
-       def get(self):
-               self.render("users/subscribe.html")
+       async def get(self):
+               await self.render("users/subscribe.html")
 
        @base.authenticated
        async def post(self):
@@ -120,9 +120,13 @@ class PushSubscribeHandler(base.BaseHandler):
                        "user_agent" : self.user_agent,
                }
 
-               with self.db.transaction():
+               async with await self.db.transaction():
                        await self.current_user.subscribe(**args)
 
+               # Send empty response
+               self.set_status(204)
+               self.finish()
+
 
 #class PushSubscribeButton(ui_modules.UIModule):
 #      def render(self):