question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

add stratified split for "shuffle=False" in train_test_split

See original GitHub issue

Describe the workflow you want to enable

When splitting time series data, data is often split without shuffling. But now train_test_split only supports stratified split with shuffle=True. It would be helpful to add stratify option for shuffle=False also.

Describe your proposed solution

Add option shuffle in StratifiedShuffleSplit, and only permutate indices when shuffle option is True. Then, use StratifiedShuffleSplit with shuffle option in train_test_split.
For example:

def train_test_split(*arrays, **options):
    ...
    if shuffle is False:
        if stratify is not None:
            cv = StratifiedShuffleSplit(test_size=n_test, train_size=n_train,
                         random_state=random_state, shuffle=False)
            train, test = next(cv.split(X=arrays[0], y=stratify))
        else:
            train = np.arange(n_train)
            test = np.arange(n_train, n_train + n_test)
    ...

class StratifiedShuffleSplit(BaseShuffleSplit):
    def __init__(self, n_splits=10, test_size=None, train_size=None,
                 random_state=None, shuffle=True):
        super().__init__(
            n_splits=n_splits,
            test_size=test_size,
            train_size=train_size,
            random_state=random_state)
        self._default_test_size = 0.1
        self.shuffle=shuffle

    def _iter_indices(self, X, y, groups=None):
        n_samples = _num_samples(X)
        y = check_array(y, ensure_2d=False, dtype=None)
        
        ...

        rng = check_random_state(self.random_state)

        for _ in range(self.n_splits):
            # if there are ties in the class-counts, we want
            # to make sure to break them anew in each iteration
            n_i = _approximate_mode(class_counts, n_train, rng)
            class_counts_remaining = class_counts - n_i
            t_i = _approximate_mode(class_counts_remaining, n_test, rng)

            train = []
            test = []

            for i in range(n_classes):
                if self.shuffle:
                    permutation = rng.permutation(class_counts[i])
                    perm_indices_class_i = class_indices[i].take(permutation,
                                                                 mode='clip')
                else:
                    perm_indices_class_i = class_indices[i]

                train.extend(perm_indices_class_i[:n_i[i]])
                test.extend(perm_indices_class_i[n_i[i]:n_i[i] + t_i[i]])

            if self.shuffle:
                train = rng.permutation(train)
                test = rng.permutation(test)

            yield train, test

Describe alternatives you’ve considered, if relevant

Alternative idea is adding split process directly in train_test_split (like case Shuffle=False and stratify is None), or add new method. But that new code is very similar to StratifiedShuffleSplit, so modifying StratifiedShuffleSplit would be better.

Issue Analytics

  • State:open
  • Created 3 years ago
  • Reactions:9
  • Comments:6 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
Yukimura66commented, Sep 5, 2021

No I haven’t. But before I post this issue, I confirmed the code works. I can create pull request.

1reaction
joanlofecommented, Oct 6, 2020

This feature will be very useful.

Read more comments on GitHub >

github_iconTop Results From Across the Web

stratify data without train_test_split shuffle - Stack Overflow
sklearn train_test_split stratify works only when the setting is shuffle=True. [See documentation: If shuffle=False then stratify must be None. ].
Read more >
What is the role of 'shuffle' in train_test_split()? - Cross Validated
With shuffle=True you split the data randomly. For example, say that you have balanced binary classification data and it is ordered by labels....
Read more >
sklearn.model_selection.StratifiedShuffleSplit
Provides train/test indices to split data in train/test sets. This cross-validation object is a merge of StratifiedKFold and ShuffleSplit, which returns ...
Read more >
How to Use Sklearn train_test_split in Python - Sharp Sight
The Sklearn train_test_split function splits a dataset into training ... If you set shuffle = False , then you must set stratify =...
Read more >
Split Your Dataset With scikit-learn's train_test_split()
In this tutorial, you'll learn why it's important to split your dataset ... Finally, you can turn off data shuffling and random split...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found