Rework language handling in the scanner and handle seasons better

This commit is contained in:
Zoe Roux 2024-01-07 22:42:18 +01:00
parent f9d6a816b0
commit e1b555781c
4 changed files with 63 additions and 69 deletions

View File

@ -5,8 +5,6 @@ if TYPE_CHECKING:
from providers.implementations.themoviedatabase import TheMovieDatabase from providers.implementations.themoviedatabase import TheMovieDatabase
from typing import List, Optional from typing import List, Optional
from datetime import timedelta
from scanner.cache import cache
from providers.types.metadataid import MetadataID from providers.types.metadataid import MetadataID
@ -23,11 +21,7 @@ class IdMapper:
# Only fetch using tmdb if one of the required ids is not already known. # Only fetch using tmdb if one of the required ids is not already known.
should_fetch = required is not None and any((x not in ids for x in required)) should_fetch = required is not None and any((x not in ids for x in required))
if self._tmdb and self._tmdb.name in ids and should_fetch: if self._tmdb and self._tmdb.name in ids and should_fetch:
tmdb_info = await self._tmdb.identify_show( tmdb_info = await self._tmdb.identify_show(ids[self._tmdb.name].data_id)
ids[self._tmdb.name].data_id,
original_language=None,
language=[self.language],
)
return {**ids, **tmdb_info.external_id} return {**ids, **tmdb_info.external_id}
return ids return ids

View File

@ -18,14 +18,21 @@ from ..types.studio import Studio
from ..types.genre import Genre from ..types.genre import Genre
from ..types.metadataid import MetadataID from ..types.metadataid import MetadataID
from ..types.show import Show, ShowTranslation, Status as ShowStatus from ..types.show import Show, ShowTranslation, Status as ShowStatus
from ..types.season import Season
from ..types.collection import Collection, CollectionTranslation from ..types.collection import Collection, CollectionTranslation
class TheMovieDatabase(Provider): class TheMovieDatabase(Provider):
def __init__( def __init__(
self, client: ClientSession, api_key: str, xem: TheXem, idmapper: IdMapper self,
languages,
client: ClientSession,
api_key: str,
xem: TheXem,
idmapper: IdMapper,
) -> None: ) -> None:
super().__init__() super().__init__()
self._languages = languages
self._client = client self._client = client
self._xem = xem self._xem = xem
self._idmapper = idmapper self._idmapper = idmapper
@ -56,6 +63,9 @@ class TheMovieDatabase(Provider):
def name(self) -> str: def name(self) -> str:
return "themoviedatabase" return "themoviedatabase"
def get_languages(self, *args):
return self._languages + list(args)
async def get( async def get(
self, self,
path: str, path: str,
@ -113,9 +123,7 @@ class TheMovieDatabase(Provider):
}, },
) )
async def identify_movie( async def identify_movie(self, name: str, year: Optional[int]) -> Movie:
self, name: str, year: Optional[int], *, language: list[str]
) -> Movie:
search_results = ( search_results = (
await self.get("search/movie", params={"query": name, "year": year}) await self.get("search/movie", params={"query": name, "year": year})
)["results"] )["results"]
@ -123,8 +131,7 @@ class TheMovieDatabase(Provider):
raise ProviderError(f"No result for a movie named: {name}") raise ProviderError(f"No result for a movie named: {name}")
search = self.get_best_result(search_results, name, year) search = self.get_best_result(search_results, name, year)
movie_id = search["id"] movie_id = search["id"]
if search["original_language"] not in language: languages = self.get_languages(search["original_language"])
language.append(search["original_language"])
async def for_language(lng: str) -> Movie: async def for_language(lng: str) -> Movie:
movie = await self.get( movie = await self.get(
@ -216,7 +223,7 @@ class TheMovieDatabase(Provider):
ret.translations = {lng: translation} ret.translations = {lng: translation}
return ret return ret
ret = await self.process_translations(for_language, language) ret = await self.process_translations(for_language, languages)
# If we have more external_ids freely available, add them. # If we have more external_ids freely available, add them.
ret.external_id = await self._idmapper.get_movie(ret.external_id) ret.external_id = await self._idmapper.get_movie(ret.external_id)
return ret return ret
@ -225,12 +232,8 @@ class TheMovieDatabase(Provider):
async def identify_show( async def identify_show(
self, self,
show_id: str, show_id: str,
*,
original_language: Optional[str],
language: list[str],
) -> Show: ) -> Show:
if original_language and original_language not in language: languages = self.get_languages()
language.append(original_language)
async def for_language(lng: str) -> Show: async def for_language(lng: str) -> Show:
show = await self.get( show = await self.get(
@ -332,15 +335,17 @@ class TheMovieDatabase(Provider):
) )
for x in items for x in items
], ],
languages=language, languages=languages,
) )
for season in item.seasons for season in item.seasons
] ]
return item return item
ret = await self.process_translations( ret = await self.process_translations(
for_language, language, merge_seasons_translations for_language, languages, merge_seasons_translations
) )
if ret.original_language is not None and ret.original_language not in ret.translations:
ret.translations[ret.original_language] = (await for_language(ret.original_language)).translations[ret.original_language]
# If we have more external_ids freely available, add them. # If we have more external_ids freely available, add them.
ret.external_id = await self._idmapper.get_show(ret.external_id) ret.external_id = await self._idmapper.get_show(ret.external_id)
return ret return ret
@ -375,6 +380,11 @@ class TheMovieDatabase(Provider):
}, },
) )
async def identify_season(self, show_id: str, season_number: int) -> Season:
# We already get seasons info in the identify_show and chances are this gets cached already
show = await self.identify_show(show_id)
return show.seasons[season_number]
@cache(ttl=timedelta(days=1)) @cache(ttl=timedelta(days=1))
async def search_show(self, name: str, year: Optional[int]) -> PartialShow: async def search_show(self, name: str, year: Optional[int]) -> PartialShow:
search_results = ( search_results = (
@ -425,12 +435,9 @@ class TheMovieDatabase(Provider):
episode_nbr: Optional[int], episode_nbr: Optional[int],
absolute: Optional[int], absolute: Optional[int],
year: Optional[int], year: Optional[int],
*,
language: list[str],
) -> Episode: ) -> Episode:
show = await self.search_show(name, year) show = await self.search_show(name, year)
if show.original_language and show.original_language not in language: languages = self.get_languages(show.original_language)
language.append(show.original_language)
# Keep it for xem overrides of season/episode # Keep it for xem overrides of season/episode
old_name = name old_name = name
name = show.name name = show.name
@ -506,7 +513,7 @@ class TheMovieDatabase(Provider):
ret.translations = {lng: translation} ret.translations = {lng: translation}
return ret return ret
return await self.process_translations(for_language, language) return await self.process_translations(for_language, languages)
def get_best_result( def get_best_result(
self, search_results: List[Any], name: str, year: Optional[int] self, search_results: List[Any], name: str, year: Optional[int]
@ -588,7 +595,7 @@ class TheMovieDatabase(Provider):
episode_nbr = absgrp[absolute - 1]["episode_number"] episode_nbr = absgrp[absolute - 1]["episode_number"]
return (season, episode_nbr) return (season, episode_nbr)
# We assume that each season should be played in order with no special episodes. # We assume that each season should be played in order with no special episodes.
show = await self.identify_show(show_id, original_language=None, language=[]) show = await self.identify_show(show_id)
seasons = [x.episodes_count for x in show.seasons] seasons = [x.episodes_count for x in show.seasons]
# enumerate(accumulate(season)) return [(0, 12), (1, 24)] if the show has two seasons with 12 eps # enumerate(accumulate(season)) return [(0, 12), (1, 24)] if the show has two seasons with 12 eps
# we take the last group that has less total episodes than the absolute number. # we take the last group that has less total episodes than the absolute number.
@ -603,9 +610,7 @@ class TheMovieDatabase(Provider):
absgrp = await self.get_absolute_order(show_id) absgrp = await self.get_absolute_order(show_id)
if absgrp is None: if absgrp is None:
# We assume that each season should be played in order with no special episodes. # We assume that each season should be played in order with no special episodes.
show = await self.identify_show( show = await self.identify_show(show_id)
show_id, original_language=None, language=[]
)
return sum(x.episodes_count for x in show.seasons[:season]) + episode_nbr return sum(x.episodes_count for x in show.seasons[:season]) + episode_nbr
return next( return next(
( (
@ -617,9 +622,9 @@ class TheMovieDatabase(Provider):
None, None,
) )
async def identify_collection( async def identify_collection(self, provider_id: str) -> Collection:
self, provider_id: str, *, language: list[str] languages = self.get_languages()
) -> Collection:
async def for_language(lng: str) -> Collection: async def for_language(lng: str) -> Collection:
collection = await self.get( collection = await self.get(
f"collection/{provider_id}", f"collection/{provider_id}",
@ -651,4 +656,4 @@ class TheMovieDatabase(Provider):
ret.translations = {lng: translation} ret.translations = {lng: translation}
return ret return ret
return await self.process_translations(for_language, language) return await self.process_translations(for_language, languages)

View File

@ -5,8 +5,9 @@ from typing import Optional, TypeVar
from providers.utils import ProviderError from providers.utils import ProviderError
from .types.episode import Episode
from .types.show import Show from .types.show import Show
from .types.season import Season
from .types.episode import Episode
from .types.movie import Movie from .types.movie import Movie
from .types.collection import Collection from .types.collection import Collection
@ -33,7 +34,7 @@ class Provider:
tmdb = os.environ.get("THEMOVIEDB_APIKEY") tmdb = os.environ.get("THEMOVIEDB_APIKEY")
if tmdb: if tmdb:
tmdb = TheMovieDatabase(client, tmdb, xem, idmapper) tmdb = TheMovieDatabase(languages, client, tmdb, xem, idmapper)
providers.append(tmdb) providers.append(tmdb)
else: else:
tmdb = None tmdb = None
@ -52,15 +53,17 @@ class Provider:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def identify_movie( async def identify_movie(self, name: str, year: Optional[int]) -> Movie:
self, name: str, year: Optional[int], *, language: list[str]
) -> Movie:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def identify_show( async def identify_show(self, show_id: str) -> Show:
self, show_id: str, *, original_language: Optional[str], language: list[str] raise NotImplementedError
) -> Show:
@abstractmethod
async def identify_season(
self, show_id: str, season_number: Optional[int]
) -> Season:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
@ -71,13 +74,9 @@ class Provider:
episode_nbr: Optional[int], episode_nbr: Optional[int],
absolute: Optional[int], absolute: Optional[int],
year: Optional[int], year: Optional[int],
*,
language: list[str]
) -> Episode: ) -> Episode:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
async def identify_collection( async def identify_collection(self, provider_id: str) -> Collection:
self, provider_id: str, *, language: list[str]
) -> Collection:
raise NotImplementedError raise NotImplementedError

View File

@ -10,8 +10,9 @@ from guessit import guessit
from typing import List, Literal, Any from typing import List, Literal, Any
from providers.provider import Provider from providers.provider import Provider
from providers.types.collection import Collection from providers.types.collection import Collection
from providers.types.show import Show
from providers.types.episode import Episode, PartialShow from providers.types.episode import Episode, PartialShow
from providers.types.season import Season, SeasonTranslation from providers.types.season import Season
from .utils import batch, log_errors from .utils import batch, log_errors
from .cache import cache, exec_as_cache, make_key from .cache import cache, exec_as_cache, make_key
@ -88,9 +89,7 @@ class Scanner:
logging.info("Identied %s: %s", path, raw) logging.info("Identied %s: %s", path, raw)
if raw["type"] == "movie": if raw["type"] == "movie":
movie = await self.provider.identify_movie( movie = await self.provider.identify_movie(raw["title"], raw.get("year"))
raw["title"], raw.get("year"), language=self.languages
)
movie.path = str(path) movie.path = str(path)
logging.debug("Got movie: %s", movie) logging.debug("Got movie: %s", movie)
movie_id = await self.post("movies", data=movie.to_kyoo()) movie_id = await self.post("movies", data=movie.to_kyoo())
@ -109,7 +108,6 @@ class Scanner:
episode_nbr=raw.get("episode"), episode_nbr=raw.get("episode"),
absolute=raw.get("episode") if "season" not in raw else None, absolute=raw.get("episode") if "season" not in raw else None,
year=raw.get("year"), year=raw.get("year"),
language=self.languages,
) )
episode.path = str(path) episode.path = str(path)
logging.debug("Got episode: %s", episode) logging.debug("Got episode: %s", episode)
@ -117,8 +115,7 @@ class Scanner:
if episode.season_number is not None: if episode.season_number is not None:
episode.season_id = await self.register_seasons( episode.season_id = await self.register_seasons(
show_id=episode.show_id, episode.show, episode.show_id, episode.season_number
season_number=episode.season_number,
) )
await self.post("episodes", data=episode.to_kyoo()) await self.post("episodes", data=episode.to_kyoo())
else: else:
@ -129,9 +126,7 @@ class Scanner:
async def create_collection(provider_id: str): async def create_collection(provider_id: str):
# TODO: Check if a collection with the same metadata id exists already on kyoo. # TODO: Check if a collection with the same metadata id exists already on kyoo.
new_collection = ( new_collection = (
await self.provider.identify_collection( await self.provider.identify_collection(provider_id)
provider_id, language=self.languages
)
if not any(collection.translations.keys()) if not any(collection.translations.keys())
else collection else collection
) )
@ -161,8 +156,6 @@ class Scanner:
show = ( show = (
await self.provider.identify_show( await self.provider.identify_show(
episode.show.external_id[self.provider.name].data_id, episode.show.external_id[self.provider.name].data_id,
original_language=episode.show.original_language,
language=self.languages,
) )
if isinstance(episode.show, PartialShow) if isinstance(episode.show, PartialShow)
else episode.show else episode.show
@ -194,17 +187,20 @@ class Scanner:
provider_id = episode.show.external_id[self.provider.name].data_id provider_id = episode.show.external_id[self.provider.name].data_id
return await create_show(provider_id) return await create_show(provider_id)
async def register_seasons(
self, show: Show | PartialShow, show_id: str, season_number: int
) -> str:
# We use an external season cache because we want to edit this cache programatically # We use an external season cache because we want to edit this cache programatically
@cache(ttl=timedelta(days=1), cache=season_cache) @cache(ttl=timedelta(days=1), cache=season_cache)
async def register_seasons(self, show_id: str, season_number: int) -> str: async def create_season(_: str, __: int):
# TODO: fetch season here. this will be useful when a new season of a show is aired after the show has been created on kyoo. season = await self.provider.identify_season(
season = Season( show.external_id[self.provider.name].data_id, season_number
season_number=season_number,
show_id=show_id,
translations={lng: SeasonTranslation() for lng in self.languages},
) )
season.show_id = show_id
return await self.post("seasons", data=season.to_kyoo()) return await self.post("seasons", data=season.to_kyoo())
return await create_season(show_id, season_number)
async def post(self, path: str, *, data: dict[str, Any]) -> str: async def post(self, path: str, *, data: dict[str, Any]) -> str:
logging.debug( logging.debug(
"Sending %s: %s", "Sending %s: %s",