import tqdm
from django.conf import settings
from django.contrib.auth.models import Group
+from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
+from django.contrib.contenttypes.models import ContentType
from django.core import serializers
from django.core.management.base import BaseCommand
from django.core.management.base import CommandError
from django.db import transaction
from django.utils import timezone
from filelock import FileLock
+from guardian.models import GroupObjectPermission
+from guardian.models import UserObjectPermission
from documents.file_handling import delete_empty_directories
from documents.file_handling import generate_filename
serializers.serialize("json", UiSettings.objects.all()),
)
+ manifest += json.loads(
+ serializers.serialize("json", ContentType.objects.all()),
+ )
+
+ manifest += json.loads(
+ serializers.serialize("json", Permission.objects.all()),
+ )
+
+ manifest += json.loads(
+ serializers.serialize("json", UserObjectPermission.objects.all()),
+ )
+
+ manifest += json.loads(
+ serializers.serialize("json", GroupObjectPermission.objects.all()),
+ )
+
# 3. Export files from each document
for index, document_dict in tqdm.tqdm(
enumerate(document_manifest),
import tqdm
from django.conf import settings
+from django.contrib.auth.models import Permission
+from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import FieldDoesNotExist
from django.core.management import call_command
from django.core.management.base import BaseCommand
from django.core.management.base import CommandError
from django.core.serializers.base import DeserializationError
+from django.db import IntegrityError
+from django.db import transaction
from django.db.models.signals import m2m_changed
from django.db.models.signals import post_save
from filelock import FileLock
):
# Fill up the database with whatever is in the manifest
try:
- for manifest_path in manifest_paths:
- call_command("loaddata", manifest_path)
- except (FieldDoesNotExist, DeserializationError) as e:
+ with transaction.atomic():
+ for manifest_path in manifest_paths:
+ # delete these since pk can change, re-created from import
+ ContentType.objects.all().delete()
+ Permission.objects.all().delete()
+ call_command("loaddata", manifest_path)
+ except (FieldDoesNotExist, DeserializationError, IntegrityError) as e:
self.stdout.write(self.style.ERROR("Database import failed"))
if (
self.version is not None
from unittest import mock
from zipfile import ZipFile
+from django.contrib.auth.models import Group
+from django.contrib.auth.models import Permission
+from django.contrib.contenttypes.models import ContentType
from django.core.management import call_command
from django.core.management.base import CommandError
+from django.db import IntegrityError
from django.test import TestCase
from django.test import override_settings
from django.utils import timezone
+from guardian.models import GroupObjectPermission
+from guardian.models import UserObjectPermission
+from guardian.shortcuts import assign_perm
from documents.management.commands import document_exporter
from documents.models import Correspondent
self.addCleanup(shutil.rmtree, self.target)
self.user = User.objects.create(username="temp_admin")
+ self.user2 = User.objects.create(username="user2")
+ self.group1 = Group.objects.create(name="group1")
self.d1 = Document.objects.create(
content="Content",
user=self.user,
)
+ assign_perm("view_document", self.user2, self.d2)
+ assign_perm("view_document", self.group1, self.d3)
+
self.t1 = Tag.objects.create(name="t")
self.dt1 = DocumentType.objects.create(name="dt")
self.c1 = Correspondent.objects.create(name="c")
manifest = self._do_export(use_filename_format=use_filename_format)
- self.assertEqual(len(manifest), 10)
+ self.assertEqual(len(manifest), 149)
# dont include consumer or AnonymousUser users
self.assertEqual(
len(list(filter(lambda e: e["model"] == "auth.user", manifest))),
- 1,
+ 2,
)
self.assertEqual(
Correspondent.objects.all().delete()
DocumentType.objects.all().delete()
Tag.objects.all().delete()
+ Permission.objects.all().delete()
+ UserObjectPermission.objects.all().delete()
+ GroupObjectPermission.objects.all().delete()
self.assertEqual(Document.objects.count(), 0)
call_command("document_importer", "--no-progress-bar", self.target)
self.assertEqual(Document.objects.get(id=self.d2.id).title, "wow2")
self.assertEqual(Document.objects.get(id=self.d3.id).title, "wow2")
self.assertEqual(Document.objects.get(id=self.d4.id).title, "wow_dec")
+ self.assertEqual(GroupObjectPermission.objects.count(), 1)
+ self.assertEqual(UserObjectPermission.objects.count(), 1)
+ self.assertEqual(Permission.objects.count(), 108)
messages = check_sanity()
# everything is alright after the test
self.assertEqual(len(messages), 0)
self.assertEqual(Document.objects.count(), 0)
call_command("document_importer", "--no-progress-bar", self.target)
self.assertEqual(Document.objects.count(), 4)
+
+ def test_import_db_transaction_failed(self):
+ """
+ GIVEN:
+ - Import from manifest started
+ WHEN:
+ - Import of database fails
+ THEN:
+ - ContentType & Permission objects are not deleted, db transaction rolled back
+ """
+
+ shutil.rmtree(os.path.join(self.dirs.media_dir, "documents"))
+ shutil.copytree(
+ os.path.join(os.path.dirname(__file__), "samples", "documents"),
+ os.path.join(self.dirs.media_dir, "documents"),
+ )
+
+ self.assertEqual(ContentType.objects.count(), 27)
+ self.assertEqual(Permission.objects.count(), 108)
+
+ manifest = self._do_export()
+
+ with paperless_environment():
+ self.assertEqual(
+ len(list(filter(lambda e: e["model"] == "auth.permission", manifest))),
+ 108,
+ )
+ # add 1 more to db to show objects are not re-created by import
+ Permission.objects.create(
+ name="test",
+ codename="test_perm",
+ content_type_id=1,
+ )
+ self.assertEqual(Permission.objects.count(), 109)
+
+ # will cause an import error
+ self.user.delete()
+ self.user = User.objects.create(username="temp_admin")
+
+ with self.assertRaises(IntegrityError):
+ call_command("document_importer", "--no-progress-bar", self.target)
+
+ self.assertEqual(ContentType.objects.count(), 27)
+ self.assertEqual(Permission.objects.count(), 109)