mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-24 23:39:03 -04:00 
			
		
		
		
	fix(ml): better model unloading (#3340)
* restart process on inactivity * formatting * always update `last_called` * load models sequentially * renamed variable, updated docs * formatting * made poll env name consistent with model ttl env
This commit is contained in:
		
							parent
							
								
									98f87c6548
								
							
						
					
					
						commit
						a6af4892e3
					
				| @ -188,19 +188,18 @@ Typesense URL example JSON before encoding: | ||||
| 
 | ||||
| | Variable                                         | Description                                                       |       Default       | Services         | | ||||
| | :----------------------------------------------- | :---------------------------------------------------------------- | :-----------------: | :--------------- | | ||||
| | `MACHINE_LEARNING_MODEL_TTL`<sup>\*1</sup>       | Inactivity time (s) before a model is unloaded (disabled if <= 0) |         `0`         | machine learning | | ||||
| | `MACHINE_LEARNING_MODEL_TTL`                     | Inactivity time (s) before a model is unloaded (disabled if <= 0) |        `300`        | machine learning | | ||||
| | `MACHINE_LEARNING_MODEL_TTL_POLL_S`              | Interval (s) between checks for the model TTL (disabled if <= 0)  |        `10`         | machine learning | | ||||
| | `MACHINE_LEARNING_CACHE_FOLDER`                  | Directory where models are downloaded                             |      `/cache`       | machine learning | | ||||
| | `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*2</sup> | Thread count of the request thread pool (disabled if <= 0)        | number of CPU cores | machine learning | | ||||
| | `MACHINE_LEARNING_REQUEST_THREADS`<sup>\*1</sup> | Thread count of the request thread pool (disabled if <= 0)        | number of CPU cores | machine learning | | ||||
| | `MACHINE_LEARNING_MODEL_INTER_OP_THREADS`        | Number of parallel model operations                               |         `1`         | machine learning | | ||||
| | `MACHINE_LEARNING_MODEL_INTRA_OP_THREADS`        | Number of threads for each model operation                        |         `2`         | machine learning | | ||||
| | `MACHINE_LEARNING_WORKERS`<sup>\*3</sup>         | Number of worker processes to spawn                               |         `1`         | machine learning | | ||||
| | `MACHINE_LEARNING_WORKERS`<sup>\*2</sup>         | Number of worker processes to spawn                               |         `1`         | machine learning | | ||||
| | `MACHINE_LEARNING_WORKER_TIMEOUT`                | Maximum time (s) of unresponsiveness before a worker is killed    |        `120`        | machine learning | | ||||
| 
 | ||||
| \*1: This is an experimental feature. It may result in increased memory use over time when loading models repeatedly. | ||||
| \*1: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones. | ||||
| 
 | ||||
| \*2: It is recommended to begin with this parameter when changing the concurrency levels of the machine learning service and then tune the other ones. | ||||
| 
 | ||||
| \*3: Since each process duplicates models in memory, changing this is not recommended unless you have abundant memory to go around. | ||||
| \*2: Since each process duplicates models in memory, changing this is not recommended unless you have abundant memory to go around. | ||||
| 
 | ||||
| :::info | ||||
| 
 | ||||
|  | ||||
| @ -13,7 +13,8 @@ from .schemas import ModelType | ||||
| 
 | ||||
| class Settings(BaseSettings): | ||||
|     cache_folder: str = "/cache" | ||||
|     model_ttl: int = 0 | ||||
|     model_ttl: int = 300 | ||||
|     model_ttl_poll_s: int = 10 | ||||
|     host: str = "0.0.0.0" | ||||
|     port: int = 3003 | ||||
|     workers: int = 1 | ||||
|  | ||||
| @ -1,5 +1,9 @@ | ||||
| import asyncio | ||||
| import gc | ||||
| import os | ||||
| import sys | ||||
| import threading | ||||
| import time | ||||
| from concurrent.futures import ThreadPoolExecutor | ||||
| from typing import Any | ||||
| from zipfile import BadZipFile | ||||
| @ -34,7 +38,10 @@ def init_state() -> None: | ||||
|     ) | ||||
|     # asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code | ||||
|     app.state.thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None | ||||
|     app.state.locks = {model_type: threading.Lock() for model_type in ModelType} | ||||
|     app.state.lock = threading.Lock() | ||||
|     app.state.last_called = None | ||||
|     if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0: | ||||
|         asyncio.ensure_future(idle_shutdown_task()) | ||||
|     log.info(f"Initialized request thread pool with {settings.request_threads} threads.") | ||||
| 
 | ||||
| 
 | ||||
| @ -79,9 +86,9 @@ async def predict( | ||||
| 
 | ||||
| 
 | ||||
| async def run(model: InferenceModel, inputs: Any) -> Any: | ||||
|     app.state.last_called = time.time() | ||||
|     if app.state.thread_pool is None: | ||||
|         return model.predict(inputs) | ||||
| 
 | ||||
|     return await asyncio.get_running_loop().run_in_executor(app.state.thread_pool, model.predict, inputs) | ||||
| 
 | ||||
| 
 | ||||
| @ -90,7 +97,7 @@ async def load(model: InferenceModel) -> InferenceModel: | ||||
|         return model | ||||
| 
 | ||||
|     def _load() -> None: | ||||
|         with app.state.locks[model.model_type]: | ||||
|         with app.state.lock: | ||||
|             model.load() | ||||
| 
 | ||||
|     loop = asyncio.get_running_loop() | ||||
| @ -113,3 +120,27 @@ async def load(model: InferenceModel) -> InferenceModel: | ||||
|         else: | ||||
|             await loop.run_in_executor(app.state.thread_pool, _load) | ||||
|         return model | ||||
| 
 | ||||
| 
 | ||||
| async def idle_shutdown_task() -> None: | ||||
|     while True: | ||||
|         log.debug("Checking for inactivity...") | ||||
|         if app.state.last_called is not None and time.time() - app.state.last_called > settings.model_ttl: | ||||
|             log.info("Shutting down due to inactivity.") | ||||
|             loop = asyncio.get_running_loop() | ||||
|             for task in asyncio.all_tasks(loop): | ||||
|                 if task is not asyncio.current_task(): | ||||
|                     try: | ||||
|                         task.cancel() | ||||
|                     except asyncio.CancelledError: | ||||
|                         pass | ||||
|             sys.stderr.close() | ||||
|             sys.stdout.close() | ||||
|             sys.stdout = sys.stderr = open(os.devnull, "w") | ||||
|             try: | ||||
|                 await app.state.model_cache.cache.clear() | ||||
|                 gc.collect() | ||||
|                 loop.stop() | ||||
|             except asyncio.CancelledError: | ||||
|                 pass | ||||
|         await asyncio.sleep(settings.model_ttl_poll_s) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user