From 7a66bbbd407c7a038ccbb9765b6ac5d3be560acc Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sun, 9 Nov 2014 20:42:53 +0530 Subject: [PATCH] Add a facility to send common data (common to all jobs) to the workers in a pool --- src/calibre/utils/ipc/pool.py | 49 ++++++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/src/calibre/utils/ipc/pool.py b/src/calibre/utils/ipc/pool.py index 6e66d803bd..1e13e959cd 100644 --- a/src/calibre/utils/ipc/pool.py +++ b/src/calibre/utils/ipc/pool.py @@ -52,6 +52,9 @@ class Worker(object): wr = WorkerResult(self.job_id, result, True, self) self.events.put(wr) + def set_common_data(self, data): + eintr_retry_call(self.conn.send, data) + class Pool(Thread): daemon = True @@ -67,6 +70,7 @@ class Pool(Thread): self.results = Queue() self.tracker = Queue() self.terminal_failure = None + self.common_data = None self.start() @@ -75,11 +79,28 @@ class Pool(Thread): p = start_pipe_worker( 'from {0} import run_main, {1}; run_main({1})'.format(self.__class__.__module__, 'worker_main'), stdout=None) sys.stdout.flush() - eintr_retry_call(p.stdin.write, cPickle.dumps((self.address, self.auth_key), -1)) + eintr_retry_call(p.stdin.write, cPickle.dumps((self.address, self.auth_key, self.common_data), -1)) p.stdin.flush(), p.stdin.close() conn = eintr_retry_call(self.listener.accept) return Worker(p, conn, self.events, self.name) + def set_common_data(self, data=None): + ''' Set some data that will be passed to all subsequent jobs without + needing to be transmitted every time. You must call this method before + queueing any jobs, otherwise the behavior is undefined. You can call it + after all jobs are done, then it will be used for the new round of + jobs. ''' + with self.lock: + self.common_data = data + for worker in self.available_workers: + try: + worker.set_common_data(data) + except Exception: + import traceback + self.terminal_failure = TerminalFailure('Worker process crashed while sending common data', traceback.format_exc()) + self.terminal_error() + break + def start_worker(self): try: self.available_workers.append(self.create_worker()) @@ -217,7 +238,7 @@ class Pool(Thread): del self.available_workers[:] self.busy_workers.clear() -def worker_main(conn): +def worker_main(conn, common_data): from importlib import import_module while True: try: @@ -231,6 +252,9 @@ def worker_main(conn): return 1 if job is None: break + if not isinstance(job, Job): + common_data = job + continue try: if '\n' in job.module: import_module('calibre.customize.ui') # Load plugins @@ -239,6 +263,8 @@ def worker_main(conn): func = mod[job.func] else: func = getattr(import_module(job.module), job.func) + if common_data is not None: + job.kwargs['common_data'] = common_data result = func(*job.args, **job.kwargs) result = Result(result, None, None) except Exception as err: @@ -258,9 +284,9 @@ def worker_main(conn): def run_main(func): from multiprocessing.connection import Client from contextlib import closing - address, key = cPickle.loads(eintr_retry_call(sys.stdin.read)) + address, key, common_data = cPickle.loads(eintr_retry_call(sys.stdin.read)) with closing(Client(address, authkey=key)) as conn: - raise SystemExit(func(conn)) + raise SystemExit(func(conn, common_data)) def test(): def get_results(pool, ignore_fail=False): @@ -286,6 +312,21 @@ def test(): raise SystemExit('%r != %r' % (expected_results, results)) p.shutdown(), p.join() + # Test common_data + p = Pool(name='Test') + expected_results = {} + with p.lock: + p.start_worker() + p.set_common_data(7) + for i in range(1000): + p(i, 'def x(i, common_data=None):\n return common_data + i', 'x', i) + expected_results[i] = p.common_data + i + p.wait_for_tasks(30) + results = {k:v.value for k, v in get_results(p).iteritems()} + if results != expected_results: + raise SystemExit('%r != %r' % (expected_results, results)) + p.shutdown(), p.join() + # Test exceptions in jobs p = Pool(name='Test') for i in range(1000):