Eliminate deadlocks in the pool at the cost of a small chance to leak worker processes

This commit is contained in:
Kovid Goyal 2014-11-11 00:53:14 +05:30
parent 1db7a9e32d
commit 46fe62e45f

View File

@ -7,7 +7,7 @@ __license__ = 'GPL v3'
__copyright__ = '2014, Kovid Goyal <kovid at kovidgoyal.net>' __copyright__ = '2014, Kovid Goyal <kovid at kovidgoyal.net>'
import os, cPickle, sys import os, cPickle, sys
from threading import Thread, RLock from threading import Thread
from collections import namedtuple from collections import namedtuple
from Queue import Queue from Queue import Queue
@ -66,7 +66,6 @@ class Pool(Thread):
def __init__(self, max_workers=None, name=None): def __init__(self, max_workers=None, name=None):
Thread.__init__(self, name=name) Thread.__init__(self, name=name)
self.lock = RLock()
self.max_workers = max_workers or detect_ncpus() self.max_workers = max_workers or detect_ncpus()
self.available_workers = [] self.available_workers = []
self.busy_workers = {} self.busy_workers = {}
@ -77,6 +76,7 @@ class Pool(Thread):
self.terminal_failure = None self.terminal_failure = None
self.common_data = cPickle.dumps(None, -1) self.common_data = cPickle.dumps(None, -1)
self.worker_data = None self.worker_data = None
self.shutting_down = False
self.start() self.start()
@ -87,7 +87,6 @@ class Pool(Thread):
after all jobs are done, then it will be used for the new round of after all jobs are done, then it will be used for the new round of
jobs. Can raise the :class:`Failure` exception is data could not be jobs. Can raise the :class:`Failure` exception is data could not be
sent to workers.''' sent to workers.'''
with self.lock:
if self.failed: if self.failed:
raise Failure(self.terminal_failure) raise Failure(self.terminal_failure)
self.events.put(data) self.events.put(data)
@ -105,10 +104,9 @@ class Pool(Thread):
:param func: Name of the function from ``module`` that will be :param func: Name of the function from ``module`` that will be
executed. ``args`` and ``kwargs`` will be passed to the function. executed. ``args`` and ``kwargs`` will be passed to the function.
''' '''
job = Job(job_id, module, func, args, kwargs)
with self.lock:
if self.failed: if self.failed:
raise Failure(self.terminal_failure) raise Failure(self.terminal_failure)
job = Job(job_id, module, func, args, kwargs)
self.tracker.put(None) self.tracker.put(None)
self.events.put(job) self.events.put(job)
@ -124,6 +122,13 @@ class Pool(Thread):
else: else:
join_with_timeout(self.tracker, timeout) join_with_timeout(self.tracker, timeout)
def shutdown(self):
''' Shutdown this pool, terminating all worker process. The pool cannot
be used after a shutdown. '''
self.shutting_down = True
self.events.put(None)
self.shutdown_workers()
def create_worker(self): def create_worker(self):
from calibre.utils.ipc.simple_worker import start_pipe_worker from calibre.utils.ipc.simple_worker import start_pipe_worker
p = start_pipe_worker( p = start_pipe_worker(
@ -139,7 +144,9 @@ class Pool(Thread):
def start_worker(self): def start_worker(self):
try: try:
self.available_workers.append(self.create_worker()) w = self.create_worker()
if not self.shutting_down:
self.available_workers.append(w)
except Exception: except Exception:
import traceback import traceback
self.terminal_failure = TerminalFailure('Failed to start worker process', traceback.format_exc(), None) self.terminal_failure = TerminalFailure('Failed to start worker process', traceback.format_exc(), None)
@ -151,19 +158,15 @@ class Pool(Thread):
self.auth_key = os.urandom(32) self.auth_key = os.urandom(32)
self.address, self.listener = create_listener(self.auth_key) self.address, self.listener = create_listener(self.auth_key)
self.worker_data = cPickle.dumps((self.address, self.auth_key), -1) self.worker_data = cPickle.dumps((self.address, self.auth_key), -1)
with self.lock:
if self.start_worker() is False: if self.start_worker() is False:
return return
while True: while True:
event = self.events.get() event = self.events.get()
with self.lock: if event is None or self.shutting_down:
if event is None:
break break
if self.handle_event(event) is False: if self.handle_event(event) is False:
break break
with self.lock:
self.shutdown_workers()
def handle_event(self, event): def handle_event(self, event):
if isinstance(event, Job): if isinstance(event, Job):
@ -228,33 +231,38 @@ class Pool(Thread):
job = self.pending_jobs.pop() job = self.pending_jobs.pop()
self.results.put(WorkerResult(job.id, Result(None, None, None), True, None)) self.results.put(WorkerResult(job.id, Result(None, None, None), True, None))
self.tracker.task_done() self.tracker.task_done()
self.shutdown_workers() self.shutdown()
self.events.put(None)
def shutdown(self):
with self.lock:
self.events.put(None)
self.shutdown_workers()
def shutdown_workers(self, wait_time=0.1): def shutdown_workers(self, wait_time=0.1):
for worker in self.available_workers: for worker in self.busy_workers:
if worker.process.poll() is None:
try: try:
worker(None) worker.process.terminate()
except EnvironmentError:
pass # If the process has already been killed
workers = [w.process for w in self.available_workers + list(self.busy_workers)]
aw = list(self.available_workers)
def join():
for w in aw:
try:
w(None)
except Exception: except Exception:
pass pass
for worker in self.busy_workers:
worker.process.terminate()
workers = [w.process for w in self.available_workers + list(self.busy_workers)]
def join():
for w in workers: for w in workers:
try:
w.wait() w.wait()
except Exception:
pass
reaper = Thread(target=join, name='ReapPoolWorkers') reaper = Thread(target=join, name='ReapPoolWorkers')
reaper.daemon = True reaper.daemon = True
reaper.start() reaper.start()
reaper.join(wait_time) reaper.join(wait_time)
for w in self.available_workers: for w in workers:
if w.process.poll() is None: if w.poll() is None:
w.process.kill() try:
w.kill()
except EnvironmentError:
pass
del self.available_workers[:] del self.available_workers[:]
self.busy_workers.clear() self.busy_workers.clear()
if hasattr(self, 'cd_file'): if hasattr(self, 'cd_file'):
@ -344,7 +352,6 @@ def test():
# Test common_data # Test common_data
p = Pool(name='Test') p = Pool(name='Test')
expected_results = {} expected_results = {}
with p.lock:
p.start_worker() p.start_worker()
p.set_common_data(7) p.set_common_data(7)
for i in range(1000): for i in range(1000):