Add a basic test for the fetch backend

This commit is contained in:
Kovid Goyal 2024-08-09 10:44:02 +05:30
parent 4f077f1934
commit d437f1b644
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 94 additions and 4 deletions

View File

@ -55,7 +55,7 @@ class FakeResponse:
class Browser: class Browser:
def __init__(self, user_agent: str = '', headers: tuple[tuple[str, str], ...] = ()): def __init__(self, user_agent: str = '', headers: tuple[tuple[str, str], ...] = (), start_worker: bool = False):
self.tdir = '' self.tdir = ''
self.worker = self.dispatcher = None self.worker = self.dispatcher = None
self.dispatch_map = {} self.dispatch_map = {}
@ -64,6 +64,8 @@ class Browser:
self.user_agent = user_agent self.user_agent = user_agent
self.lock = Lock() self.lock = Lock()
self.shutting_down = False self.shutting_down = False
if start_worker:
self._ensure_state()
def open(self, url_or_request, data=None, timeout=None): def open(self, url_or_request, data=None, timeout=None):
if data is not None: if data is not None:

View File

@ -27,7 +27,7 @@ class RequestInterceptor(QWebEngineUrlRequestInterceptor):
fb: FetchBackend = self.parent() fb: FetchBackend = self.parent()
if fb: if fb:
key = qurl_to_key(req.requestUrl()) key = qurl_to_key(req.requestUrl())
if dr := fb.download_requests[key]: if dr := fb.download_requests.get(key):
for (name, val) in dr.headers: for (name, val) in dr.headers:
req.setHttpHeader(name.encode(), val.encode()) req.setHttpHeader(name.encode(), val.encode())
@ -36,7 +36,10 @@ def qurl_to_string(url: QUrl | str) -> str:
return bytes(QUrl(url).toEncoded()).decode() return bytes(QUrl(url).toEncoded()).decode()
qurl_to_key = qurl_to_string def qurl_to_key(url: QUrl | str) -> str:
return qurl_to_string(url).rstrip('/')
Headers = list[tuple[str, str]] Headers = list[tuple[str, str]]
@ -91,6 +94,7 @@ class FetchBackend(QWebEnginePage):
self.output_dir = output_dir or os.getcwd() self.output_dir = output_dir or os.getcwd()
profile.setDownloadPath(self.output_dir) profile.setDownloadPath(self.output_dir)
super().__init__(profile, parent) super().__init__(profile, parent)
sys.excepthook = self.excepthook
self.interceptor = RequestInterceptor(self) self.interceptor = RequestInterceptor(self)
profile.setUrlRequestInterceptor(self.interceptor) profile.setUrlRequestInterceptor(self.interceptor)
self.request_download.connect(self.download, type=Qt.ConnectionType.QueuedConnection) self.request_download.connect(self.download, type=Qt.ConnectionType.QueuedConnection)
@ -104,6 +108,11 @@ class FetchBackend(QWebEnginePage):
t.setInterval(50) t.setInterval(50)
t.timeout.connect(self.enforce_timeouts) t.timeout.connect(self.enforce_timeouts)
def excepthook(self, cls: type, exc: Exception, tb) -> None:
if not isinstance(exc, KeyboardInterrupt):
sys.__excepthook__(cls, exc, tb)
QApplication.instance().exit(1)
def on_input_finished(self, error_msg: str) -> None: def on_input_finished(self, error_msg: str) -> None:
if error_msg: if error_msg:
self.send_response({'action': 'input_error', 'error': error_msg}) self.send_response({'action': 'input_error', 'error': error_msg})

View File

@ -1,14 +1,18 @@
#!/usr/bin/env python #!/usr/bin/env python
# License: GPLv3 Copyright: 2024, Kovid Goyal <kovid at kovidgoyal.net> # License: GPLv3 Copyright: 2024, Kovid Goyal <kovid at kovidgoyal.net>
import http.server
import json
import os import os
import re import re
import unittest import unittest
from threading import Event, Thread
from lxml.html import fromstring, tostring from lxml.html import fromstring, tostring
from calibre.utils.resources import get_path as P from calibre.utils.resources import get_path as P
from .fetch import Browser
from .simple import Overseer from .simple import Overseer
skip = '' skip = ''
@ -43,5 +47,80 @@ class TestSimpleWebEngineScraper(unittest.TestCase):
self.assertFalse(w) self.assertFalse(w)
class Handler(http.server.BaseHTTPRequestHandler):
request_count = 0
def do_GET(self):
h = {}
for k, v in self.headers.items():
h.setdefault(k, []).append(v)
ans = {
'path': self.path,
'headers': h,
'request_count': self.request_count,
}
data = json.dumps(ans).encode()
self.send_response(http.HTTPStatus.OK)
self.send_header('Content-type', 'application/json')
self.send_header('Content-Length', str(len(data)))
self.end_headers()
self.flush_headers()
self.wfile.write(data)
def log_request(self, code='-', size='-'):
self.request_count += 1
@unittest.skipIf(skip, skip)
class TestFetchBackend(unittest.TestCase):
ae = unittest.TestCase.assertEqual
def setUp(self):
self.server_started = Event()
self.server_thread = Thread(target=self.run_server, daemon=True)
self.server_thread.start()
self.server_started.wait(5)
self.request_count = 0
def tearDown(self):
self.server.shutdown()
self.server_thread.join(5)
def test_recipe_browser(self):
def u(path=''):
return f'http://localhost:{self.port}{path}'
def get(path=''):
raw = br.open(u(path)).read()
return json.loads(raw)
br = Browser(user_agent='test-ua', headers=(('th', '1'),), start_worker=True)
try:
r = get()
self.ae(r['request_count'], 0)
print(r)
self.ae(r['headers']['th'], ['1'])
self.ae(r['headers']['User-Agent'], ['test-ua'])
self.assertIn('Accept-Encoding', r['headers'])
finally:
br.shutdown()
def run_server(self):
import socketserver
def create_handler(*a):
ans = Handler(*a)
ans.backend = self
return ans
with socketserver.TCPServer(("", 0), create_handler) as httpd:
self.server = httpd
self.port = httpd.server_address[1]
self.server_started.set()
httpd.serve_forever()
def find_tests(): def find_tests():
return unittest.defaultTestLoader.loadTestsFromTestCase(TestSimpleWebEngineScraper) ans = unittest.defaultTestLoader.loadTestsFromTestCase(TestSimpleWebEngineScraper)
ans.addTests(iter(unittest.defaultTestLoader.loadTestsFromTestCase(TestFetchBackend)))
return ans