]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Include global and object-level permissions in export / import 3672/head
authorshamoon <4887959+shamoon@users.noreply.github.com>
Fri, 23 Jun 2023 15:56:18 +0000 (08:56 -0700)
committershamoon <4887959+shamoon@users.noreply.github.com>
Sat, 24 Jun 2023 06:33:36 +0000 (23:33 -0700)
adds test for transaction

src/documents/management/commands/document_exporter.py
src/documents/management/commands/document_importer.py
src/documents/tests/test_management_exporter.py

index fba89695b20ff54e24e3b76e2cc8e4681ef431d5..22fb5930878e26b98fb9c76afb52d1d44887c437 100644 (file)
@@ -11,13 +11,17 @@ from typing import Set
 import tqdm
 from django.conf import settings
 from django.contrib.auth.models import Group
+from django.contrib.auth.models import Permission
 from django.contrib.auth.models import User
+from django.contrib.contenttypes.models import ContentType
 from django.core import serializers
 from django.core.management.base import BaseCommand
 from django.core.management.base import CommandError
 from django.db import transaction
 from django.utils import timezone
 from filelock import FileLock
+from guardian.models import GroupObjectPermission
+from guardian.models import UserObjectPermission
 
 from documents.file_handling import delete_empty_directories
 from documents.file_handling import generate_filename
@@ -261,6 +265,22 @@ class Command(BaseCommand):
                 serializers.serialize("json", UiSettings.objects.all()),
             )
 
+            manifest += json.loads(
+                serializers.serialize("json", ContentType.objects.all()),
+            )
+
+            manifest += json.loads(
+                serializers.serialize("json", Permission.objects.all()),
+            )
+
+            manifest += json.loads(
+                serializers.serialize("json", UserObjectPermission.objects.all()),
+            )
+
+            manifest += json.loads(
+                serializers.serialize("json", GroupObjectPermission.objects.all()),
+            )
+
         # 3. Export files from each document
         for index, document_dict in tqdm.tqdm(
             enumerate(document_manifest),
index b00cb45fac2df945c066bfee03df28b2a4340798..baf6d75285a1a64f6a936952973446bbdb12bc2f 100644 (file)
@@ -7,11 +7,15 @@ from pathlib import Path
 
 import tqdm
 from django.conf import settings
+from django.contrib.auth.models import Permission
+from django.contrib.contenttypes.models import ContentType
 from django.core.exceptions import FieldDoesNotExist
 from django.core.management import call_command
 from django.core.management.base import BaseCommand
 from django.core.management.base import CommandError
 from django.core.serializers.base import DeserializationError
+from django.db import IntegrityError
+from django.db import transaction
 from django.db.models.signals import m2m_changed
 from django.db.models.signals import post_save
 from filelock import FileLock
@@ -116,9 +120,13 @@ class Command(BaseCommand):
         ):
             # Fill up the database with whatever is in the manifest
             try:
-                for manifest_path in manifest_paths:
-                    call_command("loaddata", manifest_path)
-            except (FieldDoesNotExist, DeserializationError) as e:
+                with transaction.atomic():
+                    for manifest_path in manifest_paths:
+                        # delete these since pk can change, re-created from import
+                        ContentType.objects.all().delete()
+                        Permission.objects.all().delete()
+                        call_command("loaddata", manifest_path)
+            except (FieldDoesNotExist, DeserializationError, IntegrityError) as e:
                 self.stdout.write(self.style.ERROR("Database import failed"))
                 if (
                     self.version is not None
index e7c116caf7dcaae41887cec392b5fbb4a591002a..421ae51fca96c9360940f6e09b8d32b092f8fa80 100644 (file)
@@ -7,11 +7,18 @@ from pathlib import Path
 from unittest import mock
 from zipfile import ZipFile
 
+from django.contrib.auth.models import Group
+from django.contrib.auth.models import Permission
+from django.contrib.contenttypes.models import ContentType
 from django.core.management import call_command
 from django.core.management.base import CommandError
+from django.db import IntegrityError
 from django.test import TestCase
 from django.test import override_settings
 from django.utils import timezone
+from guardian.models import GroupObjectPermission
+from guardian.models import UserObjectPermission
+from guardian.shortcuts import assign_perm
 
 from documents.management.commands import document_exporter
 from documents.models import Correspondent
@@ -34,6 +41,8 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
         self.addCleanup(shutil.rmtree, self.target)
 
         self.user = User.objects.create(username="temp_admin")
+        self.user2 = User.objects.create(username="user2")
+        self.group1 = Group.objects.create(name="group1")
 
         self.d1 = Document.objects.create(
             content="Content",
@@ -73,6 +82,9 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
             user=self.user,
         )
 
+        assign_perm("view_document", self.user2, self.d2)
+        assign_perm("view_document", self.group1, self.d3)
+
         self.t1 = Tag.objects.create(name="t")
         self.dt1 = DocumentType.objects.create(name="dt")
         self.c1 = Correspondent.objects.create(name="c")
@@ -141,12 +153,12 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
 
         manifest = self._do_export(use_filename_format=use_filename_format)
 
-        self.assertEqual(len(manifest), 10)
+        self.assertEqual(len(manifest), 149)
 
         # dont include consumer or AnonymousUser users
         self.assertEqual(
             len(list(filter(lambda e: e["model"] == "auth.user", manifest))),
-            1,
+            2,
         )
 
         self.assertEqual(
@@ -218,6 +230,9 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
             Correspondent.objects.all().delete()
             DocumentType.objects.all().delete()
             Tag.objects.all().delete()
+            Permission.objects.all().delete()
+            UserObjectPermission.objects.all().delete()
+            GroupObjectPermission.objects.all().delete()
             self.assertEqual(Document.objects.count(), 0)
 
             call_command("document_importer", "--no-progress-bar", self.target)
@@ -230,6 +245,9 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
             self.assertEqual(Document.objects.get(id=self.d2.id).title, "wow2")
             self.assertEqual(Document.objects.get(id=self.d3.id).title, "wow2")
             self.assertEqual(Document.objects.get(id=self.d4.id).title, "wow_dec")
+            self.assertEqual(GroupObjectPermission.objects.count(), 1)
+            self.assertEqual(UserObjectPermission.objects.count(), 1)
+            self.assertEqual(Permission.objects.count(), 108)
             messages = check_sanity()
             # everything is alright after the test
             self.assertEqual(len(messages), 0)
@@ -641,3 +659,47 @@ class TestExportImport(DirectoriesMixin, FileSystemAssertsMixin, TestCase):
             self.assertEqual(Document.objects.count(), 0)
             call_command("document_importer", "--no-progress-bar", self.target)
             self.assertEqual(Document.objects.count(), 4)
+
+    def test_import_db_transaction_failed(self):
+        """
+        GIVEN:
+            - Import from manifest started
+        WHEN:
+            - Import of database fails
+        THEN:
+            - ContentType & Permission objects are not deleted, db transaction rolled back
+        """
+
+        shutil.rmtree(os.path.join(self.dirs.media_dir, "documents"))
+        shutil.copytree(
+            os.path.join(os.path.dirname(__file__), "samples", "documents"),
+            os.path.join(self.dirs.media_dir, "documents"),
+        )
+
+        self.assertEqual(ContentType.objects.count(), 27)
+        self.assertEqual(Permission.objects.count(), 108)
+
+        manifest = self._do_export()
+
+        with paperless_environment():
+            self.assertEqual(
+                len(list(filter(lambda e: e["model"] == "auth.permission", manifest))),
+                108,
+            )
+            # add 1 more to db to show objects are not re-created by import
+            Permission.objects.create(
+                name="test",
+                codename="test_perm",
+                content_type_id=1,
+            )
+            self.assertEqual(Permission.objects.count(), 109)
+
+            # will cause an import error
+            self.user.delete()
+            self.user = User.objects.create(username="temp_admin")
+
+            with self.assertRaises(IntegrityError):
+                call_command("document_importer", "--no-progress-bar", self.target)
+
+            self.assertEqual(ContentType.objects.count(), 27)
+            self.assertEqual(Permission.objects.count(), 109)