Add a basic test for the fetch backend

This commit is contained in:
Kovid Goyal 2024-08-09 10:44:02 +05:30
parent 4f077f1934
commit d437f1b644
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
3 changed files with 94 additions and 4 deletions

View File

@ -55,7 +55,7 @@ class FakeResponse:
class Browser:
def __init__(self, user_agent: str = '', headers: tuple[tuple[str, str], ...] = ()):
def __init__(self, user_agent: str = '', headers: tuple[tuple[str, str], ...] = (), start_worker: bool = False):
self.tdir = ''
self.worker = self.dispatcher = None
self.dispatch_map = {}
@ -64,6 +64,8 @@ class Browser:
self.user_agent = user_agent
self.lock = Lock()
self.shutting_down = False
if start_worker:
self._ensure_state()
def open(self, url_or_request, data=None, timeout=None):
if data is not None:

View File

@ -27,7 +27,7 @@ class RequestInterceptor(QWebEngineUrlRequestInterceptor):
fb: FetchBackend = self.parent()
if fb:
key = qurl_to_key(req.requestUrl())
if dr := fb.download_requests[key]:
if dr := fb.download_requests.get(key):
for (name, val) in dr.headers:
req.setHttpHeader(name.encode(), val.encode())
@ -36,7 +36,10 @@ def qurl_to_string(url: QUrl | str) -> str:
return bytes(QUrl(url).toEncoded()).decode()
qurl_to_key = qurl_to_string
def qurl_to_key(url: QUrl | str) -> str:
return qurl_to_string(url).rstrip('/')
Headers = list[tuple[str, str]]
@ -91,6 +94,7 @@ class FetchBackend(QWebEnginePage):
self.output_dir = output_dir or os.getcwd()
profile.setDownloadPath(self.output_dir)
super().__init__(profile, parent)
sys.excepthook = self.excepthook
self.interceptor = RequestInterceptor(self)
profile.setUrlRequestInterceptor(self.interceptor)
self.request_download.connect(self.download, type=Qt.ConnectionType.QueuedConnection)
@ -104,6 +108,11 @@ class FetchBackend(QWebEnginePage):
t.setInterval(50)
t.timeout.connect(self.enforce_timeouts)
def excepthook(self, cls: type, exc: Exception, tb) -> None:
if not isinstance(exc, KeyboardInterrupt):
sys.__excepthook__(cls, exc, tb)
QApplication.instance().exit(1)
def on_input_finished(self, error_msg: str) -> None:
if error_msg:
self.send_response({'action': 'input_error', 'error': error_msg})

View File

@ -1,14 +1,18 @@
#!/usr/bin/env python
# License: GPLv3 Copyright: 2024, Kovid Goyal <kovid at kovidgoyal.net>
import http.server
import json
import os
import re
import unittest
from threading import Event, Thread
from lxml.html import fromstring, tostring
from calibre.utils.resources import get_path as P
from .fetch import Browser
from .simple import Overseer
skip = ''
@ -43,5 +47,80 @@ class TestSimpleWebEngineScraper(unittest.TestCase):
self.assertFalse(w)
class Handler(http.server.BaseHTTPRequestHandler):
request_count = 0
def do_GET(self):
h = {}
for k, v in self.headers.items():
h.setdefault(k, []).append(v)
ans = {
'path': self.path,
'headers': h,
'request_count': self.request_count,
}
data = json.dumps(ans).encode()
self.send_response(http.HTTPStatus.OK)
self.send_header('Content-type', 'application/json')
self.send_header('Content-Length', str(len(data)))
self.end_headers()
self.flush_headers()
self.wfile.write(data)
def log_request(self, code='-', size='-'):
self.request_count += 1
@unittest.skipIf(skip, skip)
class TestFetchBackend(unittest.TestCase):
ae = unittest.TestCase.assertEqual
def setUp(self):
self.server_started = Event()
self.server_thread = Thread(target=self.run_server, daemon=True)
self.server_thread.start()
self.server_started.wait(5)
self.request_count = 0
def tearDown(self):
self.server.shutdown()
self.server_thread.join(5)
def test_recipe_browser(self):
def u(path=''):
return f'http://localhost:{self.port}{path}'
def get(path=''):
raw = br.open(u(path)).read()
return json.loads(raw)
br = Browser(user_agent='test-ua', headers=(('th', '1'),), start_worker=True)
try:
r = get()
self.ae(r['request_count'], 0)
print(r)
self.ae(r['headers']['th'], ['1'])
self.ae(r['headers']['User-Agent'], ['test-ua'])
self.assertIn('Accept-Encoding', r['headers'])
finally:
br.shutdown()
def run_server(self):
import socketserver
def create_handler(*a):
ans = Handler(*a)
ans.backend = self
return ans
with socketserver.TCPServer(("", 0), create_handler) as httpd:
self.server = httpd
self.port = httpd.server_address[1]
self.server_started.set()
httpd.serve_forever()
def find_tests():
return unittest.defaultTestLoader.loadTestsFromTestCase(TestSimpleWebEngineScraper)
ans = unittest.defaultTestLoader.loadTestsFromTestCase(TestSimpleWebEngineScraper)
ans.addTests(iter(unittest.defaultTestLoader.loadTestsFromTestCase(TestFetchBackend)))
return ans