]> git.ipfire.org Git - pbs.git/commitdiff
hub: Refactor uploads for streaming
authorMichael Tremer <michael.tremer@ipfire.org>
Thu, 26 May 2022 09:05:11 +0000 (09:05 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Thu, 26 May 2022 09:05:11 +0000 (09:05 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/buildservice/uploads.py
src/database.sql
src/hub/__init__.py
src/hub/handlers.py

index b994a404be26357efa0b2fef032844a05832b6e7..7c7316fecaa1d698200900ac59006e22fce497be 100644 (file)
@@ -1,16 +1,10 @@
-#!/usr/bin/python
+#!/usr/bin/python3
 
-import datetime
-import hashlib
-import logging
-import os
 import shutil
+import tempfile
+import uuid
 
 from . import base
-from . import misc
-from . import packages
-from . import users
-
 from .constants import *
 from .decorators import *
 
@@ -28,42 +22,80 @@ class Uploads(base.Object):
                        yield Upload(self.backend, row.id, data=row)
 
        def __iter__(self):
-               uploads = self._get_uploads("SELECT * FROM uploads ORDER BY time_started DESC")
+               uploads = self._get_uploads("SELECT * FROM uploads \
+                       ORDER BY created_at DESC")
 
                return iter(uploads)
 
        def get_by_uuid(self, uuid):
-               return self._get_upload("SELECT * FROM uploads WHERE uuid = %s", uuid)
-
-       def create(self, filename, size, hash, builder=None, user=None):
-               assert builder or user
-
-               # Create a random ID for this upload
-               uuid = users.generate_random_string(64)
-
-               upload = self._get_upload("INSERT INTO uploads(uuid, filename, size, hash) \
-                       VALUES(%s, %s, %s, %s) RETURNING *", uuid, filename, size, hash)
-
-               if builder:
-                       upload.builder = builder
-
-               elif user:
-                       upload.user = user
-
-               # Create space to where we save the data.
-               dirname = os.path.dirname(upload.path)
-               if not os.path.exists(dirname):
-                       os.makedirs(dirname)
-
-               # Create empty file.
-               f = open(upload.path, "w")
-               f.close()
-
+               return self._get_upload("SELECT * FROM uploads \
+                       WHERE uuid = %s AND expires_at > CURRENT_TIMESTAMP", uuid)
+
+       def create(self, filename, f, builder=None, user=None):
+               # Check if either builder or user are set
+               if not builder and not user:
+                       raise ValueError("builder or user must be set")
+
+               # Reset f
+               f.seek(0)
+
+               # Create a new temporary file
+               t = tempfile.NamedTemporaryFile(dir=UPLOADS_DIR, delete=False)
+
+               # Copy all content from f
+               shutil.copyfileobj(f, t)
+
+               upload = self._get_upload("""
+                       INSERT INTO
+                               uploads
+                       (
+                               uuid,
+                               filename,
+                               path,
+                               size,
+                               builder_id,
+                               user_id
+                       )
+                       VALUES
+                       (
+                               %s,
+                               %s,
+                               %s,
+                               %s,
+                               %s,
+                               %s
+                       )
+                       RETURNING *""",
+                       "%s" % uuid.uuid4(),
+                       filename,
+                       t.name,
+                       t.tell(),
+                       builder.id if builder else None,
+                       user.id if user else None,
+               )
+
+               # Close the temporary file
+               t.close()
+
+               # Return the newly created upload object
                return upload
 
        def cleanup(self):
-               for upload in self:
-                       upload.cleanup()
+               # Find all expired uploads
+               uploads = self._get_uploads("""
+                       SELECT
+                               *
+                       FROM
+                               uploads
+                       WHERE
+                               expires_at <= CURRENT_TIMESTAMP
+                       ORDER BY
+                               created_at
+               """)
+
+               # Delete them all
+               for upload in uploads:
+                       upload.delete()
 
 
 class Upload(base.DataObject):
@@ -73,26 +105,18 @@ class Upload(base.DataObject):
        def uuid(self):
                return self.data.uuid
 
-       @property
-       def hash(self):
-               return self.data.hash
-
        @property
        def filename(self):
                return self.data.filename
 
        @property
        def path(self):
-               return os.path.join(UPLOADS_DIR, self.uuid, self.filename)
+               return self.data.path
 
        @property
        def size(self):
                return self.data.size
 
-       @property
-       def progress(self):
-               return self.data.progress / self.size
-
        # Builder
 
        def get_builder(self):
@@ -115,80 +139,17 @@ class Upload(base.DataObject):
 
        user = lazy_property(get_user, set_user)
 
-       def append(self, data):
-               # Check if the filesize was exceeded.
-               size = os.path.getsize(self.path) + len(data)
-               if size > self.data.size:
-                       raise Exception("Given filesize was exceeded for upload %s" % self.uuid)
-
-               logging.debug("Writing %s bytes to %s" % (len(data), self.path))
-
-               with open(self.path, "ab") as f:
-                       f.write(data)
+       def delete(self):
+               # Remove the uploaded data
+               shutil.rmtree(self.path, ignore_errors=True)
 
-               self._set_attribute("progress", size)
-
-       def validate(self):
-               size = os.path.getsize(self.path)
-               if not size == self.data.size:
-                       logging.error("Filesize is not okay: %s" % (self.uuid))
-                       return False
-
-               # Calculate a hash to validate the upload.
-               hash = misc.calc_hash(self.path, "sha1")
-
-               if not self.hash == hash:
-                       logging.error("Hash did not match: %s != %s" % (self.hash, hash))
-                       return False
-
-               return True
-
-       def finished(self):
-               """
-                       Update the status of the upload in the database to "finished".
-               """
-               # Check if the file was completely uploaded and the hash is correct.
-               # If not, the upload has failed.
-               if not self.validate():
-                       return False
-
-               self._set_attribute("finished", True)
-               self._set_attribute("time_finished", datetime.datetime.utcnow())
-
-               return True
-
-       def remove(self):
-               # Remove the uploaded data.
-               path = os.path.dirname(self.path)
-               if os.path.exists(path):
-                       shutil.rmtree(path, ignore_errors=True)
-
-               # Delete the upload from the database.
+               # Delete the upload from the database
                self.db.execute("DELETE FROM uploads WHERE id = %s", self.id)
 
        @property
-       def time_started(self):
-               return self.data.time_started
-
-       @property
-       def time_running(self):
-               # Get the seconds since we are running.
-               try:
-                       time_running = datetime.datetime.utcnow() - self.time_started
-                       time_running = time_running.total_seconds()
-               except:
-                       time_running = 0
-
-               return time_running
+       def created_at(self):
+               return self.data.created_at
 
        @property
-       def speed(self):
-               if not self.time_running:
-                       return 0
-
-               return self.data.progress / self.time_running
-
-       def cleanup(self):
-               # Remove uploads that are older than 2 hours.
-               if self.time_running >= 3600 * 2:
-                       self.remove()
+       def expires_at(self):
+               return self.data.expires_at
index 75bb7afa3a08d6db20b4e0d5800cd7f20153aff6..b1140887d44f5a8e51885ad9d4addf4b2d472c92 100644 (file)
@@ -1377,31 +1377,29 @@ ALTER SEQUENCE sources_id_seq OWNED BY sources.id;
 
 
 --
--- Name: uploads; Type: TABLE; Schema: public; Owner: pakfire; Tablespace: 
+-- Name: uploads; Type: TABLE; Schema: public; Owner: pakfire
 --
 
-CREATE TABLE uploads (
+CREATE TABLE public.uploads (
     id integer NOT NULL,
     uuid text NOT NULL,
     user_id integer,
     builder_id integer,
     filename text NOT NULL,
-    hash text,
+    path text NOT NULL,
     size bigint NOT NULL,
-    progress bigint DEFAULT 0 NOT NULL,
-    finished boolean DEFAULT false NOT NULL,
-    time_started timestamp without time zone DEFAULT now() NOT NULL,
-    time_finished timestamp without time zone
+    created_at timestamp without time zone DEFAULT CURRENT_TIMESTAMP NOT NULL,
+    expires_at timestamp without time zone DEFAULT (CURRENT_TIMESTAMP + '24:00:00'::interval) NOT NULL
 );
 
 
-ALTER TABLE uploads OWNER TO pakfire;
+ALTER TABLE public.uploads OWNER TO pakfire;
 
 --
 -- Name: uploads_id_seq; Type: SEQUENCE; Schema: public; Owner: pakfire
 --
 
-CREATE SEQUENCE uploads_id_seq
+CREATE SEQUENCE public.uploads_id_seq
     START WITH 1
     INCREMENT BY 1
     NO MINVALUE
@@ -1409,13 +1407,13 @@ CREATE SEQUENCE uploads_id_seq
     CACHE 1;
 
 
-ALTER TABLE uploads_id_seq OWNER TO pakfire;
+ALTER TABLE public.uploads_id_seq OWNER TO pakfire;
 
 --
 -- Name: uploads_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: pakfire
 --
 
-ALTER SEQUENCE uploads_id_seq OWNED BY uploads.id;
+ALTER SEQUENCE public.uploads_id_seq OWNED BY public.uploads.id;
 
 
 --
index 1037e420d626732f3a2ee2375d4ea2539ad4c59e..34c2393b68b9bcf008aaac33471a74cc5f4d33df 100644 (file)
@@ -40,6 +40,7 @@ class Application(tornado.web.Application):
                        (r"/packages/(.*)", handlers.PackagesGetHandler),
 
                        # Uploads
+                       (r"/upload", handlers.UploadHandler),
                        (r"/uploads/create", handlers.UploadsCreateHandler),
                        (r"/uploads/stream", handlers.UploadsStreamHandler),
                        (r"/uploads/(.*)/sendchunk", handlers.UploadsSendChunkHandler),
index 26ce1426afcc13a7055664f00a7bff62a070b1d6..fcdf50891c7697a195b5d4d1d61a80f56c42c4c8 100644 (file)
@@ -2,8 +2,10 @@
 
 import base64
 import hashlib
+import hmac
 import json
 import logging
+import tempfile
 import time
 import tornado.web
 
@@ -12,6 +14,8 @@ from .. import builders
 from .. import uploads
 from .. import users
 
+log = logging.getLogger("pakfire.hub")
+
 class LongPollMixin(object):
        def initialize(self):
                self._start_time = time.time()
@@ -141,6 +145,88 @@ class ErrorTestHandler(BaseHandler):
 
 # Uploads
 
+@tornado.web.stream_request_body
+class UploadHandler(BaseHandler):
+       @tornado.web.authenticated
+       def initialize(self):
+               # Fetch the filename
+               self.filename = self.get_argument("filename")
+
+               # Fetch file size
+               self.filesize = self.get_argument_int("size")
+
+               # XXX check quota
+
+               # Count how many bytes have been received
+               self.bytes_read = 0
+
+               # Allocate a temporary file
+               self.f = tempfile.SpooledTemporaryFile(max_size=10485760)
+
+               self.h, self.hexdigest = self._setup_digest()
+
+       def data_received(self, data):
+               """
+                       Called when some data is being received
+               """
+               self.bytes_read += len(data)
+
+               # Abort if we have received more data than expected
+               if self.bytes_read > self.filesize:
+                       # 422 - Unprocessable Entity
+                       raise tornado.web.HTTPError(422)
+
+               # Update the digest
+               self.h.update(data)
+
+               # Write payload
+               self.f.write(data)
+
+       def put(self):
+               """
+                       Called after the entire file has been received
+               """
+               log.debug("Finished receiving data")
+
+               # Finish computing the hexdigest
+               computed_hexdigest = self.h.hexdigest()
+
+               # Check if the hexdigest matches
+               # If not, we will raise an error
+               if not hmac.compare_digest(self.hexdigest, computed_hexdigest):
+                       # 422 - Unprocessable Entity
+                       raise tornado.web.HTTPError(422)
+
+               # Create a new upload object
+               with self.db.transaction():
+                       upload = self.backend.uploads.create(
+                               self.filename,
+                               self.f,
+                               builder=self.builder,
+                               user=self.user,
+                       )
+
+               # Send the ID of the upload back to the client
+               self.finish({
+                       "id"         : upload.uuid,
+                       "expires_at" : upload.expires_at.isoformat(),
+               })
+
+       def _setup_digest(self):
+               # Fetch the digest argument
+               digest = self.get_argument("digest")
+
+               # Find the algorithm
+               algo, delim, hexdigest = digest.partition(":")
+
+               try:
+                       h = hashlib.new(algo)
+               except ValueError as e:
+                       raise tornado.web.HTTPError(415) from e
+
+               return h, hexdigest
+
+
 class UploadsCreateHandler(BaseHandler):
        """
                Create a new upload object in the database and return a unique ID