CatBoostClassifier serialization doesn't work
See original GitHub issueSystem information
- OS: macOS Sierra
- Ray installed from: binary
- Ray version: 0.5.3
- Python version: 3.6.4
- Exact command to reproduce: run the code below
Describe the problem
When trying to send CatBoostClassifier object to a remote function fails due to serialization. Using pickle works fine.
Source code / logs
Fails
import ray
import numpy as np
from catboost import CatBoostClassifier
ray.init()
@ray.remote
def fit(model):
train_data = np.random.randint(0, 100, size=(100, 10))
train_label = np.random.randint(0, 2, size=(100))
model.fit(train_data, train_label, cat_features=[0, 2, 5])
return model
test_data = np.random.randint(0, 100, size=(50, 10))
model = CatBoostClassifier(iterations=2, depth=2, learning_rate=1, loss_function='Logloss', logging_level='Verbose')
model = ray.get(fit.remote(model))
preds_class = model.predict(test_data)
preds_proba = model.predict_proba(test_data)
print("class = ", preds_class)
print("proba = ", preds_proba)
Log
WARNING: Serializing objects of type <class 'catboost.core.CatBoostClassifier'> by expanding them as dictionaries of their fields. This behavior may be incorrect in some cases.
WARNING: Falling back to serializing objects of type <class '_catboost._CatBoost'> by using pickle. This may be inefficient.
Traceback (most recent call last):
File "/Users/eavidan/IdeaProjects/python_playground/ray/serialization issue.py", line 22, in <module>
model = ray.get(fit.remote(model))
File "/Users/eavidan/anaconda/envs/ray/lib/python3.6/site-packages/ray/worker.py", line 2451, in func_call
objectids = _submit_task(function_id, args)
File "/Users/eavidan/anaconda/envs/ray/lib/python3.6/site-packages/ray/worker.py", line 2300, in _submit_task
return worker.submit_task(function_id, args)
File "/Users/eavidan/anaconda/envs/ray/lib/python3.6/site-packages/ray/worker.py", line 528, in submit_task
args_for_local_scheduler.append(put(arg))
File "/Users/eavidan/anaconda/envs/ray/lib/python3.6/site-packages/ray/worker.py", line 2198, in put
worker.put_object(object_id, value)
File "/Users/eavidan/anaconda/envs/ray/lib/python3.6/site-packages/ray/worker.py", line 357, in put_object
self.store_and_register(object_id, value)
File "/Users/eavidan/anaconda/envs/ray/lib/python3.6/site-packages/ray/worker.py", line 287, in store_and_register
object_id.id()), self.serialization_context)
File "plasma.pyx", line 394, in pyarrow.plasma.PlasmaClient.put
File "serialization.pxi", line 235, in pyarrow.lib.serialize
File "serialization.pxi", line 106, in pyarrow.lib.SerializationContext._serialize_callback
File "/Users/eavidan/anaconda/envs/ray/lib/python3.6/site-packages/cloudpickle/cloudpickle.py", line 881, in dumps
cp.dump(obj)
File "/Users/eavidan/anaconda/envs/ray/lib/python3.6/site-packages/cloudpickle/cloudpickle.py", line 268, in dump
return Pickler.dump(self, obj)
File "/Users/eavidan/anaconda/envs/ray/lib/python3.6/pickle.py", line 409, in dump
self.save(obj)
File "/Users/eavidan/anaconda/envs/ray/lib/python3.6/pickle.py", line 496, in save
rv = reduce(self.proto)
File "stringsource", line 2, in _catboost._CatBoost.__reduce_cython__
TypeError: no default __reduce__ due to non-trivial __cinit__
Disconnecting client on fd 9
[WARN] (/Users/rkn/Workspace/ray/src/global_scheduler/global_scheduler.cc:404) Missed too many heartbeats from local scheduler, marking as dead.
Disconnecting client on fd 5
Process finished with exit code 1
Works
import ray
import numpy as np
from catboost import CatBoostClassifier
import pickle
ray.init()
@ray.remote
def fit(model):
train_data = np.random.randint(0, 100, size=(100, 10))
train_label = np.random.randint(0, 2, size=(100))
model = pickle.loads(model)
model.fit(train_data, train_label, cat_features=[0, 2, 5])
return pickle.dumps(model)
test_data = np.random.randint(0, 100, size=(50, 10))
model = CatBoostClassifier(iterations=2, depth=2, learning_rate=1, loss_function='Logloss', logging_level='Verbose')
model = pickle.loads(ray.get(fit.remote(pickle.dumps(model))))
preds_class = model.predict(test_data)
preds_proba = model.predict_proba(test_data)
print("class = ", preds_class)
print("proba = ", preds_proba)
Issue Analytics
- State:
- Created 5 years ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
Bug in CatBoost? CatBoostClassifier doesn't work well with ...
If I use CatBoostClassifier indipendently I get normal looking probabilities. This leads me to believe that this Classifier is not ...
Read more >predict - CatBoostClassifier | CatBoost
The model prediction results will be correct only if the data parameter with feature values contains all the features used in the model....
Read more >Serialization and Deserialization Issues in Spring REST
Deserialization of a JSON @RequestParam object can cause parsing errors if the JSON object is not well-formed.
Read more >Serializing System.object or System.Type doesnt work. Data is ...
Is there anything special I have to do to serialize System.object or System.Type, I've tried using [Serializable] and [SerializeField] but ...
Read more >Polymorphic serialization doesn't work for deeper sealed class ...
I have the following code: @Serializable sealed class Component : Tagged<UUID, Component> { @Serializable @SerialName("name") data class ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
This is fixed by https://github.com/ray-project/ray/pull/3468 now!
Running your code from above: