size of slices can’t be functions of argument values
See original GitHub issueI’m having problems implementing the following fori_loop
body.
def body_fun(i, val):
batch_sizes, packed_input_data, input_offset = val
batch_size = packed_input.batch_sizes[i]
#batch = packed_input_data[input_offset:input_offset+batch_size]
batch = dynamic_slice_in_dim(packed_input_data, input_offset, batch_size)
input_offset += batch_size
return batch_sizes, packed_input_data, input_offset
I’m trying to replace the commented out code with jax primitives - i.e. slicing packed_input_data
(n x m) based on the input_offset[i] (int). The issue I’m running into is dynamic_slice_in_dim
throws:
TypeError: Abstract value passed to
int
, which requires a concrete value. Try usingvalue.astype(int)
instead.
What seems to be happening is dynamic_slice_in_dim
attempts to cast input_offset
and/or batch_size
as int() which throws since they’re traced and don’t support this cast? i.e. If I use literal ints for the start_index, slice_size it works. Also if I call body_fun(i, val)
from a pure python for loop it also works no problems.
How should I approach this problem? I can paste full code if required, my full implementation also includes packed_output_data (same shape as packed_input_data) which I’m trying to update using dynamic_update_slice
plus a state array which also slice by batch_size.
(I’m trying to implement variable length sequence RNN’s using PyTorch as inspiration.)
Issue Analytics
- State:
- Created 4 years ago
- Reactions:1
- Comments:8 (3 by maintainers)
I’ve read through https://github.com/google/jax/issues/2308 which seems similar but I’m still cannot get this to work.
No matter what primitive I use to translate the 2nd slice operation below it fails with an exception that seems to imply batch_size must be a concrete value but an abstract value was supplied.
Yes thanks!