Eliminate deadlocks in the pool at the cost of a small chance to leak worker processes

This commit is contained in:
Kovid Goyal 2014-11-11 00:53:14 +05:30
parent 1db7a9e32d
commit 46fe62e45f

View File

@ -7,7 +7,7 @@ __license__ = 'GPL v3'
__copyright__ = '2014, Kovid Goyal <kovid at kovidgoyal.net>'
import os, cPickle, sys
from threading import Thread, RLock
from threading import Thread
from collections import namedtuple
from Queue import Queue
@ -66,7 +66,6 @@ class Pool(Thread):
def __init__(self, max_workers=None, name=None):
Thread.__init__(self, name=name)
self.lock = RLock()
self.max_workers = max_workers or detect_ncpus()
self.available_workers = []
self.busy_workers = {}
@ -77,6 +76,7 @@ class Pool(Thread):
self.terminal_failure = None
self.common_data = cPickle.dumps(None, -1)
self.worker_data = None
self.shutting_down = False
self.start()
@ -87,7 +87,6 @@ class Pool(Thread):
after all jobs are done, then it will be used for the new round of
jobs. Can raise the :class:`Failure` exception is data could not be
sent to workers.'''
with self.lock:
if self.failed:
raise Failure(self.terminal_failure)
self.events.put(data)
@ -105,10 +104,9 @@ class Pool(Thread):
:param func: Name of the function from ``module`` that will be
executed. ``args`` and ``kwargs`` will be passed to the function.
'''
job = Job(job_id, module, func, args, kwargs)
with self.lock:
if self.failed:
raise Failure(self.terminal_failure)
job = Job(job_id, module, func, args, kwargs)
self.tracker.put(None)
self.events.put(job)
@ -124,6 +122,13 @@ class Pool(Thread):
else:
join_with_timeout(self.tracker, timeout)
def shutdown(self):
''' Shutdown this pool, terminating all worker process. The pool cannot
be used after a shutdown. '''
self.shutting_down = True
self.events.put(None)
self.shutdown_workers()
def create_worker(self):
from calibre.utils.ipc.simple_worker import start_pipe_worker
p = start_pipe_worker(
@ -139,7 +144,9 @@ class Pool(Thread):
def start_worker(self):
try:
self.available_workers.append(self.create_worker())
w = self.create_worker()
if not self.shutting_down:
self.available_workers.append(w)
except Exception:
import traceback
self.terminal_failure = TerminalFailure('Failed to start worker process', traceback.format_exc(), None)
@ -151,19 +158,15 @@ class Pool(Thread):
self.auth_key = os.urandom(32)
self.address, self.listener = create_listener(self.auth_key)
self.worker_data = cPickle.dumps((self.address, self.auth_key), -1)
with self.lock:
if self.start_worker() is False:
return
while True:
event = self.events.get()
with self.lock:
if event is None:
if event is None or self.shutting_down:
break
if self.handle_event(event) is False:
break
with self.lock:
self.shutdown_workers()
def handle_event(self, event):
if isinstance(event, Job):
@ -228,33 +231,38 @@ class Pool(Thread):
job = self.pending_jobs.pop()
self.results.put(WorkerResult(job.id, Result(None, None, None), True, None))
self.tracker.task_done()
self.shutdown_workers()
self.events.put(None)
def shutdown(self):
with self.lock:
self.events.put(None)
self.shutdown_workers()
self.shutdown()
def shutdown_workers(self, wait_time=0.1):
for worker in self.available_workers:
for worker in self.busy_workers:
if worker.process.poll() is None:
try:
worker(None)
worker.process.terminate()
except EnvironmentError:
pass # If the process has already been killed
workers = [w.process for w in self.available_workers + list(self.busy_workers)]
aw = list(self.available_workers)
def join():
for w in aw:
try:
w(None)
except Exception:
pass
for worker in self.busy_workers:
worker.process.terminate()
workers = [w.process for w in self.available_workers + list(self.busy_workers)]
def join():
for w in workers:
try:
w.wait()
except Exception:
pass
reaper = Thread(target=join, name='ReapPoolWorkers')
reaper.daemon = True
reaper.start()
reaper.join(wait_time)
for w in self.available_workers:
if w.process.poll() is None:
w.process.kill()
for w in workers:
if w.poll() is None:
try:
w.kill()
except EnvironmentError:
pass
del self.available_workers[:]
self.busy_workers.clear()
if hasattr(self, 'cd_file'):
@ -344,7 +352,6 @@ def test():
# Test common_data
p = Pool(name='Test')
expected_results = {}
with p.lock:
p.start_worker()
p.set_common_data(7)
for i in range(1000):