Sub-batch of a batch
See original GitHub issue🚀 The feature, motivation and pitch
I would like to be able to index a batch to obtain a sub-batch, rather than a list of data objects. E.g. if I have a Batch
object my_batch
, then something like my_batch.subbatch(np.arange(5))
would return another Batch
object containing the first 5 graphs in my_batch
. Perhaps this can be achieved by calling subgraph
as a subroutine.
Alternatives
Of course, Batch.from_data_list(my_batch[:5])
would achieve this functionality, but this seems slow. I would like to directly construct a sub-batch, if possible.
Additional context
In my particular application, I have a list of Data
objects, and I need to include a non-decreasing number of them in a Batch
every time I pass them into my GNN. I could pass Batch.from_data_list(my_list[:i])
into my GNN for each iteration i
, but I presume it would be faster if I could just construct a batch of this whole list at the start, complete_batch
, and then access sub-batches of that complete batch, passing complete_batch.subbatch(np.arange(i))
at each iteration i
.
Issue Analytics
- State:
- Created a year ago
- Comments:9 (3 by maintainers)
Good point, I have now added this.
Should we could reuse the logic in the index_select method for selecting the sub-batch elements as input into the sub-batch function? In PR above, currently requires that the input is
Tensor.bool
, but I think it could be generalized to the same input as used to return a list of elements as done in the linked function.