#!/usr/bin/python3
+import asyncio
+import base64
+import binascii
+import cryptography.hazmat.backends
+import cryptography.hazmat.primitives.asymmetric.ec
+import cryptography.hazmat.primitives.asymmetric.utils
+import cryptography.hazmat.primitives.ciphers
+import cryptography.hazmat.primitives.ciphers.aead
+import cryptography.hazmat.primitives.hashes
+import cryptography.hazmat.primitives.kdf.hkdf
+import cryptography.hazmat.primitives.serialization
import datetime
import email.utils
+import json
import ldap
import logging
+import os
import pickle
+import struct
import threading
import time
+import urllib.parse
import tornado.locale
from . import base
from . import bugtracker
+from . import httpclient
from .decorators import *
"mailAlternateAddress",
)
+class QuotaExceededError(Exception):
+ pass
+
class Users(base.Object):
def init(self):
# Initialize thread-local storage
return list(users)
+ # Push Notifications
+
+ @property
+ def vapid_public_key(self):
+ """
+ The public part of the VAPID key
+ """
+ return self.settings.get("vapid-public-key")
+
+ @property
+ def vapid_private_key(self):
+ """
+ The private part of the VAPID key
+ """
+ return self.settings.get("vapid-private-key")
+
+ async def generate_vapid_keys(self):
+ """
+ Generates the VAPID keys
+ """
+ # Do not generate a new key if one exists
+ if self.vapid_public_key and self.vapid_private_key:
+ return
+
+ with self.db.transaction():
+ # Generate the private key
+ private_key = await self.backend.command(
+ "openssl",
+ "ecparam",
+ "-name", "prime256v1",
+ "-genkey",
+ "-noout",
+ return_output=True,
+ )
+
+ # Generate the public key
+ public_key = await self.backend.command(
+ "openssl",
+ "ec",
+ "-pubout",
+ input=private_key,
+ return_output=True,
+ )
+
+ # Store the keys
+ self.settings.set("vapid-public-key", public_key)
+ self.settings.set("vapid-private-key", private_key)
+
+ log.info("VAPID keys have been successfully generated")
+
+ @property
+ def application_server_key(self):
+ lines = []
+
+ for line in self.vapid_public_key.splitlines():
+ if line.startswith("-"):
+ continue
+
+ lines.append(line)
+
+ # Join everything together
+ key = "".join(lines)
+
+ # Decode the key
+ key = base64.b64decode(key)
+
+ # Only take the last bit
+ key = key[-65:]
+
+ # Encode the key URL-safe
+ key = base64.urlsafe_b64encode(key).strip(b"=")
+
+ return key
+
class User(base.DataObject):
table = "users"
return list(uploads)
+ # Push Subscriptions
-class QuotaExceededError(Exception):
- pass
+ def _get_subscriptions(self, query, *args):
+ res = self.db.query(query, *args)
+
+ for row in res:
+ yield UserPushSubscription(self.backend, row.id, data=row)
+
+ def _get_subscription(self, query, *args):
+ res = self.db.get(query, *args)
+
+ if res:
+ return UserPushSubscription(self.backend, res.id, data=res)
+
+ @lazy_property
+ def subscriptions(self):
+ subscriptions = self._get_subscriptions("""
+ SELECT
+ *
+ FROM
+ user_push_subscriptions
+ WHERE
+ deleted_at IS NULL
+ AND
+ user_id = %s
+ ORDER BY
+ created_at
+ """, self.id,
+ )
+
+ return set(subscriptions)
+
+ async def subscribe(self, endpoint, p256dh, auth, user_agent=None):
+ """
+ Creates a new subscription for this user
+ """
+ # Decode p256dh
+ if not isinstance(p256dh, bytes):
+ p256dh = base64.urlsafe_b64decode(p256dh + "==")
+
+ # Decode auth
+ if not isinstance(auth, bytes):
+ auth = base64.urlsafe_b64decode(auth + "==")
+
+ subscription = self._get_subscription("""
+ INSERT INTO
+ user_push_subscriptions
+ (
+ user_id,
+ user_agent,
+ endpoint,
+ p256dh,
+ auth
+ )
+ VALUES
+ (
+ %s, %s, %s, %s, %s
+ )
+ RETURNING *
+ """, self.id, user_agent, endpoint, p256dh, auth,
+ )
+
+ # Send a message
+ await subscription.send("Hello World!")
+
+ return subscription
+
+ async def send_push_message(self, message):
+ """
+ Sends a message to all active subscriptions
+ """
+ async with asyncio.TaskGroup() as tg:
+ for subscription in self.subscriptions:
+ tg.create_task(subscription.send(message))
+
+
+class UserPushSubscription(base.DataObject):
+ table = "user_push_subscriptions"
+
+ @property
+ def uuid(self):
+ """
+ UUID
+ """
+ return self.data.uuid
+
+ @property
+ def created_at(self):
+ return self.data.created_at
+
+ @property
+ def deleted_at(self):
+ return self.data.deleted_at
+
+ def delete(self):
+ """
+ Deletes this subscription
+ """
+ self._set_attribute_now("deleted_at")
+
+ @property
+ def endpoint(self):
+ return self.data.endpoint
+
+ @lazy_property
+ def p256dh(self):
+ """
+ The client's public key
+ """
+ p = cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePublicKey.from_encoded_point(
+ cryptography.hazmat.primitives.asymmetric.ec.SECP256R1(), bytes(self.data.p256dh),
+ )
+
+ return p
+
+ @property
+ def auth(self):
+ return bytes(self.data.auth)
+
+ @property
+ def vapid_private_key(self):
+ return cryptography.hazmat.primitives.serialization.load_pem_private_key(
+ self.backend.users.vapid_private_key.encode(),
+ password=None,
+ backend=cryptography.hazmat.backends.default_backend(),
+ )
+
+ @property
+ def vapid_public_key(self):
+ return self.vapid_private_key.public_key()
+
+ async def send(self, message, ttl=0):
+ """
+ Sends a message to the user using the push service
+ """
+ # Convert strings into a message object
+ if isinstance(message, str):
+ message = {
+ "message" : message,
+ }
+
+ # Convert dict() to JSON
+ if isinstance(message, dict):
+ message = json.dumps(message)
+
+ # Encode everything as bytes
+ if not isinstance(message, bytes):
+ message = message.encode()
+
+ # Encrypt the message
+ message = self._encrypt(message)
+
+ # Create a signature
+ signature = self._sign()
+
+ # Encode the public key
+ crypto_key = self.b64encode(
+ self.vapid_public_key.public_bytes(
+ cryptography.hazmat.primitives.serialization.Encoding.X962,
+ cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint,
+ )
+ ).decode()
+
+ # Form request headers
+ headers = {
+ "Authorization" : "WebPush %s" % signature,
+ "Crypto-Key" : "p256ecdsa=%s" % crypto_key,
+
+ "Content-Type" : "application/octet-stream",
+ "Content-Encoding" : "aes128gcm",
+ "TTL" : "%s" % (ttl or 0),
+ }
+
+ # Send the request
+ try:
+ await self.backend.httpclient.fetch(self.endpoint, method="POST",
+ headers=headers, body=message)
+
+ except httpclient.HTTPError as e:
+ # 410 - Gone
+ # The subscription is no longer valid
+ if e.code == 410:
+ # Let's just delete ourselves
+ self.delete()
+ return
+
+ # Raise everything else
+ raise e
+
+ def _sign(self):
+ elements = []
+
+ for element in (self._jwt_info, self._jwt_data):
+ # Format the dictionary
+ element = json.dumps(element, separators=(',', ':'), sort_keys=True)
+
+ # Encode to bytes
+ element = element.encode()
+
+ # Encode URL-safe in base64 and remove any padding
+ element = self.b64encode(element)
+
+ elements.append(element)
+
+ # Concatenate
+ token = b".".join(elements)
+
+ log.debug("String to sign: %s" % token)
+
+ # Create the signature
+ signature = self.vapid_private_key.sign(
+ token,
+ cryptography.hazmat.primitives.asymmetric.ec.ECDSA(
+ cryptography.hazmat.primitives.hashes.SHA256(),
+ ),
+ )
+
+ # Decode the signature
+ r, s = cryptography.hazmat.primitives.asymmetric.utils.decode_dss_signature(signature)
+
+ # Encode the signature in base64
+ signature = self.b64encode(
+ self._num_to_bytes(r, 32) + self._num_to_bytes(s, 32),
+ )
+
+ # Put everything together
+ signature = b"%s.%s" % (token, signature)
+ signature = signature.decode()
+
+ log.debug("Created signature: %s" % signature)
+
+ return signature
+
+ _jwt_info = {
+ "typ" : "JWT",
+ "alg" : "ES256",
+ }
+
+ @property
+ def _jwt_data(self):
+ # Parse the URL
+ url = urllib.parse.urlparse(self.endpoint)
+
+ # Let the signature expire after 12 hours
+ expires = time.time() + (12 * 3600)
+
+ return {
+ "aud" : "%s://%s" % (url.scheme, url.netloc),
+ "exp" : int(expires),
+ "sub" : "mailto:info@ipfire.org",
+ }
+
+ @staticmethod
+ def _num_to_bytes(n, pad_to):
+ """
+ Returns the byte representation of an integer, in big-endian order.
+ """
+ h = "%x" % n
+
+ r = binascii.unhexlify("0" * (len(h) % 2) + h)
+ return b"\x00" * (pad_to - len(r)) + r
+
+ @staticmethod
+ def _serialize_key(key):
+ if isinstance(key, cryptography.hazmat.primitives.asymmetric.ec.EllipticCurvePrivateKey):
+ return key.private_bytes(
+ cryptography.hazmat.primitives.serialization.Encoding.DER,
+ cryptography.hazmat.primitives.serialization.PrivateFormat.PKCS8,
+ cryptography.hazmat.primitives.serialization.NoEncryption(),
+ )
+
+ return key.public_bytes(
+ cryptography.hazmat.primitives.serialization.Encoding.X962,
+ cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint,
+ )
+
+ @staticmethod
+ def b64encode(data):
+ return base64.urlsafe_b64encode(data).strip(b"=")
+
+ def _encrypt(self, message):
+ """
+ This is an absolutely ugly monster of a function which will sign the message
+ """
+ headers = {}
+
+ # Generate some salt
+ salt = os.urandom(16)
+
+ record_size = 4096
+ chunk_size = record_size - 17
+
+ # Generate an ephemeral server key
+ server_private_key = cryptography.hazmat.primitives.asymmetric.ec.generate_private_key(
+ cryptography.hazmat.primitives.asymmetric.ec.SECP256R1,
+ cryptography.hazmat.backends.default_backend(),
+ )
+ server_public_key = server_private_key.public_key()
+
+ context = b"WebPush: info\x00"
+
+ # Serialize the client's public key
+ context += self.p256dh.public_bytes(
+ cryptography.hazmat.primitives.serialization.Encoding.X962,
+ cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint,
+ )
+
+ # Serialize the server's public key
+ context += server_public_key.public_bytes(
+ cryptography.hazmat.primitives.serialization.Encoding.X962,
+ cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint,
+ )
+
+ # Perform key derivation with ECDH
+ secret = server_private_key.exchange(
+ cryptography.hazmat.primitives.asymmetric.ec.ECDH(), self.p256dh,
+ )
+
+ # Derive more stuff
+ hkdf_auth = cryptography.hazmat.primitives.kdf.hkdf.HKDF(
+ algorithm=cryptography.hazmat.primitives.hashes.SHA256(),
+ length=32,
+ salt=self.auth,
+ info=context,
+ backend=cryptography.hazmat.backends.default_backend(),
+ )
+ secret = hkdf_auth.derive(secret)
+
+ # Derive the signing key
+ hkdf_key = cryptography.hazmat.primitives.kdf.hkdf.HKDF(
+ algorithm=cryptography.hazmat.primitives.hashes.SHA256(),
+ length=16,
+ salt=salt,
+ info=b"Content-Encoding: aes128gcm\x00",
+ backend=cryptography.hazmat.backends.default_backend(),
+ )
+ encryption_key = hkdf_key.derive(secret)
+
+ # Derive a nonce
+ hkdf_nonce = cryptography.hazmat.primitives.kdf.hkdf.HKDF(
+ algorithm=cryptography.hazmat.primitives.hashes.SHA256(),
+ length=12,
+ salt=salt,
+ info=b"Content-Encoding: nonce\x00",
+ backend=cryptography.hazmat.backends.default_backend(),
+ )
+ nonce = hkdf_nonce.derive(secret)
+
+ result = b""
+ chunks = 0
+
+ while True:
+ # Fetch a chunk
+ chunk, message = message[:chunk_size], message[chunk_size:]
+ if not chunk:
+ break
+
+ # Is this the last chunk?
+ last = not message
+
+ # Encrypt the chunk
+ result += self._encrypt_chunk(encryption_key, nonce, chunks, chunk, last)
+
+ # Kepp counting...
+ chunks += 1
+
+ # Fetch the public key
+ key_id = server_public_key.public_bytes(
+ cryptography.hazmat.primitives.serialization.Encoding.X962,
+ cryptography.hazmat.primitives.serialization.PublicFormat.UncompressedPoint,
+ )
+
+ # Join the entire message together
+ message = [
+ salt,
+ struct.pack("!L", record_size),
+ struct.pack("!B", len(key_id)),
+ key_id,
+ result,
+ ]
+
+ return b"".join(message)
+
+ def _encrypt_chunk(self, key, nonce, counter, chunk, last=False):
+ """
+ Encrypts one chunk
+ """
+ # Make the IV
+ iv = self._make_iv(nonce, counter)
+
+ log.debug("Encrypting chunk %s: length = %s" % (counter + 1, len(chunk)))
+
+ if last:
+ chunk += b"\x02"
+ else:
+ chunk += b"\x01"
+
+ # Setup AES GCM
+ cipher = cryptography.hazmat.primitives.ciphers.Cipher(
+ cryptography.hazmat.primitives.ciphers.algorithms.AES128(key),
+ cryptography.hazmat.primitives.ciphers.modes.GCM(iv),
+ backend=cryptography.hazmat.backends.default_backend(),
+ )
+
+ # Get the encryptor
+ encryptor = cipher.encryptor()
+
+ # Encrypt the chunk
+ chunk = encryptor.update(chunk)
+
+ # Finalize this round
+ chunk += encryptor.finalize() + encryptor.tag
+
+ return chunk
+
+ @staticmethod
+ def _make_iv(base, counter):
+ mask, = struct.unpack("!Q", base[4:])
+
+ return base[:4] + struct.pack("!Q", counter ^ mask)