Fix derivation of batches in forked_map

This commit is contained in:
Kovid Goyal 2025-04-06 13:39:39 +05:30
parent e34681b159
commit 7c5756c9d3
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C

View File

@ -116,19 +116,21 @@ def forked_map(fn: Callable[[T], R], iterable: T, *iterables: T, timeout: int |
Should be used only in worker processes that have no threads and that do not use/import any non fork safe libraries such as macOS Should be used only in worker processes that have no threads and that do not use/import any non fork safe libraries such as macOS
system libraries. system libraries.
''' '''
num_items = len(iterable) + sum(map(len, iterables))
if num_workers <= 0: if num_workers <= 0:
num_workers = max(1, os.cpu_count()) num_workers = max(1, min(num_items, os.cpu_count()))
chunk_size = max(1, len(iterables) / num_workers) chunk_size = max(1, num_items // num_workers)
groups = batched((Job(i, arg, fn) for i, arg in enumerate(chain(iterable, *iterables))), chunk_size) groups = batched((Job(i, arg, fn) for i, arg in enumerate(chain(iterable, *iterables))), chunk_size)
cache: dict[int, Result] = {} cache: dict[int, Result] = {}
pos = 0 pos = 0
workers = tuple(run_jobs(*g) for g in groups) workers = tuple(run_jobs(*g) for g in groups)
count = 0
with ExitStack() as stack: with ExitStack() as stack:
for w in workers: for w in workers:
stack.push(w) stack.push(w)
wmap = {w.pipe.fileno(): w for w in workers} wmap = {w.pipe.fileno(): w for w in workers}
while wmap: while wmap and count < num_items:
ready, _, _ = select.select(tuple(wmap), (), (), timeout) ready, _, _ = select.select(wmap, (), (), timeout)
if not ready: if not ready:
raise TimeoutError(f'Forked workers did not produce a result in {timeout} seconds') raise TimeoutError(f'Forked workers did not produce a result in {timeout} seconds')
for r in ready: for r in ready:
@ -138,6 +140,7 @@ def forked_map(fn: Callable[[T], R], iterable: T, *iterables: T, timeout: int |
except EOFError: except EOFError:
del wmap[r] del wmap[r]
continue continue
count += 1
if pos == result.id: if pos == result.id:
if not result.ok: if not result.ok:
raise result.value raise result.value
@ -151,7 +154,7 @@ def forked_map(fn: Callable[[T], R], iterable: T, *iterables: T, timeout: int |
while r := cache.pop(pos, None): while r := cache.pop(pos, None):
yield r.value yield r.value
pos += 1 pos += 1
if pos < len(iterables): if pos < num_items:
raise OSError(f'Forked workers exited producing only {pos} out of {len(iterables)} results') raise OSError(f'Forked workers exited producing only {pos} out of {len(iterables)} results')
@ -169,16 +172,16 @@ def find_tests():
time.sleep(10 * x) time.sleep(10 * x)
return x return x
with self.assertRaises(TimeoutError): with self.assertRaises(TimeoutError):
tuple(forked_map(sleep, range(os.cpu_count() * 3), timeout=0.001)) tuple(forked_map(sleep, range(3), timeout=0.001))
def raise_error(x: int) -> None: def raise_error(x: int) -> None:
raise ReferenceError('testing') raise ReferenceError('testing')
with self.assertRaises(ReferenceError): with self.assertRaises(ReferenceError):
tuple(forked_map(raise_error, range(os.cpu_count() * 3))) tuple(forked_map(raise_error, range(3)))
timings = 0, 1, 2, 3 timings = 0, 1, 2, 3
def echo(x: int) -> int: def echo(x: int) -> int:
time.sleep(0.0001 * random.choice(timings)) time.sleep(0.0001 * random.choice(timings))
return x return x
for num_workers in range(1, os.cpu_count() + 1): for num_workers in range(1, 9):
items = tuple(range(num_workers * 3)) items = tuple(range(num_workers * 3 + 1))
self.assertEqual(tuple(map(echo, items)), tuple(forked_map(echo, items))) self.assertEqual(tuple(map(echo, items)), tuple(forked_map(echo, items, num_workers=num_workers)))
return unittest.defaultTestLoader.loadTestsFromTestCase(TestForkedMap) return unittest.defaultTestLoader.loadTestsFromTestCase(TestForkedMap)