diff --git a/scanner/providers/idmapper.py b/scanner/providers/idmapper.py index c50745a0..1d9b0815 100644 --- a/scanner/providers/idmapper.py +++ b/scanner/providers/idmapper.py @@ -5,8 +5,6 @@ if TYPE_CHECKING: from providers.implementations.themoviedatabase import TheMovieDatabase from typing import List, Optional -from datetime import timedelta -from scanner.cache import cache from providers.types.metadataid import MetadataID @@ -23,11 +21,7 @@ class IdMapper: # Only fetch using tmdb if one of the required ids is not already known. should_fetch = required is not None and any((x not in ids for x in required)) if self._tmdb and self._tmdb.name in ids and should_fetch: - tmdb_info = await self._tmdb.identify_show( - ids[self._tmdb.name].data_id, - original_language=None, - language=[self.language], - ) + tmdb_info = await self._tmdb.identify_show(ids[self._tmdb.name].data_id) return {**ids, **tmdb_info.external_id} return ids diff --git a/scanner/providers/implementations/themoviedatabase.py b/scanner/providers/implementations/themoviedatabase.py index 99ac4bb0..b4ab427a 100644 --- a/scanner/providers/implementations/themoviedatabase.py +++ b/scanner/providers/implementations/themoviedatabase.py @@ -18,14 +18,21 @@ from ..types.studio import Studio from ..types.genre import Genre from ..types.metadataid import MetadataID from ..types.show import Show, ShowTranslation, Status as ShowStatus +from ..types.season import Season from ..types.collection import Collection, CollectionTranslation class TheMovieDatabase(Provider): def __init__( - self, client: ClientSession, api_key: str, xem: TheXem, idmapper: IdMapper + self, + languages, + client: ClientSession, + api_key: str, + xem: TheXem, + idmapper: IdMapper, ) -> None: super().__init__() + self._languages = languages self._client = client self._xem = xem self._idmapper = idmapper @@ -56,6 +63,9 @@ class TheMovieDatabase(Provider): def name(self) -> str: return "themoviedatabase" + def get_languages(self, *args): + return self._languages + list(args) + async def get( self, path: str, @@ -113,9 +123,7 @@ class TheMovieDatabase(Provider): }, ) - async def identify_movie( - self, name: str, year: Optional[int], *, language: list[str] - ) -> Movie: + async def identify_movie(self, name: str, year: Optional[int]) -> Movie: search_results = ( await self.get("search/movie", params={"query": name, "year": year}) )["results"] @@ -123,8 +131,7 @@ class TheMovieDatabase(Provider): raise ProviderError(f"No result for a movie named: {name}") search = self.get_best_result(search_results, name, year) movie_id = search["id"] - if search["original_language"] not in language: - language.append(search["original_language"]) + languages = self.get_languages(search["original_language"]) async def for_language(lng: str) -> Movie: movie = await self.get( @@ -216,7 +223,7 @@ class TheMovieDatabase(Provider): ret.translations = {lng: translation} return ret - ret = await self.process_translations(for_language, language) + ret = await self.process_translations(for_language, languages) # If we have more external_ids freely available, add them. ret.external_id = await self._idmapper.get_movie(ret.external_id) return ret @@ -225,12 +232,8 @@ class TheMovieDatabase(Provider): async def identify_show( self, show_id: str, - *, - original_language: Optional[str], - language: list[str], ) -> Show: - if original_language and original_language not in language: - language.append(original_language) + languages = self.get_languages() async def for_language(lng: str) -> Show: show = await self.get( @@ -332,15 +335,17 @@ class TheMovieDatabase(Provider): ) for x in items ], - languages=language, + languages=languages, ) for season in item.seasons ] return item ret = await self.process_translations( - for_language, language, merge_seasons_translations + for_language, languages, merge_seasons_translations ) + if ret.original_language is not None and ret.original_language not in ret.translations: + ret.translations[ret.original_language] = (await for_language(ret.original_language)).translations[ret.original_language] # If we have more external_ids freely available, add them. ret.external_id = await self._idmapper.get_show(ret.external_id) return ret @@ -375,6 +380,11 @@ class TheMovieDatabase(Provider): }, ) + async def identify_season(self, show_id: str, season_number: int) -> Season: + # We already get seasons info in the identify_show and chances are this gets cached already + show = await self.identify_show(show_id) + return show.seasons[season_number] + @cache(ttl=timedelta(days=1)) async def search_show(self, name: str, year: Optional[int]) -> PartialShow: search_results = ( @@ -425,12 +435,9 @@ class TheMovieDatabase(Provider): episode_nbr: Optional[int], absolute: Optional[int], year: Optional[int], - *, - language: list[str], ) -> Episode: show = await self.search_show(name, year) - if show.original_language and show.original_language not in language: - language.append(show.original_language) + languages = self.get_languages(show.original_language) # Keep it for xem overrides of season/episode old_name = name name = show.name @@ -506,7 +513,7 @@ class TheMovieDatabase(Provider): ret.translations = {lng: translation} return ret - return await self.process_translations(for_language, language) + return await self.process_translations(for_language, languages) def get_best_result( self, search_results: List[Any], name: str, year: Optional[int] @@ -588,7 +595,7 @@ class TheMovieDatabase(Provider): episode_nbr = absgrp[absolute - 1]["episode_number"] return (season, episode_nbr) # We assume that each season should be played in order with no special episodes. - show = await self.identify_show(show_id, original_language=None, language=[]) + show = await self.identify_show(show_id) seasons = [x.episodes_count for x in show.seasons] # enumerate(accumulate(season)) return [(0, 12), (1, 24)] if the show has two seasons with 12 eps # we take the last group that has less total episodes than the absolute number. @@ -603,9 +610,7 @@ class TheMovieDatabase(Provider): absgrp = await self.get_absolute_order(show_id) if absgrp is None: # We assume that each season should be played in order with no special episodes. - show = await self.identify_show( - show_id, original_language=None, language=[] - ) + show = await self.identify_show(show_id) return sum(x.episodes_count for x in show.seasons[:season]) + episode_nbr return next( ( @@ -617,9 +622,9 @@ class TheMovieDatabase(Provider): None, ) - async def identify_collection( - self, provider_id: str, *, language: list[str] - ) -> Collection: + async def identify_collection(self, provider_id: str) -> Collection: + languages = self.get_languages() + async def for_language(lng: str) -> Collection: collection = await self.get( f"collection/{provider_id}", @@ -651,4 +656,4 @@ class TheMovieDatabase(Provider): ret.translations = {lng: translation} return ret - return await self.process_translations(for_language, language) + return await self.process_translations(for_language, languages) diff --git a/scanner/providers/provider.py b/scanner/providers/provider.py index 8ac53906..31e5d244 100644 --- a/scanner/providers/provider.py +++ b/scanner/providers/provider.py @@ -5,8 +5,9 @@ from typing import Optional, TypeVar from providers.utils import ProviderError -from .types.episode import Episode from .types.show import Show +from .types.season import Season +from .types.episode import Episode from .types.movie import Movie from .types.collection import Collection @@ -33,7 +34,7 @@ class Provider: tmdb = os.environ.get("THEMOVIEDB_APIKEY") if tmdb: - tmdb = TheMovieDatabase(client, tmdb, xem, idmapper) + tmdb = TheMovieDatabase(languages, client, tmdb, xem, idmapper) providers.append(tmdb) else: tmdb = None @@ -52,15 +53,17 @@ class Provider: raise NotImplementedError @abstractmethod - async def identify_movie( - self, name: str, year: Optional[int], *, language: list[str] - ) -> Movie: + async def identify_movie(self, name: str, year: Optional[int]) -> Movie: raise NotImplementedError @abstractmethod - async def identify_show( - self, show_id: str, *, original_language: Optional[str], language: list[str] - ) -> Show: + async def identify_show(self, show_id: str) -> Show: + raise NotImplementedError + + @abstractmethod + async def identify_season( + self, show_id: str, season_number: Optional[int] + ) -> Season: raise NotImplementedError @abstractmethod @@ -71,13 +74,9 @@ class Provider: episode_nbr: Optional[int], absolute: Optional[int], year: Optional[int], - *, - language: list[str] ) -> Episode: raise NotImplementedError @abstractmethod - async def identify_collection( - self, provider_id: str, *, language: list[str] - ) -> Collection: + async def identify_collection(self, provider_id: str) -> Collection: raise NotImplementedError diff --git a/scanner/scanner/scanner.py b/scanner/scanner/scanner.py index 91dee2a6..584695f8 100644 --- a/scanner/scanner/scanner.py +++ b/scanner/scanner/scanner.py @@ -10,8 +10,9 @@ from guessit import guessit from typing import List, Literal, Any from providers.provider import Provider from providers.types.collection import Collection +from providers.types.show import Show from providers.types.episode import Episode, PartialShow -from providers.types.season import Season, SeasonTranslation +from providers.types.season import Season from .utils import batch, log_errors from .cache import cache, exec_as_cache, make_key @@ -88,9 +89,7 @@ class Scanner: logging.info("Identied %s: %s", path, raw) if raw["type"] == "movie": - movie = await self.provider.identify_movie( - raw["title"], raw.get("year"), language=self.languages - ) + movie = await self.provider.identify_movie(raw["title"], raw.get("year")) movie.path = str(path) logging.debug("Got movie: %s", movie) movie_id = await self.post("movies", data=movie.to_kyoo()) @@ -109,7 +108,6 @@ class Scanner: episode_nbr=raw.get("episode"), absolute=raw.get("episode") if "season" not in raw else None, year=raw.get("year"), - language=self.languages, ) episode.path = str(path) logging.debug("Got episode: %s", episode) @@ -117,8 +115,7 @@ class Scanner: if episode.season_number is not None: episode.season_id = await self.register_seasons( - show_id=episode.show_id, - season_number=episode.season_number, + episode.show, episode.show_id, episode.season_number ) await self.post("episodes", data=episode.to_kyoo()) else: @@ -129,9 +126,7 @@ class Scanner: async def create_collection(provider_id: str): # TODO: Check if a collection with the same metadata id exists already on kyoo. new_collection = ( - await self.provider.identify_collection( - provider_id, language=self.languages - ) + await self.provider.identify_collection(provider_id) if not any(collection.translations.keys()) else collection ) @@ -161,8 +156,6 @@ class Scanner: show = ( await self.provider.identify_show( episode.show.external_id[self.provider.name].data_id, - original_language=episode.show.original_language, - language=self.languages, ) if isinstance(episode.show, PartialShow) else episode.show @@ -194,16 +187,19 @@ class Scanner: provider_id = episode.show.external_id[self.provider.name].data_id return await create_show(provider_id) - # We use an external season cache because we want to edit this cache programatically - @cache(ttl=timedelta(days=1), cache=season_cache) - async def register_seasons(self, show_id: str, season_number: int) -> str: - # TODO: fetch season here. this will be useful when a new season of a show is aired after the show has been created on kyoo. - season = Season( - season_number=season_number, - show_id=show_id, - translations={lng: SeasonTranslation() for lng in self.languages}, - ) - return await self.post("seasons", data=season.to_kyoo()) + async def register_seasons( + self, show: Show | PartialShow, show_id: str, season_number: int + ) -> str: + # We use an external season cache because we want to edit this cache programatically + @cache(ttl=timedelta(days=1), cache=season_cache) + async def create_season(_: str, __: int): + season = await self.provider.identify_season( + show.external_id[self.provider.name].data_id, season_number + ) + season.show_id = show_id + return await self.post("seasons", data=season.to_kyoo()) + + return await create_season(show_id, season_number) async def post(self, path: str, *, data: dict[str, Any]) -> str: logging.debug(