Add a facility to send common data (common to all jobs) to the workers in a pool

This commit is contained in:
Kovid Goyal 2014-11-09 20:42:53 +05:30
parent feb269da46
commit 7a66bbbd40

View File

@ -52,6 +52,9 @@ class Worker(object):
wr = WorkerResult(self.job_id, result, True, self)
self.events.put(wr)
def set_common_data(self, data):
eintr_retry_call(self.conn.send, data)
class Pool(Thread):
daemon = True
@ -67,6 +70,7 @@ class Pool(Thread):
self.results = Queue()
self.tracker = Queue()
self.terminal_failure = None
self.common_data = None
self.start()
@ -75,11 +79,28 @@ class Pool(Thread):
p = start_pipe_worker(
'from {0} import run_main, {1}; run_main({1})'.format(self.__class__.__module__, 'worker_main'), stdout=None)
sys.stdout.flush()
eintr_retry_call(p.stdin.write, cPickle.dumps((self.address, self.auth_key), -1))
eintr_retry_call(p.stdin.write, cPickle.dumps((self.address, self.auth_key, self.common_data), -1))
p.stdin.flush(), p.stdin.close()
conn = eintr_retry_call(self.listener.accept)
return Worker(p, conn, self.events, self.name)
def set_common_data(self, data=None):
''' Set some data that will be passed to all subsequent jobs without
needing to be transmitted every time. You must call this method before
queueing any jobs, otherwise the behavior is undefined. You can call it
after all jobs are done, then it will be used for the new round of
jobs. '''
with self.lock:
self.common_data = data
for worker in self.available_workers:
try:
worker.set_common_data(data)
except Exception:
import traceback
self.terminal_failure = TerminalFailure('Worker process crashed while sending common data', traceback.format_exc())
self.terminal_error()
break
def start_worker(self):
try:
self.available_workers.append(self.create_worker())
@ -217,7 +238,7 @@ class Pool(Thread):
del self.available_workers[:]
self.busy_workers.clear()
def worker_main(conn):
def worker_main(conn, common_data):
from importlib import import_module
while True:
try:
@ -231,6 +252,9 @@ def worker_main(conn):
return 1
if job is None:
break
if not isinstance(job, Job):
common_data = job
continue
try:
if '\n' in job.module:
import_module('calibre.customize.ui') # Load plugins
@ -239,6 +263,8 @@ def worker_main(conn):
func = mod[job.func]
else:
func = getattr(import_module(job.module), job.func)
if common_data is not None:
job.kwargs['common_data'] = common_data
result = func(*job.args, **job.kwargs)
result = Result(result, None, None)
except Exception as err:
@ -258,9 +284,9 @@ def worker_main(conn):
def run_main(func):
from multiprocessing.connection import Client
from contextlib import closing
address, key = cPickle.loads(eintr_retry_call(sys.stdin.read))
address, key, common_data = cPickle.loads(eintr_retry_call(sys.stdin.read))
with closing(Client(address, authkey=key)) as conn:
raise SystemExit(func(conn))
raise SystemExit(func(conn, common_data))
def test():
def get_results(pool, ignore_fail=False):
@ -286,6 +312,21 @@ def test():
raise SystemExit('%r != %r' % (expected_results, results))
p.shutdown(), p.join()
# Test common_data
p = Pool(name='Test')
expected_results = {}
with p.lock:
p.start_worker()
p.set_common_data(7)
for i in range(1000):
p(i, 'def x(i, common_data=None):\n return common_data + i', 'x', i)
expected_results[i] = p.common_data + i
p.wait_for_tasks(30)
results = {k:v.value for k, v in get_results(p).iteritems()}
if results != expected_results:
raise SystemExit('%r != %r' % (expected_results, results))
p.shutdown(), p.join()
# Test exceptions in jobs
p = Pool(name='Test')
for i in range(1000):