From 7c5756c9d33b7e338336b5bf3b3000cf057de954 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Sun, 6 Apr 2025 13:39:39 +0530 Subject: [PATCH] Fix derivation of batches in forked_map --- src/calibre/utils/forked_map.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/calibre/utils/forked_map.py b/src/calibre/utils/forked_map.py index 6b530aefa2..b3cef9161e 100644 --- a/src/calibre/utils/forked_map.py +++ b/src/calibre/utils/forked_map.py @@ -116,19 +116,21 @@ def forked_map(fn: Callable[[T], R], iterable: T, *iterables: T, timeout: int | Should be used only in worker processes that have no threads and that do not use/import any non fork safe libraries such as macOS system libraries. ''' + num_items = len(iterable) + sum(map(len, iterables)) if num_workers <= 0: - num_workers = max(1, os.cpu_count()) - chunk_size = max(1, len(iterables) / num_workers) + num_workers = max(1, min(num_items, os.cpu_count())) + chunk_size = max(1, num_items // num_workers) groups = batched((Job(i, arg, fn) for i, arg in enumerate(chain(iterable, *iterables))), chunk_size) cache: dict[int, Result] = {} pos = 0 workers = tuple(run_jobs(*g) for g in groups) + count = 0 with ExitStack() as stack: for w in workers: stack.push(w) wmap = {w.pipe.fileno(): w for w in workers} - while wmap: - ready, _, _ = select.select(tuple(wmap), (), (), timeout) + while wmap and count < num_items: + ready, _, _ = select.select(wmap, (), (), timeout) if not ready: raise TimeoutError(f'Forked workers did not produce a result in {timeout} seconds') for r in ready: @@ -138,6 +140,7 @@ def forked_map(fn: Callable[[T], R], iterable: T, *iterables: T, timeout: int | except EOFError: del wmap[r] continue + count += 1 if pos == result.id: if not result.ok: raise result.value @@ -151,7 +154,7 @@ def forked_map(fn: Callable[[T], R], iterable: T, *iterables: T, timeout: int | while r := cache.pop(pos, None): yield r.value pos += 1 - if pos < len(iterables): + if pos < num_items: raise OSError(f'Forked workers exited producing only {pos} out of {len(iterables)} results') @@ -169,16 +172,16 @@ def find_tests(): time.sleep(10 * x) return x with self.assertRaises(TimeoutError): - tuple(forked_map(sleep, range(os.cpu_count() * 3), timeout=0.001)) + tuple(forked_map(sleep, range(3), timeout=0.001)) def raise_error(x: int) -> None: raise ReferenceError('testing') with self.assertRaises(ReferenceError): - tuple(forked_map(raise_error, range(os.cpu_count() * 3))) + tuple(forked_map(raise_error, range(3))) timings = 0, 1, 2, 3 def echo(x: int) -> int: time.sleep(0.0001 * random.choice(timings)) return x - for num_workers in range(1, os.cpu_count() + 1): - items = tuple(range(num_workers * 3)) - self.assertEqual(tuple(map(echo, items)), tuple(forked_map(echo, items))) + for num_workers in range(1, 9): + items = tuple(range(num_workers * 3 + 1)) + self.assertEqual(tuple(map(echo, items)), tuple(forked_map(echo, items, num_workers=num_workers))) return unittest.defaultTestLoader.loadTestsFromTestCase(TestForkedMap)