add support for groups in train_test_split
See original GitHub issuetrain_test_split
has support for options ‘stratify’ and ‘shuffle’ but not ‘groups’.
I’m interested in adding support for ‘groups’ to train_test_split
(so that all samples from the same group will be in either train or test but not both). In 0.18.1 there is support for the option ‘stratify’ and in master there is recently added support for ‘shuffle’. If others think it might be useful I’d like to make a PR.
The rules starts to get a little complicated with options for ‘stratify’, ‘shuffle’, and ‘groups’ interacting. It makes sense to throw an error when groups
is not None
and shuffle
is False
(there is similar logic for stratify
and shuffle
in master). And it makes sense for a ValueError
to be raised if groups
is not None
and stratify
is not None
since there is a class GroupShuffleSplit
and StratifyShuffleSplit
but no StratifyGroupShuffleSplit
.
The rules look something like this:
stratify |
shuffle |
groups |
behavior |
---|---|---|---|
None | True | None | use ShuffleSplit |
None | False | None | no shuffling, just splits on n_train |
not None | True | None | use StratifiedShuffleSplit |
not None | False | None | raise ValueError |
None | True | not None | use GroupShuffleSplit (proposed) |
None | False | not None | raise ValueError (proposed) |
not None | True | not None | raise ValueError (proposed) |
not None | False | not None | raise ValueError (proposed) |
Another possibility is for train_test_split
to be explicitly passed a cross-validator class (rather than figuring it out), but that might be adding more burden on the caller, considering this is a convenience function.
If this is easier to discuss in the form of a PR, I’d be happy to submit one. And if I’m missing a simpler solution to this, I’d be happy to learn that.
thanks, Dennis
Issue Analytics
- State:
- Created 6 years ago
- Reactions:27
- Comments:40 (15 by maintainers)
Top GitHub Comments
Hm I’m not sure we want to go overboard with this helper. You can do
In case anyone else is using Python3 and got
AttributeError: 'generator' object has no attribute 'next'
from the code:train_inds, test_inds = GroupShuffleSplit().split(X, groups=groups).next()
…the following works in Python3:
train_inds, test_inds = next(GroupShuffleSplit().split(X, groups=groups))