From: Michael Tremer Date: Tue, 18 Oct 2022 14:34:54 +0000 (+0000) Subject: uploads: Refactor so that we won't duplicate any data X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7ef2c5281f9265c4d57a1699fade002f9e3d098f;p=pbs.git uploads: Refactor so that we won't duplicate any data Signed-off-by: Michael Tremer --- diff --git a/src/buildservice/uploads.py b/src/buildservice/uploads.py index 3e6b88c6..1754a356 100644 --- a/src/buildservice/uploads.py +++ b/src/buildservice/uploads.py @@ -1,11 +1,14 @@ #!/usr/bin/python3 +import asyncio import hashlib +import hmac import logging import os import tempfile from . import base +from . import users from .constants import * from .decorators import * @@ -35,7 +38,7 @@ class Uploads(base.Object): return self._get_upload("SELECT * FROM uploads \ WHERE uuid = %s AND expires_at > CURRENT_TIMESTAMP", uuid) - def allocate_file(self): + def _allocate_file(self): """ Returns a file handle which can be used to write temporary data to. """ @@ -49,14 +52,18 @@ class Uploads(base.Object): return tempfile.NamedTemporaryFile(dir=path, delete=False) - def create(self, filename, path, size=None, builder=None, user=None): + def create(self, filename, size, builder=None, user=None): # Check if either builder or user are set if not builder and not user: raise ValueError("builder or user must be set") - # Fetch size if none given - if size is None: - size = os.path.getsize(path) + # Check quota for users + if user: + # This will raise an exception if the quota has been exceeded + user.check_quota(size) + + # Allocate a new temporary file + f = self._allocate_file() upload = self._get_upload(""" INSERT INTO @@ -78,7 +85,7 @@ class Uploads(base.Object): ) RETURNING *""", filename, - path, + f.name, size, builder, user, @@ -167,6 +174,43 @@ class Upload(base.DataObject): def expires_at(self): return self.data.expires_at + async def write(self, data): + """ + Takes a chunk of data and writes it to disk + """ + log.debug("Received a new chunk of %s byte(s) of data for %s" % (len(data), self)) + + return await asyncio.to_thread(self._write, data) + + def _write(self, data): + # Write the data to disk + with open(self.path, "ab") as f: + # Fetch the current size + filesize = f.tell() + + # Check if we would exceed the filesize + if filesize + len(data) > self.size: + raise OverflowError + + # Write data + f.write(data) + + async def check_digest(self, algo, expected_digest): + """ + Checks if the upload matches an expected digest + """ + # Compute the digest + computed_digest = await self.digest(algo) + + # Compare the digests + if hmac.compare_digest(computed_digest, expected_digest): + return True + + # The digests didn't match + log.error("Upload does not match its digest") + + return False + async def digest(self, algo): """ Computes the digest of this download diff --git a/src/hub/uploads.py b/src/hub/uploads.py index 84fa1c54..f3531e11 100644 --- a/src/hub/uploads.py +++ b/src/hub/uploads.py @@ -19,10 +19,6 @@ # # ############################################################################### -import hashlib -import hmac -import logging -import os import tornado.web from .handlers import BaseHandler @@ -33,96 +29,58 @@ class CreateHandler(BaseHandler): @tornado.web.authenticated def initialize(self): # Fetch the filename - self.filename = self.get_argument("filename") + filename = self.get_argument("filename") # Fetch file size - self.filesize = self.get_argument_int("size") + size = self.get_argument_int("size") - # Check quota - if isinstance(self.current_user, users.User): + # Create a new upload + with self.db.transaction(): try: - self.current_user.check_quota(self.filesize) + self.upload = self.backend.uploads.create( + filename, + size=size, + builder=self.builder, + user=self.user, + ) except users.QuotaExceededError as e: raise tornado.web.HTTPError(400, "Quota exceeded for %s" % self.current_user) from e - # Count how many bytes have been received - self.bytes_read = 0 - - # Allocate a temporary file - self.f = self.backend.uploads.allocate_file() - - self.h, self.hexdigest = self._setup_digest() - - def data_received(self, data): + async def data_received(self, data): """ Called when some data is being received """ - self.bytes_read += len(data) - - # Abort if we have received more data than expected - if self.bytes_read > self.filesize: - # 422 - Unprocessable Entity - raise tornado.web.HTTPError(422) - - # Update the digest - self.h.update(data) - - # Write payload - self.f.write(data) + await self.upload.write(data) async def put(self): """ Called after the entire file has been received """ - logging.debug("Finished receiving data") + # Fetch the digest argument + algo, delim, hexdigest = self.get_argument("digest").partition(":") - # Finish computing the hexdigest - computed_hexdigest = self.h.hexdigest() + # Convert hexdigest + digest = bytes.fromhex(hexdigest) - # Check if the hexdigest matches - # If not, we will raise an error - if not hmac.compare_digest(self.hexdigest, computed_hexdigest): + # Check the digest + if not await self.upload.check_digest(algo, digest): # 422 - Unprocessable Entity - raise tornado.web.HTTPError(422) - - # Create a new upload object - with self.db.transaction(): - upload = self.backend.uploads.create( - self.filename, - self.f.name, - builder=self.builder, - user=self.user, - ) - - # Free the temporary file (to prevent cleanup) - self.f = None + raise tornado.web.HTTPError(422, "Digest did not match") # Send the ID of the upload back to the client self.finish({ - "id" : upload.uuid, - "expires_at" : upload.expires_at.isoformat(), + "id" : self.upload.uuid, + "expires_at" : self.upload.expires_at.isoformat(), }) - def _setup_digest(self): - # Fetch the digest argument - digest = self.get_argument("digest") - - # Find the algorithm - algo, delim, hexdigest = digest.partition(":") - - try: - h = hashlib.new(algo) - except ValueError as e: - raise tornado.web.HTTPError(415) from e - - return h, hexdigest + # Free upload to avoid cleanup + self.upload = None def on_connection_close(self): """ Called when a connection was unexpectedly closed """ - # Try deleting the file - if self.f: - logging.debug("Deleting temporary file %s" % self.f.name) - os.unlink(self.f.name) + # Delete the upload + #if self.upload: + # await self.upload.delete() diff --git a/tests/test.py b/tests/test.py index 9230a072..1ffa3124 100644 --- a/tests/test.py +++ b/tests/test.py @@ -3,7 +3,6 @@ import configparser import functools import os -import shutil import socket import tempfile import unittest @@ -185,21 +184,22 @@ class TestCase(unittest.IsolatedAsyncioTestCase): if user is None: user = self.user - # Allocate a new destination file - dst = self.backend.uploads.allocate_file() - - # Open the source file - with open(path, "rb") as src: - # Copy the entire content - shutil.copyfileobj(src, dst) - - # Close the destination file - dst.close() + # Determine the filesize + size = os.path.getsize(path) # Create the upload object - upload = self.backend.uploads.create(filename, dst.name, user=user, **kwargs) + upload = self.backend.uploads.create(filename, size=size, user=user, **kwargs) # Check if received the correct type self.assertIsInstance(upload, uploads.Upload) + # Copy the payload + with open(path, "rb") as f: + while True: + buf = f.read(4096) + if not buf: + break + + upload.write(buf) + return upload diff --git a/tests/upload.py b/tests/upload.py index 53202095..819456d2 100755 --- a/tests/upload.py +++ b/tests/upload.py @@ -1,7 +1,6 @@ #!/usr/bin/python3 import os -import tempfile import unittest import test @@ -12,37 +11,17 @@ class UploadTestCase(test.TestCase): """ Tests everything around uploads """ - def test_allocate(self): - """ - Tests whether we can allocate temporary files in the right place - """ - file = self.backend.uploads.allocate_file() - - # Check if we received the correct type - self.assertIsInstance(file, tempfile._TemporaryFileWrapper) - - # Check whether the file is located within basedir - self.assertTrue(file.name.startswith(self.backend.basepath)) - async def test_create_delete(self): """ Tests whether uploads can be created and deleted """ - # Create a new temporary file - file = self.backend.uploads.allocate_file() - self.assertIsNotNone(file) - # Create the upload object - upload = self.backend.uploads.create("test.blob", file.name, user=self.user) + upload = self.backend.uploads.create("test.blob", size=0, user=self.user) self.assertIsInstance(upload, uploads.Upload) self.assertEqual(upload.filename, "test.blob") - self.assertEqual(upload.path, file.name) - - # "Free" file - del file - # Check if the file still exists + # Check if the upload file exists self.assertTrue(os.path.exists(upload.path)) # Delete the upload