From 53867792f331084543b55ce1e8ac8b8e6cba9543 Mon Sep 17 00:00:00 2001 From: Zoe Roux Date: Sat, 10 May 2025 02:06:12 +0200 Subject: [PATCH] Switch to asyncpg & create initial listener --- scanner/requirements.txt | 3 +- scanner/scanner/__init__.py | 30 ++-- scanner/scanner/client.py | 8 +- scanner/scanner/providers/composite.py | 4 +- scanner/scanner/providers/themoviedatabase.py | 8 +- scanner/scanner/requests.py | 134 ++++++++---------- scanner/shell.nix | 3 +- 7 files changed, 100 insertions(+), 90 deletions(-) diff --git a/scanner/requirements.txt b/scanner/requirements.txt index 30e94a49..0575ee6c 100644 --- a/scanner/requirements.txt +++ b/scanner/requirements.txt @@ -4,5 +4,4 @@ guessit@git+https://github.com/zoriya/guessit aiohttp watchfiles langcodes -psycopg[binary,pool] - +asyncpg diff --git a/scanner/scanner/__init__.py b/scanner/scanner/__init__.py index 7798f60e..58013167 100644 --- a/scanner/scanner/__init__.py +++ b/scanner/scanner/__init__.py @@ -1,27 +1,37 @@ import logging from contextlib import asynccontextmanager +import asyncpg from fastapi import FastAPI -from psycopg import AsyncConnection -from psycopg_pool import AsyncConnectionPool + +from .client import KyooClient +from .providers.composite import CompositeProvider +from .providers.themoviedatabase import TheMovieDatabase +from .requests import RequestProcessor logging.basicConfig(level=logging.INFO) logging.getLogger("watchfiles").setLevel(logging.WARNING) logging.getLogger("rebulk").setLevel(logging.WARNING) -pool = AsyncConnectionPool(open=False, kwargs={"autocommit": True}) - @asynccontextmanager async def lifetime(): - await pool.open() - yield - await pool.close() + async with ( + await asyncpg.create_pool() as pool, + create_request_processor(pool) as processor, + ): + await processor.listen_for_requests() + yield -async def get_db() -> AsyncConnection: - async with pool.connection() as ret: - yield ret +@asynccontextmanager +async def create_request_processor(pool: asyncpg.Pool): + async with ( + pool.acquire() as db, + KyooClient() as client, + TheMovieDatabase() as themoviedb, + ): + yield RequestProcessor(db, client, CompositeProvider(themoviedb)) app = FastAPI( diff --git a/scanner/scanner/client.py b/scanner/scanner/client.py index e211dcec..9478fd2a 100644 --- a/scanner/scanner/client.py +++ b/scanner/scanner/client.py @@ -1,5 +1,6 @@ import os from logging import getLogger +from types import TracebackType from aiohttp import ClientSession @@ -27,7 +28,12 @@ class KyooClient: async def __aenter__(self): return self - async def __aexit__(self): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): await self._client.close() async def get_videos_info(self) -> VideoInfo: diff --git a/scanner/scanner/providers/composite.py b/scanner/scanner/providers/composite.py index 5f0e2da5..dc52af8d 100644 --- a/scanner/scanner/providers/composite.py +++ b/scanner/scanner/providers/composite.py @@ -8,9 +8,9 @@ from .provider import Provider class CompositeProvider(Provider): - def __init__(self): + def __init__(self, themoviedb: Provider): self._tvdb: Provider = None # type: ignore - self._themoviedb: Provider = None # type: ignore + self._themoviedb = themoviedb @property @override diff --git a/scanner/scanner/providers/themoviedatabase.py b/scanner/scanner/providers/themoviedatabase.py index ddb33c09..66a1467f 100644 --- a/scanner/scanner/providers/themoviedatabase.py +++ b/scanner/scanner/providers/themoviedatabase.py @@ -4,6 +4,7 @@ from collections.abc import Generator from datetime import datetime from logging import getLogger from statistics import mean +from types import TracebackType from typing import Any, cast, override from aiohttp import ClientSession @@ -71,7 +72,12 @@ class TheMovieDatabase(Provider): async def __aenter__(self): return self - async def __aexit__(self): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): await self._client.close() @property diff --git a/scanner/scanner/requests.py b/scanner/scanner/requests.py index acb5abb8..41789f82 100644 --- a/scanner/scanner/requests.py +++ b/scanner/scanner/requests.py @@ -3,8 +3,7 @@ from __future__ import annotations from logging import getLogger from typing import Literal -from psycopg import AsyncConnection -from psycopg.rows import class_row, dict_row +from asyncpg import Connection from pydantic import Field from .client import KyooClient @@ -31,7 +30,7 @@ class Request(Model, extra="allow"): class RequestProcessor: def __init__( self, - database: AsyncConnection, + database: Connection, client: KyooClient, providers: CompositeProvider, ): @@ -40,85 +39,76 @@ class RequestProcessor: self._providers = providers async def enqueue(self, requests: list[Request]): - async with self._database.cursor() as cur: - await cur.executemany( + await self._database.executemany( + """ + insert into scanner.requests(kind, title, year, external_id, videos) + values (%(kind)s, %(title) s, %(year)s, %(external_id)s, %(videos)s) + on conflict (kind, title, year) + do update set + videos = videos || excluded.videos + """, + (x.model_dump() for x in requests), + ) + _ = await self._database.execute("notify scanner.requests") + + async def listen_for_requests(self): + logger.info("Listening for requestes") + await self._database.add_listener("scanner.requests", self.process_request) + + async def process_request(self): + cur = await self._database.fetchrow( + """ + update + scanner.requests + set + status = 'running', + started_at = nom()::timestamptz + where + pk in ( + select + * + from + scanner.requests + where + status = 'pending' + limit 1 + for update + skip locked) + returning + * + """ + ) + if cur is None: + return + request = Request.model_validate(cur) + + logger.info(f"Starting to process {request.title}") + try: + show = await self._run_request(request) + finished = await self._database.fetchrow( """ - insert into scanner.requests(kind, title, year, external_id, videos) - values (%(kind)s, %(title) s, %(year)s, %(external_id)s, %(videos)s) - on conflict (kind, title, year) - do update set - videos = videos || excluded.videos + delete from scanner.requests + where pk = %s + returning + videos """, - (x.model_dump() for x in requests), + [request.pk], ) - # TODO: how will this conflict be handled if the request is already running - if cur.rowcount > 0: - _ = await cur.execute("notify scanner.requests") - - async def process_requests(self): - async with await self._database.execute("listen scanner.requests"): - gen = self._database.notifies() - async for _ in gen: - await self._process_request() - - async def _process_request(self): - async with self._database.cursor(row_factory=class_row(Request)) as cur: + if finished and finished["videos"] != request.videos: + await self._client.link_videos(show.slug, finished["videos"]) + except Exception as e: + logger.error("Couldn't process request", exc_info=e) cur = await cur.execute( """ update scanner.requests set - status = 'running', - started_at = nom()::timestamptz + status = 'failed' where - pk in ( - select - * - from - scanner.requests - where - status = 'pending' - limit 1 - for update - skip locked) - returning - * - """ + pk = %s + """, + [request.pk], ) - request = await cur.fetchone() - if request is None: - return - - logger.info(f"Starting to process {request.title}") - try: - show = await self._run_request(request) - - async with self._database.cursor(row_factory=dict_row) as cur: - cur = await cur.execute( - """ - delete from scanner.requests - where pk = %s - returning - videos - """, - [request.pk], - ) - finished = await anext(cur) - if finished["videos"] != request.videos: - await self._client.link_videos(show.slug, finished["videos"]) - except Exception as e: - logger.error("Couldn't process request", exc_info=e) - cur = await cur.execute( - """ - update - scanner.requests - set - status = 'failed' - where - pk = %s - """, - [request.pk], - ) async def _run_request(self, request: Request) -> Resource: if request.kind == "movie": diff --git a/scanner/shell.nix b/scanner/shell.nix index 0b62e6c8..816ec910 100644 --- a/scanner/shell.nix +++ b/scanner/shell.nix @@ -7,8 +7,7 @@ aiohttp watchfiles langcodes - psycopg - psycopg-pool + asyncpg ]); in pkgs.mkShell {