diff --git a/scanner/providers/implementations/anilist.py b/scanner/providers/implementations/anilist.py index 275658f4..8e250885 100644 --- a/scanner/providers/implementations/anilist.py +++ b/scanner/providers/implementations/anilist.py @@ -1,9 +1,8 @@ import asyncio from aiohttp import ClientSession -from datetime import date +from datetime import date, timedelta from logging import getLogger -from typing import Awaitable, Callable, Dict, List, Optional, Any, TypeVar -from itertools import accumulate, zip_longest +from typing import Optional from providers.utils import ProviderError from matcher.cache import cache @@ -56,24 +55,30 @@ class AniList(Provider): return "anilist" async def get(self, query: str, not_found: str, **variables: Optional[str | int]): - logger.error(variables) - async with self._client.post( - self.base, - json={ - "query": query, - "variables": {k: v for (k, v) in variables.items() if v is not None}, - }, - ) as r: - if r.status == 404: - raise ProviderError(not_found) - ret = await r.json() - logger.error(ret) - r.raise_for_status() - if "errors" in ret: + while True: + async with self._client.post( + self.base, + json={ + "query": query, + "variables": { + k: v for (k, v) in variables.items() if v is not None + }, + }, + ) as r: + if r.status == 404: + raise ProviderError(not_found) + if r.status == 429: + await asyncio.sleep(float(r.headers["Retry-After"])) + continue + ret = await r.json() logger.error(ret) - raise Exception(ret["errors"]) - return ret["data"] + r.raise_for_status() + if "errors" in ret: + logger.error(ret) + raise Exception(ret["errors"]) + return ret["data"] + @cache(ttl=timedelta(days=1)) async def query_anime( self, *, @@ -94,6 +99,7 @@ class AniList(Provider): } description(asHtml: false) status + episodes startDate { year month @@ -153,7 +159,7 @@ class AniList(Provider): not_found=f"Could not find the show {id or ''}{search or ''}", ) ret = q["Media"] - return Show( + show = Show( translations={ "en": ShowTranslation( name=ret["title"]["romaji"], @@ -217,7 +223,27 @@ class AniList(Provider): }, seasons=[], ) + show.seasons.append( + Season( + # TODO: fill this approprietly + season_number=1, + episodes_count=ret["episodes"], + start_air=show.start_air, + end_air=show.end_air, + external_id=show.external_id, + translations={ + "en": SeasonTranslation( + name=show.translations["en"].name, + overview=show.translations["en"].overview, + posters=show.translations["en"].posters, + thumbnails=[], + ) + }, + ) + ) + return show + @cache(ttl=timedelta(days=1)) async def query_movie( self, *, @@ -378,7 +404,8 @@ class AniList(Provider): return await self.query_anime(id=show_id) async def identify_season(self, show_id: str, season: int) -> Season: - raise NotImplementedError + show = await self.query_anime(id=show_id) + return next((x for x in show.seasons if x.season_number == season)) async def identify_episode( self, show_id: str, season: Optional[int], episode_nbr: int, absolute: int diff --git a/scanner/providers/implementations/themoviedatabase.py b/scanner/providers/implementations/themoviedatabase.py index 3cf99a88..e48bd3f2 100644 --- a/scanner/providers/implementations/themoviedatabase.py +++ b/scanner/providers/implementations/themoviedatabase.py @@ -2,7 +2,7 @@ import asyncio from aiohttp import ClientSession from datetime import datetime, timedelta from logging import getLogger -from typing import Awaitable, Callable, Dict, List, Optional, Any, TypeVar +from typing import cast, Awaitable, Callable, Dict, List, Optional, Any, TypeVar from itertools import accumulate, zip_longest from providers.utils import ProviderError @@ -635,7 +635,9 @@ class TheMovieDatabase(Provider): show = await self.identify_show(show_id) # Dont forget to ingore the special season (season_number 0) seasons_nbrs = [x.season_number for x in show.seasons if x.season_number != 0] - seasons_eps = [x.episodes_count for x in show.seasons if x.season_number != 0] + seasons_eps = [ + cast(int, x.episodes_count) for x in show.seasons if x.season_number != 0 + ] if not any(seasons_nbrs): return (None, None) @@ -663,7 +665,7 @@ class TheMovieDatabase(Provider): show = await self.identify_show(show_id) return ( sum( - x.episodes_count + cast(int, x.episodes_count) for x in show.seasons if 0 < x.season_number < season ) diff --git a/scanner/providers/types/season.py b/scanner/providers/types/season.py index 0c224ece..a2568ff9 100644 --- a/scanner/providers/types/season.py +++ b/scanner/providers/types/season.py @@ -19,7 +19,7 @@ class Season: season_number: int # This is not used by kyoo, this is just used internaly by the TMDB provider. # maybe this should be moved? - episodes_count: int + episodes_count: Optional[int] start_air: Optional[date | int] = None end_air: Optional[date | int] = None external_id: dict[str, MetadataID] = field(default_factory=dict)