Refactoring of `DeMask` model
See original GitHub issue❓ Questions and Help
Goal
Hey guys, I was planning to refactor DeMask
to use BaseEncoderMaskerDecoder
so that it supports 1D input data when traced with 2D (and the other way round). Currently, a traced DeMask
model will only work with data that has the same input shape as the example used for tracing because of the way forward
is written.
Refactoring the class with BaseEncoderMaskerDecoder
would automatically solve this issue in addition to cleaning the class. However I had some questions before making a whole bunch of changes (see below).
Questions about the current code
in __init__
- When computing
n_feats_input
withfb_type
set tostft
, isn’t the value
n_feats_input = (self.encoder.filterbank.n_filters) // 2 + 1
only valid for even values of n_filters
. If it is, wouldn’t we need to either support the case of odd values for n_filters
or raise an error when n_filters
is odd?
2. When computing n_feats_output
, there is a check on self.input_type
. Shouldn’t it be on self.output_type
instead? The error message also mentions “Input type” but shouldn’t it be about “Output type”?
In forward
The current flow of data in the forward
of BaseEncoderMaskerDecoder
is as follow (written as plain text as Mermaid diagrams are not supported on GitHub yet):
wav -> <input shaping> -> [Encoder] -> [Encoder postprocess] -> [Encoder activation] -> tf_rep
tf_rep -> [Masker] -> masker_outmasker_out --> [Masker postprocess] -> <mult> -> masked
| 1
|------------------------------------------------------------------|
masked -> [Decoder] -> [Decoder postprocess] -> <shape output> -> reconstruction
where the values in square brackets indicate blocks that can be changed by any inheriting class and the values in <>
indicate operations that are hard coded in the forward of BaseEncoderMaskerDecoder
.
As the forward of DeMask
is written, the input to the mask estimator is not always tf_rep
. To make it compatible with the BaseEncoderMaskerDecoder
, we could add a preprocess_masker_input
between tf_rep
and [Masker]
in the base class. Like all the hooks in BaseEncoderMaskerDecoder
, this would have a default behaviour of returning the data given as input.
Issue Analytics
- State:
- Created 3 years ago
- Comments:5
Top GitHub Comments
DeMask
was merged in a hurry so the code is not so clean indeed. It’s a good thing to clean it up, thanks!This code should probably use
filterbank.n_feats_out
which account for the nyquist frequency in the STFT. About the error, I think it’ll be raised intake_mag
, but we can raise it before as well.Correct.
Completely agree, let’s add that!
Things we have to pay attention to:
See #294 about that question.
Please open the PR so we can discuss it there 😃