Use the new cache in the scanner

This commit is contained in:
Zoe Roux 2024-01-05 14:46:39 +01:00
parent 5f787bedfe
commit f110509ea0
3 changed files with 26 additions and 63 deletions

View File

@ -2,7 +2,6 @@ import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import wraps from functools import wraps
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
from providers.utils import ProviderError
type Cache = dict[Any, Tuple[Optional[asyncio.Event], Optional[datetime], Any]] type Cache = dict[Any, Tuple[Optional[asyncio.Event], Optional[datetime], Any]]

View File

@ -1,3 +1,4 @@
from datetime import timedelta
import os import os
import asyncio import asyncio
import logging import logging
@ -11,7 +12,10 @@ from providers.provider import Provider
from providers.types.collection import Collection from providers.types.collection import Collection
from providers.types.episode import Episode, PartialShow from providers.types.episode import Episode, PartialShow
from providers.types.season import Season, SeasonTranslation from providers.types.season import Season, SeasonTranslation
from .utils import batch, log_errors, provider_cache, set_in_cache from .utils import batch, log_errors
from .cache import cache, exec_as_cache, make_key
season_cache = {}
class Scanner: class Scanner:
@ -29,7 +33,6 @@ class Scanner:
self._ignore_pattern = re.compile("") self._ignore_pattern = re.compile("")
logging.error(f"Invalid ignore pattern. Ignoring. Error: {e}") logging.error(f"Invalid ignore pattern. Ignoring. Error: {e}")
self.provider = Provider.get_all(client, languages)[0] self.provider = Provider.get_all(client, languages)[0]
self.cache = {"shows": {}, "seasons": {}, "collections": {}}
self.languages = languages self.languages = languages
async def scan(self, path: str): async def scan(self, path: str):
@ -122,7 +125,7 @@ class Scanner:
logging.warn("Unknown video file type: %s", raw["type"]) logging.warn("Unknown video file type: %s", raw["type"])
async def create_or_get_collection(self, collection: Collection) -> str: async def create_or_get_collection(self, collection: Collection) -> str:
@provider_cache("collection") @cache(ttl=timedelta(days=1))
async def create_collection(provider_id: str): async def create_collection(provider_id: str):
# TODO: Check if a collection with the same metadata id exists already on kyoo. # TODO: Check if a collection with the same metadata id exists already on kyoo.
new_collection = ( new_collection = (
@ -152,7 +155,7 @@ class Scanner:
r.raise_for_status() r.raise_for_status()
async def create_or_get_show(self, episode: Episode) -> str: async def create_or_get_show(self, episode: Episode) -> str:
@provider_cache("shows") @cache(ttl=timedelta(days=1))
async def create_show(_: str): async def create_show(_: str):
# TODO: Check if a show with the same metadata id exists already on kyoo. # TODO: Check if a show with the same metadata id exists already on kyoo.
show = ( show = (
@ -167,20 +170,32 @@ class Scanner:
# TODO: collections # TODO: collections
logging.debug("Got show: %s", episode) logging.debug("Got show: %s", episode)
ret = await self.post("show", data=show.to_kyoo()) ret = await self.post("show", data=show.to_kyoo())
try:
for season in show.seasons: async def create_season(season: Season):
try:
season.show_id = ret season.show_id = ret
await self.post("seasons", data=season.to_kyoo()) return await self.post("seasons", data=season.to_kyoo())
set_in_cache(key=["seasons", ret, season.season_number]) except Exception as e:
except Exception as e: logging.exception("Unhandled error create a season", exc_info=e)
logging.exception("Unhandled error create a season", exc_info=e)
season_tasks = (
exec_as_cache(
season_cache,
make_key((ret, s.season_number)),
lambda: create_season(s),
)
for s in show.seasons
)
await asyncio.gather(*season_tasks)
return ret return ret
# The parameter is only used as a key for the cache. # The parameter is only used as a key for the cache.
provider_id = episode.show.external_id[self.provider.name].data_id provider_id = episode.show.external_id[self.provider.name].data_id
return await create_show(provider_id) return await create_show(provider_id)
@provider_cache("seasons") # We use an external season cache because we want to edit this cache programatically
@cache(ttl=timedelta(days=1), cache=season_cache)
async def register_seasons(self, show_id: str, season_number: int) -> str: async def register_seasons(self, show_id: str, season_number: int) -> str:
# TODO: fetch season here. this will be useful when a new season of a show is aired after the show has been created on kyoo. # TODO: fetch season here. this will be useful when a new season of a show is aired after the show has been created on kyoo.
season = Season( season = Season(

View File

@ -1,4 +1,3 @@
import asyncio
import logging import logging
from functools import wraps from functools import wraps
from itertools import islice from itertools import islice
@ -31,53 +30,3 @@ def log_errors(f):
logging.exception("Unhandled error", exc_info=e) logging.exception("Unhandled error", exc_info=e)
return internal return internal
cache = {}
def provider_cache(*args):
ic = cache
for arg in args:
if arg not in ic:
ic[arg] = {}
ic = ic[arg]
def wrapper(f):
@wraps(f)
async def internal(*args, **kwargs):
nonlocal ic
for arg in args:
if arg not in ic:
ic[arg] = {}
ic = ic[arg]
if "event" in ic:
await ic["event"].wait()
if "ret" not in ic:
raise ProviderError("Cache miss. Another error should exist")
return ic["ret"]
ic["event"] = asyncio.Event()
try:
ret = await f(*args, **kwargs)
ic["ret"] = ret
except:
ic["event"].set()
raise
ic["event"].set()
return ret
return internal
return wrapper
def set_in_cache(key: list[str | int]):
ic = cache
for arg in key:
if arg not in ic:
ic[arg] = {}
ic = ic[arg]
evt = asyncio.Event()
evt.set()
ic["event"] = evt