Fix absolute handling

This commit is contained in:
Zoe Roux 2024-01-08 00:23:48 +01:00
parent 6e6098b03a
commit 472718c595

View File

@ -3,7 +3,7 @@ import logging
from aiohttp import ClientSession from aiohttp import ClientSession
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Awaitable, Callable, Dict, List, Optional, Any, TypeVar from typing import Awaitable, Callable, Dict, List, Optional, Any, TypeVar
from itertools import accumulate from itertools import accumulate, zip_longest
from providers.idmapper import IdMapper from providers.idmapper import IdMapper
from providers.implementations.thexem import TheXem from providers.implementations.thexem import TheXem
@ -344,8 +344,13 @@ class TheMovieDatabase(Provider):
ret = await self.process_translations( ret = await self.process_translations(
for_language, languages, merge_seasons_translations for_language, languages, merge_seasons_translations
) )
if ret.original_language is not None and ret.original_language not in ret.translations: if (
ret.translations[ret.original_language] = (await for_language(ret.original_language)).translations[ret.original_language] ret.original_language is not None
and ret.original_language not in ret.translations
):
ret.translations[ret.original_language] = (
await for_language(ret.original_language)
).translations[ret.original_language]
# If we have more external_ids freely available, add them. # If we have more external_ids freely available, add them.
ret.external_id = await self._idmapper.get_show(ret.external_id) ret.external_id = await self._idmapper.get_show(ret.external_id)
return ret return ret
@ -383,7 +388,12 @@ class TheMovieDatabase(Provider):
async def identify_season(self, show_id: str, season_number: int) -> Season: async def identify_season(self, show_id: str, season_number: int) -> Season:
# We already get seasons info in the identify_show and chances are this gets cached already # We already get seasons info in the identify_show and chances are this gets cached already
show = await self.identify_show(show_id) show = await self.identify_show(show_id)
return show.seasons[season_number] ret = next((x for x in show.seasons if x.season_number == season_number), None)
if ret is None:
raise ProviderError(
f"Could not find season {season_number} for show {show.to_kyoo()['name']}"
)
return ret
@cache(ttl=timedelta(days=1)) @cache(ttl=timedelta(days=1))
async def search_show(self, name: str, year: Optional[int]) -> PartialShow: async def search_show(self, name: str, year: Optional[int]) -> PartialShow:
@ -467,11 +477,7 @@ class TheMovieDatabase(Provider):
if season is None or episode_nbr is None: if season is None or episode_nbr is None:
raise ProviderError( raise ProviderError(
"Could not guess season or episode number of the episode %s %d-%d (%d)", f"Could not guess season or episode number of the episode {name} {season}-{episode_nbr} ({absolute})",
name,
season,
episode_nbr,
absolute,
) )
if absolute is None: if absolute is None:
@ -596,22 +602,36 @@ class TheMovieDatabase(Provider):
return (season, episode_nbr) return (season, episode_nbr)
# We assume that each season should be played in order with no special episodes. # We assume that each season should be played in order with no special episodes.
show = await self.identify_show(show_id) show = await self.identify_show(show_id)
seasons = [x.episodes_count for x in show.seasons] # Dont forget to ingore the special season (season_number 0)
# enumerate(accumulate(season)) return [(0, 12), (1, 24)] if the show has two seasons with 12 eps seasons_nbrs = [x.season_number for x in show.seasons if x.season_number != 0]
seasons_eps = [x.episodes_count for x in show.seasons if x.season_number != 0]
# zip_longest(seasons_nbrs[1:], accumulate(seasons_eps)) return [(2, 12), (None, 24)] if the show has two seasons with 12 eps
# we take the last group that has less total episodes than the absolute number. # we take the last group that has less total episodes than the absolute number.
(season, total_ep_count) = next( return next(
(snbr, ep_cnt) (
for snbr, ep_cnt in reversed(list(enumerate(accumulate(seasons)))) (snbr, absolute - ep_cnt)
if ep_cnt <= absolute for snbr, ep_cnt in reversed(
list(zip_longest(seasons_nbrs[1:], accumulate(seasons_eps)))
)
if ep_cnt < absolute
),
# If the absolute episode number is lower than the 1st season number of episode, it is part of it.
(seasons_nbrs[0], absolute),
) )
return (show.seasons[season].season_number, absolute - total_ep_count)
async def get_absolute_number(self, show_id: str, season: int, episode_nbr: int): async def get_absolute_number(self, show_id: str, season: int, episode_nbr: int):
absgrp = await self.get_absolute_order(show_id) absgrp = await self.get_absolute_order(show_id)
if absgrp is None: if absgrp is None:
# We assume that each season should be played in order with no special episodes. # We assume that each season should be played in order with no special episodes.
show = await self.identify_show(show_id) show = await self.identify_show(show_id)
return sum(x.episodes_count for x in show.seasons[:season]) + episode_nbr return (
sum(
x.episodes_count
for x in show.seasons
if 0 < x.season_number < season
)
+ episode_nbr
)
return next( return next(
( (
# The + 1 is to go from 0based index to 1based absolute number # The + 1 is to go from 0based index to 1based absolute number