Use a sentinel instead of None for cache presence

This commit is contained in:
Zoe Roux 2024-01-08 01:03:01 +01:00
parent b180429505
commit 0b55bd7dbb

View File

@ -1,9 +1,15 @@
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import wraps from functools import wraps
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple, Final, Literal
from enum import Enum
type Cache = dict[Any, Tuple[Optional[asyncio.Event], Optional[datetime], Any]] # Waiting for https://github.com/python/typing/issues/689 for better sentinels
class Sentinel(Enum):
NoneSentinel = 0
none: Final = Sentinel.NoneSentinel
type Cache = dict[Any, Tuple[asyncio.Event | Literal[none], datetime | Literal[none], Any]]
def cache(ttl: timedelta, cache: Optional[Cache] = None, typed=False): def cache(ttl: timedelta, cache: Optional[Cache] = None, typed=False):
""" """
@ -25,19 +31,19 @@ def cache(ttl: timedelta, cache: Optional[Cache] = None, typed=False):
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs):
key = make_key(args, kwargs, typed) key = make_key(args, kwargs, typed)
ret = cache.get(key, (None, None, None)) ret = cache.get(key, (none, none, none))
# First check if the same method is already running and wait for it. # First check if the same method is already running and wait for it.
if ret[0] is not None: if ret[0] != none:
await ret[0].wait() await ret[0].wait()
ret = cache.get(key, (None, None, None)) ret = cache.get(key, (none, none, none))
if ret[2] is None: if ret[2] == none:
# ret[2] can be None if the cached method failed. if that is the case, run again. # ret[2] can be None if the cached method failed. if that is the case, run again.
return await wrapper(*args, **kwargs) return await wrapper(*args, **kwargs)
return ret[2] return ret[2]
# Return the cached result if it exits and is not expired # Return the cached result if it exits and is not expired
if ( if (
ret[2] is not None ret[2] != none
and ret[1] is not None and ret[1] != none
and datetime.now() - ret[1] < ttl and datetime.now() - ret[1] < ttl
): ):
return ret[2] return ret[2]
@ -50,7 +56,7 @@ def cache(ttl: timedelta, cache: Optional[Cache] = None, typed=False):
async def exec_as_cache(cache: Cache, key, f): async def exec_as_cache(cache: Cache, key, f):
event = asyncio.Event() event = asyncio.Event()
cache[key] = (event, None, None) cache[key] = (event, none, none)
try: try:
result = await f() result = await f()
except: except:
@ -58,7 +64,7 @@ async def exec_as_cache(cache: Cache, key, f):
event.set() event.set()
raise raise
cache[key] = (None, datetime.now(), result) cache[key] = (none, datetime.now(), result)
event.set() event.set()
return result return result