]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Resolves test issues with Python 3.12 (#6902)
authorTrenton H <797416+stumpylog@users.noreply.github.com>
Mon, 3 Jun 2024 19:33:46 +0000 (12:33 -0700)
committerGitHub <noreply@github.com>
Mon, 3 Jun 2024 19:33:46 +0000 (12:33 -0700)
src/documents/consumer.py
src/documents/tests/test_api_app_config.py
src/paperless/tests/test_settings.py

index 0d5514e2c78d26beab9ad9a0d530a4aa5ed655cb..9447fb329fafe3ea31ead7baf2c33083a7872a94 100644 (file)
@@ -485,56 +485,65 @@ class ConsumerPlugin(
         Return the document object if it was successfully created.
         """
 
-        self._send_progress(
-            0,
-            100,
-            ProgressStatusOptions.STARTED,
-            ConsumerStatusShortMessage.NEW_FILE,
-        )
+        tempdir = None
 
-        # Make sure that preconditions for consuming the file are met.
+        try:
+            self._send_progress(
+                0,
+                100,
+                ProgressStatusOptions.STARTED,
+                ConsumerStatusShortMessage.NEW_FILE,
+            )
 
-        self.pre_check_file_exists()
-        self.pre_check_directories()
-        self.pre_check_duplicate()
-        self.pre_check_asn_value()
+            # Make sure that preconditions for consuming the file are met.
 
-        self.log.info(f"Consuming {self.filename}")
+            self.pre_check_file_exists()
+            self.pre_check_directories()
+            self.pre_check_duplicate()
+            self.pre_check_asn_value()
 
-        # For the actual work, copy the file into a tempdir
-        tempdir = tempfile.TemporaryDirectory(
-            prefix="paperless-ngx",
-            dir=settings.SCRATCH_DIR,
-        )
-        self.working_copy = Path(tempdir.name) / Path(self.filename)
-        copy_file_with_basic_stats(self.input_doc.original_file, self.working_copy)
+            self.log.info(f"Consuming {self.filename}")
 
-        # Determine the parser class.
+            # For the actual work, copy the file into a tempdir
+            tempdir = tempfile.TemporaryDirectory(
+                prefix="paperless-ngx",
+                dir=settings.SCRATCH_DIR,
+            )
+            self.working_copy = Path(tempdir.name) / Path(self.filename)
+            copy_file_with_basic_stats(self.input_doc.original_file, self.working_copy)
 
-        mime_type = magic.from_file(self.working_copy, mime=True)
+            # Determine the parser class.
 
-        self.log.debug(f"Detected mime type: {mime_type}")
+            mime_type = magic.from_file(self.working_copy, mime=True)
 
-        # Based on the mime type, get the parser for that type
-        parser_class: Optional[type[DocumentParser]] = get_parser_class_for_mime_type(
-            mime_type,
-        )
-        if not parser_class:
-            tempdir.cleanup()
-            self._fail(
-                ConsumerStatusShortMessage.UNSUPPORTED_TYPE,
-                f"Unsupported mime type {mime_type}",
+            self.log.debug(f"Detected mime type: {mime_type}")
+
+            # Based on the mime type, get the parser for that type
+            parser_class: Optional[type[DocumentParser]] = (
+                get_parser_class_for_mime_type(
+                    mime_type,
+                )
             )
+            if not parser_class:
+                tempdir.cleanup()
+                self._fail(
+                    ConsumerStatusShortMessage.UNSUPPORTED_TYPE,
+                    f"Unsupported mime type {mime_type}",
+                )
 
-        # Notify all listeners that we're going to do some work.
+            # Notify all listeners that we're going to do some work.
 
-        document_consumption_started.send(
-            sender=self.__class__,
-            filename=self.working_copy,
-            logging_group=self.logging_group,
-        )
+            document_consumption_started.send(
+                sender=self.__class__,
+                filename=self.working_copy,
+                logging_group=self.logging_group,
+            )
 
-        self.run_pre_consume_script()
+            self.run_pre_consume_script()
+        except:
+            if tempdir:
+                tempdir.cleanup()
+            raise
 
         def progress_callback(current_progress, max_progress):  # pragma: no cover
             # recalculate progress to be within 20 and 80
@@ -593,6 +602,9 @@ class ConsumerPlugin(
             archive_path = document_parser.get_archive_path()
 
         except ParseError as e:
+            document_parser.cleanup()
+            if tempdir:
+                tempdir.cleanup()
             self._fail(
                 str(e),
                 f"Error occurred while consuming document {self.filename}: {e}",
@@ -601,7 +613,8 @@ class ConsumerPlugin(
             )
         except Exception as e:
             document_parser.cleanup()
-            tempdir.cleanup()
+            if tempdir:
+                tempdir.cleanup()
             self._fail(
                 str(e),
                 f"Unexpected error while consuming document {self.filename}: {e}",
index ba14e664a60ee3a86ec9dcbf594ca3f65313aa5e..0d7771c0798df0a6a74f2db181b3857bd17ed2f5 100644 (file)
@@ -70,12 +70,13 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
         config.app_logo = "/logo/example.jpg"
         config.save()
         response = self.client.get("/api/ui_settings/", format="json")
-        self.assertDictContainsSubset(
+        self.assertDictEqual(
+            response.data["settings"],
             {
                 "app_title": config.app_title,
                 "app_logo": config.app_logo,
-            },
-            response.data["settings"],
+            }
+            response.data["settings"],
         )
 
     def test_api_update_config(self):
index e27630ffa3a196b04af3235b5ac99dbc2731af5b..0051d40e747cd24001f82b6a4fb5f0c37621e8de 100644 (file)
@@ -339,11 +339,12 @@ class TestDBSettings(TestCase):
         ):
             databases = _parse_db_settings()
 
-            self.assertDictContainsSubset(
-                {
+            self.assertDictEqual(
+                databases["default"]["OPTIONS"],
+                databases["default"]["OPTIONS"]
+                | {
                     "connect_timeout": 10.0,
                 },
-                databases["default"]["OPTIONS"],
             )
             self.assertDictEqual(
                 {