mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
Handle large common data being sent to workers (30MB+)
This commit is contained in:
parent
b8e9bb0ca2
commit
9eae6a99c7
@ -12,6 +12,7 @@ from collections import namedtuple
|
||||
from Queue import Queue
|
||||
|
||||
from calibre import detect_ncpus, as_unicode, prints
|
||||
from calibre.ptempfile import PersistentTemporaryFile
|
||||
from calibre.utils import join_with_timeout
|
||||
from calibre.utils.ipc import eintr_retry_call
|
||||
|
||||
@ -19,6 +20,9 @@ Job = namedtuple('Job', 'id module func args kwargs')
|
||||
Result = namedtuple('Result', 'value err traceback')
|
||||
WorkerResult = namedtuple('WorkerResult', 'id result is_terminal_failure worker')
|
||||
TerminalFailure = namedtuple('TerminalFailure', 'message tb job_id')
|
||||
File = namedtuple('File', 'name')
|
||||
|
||||
MAX_SIZE = 30 * 1024 * 1024 # max size of data to send over the connection (old versions of windows cannot handle arbitrary data lengths)
|
||||
|
||||
class Failure(Exception):
|
||||
|
||||
@ -36,7 +40,7 @@ class Worker(object):
|
||||
self.name = name or ''
|
||||
|
||||
def __call__(self, job):
|
||||
eintr_retry_call(self.conn.send, job)
|
||||
eintr_retry_call(self.conn.send_bytes, cPickle.dumps(job, -1))
|
||||
if job is not None:
|
||||
self.job_id = job.id
|
||||
t = Thread(target=self.recv, name='PoolWorker-'+self.name)
|
||||
@ -45,7 +49,7 @@ class Worker(object):
|
||||
|
||||
def recv(self):
|
||||
try:
|
||||
result = eintr_retry_call(self.conn.recv)
|
||||
result = cPickle.loads(eintr_retry_call(self.conn.recv_bytes))
|
||||
wr = WorkerResult(self.job_id, result, False, self)
|
||||
except Exception as err:
|
||||
import traceback
|
||||
@ -54,7 +58,7 @@ class Worker(object):
|
||||
self.events.put(wr)
|
||||
|
||||
def set_common_data(self, data):
|
||||
eintr_retry_call(self.conn.send, data)
|
||||
eintr_retry_call(self.conn.send_bytes, data)
|
||||
|
||||
class Pool(Thread):
|
||||
|
||||
@ -71,7 +75,7 @@ class Pool(Thread):
|
||||
self.results = Queue()
|
||||
self.tracker = Queue()
|
||||
self.terminal_failure = None
|
||||
self.common_data = None
|
||||
self.common_data = cPickle.dumps(None, -1)
|
||||
self.worker_data = None
|
||||
|
||||
self.start()
|
||||
@ -84,7 +88,10 @@ class Pool(Thread):
|
||||
eintr_retry_call(p.stdin.write, self.worker_data)
|
||||
p.stdin.flush(), p.stdin.close()
|
||||
conn = eintr_retry_call(self.listener.accept)
|
||||
return Worker(p, conn, self.events, self.name)
|
||||
w = Worker(p, conn, self.events, self.name)
|
||||
if self.common_data != cPickle.dumps(None, -1):
|
||||
w.set_common_data(self.common_data)
|
||||
return w
|
||||
|
||||
def set_common_data(self, data=None):
|
||||
''' Set some data that will be passed to all subsequent jobs without
|
||||
@ -94,11 +101,15 @@ class Pool(Thread):
|
||||
jobs. Can raise the :class:`Failure` exception is data could not be
|
||||
sent to workers.'''
|
||||
with self.lock:
|
||||
self.common_data = data
|
||||
self.worker_data = cPickle.dumps((self.address, self.auth_key, self.common_data), -1)
|
||||
self.common_data = cPickle.dumps(data, -1)
|
||||
if len(self.common_data) > MAX_SIZE:
|
||||
self.cd_file = PersistentTemporaryFile('pool_common_data')
|
||||
with self.cd_file as f:
|
||||
f.write(self.common_data)
|
||||
self.common_data = cPickle.dumps(File(f.name), -1)
|
||||
for worker in self.available_workers:
|
||||
try:
|
||||
worker.set_common_data(self.worker_data[-1])
|
||||
worker.set_common_data(self.common_data)
|
||||
except Exception:
|
||||
import traceback
|
||||
self.terminal_failure = TerminalFailure('Worker process crashed while sending common data', traceback.format_exc())
|
||||
@ -118,7 +129,7 @@ class Pool(Thread):
|
||||
from calibre.utils.ipc.server import create_listener
|
||||
self.auth_key = os.urandom(32)
|
||||
self.address, self.listener = create_listener(self.auth_key)
|
||||
self.worker_data = cPickle.dumps((self.address, self.auth_key, self.common_data), -1)
|
||||
self.worker_data = cPickle.dumps((self.address, self.auth_key), -1)
|
||||
with self.lock:
|
||||
if self.start_worker() is False:
|
||||
return
|
||||
@ -211,7 +222,7 @@ class Pool(Thread):
|
||||
self.tracker.task_done()
|
||||
while self.pending_jobs:
|
||||
job = self.pending_jobs.pop()
|
||||
self.results.put(WorkerResult(job.id, Result(None, None, None), True, worker))
|
||||
self.results.put(WorkerResult(job.id, Result(None, None, None), True, None))
|
||||
self.tracker.task_done()
|
||||
self.shutdown_workers()
|
||||
self.events.put(None)
|
||||
@ -242,12 +253,18 @@ class Pool(Thread):
|
||||
w.process.kill()
|
||||
del self.available_workers[:]
|
||||
self.busy_workers.clear()
|
||||
if hasattr(self, 'cd_file'):
|
||||
try:
|
||||
os.remove(self.cd_file.name)
|
||||
except EnvironmentError:
|
||||
pass
|
||||
|
||||
def worker_main(conn, common_data):
|
||||
def worker_main(conn):
|
||||
from importlib import import_module
|
||||
common_data = None
|
||||
while True:
|
||||
try:
|
||||
job = eintr_retry_call(conn.recv)
|
||||
job = cPickle.loads(eintr_retry_call(conn.recv_bytes))
|
||||
except EOFError:
|
||||
break
|
||||
except Exception:
|
||||
@ -258,7 +275,10 @@ def worker_main(conn, common_data):
|
||||
if job is None:
|
||||
break
|
||||
if not isinstance(job, Job):
|
||||
common_data = cPickle.loads(job)
|
||||
if isinstance(job, File):
|
||||
common_data = cPickle.load(open(job.name, 'rb'))
|
||||
else:
|
||||
common_data = job
|
||||
continue
|
||||
try:
|
||||
if '\n' in job.module:
|
||||
@ -276,7 +296,7 @@ def worker_main(conn, common_data):
|
||||
import traceback
|
||||
result = Result(None, as_unicode(err), traceback.format_exc())
|
||||
try:
|
||||
eintr_retry_call(conn.send, result)
|
||||
eintr_retry_call(conn.send_bytes, cPickle.dumps(result, -1))
|
||||
except EOFError:
|
||||
break
|
||||
except Exception:
|
||||
@ -289,9 +309,9 @@ def worker_main(conn, common_data):
|
||||
def run_main(func):
|
||||
from multiprocessing.connection import Client
|
||||
from contextlib import closing
|
||||
address, key, common_data = cPickle.loads(eintr_retry_call(sys.stdin.read))
|
||||
address, key = cPickle.loads(eintr_retry_call(sys.stdin.read))
|
||||
with closing(Client(address, authkey=key)) as conn:
|
||||
raise SystemExit(func(conn, common_data))
|
||||
raise SystemExit(func(conn))
|
||||
|
||||
def test():
|
||||
def get_results(pool, ignore_fail=False):
|
||||
@ -325,13 +345,24 @@ def test():
|
||||
p.set_common_data(7)
|
||||
for i in range(1000):
|
||||
p(i, 'def x(i, common_data=None):\n return common_data + i', 'x', i)
|
||||
expected_results[i] = p.common_data + i
|
||||
expected_results[i] = 7 + i
|
||||
p.wait_for_tasks(30)
|
||||
results = {k:v.value for k, v in get_results(p).iteritems()}
|
||||
if results != expected_results:
|
||||
raise SystemExit('%r != %r' % (expected_results, results))
|
||||
p.shutdown(), p.join()
|
||||
|
||||
# Test large common data
|
||||
p = Pool(name='Test')
|
||||
data = b'a' * (4 * MAX_SIZE)
|
||||
p.set_common_data(data)
|
||||
p(0, 'def x(i, common_data=None):\n return len(common_data)', 'x', 0)
|
||||
p.wait_for_tasks(30)
|
||||
results = get_results(p)
|
||||
if len(data) != results[0].value:
|
||||
raise SystemExit('Common data was not returned correctly')
|
||||
p.shutdown(), p.join()
|
||||
|
||||
# Test exceptions in jobs
|
||||
p = Pool(name='Test')
|
||||
for i in range(1000):
|
||||
@ -361,3 +392,5 @@ def test():
|
||||
if not p.failed:
|
||||
raise SystemExit('No expected terminal failure')
|
||||
p.shutdown(), p.join()
|
||||
|
||||
print ('Tests all passed!')
|
||||
|
Loading…
x
Reference in New Issue
Block a user