Add seasons metadata

This commit is contained in:
Zoe Roux 2023-03-29 02:25:41 +09:00
parent 75fb4b5809
commit 7388719cad
5 changed files with 132 additions and 57 deletions

View File

@ -107,7 +107,7 @@ namespace Kyoo.Core.Controllers
{ {
await base.Create(obj); await base.Create(obj);
_database.Entry(obj).State = EntityState.Added; _database.Entry(obj).State = EntityState.Added;
await _database.SaveChangesAsync(() => Get(obj.Slug)); await _database.SaveChangesAsync(() => Get(obj.ShowID, obj.SeasonNumber));
return obj; return obj;
} }

View File

@ -4,6 +4,8 @@ from aiohttp import ClientSession
from datetime import datetime from datetime import datetime
from typing import Awaitable, Callable, Dict, Optional, Any, TypeVar from typing import Awaitable, Callable, Dict, Optional, Any, TypeVar
from providers.utils import ProviderError
from ..provider import Provider from ..provider import Provider
from ..types.movie import Movie, MovieTranslation, Status as MovieStatus from ..types.movie import Movie, MovieTranslation, Status as MovieStatus
from ..types.season import Season, SeasonTranslation from ..types.season import Season, SeasonTranslation
@ -97,9 +99,12 @@ class TheMovieDatabase(Provider):
async def identify_movie( async def identify_movie(
self, name: str, year: Optional[int], *, language: list[str] self, name: str, year: Optional[int], *, language: list[str]
) -> Movie: ) -> Movie:
search = (await self.get("search/movie", params={"query": name, "year": year}))[ search_results = (
"results" await self.get("search/movie", params={"query": name, "year": year})
][0] )["results"]
if len(search_results) == 0:
raise ProviderError(f"No result for a movie named: {name}")
search = search_results[0]
movie_id = search["id"] movie_id = search["id"]
if search["original_language"] not in language: if search["original_language"] not in language:
language.append(search["original_language"]) language.append(search["original_language"])
@ -118,9 +123,9 @@ class TheMovieDatabase(Provider):
ret = Movie( ret = Movie(
original_language=movie["original_language"], original_language=movie["original_language"],
aliases=[x["title"] for x in movie["alternative_titles"]["titles"]], aliases=[x["title"] for x in movie["alternative_titles"]["titles"]],
release_date=datetime.strptime( release_date=datetime.strptime(movie["release_date"], "%Y-%m-%d").date()
movie["release_date"], "%Y-%m-%d" if movie["release_date"]
).date(), else None,
status=MovieStatus.FINISHED status=MovieStatus.FINISHED
if movie["status"] == "Released" if movie["status"] == "Released"
else MovieStatus.PLANNED, else MovieStatus.PLANNED,
@ -184,8 +189,12 @@ class TheMovieDatabase(Provider):
ret = Show( ret = Show(
original_language=show["original_language"], original_language=show["original_language"],
aliases=[x["title"] for x in show["alternative_titles"]["results"]], aliases=[x["title"] for x in show["alternative_titles"]["results"]],
start_air=datetime.strptime(show["first_air_date"], "%Y-%m-%d").date(), start_air=datetime.strptime(show["first_air_date"], "%Y-%m-%d").date()
end_air=datetime.strptime(show["last_air_date"], "%Y-%m-%d").date(), if show["first_air_date"]
else None,
end_air=datetime.strptime(show["last_air_date"], "%Y-%m-%d").date()
if show["last_air_date"]
else None,
status=ShowStatus.FINISHED status=ShowStatus.FINISHED
if show["status"] == "Released" if show["status"] == "Released"
else ShowStatus.AIRING else ShowStatus.AIRING
@ -258,7 +267,9 @@ class TheMovieDatabase(Provider):
) -> Season: ) -> Season:
return Season( return Season(
season_number=season["season_number"], season_number=season["season_number"],
start_air=datetime.strptime(season["air_date"], "%Y-%m-%d").date(), start_air=datetime.strptime(season["air_date"], "%Y-%m-%d").date()
if season["air_date"]
else None,
end_air=None, end_air=None,
external_ids={ external_ids={
self.name: MetadataID( self.name: MetadataID(
@ -289,14 +300,19 @@ class TheMovieDatabase(Provider):
*, *,
language: list[str], language: list[str],
) -> Episode: ) -> Episode:
search = (await self.get("search/tv", params={"query": name}))["results"][0] search_results = (await self.get("search/tv", params={"query": name}))[
"results"
]
if len(search_results) == 0:
raise ProviderError(f"No result for a tv show named: {name}")
search = search_results[0]
show_id = search["id"] show_id = search["id"]
if search["original_language"] not in language: if search["original_language"] not in language:
language.append(search["original_language"]) language.append(search["original_language"])
# TODO: Handle absolute episodes # TODO: Handle absolute episodes
if not season or not episode_nbr: if not season or not episode_nbr:
raise NotImplementedError( raise ProviderError(
"Absolute order episodes not implemented for the movie database" "Absolute order episodes not implemented for the movie database"
) )
@ -323,7 +339,9 @@ class TheMovieDatabase(Provider):
episode_number=episode["episode_number"], episode_number=episode["episode_number"],
# TODO: absolute numbers # TODO: absolute numbers
absolute_number=None, absolute_number=None,
release_date=datetime.strptime(episode["air_date"], "%Y-%m-%d").date(), release_date=datetime.strptime(episode["air_date"], "%Y-%m-%d").date()
if episode["air_date"]
else None,
thumbnail=f"https://image.tmdb.org/t/p/original{episode['poster_path']}" thumbnail=f"https://image.tmdb.org/t/p/original{episode['poster_path']}"
if "poster_path" in episode if "poster_path" in episode
else None, else None,

View File

@ -9,18 +9,18 @@ from .metadataid import MetadataID
@dataclass @dataclass
class SeasonTranslation: class SeasonTranslation:
name: Optional[str] name: Optional[str] = None
overview: Optional[str] overview: Optional[str] = None
posters: list[str] posters: list[str] = field(default_factory=list)
thumbnails: list[str] thumbnails: list[str] = field(default_factory=list)
@dataclass @dataclass
class Season: class Season:
season_number: int season_number: int
start_air: Optional[date | int] start_air: Optional[date | int] = None
end_air: Optional[date | int] end_air: Optional[date | int] = None
external_ids: dict[str, MetadataID] external_ids: dict[str, MetadataID] = field(default_factory=dict)
show_id: Optional[str] = None show_id: Optional[str] = None
translations: dict[str, SeasonTranslation] = field(default_factory=dict) translations: dict[str, SeasonTranslation] = field(default_factory=dict)

View File

@ -7,3 +7,7 @@ def format_date(date: date | int | None) -> str | None:
if isinstance(date, int): if isinstance(date, int):
return f"{date}-01-01" return f"{date}-01-01"
return date.isoformat() return date.isoformat()
class ProviderError(RuntimeError):
def __init__(self, *args: object) -> None:
super().__init__(*args)

View File

@ -8,7 +8,8 @@ from pathlib import Path
from guessit import guessit from guessit import guessit
from providers.provider import Provider from providers.provider import Provider
from providers.types.episode import Episode, PartialShow from providers.types.episode import Episode, PartialShow
from providers.types.season import Season from providers.types.season import Season, SeasonTranslation
from providers.utils import ProviderError
def log_errors(f): def log_errors(f):
@ -16,12 +17,64 @@ def log_errors(f):
async def internal(*args, **kwargs): async def internal(*args, **kwargs):
try: try:
await f(*args, **kwargs) await f(*args, **kwargs)
except ProviderError as e:
logging.error(str(e))
except Exception as e: except Exception as e:
logging.exception("Unhandled error", exc_info=e) logging.exception("Unhandled error", exc_info=e)
return internal return internal
cache = {}
def provider_cache(*args):
ic = cache
for arg in args:
if arg not in ic:
ic[arg] = {}
ic = ic[arg]
def wrapper(f):
@wraps(f)
async def internal(*args, **kwargs):
nonlocal ic
for arg in args:
if arg not in ic:
ic[arg] = {}
ic = ic[arg]
if "event" in ic:
await ic["event"].wait()
if not ic["ret"]:
raise ProviderError("Cache miss. Another error should exist")
return ic["ret"]
ic["event"] = asyncio.Event()
try:
ret = await f(*args, **kwargs)
ic["ret"] = ret
except:
ic["event"].set()
raise
ic["event"].set()
return ret
return internal
return wrapper
def set_in_cache(key: list[str | int]):
ic = cache
for arg in key:
if arg not in ic:
ic[arg] = {}
ic = ic[arg]
evt = asyncio.Event()
evt.set()
ic["event"] = evt
class Scanner: class Scanner:
def __init__( def __init__(
self, client: ClientSession, *, languages: list[str], api_key: str self, client: ClientSession, *, languages: list[str], api_key: str
@ -30,7 +83,7 @@ class Scanner:
self._api_key = api_key self._api_key = api_key
self._url = os.environ.get("KYOO_URL", "http://back:5000") self._url = os.environ.get("KYOO_URL", "http://back:5000")
self.provider = Provider.get_all(client)[0] self.provider = Provider.get_all(client)[0]
self.cache = {"shows": {}} self.cache = {"shows": {}, "seasons": {}}
self.languages = languages self.languages = languages
async def scan(self, path: str): async def scan(self, path: str):
@ -50,7 +103,6 @@ class Scanner:
return True return True
return False return False
@log_errors @log_errors
async def identify(self, path: Path): async def identify(self, path: Path):
if await self.is_registered(path): if await self.is_registered(path):
@ -58,7 +110,6 @@ class Scanner:
raw = guessit(path, "--episode-prefer-number") raw = guessit(path, "--episode-prefer-number")
logging.info("Identied %s: %s", path, raw) logging.info("Identied %s: %s", path, raw)
# TODO: check if episode/movie already exists in kyoo and skip if it does.
# TODO: Add collections support # TODO: Add collections support
if raw["type"] == "movie": if raw["type"] == "movie":
movie = await self.provider.identify_movie( movie = await self.provider.identify_movie(
@ -78,47 +129,49 @@ class Scanner:
episode.path = str(path) episode.path = str(path)
logging.debug("Got episode: %s", episode) logging.debug("Got episode: %s", episode)
episode.show_id = await self.create_or_get_show(episode) episode.show_id = await self.create_or_get_show(episode)
# TODO: Do the same things for seasons and wait for them to be created on the api (else the episode creation will fail)
if episode.season_number is not None:
await self.register_seasons(
show_id=episode.show_id,
season_number=episode.season_number,
)
await self.post("episodes", data=episode.to_kyoo()) await self.post("episodes", data=episode.to_kyoo())
else: else:
logging.warn("Unknown video file type: %s", raw["type"]) logging.warn("Unknown video file type: %s", raw["type"])
async def create_or_get_show(self, episode: Episode) -> str: async def create_or_get_show(self, episode: Episode) -> str:
provider_id = episode.show.external_ids[self.provider.name].id @provider_cache("shows")
if provider_id in self.cache["shows"]: async def create_show(_: str):
ret = self.cache["shows"][provider_id] # TODO: Check if a show with the same metadata id exists already on kyoo.
await ret["event"].wait() show = (
if not ret["id"]: await self.provider.identify_show(episode.show, language=self.languages)
raise RuntimeError("Provider failed to create the show") if isinstance(episode.show, PartialShow)
return ret["id"] else episode.show
)
self.cache["shows"][provider_id] = {"id": None, "event": asyncio.Event()} logging.debug("Got show: %s", episode)
# TODO: Check if a show with the same metadata id exists already on kyoo.
show = (
await self.provider.identify_show(episode.show, language=self.languages)
if isinstance(episode.show, PartialShow)
else episode.show
)
logging.debug("Got show: %s", episode)
try:
ret = await self.post("show", data=show.to_kyoo()) ret = await self.post("show", data=show.to_kyoo())
except: try:
# Allow tasks waiting for this show to bail out. for season in show.seasons:
self.cache["shows"][provider_id]["event"].set() season.show_id = ret
raise await self.post("seasons", data=season.to_kyoo())
self.cache["shows"][provider_id]["id"] = ret set_in_cache(key=["seasons", ret, season.season_number])
self.cache["shows"][provider_id]["event"].set() except Exception as e:
logging.exception("Unhandled error create a season", exc_info=e)
return ret
# TODO: Better handling of seasons registrations (maybe a lock also) # The parameter is only used as a key for the cache.
await self.register_seasons(ret, show.seasons) provider_id = episode.show.external_ids[self.provider.name].id
return ret return await create_show(provider_id)
async def register_seasons(self, show_id: str, seasons: list[Season]): @provider_cache("seasons")
for season in seasons: async def register_seasons(self, show_id: str, season_number: int):
season.show_id = show_id # TODO: fetch season here. this will be useful when a new season of a show is aired after the show has been created on kyoo.
await self.post("seasons", data=season.to_kyoo()) season = Season(
season_number=season_number,
show_id=show_id,
translations={lng: SeasonTranslation() for lng in self.languages},
)
await self.post("seasons", data=season.to_kyoo())
async def post(self, path: str, *, data: object) -> str: async def post(self, path: str, *, data: object) -> str:
logging.debug( logging.debug(