add stratified split for "shuffle=False" in train_test_split
See original GitHub issueDescribe the workflow you want to enable
When splitting time series data, data is often split without shuffling. But now train_test_split
only supports stratified split with shuffle=True
.
It would be helpful to add stratify option for shuffle=False
also.
Describe your proposed solution
Add option shuffle
in StratifiedShuffleSplit
, and only permutate indices when shuffle
option is True
.
Then, use StratifiedShuffleSplit
with shuffle option in train_test_split
.
For example:
def train_test_split(*arrays, **options):
...
if shuffle is False:
if stratify is not None:
cv = StratifiedShuffleSplit(test_size=n_test, train_size=n_train,
random_state=random_state, shuffle=False)
train, test = next(cv.split(X=arrays[0], y=stratify))
else:
train = np.arange(n_train)
test = np.arange(n_train, n_train + n_test)
...
class StratifiedShuffleSplit(BaseShuffleSplit):
def __init__(self, n_splits=10, test_size=None, train_size=None,
random_state=None, shuffle=True):
super().__init__(
n_splits=n_splits,
test_size=test_size,
train_size=train_size,
random_state=random_state)
self._default_test_size = 0.1
self.shuffle=shuffle
def _iter_indices(self, X, y, groups=None):
n_samples = _num_samples(X)
y = check_array(y, ensure_2d=False, dtype=None)
...
rng = check_random_state(self.random_state)
for _ in range(self.n_splits):
# if there are ties in the class-counts, we want
# to make sure to break them anew in each iteration
n_i = _approximate_mode(class_counts, n_train, rng)
class_counts_remaining = class_counts - n_i
t_i = _approximate_mode(class_counts_remaining, n_test, rng)
train = []
test = []
for i in range(n_classes):
if self.shuffle:
permutation = rng.permutation(class_counts[i])
perm_indices_class_i = class_indices[i].take(permutation,
mode='clip')
else:
perm_indices_class_i = class_indices[i]
train.extend(perm_indices_class_i[:n_i[i]])
test.extend(perm_indices_class_i[n_i[i]:n_i[i] + t_i[i]])
if self.shuffle:
train = rng.permutation(train)
test = rng.permutation(test)
yield train, test
Describe alternatives you’ve considered, if relevant
Alternative idea is adding split process directly in train_test_split (like case Shuffle=False
and stratify is None), or add new method. But that new code is very similar to StratifiedShuffleSplit, so modifying StratifiedShuffleSplit would be better.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:9
- Comments:6 (2 by maintainers)
Top Results From Across the Web
stratify data without train_test_split shuffle - Stack Overflow
sklearn train_test_split stratify works only when the setting is shuffle=True. [See documentation: If shuffle=False then stratify must be None. ].
Read more >What is the role of 'shuffle' in train_test_split()? - Cross Validated
With shuffle=True you split the data randomly. For example, say that you have balanced binary classification data and it is ordered by labels....
Read more >sklearn.model_selection.StratifiedShuffleSplit
Provides train/test indices to split data in train/test sets. This cross-validation object is a merge of StratifiedKFold and ShuffleSplit, which returns ...
Read more >How to Use Sklearn train_test_split in Python - Sharp Sight
The Sklearn train_test_split function splits a dataset into training ... If you set shuffle = False , then you must set stratify =...
Read more >Split Your Dataset With scikit-learn's train_test_split()
In this tutorial, you'll learn why it's important to split your dataset ... Finally, you can turn off data shuffling and random split...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
No I haven’t. But before I post this issue, I confirmed the code works. I can create pull request.
This feature will be very useful.