problems with numba ufunc + distributed
See original GitHub issueWe have created a new software package called fastjmd95 that uses numba to accelerate computation of the ocean equation of state. Everything works find with dask and a local scheduler. Now I want to run this code on a distributed dask cluster. It isn’t working, I think because the workers are not able to deserialize the numba functions properly.
Original Full Example
This example with real data can be run on any pangeo cluster
from intake import open_catalog
from fastjmd95 import rho
cat = open_catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean.yaml")
ds = cat["SOSE"].to_dask()
rhonil = 1025
pa_to_dbar = 1.0/10000
p = ds.PHrefC * rhonil * pa_to_dbar
s = ds.SALT
t = ds.THETA
r = rho(s.data, t.data, 0)
# works fine with local scheduler
r_mean = r[:5].compute()
# now start distributed scheduler
from dask.distributed import Client
client = Client()
r_mean = r[:5].compute()
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-4-7316322484d4> in <module>
----> 1 r_mean = r[:5].compute()
/srv/conda/envs/notebook/lib/python3.7/site-packages/dask/base.py in compute(self, **kwargs)
163 dask.base.compute
164 """
--> 165 (result,) = compute(self, traverse=False, **kwargs)
166 return result
167
/srv/conda/envs/notebook/lib/python3.7/site-packages/dask/base.py in compute(*args, **kwargs)
434 keys = [x.__dask_keys__() for x in collections]
435 postcomputes = [x.__dask_postcompute__() for x in collections]
--> 436 results = schedule(dsk, keys, **kwargs)
437 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
438
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
2571 should_rejoin = False
2572 try:
-> 2573 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
2574 finally:
2575 for f in futures.values():
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
1871 direct=direct,
1872 local_worker=local_worker,
-> 1873 asynchronous=asynchronous,
1874 )
1875
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
766 else:
767 return sync(
--> 768 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
769 )
770
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
332 if error[0]:
333 typ, exc, tb = error[0]
--> 334 raise exc.with_traceback(tb)
335 else:
336 return result[0]
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/utils.py in f()
316 if callback_timeout is not None:
317 future = gen.with_timeout(timedelta(seconds=callback_timeout), future)
--> 318 result[0] = yield future
319 except Exception as exc:
320 error[0] = sys.exc_info()
/srv/conda/envs/notebook/lib/python3.7/site-packages/tornado/gen.py in run(self)
733
734 try:
--> 735 value = future.result()
736 except Exception:
737 exc_info = sys.exc_info()
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
1727 exc = CancelledError(key)
1728 else:
-> 1729 raise exception.with_traceback(traceback)
1730 raise exc
1731 if errors == "skip":
/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/pickle.py in loads()
57 def loads(x):
58 try:
---> 59 return pickle.loads(x)
60 except Exception:
61 logger.info("Failed to deserialize %s", x[:10000], exc_info=True)
/srv/conda/envs/notebook/lib/python3.7/site-packages/numpy/core/__init__.py in _ufunc_reconstruct()
123 # scipy.special.expit for instance.
124 mod = __import__(module, fromlist=[name])
--> 125 return getattr(mod, name)
126
127 def _ufunc_reduce(func):
AttributeError: module '__main__' has no attribute 'rho'
Minimal Example
I believe this reproduces the core problem
import numpy as np
from numba import vectorize, float64, float32
import dask.array as dsa
from dask.distributed import Client
client = Client()
# define a numba ufunc
@vectorize([float64(float64), float32(float32)], nopython=True)
def test_numba(a):
return a**2
# verify that the client can run it
def try_numba_on_client():
data = np.arange(5, dtype='f4')
return test_numba(data)
client.run(try_numba_on_client)
# works, output is:
# > {'tcp://127.0.0.1:37583': array([ 0., 1., 4., 9., 16.]),
# > 'tcp://127.0.0.1:44855': array([ 0., 1., 4., 9., 16.])}
# use in a computation
data_dask = dsa.arange(5, dtype='f4')
test_numba(data_dask).compute()
At this point I get a KilledWorker
error. In the worker log, I can see the following error (sorry for the lack of formatting–that’s how it comes out of the worker error logs)
distributed.worker - ERROR - module '__main__' has no attribute 'test_numba'
Traceback (most recent call last): File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/worker.py", line 905, in handle_scheduler comm, every_cycle=[self.ensure_communicating, self.ensure_computing] File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/core.py", line 456, in handle_stream msgs = await comm.read() File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/tcp.py", line 222, in read frames, deserialize=self.deserialize, deserializers=deserializers File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/utils.py", line 69, in from_frames res = _from_frames() File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/comm/utils.py", line 55, in _from_frames frames, deserialize=deserialize, deserializers=deserializers File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/core.py", line 124, in loads value = _deserialize(head, fs, deserializers=deserializers) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 255, in deserialize deserializers=deserializers, File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 268, in deserialize return loads(header, frames) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/serialize.py", line 62, in pickle_loads return pickle.loads(b"".join(frames)) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/distributed/protocol/pickle.py", line 59, in loads return pickle.loads(x) File "/srv/conda/envs/notebook/lib/python3.7/site-packages/numpy/core/__init__.py", line 125, in _ufunc_reconstruct return getattr(mod, name)
AttributeError: module '__main__' has no attribute 'test_numba'
The basic error appears to be the same as in the full example.
This seems like a pretty straightforward use of numba + distributed, and I assumed this sort of usage was supported. Am I missing something obvious?
Installed versions
I’m on dask 2.9.0 and numba 0.48.0.
>>> client.get_versions(check=True)
{'scheduler': {'host': (('python', '3.7.6.final.0'),
('python-bits', 64),
('OS', 'Linux'),
('OS-release', '4.19.76+'),
('machine', 'x86_64'),
('processor', 'x86_64'),
('byteorder', 'little'),
('LC_ALL', 'en_US.UTF-8'),
('LANG', 'en_US.UTF-8'),
('LOCALE', 'en_US.UTF-8')),
'packages': {'required': (('dask', '2.9.0'),
('distributed', '2.9.0'),
('msgpack', '0.6.2'),
('cloudpickle', '1.2.2'),
('tornado', '6.0.3'),
('toolz', '0.10.0')),
'optional': (('numpy', '1.17.3'),
('pandas', '0.25.3'),
('bokeh', '1.4.0'),
('lz4', '2.2.1'),
('dask_ml', '1.1.1'),
('blosc', '1.8.1'))}},
'workers': {'tcp://10.32.181.10:45663': {'host': (('python', '3.7.6.final.0'),
('python-bits', 64),
('OS', 'Linux'),
('OS-release', '4.19.76+'),
('machine', 'x86_64'),
('processor', 'x86_64'),
('byteorder', 'little'),
('LC_ALL', 'en_US.UTF-8'),
('LANG', 'en_US.UTF-8'),
('LOCALE', 'en_US.UTF-8')),
'packages': {'required': (('dask', '2.9.0'),
('distributed', '2.9.0'),
('msgpack', '0.6.2'),
('cloudpickle', '1.2.2'),
('tornado', '6.0.3'),
('toolz', '0.10.0')),
'optional': (('numpy', '1.17.3'),
('pandas', '0.25.3'),
('bokeh', '1.4.0'),
('lz4', '2.2.1'),
('dask_ml', '1.1.1'),
('blosc', '1.8.1'))}},
'tcp://10.32.181.11:37259': {'host': (('python', '3.7.6.final.0'),
('python-bits', 64),
('OS', 'Linux'),
('OS-release', '4.19.76+'),
('machine', 'x86_64'),
('processor', 'x86_64'),
('byteorder', 'little'),
('LC_ALL', 'en_US.UTF-8'),
('LANG', 'en_US.UTF-8'),
('LOCALE', 'en_US.UTF-8')),
'packages': {'required': (('dask', '2.9.0'),
('distributed', '2.9.0'),
('msgpack', '0.6.2'),
('cloudpickle', '1.2.2'),
('tornado', '6.0.3'),
('toolz', '0.10.0')),
'optional': (('numpy', '1.17.3'),
('pandas', '0.25.3'),
('bokeh', '1.4.0'),
('lz4', '2.2.1'),
('dask_ml', '1.1.1'),
('blosc', '1.8.1'))}}},
'client': {'host': [('python', '3.7.6.final.0'),
('python-bits', 64),
('OS', 'Linux'),
('OS-release', '4.19.76+'),
('machine', 'x86_64'),
('processor', 'x86_64'),
('byteorder', 'little'),
('LC_ALL', 'en_US.UTF-8'),
('LANG', 'en_US.UTF-8'),
('LOCALE', 'en_US.UTF-8')],
'packages': {'required': [('dask', '2.9.0'),
('distributed', '2.9.0'),
('msgpack', '0.6.2'),
('cloudpickle', '1.2.2'),
('tornado', '6.0.3'),
('toolz', '0.10.0')],
'optional': [('numpy', '1.17.3'),
('pandas', '0.25.3'),
('bokeh', '1.4.0'),
('lz4', '2.2.1'),
('dask_ml', '1.1.1'),
('blosc', '1.8.1')]}}}
Issue Analytics
- State:
- Created 4 years ago
- Comments:29 (18 by maintainers)
Top GitHub Comments
I have a slightly better understanding of the situation now. The call order is something like
The
test_numba.ufunc
is a NumPy ufunc that is (I think) dynamically generated by numba.And that’s what chokes up dask’s serialization
Will start looking for solutions now.
@TomAugspurger, seems hackish, but maybe a band-aid is better than nothing.
However, I tried around a bit and I think we are missing that pickle got better, or how reliable it actually is? I.e. I think NumPy is over-engineered and that makes the solution harder than necessary. I tried modifying NumPy like this, but you can also do it manually:
Now you need one more ingredient, and that is that
test_numba.ufunc
has to report its__name__
astest_numba.ufunc
(a bit like a__qualname__
. I tried this by hacking that the ufunc name is mutable.__qualname__
would maybe be better, and I guess we could add a__qualname__
to UFuncs, but if printing the extra.ufunc
seems OK, this solution is possible right now maybe.Now overriding the ufunc pickling outside of NumPy seems pretty extreme, but I am not actually sure its all that bad, I did not check, but I think the above replacement is effectively identical to what NumPy does, except that it supports attributes in a
__qualname__
like fashion.