mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
Add a facility to send common data (common to all jobs) to the workers in a pool
This commit is contained in:
parent
feb269da46
commit
7a66bbbd40
@ -52,6 +52,9 @@ class Worker(object):
|
|||||||
wr = WorkerResult(self.job_id, result, True, self)
|
wr = WorkerResult(self.job_id, result, True, self)
|
||||||
self.events.put(wr)
|
self.events.put(wr)
|
||||||
|
|
||||||
|
def set_common_data(self, data):
|
||||||
|
eintr_retry_call(self.conn.send, data)
|
||||||
|
|
||||||
class Pool(Thread):
|
class Pool(Thread):
|
||||||
|
|
||||||
daemon = True
|
daemon = True
|
||||||
@ -67,6 +70,7 @@ class Pool(Thread):
|
|||||||
self.results = Queue()
|
self.results = Queue()
|
||||||
self.tracker = Queue()
|
self.tracker = Queue()
|
||||||
self.terminal_failure = None
|
self.terminal_failure = None
|
||||||
|
self.common_data = None
|
||||||
|
|
||||||
self.start()
|
self.start()
|
||||||
|
|
||||||
@ -75,11 +79,28 @@ class Pool(Thread):
|
|||||||
p = start_pipe_worker(
|
p = start_pipe_worker(
|
||||||
'from {0} import run_main, {1}; run_main({1})'.format(self.__class__.__module__, 'worker_main'), stdout=None)
|
'from {0} import run_main, {1}; run_main({1})'.format(self.__class__.__module__, 'worker_main'), stdout=None)
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
eintr_retry_call(p.stdin.write, cPickle.dumps((self.address, self.auth_key), -1))
|
eintr_retry_call(p.stdin.write, cPickle.dumps((self.address, self.auth_key, self.common_data), -1))
|
||||||
p.stdin.flush(), p.stdin.close()
|
p.stdin.flush(), p.stdin.close()
|
||||||
conn = eintr_retry_call(self.listener.accept)
|
conn = eintr_retry_call(self.listener.accept)
|
||||||
return Worker(p, conn, self.events, self.name)
|
return Worker(p, conn, self.events, self.name)
|
||||||
|
|
||||||
|
def set_common_data(self, data=None):
|
||||||
|
''' Set some data that will be passed to all subsequent jobs without
|
||||||
|
needing to be transmitted every time. You must call this method before
|
||||||
|
queueing any jobs, otherwise the behavior is undefined. You can call it
|
||||||
|
after all jobs are done, then it will be used for the new round of
|
||||||
|
jobs. '''
|
||||||
|
with self.lock:
|
||||||
|
self.common_data = data
|
||||||
|
for worker in self.available_workers:
|
||||||
|
try:
|
||||||
|
worker.set_common_data(data)
|
||||||
|
except Exception:
|
||||||
|
import traceback
|
||||||
|
self.terminal_failure = TerminalFailure('Worker process crashed while sending common data', traceback.format_exc())
|
||||||
|
self.terminal_error()
|
||||||
|
break
|
||||||
|
|
||||||
def start_worker(self):
|
def start_worker(self):
|
||||||
try:
|
try:
|
||||||
self.available_workers.append(self.create_worker())
|
self.available_workers.append(self.create_worker())
|
||||||
@ -217,7 +238,7 @@ class Pool(Thread):
|
|||||||
del self.available_workers[:]
|
del self.available_workers[:]
|
||||||
self.busy_workers.clear()
|
self.busy_workers.clear()
|
||||||
|
|
||||||
def worker_main(conn):
|
def worker_main(conn, common_data):
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -231,6 +252,9 @@ def worker_main(conn):
|
|||||||
return 1
|
return 1
|
||||||
if job is None:
|
if job is None:
|
||||||
break
|
break
|
||||||
|
if not isinstance(job, Job):
|
||||||
|
common_data = job
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
if '\n' in job.module:
|
if '\n' in job.module:
|
||||||
import_module('calibre.customize.ui') # Load plugins
|
import_module('calibre.customize.ui') # Load plugins
|
||||||
@ -239,6 +263,8 @@ def worker_main(conn):
|
|||||||
func = mod[job.func]
|
func = mod[job.func]
|
||||||
else:
|
else:
|
||||||
func = getattr(import_module(job.module), job.func)
|
func = getattr(import_module(job.module), job.func)
|
||||||
|
if common_data is not None:
|
||||||
|
job.kwargs['common_data'] = common_data
|
||||||
result = func(*job.args, **job.kwargs)
|
result = func(*job.args, **job.kwargs)
|
||||||
result = Result(result, None, None)
|
result = Result(result, None, None)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@ -258,9 +284,9 @@ def worker_main(conn):
|
|||||||
def run_main(func):
|
def run_main(func):
|
||||||
from multiprocessing.connection import Client
|
from multiprocessing.connection import Client
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
address, key = cPickle.loads(eintr_retry_call(sys.stdin.read))
|
address, key, common_data = cPickle.loads(eintr_retry_call(sys.stdin.read))
|
||||||
with closing(Client(address, authkey=key)) as conn:
|
with closing(Client(address, authkey=key)) as conn:
|
||||||
raise SystemExit(func(conn))
|
raise SystemExit(func(conn, common_data))
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
def get_results(pool, ignore_fail=False):
|
def get_results(pool, ignore_fail=False):
|
||||||
@ -286,6 +312,21 @@ def test():
|
|||||||
raise SystemExit('%r != %r' % (expected_results, results))
|
raise SystemExit('%r != %r' % (expected_results, results))
|
||||||
p.shutdown(), p.join()
|
p.shutdown(), p.join()
|
||||||
|
|
||||||
|
# Test common_data
|
||||||
|
p = Pool(name='Test')
|
||||||
|
expected_results = {}
|
||||||
|
with p.lock:
|
||||||
|
p.start_worker()
|
||||||
|
p.set_common_data(7)
|
||||||
|
for i in range(1000):
|
||||||
|
p(i, 'def x(i, common_data=None):\n return common_data + i', 'x', i)
|
||||||
|
expected_results[i] = p.common_data + i
|
||||||
|
p.wait_for_tasks(30)
|
||||||
|
results = {k:v.value for k, v in get_results(p).iteritems()}
|
||||||
|
if results != expected_results:
|
||||||
|
raise SystemExit('%r != %r' % (expected_results, results))
|
||||||
|
p.shutdown(), p.join()
|
||||||
|
|
||||||
# Test exceptions in jobs
|
# Test exceptions in jobs
|
||||||
p = Pool(name='Test')
|
p = Pool(name='Test')
|
||||||
for i in range(1000):
|
for i in range(1000):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user