]> git.ipfire.org Git - thirdparty/paperless-ngx.git/commitdiff
Adds a layer to translate between differing formats of socket based Redis URLs
authorTrenton H <797416+stumpylog@users.noreply.github.com>
Fri, 2 Dec 2022 17:34:59 +0000 (09:34 -0800)
committerTrenton H <797416+stumpylog@users.noreply.github.com>
Sat, 3 Dec 2022 16:39:32 +0000 (08:39 -0800)
src/paperless/settings.py
src/paperless/tests/test_settings.py

index 456e15745562eb9563fcb3831c301373143331fb..eef7344da86bc42f18131a059718c3ba3da3e695 100644 (file)
@@ -8,6 +8,7 @@ import tempfile
 from typing import Final
 from typing import Optional
 from typing import Set
+from typing import Tuple
 from urllib.parse import urlparse
 
 from celery.schedules import crontab
@@ -65,6 +66,34 @@ def __get_path(key: str, default: str) -> str:
     return os.path.abspath(os.path.normpath(os.environ.get(key, default)))
 
 
+def _parse_redis_url(env_redis: Optional[str]) -> Tuple[str]:
+    """
+    Gets the Redis information from the environment or a default and handles
+    converting from incompatible django_channels and celery formats.
+
+    Returns a tuple of (celery_url, channels_url)
+    """
+
+    # Not set, return a compatible default
+    if env_redis is None:
+        return ("redis://localhost:6379", "redis://localhost:6379")
+
+    _, path = env_redis.split(":")
+
+    if "unix" in env_redis.lower():
+        # channels_redis socket format, looks like:
+        # "unix:///path/to/redis.sock"
+        return (f"redis+socket:{path}", env_redis)
+
+    elif "+socket" in env_redis.lower():
+        # celery socket style, looks like:
+        # "redis+socket:///path/to/redis.sock"
+        return (env_redis, f"unix:{path}")
+
+    # Not a socket
+    return (env_redis, env_redis)
+
+
 # NEVER RUN WITH DEBUG IN PRODUCTION.
 DEBUG = __get_boolean("PAPERLESS_DEBUG", "NO")
 
@@ -182,7 +211,9 @@ ASGI_APPLICATION = "paperless.asgi.application"
 STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", BASE_URL + "static/")
 WHITENOISE_STATIC_PREFIX = "/static/"
 
-_REDIS_URL = os.getenv("PAPERLESS_REDIS", "redis://localhost:6379")
+_CELERY_REDIS_URL, _CHANNELS_REDIS_URL = _parse_redis_url(
+    os.getenv("PAPERLESS_REDIS", None),
+)
 
 # TODO: what is this used for?
 TEMPLATES = [
@@ -205,7 +236,7 @@ CHANNEL_LAYERS = {
     "default": {
         "BACKEND": "channels_redis.core.RedisChannelLayer",
         "CONFIG": {
-            "hosts": [_REDIS_URL],
+            "hosts": [_CHANNELS_REDIS_URL],
             "capacity": 2000,  # default 100
             "expiry": 15,  # default 60
         },
@@ -468,7 +499,7 @@ TASK_WORKERS = __get_int("PAPERLESS_TASK_WORKERS", 1)
 
 WORKER_TIMEOUT: Final[int] = __get_int("PAPERLESS_WORKER_TIMEOUT", 1800)
 
-CELERY_BROKER_URL = _REDIS_URL
+CELERY_BROKER_URL = _CELERY_REDIS_URL
 CELERY_TIMEZONE = TIME_ZONE
 
 CELERY_WORKER_HIJACK_ROOT_LOGGER = False
@@ -513,7 +544,7 @@ CELERY_BEAT_SCHEDULE_FILENAME = os.path.join(DATA_DIR, "celerybeat-schedule.db")
 CACHES = {
     "default": {
         "BACKEND": "django.core.cache.backends.redis.RedisCache",
-        "LOCATION": _REDIS_URL,
+        "LOCATION": _CHANNELS_REDIS_URL,
     },
 }
 
index fed4079e2c86381b978a9bcd668112e1d8df71c3..fa839299fad5ec97a57bd02c3a19eb9cb28113e6 100644 (file)
@@ -3,6 +3,7 @@ from unittest import mock
 from unittest import TestCase
 
 from paperless.settings import _parse_ignore_dates
+from paperless.settings import _parse_redis_url
 from paperless.settings import default_threads_per_worker
 
 
@@ -82,3 +83,35 @@ class TestIgnoreDateParsing(TestCase):
                 self.assertGreaterEqual(default_threads, 1)
 
                 self.assertLessEqual(default_workers * default_threads, i)
+
+    def test_redis_socket_parsing(self):
+        """
+        GIVEN:
+            - Various Redis connection URI formats
+        WHEN:
+            - The URI is parsed
+        THEN:
+            - Socket based URIs are translated
+            - Non-socket URIs are unchanged
+            - None provided uses default
+        """
+
+        for input, expected in [
+            (None, ("redis://localhost:6379", "redis://localhost:6379")),
+            (
+                "redis+socket:///run/redis/redis.sock",
+                (
+                    "redis+socket:///run/redis/redis.sock",
+                    "unix:///run/redis/redis.sock",
+                ),
+            ),
+            (
+                "unix:///run/redis/redis.sock",
+                (
+                    "redis+socket:///run/redis/redis.sock",
+                    "unix:///run/redis/redis.sock",
+                ),
+            ),
+        ]:
+            result = _parse_redis_url(input)
+            self.assertTupleEqual(expected, result)