from typing import Final
from typing import Optional
from typing import Set
+from typing import Tuple
from urllib.parse import urlparse
from celery.schedules import crontab
return os.path.abspath(os.path.normpath(os.environ.get(key, default)))
+def _parse_redis_url(env_redis: Optional[str]) -> Tuple[str]:
+ """
+ Gets the Redis information from the environment or a default and handles
+ converting from incompatible django_channels and celery formats.
+
+ Returns a tuple of (celery_url, channels_url)
+ """
+
+ # Not set, return a compatible default
+ if env_redis is None:
+ return ("redis://localhost:6379", "redis://localhost:6379")
+
+ _, path = env_redis.split(":")
+
+ if "unix" in env_redis.lower():
+ # channels_redis socket format, looks like:
+ # "unix:///path/to/redis.sock"
+ return (f"redis+socket:{path}", env_redis)
+
+ elif "+socket" in env_redis.lower():
+ # celery socket style, looks like:
+ # "redis+socket:///path/to/redis.sock"
+ return (env_redis, f"unix:{path}")
+
+ # Not a socket
+ return (env_redis, env_redis)
+
+
# NEVER RUN WITH DEBUG IN PRODUCTION.
DEBUG = __get_boolean("PAPERLESS_DEBUG", "NO")
STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", BASE_URL + "static/")
WHITENOISE_STATIC_PREFIX = "/static/"
-_REDIS_URL = os.getenv("PAPERLESS_REDIS", "redis://localhost:6379")
+_CELERY_REDIS_URL, _CHANNELS_REDIS_URL = _parse_redis_url(
+ os.getenv("PAPERLESS_REDIS", None),
+)
# TODO: what is this used for?
TEMPLATES = [
"default": {
"BACKEND": "channels_redis.core.RedisChannelLayer",
"CONFIG": {
- "hosts": [_REDIS_URL],
+ "hosts": [_CHANNELS_REDIS_URL],
"capacity": 2000, # default 100
"expiry": 15, # default 60
},
WORKER_TIMEOUT: Final[int] = __get_int("PAPERLESS_WORKER_TIMEOUT", 1800)
-CELERY_BROKER_URL = _REDIS_URL
+CELERY_BROKER_URL = _CELERY_REDIS_URL
CELERY_TIMEZONE = TIME_ZONE
CELERY_WORKER_HIJACK_ROOT_LOGGER = False
CACHES = {
"default": {
"BACKEND": "django.core.cache.backends.redis.RedisCache",
- "LOCATION": _REDIS_URL,
+ "LOCATION": _CHANNELS_REDIS_URL,
},
}
from unittest import TestCase
from paperless.settings import _parse_ignore_dates
+from paperless.settings import _parse_redis_url
from paperless.settings import default_threads_per_worker
self.assertGreaterEqual(default_threads, 1)
self.assertLessEqual(default_workers * default_threads, i)
+
+ def test_redis_socket_parsing(self):
+ """
+ GIVEN:
+ - Various Redis connection URI formats
+ WHEN:
+ - The URI is parsed
+ THEN:
+ - Socket based URIs are translated
+ - Non-socket URIs are unchanged
+ - None provided uses default
+ """
+
+ for input, expected in [
+ (None, ("redis://localhost:6379", "redis://localhost:6379")),
+ (
+ "redis+socket:///run/redis/redis.sock",
+ (
+ "redis+socket:///run/redis/redis.sock",
+ "unix:///run/redis/redis.sock",
+ ),
+ ),
+ (
+ "unix:///run/redis/redis.sock",
+ (
+ "redis+socket:///run/redis/redis.sock",
+ "unix:///run/redis/redis.sock",
+ ),
+ ),
+ ]:
+ result = _parse_redis_url(input)
+ self.assertTupleEqual(expected, result)