Documentation for recurrent
See original GitHub issueI’m studying RNN’s using jax so I’m currently investigating flax. I think the documentation in the RNN module is incorrect or out of date.
Results in TypeError: apply() missing 1 required positional argument: 'inputs'
Also create
builds and evaluates the model and returns a (y, model), so I feel like the design has changed and the recurrent examples should either initialise the state before calling create
(but that wouldn’t scan
), or call create_by_shape
?
Edit: I found a test which seems to confirm that the docstring is incorrect i.e. the code below creates an initial carry and passes to create
2nd Edit:
Also, I’m slightly confused by LSTMCell.initialize_carry() - it requires a batch_dim, and returns an initialised (zero) state for each batch. I might be missing something but this seems to preclude using lax.scan() to process each batch sequentially using the state from the previous batch as the initial state for the next batch (or some other state estimator which is specifically what I’m attempting). For example I have 365 “trajectories” (timeseries) each consisting of 24 samples and 5 features. So the state should be size 5 and I want to scan each trajectory from some initial state I provide (the intent is to use another net to estimate the state).
Issue Analytics
- State:
- Created 4 years ago
- Comments:18 (3 by maintainers)
Top GitHub Comments
I’m closing this issue since most important questions seem to have been answered. David: if you have any other specific concerns / questions, can you please open a new issue? Thanks!
I think both your suggestions are good insights. I guess the stochastic initialize_carry is a bit less crucial (it is a most a bit more cumbersome in the current way), but the initialization problem seems like something users can run into more often. I will ask around how people usually handle this, and see if this can be improved, or if we can have clearer guidelines around how to best initialize such variables. I guess at least we should have some example. I’ll let you know when I have an answer.