]> git.ipfire.org Git - pbs.git/commitdiff
uploads: Refactor caching the whole thing again...
authorMichael Tremer <michael.tremer@ipfire.org>
Sun, 23 Oct 2022 15:33:52 +0000 (15:33 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Sun, 23 Oct 2022 15:33:52 +0000 (15:33 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/buildservice/packages.py
src/buildservice/uploads.py
src/database.sql
src/hub/uploads.py
tests/test.py

index 8bd465949eb5037a8b04cb043a29c8a5eb332461..e821800bdb2a47ddea3ef349d2042cc643b608b5 100644 (file)
@@ -81,10 +81,6 @@ class Packages(base.Object):
                if not self.backend.test and distro:
                        raise RuntimeError("Cannot alter distro when not in test mode")
 
-               # Check if the upload has been completed
-               if not upload.is_completed():
-                       raise RuntimeError("Cannot import package from incomplete upload")
-
                # Upload the archive
                archive = await self.backend.open(upload.path)
 
index b63e0890f6d4891f486f68a957588882f5d884df..9783fe7ceffc7767236b88db7770c41c400dfe3b 100644 (file)
@@ -3,7 +3,6 @@
 import asyncio
 import hashlib
 import hmac
-import io
 import logging
 import os
 import shutil
@@ -46,8 +45,6 @@ class Uploads(base.Object):
                                uploads
                        WHERE
                                uuid = %s
-                       AND
-                               completed_at IS NOT NULL
                        AND
                                expires_at > CURRENT_TIMESTAMP
                        """, uuid,
@@ -199,77 +196,10 @@ class Upload(base.DataObject):
        def expires_at(self):
                return self.data.expires_at
 
-       @lazy_property
-       def _buffer(self):
-               """
-                       Returns something that buffers any uploaded data.
-               """
-               return io.BytesIO()
-
        @lazy_property
        def _filesize(self):
                return os.path.getsize(self.path)
 
-       async def write(self, data):
-               """
-                       Takes a chunk of data and writes it to disk
-               """
-               #log.debug("Received a new chunk of %s byte(s) of data for %s" % (len(data), self))
-
-               # Check if we would exceed the filesize
-               if self._filesize + len(data) > self.size:
-                       raise OverflowError
-
-               # Append the chunk to the buffer
-               self._buffer.write(data)
-               self._filesize += len(data)
-
-               # Flush the buffer to disk after it has reached its threshold
-               if self._buffer.tell() >= MAX_BUFFER_SIZE:
-                       await self.flush()
-
-       async def flush(self):
-               """
-                       Flushes any buffered file content to disk
-               """
-               # Nothing to do if the buffer is empty
-               if not self._buffer.tell():
-                       return
-
-               #log.debug("Flushing buffer (%s byte(s))" % self._buffer.tell())
-
-               return await asyncio.to_thread(self._flush)
-
-       def _flush(self):
-               # Move back to the beginning of the buffer
-               self._buffer.seek(0)
-
-               # Append the buffer to the file
-               with open(self.path, "ab") as f:
-                       shutil.copyfileobj(self._buffer, f)
-
-               # Reset the buffer
-               self._buffer = io.BytesIO()
-
-       async def completed(self):
-               """
-                       Called when all content of the upload is received
-               """
-               # Flush anything that might have been left in the buffer
-               await self.flush()
-
-               # Mark as completed
-               self._set_attribute_now("completed_at")
-
-       def is_completed(self):
-               """
-                       Returns True if this upload has completed
-               """
-               if self.data.completed_at:
-                       return True
-
-               return False
-
        async def check_digest(self, algo, expected_digest):
                """
                        Checks if the upload matches an expected digest
@@ -290,9 +220,6 @@ class Upload(base.DataObject):
                """
                        Computes the digest of this download
                """
-               if not self.is_completed():
-                       raise RuntimeError("Upload has not been completed, yet")
-
                log.debug("Computing '%s' digest for %s" % (algo, self))
 
                return await asyncio.to_thread(self._digest, algo)
@@ -313,6 +240,17 @@ class Upload(base.DataObject):
                # Return the digest
                return h.digest()
 
+       async def copyfrom(self, src):
+               """
+                       Copies the content of this upload from the source file descriptor
+               """
+               return await asyncio.to_thread(self._copyfrom, src)
+
+       def _copyfrom(self, src):
+               # Open the destination file and copy all source data into it
+               with open(self.path, "wb") as dst:
+                       shutil.copyfileobj(src, dst)
+
        async def copyinto(self, dst):
                """
                        Copies the content of this upload into the destination file descriptor.
index 145270c6fc45e629e0d83d63c0222f0fad73a9aa..a6f2b863ec0ffe7051efae3228a719dc8e9b7731 100644 (file)
@@ -957,8 +957,7 @@ CREATE TABLE public.uploads (
     path text NOT NULL,
     size bigint NOT NULL,
     created_at timestamp without time zone DEFAULT CURRENT_TIMESTAMP NOT NULL,
-    expires_at timestamp without time zone DEFAULT (CURRENT_TIMESTAMP + '24:00:00'::interval) NOT NULL,
-    completed_at timestamp without time zone
+    expires_at timestamp without time zone DEFAULT (CURRENT_TIMESTAMP + '24:00:00'::interval) NOT NULL
 );
 
 
index 1fe36cdc3fc96b8aef11eceb2ca5036de30281c5..485dfedfe5df206d7e4db0b5b1689e1de40e78ee 100644 (file)
@@ -19,6 +19,7 @@
 #                                                                             #
 ###############################################################################
 
+import io
 import tornado.web
 
 from .handlers import BaseHandler
@@ -26,64 +27,60 @@ from .. import users
 
 @tornado.web.stream_request_body
 class CreateHandler(BaseHandler):
-       @tornado.web.authenticated
        def initialize(self):
+               # Buffer to cache the uploaded content
+               self.buffer = io.BytesIO()
+
+       def data_received(self, data):
+               """
+                       Called when some data is being received
+               """
+               self.buffer.write(data)
+
+       @tornado.web.authenticated
+       async def put(self):
+               """
+                       Called after the entire file has been received
+               """
                # Fetch the filename
                filename = self.get_argument("filename")
 
                # Fetch file size
                size = self.get_argument_int("size")
 
+               # Fetch the digest argument
+               algo, delim, hexdigest = self.get_argument("digest").partition(":")
+
+               # Convert hexdigest
+               digest = bytes.fromhex(hexdigest)
+
+               # Move to the beginning of the buffer
+               self.buffer.seek(0)
+
                # Create a new upload
                with self.db.transaction():
                        try:
-                               self.upload = self.backend.uploads.create(
+                               upload = self.backend.uploads.create(
                                        filename,
                                        size=size,
                                        builder=self.builder,
                                        user=self.user,
                                )
+
                        except users.QuotaExceededError as e:
                                raise tornado.web.HTTPError(400,
                                        "Quota exceeded for %s" % self.current_user) from e
 
-       async def data_received(self, data):
-               """
-                       Called when some data is being received
-               """
-               await self.upload.write(data)
-
-       async def put(self):
-               """
-                       Called after the entire file has been received
-               """
-               # Consider the upload completed
-               await self.upload.completed()
-
-               # Fetch the digest argument
-               algo, delim, hexdigest = self.get_argument("digest").partition(":")
-
-               # Convert hexdigest
-               digest = bytes.fromhex(hexdigest)
+                       # Import the payload from the buffer
+                       await upload.copyfrom(self.buffer)
 
-               # Check the digest
-               if not await self.upload.check_digest(algo, digest):
-                       # 422 - Unprocessable Entity
-                       raise tornado.web.HTTPError(422, "Digest did not match")
+                       # Check the digest
+                       if not await upload.check_digest(algo, digest):
+                               # 422 - Unprocessable Entity
+                               raise tornado.web.HTTPError(422, "Digest did not match")
 
                # Send the ID of the upload back to the client
                self.finish({
-                       "id"         : self.upload.uuid,
-                       "expires_at" : self.upload.expires_at.isoformat(),
+                       "id"         : upload.uuid,
+                       "expires_at" : upload.expires_at.isoformat(),
                })
-
-               # Free upload to avoid cleanup
-               self.upload = None
-
-       def on_connection_close(self):
-               """
-                       Called when a connection was unexpectedly closed
-               """
-               # Delete the upload
-               #if self.upload:
-               #       await self.upload.delete()
index 8a3a99766cf7922b87714e82d64723b2ac66afc5..b85d14cbf1720aac88f7ca2ef89a9e94b2c65edc 100644 (file)
@@ -195,14 +195,6 @@ class TestCase(unittest.IsolatedAsyncioTestCase):
 
                # Copy the payload
                with open(path, "rb") as f:
-                       while True:
-                               buf = f.read(4096)
-                               if not buf:
-                                       break
-
-                               await upload.write(buf)
-
-               # Complete the upload
-               await upload.completed()
+                       await upload.copyfrom(f)
 
                return upload