]> git.ipfire.org Git - pbs.git/commitdiff
uploads: Refactor this once again
authorMichael Tremer <michael.tremer@ipfire.org>
Wed, 29 Jan 2025 18:18:51 +0000 (18:18 +0000)
committerMichael Tremer <michael.tremer@ipfire.org>
Wed, 29 Jan 2025 18:18:51 +0000 (18:18 +0000)
This is not 100% async but good enough. If the system cannot write to
its own file system we probably have bigger issues.

This will now spool the uploaded data for a little while but soon flush
it out once it has become too large to be kept in memory for forever.

The file will then be copied (yes I know) and the new path stored.

Signed-off-by: Michael Tremer <michael.tremer@ipfire.org>
src/buildservice/__init__.py
src/buildservice/uploads.py
src/web/uploads.py

index 894dbba5df666cd3a396efa7abc23d85a0f2d15a..ca3e5e8f83550a7d5b461d515b5adf8e5e1d4504 100644 (file)
@@ -448,13 +448,17 @@ class Backend(object):
        async def chmod(self, *args, **kwargs):
                return await asyncio.to_thread(os.chmod, *args, **kwargs)
 
-       def tempfile(self, mode="w+b", delete=True):
+       def tempfile(self, mode="w+b", sync=False, **kwargs):
                """
                        Returns an open file handle to a new temporary file
                """
                path = self.path("tmp")
 
-               return aiofiles.tempfile.NamedTemporaryFile(mode=mode, dir=path, delete=delete)
+               # If requested, return a sync temporary file
+               if sync:
+                       return tempfile.NamedTemporaryFile(mode=mode, dir=path, **kwargs)
+
+               return aiofiles.tempfile.NamedTemporaryFile(mode=mode, dir=path, **kwargs)
 
        def tempdir(self, **kwargs):
                """
index fff733f7c2735329e6f5e4b597eeecaae4f6338a..07de87159cddb8a4ea8201dc2a816d2e6406d239 100644 (file)
@@ -7,6 +7,7 @@ import logging
 import os
 import shutil
 import sqlalchemy
+import tempfile
 
 from sqlalchemy import Column, ForeignKey
 from sqlalchemy import BigInteger, DateTime, Integer, LargeBinary, Text, UUID
@@ -224,51 +225,12 @@ class Upload(database.Base, database.BackendMixin):
                if await self.has_payload():
                        raise FileExistsError("We already have the payload")
 
-               # Reset the source file handle
-               src.seek(0)
-
-               # Setup the hash function
-               h = hashlib.new("blake2b512")
-
-               async with self.backend.tempfile(delete=False) as f:
-                       try:
-                               while True:
-                                       buffer = src.read(1024 ** 2)
-                                       if not buffer:
-                                               break
-
-                                       # Put the buffer into the hash function
-                                       h.update(buffer)
-
-                                       # Write the buffer to disk
-                                       await f.write(buffer)
-
-                               # How many bytes did we write?
-                               received_size = await f.tell()
-
-                               # Get the digest
-                               computed_digest = h.digest()
-
-                               # Check if the filesize matches
-                               if not received_size == self.size:
-                                       raise ValueError("File size differs")
-
-                               # Check that the digest matches
-                               if not hmac.compare_digest(computed_digest, self.digest_blake2b512):
-                                       log.error("Upload %s had an incorrect digest:" % self)
-                                       log.error("  Expected: %s" % self.digest_blake2b512.hex())
-                                       log.error("  Got     : %s" % computed_digest.hex())
-
-                                       raise ValueError("Invalid digest")
-
-                       # If there has been any kind of exception, we want to delete the temporary file
-                       except Exception as e:
-                               await self.backend.unlink(f.name)
-
-                               raise e
+               # Copy the entire content to a new temporary file
+               with self.backend.tempfile(prefix="upload-", sync=True, delete=False) as dst:
+                       await asyncio.to_thread(shutil.copyfileobj, src, dst)
 
                # Store the path
-               self.path = f.name
+               self.path = dst.name
 
        # Copy the payload to somewhere else
 
index 97d1269f852b24b8931543311ce3c1c24041f808..888441fee42265a6a70f39aae30f656804497745 100644 (file)
 #                                                                             #
 ###############################################################################
 
+import asyncio
 import errno
 import hashlib
 import hmac
 import io
+import tempfile
 import tornado.web
 
 from . import base
@@ -118,13 +120,17 @@ class APIv1DetailHandler(base.APIMixin, base.BaseHandler):
        )
 
        def initialize(self):
-               # Buffer to cache the uploaded content
-               self.buffer = io.BytesIO()
+               # Create a temporary buffer in memory which will be flushed out to disk
+               # once it has received more than 128 MiB of data
+               self.f = tempfile.SpooledTemporaryFile(
+                       max_size = 128 * 1024 * 1024, # 128 MiB
+                       dir      = self.backend.path("tmp"),
+               )
 
                # Initalize the hashers
                self.hashers = { h : hashlib.new(h) for h in self.hashes }
 
-       def data_received(self, data):
+       async def data_received(self, data):
                """
                        Called when some data is being received
                """
@@ -133,7 +139,7 @@ class APIv1DetailHandler(base.APIMixin, base.BaseHandler):
                        self.hashers[h].update(data)
 
                # Write the data to the buffer
-               self.buffer.write(data)
+               await asyncio.to_thread(self.f.write, data)
 
        @base.negotiate
        async def get(self, uuid):
@@ -168,9 +174,12 @@ class APIv1DetailHandler(base.APIMixin, base.BaseHandler):
                # XXX has perm?
 
                # Fail if we did not receive anything
-               if not self.buffer.tell():
+               if not self.f.tell():
                        raise base.APIError(errno.ENODATA, "No data received")
 
+               # Rewind the file
+               self.f.seek(0)
+
                # Finalize digests of received data
                digests = {
                        h : self.hashers[h].digest() for h in self.hashers
@@ -183,12 +192,11 @@ class APIv1DetailHandler(base.APIMixin, base.BaseHandler):
                                        raise tornado.web.HTTPError(409, "%s digest does not match" % algo)
 
                # Import the payload from the buffer
-               async with await self.db.transaction():
-                       try:
-                               await upload.copyfrom(self.buffer)
+               try:
+                       await upload.copyfrom(self.f)
 
-                       except ValueError as e:
-                               raise base.APIError(errno.EINVAL, "%s" % e) from e
+               except ValueError as e:
+                       raise base.APIError(errno.EINVAL, "%s" % e) from e
 
                # Send no response
                self.set_status(204)