diff --git a/back/src/Kyoo.Core/Controllers/Repositories/SeasonRepository.cs b/back/src/Kyoo.Core/Controllers/Repositories/SeasonRepository.cs index 8dc193c5..3d20216e 100644 --- a/back/src/Kyoo.Core/Controllers/Repositories/SeasonRepository.cs +++ b/back/src/Kyoo.Core/Controllers/Repositories/SeasonRepository.cs @@ -107,7 +107,7 @@ namespace Kyoo.Core.Controllers { await base.Create(obj); _database.Entry(obj).State = EntityState.Added; - await _database.SaveChangesAsync(() => Get(obj.Slug)); + await _database.SaveChangesAsync(() => Get(obj.ShowID, obj.SeasonNumber)); return obj; } diff --git a/scanner/providers/implementations/themoviedatabase.py b/scanner/providers/implementations/themoviedatabase.py index 1d04f6ec..5e3dc426 100644 --- a/scanner/providers/implementations/themoviedatabase.py +++ b/scanner/providers/implementations/themoviedatabase.py @@ -4,6 +4,8 @@ from aiohttp import ClientSession from datetime import datetime from typing import Awaitable, Callable, Dict, Optional, Any, TypeVar +from providers.utils import ProviderError + from ..provider import Provider from ..types.movie import Movie, MovieTranslation, Status as MovieStatus from ..types.season import Season, SeasonTranslation @@ -97,9 +99,12 @@ class TheMovieDatabase(Provider): async def identify_movie( self, name: str, year: Optional[int], *, language: list[str] ) -> Movie: - search = (await self.get("search/movie", params={"query": name, "year": year}))[ - "results" - ][0] + search_results = ( + await self.get("search/movie", params={"query": name, "year": year}) + )["results"] + if len(search_results) == 0: + raise ProviderError(f"No result for a movie named: {name}") + search = search_results[0] movie_id = search["id"] if search["original_language"] not in language: language.append(search["original_language"]) @@ -118,9 +123,9 @@ class TheMovieDatabase(Provider): ret = Movie( original_language=movie["original_language"], aliases=[x["title"] for x in movie["alternative_titles"]["titles"]], - release_date=datetime.strptime( - movie["release_date"], "%Y-%m-%d" - ).date(), + release_date=datetime.strptime(movie["release_date"], "%Y-%m-%d").date() + if movie["release_date"] + else None, status=MovieStatus.FINISHED if movie["status"] == "Released" else MovieStatus.PLANNED, @@ -184,8 +189,12 @@ class TheMovieDatabase(Provider): ret = Show( original_language=show["original_language"], aliases=[x["title"] for x in show["alternative_titles"]["results"]], - start_air=datetime.strptime(show["first_air_date"], "%Y-%m-%d").date(), - end_air=datetime.strptime(show["last_air_date"], "%Y-%m-%d").date(), + start_air=datetime.strptime(show["first_air_date"], "%Y-%m-%d").date() + if show["first_air_date"] + else None, + end_air=datetime.strptime(show["last_air_date"], "%Y-%m-%d").date() + if show["last_air_date"] + else None, status=ShowStatus.FINISHED if show["status"] == "Released" else ShowStatus.AIRING @@ -258,7 +267,9 @@ class TheMovieDatabase(Provider): ) -> Season: return Season( season_number=season["season_number"], - start_air=datetime.strptime(season["air_date"], "%Y-%m-%d").date(), + start_air=datetime.strptime(season["air_date"], "%Y-%m-%d").date() + if season["air_date"] + else None, end_air=None, external_ids={ self.name: MetadataID( @@ -289,14 +300,19 @@ class TheMovieDatabase(Provider): *, language: list[str], ) -> Episode: - search = (await self.get("search/tv", params={"query": name}))["results"][0] + search_results = (await self.get("search/tv", params={"query": name}))[ + "results" + ] + if len(search_results) == 0: + raise ProviderError(f"No result for a tv show named: {name}") + search = search_results[0] show_id = search["id"] if search["original_language"] not in language: language.append(search["original_language"]) # TODO: Handle absolute episodes if not season or not episode_nbr: - raise NotImplementedError( + raise ProviderError( "Absolute order episodes not implemented for the movie database" ) @@ -323,7 +339,9 @@ class TheMovieDatabase(Provider): episode_number=episode["episode_number"], # TODO: absolute numbers absolute_number=None, - release_date=datetime.strptime(episode["air_date"], "%Y-%m-%d").date(), + release_date=datetime.strptime(episode["air_date"], "%Y-%m-%d").date() + if episode["air_date"] + else None, thumbnail=f"https://image.tmdb.org/t/p/original{episode['poster_path']}" if "poster_path" in episode else None, diff --git a/scanner/providers/types/season.py b/scanner/providers/types/season.py index fde20130..ca7094b4 100644 --- a/scanner/providers/types/season.py +++ b/scanner/providers/types/season.py @@ -9,18 +9,18 @@ from .metadataid import MetadataID @dataclass class SeasonTranslation: - name: Optional[str] - overview: Optional[str] - posters: list[str] - thumbnails: list[str] + name: Optional[str] = None + overview: Optional[str] = None + posters: list[str] = field(default_factory=list) + thumbnails: list[str] = field(default_factory=list) @dataclass class Season: season_number: int - start_air: Optional[date | int] - end_air: Optional[date | int] - external_ids: dict[str, MetadataID] + start_air: Optional[date | int] = None + end_air: Optional[date | int] = None + external_ids: dict[str, MetadataID] = field(default_factory=dict) show_id: Optional[str] = None translations: dict[str, SeasonTranslation] = field(default_factory=dict) diff --git a/scanner/providers/utils.py b/scanner/providers/utils.py index c0e6d8a2..27b38212 100644 --- a/scanner/providers/utils.py +++ b/scanner/providers/utils.py @@ -7,3 +7,7 @@ def format_date(date: date | int | None) -> str | None: if isinstance(date, int): return f"{date}-01-01" return date.isoformat() + +class ProviderError(RuntimeError): + def __init__(self, *args: object) -> None: + super().__init__(*args) diff --git a/scanner/scanner/scanner.py b/scanner/scanner/scanner.py index 22b2df45..60142a88 100644 --- a/scanner/scanner/scanner.py +++ b/scanner/scanner/scanner.py @@ -8,7 +8,8 @@ from pathlib import Path from guessit import guessit from providers.provider import Provider from providers.types.episode import Episode, PartialShow -from providers.types.season import Season +from providers.types.season import Season, SeasonTranslation +from providers.utils import ProviderError def log_errors(f): @@ -16,12 +17,64 @@ def log_errors(f): async def internal(*args, **kwargs): try: await f(*args, **kwargs) + except ProviderError as e: + logging.error(str(e)) except Exception as e: logging.exception("Unhandled error", exc_info=e) return internal +cache = {} + + +def provider_cache(*args): + ic = cache + for arg in args: + if arg not in ic: + ic[arg] = {} + ic = ic[arg] + + def wrapper(f): + @wraps(f) + async def internal(*args, **kwargs): + nonlocal ic + for arg in args: + if arg not in ic: + ic[arg] = {} + ic = ic[arg] + + if "event" in ic: + await ic["event"].wait() + if not ic["ret"]: + raise ProviderError("Cache miss. Another error should exist") + return ic["ret"] + ic["event"] = asyncio.Event() + try: + ret = await f(*args, **kwargs) + ic["ret"] = ret + except: + ic["event"].set() + raise + ic["event"].set() + return ret + return internal + return wrapper + + +def set_in_cache(key: list[str | int]): + ic = cache + for arg in key: + if arg not in ic: + ic[arg] = {} + ic = ic[arg] + evt = asyncio.Event() + evt.set() + ic["event"] = evt + + + + class Scanner: def __init__( self, client: ClientSession, *, languages: list[str], api_key: str @@ -30,7 +83,7 @@ class Scanner: self._api_key = api_key self._url = os.environ.get("KYOO_URL", "http://back:5000") self.provider = Provider.get_all(client)[0] - self.cache = {"shows": {}} + self.cache = {"shows": {}, "seasons": {}} self.languages = languages async def scan(self, path: str): @@ -50,7 +103,6 @@ class Scanner: return True return False - @log_errors async def identify(self, path: Path): if await self.is_registered(path): @@ -58,7 +110,6 @@ class Scanner: raw = guessit(path, "--episode-prefer-number") logging.info("Identied %s: %s", path, raw) - # TODO: check if episode/movie already exists in kyoo and skip if it does. # TODO: Add collections support if raw["type"] == "movie": movie = await self.provider.identify_movie( @@ -78,47 +129,49 @@ class Scanner: episode.path = str(path) logging.debug("Got episode: %s", episode) episode.show_id = await self.create_or_get_show(episode) - # TODO: Do the same things for seasons and wait for them to be created on the api (else the episode creation will fail) + + if episode.season_number is not None: + await self.register_seasons( + show_id=episode.show_id, + season_number=episode.season_number, + ) await self.post("episodes", data=episode.to_kyoo()) else: logging.warn("Unknown video file type: %s", raw["type"]) async def create_or_get_show(self, episode: Episode) -> str: - provider_id = episode.show.external_ids[self.provider.name].id - if provider_id in self.cache["shows"]: - ret = self.cache["shows"][provider_id] - await ret["event"].wait() - if not ret["id"]: - raise RuntimeError("Provider failed to create the show") - return ret["id"] - - self.cache["shows"][provider_id] = {"id": None, "event": asyncio.Event()} - - # TODO: Check if a show with the same metadata id exists already on kyoo. - - show = ( - await self.provider.identify_show(episode.show, language=self.languages) - if isinstance(episode.show, PartialShow) - else episode.show - ) - logging.debug("Got show: %s", episode) - try: + @provider_cache("shows") + async def create_show(_: str): + # TODO: Check if a show with the same metadata id exists already on kyoo. + show = ( + await self.provider.identify_show(episode.show, language=self.languages) + if isinstance(episode.show, PartialShow) + else episode.show + ) + logging.debug("Got show: %s", episode) ret = await self.post("show", data=show.to_kyoo()) - except: - # Allow tasks waiting for this show to bail out. - self.cache["shows"][provider_id]["event"].set() - raise - self.cache["shows"][provider_id]["id"] = ret - self.cache["shows"][provider_id]["event"].set() + try: + for season in show.seasons: + season.show_id = ret + await self.post("seasons", data=season.to_kyoo()) + set_in_cache(key=["seasons", ret, season.season_number]) + except Exception as e: + logging.exception("Unhandled error create a season", exc_info=e) + return ret - # TODO: Better handling of seasons registrations (maybe a lock also) - await self.register_seasons(ret, show.seasons) - return ret + # The parameter is only used as a key for the cache. + provider_id = episode.show.external_ids[self.provider.name].id + return await create_show(provider_id) - async def register_seasons(self, show_id: str, seasons: list[Season]): - for season in seasons: - season.show_id = show_id - await self.post("seasons", data=season.to_kyoo()) + @provider_cache("seasons") + async def register_seasons(self, show_id: str, season_number: int): + # TODO: fetch season here. this will be useful when a new season of a show is aired after the show has been created on kyoo. + season = Season( + season_number=season_number, + show_id=show_id, + translations={lng: SeasonTranslation() for lng in self.languages}, + ) + await self.post("seasons", data=season.to_kyoo()) async def post(self, path: str, *, data: object) -> str: logging.debug(