From: Michael Tremer Date: Fri, 19 May 2023 16:55:34 +0000 (+0000) Subject: users: Implement scaffolding for push notifications X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=d48f75f7ccd57a04a03538775b6ae0d4d5d01e72;p=pbs.git users: Implement scaffolding for push notifications Signed-off-by: Michael Tremer --- diff --git a/Makefile.am b/Makefile.am index 107b9e04..9614ba36 100644 --- a/Makefile.am +++ b/Makefile.am @@ -320,7 +320,8 @@ dist_templates_users_DATA = \ src/templates/users/delete.html \ src/templates/users/edit.html \ src/templates/users/index.html \ - src/templates/users/show.html + src/templates/users/show.html \ + src/templates/users/subscribe.html templates_usersdir = $(templatesdir)/users @@ -330,7 +331,8 @@ dist_templates_users_messages_DATA = \ templates_users_messagesdir = $(templates_usersdir)/messages dist_templates_users_modules_DATA = \ - src/templates/users/modules/list.html + src/templates/users/modules/list.html \ + src/templates/users/modules/push-subscribe-button.html templates_users_modulesdir = $(templates_usersdir)/modules @@ -359,14 +361,18 @@ static_js_DATA = \ src/static/js/builders-stats.min.js \ src/static/js/jquery.min.js \ src/static/js/job-log-stream.min.js \ - src/static/js/pbs.min.js + src/static/js/notification-worker.min.js \ + src/static/js/pbs.min.js \ + src/static/js/user-push-subscribe-button.min.js static_jsdir = $(staticdir)/js EXTRA_DIST += \ src/static/js/builders-stats.js \ src/static/js/job-log-stream.js \ - src/static/js/pbs.js + src/static/js/notification-worker.js \ + src/static/js/pbs.js \ + src/static/js/user-push-subscribe-button.js CLEANFILES += \ $(static_js_DATA) diff --git a/src/buildservice/users.py b/src/buildservice/users.py index de5d2066..1976a385 100644 --- a/src/buildservice/users.py +++ b/src/buildservice/users.py @@ -1,17 +1,33 @@ #!/usr/bin/python3 +import asyncio +import base64 +import binascii +import cryptography.hazmat.backends +import cryptography.hazmat.primitives.asymmetric.ec +import cryptography.hazmat.primitives.asymmetric.utils +import cryptography.hazmat.primitives.ciphers +import cryptography.hazmat.primitives.ciphers.aead +import cryptography.hazmat.primitives.hashes +import cryptography.hazmat.primitives.kdf.hkdf +import cryptography.hazmat.primitives.serialization import datetime import email.utils +import json import ldap import logging +import os import pickle +import struct import threading import time +import urllib.parse import tornado.locale from . import base from . import bugtracker +from . import httpclient from .decorators import * @@ -34,6 +50,9 @@ LDAP_ATTRS = ( "mailAlternateAddress", ) +class QuotaExceededError(Exception): + pass + class Users(base.Object): def init(self): # Initialize thread-local storage @@ -320,6 +339,80 @@ class Users(base.Object): return list(users) + # Push Notifications + + @property + def vapid_public_key(self): + """ + The public part of the VAPID key + """ + return self.settings.get("vapid-public-key") + + @property + def vapid_private_key(self): + """ + The private part of the VAPID key + """ + return self.settings.get("vapid-private-key") + + async def generate_vapid_keys(self): + """ + Generates the VAPID keys + """ + # Do not generate a new key if one exists + if self.vapid_public_key and self.vapid_private_key: + return + + with self.db.transaction(): + # Generate the private key + private_key = await self.backend.command( + "openssl", + "ecparam", + "-name", "prime256v1", + "-genkey", + "-noout", + return_output=True, + ) + + # Generate the public key + public_key = await self.backend.command( + "openssl", + "ec", + "-pubout", + input=private_key, + return_output=True, + ) + + # Store the keys + self.settings.set("vapid-public-key", public_key) + self.settings.set("vapid-private-key", private_key) + + log.info("VAPID keys have been successfully generated") + + @property + def application_server_key(self): + lines = [] + + for line in self.vapid_public_key.splitlines(): + if line.startswith("-"): + continue + + lines.append(line) + + # Join everything together + key = "".join(lines) + + # Decode the key + key = base64.b64decode(key) + + # Only take the last bit + key = key[-65:] + + # Encode the key URL-safe + key = base64.urlsafe_b64encode(key).strip(b"=") + + return key + class User(base.DataObject): table = "users" @@ -662,6 +755,423 @@ class User(base.DataObject): return list(uploads) + # Push Subscriptions -class QuotaExceededError(Exception): - pass + def _get_subscriptions(self, query, *args): + res = self.db.query(query, *args) + + for row in res: + yield UserPushSubscription(self.backend, row.id, data=row) + + def _get_subscription(self, query, *args): + res = self.db.get(query, *args) + + if res: + return UserPushSubscription(self.backend, res.id, data=res) + + @lazy_property + def subscriptions(self): + subscriptions = self._get_subscriptions(""" + SELECT + * + FROM + user_push_subscriptions + WHERE + deleted_at IS NULL + AND + user_id = %s + ORDER BY + created_at + """, self.id, + ) + + return set(subscriptions) + + async def subscribe(self, endpoint, p256dh, auth, user_agent=None): + """ + Creates a new subscription for this user + """ + # Decode p256dh + if not isinstance(p256dh, bytes): + p256dh = base64.urlsafe_b64decode(p256dh + "==") + + # Decode auth + if not isinstance(auth, bytes): + auth = base64.urlsafe_b64decode(auth + "==") + + subscription = self._get_subscription(""" + INSERT INTO + user_push_subscriptions + ( + user_id, + user_agent, + endpoint, + p256dh, + auth + ) + VALUES + ( + %s, %s, %s, %s, %s + ) + RETURNING * + """, self.id, user_agent, endpoint, p256dh, auth, + ) + + # Send a message + await subscription.send("Hello World!") + + return subscription + + async def send_push_message(self, message): + """ + Sends a message to all active subscriptions + """ + async with asyncio.TaskGroup() as tg: + for subscription in self.subscriptions: + tg.create_task(subscription.send(message)) + + +class UserPushSubscription(base.DataObject): + table = "user_push_subscriptions" + + @property + def uuid(self): + """ + UUID + """ + return self.data.uuid + + @property + def created_at(self): + return self.data.created_at + + @property + def deleted_at(self): + return self.data.deleted_at + + def delete(self): + """ + Deletes this subscription + """ + self._set_attribute_now("deleted_at") + + @property + def endpoint(self): + return self.data.endpoint + + @lazy_property + def p256dh(self): + """ + The client's public key + """ + p = cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point( + cryptography.hazmat.primitives.asymmetric.ec.SECP256R1(), bytes(self.data.p256dh), + ) + + return p + + @property + def auth(self): + return bytes(self.data.auth) + + @property + def vapid_private_key(self): + return cryptography.hazmat.primitives.serialization.load_pem_private_key( + self.backend.users.vapid_private_key.encode(), + password=None, + backend=cryptography.hazmat.backends.default_backend(), + ) + + @property + def vapid_public_key(self): + return self.vapid_private_key.public_key() + + async def send(self, message, ttl=0): + """ + Sends a message to the user using the push service + """ + # Convert strings into a message object + if isinstance(message, str): + message = { + "message" : message, + } + + # Convert dict() to JSON + if isinstance(message, dict): + message = json.dumps(message) + + # Encode everything as bytes + if not isinstance(message, bytes): + message = message.encode() + + # Encrypt the message + message = self._encrypt(message) + + # Create a signature + signature = self._sign() + + # Encode the public key + crypto_key = self.b64encode( + self.vapid_public_key.public_bytes( + cryptography.hazmat.primitives.serialization.Encoding.X962, + cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint, + ) + ).decode() + + # Form request headers + headers = { + "Authorization" : "WebPush %s" % signature, + "Crypto-Key" : "p256ecdsa=%s" % crypto_key, + + "Content-Type" : "application/octet-stream", + "Content-Encoding" : "aes128gcm", + "TTL" : "%s" % (ttl or 0), + } + + # Send the request + try: + await self.backend.httpclient.fetch(self.endpoint, method="POST", + headers=headers, body=message) + + except httpclient.HTTPError as e: + # 410 - Gone + # The subscription is no longer valid + if e.code == 410: + # Let's just delete ourselves + self.delete() + return + + # Raise everything else + raise e + + def _sign(self): + elements = [] + + for element in (self._jwt_info, self._jwt_data): + # Format the dictionary + element = json.dumps(element, separators=(',', ':'), sort_keys=True) + + # Encode to bytes + element = element.encode() + + # Encode URL-safe in base64 and remove any padding + element = self.b64encode(element) + + elements.append(element) + + # Concatenate + token = b".".join(elements) + + log.debug("String to sign: %s" % token) + + # Create the signature + signature = self.vapid_private_key.sign( + token, + cryptography.hazmat.primitives.asymmetric.ec.ECDSA( + cryptography.hazmat.primitives.hashes.SHA256(), + ), + ) + + # Decode the signature + r, s = cryptography.hazmat.primitives.asymmetric.utils.decode_dss_signature(signature) + + # Encode the signature in base64 + signature = self.b64encode( + self._num_to_bytes(r, 32) + self._num_to_bytes(s, 32), + ) + + # Put everything together + signature = b"%s.%s" % (token, signature) + signature = signature.decode() + + log.debug("Created signature: %s" % signature) + + return signature + + _jwt_info = { + "typ" : "JWT", + "alg" : "ES256", + } + + @property + def _jwt_data(self): + # Parse the URL + url = urllib.parse.urlparse(self.endpoint) + + # Let the signature expire after 12 hours + expires = time.time() + (12 * 3600) + + return { + "aud" : "%s://%s" % (url.scheme, url.netloc), + "exp" : int(expires), + "sub" : "mailto:info@ipfire.org", + } + + @staticmethod + def _num_to_bytes(n, pad_to): + """ + Returns the byte representation of an integer, in big-endian order. + """ + h = "%x" % n + + r = binascii.unhexlify("0" * (len(h) % 2) + h) + return b"\x00" * (pad_to - len(r)) + r + + @staticmethod + def _serialize_key(key): + if isinstance(key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey): + return key.private_bytes( + cryptography.hazmat.primitives.serialization.Encoding.DER, + cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8, + cryptography.hazmat.primitives.serialization.NoEncryption(), + ) + + return key.public_bytes( + cryptography.hazmat.primitives.serialization.Encoding.X962, + cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint, + ) + + @staticmethod + def b64encode(data): + return base64.urlsafe_b64encode(data).strip(b"=") + + def _encrypt(self, message): + """ + This is an absolutely ugly monster of a function which will sign the message + """ + headers = {} + + # Generate some salt + salt = os.urandom(16) + + record_size = 4096 + chunk_size = record_size - 17 + + # Generate an ephemeral server key + server_private_key = cryptography.hazmat.primitives.asymmetric.ec.generate_private_key( + cryptography.hazmat.primitives.asymmetric.ec.SECP256R1, + cryptography.hazmat.backends.default_backend(), + ) + server_public_key = server_private_key.public_key() + + context = b"WebPush: info\x00" + + # Serialize the client's public key + context += self.p256dh.public_bytes( + cryptography.hazmat.primitives.serialization.Encoding.X962, + cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint, + ) + + # Serialize the server's public key + context += server_public_key.public_bytes( + cryptography.hazmat.primitives.serialization.Encoding.X962, + cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint, + ) + + # Perform key derivation with ECDH + secret = server_private_key.exchange( + cryptography.hazmat.primitives.asymmetric.ec.ECDH(), self.p256dh, + ) + + # Derive more stuff + hkdf_auth = cryptography.hazmat.primitives.kdf.hkdf.HKDF( + algorithm=cryptography.hazmat.primitives.hashes.SHA256(), + length=32, + salt=self.auth, + info=context, + backend=cryptography.hazmat.backends.default_backend(), + ) + secret = hkdf_auth.derive(secret) + + # Derive the signing key + hkdf_key = cryptography.hazmat.primitives.kdf.hkdf.HKDF( + algorithm=cryptography.hazmat.primitives.hashes.SHA256(), + length=16, + salt=salt, + info=b"Content-Encoding: aes128gcm\x00", + backend=cryptography.hazmat.backends.default_backend(), + ) + encryption_key = hkdf_key.derive(secret) + + # Derive a nonce + hkdf_nonce = cryptography.hazmat.primitives.kdf.hkdf.HKDF( + algorithm=cryptography.hazmat.primitives.hashes.SHA256(), + length=12, + salt=salt, + info=b"Content-Encoding: nonce\x00", + backend=cryptography.hazmat.backends.default_backend(), + ) + nonce = hkdf_nonce.derive(secret) + + result = b"" + chunks = 0 + + while True: + # Fetch a chunk + chunk, message = message[:chunk_size], message[chunk_size:] + if not chunk: + break + + # Is this the last chunk? + last = not message + + # Encrypt the chunk + result += self._encrypt_chunk(encryption_key, nonce, chunks, chunk, last) + + # Kepp counting... + chunks += 1 + + # Fetch the public key + key_id = server_public_key.public_bytes( + cryptography.hazmat.primitives.serialization.Encoding.X962, + cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint, + ) + + # Join the entire message together + message = [ + salt, + struct.pack("!L", record_size), + struct.pack("!B", len(key_id)), + key_id, + result, + ] + + return b"".join(message) + + def _encrypt_chunk(self, key, nonce, counter, chunk, last=False): + """ + Encrypts one chunk + """ + # Make the IV + iv = self._make_iv(nonce, counter) + + log.debug("Encrypting chunk %s: length = %s" % (counter + 1, len(chunk))) + + if last: + chunk += b"\x02" + else: + chunk += b"\x01" + + # Setup AES GCM + cipher = cryptography.hazmat.primitives.ciphers.Cipher( + cryptography.hazmat.primitives.ciphers.algorithms.AES128(key), + cryptography.hazmat.primitives.ciphers.modes.GCM(iv), + backend=cryptography.hazmat.backends.default_backend(), + ) + + # Get the encryptor + encryptor = cipher.encryptor() + + # Encrypt the chunk + chunk = encryptor.update(chunk) + + # Finalize this round + chunk += encryptor.finalize() + encryptor.tag + + return chunk + + @staticmethod + def _make_iv(base, counter): + mask, = struct.unpack("!Q", base[4:]) + + return base[:4] + struct.pack("!Q", counter ^ mask) diff --git a/src/database.sql b/src/database.sql index fa6fbd2c..95c85914 100644 --- a/src/database.sql +++ b/src/database.sql @@ -2,8 +2,8 @@ -- PostgreSQL database dump -- --- Dumped from database version 15.2 (Debian 15.2-2) --- Dumped by pg_dump version 15.2 (Debian 15.2-2) +-- Dumped from database version 15.3 (Debian 15.3-0+deb12u1) +-- Dumped by pg_dump version 15.3 (Debian 15.3-0+deb12u1) SET statement_timeout = 0; SET lock_timeout = 0; @@ -1088,6 +1088,43 @@ CREATE VIEW public.user_disk_usages AS GROUP BY objects.user_id; +-- +-- Name: user_push_subscriptions; Type: TABLE; Schema: public; Owner: - +-- + +CREATE TABLE public.user_push_subscriptions ( + id integer NOT NULL, + user_id integer NOT NULL, + uuid uuid DEFAULT gen_random_uuid() NOT NULL, + created_at timestamp without time zone DEFAULT CURRENT_TIMESTAMP NOT NULL, + deleted_at timestamp without time zone, + user_agent text, + endpoint text NOT NULL, + p256dh bytea NOT NULL, + auth bytea NOT NULL +); + + +-- +-- Name: user_push_subscriptions_id_seq; Type: SEQUENCE; Schema: public; Owner: - +-- + +CREATE SEQUENCE public.user_push_subscriptions_id_seq + AS integer + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; + + +-- +-- Name: user_push_subscriptions_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: - +-- + +ALTER SEQUENCE public.user_push_subscriptions_id_seq OWNED BY public.user_push_subscriptions.id; + + -- -- Name: users; Type: TABLE; Schema: public; Owner: - -- @@ -1264,6 +1301,13 @@ ALTER TABLE ONLY public.sources ALTER COLUMN id SET DEFAULT nextval('public.sour ALTER TABLE ONLY public.uploads ALTER COLUMN id SET DEFAULT nextval('public.uploads_id_seq'::regclass); +-- +-- Name: user_push_subscriptions id; Type: DEFAULT; Schema: public; Owner: - +-- + +ALTER TABLE ONLY public.user_push_subscriptions ALTER COLUMN id SET DEFAULT nextval('public.user_push_subscriptions_id_seq'::regclass); + + -- -- Name: users id; Type: DEFAULT; Schema: public; Owner: - -- @@ -1455,6 +1499,14 @@ ALTER TABLE ONLY public.uploads ADD CONSTRAINT uploads_id PRIMARY KEY (id); +-- +-- Name: user_push_subscriptions user_push_subscriptions_pkey; Type: CONSTRAINT; Schema: public; Owner: - +-- + +ALTER TABLE ONLY public.user_push_subscriptions + ADD CONSTRAINT user_push_subscriptions_pkey PRIMARY KEY (id); + + -- -- Name: build_comments_build_id; Type: INDEX; Schema: public; Owner: - -- @@ -1784,6 +1836,13 @@ CREATE UNIQUE INDEX sources_slug ON public.sources USING btree (slug) WHERE (del CREATE UNIQUE INDEX uploads_uuid ON public.uploads USING btree (uuid); +-- +-- Name: user_push_subscriptions_user_id; Type: INDEX; Schema: public; Owner: - +-- + +CREATE INDEX user_push_subscriptions_user_id ON public.user_push_subscriptions USING btree (user_id) WHERE (deleted_at IS NULL); + + -- -- Name: sources on_update_current_timestamp; Type: TRIGGER; Schema: public; Owner: - -- @@ -2255,6 +2314,14 @@ ALTER TABLE ONLY public.uploads ADD CONSTRAINT uploads_user_id FOREIGN KEY (user_id) REFERENCES public.users(id); +-- +-- Name: user_push_subscriptions user_push_subscriptions_user_id; Type: FK CONSTRAINT; Schema: public; Owner: - +-- + +ALTER TABLE ONLY public.user_push_subscriptions + ADD CONSTRAINT user_push_subscriptions_user_id FOREIGN KEY (user_id) REFERENCES public.users(id); + + -- -- Name: SCHEMA public; Type: ACL; Schema: -; Owner: - -- diff --git a/src/scripts/pakfire-build-service b/src/scripts/pakfire-build-service index 8622a426..ed30120e 100644 --- a/src/scripts/pakfire-build-service +++ b/src/scripts/pakfire-build-service @@ -52,6 +52,10 @@ class Cli(object): # Sync "sync" : self.backend.sync, + # Users + "users:generate-vapid-keys" : self.backend.users.generate_vapid_keys, + "users:send-push-message" : self._users_send_push_message, + # Dist #"dist" : self.backend.sources.dist, @@ -197,6 +201,16 @@ class Cli(object): """ return await self.backend.mirrors.check(force=True) + async def _users_send_push_message(self, name, message): + # Fetch the user + user = self.backend.users.get_by_name(name) + if not user: + log.error("Could not find user %s" % name) + return + + # Send the message + await user.send_push_message(message) + async def main(): cli = Cli() diff --git a/src/static/js/notification-worker.js b/src/static/js/notification-worker.js new file mode 100644 index 00000000..b79b222f --- /dev/null +++ b/src/static/js/notification-worker.js @@ -0,0 +1,40 @@ +'use strict'; + +self.addEventListener("push", function(event) { + var data = {}; + + // Fetch the data as JSON + try { + data = event.data.json(); + } catch (e) { + // Nothing + } + + // Log what we have received + console.debug("Push notification has been received: " + data); + + const title = data.title || event.data.text(); + + const options = { + "body" : data.message, + }; + + // Show the notification + const notification = self.registration.showNotification(title, options); + + event.waitUntil(notification); +}); + +/* + Handle when the user clicks the notification +*/ +self.addEventListener("notificationclick", function(event) { + // Close the notification + event.notification.close(); + + event.waitUntil( + clients.openWindow('https://developers.google.com/web/') + ); +}); + +// pushsubscriptionchange Handle this? diff --git a/src/static/js/user-push-subscribe-button.js b/src/static/js/user-push-subscribe-button.js new file mode 100644 index 00000000..c17829ca --- /dev/null +++ b/src/static/js/user-push-subscribe-button.js @@ -0,0 +1,67 @@ +/* + * Request permission when the button is being clicked + */ + +// Check if the browser supports notifications +$(function() { + // Nothing to do if the browser supports notifications + if ("serviceWorker" in navigator && "PushManager" in window) + return; + + // If not, we will disable the button + $("#push-subscribe-button").prop("disabled", true); +}); + +// Handle button click +$("#push-subscribe-button").on("click", function() { + console.debug("Subscribe button clicked!"); + + // Fetch our application server key + const application_server_key = $(this).data("application-server-key"); + + // Request permission from the user + const request = new Promise(function (resolve, reject) { + const result = Notification.requestPermission(function (result) { + resolve(result); + }); + + if (result) { + result.then(resolve, reject); + } + }).then(function (result) { + if (result !== 'granted') { + throw new Error("We weren't granted permission."); + } + }); + + // Show some activity + $(this).addClass("is-loading"); + + // Register our service worker + var registration = navigator.serviceWorker.register("/static/js/notification-worker.min.js"); + + // Register with the push service + registration = registration.then(function (registration) { + return registration.pushManager.subscribe({ + userVisibleOnly: true, + applicationServerKey: application_server_key, + }); + }) + + // Fetch the PushSubscription + const subscription = registration.then(function (subscription) { + console.debug("Received PushSubscription: ", JSON.stringify(subscription)); + + // Send the PushSubscription to our server + $.post({ + "url" : "/users/push/subscribe", + + // Payload + "contentType" : "application/json", + "data" : JSON.stringify(subscription), + }); + + return subscription; + }); +}); + diff --git a/src/templates/users/modules/push-subscribe-button.html b/src/templates/users/modules/push-subscribe-button.html new file mode 100644 index 00000000..e0ea53e2 --- /dev/null +++ b/src/templates/users/modules/push-subscribe-button.html @@ -0,0 +1,4 @@ + diff --git a/src/templates/users/subscribe.html b/src/templates/users/subscribe.html new file mode 100644 index 00000000..60488dce --- /dev/null +++ b/src/templates/users/subscribe.html @@ -0,0 +1,40 @@ +{% extends "../modal.html" %} + +{% block title %}{{ _("Subscribe To Push Notifications") }}{% end block %} + +{% block breadcrumbs %} + +{% end block %} + +{% block modal_title %} +

{{ _("Subscribe To Push Notifications") }}

+{% end block %} + +{% block modal %} + {% raw xsrf_form_html() %} + +
+

+ {{ _("Do you want to subscribe to push notifications?") }} +

+
+ + {# Submit! #} +
+ {% module UserPushSubscribeButton() %} +
+{% end block %} diff --git a/src/web/__init__.py b/src/web/__init__.py index a23cfded..351ff7b3 100644 --- a/src/web/__init__.py +++ b/src/web/__init__.py @@ -79,6 +79,7 @@ class Application(tornado.web.Application): # Users "UsersList" : users.ListModule, + "UserPushSubscribeButton" : users.PushSubscribeButton, "CommitMessage" : ui_modules.CommitMessageModule, "CommitsTable" : ui_modules.CommitsTableModule, @@ -111,6 +112,7 @@ class Application(tornado.web.Application): (r"/users/(\w+)", users.ShowHandler), (r"/users/(\w+)/delete", users.DeleteHandler), (r"/users/(\w+)/edit", users.EditHandler), + (r"/users/push/subscribe", users.PushSubscribeHandler), # User Repositories (r"/users/(\w+)/repos/create", repos.CreateCustomHandler), @@ -216,6 +218,7 @@ class Application(tornado.web.Application): logging.info("Successfully initialied application") # Launch some initial tasks + self.backend.run_task(self.backend.users.generate_vapid_keys) self.backend.run_task(self.backend.builders.sync) self.backend.run_task(self.backend.builders.autoscale) diff --git a/src/web/users.py b/src/web/users.py index 80842991..888c0de9 100644 --- a/src/web/users.py +++ b/src/web/users.py @@ -1,5 +1,6 @@ #!/usr/bin/python +import json import tornado.locale import tornado.web @@ -90,6 +91,51 @@ class BuildsHandler(base.BaseHandler): self.render("users/builds.html", user=user, builds=user.builds) +class PushSubscribeHandler(base.BaseHandler): + @tornado.web.authenticated + def get(self): + self.render("users/subscribe.html") + + @tornado.web.authenticated + async def post(self): + # The request body must be JSON + if not self.request.headers.get("Content-Type") == "application/json": + raise tornado.web.HTTPError(400) + + # Parse the JSON blob + try: + blob = json.loads(self.request.body) + except json.DecodeError as e: + raise tornado.web.HTTPError(400, "Could not parse JSON: %s" % e) from e + + # Fetch all values + args = { + "endpoint" : blob.get("endpoint"), + "p256dh" : blob.get("keys").get("p256dh"), + "auth" : blob.get("keys").get("auth"), + + # Add the user agent + "user_agent" : self.user_agent, + } + + with self.db.transaction(): + await self.current_user.subscribe(**args) + + class ListModule(ui_modules.UIModule): def render(self, users): return self.render_string("users/modules/list.html", users=users) + + +class PushSubscribeButton(ui_modules.UIModule): + def render(self): + # Fetch the application server key + application_server_key = self.backend.users.application_server_key + + return self.render_string("users/modules/push-subscribe-button.html", + application_server_key=application_server_key) + + def javascript_files(self): + return ( + "js/user-push-subscribe-button.min.js", + )