mirror of
				https://github.com/immich-app/immich.git
				synced 2025-11-03 19:29:32 -05:00 
			
		
		
		
	* basic refactor and styling * removed batching * module entrypoint * removed unused imports * model superclass, model cache now in app state * fixed cache dir and enforced abstract method --------- Co-authored-by: Alex Tran <alex.tran1502@gmail.com>
		
			
				
	
	
		
			93 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			93 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import asyncio
 | 
						|
 | 
						|
from aiocache.backends.memory import SimpleMemoryCache
 | 
						|
from aiocache.lock import OptimisticLock
 | 
						|
from aiocache.plugins import BasePlugin, TimingPlugin
 | 
						|
 | 
						|
from ..schemas import ModelType
 | 
						|
from .base import InferenceModel
 | 
						|
 | 
						|
 | 
						|
class ModelCache:
 | 
						|
    """Fetches a model from an in-memory cache, instantiating it if it's missing."""
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        ttl: float | None = None,
 | 
						|
        revalidate: bool = False,
 | 
						|
        timeout: int | None = None,
 | 
						|
        profiling: bool = False,
 | 
						|
    ):
 | 
						|
        """
 | 
						|
        Args:
 | 
						|
            ttl: Unloads model after this duration. Disabled if None. Defaults to None.
 | 
						|
            revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
 | 
						|
            timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
 | 
						|
            profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
 | 
						|
        """
 | 
						|
 | 
						|
        self.ttl = ttl
 | 
						|
        plugins = []
 | 
						|
 | 
						|
        if revalidate:
 | 
						|
            plugins.append(RevalidationPlugin())
 | 
						|
        if profiling:
 | 
						|
            plugins.append(TimingPlugin())
 | 
						|
 | 
						|
        self.cache = SimpleMemoryCache(
 | 
						|
            ttl=ttl, timeout=timeout, plugins=plugins, namespace=None
 | 
						|
        )
 | 
						|
 | 
						|
    async def get(
 | 
						|
        self, model_name: str, model_type: ModelType, **model_kwargs
 | 
						|
    ) -> InferenceModel:
 | 
						|
        """
 | 
						|
        Args:
 | 
						|
            model_name: Name of model in the model hub used for the task.
 | 
						|
            model_type: Model type or task, which determines which model zoo is used.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            model: The requested model.
 | 
						|
        """
 | 
						|
 | 
						|
        key = self.cache.build_key(model_name, model_type.value)
 | 
						|
        model = await self.cache.get(key)
 | 
						|
        if model is None:
 | 
						|
            async with OptimisticLock(self.cache, key) as lock:
 | 
						|
                model = await asyncio.get_running_loop().run_in_executor(
 | 
						|
                    None,
 | 
						|
                    lambda: InferenceModel.from_model_type(
 | 
						|
                        model_type, model_name, **model_kwargs
 | 
						|
                    ),
 | 
						|
                )
 | 
						|
                await lock.cas(model, ttl=self.ttl)
 | 
						|
        return model
 | 
						|
 | 
						|
    async def get_profiling(self) -> dict[str, float] | None:
 | 
						|
        if not hasattr(self.cache, "profiling"):
 | 
						|
            return None
 | 
						|
 | 
						|
        return self.cache.profiling  # type: ignore
 | 
						|
 | 
						|
 | 
						|
class RevalidationPlugin(BasePlugin):
 | 
						|
    """Revalidates cache item's TTL after cache hit."""
 | 
						|
 | 
						|
    async def post_get(self, client, key, ret=None, namespace=None, **kwargs):
 | 
						|
        if ret is None:
 | 
						|
            return
 | 
						|
        if namespace is not None:
 | 
						|
            key = client.build_key(key, namespace)
 | 
						|
        if key in client._handlers:
 | 
						|
            await client.expire(key, client.ttl)
 | 
						|
 | 
						|
    async def post_multi_get(self, client, keys, ret=None, namespace=None, **kwargs):
 | 
						|
        if ret is None:
 | 
						|
            return
 | 
						|
 | 
						|
        for key, val in zip(keys, ret):
 | 
						|
            if namespace is not None:
 | 
						|
                key = client.build_key(key, namespace)
 | 
						|
            if val is not None and key in client._handlers:
 | 
						|
                await client.expire(key, client.ttl)
 |