Handle large common data being sent to workers (30MB+)

This commit is contained in:
Kovid Goyal 2014-11-10 15:45:17 +05:30
parent b8e9bb0ca2
commit 9eae6a99c7

View File

@ -12,6 +12,7 @@ from collections import namedtuple
from Queue import Queue
from calibre import detect_ncpus, as_unicode, prints
from calibre.ptempfile import PersistentTemporaryFile
from calibre.utils import join_with_timeout
from calibre.utils.ipc import eintr_retry_call
@ -19,6 +20,9 @@ Job = namedtuple('Job', 'id module func args kwargs')
Result = namedtuple('Result', 'value err traceback')
WorkerResult = namedtuple('WorkerResult', 'id result is_terminal_failure worker')
TerminalFailure = namedtuple('TerminalFailure', 'message tb job_id')
File = namedtuple('File', 'name')
MAX_SIZE = 30 * 1024 * 1024 # max size of data to send over the connection (old versions of windows cannot handle arbitrary data lengths)
class Failure(Exception):
@ -36,7 +40,7 @@ class Worker(object):
self.name = name or ''
def __call__(self, job):
eintr_retry_call(self.conn.send, job)
eintr_retry_call(self.conn.send_bytes, cPickle.dumps(job, -1))
if job is not None:
self.job_id = job.id
t = Thread(target=self.recv, name='PoolWorker-'+self.name)
@ -45,7 +49,7 @@ class Worker(object):
def recv(self):
try:
result = eintr_retry_call(self.conn.recv)
result = cPickle.loads(eintr_retry_call(self.conn.recv_bytes))
wr = WorkerResult(self.job_id, result, False, self)
except Exception as err:
import traceback
@ -54,7 +58,7 @@ class Worker(object):
self.events.put(wr)
def set_common_data(self, data):
eintr_retry_call(self.conn.send, data)
eintr_retry_call(self.conn.send_bytes, data)
class Pool(Thread):
@ -71,7 +75,7 @@ class Pool(Thread):
self.results = Queue()
self.tracker = Queue()
self.terminal_failure = None
self.common_data = None
self.common_data = cPickle.dumps(None, -1)
self.worker_data = None
self.start()
@ -84,7 +88,10 @@ class Pool(Thread):
eintr_retry_call(p.stdin.write, self.worker_data)
p.stdin.flush(), p.stdin.close()
conn = eintr_retry_call(self.listener.accept)
return Worker(p, conn, self.events, self.name)
w = Worker(p, conn, self.events, self.name)
if self.common_data != cPickle.dumps(None, -1):
w.set_common_data(self.common_data)
return w
def set_common_data(self, data=None):
''' Set some data that will be passed to all subsequent jobs without
@ -94,11 +101,15 @@ class Pool(Thread):
jobs. Can raise the :class:`Failure` exception is data could not be
sent to workers.'''
with self.lock:
self.common_data = data
self.worker_data = cPickle.dumps((self.address, self.auth_key, self.common_data), -1)
self.common_data = cPickle.dumps(data, -1)
if len(self.common_data) > MAX_SIZE:
self.cd_file = PersistentTemporaryFile('pool_common_data')
with self.cd_file as f:
f.write(self.common_data)
self.common_data = cPickle.dumps(File(f.name), -1)
for worker in self.available_workers:
try:
worker.set_common_data(self.worker_data[-1])
worker.set_common_data(self.common_data)
except Exception:
import traceback
self.terminal_failure = TerminalFailure('Worker process crashed while sending common data', traceback.format_exc())
@ -118,7 +129,7 @@ class Pool(Thread):
from calibre.utils.ipc.server import create_listener
self.auth_key = os.urandom(32)
self.address, self.listener = create_listener(self.auth_key)
self.worker_data = cPickle.dumps((self.address, self.auth_key, self.common_data), -1)
self.worker_data = cPickle.dumps((self.address, self.auth_key), -1)
with self.lock:
if self.start_worker() is False:
return
@ -211,7 +222,7 @@ class Pool(Thread):
self.tracker.task_done()
while self.pending_jobs:
job = self.pending_jobs.pop()
self.results.put(WorkerResult(job.id, Result(None, None, None), True, worker))
self.results.put(WorkerResult(job.id, Result(None, None, None), True, None))
self.tracker.task_done()
self.shutdown_workers()
self.events.put(None)
@ -242,12 +253,18 @@ class Pool(Thread):
w.process.kill()
del self.available_workers[:]
self.busy_workers.clear()
if hasattr(self, 'cd_file'):
try:
os.remove(self.cd_file.name)
except EnvironmentError:
pass
def worker_main(conn, common_data):
def worker_main(conn):
from importlib import import_module
common_data = None
while True:
try:
job = eintr_retry_call(conn.recv)
job = cPickle.loads(eintr_retry_call(conn.recv_bytes))
except EOFError:
break
except Exception:
@ -258,7 +275,10 @@ def worker_main(conn, common_data):
if job is None:
break
if not isinstance(job, Job):
common_data = cPickle.loads(job)
if isinstance(job, File):
common_data = cPickle.load(open(job.name, 'rb'))
else:
common_data = job
continue
try:
if '\n' in job.module:
@ -276,7 +296,7 @@ def worker_main(conn, common_data):
import traceback
result = Result(None, as_unicode(err), traceback.format_exc())
try:
eintr_retry_call(conn.send, result)
eintr_retry_call(conn.send_bytes, cPickle.dumps(result, -1))
except EOFError:
break
except Exception:
@ -289,9 +309,9 @@ def worker_main(conn, common_data):
def run_main(func):
from multiprocessing.connection import Client
from contextlib import closing
address, key, common_data = cPickle.loads(eintr_retry_call(sys.stdin.read))
address, key = cPickle.loads(eintr_retry_call(sys.stdin.read))
with closing(Client(address, authkey=key)) as conn:
raise SystemExit(func(conn, common_data))
raise SystemExit(func(conn))
def test():
def get_results(pool, ignore_fail=False):
@ -325,13 +345,24 @@ def test():
p.set_common_data(7)
for i in range(1000):
p(i, 'def x(i, common_data=None):\n return common_data + i', 'x', i)
expected_results[i] = p.common_data + i
expected_results[i] = 7 + i
p.wait_for_tasks(30)
results = {k:v.value for k, v in get_results(p).iteritems()}
if results != expected_results:
raise SystemExit('%r != %r' % (expected_results, results))
p.shutdown(), p.join()
# Test large common data
p = Pool(name='Test')
data = b'a' * (4 * MAX_SIZE)
p.set_common_data(data)
p(0, 'def x(i, common_data=None):\n return len(common_data)', 'x', 0)
p.wait_for_tasks(30)
results = get_results(p)
if len(data) != results[0].value:
raise SystemExit('Common data was not returned correctly')
p.shutdown(), p.join()
# Test exceptions in jobs
p = Pool(name='Test')
for i in range(1000):
@ -361,3 +392,5 @@ def test():
if not p.failed:
raise SystemExit('No expected terminal failure')
p.shutdown(), p.join()
print ('Tests all passed!')