Switch to asyncpg & create initial listener

This commit is contained in:
Zoe Roux 2025-05-10 02:06:12 +02:00
parent 479e91e2e2
commit 7098a8326d
No known key found for this signature in database
7 changed files with 100 additions and 90 deletions

View File

@ -4,5 +4,4 @@ guessit@git+https://github.com/zoriya/guessit
aiohttp aiohttp
watchfiles watchfiles
langcodes langcodes
psycopg[binary,pool] asyncpg

View File

@ -1,27 +1,37 @@
import logging import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncpg
from fastapi import FastAPI from fastapi import FastAPI
from psycopg import AsyncConnection
from psycopg_pool import AsyncConnectionPool from .client import KyooClient
from .providers.composite import CompositeProvider
from .providers.themoviedatabase import TheMovieDatabase
from .requests import RequestProcessor
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logging.getLogger("watchfiles").setLevel(logging.WARNING) logging.getLogger("watchfiles").setLevel(logging.WARNING)
logging.getLogger("rebulk").setLevel(logging.WARNING) logging.getLogger("rebulk").setLevel(logging.WARNING)
pool = AsyncConnectionPool(open=False, kwargs={"autocommit": True})
@asynccontextmanager @asynccontextmanager
async def lifetime(): async def lifetime():
await pool.open() async with (
yield await asyncpg.create_pool() as pool,
await pool.close() create_request_processor(pool) as processor,
):
await processor.listen_for_requests()
yield
async def get_db() -> AsyncConnection: @asynccontextmanager
async with pool.connection() as ret: async def create_request_processor(pool: asyncpg.Pool):
yield ret async with (
pool.acquire() as db,
KyooClient() as client,
TheMovieDatabase() as themoviedb,
):
yield RequestProcessor(db, client, CompositeProvider(themoviedb))
app = FastAPI( app = FastAPI(

View File

@ -1,5 +1,6 @@
import os import os
from logging import getLogger from logging import getLogger
from types import TracebackType
from aiohttp import ClientSession from aiohttp import ClientSession
@ -27,7 +28,12 @@ class KyooClient:
async def __aenter__(self): async def __aenter__(self):
return self return self
async def __aexit__(self): async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
):
await self._client.close() await self._client.close()
async def get_videos_info(self) -> VideoInfo: async def get_videos_info(self) -> VideoInfo:

View File

@ -8,9 +8,9 @@ from .provider import Provider
class CompositeProvider(Provider): class CompositeProvider(Provider):
def __init__(self): def __init__(self, themoviedb: Provider):
self._tvdb: Provider = None # type: ignore self._tvdb: Provider = None # type: ignore
self._themoviedb: Provider = None # type: ignore self._themoviedb = themoviedb
@property @property
@override @override

View File

@ -4,6 +4,7 @@ from collections.abc import Generator
from datetime import datetime from datetime import datetime
from logging import getLogger from logging import getLogger
from statistics import mean from statistics import mean
from types import TracebackType
from typing import Any, cast, override from typing import Any, cast, override
from aiohttp import ClientSession from aiohttp import ClientSession
@ -71,7 +72,12 @@ class TheMovieDatabase(Provider):
async def __aenter__(self): async def __aenter__(self):
return self return self
async def __aexit__(self): async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
):
await self._client.close() await self._client.close()
@property @property

View File

@ -3,8 +3,7 @@ from __future__ import annotations
from logging import getLogger from logging import getLogger
from typing import Literal from typing import Literal
from psycopg import AsyncConnection from asyncpg import Connection
from psycopg.rows import class_row, dict_row
from pydantic import Field from pydantic import Field
from .client import KyooClient from .client import KyooClient
@ -31,7 +30,7 @@ class Request(Model, extra="allow"):
class RequestProcessor: class RequestProcessor:
def __init__( def __init__(
self, self,
database: AsyncConnection, database: Connection,
client: KyooClient, client: KyooClient,
providers: CompositeProvider, providers: CompositeProvider,
): ):
@ -40,85 +39,76 @@ class RequestProcessor:
self._providers = providers self._providers = providers
async def enqueue(self, requests: list[Request]): async def enqueue(self, requests: list[Request]):
async with self._database.cursor() as cur: await self._database.executemany(
await cur.executemany( """
insert into scanner.requests(kind, title, year, external_id, videos)
values (%(kind)s, %(title) s, %(year)s, %(external_id)s, %(videos)s)
on conflict (kind, title, year)
do update set
videos = videos || excluded.videos
""",
(x.model_dump() for x in requests),
)
_ = await self._database.execute("notify scanner.requests")
async def listen_for_requests(self):
logger.info("Listening for requestes")
await self._database.add_listener("scanner.requests", self.process_request)
async def process_request(self):
cur = await self._database.fetchrow(
"""
update
scanner.requests
set
status = 'running',
started_at = nom()::timestamptz
where
pk in (
select
*
from
scanner.requests
where
status = 'pending'
limit 1
for update
skip locked)
returning
*
"""
)
if cur is None:
return
request = Request.model_validate(cur)
logger.info(f"Starting to process {request.title}")
try:
show = await self._run_request(request)
finished = await self._database.fetchrow(
""" """
insert into scanner.requests(kind, title, year, external_id, videos) delete from scanner.requests
values (%(kind)s, %(title) s, %(year)s, %(external_id)s, %(videos)s) where pk = %s
on conflict (kind, title, year) returning
do update set videos
videos = videos || excluded.videos
""", """,
(x.model_dump() for x in requests), [request.pk],
) )
# TODO: how will this conflict be handled if the request is already running if finished and finished["videos"] != request.videos:
if cur.rowcount > 0: await self._client.link_videos(show.slug, finished["videos"])
_ = await cur.execute("notify scanner.requests") except Exception as e:
logger.error("Couldn't process request", exc_info=e)
async def process_requests(self):
async with await self._database.execute("listen scanner.requests"):
gen = self._database.notifies()
async for _ in gen:
await self._process_request()
async def _process_request(self):
async with self._database.cursor(row_factory=class_row(Request)) as cur:
cur = await cur.execute( cur = await cur.execute(
""" """
update update
scanner.requests scanner.requests
set set
status = 'running', status = 'failed'
started_at = nom()::timestamptz
where where
pk in ( pk = %s
select """,
* [request.pk],
from
scanner.requests
where
status = 'pending'
limit 1
for update
skip locked)
returning
*
"""
) )
request = await cur.fetchone()
if request is None:
return
logger.info(f"Starting to process {request.title}")
try:
show = await self._run_request(request)
async with self._database.cursor(row_factory=dict_row) as cur:
cur = await cur.execute(
"""
delete from scanner.requests
where pk = %s
returning
videos
""",
[request.pk],
)
finished = await anext(cur)
if finished["videos"] != request.videos:
await self._client.link_videos(show.slug, finished["videos"])
except Exception as e:
logger.error("Couldn't process request", exc_info=e)
cur = await cur.execute(
"""
update
scanner.requests
set
status = 'failed'
where
pk = %s
""",
[request.pk],
)
async def _run_request(self, request: Request) -> Resource: async def _run_request(self, request: Request) -> Resource:
if request.kind == "movie": if request.kind == "movie":

View File

@ -7,8 +7,7 @@
aiohttp aiohttp
watchfiles watchfiles
langcodes langcodes
psycopg asyncpg
psycopg-pool
]); ]);
in in
pkgs.mkShell { pkgs.mkShell {