mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-08 18:54:09 -04:00
Add a facility to send common data (common to all jobs) to the workers in a pool
This commit is contained in:
parent
feb269da46
commit
7a66bbbd40
@ -52,6 +52,9 @@ class Worker(object):
|
||||
wr = WorkerResult(self.job_id, result, True, self)
|
||||
self.events.put(wr)
|
||||
|
||||
def set_common_data(self, data):
|
||||
eintr_retry_call(self.conn.send, data)
|
||||
|
||||
class Pool(Thread):
|
||||
|
||||
daemon = True
|
||||
@ -67,6 +70,7 @@ class Pool(Thread):
|
||||
self.results = Queue()
|
||||
self.tracker = Queue()
|
||||
self.terminal_failure = None
|
||||
self.common_data = None
|
||||
|
||||
self.start()
|
||||
|
||||
@ -75,11 +79,28 @@ class Pool(Thread):
|
||||
p = start_pipe_worker(
|
||||
'from {0} import run_main, {1}; run_main({1})'.format(self.__class__.__module__, 'worker_main'), stdout=None)
|
||||
sys.stdout.flush()
|
||||
eintr_retry_call(p.stdin.write, cPickle.dumps((self.address, self.auth_key), -1))
|
||||
eintr_retry_call(p.stdin.write, cPickle.dumps((self.address, self.auth_key, self.common_data), -1))
|
||||
p.stdin.flush(), p.stdin.close()
|
||||
conn = eintr_retry_call(self.listener.accept)
|
||||
return Worker(p, conn, self.events, self.name)
|
||||
|
||||
def set_common_data(self, data=None):
|
||||
''' Set some data that will be passed to all subsequent jobs without
|
||||
needing to be transmitted every time. You must call this method before
|
||||
queueing any jobs, otherwise the behavior is undefined. You can call it
|
||||
after all jobs are done, then it will be used for the new round of
|
||||
jobs. '''
|
||||
with self.lock:
|
||||
self.common_data = data
|
||||
for worker in self.available_workers:
|
||||
try:
|
||||
worker.set_common_data(data)
|
||||
except Exception:
|
||||
import traceback
|
||||
self.terminal_failure = TerminalFailure('Worker process crashed while sending common data', traceback.format_exc())
|
||||
self.terminal_error()
|
||||
break
|
||||
|
||||
def start_worker(self):
|
||||
try:
|
||||
self.available_workers.append(self.create_worker())
|
||||
@ -217,7 +238,7 @@ class Pool(Thread):
|
||||
del self.available_workers[:]
|
||||
self.busy_workers.clear()
|
||||
|
||||
def worker_main(conn):
|
||||
def worker_main(conn, common_data):
|
||||
from importlib import import_module
|
||||
while True:
|
||||
try:
|
||||
@ -231,6 +252,9 @@ def worker_main(conn):
|
||||
return 1
|
||||
if job is None:
|
||||
break
|
||||
if not isinstance(job, Job):
|
||||
common_data = job
|
||||
continue
|
||||
try:
|
||||
if '\n' in job.module:
|
||||
import_module('calibre.customize.ui') # Load plugins
|
||||
@ -239,6 +263,8 @@ def worker_main(conn):
|
||||
func = mod[job.func]
|
||||
else:
|
||||
func = getattr(import_module(job.module), job.func)
|
||||
if common_data is not None:
|
||||
job.kwargs['common_data'] = common_data
|
||||
result = func(*job.args, **job.kwargs)
|
||||
result = Result(result, None, None)
|
||||
except Exception as err:
|
||||
@ -258,9 +284,9 @@ def worker_main(conn):
|
||||
def run_main(func):
|
||||
from multiprocessing.connection import Client
|
||||
from contextlib import closing
|
||||
address, key = cPickle.loads(eintr_retry_call(sys.stdin.read))
|
||||
address, key, common_data = cPickle.loads(eintr_retry_call(sys.stdin.read))
|
||||
with closing(Client(address, authkey=key)) as conn:
|
||||
raise SystemExit(func(conn))
|
||||
raise SystemExit(func(conn, common_data))
|
||||
|
||||
def test():
|
||||
def get_results(pool, ignore_fail=False):
|
||||
@ -286,6 +312,21 @@ def test():
|
||||
raise SystemExit('%r != %r' % (expected_results, results))
|
||||
p.shutdown(), p.join()
|
||||
|
||||
# Test common_data
|
||||
p = Pool(name='Test')
|
||||
expected_results = {}
|
||||
with p.lock:
|
||||
p.start_worker()
|
||||
p.set_common_data(7)
|
||||
for i in range(1000):
|
||||
p(i, 'def x(i, common_data=None):\n return common_data + i', 'x', i)
|
||||
expected_results[i] = p.common_data + i
|
||||
p.wait_for_tasks(30)
|
||||
results = {k:v.value for k, v in get_results(p).iteritems()}
|
||||
if results != expected_results:
|
||||
raise SystemExit('%r != %r' % (expected_results, results))
|
||||
p.shutdown(), p.join()
|
||||
|
||||
# Test exceptions in jobs
|
||||
p = Pool(name='Test')
|
||||
for i in range(1000):
|
||||
|
Loading…
x
Reference in New Issue
Block a user