diff --git a/scanner/scanner/identifiers/anilist.py b/scanner/scanner/identifiers/anilist.py index 4ce0e7f2..ee10afd3 100644 --- a/scanner/scanner/identifiers/anilist.py +++ b/scanner/scanner/identifiers/anilist.py @@ -113,7 +113,11 @@ class AnimeListData: async def get_anilist_data() -> AnimeListData: logger.info("Fetching anime-lists XML databases...") ret = AnimeListData(fetched_at=datetime.now()) - async with ClientSession() as session: + async with ClientSession( + headers={ + "User-Agent": "kyoo scanner v5", + }, + ) as session: async with session.get(AnimeTitlesDb.get_url()) as resp: resp.raise_for_status() titles = AnimeTitlesDb.from_xml(await resp.read()) diff --git a/scanner/scanner/identifiers/guess/guess.py b/scanner/scanner/identifiers/guess/guess.py index d69150b8..2b86ba63 100644 --- a/scanner/scanner/identifiers/guess/guess.py +++ b/scanner/scanner/identifiers/guess/guess.py @@ -13,6 +13,7 @@ rblk = cast(Rebulk, default_api.rebulk).rules(rules) def guessit( name: str, *, + expected_titles: list[str] = [], extra_flags: dict[str, Any] = {}, ) -> dict[str, list[Match]]: return default_api.guessit( @@ -20,6 +21,7 @@ def guessit( { "episode_prefer_number": True, "excludes": "language", + "expected_titles": expected_titles, "enforce_list": True, "advanced": True, } diff --git a/scanner/scanner/identifiers/guess/rules.py b/scanner/scanner/identifiers/guess/rules.py index 9c38c0b9..7a55a655 100644 --- a/scanner/scanner/identifiers/guess/rules.py +++ b/scanner/scanner/identifiers/guess/rules.py @@ -76,6 +76,95 @@ class UnlistTitles(Rule): return [titles, [title]] +class ExpectedTitles(Rule): + """Fix both alternate names and seasons that are known titles but parsed differently by guessit + + Example: "JoJo's Bizarre Adventure - Diamond is Unbreakable - 12.mkv" + Default: + ```json + { + "title": "JoJo's Bizarre Adventure", + "alternative_title": "Diamond is Unbreakable", + "episode": 12, + } + ``` + Expected: + ```json + { + "title": "JoJo's Bizarre Adventure - Diamond is Unbreakable", + "episode": 12, + } + ``` + + Or + Example: 'Owarimonogatari S2 E15.mkv' + Default: + ```json + { + "title": "Owarimonogatari", + "season": 2, + "episode": 15 + } + ``` + Expected: + ```json + { + "title": "Owarimonogatari S2", + "episode": 15 + } + ``` + """ + + priority = POST_PROCESS + consequence = [RemoveMatch, AppendMatch] + + @override + def when(self, matches: Matches, context) -> Any: + from ..anilist import normalize_title + + titles: list[Match] = matches.named("title", lambda m: m.tagged("title")) # type: ignore + + if not titles or not context["expected_titles"]: + return + title = titles[0] + + # Greedily collect all adjacent matches that could be part of the title + absorbed: list[Match] = [] + current = title + while True: + nmatch: list[Match] = matches.next(current) + if not nmatch or not ( + nmatch[0].tagged("title") + or nmatch[0].named("season") + or nmatch[0].named("part") + ): + break + absorbed.append(nmatch[0]) + current = nmatch[0] + if not absorbed: + return + + # Try longest combined title first, then progressively shorter ones + for end in range(len(absorbed), 0, -1): + candidate_matches = absorbed[:end] + + mtitle = f"{title.value}" + prev = title + for m in candidate_matches: + holes: list[Match] = matches.holes(prev.end, m.start) # type: ignore + hole = "".join( + f" {h.value}" if h.value != "-" else " - " for h in holes + ) + mtitle = f"{mtitle}{hole}{m.value}" + prev = m + + if normalize_title(mtitle) in context["expected_titles"]: + new_title = copy(title) + new_title.end = candidate_matches[-1].end + new_title.value = mtitle + return [[title] + candidate_matches, [new_title]] + + class MultipleSeasonRule(Rule): """Understand `abcd Season 2 - 5.mkv` as S2E5 diff --git a/scanner/scanner/identifiers/identify.py b/scanner/scanner/identifiers/identify.py index 2ac9ea96..41cc7db5 100644 --- a/scanner/scanner/identifiers/identify.py +++ b/scanner/scanner/identifiers/identify.py @@ -7,7 +7,7 @@ from typing import Callable, Literal, cast from rebulk.match import Match from ..models.videos import Guess, Video -from .anilist import identify_anilist +from .anilist import get_anilist_data, identify_anilist from .guess.guess import guessit logger = getLogger(__name__) @@ -20,7 +20,10 @@ pipeline: list[Callable[[str, Guess], Awaitable[Guess]]] = [ async def identify(path: str) -> Video: - raw = guessit(path) + raw = guessit( + path, + expected_titles=list((await get_anilist_data()).titles.keys()), + ) # guessit should only return one (according to the doc) title = raw.get("title", [])[0]