mirror of
https://github.com/kovidgoyal/calibre.git
synced 2025-07-09 03:04:10 -04:00
Handle large common data being sent to workers (30MB+)
This commit is contained in:
parent
b8e9bb0ca2
commit
9eae6a99c7
@ -12,6 +12,7 @@ from collections import namedtuple
|
|||||||
from Queue import Queue
|
from Queue import Queue
|
||||||
|
|
||||||
from calibre import detect_ncpus, as_unicode, prints
|
from calibre import detect_ncpus, as_unicode, prints
|
||||||
|
from calibre.ptempfile import PersistentTemporaryFile
|
||||||
from calibre.utils import join_with_timeout
|
from calibre.utils import join_with_timeout
|
||||||
from calibre.utils.ipc import eintr_retry_call
|
from calibre.utils.ipc import eintr_retry_call
|
||||||
|
|
||||||
@ -19,6 +20,9 @@ Job = namedtuple('Job', 'id module func args kwargs')
|
|||||||
Result = namedtuple('Result', 'value err traceback')
|
Result = namedtuple('Result', 'value err traceback')
|
||||||
WorkerResult = namedtuple('WorkerResult', 'id result is_terminal_failure worker')
|
WorkerResult = namedtuple('WorkerResult', 'id result is_terminal_failure worker')
|
||||||
TerminalFailure = namedtuple('TerminalFailure', 'message tb job_id')
|
TerminalFailure = namedtuple('TerminalFailure', 'message tb job_id')
|
||||||
|
File = namedtuple('File', 'name')
|
||||||
|
|
||||||
|
MAX_SIZE = 30 * 1024 * 1024 # max size of data to send over the connection (old versions of windows cannot handle arbitrary data lengths)
|
||||||
|
|
||||||
class Failure(Exception):
|
class Failure(Exception):
|
||||||
|
|
||||||
@ -36,7 +40,7 @@ class Worker(object):
|
|||||||
self.name = name or ''
|
self.name = name or ''
|
||||||
|
|
||||||
def __call__(self, job):
|
def __call__(self, job):
|
||||||
eintr_retry_call(self.conn.send, job)
|
eintr_retry_call(self.conn.send_bytes, cPickle.dumps(job, -1))
|
||||||
if job is not None:
|
if job is not None:
|
||||||
self.job_id = job.id
|
self.job_id = job.id
|
||||||
t = Thread(target=self.recv, name='PoolWorker-'+self.name)
|
t = Thread(target=self.recv, name='PoolWorker-'+self.name)
|
||||||
@ -45,7 +49,7 @@ class Worker(object):
|
|||||||
|
|
||||||
def recv(self):
|
def recv(self):
|
||||||
try:
|
try:
|
||||||
result = eintr_retry_call(self.conn.recv)
|
result = cPickle.loads(eintr_retry_call(self.conn.recv_bytes))
|
||||||
wr = WorkerResult(self.job_id, result, False, self)
|
wr = WorkerResult(self.job_id, result, False, self)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
import traceback
|
import traceback
|
||||||
@ -54,7 +58,7 @@ class Worker(object):
|
|||||||
self.events.put(wr)
|
self.events.put(wr)
|
||||||
|
|
||||||
def set_common_data(self, data):
|
def set_common_data(self, data):
|
||||||
eintr_retry_call(self.conn.send, data)
|
eintr_retry_call(self.conn.send_bytes, data)
|
||||||
|
|
||||||
class Pool(Thread):
|
class Pool(Thread):
|
||||||
|
|
||||||
@ -71,7 +75,7 @@ class Pool(Thread):
|
|||||||
self.results = Queue()
|
self.results = Queue()
|
||||||
self.tracker = Queue()
|
self.tracker = Queue()
|
||||||
self.terminal_failure = None
|
self.terminal_failure = None
|
||||||
self.common_data = None
|
self.common_data = cPickle.dumps(None, -1)
|
||||||
self.worker_data = None
|
self.worker_data = None
|
||||||
|
|
||||||
self.start()
|
self.start()
|
||||||
@ -84,7 +88,10 @@ class Pool(Thread):
|
|||||||
eintr_retry_call(p.stdin.write, self.worker_data)
|
eintr_retry_call(p.stdin.write, self.worker_data)
|
||||||
p.stdin.flush(), p.stdin.close()
|
p.stdin.flush(), p.stdin.close()
|
||||||
conn = eintr_retry_call(self.listener.accept)
|
conn = eintr_retry_call(self.listener.accept)
|
||||||
return Worker(p, conn, self.events, self.name)
|
w = Worker(p, conn, self.events, self.name)
|
||||||
|
if self.common_data != cPickle.dumps(None, -1):
|
||||||
|
w.set_common_data(self.common_data)
|
||||||
|
return w
|
||||||
|
|
||||||
def set_common_data(self, data=None):
|
def set_common_data(self, data=None):
|
||||||
''' Set some data that will be passed to all subsequent jobs without
|
''' Set some data that will be passed to all subsequent jobs without
|
||||||
@ -94,11 +101,15 @@ class Pool(Thread):
|
|||||||
jobs. Can raise the :class:`Failure` exception is data could not be
|
jobs. Can raise the :class:`Failure` exception is data could not be
|
||||||
sent to workers.'''
|
sent to workers.'''
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.common_data = data
|
self.common_data = cPickle.dumps(data, -1)
|
||||||
self.worker_data = cPickle.dumps((self.address, self.auth_key, self.common_data), -1)
|
if len(self.common_data) > MAX_SIZE:
|
||||||
|
self.cd_file = PersistentTemporaryFile('pool_common_data')
|
||||||
|
with self.cd_file as f:
|
||||||
|
f.write(self.common_data)
|
||||||
|
self.common_data = cPickle.dumps(File(f.name), -1)
|
||||||
for worker in self.available_workers:
|
for worker in self.available_workers:
|
||||||
try:
|
try:
|
||||||
worker.set_common_data(self.worker_data[-1])
|
worker.set_common_data(self.common_data)
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
self.terminal_failure = TerminalFailure('Worker process crashed while sending common data', traceback.format_exc())
|
self.terminal_failure = TerminalFailure('Worker process crashed while sending common data', traceback.format_exc())
|
||||||
@ -118,7 +129,7 @@ class Pool(Thread):
|
|||||||
from calibre.utils.ipc.server import create_listener
|
from calibre.utils.ipc.server import create_listener
|
||||||
self.auth_key = os.urandom(32)
|
self.auth_key = os.urandom(32)
|
||||||
self.address, self.listener = create_listener(self.auth_key)
|
self.address, self.listener = create_listener(self.auth_key)
|
||||||
self.worker_data = cPickle.dumps((self.address, self.auth_key, self.common_data), -1)
|
self.worker_data = cPickle.dumps((self.address, self.auth_key), -1)
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if self.start_worker() is False:
|
if self.start_worker() is False:
|
||||||
return
|
return
|
||||||
@ -211,7 +222,7 @@ class Pool(Thread):
|
|||||||
self.tracker.task_done()
|
self.tracker.task_done()
|
||||||
while self.pending_jobs:
|
while self.pending_jobs:
|
||||||
job = self.pending_jobs.pop()
|
job = self.pending_jobs.pop()
|
||||||
self.results.put(WorkerResult(job.id, Result(None, None, None), True, worker))
|
self.results.put(WorkerResult(job.id, Result(None, None, None), True, None))
|
||||||
self.tracker.task_done()
|
self.tracker.task_done()
|
||||||
self.shutdown_workers()
|
self.shutdown_workers()
|
||||||
self.events.put(None)
|
self.events.put(None)
|
||||||
@ -242,12 +253,18 @@ class Pool(Thread):
|
|||||||
w.process.kill()
|
w.process.kill()
|
||||||
del self.available_workers[:]
|
del self.available_workers[:]
|
||||||
self.busy_workers.clear()
|
self.busy_workers.clear()
|
||||||
|
if hasattr(self, 'cd_file'):
|
||||||
|
try:
|
||||||
|
os.remove(self.cd_file.name)
|
||||||
|
except EnvironmentError:
|
||||||
|
pass
|
||||||
|
|
||||||
def worker_main(conn, common_data):
|
def worker_main(conn):
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
common_data = None
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
job = eintr_retry_call(conn.recv)
|
job = cPickle.loads(eintr_retry_call(conn.recv_bytes))
|
||||||
except EOFError:
|
except EOFError:
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -258,7 +275,10 @@ def worker_main(conn, common_data):
|
|||||||
if job is None:
|
if job is None:
|
||||||
break
|
break
|
||||||
if not isinstance(job, Job):
|
if not isinstance(job, Job):
|
||||||
common_data = cPickle.loads(job)
|
if isinstance(job, File):
|
||||||
|
common_data = cPickle.load(open(job.name, 'rb'))
|
||||||
|
else:
|
||||||
|
common_data = job
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
if '\n' in job.module:
|
if '\n' in job.module:
|
||||||
@ -276,7 +296,7 @@ def worker_main(conn, common_data):
|
|||||||
import traceback
|
import traceback
|
||||||
result = Result(None, as_unicode(err), traceback.format_exc())
|
result = Result(None, as_unicode(err), traceback.format_exc())
|
||||||
try:
|
try:
|
||||||
eintr_retry_call(conn.send, result)
|
eintr_retry_call(conn.send_bytes, cPickle.dumps(result, -1))
|
||||||
except EOFError:
|
except EOFError:
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -289,9 +309,9 @@ def worker_main(conn, common_data):
|
|||||||
def run_main(func):
|
def run_main(func):
|
||||||
from multiprocessing.connection import Client
|
from multiprocessing.connection import Client
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
address, key, common_data = cPickle.loads(eintr_retry_call(sys.stdin.read))
|
address, key = cPickle.loads(eintr_retry_call(sys.stdin.read))
|
||||||
with closing(Client(address, authkey=key)) as conn:
|
with closing(Client(address, authkey=key)) as conn:
|
||||||
raise SystemExit(func(conn, common_data))
|
raise SystemExit(func(conn))
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
def get_results(pool, ignore_fail=False):
|
def get_results(pool, ignore_fail=False):
|
||||||
@ -325,13 +345,24 @@ def test():
|
|||||||
p.set_common_data(7)
|
p.set_common_data(7)
|
||||||
for i in range(1000):
|
for i in range(1000):
|
||||||
p(i, 'def x(i, common_data=None):\n return common_data + i', 'x', i)
|
p(i, 'def x(i, common_data=None):\n return common_data + i', 'x', i)
|
||||||
expected_results[i] = p.common_data + i
|
expected_results[i] = 7 + i
|
||||||
p.wait_for_tasks(30)
|
p.wait_for_tasks(30)
|
||||||
results = {k:v.value for k, v in get_results(p).iteritems()}
|
results = {k:v.value for k, v in get_results(p).iteritems()}
|
||||||
if results != expected_results:
|
if results != expected_results:
|
||||||
raise SystemExit('%r != %r' % (expected_results, results))
|
raise SystemExit('%r != %r' % (expected_results, results))
|
||||||
p.shutdown(), p.join()
|
p.shutdown(), p.join()
|
||||||
|
|
||||||
|
# Test large common data
|
||||||
|
p = Pool(name='Test')
|
||||||
|
data = b'a' * (4 * MAX_SIZE)
|
||||||
|
p.set_common_data(data)
|
||||||
|
p(0, 'def x(i, common_data=None):\n return len(common_data)', 'x', 0)
|
||||||
|
p.wait_for_tasks(30)
|
||||||
|
results = get_results(p)
|
||||||
|
if len(data) != results[0].value:
|
||||||
|
raise SystemExit('Common data was not returned correctly')
|
||||||
|
p.shutdown(), p.join()
|
||||||
|
|
||||||
# Test exceptions in jobs
|
# Test exceptions in jobs
|
||||||
p = Pool(name='Test')
|
p = Pool(name='Test')
|
||||||
for i in range(1000):
|
for i in range(1000):
|
||||||
@ -361,3 +392,5 @@ def test():
|
|||||||
if not p.failed:
|
if not p.failed:
|
||||||
raise SystemExit('No expected terminal failure')
|
raise SystemExit('No expected terminal failure')
|
||||||
p.shutdown(), p.join()
|
p.shutdown(), p.join()
|
||||||
|
|
||||||
|
print ('Tests all passed!')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user