mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-26 08:12:34 -04:00 
			
		
		
		
	Adds a layer to translate between differing formats of socket based Redis URLs
This commit is contained in:
		
							parent
							
								
									2e8706f4e2
								
							
						
					
					
						commit
						01d070b882
					
				| @ -8,6 +8,7 @@ import tempfile | ||||
| from typing import Final | ||||
| from typing import Optional | ||||
| from typing import Set | ||||
| from typing import Tuple | ||||
| from urllib.parse import urlparse | ||||
| 
 | ||||
| from celery.schedules import crontab | ||||
| @ -65,6 +66,34 @@ def __get_path(key: str, default: str) -> str: | ||||
|     return os.path.abspath(os.path.normpath(os.environ.get(key, default))) | ||||
| 
 | ||||
| 
 | ||||
| def _parse_redis_url(env_redis: Optional[str]) -> Tuple[str]: | ||||
|     """ | ||||
|     Gets the Redis information from the environment or a default and handles | ||||
|     converting from incompatible django_channels and celery formats. | ||||
| 
 | ||||
|     Returns a tuple of (celery_url, channels_url) | ||||
|     """ | ||||
| 
 | ||||
|     # Not set, return a compatible default | ||||
|     if env_redis is None: | ||||
|         return ("redis://localhost:6379", "redis://localhost:6379") | ||||
| 
 | ||||
|     _, path = env_redis.split(":") | ||||
| 
 | ||||
|     if "unix" in env_redis.lower(): | ||||
|         # channels_redis socket format, looks like: | ||||
|         # "unix:///path/to/redis.sock" | ||||
|         return (f"redis+socket:{path}", env_redis) | ||||
| 
 | ||||
|     elif "+socket" in env_redis.lower(): | ||||
|         # celery socket style, looks like: | ||||
|         # "redis+socket:///path/to/redis.sock" | ||||
|         return (env_redis, f"unix:{path}") | ||||
| 
 | ||||
|     # Not a socket | ||||
|     return (env_redis, env_redis) | ||||
| 
 | ||||
| 
 | ||||
| # NEVER RUN WITH DEBUG IN PRODUCTION. | ||||
| DEBUG = __get_boolean("PAPERLESS_DEBUG", "NO") | ||||
| 
 | ||||
| @ -182,7 +211,9 @@ ASGI_APPLICATION = "paperless.asgi.application" | ||||
| STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", BASE_URL + "static/") | ||||
| WHITENOISE_STATIC_PREFIX = "/static/" | ||||
| 
 | ||||
| _REDIS_URL = os.getenv("PAPERLESS_REDIS", "redis://localhost:6379") | ||||
| _CELERY_REDIS_URL, _CHANNELS_REDIS_URL = _parse_redis_url( | ||||
|     os.getenv("PAPERLESS_REDIS", None), | ||||
| ) | ||||
| 
 | ||||
| # TODO: what is this used for? | ||||
| TEMPLATES = [ | ||||
| @ -205,7 +236,7 @@ CHANNEL_LAYERS = { | ||||
|     "default": { | ||||
|         "BACKEND": "channels_redis.core.RedisChannelLayer", | ||||
|         "CONFIG": { | ||||
|             "hosts": [_REDIS_URL], | ||||
|             "hosts": [_CHANNELS_REDIS_URL], | ||||
|             "capacity": 2000,  # default 100 | ||||
|             "expiry": 15,  # default 60 | ||||
|         }, | ||||
| @ -468,7 +499,7 @@ TASK_WORKERS = __get_int("PAPERLESS_TASK_WORKERS", 1) | ||||
| 
 | ||||
| WORKER_TIMEOUT: Final[int] = __get_int("PAPERLESS_WORKER_TIMEOUT", 1800) | ||||
| 
 | ||||
| CELERY_BROKER_URL = _REDIS_URL | ||||
| CELERY_BROKER_URL = _CELERY_REDIS_URL | ||||
| CELERY_TIMEZONE = TIME_ZONE | ||||
| 
 | ||||
| CELERY_WORKER_HIJACK_ROOT_LOGGER = False | ||||
| @ -513,7 +544,7 @@ CELERY_BEAT_SCHEDULE_FILENAME = os.path.join(DATA_DIR, "celerybeat-schedule.db") | ||||
| CACHES = { | ||||
|     "default": { | ||||
|         "BACKEND": "django.core.cache.backends.redis.RedisCache", | ||||
|         "LOCATION": _REDIS_URL, | ||||
|         "LOCATION": _CHANNELS_REDIS_URL, | ||||
|     }, | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -3,6 +3,7 @@ from unittest import mock | ||||
| from unittest import TestCase | ||||
| 
 | ||||
| from paperless.settings import _parse_ignore_dates | ||||
| from paperless.settings import _parse_redis_url | ||||
| from paperless.settings import default_threads_per_worker | ||||
| 
 | ||||
| 
 | ||||
| @ -82,3 +83,35 @@ class TestIgnoreDateParsing(TestCase): | ||||
|                 self.assertGreaterEqual(default_threads, 1) | ||||
| 
 | ||||
|                 self.assertLessEqual(default_workers * default_threads, i) | ||||
| 
 | ||||
|     def test_redis_socket_parsing(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Various Redis connection URI formats | ||||
|         WHEN: | ||||
|             - The URI is parsed | ||||
|         THEN: | ||||
|             - Socket based URIs are translated | ||||
|             - Non-socket URIs are unchanged | ||||
|             - None provided uses default | ||||
|         """ | ||||
| 
 | ||||
|         for input, expected in [ | ||||
|             (None, ("redis://localhost:6379", "redis://localhost:6379")), | ||||
|             ( | ||||
|                 "redis+socket:///run/redis/redis.sock", | ||||
|                 ( | ||||
|                     "redis+socket:///run/redis/redis.sock", | ||||
|                     "unix:///run/redis/redis.sock", | ||||
|                 ), | ||||
|             ), | ||||
|             ( | ||||
|                 "unix:///run/redis/redis.sock", | ||||
|                 ( | ||||
|                     "redis+socket:///run/redis/redis.sock", | ||||
|                     "unix:///run/redis/redis.sock", | ||||
|                 ), | ||||
|             ), | ||||
|         ]: | ||||
|             result = _parse_redis_url(input) | ||||
|             self.assertTupleEqual(expected, result) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user