]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Enhancement: support default groups for regular and social account signup (#9039)
authorshamoon <4887959+shamoon@users.noreply.github.com>
Mon, 24 Feb 2025 17:23:20 +0000 (09:23 -0800)
committerGitHub <noreply@github.com>
Mon, 24 Feb 2025 17:23:20 +0000 (09:23 -0800)
docs/configuration.md
src/paperless/adapter.py
src/paperless/apps.py
src/paperless/settings.py
src/paperless/signals.py
src/paperless/tests/test_adapter.py
src/paperless/tests/test_signals.py

index 441d4610522fd475651eb64fd05f1382ccb52c90..391b97d137395e4eecc8160f6e652267f7a77272 100644 (file)
@@ -557,6 +557,20 @@ This is for use with self-signed certificates against local IMAP servers.
     Settings this value has security implications for the security of your email.
     Understand what it does and be sure you need to before setting.
 
+### Authentication & SSO {#authentication}
+
+#### [`PAPERLESS_ACCOUNT_ALLOW_SIGNUPS=<bool>`](#PAPERLESS_ACCOUNT_ALLOW_SIGNUPS) {#PAPERLESS_ACCOUNT_ALLOW_SIGNUPS}
+
+: Allow users to signup for a new Paperless-ngx account.
+
+    Defaults to False
+
+#### [`PAPERLESS_ACCOUNT_DEFAULT_GROUPS=<comma-separated-list>`](#PAPERLESS_ACCOUNT_DEFAULT_GROUPS) {#PAPERLESS_ACCOUNT_DEFAULT_GROUPS}
+
+: A list of group names that users will be added to when they sign up for a new account. Groups listed here must already exist.
+
+    Defaults to None
+
 #### [`PAPERLESS_SOCIALACCOUNT_PROVIDERS=<json>`](#PAPERLESS_SOCIALACCOUNT_PROVIDERS) {#PAPERLESS_SOCIALACCOUNT_PROVIDERS}
 
 : This variable is used to setup login and signup via social account providers which are compatible with django-allauth.
@@ -580,12 +594,25 @@ system. See the corresponding
 
     Defaults to True
 
-#### [`PAPERLESS_ACCOUNT_ALLOW_SIGNUPS=<bool>`](#PAPERLESS_ACCOUNT_ALLOW_SIGNUPS) {#PAPERLESS_ACCOUNT_ALLOW_SIGNUPS}
+#### [`PAPERLESS_SOCIAL_ACCOUNT_SYNC_GROUPS=<bool>`](#PAPERLESS_SOCIAL_ACCOUNT_SYNC_GROUPS) {#PAPERLESS_SOCIAL_ACCOUNT_SYNC_GROUPS}
 
-: Allow users to signup for a new Paperless-ngx account.
+: Sync groups from the third party authentication system (e.g. OIDC) to Paperless-ngx. When enabled, users will be added or removed from groups based on their group membership in the third party authentication system. Groups must already exist in Paperless-ngx and have the same name as in the third party authentication system. Groups are updated upon logging in via the third party authentication system, see the corresponding [django-allauth documentation](https://docs.allauth.org/en/dev/socialaccount/signals.html).
+
+: In order to pass groups from the authentication system you will need to update your [PAPERLESS_SOCIALACCOUNT_PROVIDERS](#PAPERLESS_SOCIALACCOUNT_PROVIDERS) setting by adding a top-level "SCOPES" setting which includes "groups", e.g.:
+
+    ```json
+    {"openid_connect":{"SCOPE": ["openid","profile","email","groups"]...
+    ```
 
     Defaults to False
 
+#### [`PAPERLESS_SOCIAL_ACCOUNT_DEFAULT_GROUPS=<comma-separated-list>`](#PAPERLESS_SOCIAL_ACCOUNT_DEFAULT_GROUPS) {#PAPERLESS_SOCIAL_ACCOUNT_DEFAULT_GROUPS}
+
+: A list of group names that users who signup via social accounts will be added to upon signup. Groups listed here must already exist.
+If both the [PAPERLESS_ACCOUNT_DEFAULT_GROUPS](#PAPERLESS_ACCOUNT_DEFAULT_GROUPS) setting and this setting are used, the user will be added to both sets of groups.
+
+    Defaults to None
+
 #### [`PAPERLESS_ACCOUNT_DEFAULT_HTTP_PROTOCOL=<string>`](#PAPERLESS_ACCOUNT_DEFAULT_HTTP_PROTOCOL) {#PAPERLESS_ACCOUNT_DEFAULT_HTTP_PROTOCOL}
 
 : The protocol used when generating URLs, e.g. login callback URLs. See the corresponding
index add2bf45d028a71aa2379ff6b6b2e22d17f7aca6..e29acb2ff90a22a83fd7323bdf0b793ed48c1937 100644 (file)
@@ -1,12 +1,17 @@
+import logging
 from urllib.parse import quote
 
 from allauth.account.adapter import DefaultAccountAdapter
 from allauth.core import context
 from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
 from django.conf import settings
+from django.contrib.auth.models import Group
+from django.contrib.auth.models import User
 from django.forms import ValidationError
 from django.urls import reverse
 
+logger = logging.getLogger("paperless.auth")
+
 
 class CustomAccountAdapter(DefaultAccountAdapter):
     def is_open_for_signup(self, request):
@@ -61,6 +66,20 @@ class CustomAccountAdapter(DefaultAccountAdapter):
             path = path.replace("UID-KEY", quote(key))
             return settings.PAPERLESS_URL + path
 
+    def save_user(self, request, user, form, commit=True):  # noqa: FBT002
+        """
+        Save the user instance. Default groups are assigned to the user, if
+        specified in the settings.
+        """
+        user: User = super().save_user(request, user, form, commit)
+        group_names: list[str] = settings.ACCOUNT_DEFAULT_GROUPS
+        if len(group_names) > 0:
+            groups = Group.objects.filter(name__in=group_names)
+            logger.debug(f"Adding default groups to user `{user}`: {group_names}")
+            user.groups.add(*groups)
+            user.save()
+        return user
+
 
 class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
     def is_open_for_signup(self, request, sociallogin):
@@ -80,10 +99,19 @@ class CustomSocialAccountAdapter(DefaultSocialAccountAdapter):
         url = reverse("base")
         return url
 
-    def populate_user(self, request, sociallogin, data):
+    def save_user(self, request, sociallogin, form=None):
         """
-        Populate the user with data from the social account. Stub is kept in case
-        global default permissions are implemented in the future.
+        Save the user instance. Default groups are assigned to the user, if
+        specified in the settings.
         """
-        # TODO: If default global permissions are implemented, should also be here
-        return super().populate_user(request, sociallogin, data)  # pragma: no cover
+        # save_user also calls account_adapter save_user which would set ACCOUNT_DEFAULT_GROUPS
+        user: User = super().save_user(request, sociallogin, form)
+        group_names: list[str] = settings.SOCIAL_ACCOUNT_DEFAULT_GROUPS
+        if len(group_names) > 0:
+            groups = Group.objects.filter(name__in=group_names)
+            logger.debug(
+                f"Adding default social groups to user `{user}`: {group_names}",
+            )
+            user.groups.add(*groups)
+            user.save()
+        return user
index b4147a2e35f65e0c0da9b61653865e92b491ed35..819d8d5ff02d4986647ba6269af5c4ab286eb280 100644 (file)
@@ -2,6 +2,7 @@ from django.apps import AppConfig
 from django.utils.translation import gettext_lazy as _
 
 from paperless.signals import handle_failed_login
+from paperless.signals import handle_social_account_updated
 
 
 class PaperlessConfig(AppConfig):
@@ -13,4 +14,9 @@ class PaperlessConfig(AppConfig):
         from django.contrib.auth.signals import user_login_failed
 
         user_login_failed.connect(handle_failed_login)
+
+        from allauth.socialaccount.signals import social_account_updated
+
+        social_account_updated.connect(handle_social_account_updated)
+
         AppConfig.ready(self)
index 8072f694e44a1661d7d6fccce02783cc56f38817..0c8c71ab9ed1219c0b5ffaf53200c26c64d94065 100644 (file)
@@ -480,6 +480,7 @@ ACCOUNT_DEFAULT_HTTP_PROTOCOL = os.getenv(
 
 ACCOUNT_ADAPTER = "paperless.adapter.CustomAccountAdapter"
 ACCOUNT_ALLOW_SIGNUPS = __get_boolean("PAPERLESS_ACCOUNT_ALLOW_SIGNUPS")
+ACCOUNT_DEFAULT_GROUPS = __get_list("PAPERLESS_ACCOUNT_DEFAULT_GROUPS")
 
 SOCIALACCOUNT_ADAPTER = "paperless.adapter.CustomSocialAccountAdapter"
 SOCIALACCOUNT_ALLOW_SIGNUPS = __get_boolean(
@@ -490,6 +491,8 @@ SOCIALACCOUNT_AUTO_SIGNUP = __get_boolean("PAPERLESS_SOCIAL_AUTO_SIGNUP")
 SOCIALACCOUNT_PROVIDERS = json.loads(
     os.getenv("PAPERLESS_SOCIALACCOUNT_PROVIDERS", "{}"),
 )
+SOCIAL_ACCOUNT_DEFAULT_GROUPS = __get_list("PAPERLESS_SOCIAL_ACCOUNT_DEFAULT_GROUPS")
+SOCIAL_ACCOUNT_SYNC_GROUPS = __get_boolean("PAPERLESS_SOCIAL_ACCOUNT_SYNC_GROUPS")
 
 MFA_TOTP_ISSUER = "Paperless-ngx"
 
index fa0298685d65764a6269113ed66b459db2226d04..a173ccc2e2987dc7718a0203b8e6c42987e4fd4f 100644 (file)
@@ -30,3 +30,21 @@ def handle_failed_login(sender, credentials, request, **kwargs):
             log_output += f" from private IP `{client_ip}`."
 
     logger.info(log_output)
+
+
+def handle_social_account_updated(sender, request, sociallogin, **kwargs):
+    """
+    Handle the social account update signal.
+    """
+    from django.contrib.auth.models import Group
+
+    social_account_groups = sociallogin.account.extra_data.get(
+        "groups",
+        [],
+    )  # None if not found
+    if settings.SOCIAL_ACCOUNT_SYNC_GROUPS and social_account_groups is not None:
+        groups = Group.objects.filter(name__in=social_account_groups)
+        logger.debug(
+            f"Syncing groups for user `{sociallogin.user}`: {social_account_groups}",
+        )
+        sociallogin.user.groups.set(groups, clear=True)
index 5659a279a9ec03d203c41ec0728c2f6a8f7fc62f..be4ad3d90814b72d7f9d356f92451fdcf8053807 100644 (file)
@@ -4,6 +4,8 @@ from allauth.account.adapter import get_adapter
 from allauth.core import context
 from allauth.socialaccount.adapter import get_adapter as get_social_adapter
 from django.conf import settings
+from django.contrib.auth.models import Group
+from django.contrib.auth.models import User
 from django.forms import ValidationError
 from django.http import HttpRequest
 from django.test import TestCase
@@ -81,6 +83,24 @@ class TestCustomAccountAdapter(TestCase):
                     expected_url,
                 )
 
+    @override_settings(ACCOUNT_DEFAULT_GROUPS=["group1", "group2"])
+    def test_save_user_adds_groups(self):
+        Group.objects.create(name="group1")
+        user = User.objects.create_user("testuser")
+        adapter = get_adapter()
+        form = mock.Mock(
+            cleaned_data={
+                "username": "testuser",
+                "email": "user@example.com",
+            },
+        )
+
+        user = adapter.save_user(HttpRequest(), user, form, commit=True)
+
+        self.assertEqual(user.groups.count(), 1)
+        self.assertTrue(user.groups.filter(name="group1").exists())
+        self.assertFalse(user.groups.filter(name="group2").exists())
+
 
 class TestCustomSocialAccountAdapter(TestCase):
     def test_is_open_for_signup(self):
@@ -105,3 +125,19 @@ class TestCustomSocialAccountAdapter(TestCase):
             adapter.get_connect_redirect_url(request, socialaccount),
             expected_url,
         )
+
+    @override_settings(SOCIAL_ACCOUNT_DEFAULT_GROUPS=["group1", "group2"])
+    def test_save_user_adds_groups(self):
+        Group.objects.create(name="group1")
+        adapter = get_social_adapter()
+        request = HttpRequest()
+        user = User.objects.create_user("testuser")
+        sociallogin = mock.Mock(
+            user=user,
+        )
+
+        user = adapter.save_user(request, sociallogin, None)
+
+        self.assertEqual(user.groups.count(), 1)
+        self.assertTrue(user.groups.filter(name="group1").exists())
+        self.assertFalse(user.groups.filter(name="group2").exists())
index dc425d6670043545f99c61f44cd983b2441225d5..0948ca575a46bb30f5a9c5ff55c03aa21f93b229 100644 (file)
@@ -1,7 +1,13 @@
+from unittest.mock import Mock
+
+from django.contrib.auth.models import Group
+from django.contrib.auth.models import User
 from django.http import HttpRequest
 from django.test import TestCase
+from django.test import override_settings
 
 from paperless.signals import handle_failed_login
+from paperless.signals import handle_social_account_updated
 
 
 class TestFailedLoginLogging(TestCase):
@@ -99,3 +105,88 @@ class TestFailedLoginLogging(TestCase):
                     "INFO:paperless.auth:Login failed for user `john lennon` from private IP `10.0.0.1`.",
                 ],
             )
+
+
+class TestSyncSocialLoginGroups(TestCase):
+    @override_settings(SOCIAL_ACCOUNT_SYNC_GROUPS=True)
+    def test_sync_enabled(self):
+        """
+        GIVEN:
+            - Enabled group syncing, a user, and a social login
+        WHEN:
+            - The social login is updated via signal after login
+        THEN:
+            - The user's groups are updated to match the social login's groups
+        """
+        group = Group.objects.create(name="group1")
+        user = User.objects.create_user(username="testuser")
+        sociallogin = Mock(
+            user=user,
+            account=Mock(
+                extra_data={
+                    "groups": ["group1"],
+                },
+            ),
+        )
+        handle_social_account_updated(
+            sender=None,
+            request=HttpRequest(),
+            sociallogin=sociallogin,
+        )
+        self.assertEqual(list(user.groups.all()), [group])
+
+    @override_settings(SOCIAL_ACCOUNT_SYNC_GROUPS=False)
+    def test_sync_disabled(self):
+        """
+        GIVEN:
+            - Disabled group syncing, a user, and a social login
+        WHEN:
+            - The social login is updated via signal after login
+        THEN:
+            - The user's groups are not updated
+        """
+        Group.objects.create(name="group1")
+        user = User.objects.create_user(username="testuser")
+        sociallogin = Mock(
+            user=user,
+            account=Mock(
+                extra_data={
+                    "groups": ["group1"],
+                },
+            ),
+        )
+        handle_social_account_updated(
+            sender=None,
+            request=HttpRequest(),
+            sociallogin=sociallogin,
+        )
+        self.assertEqual(list(user.groups.all()), [])
+
+    @override_settings(SOCIAL_ACCOUNT_SYNC_GROUPS=True)
+    def test_no_groups(self):
+        """
+        GIVEN:
+            - Enabled group syncing, a user, and a social login with no groups
+        WHEN:
+            - The social login is updated via signal after login
+        THEN:
+            - The user's groups are cleared to match the social login's groups
+        """
+        group = Group.objects.create(name="group1")
+        user = User.objects.create_user(username="testuser")
+        user.groups.add(group)
+        user.save()
+        sociallogin = Mock(
+            user=user,
+            account=Mock(
+                extra_data={
+                    "groups": [],
+                },
+            ),
+        )
+        handle_social_account_updated(
+            sender=None,
+            request=HttpRequest(),
+            sociallogin=sociallogin,
+        )
+        self.assertEqual(list(user.groups.all()), [])