question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

dask_ml.model_selection.GridSearchCV errors for keras model

See original GitHub issue

I am trying to fill Keras model into dask_ml.model_selection.GridSearchCV. If I do not set client, it works fines. However, I got errors if I have two dask workers. It seems to be unable to deserialize something. I appreciate if anyone has suggestions about this problem.

distributed.protocol.pickle - INFO - Failed to deserialize b'\x80\x04\x95b\x05\x00\x00\x00\x00\x00\x00\x8c\x1bkeras.wrappers.scikit_learn\x94\x8c\x0fKerasClassifier\x94\x93\x94)\x81\x94}

ValueError                                Traceback (most recent call last)
~/.local/lib/python3.7/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1112             subfeed_t = self.graph.as_graph_element(
-> 1113                 subfeed, allow_tensor=True, allow_operation=False)
   1114           except Exception as e:

~/.local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
   3795     with self._lock:
-> 3796       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3797 

~/.local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
   3874       if obj.graph is not self:
-> 3875         raise ValueError("Tensor %s is not an element of this graph." % obj)
   3876       return obj

ValueError: Tensor Tensor("Placeholder:0", shape=(160, 640), dtype=float32) is not an element of this graph.

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<timed exec> in <module>

/conda/lib/python3.7/site-packages/dask_ml/model_selection/_search.py in fit(self, X, y, groups, **fit_params)
   1293             dsk, keys = build_refit_graph(estimator, X, y, best_params, fit_params)
   1294 
-> 1295             out = scheduler(dsk, keys, num_workers=n_jobs)
   1296             self.best_estimator_ = out[0]
   1297 

/conda/lib/python3.7/site-packages/distributed/client.py in get(self, dsk, keys, restrictions, loose_restrictions, resources, sync, asynchronous, direct, retries, priority, fifo_timeout, actors, **kwargs)
   2525                     should_rejoin = False
   2526             try:
-> 2527                 results = self.gather(packed, asynchronous=asynchronous, direct=direct)
   2528             finally:
   2529                 for f in futures.values():

/conda/lib/python3.7/site-packages/distributed/client.py in gather(self, futures, errors, direct, asynchronous)
   1821                 direct=direct,
   1822                 local_worker=local_worker,
-> 1823                 asynchronous=asynchronous,
   1824             )
   1825 

/conda/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    761         else:
    762             return sync(
--> 763                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    764             )
    765 

/conda/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    330             e.wait(10)
    331     if error[0]:
--> 332         six.reraise(*error[0])
    333     else:
    334         return result[0]

/conda/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None

/conda/lib/python3.7/site-packages/distributed/utils.py in f()
    315             if callback_timeout is not None:
    316                 future = gen.with_timeout(timedelta(seconds=callback_timeout), future)
--> 317             result[0] = yield future
    318         except Exception as exc:
    319             error[0] = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
   1705                 else:
   1706                     self._gather_future = future
-> 1707                 response = yield future
   1708 
   1709             if response["status"] == "error":

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/client.py in _gather_remote(self, direct, local_worker)
   1758 
   1759             else:  # ask scheduler to gather data for us
-> 1760                 response = yield self.scheduler.gather(keys=keys)
   1761         finally:
   1762             self._gather_semaphore.release()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/core.py in send_recv_from_rpc(**kwargs)
    739             name, comm.name = comm.name, "ConnectionPool." + key
    740             try:
--> 741                 result = yield send_recv(comm=comm, op=key, **kwargs)
    742             finally:
    743                 self.pool.reuse(self.addr, comm)

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/core.py in send_recv(comm, reply, serializers, deserializers, **kwargs)
    533         yield comm.write(msg, serializers=serializers, on_error="raise")
    534         if reply:
--> 535             response = yield comm.read(deserializers=deserializers)
    536         else:
    537             response = None

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/comm/tcp.py in read(self, deserializers)
    216             try:
    217                 msg = yield from_frames(
--> 218                     frames, deserialize=self.deserialize, deserializers=deserializers
    219                 )
    220             except EOFError:

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

/conda/lib/python3.7/site-packages/distributed/comm/utils.py in from_frames(frames, deserialize, deserializers)
     81 
     82     if deserialize and size > FRAME_OFFLOAD_THRESHOLD:
---> 83         res = yield offload(_from_frames)
     84     else:
     85         res = _from_frames()

/conda/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

/conda/lib/python3.7/concurrent/futures/_base.py in result(self, timeout)
    423                 raise CancelledError()
    424             elif self._state == FINISHED:
--> 425                 return self.__get_result()
    426 
    427             self._condition.wait(timeout)

/conda/lib/python3.7/concurrent/futures/_base.py in __get_result(self)
    382     def __get_result(self):
    383         if self._exception:
--> 384             raise self._exception
    385         else:
    386             return self._result

/conda/lib/python3.7/concurrent/futures/thread.py in run(self)
     55 
     56         try:
---> 57             result = self.fn(*self.args, **self.kwargs)
     58         except BaseException as exc:
     59             self.future.set_exception(exc)

/conda/lib/python3.7/site-packages/distributed/comm/utils.py in _from_frames()
     69         try:
     70             return protocol.loads(
---> 71                 frames, deserialize=deserialize, deserializers=deserializers
     72             )
     73         except EOFError:

/conda/lib/python3.7/site-packages/distributed/protocol/core.py in loads(frames, deserialize, deserializers)
    124                     fs = decompress(head, fs)
    125                 fs = merge_frames(head, fs)
--> 126                 value = _deserialize(head, fs, deserializers=deserializers)
    127             else:
    128                 value = Serialized(head, fs)

/conda/lib/python3.7/site-packages/distributed/protocol/serialize.py in deserialize(header, frames, deserializers)
    188         )
    189     dumps, loads, wants_context = families[name]
--> 190     return loads(header, frames)
    191 
    192 

/conda/lib/python3.7/site-packages/distributed/protocol/serialize.py in pickle_loads(header, frames)
     62 
     63 def pickle_loads(header, frames):
---> 64     return pickle.loads(b"".join(frames))
     65 
     66 

/conda/lib/python3.7/site-packages/distributed/protocol/pickle.py in loads(x)
     59 def loads(x):
     60     try:
---> 61         return pickle.loads(x)
     62     except Exception:
     63         logger.info("Failed to deserialize %s", x[:10000], exc_info=True)

/conda/lib/python3.7/site-packages/keras/engine/network.py in __setstate__(self, state)
   1264 
   1265     def __setstate__(self, state):
-> 1266         model = saving.unpickle_model(state)
   1267         self.__dict__.update(model.__dict__)
   1268 

/conda/lib/python3.7/site-packages/keras/engine/saving.py in unpickle_model(state)
    433 def unpickle_model(state):
    434     f = h5dict(state, mode='r')
--> 435     return _deserialize_model(f)
    436 
    437 

/conda/lib/python3.7/site-packages/keras/engine/saving.py in _deserialize_model(f, custom_objects, compile)
    285                              ' elements.')
    286         weight_value_tuples += zip(symbolic_weights, weight_values)
--> 287     K.batch_set_value(weight_value_tuples)
    288 
    289     if compile:

/conda/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py in batch_set_value(tuples)
   2468             assign_ops.append(assign_op)
   2469             feed_dict[assign_placeholder] = value
-> 2470         get_session().run(assign_ops, feed_dict=feed_dict)
   2471 
   2472 

~/.local/lib/python3.7/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    948     try:
    949       result = self._run(None, fetches, feed_dict, options_ptr,
--> 950                          run_metadata_ptr)
    951       if run_metadata:
    952         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/.local/lib/python3.7/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1114           except Exception as e:
   1115             raise TypeError(
-> 1116                 'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
   1117 
   1118           if isinstance(subfeed_val, ops.Tensor):

TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("Placeholder:0", shape=(160, 640), dtype=float32) is not an element of this graph.

​ ​ ​

Issue Analytics

  • State:open
  • Created 4 years ago
  • Comments:9 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
stsievertcommented, Sep 3, 2020

There’s support for Keras serialization now in SciKeras, which brings a Scikit-Learn API to Keras. This is mentioned explicitly in the documentation on https://ml.dask.org/keras.html.

We’re trying to merge serialization support upstream in Tensorflow: https://github.com/tensorflow/tensorflow/pull/39609, https://github.com/tensorflow/community/pull/286

0reactions
bw4szcommented, Jan 16, 2020

For those interested, I have an example of a trained keras model to predict with dask.

https://github.com/dask/distributed/issues/2333

Read more comments on GitHub >

github_iconTop Results From Across the Web

dask_ml.model_selection.GridSearchCV - Dask-ML
Exhaustive search over specified parameter values for an estimator. GridSearchCV implements a “fit” and a “score” method. It also implements “predict”, “ ...
Read more >
Does dask_ml.model_selection.GridSearchCV support GPU ...
I am trying to run sklearn.neural_network.MLPRegressor coupled with using dask_ml.model_selection.GridSearchCV to train my model, and I don't know whether the ...
Read more >
How to Grid Search Hyperparameters for Deep Learning ...
The GridSearchCV process will then construct and evaluate one model for each combination of parameters. Cross validation is used to evaluate ...
Read more >
Highest scored 'dask-ml' questions - Stack Overflow
import joblib from sklearn.externals.joblib import parallel_backend with joblib.parallel_backend('dask'): from dask_ml.model_selection import GridSearchCV ...
Read more >
Grid Search - Dask - Saturn Cloud
This includes GridSearchCV and other hyperparameter search options. To use it, we load our data into a Dask DataFrame and use Dask ML's...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found