dask_ml.model_selection.GridSearchCV errors for keras model
See original GitHub issueI am trying to fill Keras model into dask_ml.model_selection.GridSearchCV. If I do not set client, it works fines. However, I got errors if I have two dask workers. It seems to be unable to deserialize something. I appreciate if anyone has suggestions about this problem.
distributed.protocol.pickle - INFO - Failed to deserialize b'\x80\x04\x95b\x05\x00\x00\x00\x00\x00\x00\x8c\x1bkeras.wrappers.scikit_learn\x94\x8c\x0fKerasClassifier\x94\x93\x94)\x81\x94}
ValueError Traceback (most recent call last)
~/.local/lib/python3.7/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1112 subfeed_t = self.graph.as_graph_element(
-> 1113 subfeed, allow_tensor=True, allow_operation=False)
1114 except Exception as e:
~/.local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
3795 with self._lock:
-> 3796 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3797
~/.local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
3874 if obj.graph is not self:
-> 3875 raise ValueError("Tensor %s is not an element of this graph." % obj)
3876 return obj
ValueError: Tensor Tensor("Placeholder:0", shape=(160, 640), dtype=float32) is not an element of this graph.
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
<timed exec> in <module>
/conda/lib/python3.7/site-packages/dask_ml/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
1293 dsk, keys = build_refit_graph(estimator, X, y, best_params, fit_params)
1294
-> 1295 out = scheduler(dsk, keys, num_workers=n_jobs)
1296 self.best_estimator_ = out[0]
1297
/conda/lib/python3.7/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
2525 should_rejoin = False
2526 try:
-> 2527 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
2528 finally:
2529 for f in futures.values():
/conda/lib/python3.7/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
1821 direct=direct,
1822 local_worker=local_worker,
-> 1823 asynchronous=asynchronous,
1824 )
1825
/conda/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
761 else:
762 return sync(
--> 763 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
764 )
765
/conda/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
330 e.wait(10)
331 if error[0]:
--> 332 six.reraise(*error[0])
333 else:
334 return result[0]
/conda/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
691 if value.__traceback__ is not tb:
692 raise value.with_traceback(tb)
--> 693 raise value
694 finally:
695 value = None
/conda/lib/python3.7/site-packages/distributed/utils.py in f()
315 if callback_timeout is not None:
316 future = gen.with_timeout(timedelta(seconds=callback_timeout), future)
--> 317 result[0] = yield future
318 except Exception as exc:
319 error[0] = sys.exc_info()
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
733
734 try:
--> 735 value = future.result()
736 except Exception:
737 exc_info = sys.exc_info()
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
740 if exc_info is not None:
741 try:
--> 742 yielded = self.gen.throw(*exc_info) # type: ignore
743 finally:
744 # Break up a reference to itself
/conda/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
1705 else:
1706 self._gather_future = future
-> 1707 response = yield future
1708
1709 if response["status"] == "error":
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
733
734 try:
--> 735 value = future.result()
736 except Exception:
737 exc_info = sys.exc_info()
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
740 if exc_info is not None:
741 try:
--> 742 yielded = self.gen.throw(*exc_info) # type: ignore
743 finally:
744 # Break up a reference to itself
/conda/lib/python3.7/site-packages/distributed/client.py in _gather_remote(self, direct, local_worker)
1758
1759 else: # ask scheduler to gather data for us
-> 1760 response = yield self.scheduler.gather(keys=keys)
1761 finally:
1762 self._gather_semaphore.release()
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
733
734 try:
--> 735 value = future.result()
736 except Exception:
737 exc_info = sys.exc_info()
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
740 if exc_info is not None:
741 try:
--> 742 yielded = self.gen.throw(*exc_info) # type: ignore
743 finally:
744 # Break up a reference to itself
/conda/lib/python3.7/site-packages/distributed/core.py in send_recv_from_rpc(**kwargs)
739 name, comm.name = comm.name, "ConnectionPool." + key
740 try:
--> 741 result = yield send_recv(comm=comm, op=key, **kwargs)
742 finally:
743 self.pool.reuse(self.addr, comm)
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
733
734 try:
--> 735 value = future.result()
736 except Exception:
737 exc_info = sys.exc_info()
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
740 if exc_info is not None:
741 try:
--> 742 yielded = self.gen.throw(*exc_info) # type: ignore
743 finally:
744 # Break up a reference to itself
/conda/lib/python3.7/site-packages/distributed/core.py in send_recv(comm, reply, serializers, deserializers, **kwargs)
533 yield comm.write(msg, serializers=serializers, on_error="raise")
534 if reply:
--> 535 response = yield comm.read(deserializers=deserializers)
536 else:
537 response = None
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
733
734 try:
--> 735 value = future.result()
736 except Exception:
737 exc_info = sys.exc_info()
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
740 if exc_info is not None:
741 try:
--> 742 yielded = self.gen.throw(*exc_info) # type: ignore
743 finally:
744 # Break up a reference to itself
/conda/lib/python3.7/site-packages/distributed/comm/tcp.py in read(self, deserializers)
216 try:
217 msg = yield from_frames(
--> 218 frames, deserialize=self.deserialize, deserializers=deserializers
219 )
220 except EOFError:
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
733
734 try:
--> 735 value = future.result()
736 except Exception:
737 exc_info = sys.exc_info()
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
740 if exc_info is not None:
741 try:
--> 742 yielded = self.gen.throw(*exc_info) # type: ignore
743 finally:
744 # Break up a reference to itself
/conda/lib/python3.7/site-packages/distributed/comm/utils.py in from_frames(frames, deserialize, deserializers)
81
82 if deserialize and size > FRAME_OFFLOAD_THRESHOLD:
---> 83 res = yield offload(_from_frames)
84 else:
85 res = _from_frames()
/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
733
734 try:
--> 735 value = future.result()
736 except Exception:
737 exc_info = sys.exc_info()
/conda/lib/python3.7/concurrent/futures/_base.py in result(self, timeout)
423 raise CancelledError()
424 elif self._state == FINISHED:
--> 425 return self.__get_result()
426
427 self._condition.wait(timeout)
/conda/lib/python3.7/concurrent/futures/_base.py in __get_result(self)
382 def __get_result(self):
383 if self._exception:
--> 384 raise self._exception
385 else:
386 return self._result
/conda/lib/python3.7/concurrent/futures/thread.py in run(self)
55
56 try:
---> 57 result = self.fn(*self.args, **self.kwargs)
58 except BaseException as exc:
59 self.future.set_exception(exc)
/conda/lib/python3.7/site-packages/distributed/comm/utils.py in _from_frames()
69 try:
70 return protocol.loads(
---> 71 frames, deserialize=deserialize, deserializers=deserializers
72 )
73 except EOFError:
/conda/lib/python3.7/site-packages/distributed/protocol/core.py in loads(frames, deserialize, deserializers)
124 fs = decompress(head, fs)
125 fs = merge_frames(head, fs)
--> 126 value = _deserialize(head, fs, deserializers=deserializers)
127 else:
128 value = Serialized(head, fs)
/conda/lib/python3.7/site-packages/distributed/protocol/serialize.py in deserialize(header, frames, deserializers)
188 )
189 dumps, loads, wants_context = families[name]
--> 190 return loads(header, frames)
191
192
/conda/lib/python3.7/site-packages/distributed/protocol/serialize.py in pickle_loads(header, frames)
62
63 def pickle_loads(header, frames):
---> 64 return pickle.loads(b"".join(frames))
65
66
/conda/lib/python3.7/site-packages/distributed/protocol/pickle.py in loads(x)
59 def loads(x):
60 try:
---> 61 return pickle.loads(x)
62 except Exception:
63 logger.info("Failed to deserialize %s", x[:10000], exc_info=True)
/conda/lib/python3.7/site-packages/keras/engine/network.py in __setstate__(self, state)
1264
1265 def __setstate__(self, state):
-> 1266 model = saving.unpickle_model(state)
1267 self.__dict__.update(model.__dict__)
1268
/conda/lib/python3.7/site-packages/keras/engine/saving.py in unpickle_model(state)
433 def unpickle_model(state):
434 f = h5dict(state, mode='r')
--> 435 return _deserialize_model(f)
436
437
/conda/lib/python3.7/site-packages/keras/engine/saving.py in _deserialize_model(f, custom_objects, compile)
285 ' elements.')
286 weight_value_tuples += zip(symbolic_weights, weight_values)
--> 287 K.batch_set_value(weight_value_tuples)
288
289 if compile:
/conda/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py in batch_set_value(tuples)
2468 assign_ops.append(assign_op)
2469 feed_dict[assign_placeholder] = value
-> 2470 get_session().run(assign_ops, feed_dict=feed_dict)
2471
2472
~/.local/lib/python3.7/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
948 try:
949 result = self._run(None, fetches, feed_dict, options_ptr,
--> 950 run_metadata_ptr)
951 if run_metadata:
952 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~/.local/lib/python3.7/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1114 except Exception as e:
1115 raise TypeError(
-> 1116 'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
1117
1118 if isinstance(subfeed_val, ops.Tensor):
TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(160, 640), dtype=float32) is not an element of this graph.
Issue Analytics
- State:
- Created 4 years ago
- Comments:9 (4 by maintainers)
Top Results From Across the Web
dask_ml.model_selection.GridSearchCV - Dask-ML
Exhaustive search over specified parameter values for an estimator. GridSearchCV implements a “fit” and a “score” method. It also implements “predict”, “ ...
Read more >Does dask_ml.model_selection.GridSearchCV support GPU ...
I am trying to run sklearn.neural_network.MLPRegressor coupled with using dask_ml.model_selection.GridSearchCV to train my model, and I don't know whether the ...
Read more >How to Grid Search Hyperparameters for Deep Learning ...
The GridSearchCV process will then construct and evaluate one model for each combination of parameters. Cross validation is used to evaluate ...
Read more >Highest scored 'dask-ml' questions - Stack Overflow
import joblib from sklearn.externals.joblib import parallel_backend with joblib.parallel_backend('dask'): from dask_ml.model_selection import GridSearchCV ...
Read more >Grid Search - Dask - Saturn Cloud
This includes GridSearchCV and other hyperparameter search options. To use it, we load our data into a Dask DataFrame and use Dask ML's...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
There’s support for Keras serialization now in SciKeras, which brings a Scikit-Learn API to Keras. This is mentioned explicitly in the documentation on https://ml.dask.org/keras.html.
We’re trying to merge serialization support upstream in Tensorflow: https://github.com/tensorflow/tensorflow/pull/39609, https://github.com/tensorflow/community/pull/286
For those interested, I have an example of a trained keras model to predict with dask.
https://github.com/dask/distributed/issues/2333