question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

RandomizedSearchCV with pytorch-tabnet

See original GitHub issue

It appears that the TabNetClassifier does not have a get_params method for hyperparameter estimation.

Is this reproducible your end?

Many thanks

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-33-03d6c8d15377> in <module>()
      4 
      5 start = time()
----> 6 randomSearch.fit(X_train, y_train)
      7 
      8 

1 frames
/usr/local/lib/python3.6/dist-packages/sklearn/base.py in clone(estimator, safe)
     65                             "it does not seem to be a scikit-learn estimator "
     66                             "as it does not implement a 'get_params' methods."
---> 67                             % (repr(estimator), type(estimator)))
     68     klass = estimator.__class__
     69     new_object_params = estimator.get_params(deep=False)

TypeError: Cannot clone object 'TabNetClassifier(n_d=32, n_a=32, n_steps=5,
                 lr=0.02, seed=0,
                 gamma=1.5, n_independent=2, n_shared=2,
                 cat_idxs=[],
                 cat_dims=[],
                 cat_emb_dim=1,
                 lambda_sparse=0.0001, momentum=0.3,
                 clip_value=2.0,
                 verbose=1, device_name="auto",
                 model_name="DreamQuarkTabNet", epsilon=1e-15,
                 optimizer_fn=<class 'torch.optim.adam.Adam'>,
                 scheduler_params={'gamma': 0.95, 'step_size': 20},
                 scheduler_fn=<class 'torch.optim.lr_scheduler.StepLR'>, saving_path="./")' (type <class 'pytorch_tabnet.tab_model.TabNetClassifier'>): it does not seem to be a scikit-learn estimator as it does not implement a 'get_params' methods.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:19 (7 by maintainers)

github_iconTop GitHub Comments

3reactions
meechoscommented, Apr 20, 2020

@Catypad, compatibility with sklearn is not established at the level of using sklearn pipelines as in GridSearchCV out of the box. You may try ParameterGrid as in the pseudocode below:

from sklearn.model_selection import ParameterGrid

# Function that instantiates a tabnet model.
def create_tabnet(n_d=32, n_steps=5, lr=0.02, gamma=1.5, 
                  n_independent=2, n_shared=2, lambda_sparse=1e-4, 
                  momentum=0.3, clip_value=2.):
    
    model = TabNetClassifier(
        n_d=n_d, n_a=n_a, n_steps=n_steps,
        lr=lr,
        gamma=gamma, n_independent=n_independent, n_shared=n_shared,
        cat_idxs=cat_idxs,
        cat_dims=cat_dims,
        cat_emb_dim=cat_emb_dim,
        lambda_sparse=lambda_sparse, momentum=momentum, clip_value=clip_value,
        optimizer_fn=torch.optim.Adam,
        scheduler_params = {"gamma": 0.95,
                         "step_size": 20},
        scheduler_fn=torch.optim.lr_scheduler.StepLR, epsilon=1e-15, verbose = 0
    )
    return model
                  
# Generate the parameter grid.
param_grid = dict(n_d = [24, 32],
                  n_a = [24],
                  n_steps = [3, 4, 5],
                  lr = [0.01, 0.02],
                  gamma = [1, 1.5, 2],
                  lambda_sparse = [1e-2, 1e-3, 1e-4],
                  momentum = [0.3, 0.4, 0.5],
                  n_shared = [2],
                  n_independent = [2],
                  clip_value = [2.],     
)

grid = ParameterGrid(param_grid)

search_results = pd.DataFrame() 
for params in grid:
    params['n_a'] = params['n_d'] # n_a=n_d always per the paper
    tabnet = create_tabnet()
    tabnet.set_params(**params)
    tabnet.fit(X_train=X_train, y_train=y_train,
               X_valid=X_valid, y_valid=y_valid,
               max_epochs=num_epochs, patience=patience,
               batch_size=batch_size, virtual_batch_size=virtual_batch_size,
              )

    y_prob = tabnet.predict_proba(X_test)
    score = accuracy_score(y_test, y_prob)
    
    results = pd.DataFrame([params])
    results['score'] = np.round(score, 3)
    search_results = search_results.append(results)

Hope it helps!

1reaction
meechoscommented, Feb 19, 2020

I hope my feedback has been received as constructive, as it’s meant to be.

That said, I’ll do my best to add it, cheers!

Read more comments on GitHub >

github_iconTop Results From Across the Web

[TF.Keras] Melanoma Classification Starter, TabNet | Kaggle
To get the optimal hyper-param, RandomizedSearchCV is also used in general. ... We've used an open-source PyTorch implementation of TabNet model.
Read more >
README — pytorch_tabnet documentation - GitHub Pages
This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning.
Read more >
4. Model Training Patterns - Machine Learning Design ...
To implement checkpoints in PyTorch, ask for the epoch, model state, ... TabNet employs a technique that first uses unsupervised learning to learn ......
Read more >
Implementing TabNet in PyTorch - Towards Data Science
Deep Learning has taken over vision, natural language processing, speech recognition, and many other fields achieving astonishing results ...
Read more >
hyperopt/hyperopt-sklearn - Github Repositories Trend
PyTorch implementation of TabNet paper ... A drop-in replacement for Scikit-Learn's GridSearchCV / RandomizedSearchCV -- but with cutting edge ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found