Additional input features for decoder
See original GitHub issueI’m trying to implement a translation decoder (a subclass ofFairseqIncrementalDecoder
) that can take additional input features (e.g., language id, tags, etc.) in addition to input token. I am doing it by including it as a keyword argument in forward
, it looks like this:
def forward(self, prev_output_tokens, encoder_out, segments=None, incremental_state=None):
where segments
is the extra input I mentioned.
I also created a custom dataset and translation task to add the segment info into the batch. It works fine with fairseq-train
, however I can’t use the model for inference with fairseq-generate
since SequenceGenerator
directly calls the decoder with the default input (and without **kwargs
, https://github.com/pytorch/fairseq/blob/master/fairseq/sequence_generator.py#L608). Could you recommend a good approach for doing this? Thanks.
Issue Analytics
- State:
- Created 4 years ago
- Reactions:2
- Comments:7 (2 by maintainers)
Top GitHub Comments
Hello, what I did was to subclass an existing/suitable
FairseqDataset
andFairseqTask
in fairseq, then modify relevant methods accordingly. Like this: https://github.com/raymondhs/fairseq-laser/blob/master/laser/laser_dataset.py#L19-L37 Here I modified the collater function to include a target language embedding as input to the decoder. (I think it’s a little hacky though…)Thanks for your kind reply! I’ll figure out how to do this.