question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[tune] torchtext pickle error

See original GitHub issue

The following codes run on the colab,

analysis = tune.run(train_lstm, config={
            "lr": tune.grid_search([0.0005, 0.001, 0.005]),
            "clip": tune.grid_search([1.0, 0.5, 0.1])
            }
    )

print("Best config: ", analysis.get_best_config(metric="mean_accuracy"))

# Get a dataframe for analyzing trial results.
df = analysis.dataframe()

where train_lstm is my own fined function using pytorch.

<ipython-input-31-c51e8c1cd47e> in <module>()
      2 analysis = tune.run(train_lstm, config={
      3             "lr": tune.grid_search([0.0005, 0.001, 0.005]),
----> 4             "clip": tune.grid_search([1.0, 0.5, 0.1])
      5             }
      6     )

5 frames
/usr/local/lib/python3.6/dist-packages/ray/tune/tune.py in run(run_or_experiment, name, stop, config, resources_per_trial, num_samples, local_dir, upload_dir, trial_name_creator, loggers, sync_to_cloud, sync_to_driver, checkpoint_freq, checkpoint_at_end, sync_on_checkpoint, keep_checkpoints_num, checkpoint_score_attr, global_checkpoint_period, export_formats, max_failures, restore, search_alg, scheduler, with_server, server_port, verbose, resume, queue_trials, reuse_actors, trial_executor, raise_on_failed_trial, return_trials, ray_auto_init, sync_function)
    227     for i, exp in enumerate(experiments):
    228         if not isinstance(exp, Experiment):
--> 229             run_identifier = Experiment.register_if_needed(exp)
    230             experiments[i] = Experiment(
    231                 name=name,

/usr/local/lib/python3.6/dist-packages/ray/tune/experiment.py in register_if_needed(cls, run_object)
    210                 logger.warning(
    211                     "No name detected on trainable. Using {}.".format(name))
--> 212             register_trainable(name, run_object)
    213             return name
    214         else:

/usr/local/lib/python3.6/dist-packages/ray/tune/registry.py in register_trainable(name, trainable)
     65         raise TypeError("Second argument must be convertable to Trainable",
     66                         trainable)
---> 67     _global_registry.register(TRAINABLE_CLASS, name, trainable)
     68 
     69 

/usr/local/lib/python3.6/dist-packages/ray/tune/registry.py in register(self, category, key, value)
    104             raise TuneError("Unknown category {} not among {}".format(
    105                 category, KNOWN_CATEGORIES))
--> 106         self._to_flush[(category, key)] = pickle.dumps(value)
    107         if _internal_kv_initialized():
    108             self.flush_values()

/usr/local/lib/python3.6/dist-packages/ray/cloudpickle/cloudpickle_fast.py in dumps(obj, protocol, buffer_callback)
     66     with io.BytesIO() as file:
     67         cp = CloudPickler(file, protocol=protocol, buffer_callback=buffer_callback)
---> 68         cp.dump(obj)
     69         return file.getvalue()
     70 

/usr/local/lib/python3.6/dist-packages/ray/cloudpickle/cloudpickle_fast.py in dump(self, obj)
    555     def dump(self, obj):
    556         try:
--> 557             return Pickler.dump(self, obj)
    558         except RuntimeError as e:
    559             if "recursion" in e.args[0]:

TypeError: 'generator' object is not callable

Can anyone figure out what is run with the pickle?

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:7 (5 by maintainers)

github_iconTop GitHub Comments

6reactions
richardliawcommented, Feb 22, 2020

Sorry for the slow reply - I actually thought I posted a reply already 😃

This is a scoping issue. Torchtext fields and datasets need to be created within the Trainable because the Trainable needs to be pickled and replicated across multiple parallel processes.

If they are not, Ray/cloudpickle will try to serialize them and fail.

For example, you’ll want to do something like:

from ray import tune
from ray.tune.examples.mnist_pytorch import get_data_loaders, ConvNet, train, test

from nltk.util import ngrams
from collections import defaultdict, Counter
import numpy as np
import math
import tqdm
import random

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn import Parameter
import torchtext

##### XXX Note that we duplicate this data loading here because we need the vocab size to create the model
# download and load the data

text_field = torchtext.data.Field()
datasets = torchtext.datasets.WikiText2.splits(root='.', text_field=text_field)
train_dataset, validation_dataset, test_dataset = datasets
text_field.build_vocab(train_dataset)
vocab = text_field.vocab
vocab_size = len(vocab)

train_text = train_dataset.examples[0].text # a list of tokens (strings)
validation_text = validation_dataset.examples[0].text

import torch.optim as optim

#  ... other code 

class LSTMModel():
    #  ... other code 
    def train_on_dataset(self, dataset):
        train_iterator = torchtext.data.BPTTIterator(dataset, batch_size=64, 
                                                     bptt_len=32, device='cuda')
        ## THE REST OF THE CODE IS THE SAME AS YOUR `train` FUNCTION
        h = Variable(torch.zeros(3, 64, self.network.lstm.hidden_size), requires_grad=False).cuda()
        c = Variable(torch.zeros(3, 64, self.network.lstm.hidden_size), requires_grad=False).cuda()
        state = ( h, c )

        optimizer = torch.optim.Adam(self.network.parameters(), lr=self.lr)
        n_epochs = 20

        validation_score_best = 300
        PATH = '/content/model.pth'


        for epoch in range(n_epochs):
            print('Epoch', epoch)
            for batch in tqdm.tqdm_notebook(train_iterator, leave=False):
                assert self.network.training, 'make sure your network is in train mode with `.train()`'

                # call zero_grad on your optimizer
                optimizer.zero_grad()

                context, word = batch.text, batch.target
                # run your network
                logits, state = self.network(context, state)

                logits = logits.view(-1, logits.shape[-1])
                word = word.view(-1,)

                # compute a loss
                loss = F.cross_entropy(logits, word)

                # call `.backward()` and `.step()` on your optimizer
                loss.backward()
                optimizer.step()

            validation_score = self.dataset_perplexity(validation_dataset)
            print('Validation score:', validation_score)
            
            if validation_score <= validation_score_best:
                validation_score_best = validation_score
                torch.save(self.network.state_dict(), PATH)
        
        
        self.network.load_state_dict(torch.load(PATH))

def train_net(config):
    text_field = torchtext.data.Field()
    datasets = torchtext.datasets.WikiText2.splits(root='.', text_field=text_field)
    train_dataset, validation_dataset, test_dataset = datasets
    validation_text = validation_dataset.examples[0].text
    lstm_model = LSTMModel(config['lr'], config['clip'])
    lstm_model.train_on_dataset(train_dataset)
    perp = lstm_model.perplexity(validation_text)
    print('lstm validation perplexity:', perp)
    tune.track.log(perplexity=perp)


analysis = tune.run(
    train_net, config={"lr": tune.grid_search([0.001, 0.01]), 'clip':  tune.grid_search([0.001, 0.01])})

print("Best config: ", analysis.get_best_config(metric="mean_accuracy"))

# Get a dataframe for analyzing trial results.
df = analysis.dataframe()
0reactions
JiahaoYaocommented, Feb 24, 2020

Thanks @richardliaw , and I will close #7202

Read more comments on GitHub >

github_iconTop Results From Across the Web

Newest 'torchtext' Questions - Stack Overflow
I am trying to use torchtext to process test data, however, I get the error: "AttributeError: module 'torchtext' has no attribute 'legacy'", when...
Read more >
Language modeling - Jupyter Notebooks Gallery
First, we create a torchtext field, which describes how to preprocess a piece of text - in this case, we tell torchtext to...
Read more >
Changelog — PyTorch Lightning 1.8.5 documentation
Fixed a pickling error when using RichProgressBar together with checkpointing (#15319). Fixed the RichProgressBar crashing when used with distributed ...
Read more >
torchtext Changelog - pyup.io
Fix OBO error for vocab files with empty lines (1841) * Fixing build when CUDA enabled torch is ... Added parameterized dataset pickling...
Read more >
A Tutorial on Torchtext - Allen Nie
import spacy spacy_en = spacy.load('en') def tokenizer(text): # create a tokenizer function return [tok.text for tok in spacy_en. · train, val, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found