From: Michael Tremer Date: Thu, 26 May 2022 09:05:11 +0000 (+0000) Subject: hub: Refactor uploads for streaming X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=96bcb9e7c9a8a0338bc33ee367db669bcd022786;p=pbs.git hub: Refactor uploads for streaming Signed-off-by: Michael Tremer --- diff --git a/src/buildservice/uploads.py b/src/buildservice/uploads.py index b994a404..7c7316fe 100644 --- a/src/buildservice/uploads.py +++ b/src/buildservice/uploads.py @@ -1,16 +1,10 @@ -#!/usr/bin/python +#!/usr/bin/python3 -import datetime -import hashlib -import logging -import os import shutil +import tempfile +import uuid from . import base -from . import misc -from . import packages -from . import users - from .constants import * from .decorators import * @@ -28,42 +22,80 @@ class Uploads(base.Object): yield Upload(self.backend, row.id, data=row) def __iter__(self): - uploads = self._get_uploads("SELECT * FROM uploads ORDER BY time_started DESC") + uploads = self._get_uploads("SELECT * FROM uploads \ + ORDER BY created_at DESC") return iter(uploads) def get_by_uuid(self, uuid): - return self._get_upload("SELECT * FROM uploads WHERE uuid = %s", uuid) - - def create(self, filename, size, hash, builder=None, user=None): - assert builder or user - - # Create a random ID for this upload - uuid = users.generate_random_string(64) - - upload = self._get_upload("INSERT INTO uploads(uuid, filename, size, hash) \ - VALUES(%s, %s, %s, %s) RETURNING *", uuid, filename, size, hash) - - if builder: - upload.builder = builder - - elif user: - upload.user = user - - # Create space to where we save the data. - dirname = os.path.dirname(upload.path) - if not os.path.exists(dirname): - os.makedirs(dirname) - - # Create empty file. - f = open(upload.path, "w") - f.close() - + return self._get_upload("SELECT * FROM uploads \ + WHERE uuid = %s AND expires_at > CURRENT_TIMESTAMP", uuid) + + def create(self, filename, f, builder=None, user=None): + # Check if either builder or user are set + if not builder and not user: + raise ValueError("builder or user must be set") + + # Reset f + f.seek(0) + + # Create a new temporary file + t = tempfile.NamedTemporaryFile(dir=UPLOADS_DIR, delete=False) + + # Copy all content from f + shutil.copyfileobj(f, t) + + upload = self._get_upload(""" + INSERT INTO + uploads + ( + uuid, + filename, + path, + size, + builder_id, + user_id + ) + VALUES + ( + %s, + %s, + %s, + %s, + %s, + %s + ) + RETURNING *""", + "%s" % uuid.uuid4(), + filename, + t.name, + t.tell(), + builder.id if builder else None, + user.id if user else None, + ) + + # Close the temporary file + t.close() + + # Return the newly created upload object return upload def cleanup(self): - for upload in self: - upload.cleanup() + # Find all expired uploads + uploads = self._get_uploads(""" + SELECT + * + FROM + uploads + WHERE + expires_at <= CURRENT_TIMESTAMP + ORDER BY + created_at + """) + + # Delete them all + for upload in uploads: + upload.delete() class Upload(base.DataObject): @@ -73,26 +105,18 @@ class Upload(base.DataObject): def uuid(self): return self.data.uuid - @property - def hash(self): - return self.data.hash - @property def filename(self): return self.data.filename @property def path(self): - return os.path.join(UPLOADS_DIR, self.uuid, self.filename) + return self.data.path @property def size(self): return self.data.size - @property - def progress(self): - return self.data.progress / self.size - # Builder def get_builder(self): @@ -115,80 +139,17 @@ class Upload(base.DataObject): user = lazy_property(get_user, set_user) - def append(self, data): - # Check if the filesize was exceeded. - size = os.path.getsize(self.path) + len(data) - if size > self.data.size: - raise Exception("Given filesize was exceeded for upload %s" % self.uuid) - - logging.debug("Writing %s bytes to %s" % (len(data), self.path)) - - with open(self.path, "ab") as f: - f.write(data) + def delete(self): + # Remove the uploaded data + shutil.rmtree(self.path, ignore_errors=True) - self._set_attribute("progress", size) - - def validate(self): - size = os.path.getsize(self.path) - if not size == self.data.size: - logging.error("Filesize is not okay: %s" % (self.uuid)) - return False - - # Calculate a hash to validate the upload. - hash = misc.calc_hash(self.path, "sha1") - - if not self.hash == hash: - logging.error("Hash did not match: %s != %s" % (self.hash, hash)) - return False - - return True - - def finished(self): - """ - Update the status of the upload in the database to "finished". - """ - # Check if the file was completely uploaded and the hash is correct. - # If not, the upload has failed. - if not self.validate(): - return False - - self._set_attribute("finished", True) - self._set_attribute("time_finished", datetime.datetime.utcnow()) - - return True - - def remove(self): - # Remove the uploaded data. - path = os.path.dirname(self.path) - if os.path.exists(path): - shutil.rmtree(path, ignore_errors=True) - - # Delete the upload from the database. + # Delete the upload from the database self.db.execute("DELETE FROM uploads WHERE id = %s", self.id) @property - def time_started(self): - return self.data.time_started - - @property - def time_running(self): - # Get the seconds since we are running. - try: - time_running = datetime.datetime.utcnow() - self.time_started - time_running = time_running.total_seconds() - except: - time_running = 0 - - return time_running + def created_at(self): + return self.data.created_at @property - def speed(self): - if not self.time_running: - return 0 - - return self.data.progress / self.time_running - - def cleanup(self): - # Remove uploads that are older than 2 hours. - if self.time_running >= 3600 * 2: - self.remove() + def expires_at(self): + return self.data.expires_at diff --git a/src/database.sql b/src/database.sql index 75bb7afa..b1140887 100644 --- a/src/database.sql +++ b/src/database.sql @@ -1377,31 +1377,29 @@ ALTER SEQUENCE sources_id_seq OWNED BY sources.id; -- --- Name: uploads; Type: TABLE; Schema: public; Owner: pakfire; Tablespace: +-- Name: uploads; Type: TABLE; Schema: public; Owner: pakfire -- -CREATE TABLE uploads ( +CREATE TABLE public.uploads ( id integer NOT NULL, uuid text NOT NULL, user_id integer, builder_id integer, filename text NOT NULL, - hash text, + path text NOT NULL, size bigint NOT NULL, - progress bigint DEFAULT 0 NOT NULL, - finished boolean DEFAULT false NOT NULL, - time_started timestamp without time zone DEFAULT now() NOT NULL, - time_finished timestamp without time zone + created_at timestamp without time zone DEFAULT CURRENT_TIMESTAMP NOT NULL, + expires_at timestamp without time zone DEFAULT (CURRENT_TIMESTAMP + '24:00:00'::interval) NOT NULL ); -ALTER TABLE uploads OWNER TO pakfire; +ALTER TABLE public.uploads OWNER TO pakfire; -- -- Name: uploads_id_seq; Type: SEQUENCE; Schema: public; Owner: pakfire -- -CREATE SEQUENCE uploads_id_seq +CREATE SEQUENCE public.uploads_id_seq START WITH 1 INCREMENT BY 1 NO MINVALUE @@ -1409,13 +1407,13 @@ CREATE SEQUENCE uploads_id_seq CACHE 1; -ALTER TABLE uploads_id_seq OWNER TO pakfire; +ALTER TABLE public.uploads_id_seq OWNER TO pakfire; -- -- Name: uploads_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: pakfire -- -ALTER SEQUENCE uploads_id_seq OWNED BY uploads.id; +ALTER SEQUENCE public.uploads_id_seq OWNED BY public.uploads.id; -- diff --git a/src/hub/__init__.py b/src/hub/__init__.py index 1037e420..34c2393b 100644 --- a/src/hub/__init__.py +++ b/src/hub/__init__.py @@ -40,6 +40,7 @@ class Application(tornado.web.Application): (r"/packages/(.*)", handlers.PackagesGetHandler), # Uploads + (r"/upload", handlers.UploadHandler), (r"/uploads/create", handlers.UploadsCreateHandler), (r"/uploads/stream", handlers.UploadsStreamHandler), (r"/uploads/(.*)/sendchunk", handlers.UploadsSendChunkHandler), diff --git a/src/hub/handlers.py b/src/hub/handlers.py index 26ce1426..fcdf5089 100644 --- a/src/hub/handlers.py +++ b/src/hub/handlers.py @@ -2,8 +2,10 @@ import base64 import hashlib +import hmac import json import logging +import tempfile import time import tornado.web @@ -12,6 +14,8 @@ from .. import builders from .. import uploads from .. import users +log = logging.getLogger("pakfire.hub") + class LongPollMixin(object): def initialize(self): self._start_time = time.time() @@ -141,6 +145,88 @@ class ErrorTestHandler(BaseHandler): # Uploads +@tornado.web.stream_request_body +class UploadHandler(BaseHandler): + @tornado.web.authenticated + def initialize(self): + # Fetch the filename + self.filename = self.get_argument("filename") + + # Fetch file size + self.filesize = self.get_argument_int("size") + + # XXX check quota + + # Count how many bytes have been received + self.bytes_read = 0 + + # Allocate a temporary file + self.f = tempfile.SpooledTemporaryFile(max_size=10485760) + + self.h, self.hexdigest = self._setup_digest() + + def data_received(self, data): + """ + Called when some data is being received + """ + self.bytes_read += len(data) + + # Abort if we have received more data than expected + if self.bytes_read > self.filesize: + # 422 - Unprocessable Entity + raise tornado.web.HTTPError(422) + + # Update the digest + self.h.update(data) + + # Write payload + self.f.write(data) + + def put(self): + """ + Called after the entire file has been received + """ + log.debug("Finished receiving data") + + # Finish computing the hexdigest + computed_hexdigest = self.h.hexdigest() + + # Check if the hexdigest matches + # If not, we will raise an error + if not hmac.compare_digest(self.hexdigest, computed_hexdigest): + # 422 - Unprocessable Entity + raise tornado.web.HTTPError(422) + + # Create a new upload object + with self.db.transaction(): + upload = self.backend.uploads.create( + self.filename, + self.f, + builder=self.builder, + user=self.user, + ) + + # Send the ID of the upload back to the client + self.finish({ + "id" : upload.uuid, + "expires_at" : upload.expires_at.isoformat(), + }) + + def _setup_digest(self): + # Fetch the digest argument + digest = self.get_argument("digest") + + # Find the algorithm + algo, delim, hexdigest = digest.partition(":") + + try: + h = hashlib.new(algo) + except ValueError as e: + raise tornado.web.HTTPError(415) from e + + return h, hexdigest + + class UploadsCreateHandler(BaseHandler): """ Create a new upload object in the database and return a unique ID