RandomizedSearchCV with pytorch-tabnet
See original GitHub issueIt appears that the TabNetClassifier does not have a get_params
method for hyperparameter estimation.
Is this reproducible your end?
Many thanks
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-33-03d6c8d15377> in <module>()
4
5 start = time()
----> 6 randomSearch.fit(X_train, y_train)
7
8
1 frames
/usr/local/lib/python3.6/dist-packages/sklearn/base.py in clone(estimator, safe)
65 "it does not seem to be a scikit-learn estimator "
66 "as it does not implement a 'get_params' methods."
---> 67 % (repr(estimator), type(estimator)))
68 klass = estimator.__class__
69 new_object_params = estimator.get_params(deep=False)
TypeError: Cannot clone object 'TabNetClassifier(n_d=32, n_a=32, n_steps=5,
lr=0.02, seed=0,
gamma=1.5, n_independent=2, n_shared=2,
cat_idxs=[],
cat_dims=[],
cat_emb_dim=1,
lambda_sparse=0.0001, momentum=0.3,
clip_value=2.0,
verbose=1, device_name="auto",
model_name="DreamQuarkTabNet", epsilon=1e-15,
optimizer_fn=<class 'torch.optim.adam.Adam'>,
scheduler_params={'gamma': 0.95, 'step_size': 20},
scheduler_fn=<class 'torch.optim.lr_scheduler.StepLR'>, saving_path="./")' (type <class 'pytorch_tabnet.tab_model.TabNetClassifier'>): it does not seem to be a scikit-learn estimator as it does not implement a 'get_params' methods.
Issue Analytics
- State:
- Created 4 years ago
- Comments:19 (7 by maintainers)
Top Results From Across the Web
[TF.Keras] Melanoma Classification Starter, TabNet | Kaggle
To get the optimal hyper-param, RandomizedSearchCV is also used in general. ... We've used an open-source PyTorch implementation of TabNet model.
Read more >README — pytorch_tabnet documentation - GitHub Pages
This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning.
Read more >4. Model Training Patterns - Machine Learning Design ...
To implement checkpoints in PyTorch, ask for the epoch, model state, ... TabNet employs a technique that first uses unsupervised learning to learn ......
Read more >Implementing TabNet in PyTorch - Towards Data Science
Deep Learning has taken over vision, natural language processing, speech recognition, and many other fields achieving astonishing results ...
Read more >hyperopt/hyperopt-sklearn - Github Repositories Trend
PyTorch implementation of TabNet paper ... A drop-in replacement for Scikit-Learn's GridSearchCV / RandomizedSearchCV -- but with cutting edge ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@Catypad, compatibility with sklearn is not established at the level of using sklearn pipelines as in
GridSearchCV
out of the box. You may tryParameterGrid
as in the pseudocode below:Hope it helps!
I hope my feedback has been received as constructive, as it’s meant to be.
That said, I’ll do my best to add it, cheers!