From 9eae6a99c7cda50dde9281efed8fa7223498581b Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Mon, 10 Nov 2014 15:45:17 +0530 Subject: [PATCH] Handle large common data being sent to workers (30MB+) --- src/calibre/utils/ipc/pool.py | 67 ++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/src/calibre/utils/ipc/pool.py b/src/calibre/utils/ipc/pool.py index 21995cc260..dfc3b9bfe0 100644 --- a/src/calibre/utils/ipc/pool.py +++ b/src/calibre/utils/ipc/pool.py @@ -12,6 +12,7 @@ from collections import namedtuple from Queue import Queue from calibre import detect_ncpus, as_unicode, prints +from calibre.ptempfile import PersistentTemporaryFile from calibre.utils import join_with_timeout from calibre.utils.ipc import eintr_retry_call @@ -19,6 +20,9 @@ Job = namedtuple('Job', 'id module func args kwargs') Result = namedtuple('Result', 'value err traceback') WorkerResult = namedtuple('WorkerResult', 'id result is_terminal_failure worker') TerminalFailure = namedtuple('TerminalFailure', 'message tb job_id') +File = namedtuple('File', 'name') + +MAX_SIZE = 30 * 1024 * 1024 # max size of data to send over the connection (old versions of windows cannot handle arbitrary data lengths) class Failure(Exception): @@ -36,7 +40,7 @@ class Worker(object): self.name = name or '' def __call__(self, job): - eintr_retry_call(self.conn.send, job) + eintr_retry_call(self.conn.send_bytes, cPickle.dumps(job, -1)) if job is not None: self.job_id = job.id t = Thread(target=self.recv, name='PoolWorker-'+self.name) @@ -45,7 +49,7 @@ class Worker(object): def recv(self): try: - result = eintr_retry_call(self.conn.recv) + result = cPickle.loads(eintr_retry_call(self.conn.recv_bytes)) wr = WorkerResult(self.job_id, result, False, self) except Exception as err: import traceback @@ -54,7 +58,7 @@ class Worker(object): self.events.put(wr) def set_common_data(self, data): - eintr_retry_call(self.conn.send, data) + eintr_retry_call(self.conn.send_bytes, data) class Pool(Thread): @@ -71,7 +75,7 @@ class Pool(Thread): self.results = Queue() self.tracker = Queue() self.terminal_failure = None - self.common_data = None + self.common_data = cPickle.dumps(None, -1) self.worker_data = None self.start() @@ -84,7 +88,10 @@ class Pool(Thread): eintr_retry_call(p.stdin.write, self.worker_data) p.stdin.flush(), p.stdin.close() conn = eintr_retry_call(self.listener.accept) - return Worker(p, conn, self.events, self.name) + w = Worker(p, conn, self.events, self.name) + if self.common_data != cPickle.dumps(None, -1): + w.set_common_data(self.common_data) + return w def set_common_data(self, data=None): ''' Set some data that will be passed to all subsequent jobs without @@ -94,11 +101,15 @@ class Pool(Thread): jobs. Can raise the :class:`Failure` exception is data could not be sent to workers.''' with self.lock: - self.common_data = data - self.worker_data = cPickle.dumps((self.address, self.auth_key, self.common_data), -1) + self.common_data = cPickle.dumps(data, -1) + if len(self.common_data) > MAX_SIZE: + self.cd_file = PersistentTemporaryFile('pool_common_data') + with self.cd_file as f: + f.write(self.common_data) + self.common_data = cPickle.dumps(File(f.name), -1) for worker in self.available_workers: try: - worker.set_common_data(self.worker_data[-1]) + worker.set_common_data(self.common_data) except Exception: import traceback self.terminal_failure = TerminalFailure('Worker process crashed while sending common data', traceback.format_exc()) @@ -118,7 +129,7 @@ class Pool(Thread): from calibre.utils.ipc.server import create_listener self.auth_key = os.urandom(32) self.address, self.listener = create_listener(self.auth_key) - self.worker_data = cPickle.dumps((self.address, self.auth_key, self.common_data), -1) + self.worker_data = cPickle.dumps((self.address, self.auth_key), -1) with self.lock: if self.start_worker() is False: return @@ -211,7 +222,7 @@ class Pool(Thread): self.tracker.task_done() while self.pending_jobs: job = self.pending_jobs.pop() - self.results.put(WorkerResult(job.id, Result(None, None, None), True, worker)) + self.results.put(WorkerResult(job.id, Result(None, None, None), True, None)) self.tracker.task_done() self.shutdown_workers() self.events.put(None) @@ -242,12 +253,18 @@ class Pool(Thread): w.process.kill() del self.available_workers[:] self.busy_workers.clear() + if hasattr(self, 'cd_file'): + try: + os.remove(self.cd_file.name) + except EnvironmentError: + pass -def worker_main(conn, common_data): +def worker_main(conn): from importlib import import_module + common_data = None while True: try: - job = eintr_retry_call(conn.recv) + job = cPickle.loads(eintr_retry_call(conn.recv_bytes)) except EOFError: break except Exception: @@ -258,7 +275,10 @@ def worker_main(conn, common_data): if job is None: break if not isinstance(job, Job): - common_data = cPickle.loads(job) + if isinstance(job, File): + common_data = cPickle.load(open(job.name, 'rb')) + else: + common_data = job continue try: if '\n' in job.module: @@ -276,7 +296,7 @@ def worker_main(conn, common_data): import traceback result = Result(None, as_unicode(err), traceback.format_exc()) try: - eintr_retry_call(conn.send, result) + eintr_retry_call(conn.send_bytes, cPickle.dumps(result, -1)) except EOFError: break except Exception: @@ -289,9 +309,9 @@ def worker_main(conn, common_data): def run_main(func): from multiprocessing.connection import Client from contextlib import closing - address, key, common_data = cPickle.loads(eintr_retry_call(sys.stdin.read)) + address, key = cPickle.loads(eintr_retry_call(sys.stdin.read)) with closing(Client(address, authkey=key)) as conn: - raise SystemExit(func(conn, common_data)) + raise SystemExit(func(conn)) def test(): def get_results(pool, ignore_fail=False): @@ -325,13 +345,24 @@ def test(): p.set_common_data(7) for i in range(1000): p(i, 'def x(i, common_data=None):\n return common_data + i', 'x', i) - expected_results[i] = p.common_data + i + expected_results[i] = 7 + i p.wait_for_tasks(30) results = {k:v.value for k, v in get_results(p).iteritems()} if results != expected_results: raise SystemExit('%r != %r' % (expected_results, results)) p.shutdown(), p.join() + # Test large common data + p = Pool(name='Test') + data = b'a' * (4 * MAX_SIZE) + p.set_common_data(data) + p(0, 'def x(i, common_data=None):\n return len(common_data)', 'x', 0) + p.wait_for_tasks(30) + results = get_results(p) + if len(data) != results[0].value: + raise SystemExit('Common data was not returned correctly') + p.shutdown(), p.join() + # Test exceptions in jobs p = Pool(name='Test') for i in range(1000): @@ -361,3 +392,5 @@ def test(): if not p.failed: raise SystemExit('No expected terminal failure') p.shutdown(), p.join() + + print ('Tests all passed!')