Handle large common data being sent to workers (30MB+)

This commit is contained in:
Kovid Goyal 2014-11-10 15:45:17 +05:30
parent b8e9bb0ca2
commit 9eae6a99c7

View File

@ -12,6 +12,7 @@ from collections import namedtuple
from Queue import Queue from Queue import Queue
from calibre import detect_ncpus, as_unicode, prints from calibre import detect_ncpus, as_unicode, prints
from calibre.ptempfile import PersistentTemporaryFile
from calibre.utils import join_with_timeout from calibre.utils import join_with_timeout
from calibre.utils.ipc import eintr_retry_call from calibre.utils.ipc import eintr_retry_call
@ -19,6 +20,9 @@ Job = namedtuple('Job', 'id module func args kwargs')
Result = namedtuple('Result', 'value err traceback') Result = namedtuple('Result', 'value err traceback')
WorkerResult = namedtuple('WorkerResult', 'id result is_terminal_failure worker') WorkerResult = namedtuple('WorkerResult', 'id result is_terminal_failure worker')
TerminalFailure = namedtuple('TerminalFailure', 'message tb job_id') TerminalFailure = namedtuple('TerminalFailure', 'message tb job_id')
File = namedtuple('File', 'name')
MAX_SIZE = 30 * 1024 * 1024 # max size of data to send over the connection (old versions of windows cannot handle arbitrary data lengths)
class Failure(Exception): class Failure(Exception):
@ -36,7 +40,7 @@ class Worker(object):
self.name = name or '' self.name = name or ''
def __call__(self, job): def __call__(self, job):
eintr_retry_call(self.conn.send, job) eintr_retry_call(self.conn.send_bytes, cPickle.dumps(job, -1))
if job is not None: if job is not None:
self.job_id = job.id self.job_id = job.id
t = Thread(target=self.recv, name='PoolWorker-'+self.name) t = Thread(target=self.recv, name='PoolWorker-'+self.name)
@ -45,7 +49,7 @@ class Worker(object):
def recv(self): def recv(self):
try: try:
result = eintr_retry_call(self.conn.recv) result = cPickle.loads(eintr_retry_call(self.conn.recv_bytes))
wr = WorkerResult(self.job_id, result, False, self) wr = WorkerResult(self.job_id, result, False, self)
except Exception as err: except Exception as err:
import traceback import traceback
@ -54,7 +58,7 @@ class Worker(object):
self.events.put(wr) self.events.put(wr)
def set_common_data(self, data): def set_common_data(self, data):
eintr_retry_call(self.conn.send, data) eintr_retry_call(self.conn.send_bytes, data)
class Pool(Thread): class Pool(Thread):
@ -71,7 +75,7 @@ class Pool(Thread):
self.results = Queue() self.results = Queue()
self.tracker = Queue() self.tracker = Queue()
self.terminal_failure = None self.terminal_failure = None
self.common_data = None self.common_data = cPickle.dumps(None, -1)
self.worker_data = None self.worker_data = None
self.start() self.start()
@ -84,7 +88,10 @@ class Pool(Thread):
eintr_retry_call(p.stdin.write, self.worker_data) eintr_retry_call(p.stdin.write, self.worker_data)
p.stdin.flush(), p.stdin.close() p.stdin.flush(), p.stdin.close()
conn = eintr_retry_call(self.listener.accept) conn = eintr_retry_call(self.listener.accept)
return Worker(p, conn, self.events, self.name) w = Worker(p, conn, self.events, self.name)
if self.common_data != cPickle.dumps(None, -1):
w.set_common_data(self.common_data)
return w
def set_common_data(self, data=None): def set_common_data(self, data=None):
''' Set some data that will be passed to all subsequent jobs without ''' Set some data that will be passed to all subsequent jobs without
@ -94,11 +101,15 @@ class Pool(Thread):
jobs. Can raise the :class:`Failure` exception is data could not be jobs. Can raise the :class:`Failure` exception is data could not be
sent to workers.''' sent to workers.'''
with self.lock: with self.lock:
self.common_data = data self.common_data = cPickle.dumps(data, -1)
self.worker_data = cPickle.dumps((self.address, self.auth_key, self.common_data), -1) if len(self.common_data) > MAX_SIZE:
self.cd_file = PersistentTemporaryFile('pool_common_data')
with self.cd_file as f:
f.write(self.common_data)
self.common_data = cPickle.dumps(File(f.name), -1)
for worker in self.available_workers: for worker in self.available_workers:
try: try:
worker.set_common_data(self.worker_data[-1]) worker.set_common_data(self.common_data)
except Exception: except Exception:
import traceback import traceback
self.terminal_failure = TerminalFailure('Worker process crashed while sending common data', traceback.format_exc()) self.terminal_failure = TerminalFailure('Worker process crashed while sending common data', traceback.format_exc())
@ -118,7 +129,7 @@ class Pool(Thread):
from calibre.utils.ipc.server import create_listener from calibre.utils.ipc.server import create_listener
self.auth_key = os.urandom(32) self.auth_key = os.urandom(32)
self.address, self.listener = create_listener(self.auth_key) self.address, self.listener = create_listener(self.auth_key)
self.worker_data = cPickle.dumps((self.address, self.auth_key, self.common_data), -1) self.worker_data = cPickle.dumps((self.address, self.auth_key), -1)
with self.lock: with self.lock:
if self.start_worker() is False: if self.start_worker() is False:
return return
@ -211,7 +222,7 @@ class Pool(Thread):
self.tracker.task_done() self.tracker.task_done()
while self.pending_jobs: while self.pending_jobs:
job = self.pending_jobs.pop() job = self.pending_jobs.pop()
self.results.put(WorkerResult(job.id, Result(None, None, None), True, worker)) self.results.put(WorkerResult(job.id, Result(None, None, None), True, None))
self.tracker.task_done() self.tracker.task_done()
self.shutdown_workers() self.shutdown_workers()
self.events.put(None) self.events.put(None)
@ -242,12 +253,18 @@ class Pool(Thread):
w.process.kill() w.process.kill()
del self.available_workers[:] del self.available_workers[:]
self.busy_workers.clear() self.busy_workers.clear()
if hasattr(self, 'cd_file'):
try:
os.remove(self.cd_file.name)
except EnvironmentError:
pass
def worker_main(conn, common_data): def worker_main(conn):
from importlib import import_module from importlib import import_module
common_data = None
while True: while True:
try: try:
job = eintr_retry_call(conn.recv) job = cPickle.loads(eintr_retry_call(conn.recv_bytes))
except EOFError: except EOFError:
break break
except Exception: except Exception:
@ -258,7 +275,10 @@ def worker_main(conn, common_data):
if job is None: if job is None:
break break
if not isinstance(job, Job): if not isinstance(job, Job):
common_data = cPickle.loads(job) if isinstance(job, File):
common_data = cPickle.load(open(job.name, 'rb'))
else:
common_data = job
continue continue
try: try:
if '\n' in job.module: if '\n' in job.module:
@ -276,7 +296,7 @@ def worker_main(conn, common_data):
import traceback import traceback
result = Result(None, as_unicode(err), traceback.format_exc()) result = Result(None, as_unicode(err), traceback.format_exc())
try: try:
eintr_retry_call(conn.send, result) eintr_retry_call(conn.send_bytes, cPickle.dumps(result, -1))
except EOFError: except EOFError:
break break
except Exception: except Exception:
@ -289,9 +309,9 @@ def worker_main(conn, common_data):
def run_main(func): def run_main(func):
from multiprocessing.connection import Client from multiprocessing.connection import Client
from contextlib import closing from contextlib import closing
address, key, common_data = cPickle.loads(eintr_retry_call(sys.stdin.read)) address, key = cPickle.loads(eintr_retry_call(sys.stdin.read))
with closing(Client(address, authkey=key)) as conn: with closing(Client(address, authkey=key)) as conn:
raise SystemExit(func(conn, common_data)) raise SystemExit(func(conn))
def test(): def test():
def get_results(pool, ignore_fail=False): def get_results(pool, ignore_fail=False):
@ -325,13 +345,24 @@ def test():
p.set_common_data(7) p.set_common_data(7)
for i in range(1000): for i in range(1000):
p(i, 'def x(i, common_data=None):\n return common_data + i', 'x', i) p(i, 'def x(i, common_data=None):\n return common_data + i', 'x', i)
expected_results[i] = p.common_data + i expected_results[i] = 7 + i
p.wait_for_tasks(30) p.wait_for_tasks(30)
results = {k:v.value for k, v in get_results(p).iteritems()} results = {k:v.value for k, v in get_results(p).iteritems()}
if results != expected_results: if results != expected_results:
raise SystemExit('%r != %r' % (expected_results, results)) raise SystemExit('%r != %r' % (expected_results, results))
p.shutdown(), p.join() p.shutdown(), p.join()
# Test large common data
p = Pool(name='Test')
data = b'a' * (4 * MAX_SIZE)
p.set_common_data(data)
p(0, 'def x(i, common_data=None):\n return len(common_data)', 'x', 0)
p.wait_for_tasks(30)
results = get_results(p)
if len(data) != results[0].value:
raise SystemExit('Common data was not returned correctly')
p.shutdown(), p.join()
# Test exceptions in jobs # Test exceptions in jobs
p = Pool(name='Test') p = Pool(name='Test')
for i in range(1000): for i in range(1000):
@ -361,3 +392,5 @@ def test():
if not p.failed: if not p.failed:
raise SystemExit('No expected terminal failure') raise SystemExit('No expected terminal failure')
p.shutdown(), p.join() p.shutdown(), p.join()
print ('Tests all passed!')