Nested `scatter` calls lead to `KeyError`
See original GitHub issueHi All,
I am currently working on improving the joblib-dask
integration.
It turns out that nested Parallel
calls in joblib
using the dask
backend tend to error out with either KeyError
or CancelledError
.
I narrowed it down using only dask
and numpy
, and it seems that the issue comes from nested scatter
calls.
Here is a reproducer: it consists of submitting functions that rely on scattered arrays. Each of these functions submit small arithmetic operations to be computed on scattered slices of their original input.
import logging
import numpy as np
from distributed import LocalCluster, Client, get_client, secede, rejoin
NUM_INNER_TASKS = 10
NUM_OUTER_TASKS = 10
def my_sum(x, i, j):
print(f"running inner task {j} of outer task {i}")
return np.sum(x)
def outer_function(array, i):
print(f"running outer task {i}")
client = get_client()
slices = [array[i + j :] for j in range(NUM_INNER_TASKS)]
# commenting this line makes the code run successfully
slices = client.scatter(slices, broadcast=True)
futures = client.map(my_sum, slices, [i] * NUM_INNER_TASKS, range(NUM_INNER_TASKS))
secede()
results = client.gather(futures)
rejoin()
return sum(results)
if __name__ == "__main__":
my_arrays = [np.ones(100000) for _ in range(10)]
cluster = LocalCluster(
n_workers=1, threads_per_worker=1, silence_logs=logging.WARNING
)
client = Client(cluster)
future_arrays = client.scatter(my_arrays, direct=False)
# using .map() instead of .submit() makes the code run successfully.
# futures = client.map(outer_function, future_arrays, range(10))
futures = []
for i, arr in enumerate(future_arrays):
future = client.submit(outer_function, arr, i)
futures.append(future)
results = client.gather(futures)
print(results)
2 Remarks:
- as said in the code, using
client.map
makes the code run successfully. - not scattering the
slices
in theouter
functions makes the code run successfully.
My guess as of now is that dynamically creating new compute resources through secede/rejoin
calls might interact badly with the data locality logic of distributed
. I’m investigating this own my own, but I’m not familiar enough with the dask/distributed
codebase to trace this back efficiently.
Is this behavior supported? Is there a clear anti-pattern that I’m missing? Any pointer would be helpful.
Issue Analytics
- State:
- Created 3 years ago
- Comments:14 (13 by maintainers)
Top GitHub Comments
Ah, that makes sense. I think that short term the solution of not hashing data in scatter is probably best. It’s a little bit unclean, but I suspect that it actually has better performance because locally scattering data is entirely free.
Short term these problems also just go away if you use the
hash=False
keyword toclient.scatter
. This avoids any sort of collision between the different clients. It may also mean increased memory use, but maybe not given that the work is likely to be done locally anyway.