diff --git a/src/calibre/scraper/fetch.py b/src/calibre/scraper/fetch.py index e00d82f351..78fd2ab5c6 100644 --- a/src/calibre/scraper/fetch.py +++ b/src/calibre/scraper/fetch.py @@ -55,7 +55,7 @@ class FakeResponse: class Browser: - def __init__(self, user_agent: str = '', headers: tuple[tuple[str, str], ...] = ()): + def __init__(self, user_agent: str = '', headers: tuple[tuple[str, str], ...] = (), start_worker: bool = False): self.tdir = '' self.worker = self.dispatcher = None self.dispatch_map = {} @@ -64,6 +64,8 @@ class Browser: self.user_agent = user_agent self.lock = Lock() self.shutting_down = False + if start_worker: + self._ensure_state() def open(self, url_or_request, data=None, timeout=None): if data is not None: diff --git a/src/calibre/scraper/fetch_backend.py b/src/calibre/scraper/fetch_backend.py index 0e99e0a547..a720972f70 100644 --- a/src/calibre/scraper/fetch_backend.py +++ b/src/calibre/scraper/fetch_backend.py @@ -27,7 +27,7 @@ class RequestInterceptor(QWebEngineUrlRequestInterceptor): fb: FetchBackend = self.parent() if fb: key = qurl_to_key(req.requestUrl()) - if dr := fb.download_requests[key]: + if dr := fb.download_requests.get(key): for (name, val) in dr.headers: req.setHttpHeader(name.encode(), val.encode()) @@ -36,7 +36,10 @@ def qurl_to_string(url: QUrl | str) -> str: return bytes(QUrl(url).toEncoded()).decode() -qurl_to_key = qurl_to_string +def qurl_to_key(url: QUrl | str) -> str: + return qurl_to_string(url).rstrip('/') + + Headers = list[tuple[str, str]] @@ -91,6 +94,7 @@ class FetchBackend(QWebEnginePage): self.output_dir = output_dir or os.getcwd() profile.setDownloadPath(self.output_dir) super().__init__(profile, parent) + sys.excepthook = self.excepthook self.interceptor = RequestInterceptor(self) profile.setUrlRequestInterceptor(self.interceptor) self.request_download.connect(self.download, type=Qt.ConnectionType.QueuedConnection) @@ -104,6 +108,11 @@ class FetchBackend(QWebEnginePage): t.setInterval(50) t.timeout.connect(self.enforce_timeouts) + def excepthook(self, cls: type, exc: Exception, tb) -> None: + if not isinstance(exc, KeyboardInterrupt): + sys.__excepthook__(cls, exc, tb) + QApplication.instance().exit(1) + def on_input_finished(self, error_msg: str) -> None: if error_msg: self.send_response({'action': 'input_error', 'error': error_msg}) diff --git a/src/calibre/scraper/test_fetch_backend.py b/src/calibre/scraper/test_fetch_backend.py index 2f2b12d0fd..cec5cbd70b 100644 --- a/src/calibre/scraper/test_fetch_backend.py +++ b/src/calibre/scraper/test_fetch_backend.py @@ -1,14 +1,18 @@ #!/usr/bin/env python # License: GPLv3 Copyright: 2024, Kovid Goyal +import http.server +import json import os import re import unittest +from threading import Event, Thread from lxml.html import fromstring, tostring from calibre.utils.resources import get_path as P +from .fetch import Browser from .simple import Overseer skip = '' @@ -43,5 +47,80 @@ class TestSimpleWebEngineScraper(unittest.TestCase): self.assertFalse(w) +class Handler(http.server.BaseHTTPRequestHandler): + + request_count = 0 + + def do_GET(self): + h = {} + for k, v in self.headers.items(): + h.setdefault(k, []).append(v) + ans = { + 'path': self.path, + 'headers': h, + 'request_count': self.request_count, + } + data = json.dumps(ans).encode() + self.send_response(http.HTTPStatus.OK) + self.send_header('Content-type', 'application/json') + self.send_header('Content-Length', str(len(data))) + self.end_headers() + self.flush_headers() + self.wfile.write(data) + + def log_request(self, code='-', size='-'): + self.request_count += 1 + + +@unittest.skipIf(skip, skip) +class TestFetchBackend(unittest.TestCase): + + ae = unittest.TestCase.assertEqual + + def setUp(self): + self.server_started = Event() + self.server_thread = Thread(target=self.run_server, daemon=True) + self.server_thread.start() + self.server_started.wait(5) + self.request_count = 0 + + def tearDown(self): + self.server.shutdown() + self.server_thread.join(5) + + def test_recipe_browser(self): + def u(path=''): + return f'http://localhost:{self.port}{path}' + def get(path=''): + raw = br.open(u(path)).read() + return json.loads(raw) + br = Browser(user_agent='test-ua', headers=(('th', '1'),), start_worker=True) + try: + r = get() + self.ae(r['request_count'], 0) + print(r) + self.ae(r['headers']['th'], ['1']) + self.ae(r['headers']['User-Agent'], ['test-ua']) + self.assertIn('Accept-Encoding', r['headers']) + finally: + br.shutdown() + + def run_server(self): + import socketserver + + def create_handler(*a): + ans = Handler(*a) + ans.backend = self + return ans + + with socketserver.TCPServer(("", 0), create_handler) as httpd: + self.server = httpd + self.port = httpd.server_address[1] + self.server_started.set() + httpd.serve_forever() + + def find_tests(): - return unittest.defaultTestLoader.loadTestsFromTestCase(TestSimpleWebEngineScraper) + ans = unittest.defaultTestLoader.loadTestsFromTestCase(TestSimpleWebEngineScraper) + ans.addTests(iter(unittest.defaultTestLoader.loadTestsFromTestCase(TestFetchBackend))) + return ans