]> git.ipfire.org Git - pbs.git/commitdiff
uploads: Refactor so that we won't duplicate any data
authorMichael Tremer <michael.tremer@ipfire.org>
Tue, 18 Oct 2022 14:34:54 +0000 (14:34 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Tue, 18 Oct 2022 14:34:54 +0000 (14:34 +0000)
Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/buildservice/uploads.py
src/hub/uploads.py
tests/test.py
tests/upload.py

index 3e6b88c6a9a115535d732b299b7bc33da4020789..1754a3568220c9d4e17cf8a955217a341f24e263 100644 (file)
@@ -1,11 +1,14 @@
 #!/usr/bin/python3
 
+import asyncio
 import hashlib
+import hmac
 import logging
 import os
 import tempfile
 
 from . import base
+from . import users
 from .constants import *
 from .decorators import *
 
@@ -35,7 +38,7 @@ class Uploads(base.Object):
                return self._get_upload("SELECT * FROM uploads \
                        WHERE uuid = %s AND expires_at > CURRENT_TIMESTAMP", uuid)
 
-       def allocate_file(self):
+       def _allocate_file(self):
                """
                        Returns a file handle which can be used to write temporary data to.
                """
@@ -49,14 +52,18 @@ class Uploads(base.Object):
 
                return tempfile.NamedTemporaryFile(dir=path, delete=False)
 
-       def create(self, filename, path, size=None, builder=None, user=None):
+       def create(self, filename, size, builder=None, user=None):
                # Check if either builder or user are set
                if not builder and not user:
                        raise ValueError("builder or user must be set")
 
-               # Fetch size if none given
-               if size is None:
-                       size = os.path.getsize(path)
+               # Check quota for users
+               if user:
+                       # This will raise an exception if the quota has been exceeded
+                       user.check_quota(size)
+
+               # Allocate a new temporary file
+               f = self._allocate_file()
 
                upload = self._get_upload("""
                        INSERT INTO
@@ -78,7 +85,7 @@ class Uploads(base.Object):
                        )
                        RETURNING *""",
                        filename,
-                       path,
+                       f.name,
                        size,
                        builder,
                        user,
@@ -167,6 +174,43 @@ class Upload(base.DataObject):
        def expires_at(self):
                return self.data.expires_at
 
+       async def write(self, data):
+               """
+                       Takes a chunk of data and writes it to disk
+               """
+               log.debug("Received a new chunk of %s byte(s) of data for %s" % (len(data), self))
+
+               return await asyncio.to_thread(self._write, data)
+
+       def _write(self, data):
+               # Write the data to disk
+               with open(self.path, "ab") as f:
+                       # Fetch the current size
+                       filesize = f.tell()
+
+                       # Check if we would exceed the filesize
+                       if filesize + len(data) > self.size:
+                               raise OverflowError
+
+                       # Write data
+                       f.write(data)
+
+       async def check_digest(self, algo, expected_digest):
+               """
+                       Checks if the upload matches an expected digest
+               """
+               # Compute the digest
+               computed_digest = await self.digest(algo)
+
+               # Compare the digests
+               if hmac.compare_digest(computed_digest, expected_digest):
+                       return True
+
+               # The digests didn't match
+               log.error("Upload does not match its digest")
+
+               return False
+
        async def digest(self, algo):
                """
                        Computes the digest of this download
index 84fa1c543e5eaf1ef03fb4124a1d0e4bb8c560ba..f3531e116c35daba7168f171dc711e727b242e8b 100644 (file)
 #                                                                             #
 ###############################################################################
 
-import hashlib
-import hmac
-import logging
-import os
 import tornado.web
 
 from .handlers import BaseHandler
@@ -33,96 +29,58 @@ class CreateHandler(BaseHandler):
        @tornado.web.authenticated
        def initialize(self):
                # Fetch the filename
-               self.filename = self.get_argument("filename")
+               filename = self.get_argument("filename")
 
                # Fetch file size
-               self.filesize = self.get_argument_int("size")
+               size = self.get_argument_int("size")
 
-               # Check quota
-               if isinstance(self.current_user, users.User):
+               # Create a new upload
+               with self.db.transaction():
                        try:
-                               self.current_user.check_quota(self.filesize)
+                               self.upload = self.backend.uploads.create(
+                                       filename,
+                                       size=size,
+                                       builder=self.builder,
+                                       user=self.user,
+                               )
                        except users.QuotaExceededError as e:
                                raise tornado.web.HTTPError(400,
                                        "Quota exceeded for %s" % self.current_user) from e
 
-               # Count how many bytes have been received
-               self.bytes_read = 0
-
-               # Allocate a temporary file
-               self.f = self.backend.uploads.allocate_file()
-
-               self.h, self.hexdigest = self._setup_digest()
-
-       def data_received(self, data):
+       async def data_received(self, data):
                """
                        Called when some data is being received
                """
-               self.bytes_read += len(data)
-
-               # Abort if we have received more data than expected
-               if self.bytes_read > self.filesize:
-                       # 422 - Unprocessable Entity
-                       raise tornado.web.HTTPError(422)
-
-               # Update the digest
-               self.h.update(data)
-
-               # Write payload
-               self.f.write(data)
+               await self.upload.write(data)
 
        async def put(self):
                """
                        Called after the entire file has been received
                """
-               logging.debug("Finished receiving data")
+               # Fetch the digest argument
+               algo, delim, hexdigest = self.get_argument("digest").partition(":")
 
-               # Finish computing the hexdigest
-               computed_hexdigest = self.h.hexdigest()
+               # Convert hexdigest
+               digest = bytes.fromhex(hexdigest)
 
-               # Check if the hexdigest matches
-               # If not, we will raise an error
-               if not hmac.compare_digest(self.hexdigest, computed_hexdigest):
+               # Check the digest
+               if not await self.upload.check_digest(algo, digest):
                        # 422 - Unprocessable Entity
-                       raise tornado.web.HTTPError(422)
-
-               # Create a new upload object
-               with self.db.transaction():
-                       upload = self.backend.uploads.create(
-                               self.filename,
-                               self.f.name,
-                               builder=self.builder,
-                               user=self.user,
-                       )
-
-                       # Free the temporary file (to prevent cleanup)
-                       self.f = None
+                       raise tornado.web.HTTPError(422, "Digest did not match")
 
                # Send the ID of the upload back to the client
                self.finish({
-                       "id"         : upload.uuid,
-                       "expires_at" : upload.expires_at.isoformat(),
+                       "id"         : self.upload.uuid,
+                       "expires_at" : self.upload.expires_at.isoformat(),
                })
 
-       def _setup_digest(self):
-               # Fetch the digest argument
-               digest = self.get_argument("digest")
-
-               # Find the algorithm
-               algo, delim, hexdigest = digest.partition(":")
-
-               try:
-                       h = hashlib.new(algo)
-               except ValueError as e:
-                       raise tornado.web.HTTPError(415) from e
-
-               return h, hexdigest
+               # Free upload to avoid cleanup
+               self.upload = None
 
        def on_connection_close(self):
                """
                        Called when a connection was unexpectedly closed
                """
-               # Try deleting the file
-               if self.f:
-                       logging.debug("Deleting temporary file %s" % self.f.name)
-                       os.unlink(self.f.name)
+               # Delete the upload
+               #if self.upload:
+               #       await self.upload.delete()
index 9230a07226f197a66eb66d2820fe74341b215584..1ffa3124ee51d43ec81b92e7af5928ddae716799 100644 (file)
@@ -3,7 +3,6 @@
 import configparser
 import functools
 import os
-import shutil
 import socket
 import tempfile
 import unittest
@@ -185,21 +184,22 @@ class TestCase(unittest.IsolatedAsyncioTestCase):
                if user is None:
                        user = self.user
 
-               # Allocate a new destination file
-               dst = self.backend.uploads.allocate_file()
-
-               # Open the source file
-               with open(path, "rb") as src:
-                       # Copy the entire content
-                       shutil.copyfileobj(src, dst)
-
-               # Close the destination file
-               dst.close()
+               # Determine the filesize
+               size = os.path.getsize(path)
 
                # Create the upload object
-               upload = self.backend.uploads.create(filename, dst.name, user=user, **kwargs)
+               upload = self.backend.uploads.create(filename, size=size, user=user, **kwargs)
 
                # Check if received the correct type
                self.assertIsInstance(upload, uploads.Upload)
 
+               # Copy the payload
+               with open(path, "rb") as f:
+                       while True:
+                               buf = f.read(4096)
+                               if not buf:
+                                       break
+
+                               upload.write(buf)
+
                return upload
index 532020956ae86bbf1e24af9c2782ae79e5b5544f..819456d294d88ca9b312ac1f71d103e0503885bb 100755 (executable)
@@ -1,7 +1,6 @@
 #!/usr/bin/python3
 
 import os
-import tempfile
 import unittest
 
 import test
@@ -12,37 +11,17 @@ class UploadTestCase(test.TestCase):
        """
                Tests everything around uploads
        """
-       def test_allocate(self):
-               """
-                       Tests whether we can allocate temporary files in the right place
-               """
-               file = self.backend.uploads.allocate_file()
-
-               # Check if we received the correct type
-               self.assertIsInstance(file, tempfile._TemporaryFileWrapper)
-
-               # Check whether the file is located within basedir
-               self.assertTrue(file.name.startswith(self.backend.basepath))
-
        async def test_create_delete(self):
                """
                        Tests whether uploads can be created and deleted
                """
-               # Create a new temporary file
-               file = self.backend.uploads.allocate_file()
-               self.assertIsNotNone(file)
-
                # Create the upload object
-               upload = self.backend.uploads.create("test.blob", file.name, user=self.user)
+               upload = self.backend.uploads.create("test.blob", size=0, user=self.user)
 
                self.assertIsInstance(upload, uploads.Upload)
                self.assertEqual(upload.filename, "test.blob")
-               self.assertEqual(upload.path, file.name)
-
-               # "Free" file
-               del file
 
-               # Check if the file still exists
+               # Check if the upload file exists
                self.assertTrue(os.path.exists(upload.path))
 
                # Delete the upload