Rework language handling in the scanner and handle seasons better

This commit is contained in:
Zoe Roux 2024-01-07 22:42:18 +01:00
parent f9d6a816b0
commit e1b555781c
4 changed files with 63 additions and 69 deletions

View File

@ -5,8 +5,6 @@ if TYPE_CHECKING:
from providers.implementations.themoviedatabase import TheMovieDatabase
from typing import List, Optional
from datetime import timedelta
from scanner.cache import cache
from providers.types.metadataid import MetadataID
@ -23,11 +21,7 @@ class IdMapper:
# Only fetch using tmdb if one of the required ids is not already known.
should_fetch = required is not None and any((x not in ids for x in required))
if self._tmdb and self._tmdb.name in ids and should_fetch:
tmdb_info = await self._tmdb.identify_show(
ids[self._tmdb.name].data_id,
original_language=None,
language=[self.language],
)
tmdb_info = await self._tmdb.identify_show(ids[self._tmdb.name].data_id)
return {**ids, **tmdb_info.external_id}
return ids

View File

@ -18,14 +18,21 @@ from ..types.studio import Studio
from ..types.genre import Genre
from ..types.metadataid import MetadataID
from ..types.show import Show, ShowTranslation, Status as ShowStatus
from ..types.season import Season
from ..types.collection import Collection, CollectionTranslation
class TheMovieDatabase(Provider):
def __init__(
self, client: ClientSession, api_key: str, xem: TheXem, idmapper: IdMapper
self,
languages,
client: ClientSession,
api_key: str,
xem: TheXem,
idmapper: IdMapper,
) -> None:
super().__init__()
self._languages = languages
self._client = client
self._xem = xem
self._idmapper = idmapper
@ -56,6 +63,9 @@ class TheMovieDatabase(Provider):
def name(self) -> str:
return "themoviedatabase"
def get_languages(self, *args):
return self._languages + list(args)
async def get(
self,
path: str,
@ -113,9 +123,7 @@ class TheMovieDatabase(Provider):
},
)
async def identify_movie(
self, name: str, year: Optional[int], *, language: list[str]
) -> Movie:
async def identify_movie(self, name: str, year: Optional[int]) -> Movie:
search_results = (
await self.get("search/movie", params={"query": name, "year": year})
)["results"]
@ -123,8 +131,7 @@ class TheMovieDatabase(Provider):
raise ProviderError(f"No result for a movie named: {name}")
search = self.get_best_result(search_results, name, year)
movie_id = search["id"]
if search["original_language"] not in language:
language.append(search["original_language"])
languages = self.get_languages(search["original_language"])
async def for_language(lng: str) -> Movie:
movie = await self.get(
@ -216,7 +223,7 @@ class TheMovieDatabase(Provider):
ret.translations = {lng: translation}
return ret
ret = await self.process_translations(for_language, language)
ret = await self.process_translations(for_language, languages)
# If we have more external_ids freely available, add them.
ret.external_id = await self._idmapper.get_movie(ret.external_id)
return ret
@ -225,12 +232,8 @@ class TheMovieDatabase(Provider):
async def identify_show(
self,
show_id: str,
*,
original_language: Optional[str],
language: list[str],
) -> Show:
if original_language and original_language not in language:
language.append(original_language)
languages = self.get_languages()
async def for_language(lng: str) -> Show:
show = await self.get(
@ -332,15 +335,17 @@ class TheMovieDatabase(Provider):
)
for x in items
],
languages=language,
languages=languages,
)
for season in item.seasons
]
return item
ret = await self.process_translations(
for_language, language, merge_seasons_translations
for_language, languages, merge_seasons_translations
)
if ret.original_language is not None and ret.original_language not in ret.translations:
ret.translations[ret.original_language] = (await for_language(ret.original_language)).translations[ret.original_language]
# If we have more external_ids freely available, add them.
ret.external_id = await self._idmapper.get_show(ret.external_id)
return ret
@ -375,6 +380,11 @@ class TheMovieDatabase(Provider):
},
)
async def identify_season(self, show_id: str, season_number: int) -> Season:
# We already get seasons info in the identify_show and chances are this gets cached already
show = await self.identify_show(show_id)
return show.seasons[season_number]
@cache(ttl=timedelta(days=1))
async def search_show(self, name: str, year: Optional[int]) -> PartialShow:
search_results = (
@ -425,12 +435,9 @@ class TheMovieDatabase(Provider):
episode_nbr: Optional[int],
absolute: Optional[int],
year: Optional[int],
*,
language: list[str],
) -> Episode:
show = await self.search_show(name, year)
if show.original_language and show.original_language not in language:
language.append(show.original_language)
languages = self.get_languages(show.original_language)
# Keep it for xem overrides of season/episode
old_name = name
name = show.name
@ -506,7 +513,7 @@ class TheMovieDatabase(Provider):
ret.translations = {lng: translation}
return ret
return await self.process_translations(for_language, language)
return await self.process_translations(for_language, languages)
def get_best_result(
self, search_results: List[Any], name: str, year: Optional[int]
@ -588,7 +595,7 @@ class TheMovieDatabase(Provider):
episode_nbr = absgrp[absolute - 1]["episode_number"]
return (season, episode_nbr)
# We assume that each season should be played in order with no special episodes.
show = await self.identify_show(show_id, original_language=None, language=[])
show = await self.identify_show(show_id)
seasons = [x.episodes_count for x in show.seasons]
# enumerate(accumulate(season)) return [(0, 12), (1, 24)] if the show has two seasons with 12 eps
# we take the last group that has less total episodes than the absolute number.
@ -603,9 +610,7 @@ class TheMovieDatabase(Provider):
absgrp = await self.get_absolute_order(show_id)
if absgrp is None:
# We assume that each season should be played in order with no special episodes.
show = await self.identify_show(
show_id, original_language=None, language=[]
)
show = await self.identify_show(show_id)
return sum(x.episodes_count for x in show.seasons[:season]) + episode_nbr
return next(
(
@ -617,9 +622,9 @@ class TheMovieDatabase(Provider):
None,
)
async def identify_collection(
self, provider_id: str, *, language: list[str]
) -> Collection:
async def identify_collection(self, provider_id: str) -> Collection:
languages = self.get_languages()
async def for_language(lng: str) -> Collection:
collection = await self.get(
f"collection/{provider_id}",
@ -651,4 +656,4 @@ class TheMovieDatabase(Provider):
ret.translations = {lng: translation}
return ret
return await self.process_translations(for_language, language)
return await self.process_translations(for_language, languages)

View File

@ -5,8 +5,9 @@ from typing import Optional, TypeVar
from providers.utils import ProviderError
from .types.episode import Episode
from .types.show import Show
from .types.season import Season
from .types.episode import Episode
from .types.movie import Movie
from .types.collection import Collection
@ -33,7 +34,7 @@ class Provider:
tmdb = os.environ.get("THEMOVIEDB_APIKEY")
if tmdb:
tmdb = TheMovieDatabase(client, tmdb, xem, idmapper)
tmdb = TheMovieDatabase(languages, client, tmdb, xem, idmapper)
providers.append(tmdb)
else:
tmdb = None
@ -52,15 +53,17 @@ class Provider:
raise NotImplementedError
@abstractmethod
async def identify_movie(
self, name: str, year: Optional[int], *, language: list[str]
) -> Movie:
async def identify_movie(self, name: str, year: Optional[int]) -> Movie:
raise NotImplementedError
@abstractmethod
async def identify_show(
self, show_id: str, *, original_language: Optional[str], language: list[str]
) -> Show:
async def identify_show(self, show_id: str) -> Show:
raise NotImplementedError
@abstractmethod
async def identify_season(
self, show_id: str, season_number: Optional[int]
) -> Season:
raise NotImplementedError
@abstractmethod
@ -71,13 +74,9 @@ class Provider:
episode_nbr: Optional[int],
absolute: Optional[int],
year: Optional[int],
*,
language: list[str]
) -> Episode:
raise NotImplementedError
@abstractmethod
async def identify_collection(
self, provider_id: str, *, language: list[str]
) -> Collection:
async def identify_collection(self, provider_id: str) -> Collection:
raise NotImplementedError

View File

@ -10,8 +10,9 @@ from guessit import guessit
from typing import List, Literal, Any
from providers.provider import Provider
from providers.types.collection import Collection
from providers.types.show import Show
from providers.types.episode import Episode, PartialShow
from providers.types.season import Season, SeasonTranslation
from providers.types.season import Season
from .utils import batch, log_errors
from .cache import cache, exec_as_cache, make_key
@ -88,9 +89,7 @@ class Scanner:
logging.info("Identied %s: %s", path, raw)
if raw["type"] == "movie":
movie = await self.provider.identify_movie(
raw["title"], raw.get("year"), language=self.languages
)
movie = await self.provider.identify_movie(raw["title"], raw.get("year"))
movie.path = str(path)
logging.debug("Got movie: %s", movie)
movie_id = await self.post("movies", data=movie.to_kyoo())
@ -109,7 +108,6 @@ class Scanner:
episode_nbr=raw.get("episode"),
absolute=raw.get("episode") if "season" not in raw else None,
year=raw.get("year"),
language=self.languages,
)
episode.path = str(path)
logging.debug("Got episode: %s", episode)
@ -117,8 +115,7 @@ class Scanner:
if episode.season_number is not None:
episode.season_id = await self.register_seasons(
show_id=episode.show_id,
season_number=episode.season_number,
episode.show, episode.show_id, episode.season_number
)
await self.post("episodes", data=episode.to_kyoo())
else:
@ -129,9 +126,7 @@ class Scanner:
async def create_collection(provider_id: str):
# TODO: Check if a collection with the same metadata id exists already on kyoo.
new_collection = (
await self.provider.identify_collection(
provider_id, language=self.languages
)
await self.provider.identify_collection(provider_id)
if not any(collection.translations.keys())
else collection
)
@ -161,8 +156,6 @@ class Scanner:
show = (
await self.provider.identify_show(
episode.show.external_id[self.provider.name].data_id,
original_language=episode.show.original_language,
language=self.languages,
)
if isinstance(episode.show, PartialShow)
else episode.show
@ -194,17 +187,20 @@ class Scanner:
provider_id = episode.show.external_id[self.provider.name].data_id
return await create_show(provider_id)
async def register_seasons(
self, show: Show | PartialShow, show_id: str, season_number: int
) -> str:
# We use an external season cache because we want to edit this cache programatically
@cache(ttl=timedelta(days=1), cache=season_cache)
async def register_seasons(self, show_id: str, season_number: int) -> str:
# TODO: fetch season here. this will be useful when a new season of a show is aired after the show has been created on kyoo.
season = Season(
season_number=season_number,
show_id=show_id,
translations={lng: SeasonTranslation() for lng in self.languages},
async def create_season(_: str, __: int):
season = await self.provider.identify_season(
show.external_id[self.provider.name].data_id, season_number
)
season.show_id = show_id
return await self.post("seasons", data=season.to_kyoo())
return await create_season(show_id, season_number)
async def post(self, path: str, *, data: dict[str, Any]) -> str:
logging.debug(
"Sending %s: %s",