Switch to asyncpg & create initial listener

This commit is contained in:
Zoe Roux 2025-05-10 02:06:12 +02:00
parent 479e91e2e2
commit 7098a8326d
No known key found for this signature in database
7 changed files with 100 additions and 90 deletions

View File

@ -4,5 +4,4 @@ guessit@git+https://github.com/zoriya/guessit
aiohttp aiohttp
watchfiles watchfiles
langcodes langcodes
psycopg[binary,pool] asyncpg

View File

@ -1,27 +1,37 @@
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncpg
from fastapi import FastAPI from fastapi import FastAPI
from psycopg import AsyncConnection
from psycopg_pool import AsyncConnectionPool from .client import KyooClient
from .providers.composite import CompositeProvider
from .providers.themoviedatabase import TheMovieDatabase
from .requests import RequestProcessor
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.getLogger("watchfiles").setLevel(logging.WARNING) logging.getLogger("watchfiles").setLevel(logging.WARNING)
logging.getLogger("rebulk").setLevel(logging.WARNING) logging.getLogger("rebulk").setLevel(logging.WARNING)
pool = AsyncConnectionPool(open=False, kwargs={"autocommit": True})
@asynccontextmanager @asynccontextmanager
async def lifetime(): async def lifetime():
await pool.open() async with (
await asyncpg.create_pool() as pool,
create_request_processor(pool) as processor,
):
await processor.listen_for_requests()
yield yield
await pool.close()
async def get_db() -> AsyncConnection: @asynccontextmanager
async with pool.connection() as ret: async def create_request_processor(pool: asyncpg.Pool):
yield ret async with (
pool.acquire() as db,
KyooClient() as client,
TheMovieDatabase() as themoviedb,
):
yield RequestProcessor(db, client, CompositeProvider(themoviedb))
app = FastAPI( app = FastAPI(

View File

@ -1,5 +1,6 @@
import os import os
from logging import getLogger from logging import getLogger
from types import TracebackType
from aiohttp import ClientSession from aiohttp import ClientSession
@ -27,7 +28,12 @@ class KyooClient:
async def __aenter__(self): async def __aenter__(self):
return self return self
async def __aexit__(self): async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
):
await self._client.close() await self._client.close()
async def get_videos_info(self) -> VideoInfo: async def get_videos_info(self) -> VideoInfo:

View File

@ -8,9 +8,9 @@ from .provider import Provider
class CompositeProvider(Provider): class CompositeProvider(Provider):
def __init__(self): def __init__(self, themoviedb: Provider):
self._tvdb: Provider = None # type: ignore self._tvdb: Provider = None # type: ignore
self._themoviedb: Provider = None # type: ignore self._themoviedb = themoviedb
@property @property
@override @override

View File

@ -4,6 +4,7 @@ from collections.abc import Generator
from datetime import datetime from datetime import datetime
from logging import getLogger from logging import getLogger
from statistics import mean from statistics import mean
from types import TracebackType
from typing import Any, cast, override from typing import Any, cast, override
from aiohttp import ClientSession from aiohttp import ClientSession
@ -71,7 +72,12 @@ class TheMovieDatabase(Provider):
async def __aenter__(self): async def __aenter__(self):
return self return self
async def __aexit__(self): async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
):
await self._client.close() await self._client.close()
@property @property

View File

@ -3,8 +3,7 @@ from __future__ import annotations
from logging import getLogger from logging import getLogger
from typing import Literal from typing import Literal
from psycopg import AsyncConnection from asyncpg import Connection
from psycopg.rows import class_row, dict_row
from pydantic import Field from pydantic import Field
from .client import KyooClient from .client import KyooClient
@ -31,7 +30,7 @@ class Request(Model, extra="allow"):
class RequestProcessor: class RequestProcessor:
def __init__( def __init__(
self, self,
database: AsyncConnection, database: Connection,
client: KyooClient, client: KyooClient,
providers: CompositeProvider, providers: CompositeProvider,
): ):
@ -40,8 +39,7 @@ class RequestProcessor:
self._providers = providers self._providers = providers
async def enqueue(self, requests: list[Request]): async def enqueue(self, requests: list[Request]):
async with self._database.cursor() as cur: await self._database.executemany(
await cur.executemany(
""" """
insert into scanner.requests(kind, title, year, external_id, videos) insert into scanner.requests(kind, title, year, external_id, videos)
values (%(kind)s, %(title) s, %(year)s, %(external_id)s, %(videos)s) values (%(kind)s, %(title) s, %(year)s, %(external_id)s, %(videos)s)
@ -51,19 +49,14 @@ class RequestProcessor:
""", """,
(x.model_dump() for x in requests), (x.model_dump() for x in requests),
) )
# TODO: how will this conflict be handled if the request is already running _ = await self._database.execute("notify scanner.requests")
if cur.rowcount > 0:
_ = await cur.execute("notify scanner.requests")
async def process_requests(self): async def listen_for_requests(self):
async with await self._database.execute("listen scanner.requests"): logger.info("Listening for requestes")
gen = self._database.notifies() await self._database.add_listener("scanner.requests", self.process_request)
async for _ in gen:
await self._process_request()
async def _process_request(self): async def process_request(self):
async with self._database.cursor(row_factory=class_row(Request)) as cur: cur = await self._database.fetchrow(
cur = await cur.execute(
""" """
update update
scanner.requests scanner.requests
@ -85,16 +78,14 @@ class RequestProcessor:
* *
""" """
) )
request = await cur.fetchone() if cur is None:
if request is None:
return return
request = Request.model_validate(cur)
logger.info(f"Starting to process {request.title}") logger.info(f"Starting to process {request.title}")
try: try:
show = await self._run_request(request) show = await self._run_request(request)
finished = await self._database.fetchrow(
async with self._database.cursor(row_factory=dict_row) as cur:
cur = await cur.execute(
""" """
delete from scanner.requests delete from scanner.requests
where pk = %s where pk = %s
@ -103,8 +94,7 @@ class RequestProcessor:
""", """,
[request.pk], [request.pk],
) )
finished = await anext(cur) if finished and finished["videos"] != request.videos:
if finished["videos"] != request.videos:
await self._client.link_videos(show.slug, finished["videos"]) await self._client.link_videos(show.slug, finished["videos"])
except Exception as e: except Exception as e:
logger.error("Couldn't process request", exc_info=e) logger.error("Couldn't process request", exc_info=e)

View File

@ -7,8 +7,7 @@
aiohttp aiohttp
watchfiles watchfiles
langcodes langcodes
psycopg asyncpg
psycopg-pool
]); ]);
in in
pkgs.mkShell { pkgs.mkShell {