How to use torchtext for sequence labelling with wordpiece tokeniers
See original GitHub issue❓ Questions and Help
Description
Hi,
In a previous issue (#609), I asked how to use the tokenizer from the Transformers library with torch text.
I now would like to be able to use this tokenizer and torchtext to load sequence labelling datasets. The issue I am facing is that the tokenizer introduces wordpiece tokens, which ends up breaking the alignment between tokens and labels.
Ignoring labels, I am able to load a sequence labelling dataset with a Transformer tokenizer like so,
from torchtext import data
from torchtext import datasets
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
def preprocessor(batch):
return tokenizer.encode(batch, add_special_tokens=True)
TEXT = data.Field(
use_vocab=False,
batch_first=True,
pad_token=tokenizer.pad_token_id,
preprocessing=preprocessor
)
# LABEL = data.LabelField()
fields = [('text', TEXT), ('unused_col_1', None), ('unused_col_2', None), ('label', None)]
train, valid, test = datasets.SequenceTaggingDataset.splits(
path='/Users/johngiorgi/Downloads/bert_data/BC5CDR/chem',
train='train.tsv',
validation='devel.tsv',
test='test.tsv',
fields=fields
)
train_iter, valid_iter, test_iter = data.BucketIterator.splits(
(train, valid, test), batch_sizes=(16, 256, 256)
)
# LABEL.build_vocab(train)
The data comes from here, and is a tab-seperated file with four columns. The first column contains words, the last labels and each sentence is sperated by a newline, e.g.
Naloxone 227508 0 B
reverses - 9 O
the - 18 O
antihypertensive - 22 O
effect - 39 O
of - 46 O
clonidine - 49 B
. - 58 O
In 227508 60 O
.
.
.
But when I try to load the labels, e.g.
from torchtext import data
from torchtext import datasets
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
def preprocessor(batch):
return tokenizer.encode(batch, add_special_tokens=True)
TEXT = data.Field(
use_vocab=False,
batch_first=True,
pad_token=tokenizer.pad_token_id,
preprocessing=preprocessor
)
LABEL = data.LabelField()
fields = [('text', TEXT), ('unused_col_1', None), ('unused_col_2', None), ('label', LABEL)]
train, valid, test = datasets.SequenceTaggingDataset.splits(
path='/Users/johngiorgi/Downloads/bert_data/BC5CDR/chem',
train='train.tsv',
validation='devel.tsv',
test='test.tsv',
fields=fields
)
train_iter, valid_iter, test_iter = data.BucketIterator.splits(
(train, valid, test), batch_sizes=(16, 256, 256)
)
LABEL.build_vocab(train)
I get issues when trying to access the batch
batch = next(iter(train_iter))
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-39-9919119fad82> in <module>
----> 1 batch = next(iter(train_iter))
~/miniconda3/envs/ml4h/lib/python3.7/site-packages/torchtext/data/iterator.py in __iter__(self)
154 else:
155 minibatch.sort(key=self.sort_key, reverse=True)
--> 156 yield Batch(minibatch, self.dataset, self.device)
157 if not self.repeat:
158 return
~/miniconda3/envs/ml4h/lib/python3.7/site-packages/torchtext/data/batch.py in __init__(self, data, dataset, device)
32 if field is not None:
33 batch = [getattr(x, name) for x in data]
---> 34 setattr(self, name, field.process(batch, device=device))
35
36 @classmethod
~/miniconda3/envs/ml4h/lib/python3.7/site-packages/torchtext/data/field.py in process(self, batch, device)
235 """
236 padded = self.pad(batch)
--> 237 tensor = self.numericalize(padded, device=device)
238 return tensor
239
~/miniconda3/envs/ml4h/lib/python3.7/site-packages/torchtext/data/field.py in numericalize(self, arr, device)
336 arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
337 else:
--> 338 arr = [self.vocab.stoi[x] for x in arr]
339
340 if self.postprocessing is not None:
~/miniconda3/envs/ml4h/lib/python3.7/site-packages/torchtext/data/field.py in <listcomp>(.0)
336 arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
337 else:
--> 338 arr = [self.vocab.stoi[x] for x in arr]
339
340 if self.postprocessing is not None:
TypeError: unhashable type: 'list'
Which I am guessing arise because the number of items in the text and label fields are no longer the same.
Has anyone come across this issue and been able to solve it? I know how to write a function to re-align the labels with the wordpiece tokenized text. Where might I insert that function in the loading process?
Issue Analytics
- State:
- Created 4 years ago
- Comments:6 (2 by maintainers)
Top GitHub Comments
@haorannlp Try AllenNLP!
This is an oversight with
LabelField
assequential
defaults toFalse
. Please replace this linewith
You will get misaligned (sequence-length wise) batches but that’s fine as you know how to align them.