question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Improve tests to make them run on variously typed data using the `global_dtype` fixture

See original GitHub issue

Context: the new global_dtype fixture and SKLEARN_RUN_FLOAT32_TESTS environment variable

Introduction of low-level computational routines for 32bit motivated an extension of tests to run them on 32bit.

In this regards, https://github.com/scikit-learn/scikit-learn/pull/22690 introduced a new global_dtype fixture as well has the SKLEARN_RUN_FLOAT32_TESTS env. variable to make it possible to run the test on 32bit data.

Running test on 32bit can be done using SKLEARN_RUN_FLOAT32_TESTS=1.

For instance, this run the first global_dtype-parametrised test:

SKLEARN_RUN_FLOAT32_TESTS=1 pytest sklearn/feature_selection/tests/test_mutual_info.py -k test_compute_mi_cc

This allows running tests on 32bit dataset on some CI job, and currently a single CI job is used to run tests on 32bit.

More details about the fixture in the online dev doc for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED env variable:

https://scikit-learn.org/dev/computing/parallelism.html#environment-variables

Guidelines to convert existing tests

  • Not all scikit-learn tests must use this fixture. We must parametrise tests that actually assert closeness of results using assert_allclose. For instance, tests that check for the exception messages raised when passing invalid inputs must not be converted.

  • Tests using np.testing.assert_allclose must now use sklearn.utils._testing.assert_allclose as a drop-in replacement.

  • Check that the dtype of fitted attributes or return values that depend on the dtype of the input datastructure actually have the expected dtype: typically when all inputs are continuous values in float32, it is often (but not always) the case that scikit-learn should carry all arithmetic operations at that precision level and return output arrays with the same precision level. There can be exceptions, in which case they could be made explicit with an inline comment in the test, possibly with a TODO marker when one thing that the current behavior should change (see the related: #11000 and #22682 for instance).

  • To avoid having to review huge PRs that impact many files at once and can lead to conflicts, let’s open PRs that edit at most one test file at a time. For instance use a title such as:

TST use global_dtype in sklearn/_loss/tests/test_glm_distribution.py

  • Please reference #22881 (i.e. this issue) in the description of the PR and put the full filename of the test file you edit in the title of the PR.

  • To convert an existing test with a fixed dtype, the general pattern is to rewrite a function such as:

from numpy.testing import assert_allclose

def test_some_function():
    # ...
    rng = np.random.RandomState(0)
    X = rng.randon.rand(n_samples, n_features)
    y = rng.randon.rand(n_samples).astype(global_dtype)
    model.fit(X, y)
    # ...
    y_pred = model.predict(X)
    assert_allclose(y_pred, y_true)

to:

from sklearn.utils._testing import assert_allclose

def test_some_function(global_dtype):
    # ...
    rng = np.random.RandomState(0)
    X = rng.randon.rand(n_samples, n_features).astype(global_dtype, copy=False)
    y = rng.randon.rand(n_samples).astype(global_dtype, copy=False)
    model.fit(X, y)
    # ...
    assert model.fitted_param_.dtype == global_dtype
    y_pred = model.predict(X)
    assert y_pred.dtype == global_dtype
    assert_allclose(y_pred, y_true)

and then check that the test is passing on 32bit datasets

SKLEARN_RUN_FLOAT32_TESTS=1 pytest sklearn/some_module/test/test_some_module.py -k test_some_function

Failures are to be handle on a case-by-case basis.

List of test modules to upgrade

find sklearn -name "test_*.py"
  • sklearn/_loss/tests/test_glm_distribution.py
  • sklearn/_loss/tests/test_link.py
  • sklearn/_loss/tests/test_loss.py
  • sklearn/cluster/tests/test_affinity_propagation.py #22667
  • sklearn/cluster/tests/test_bicluster.py
  • sklearn/cluster/tests/test_birch.py #22671
  • sklearn/cluster/tests/test_dbscan.py
  • sklearn/cluster/tests/test_feature_agglomeration.py
  • sklearn/cluster/tests/test_hierarchical.py
  • sklearn/cluster/tests/test_k_means.py
  • sklearn/cluster/tests/test_mean_shift.py #22672
  • sklearn/cluster/tests/test_optics.py #22665
  • sklearn/cluster/tests/test_spectral.py #22669
  • sklearn/compose/tests/test_column_transformer.py
  • sklearn/compose/tests/test_target.py
  • sklearn/covariance/tests/test_covariance.py
  • sklearn/covariance/tests/test_elliptic_envelope.py
  • sklearn/covariance/tests/test_graphical_lasso.py
  • sklearn/covariance/tests/test_robust_covariance.py
  • sklearn/cross_decomposition/tests/test_pls.py
  • sklearn/datasets/tests/test_20news.py
  • sklearn/datasets/tests/test_base.py
  • sklearn/datasets/tests/test_california_housing.py
  • sklearn/datasets/tests/test_common.py
  • sklearn/datasets/tests/test_covtype.py
  • sklearn/datasets/tests/test_kddcup99.py
  • sklearn/datasets/tests/test_lfw.py
  • sklearn/datasets/tests/test_olivetti_faces.py
  • sklearn/datasets/tests/test_openml.py
  • sklearn/datasets/tests/test_rcv1.py
  • sklearn/datasets/tests/test_samples_generator.py
  • sklearn/datasets/tests/test_svmlight_format.py
  • sklearn/decomposition/tests/test_dict_learning.py
  • sklearn/decomposition/tests/test_factor_analysis.py
  • sklearn/decomposition/tests/test_fastica.py
  • sklearn/decomposition/tests/test_incremental_pca.py
  • sklearn/decomposition/tests/test_kernel_pca.py
  • sklearn/decomposition/tests/test_nmf.py
  • sklearn/decomposition/tests/test_online_lda.py
  • sklearn/decomposition/tests/test_pca.py
  • sklearn/decomposition/tests/test_sparse_pca.py
  • sklearn/decomposition/tests/test_truncated_svd.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_binning.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_bitset.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_histogram.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_predictor.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_warm_start.py
  • sklearn/ensemble/tests/test_bagging.py
  • sklearn/ensemble/tests/test_base.py
  • sklearn/ensemble/tests/test_common.py
  • sklearn/ensemble/tests/test_forest.py
  • sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py
  • sklearn/ensemble/tests/test_gradient_boosting.py
  • sklearn/ensemble/tests/test_iforest.py
  • sklearn/ensemble/tests/test_stacking.py
  • sklearn/ensemble/tests/test_voting.py
  • sklearn/ensemble/tests/test_weight_boosting.py
  • sklearn/experimental/tests/test_enable_hist_gradient_boosting.py
  • sklearn/experimental/tests/test_enable_iterative_imputer.py
  • sklearn/experimental/tests/test_enable_successive_halving.py
  • sklearn/feature_extraction/tests/test_dict_vectorizer.py
  • sklearn/feature_extraction/tests/test_feature_hasher.py
  • sklearn/feature_extraction/tests/test_image.py
  • sklearn/feature_extraction/tests/test_text.py
  • sklearn/feature_selection/tests/test_base.py
  • sklearn/feature_selection/tests/test_chi2.py
  • sklearn/feature_selection/tests/test_feature_select.py
  • sklearn/feature_selection/tests/test_from_model.py
  • sklearn/feature_selection/tests/test_mutual_info.py #22677
  • sklearn/feature_selection/tests/test_rfe.py
  • sklearn/feature_selection/tests/test_sequential.py
  • sklearn/feature_selection/tests/test_variance_threshold.py
  • sklearn/gaussian_process/tests/test_gpc.py
  • sklearn/gaussian_process/tests/test_gpr.py
  • sklearn/gaussian_process/tests/test_kernels.py
  • sklearn/impute/tests/test_base.py
  • sklearn/impute/tests/test_common.py
  • sklearn/impute/tests/test_impute.py
  • sklearn/impute/tests/test_knn.py
  • sklearn/inspection/_plot/tests/test_plot_partial_dependence.py
  • sklearn/inspection/tests/test_partial_dependence.py
  • sklearn/inspection/tests/test_permutation_importance.py
  • sklearn/linear_model/_glm/tests/test_glm.py
  • sklearn/linear_model/_glm/tests/test_link.py
  • sklearn/linear_model/tests/test_base.py
  • sklearn/linear_model/tests/test_bayes.py
  • sklearn/linear_model/tests/test_common.py
  • sklearn/linear_model/tests/test_coordinate_descent.py
  • sklearn/linear_model/tests/test_huber.py
  • sklearn/linear_model/tests/test_least_angle.py
  • sklearn/linear_model/tests/test_linear_loss.py
  • sklearn/linear_model/tests/test_logistic.py
  • sklearn/linear_model/tests/test_omp.py
  • sklearn/linear_model/tests/test_passive_aggressive.py
  • sklearn/linear_model/tests/test_perceptron.py
  • sklearn/linear_model/tests/test_quantile.py
  • sklearn/linear_model/tests/test_ransac.py
  • sklearn/linear_model/tests/test_ridge.py
  • sklearn/linear_model/tests/test_sag.py
  • sklearn/linear_model/tests/test_sgd.py
  • sklearn/linear_model/tests/test_sparse_coordinate_descent.py
  • sklearn/linear_model/tests/test_theil_sen.py
  • sklearn/manifold/tests/test_isomap.py #22673
  • sklearn/manifold/tests/test_locally_linear.py #22676
  • sklearn/manifold/tests/test_mds.py
  • sklearn/manifold/tests/test_spectral_embedding.py
  • sklearn/manifold/tests/test_t_sne.py #22675
  • sklearn/metrics/_plot/tests/test_base.py
  • sklearn/metrics/_plot/tests/test_common_curve_display.py
  • sklearn/metrics/_plot/tests/test_confusion_matrix_display.py
  • sklearn/metrics/_plot/tests/test_det_curve_display.py
  • sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py
  • sklearn/metrics/_plot/tests/test_plot_curve_common.py
  • sklearn/metrics/_plot/tests/test_plot_det_curve.py
  • sklearn/metrics/_plot/tests/test_plot_precision_recall.py
  • sklearn/metrics/_plot/tests/test_plot_roc_curve.py
  • sklearn/metrics/_plot/tests/test_precision_recall_display.py
  • sklearn/metrics/_plot/tests/test_roc_curve_display.py
  • sklearn/metrics/cluster/tests/test_bicluster.py
  • sklearn/metrics/cluster/tests/test_common.py
  • sklearn/metrics/cluster/tests/test_supervised.py
  • sklearn/metrics/cluster/tests/test_unsupervised.py
  • sklearn/metrics/tests/test_classification.py
  • sklearn/metrics/tests/test_common.py
  • sklearn/metrics/tests/test_dist_metrics.py
  • sklearn/metrics/tests/test_pairwise_distances_reduction.py
  • sklearn/metrics/tests/test_pairwise.py #22666
  • sklearn/metrics/tests/test_ranking.py
  • sklearn/metrics/tests/test_regression.py
  • sklearn/metrics/tests/test_score_objects.py
  • sklearn/mixture/tests/test_bayesian_mixture.py
  • sklearn/mixture/tests/test_gaussian_mixture.py
  • sklearn/mixture/tests/test_mixture.py
  • sklearn/model_selection/tests/test_search.py
  • sklearn/model_selection/tests/test_split.py
  • sklearn/model_selection/tests/test_successive_halving.py
  • sklearn/model_selection/tests/test_validation.py
  • sklearn/neighbors/tests/test_ball_tree.py
  • sklearn/neighbors/tests/test_graph.py
  • sklearn/neighbors/tests/test_kd_tree.py
  • sklearn/neighbors/tests/test_kde.py
  • sklearn/neighbors/tests/test_lof.py #22665
  • sklearn/neighbors/tests/test_nca.py
  • sklearn/neighbors/tests/test_nearest_centroid.py
  • sklearn/neighbors/tests/test_neighbors_pipeline.py
  • sklearn/neighbors/tests/test_neighbors_tree.py
  • sklearn/neighbors/tests/test_neighbors.py #22663
  • sklearn/neighbors/tests/test_quad_tree.py
  • sklearn/neural_network/tests/test_base.py
  • sklearn/neural_network/tests/test_mlp.py
  • sklearn/neural_network/tests/test_rbm.py
  • sklearn/neural_network/tests/test_stochastic_optimizers.py
  • sklearn/preprocessing/tests/test_common.py
  • sklearn/preprocessing/tests/test_data.py
  • sklearn/preprocessing/tests/test_discretization.py
  • sklearn/preprocessing/tests/test_encoders.py
  • sklearn/preprocessing/tests/test_function_transformer.py
  • sklearn/preprocessing/tests/test_label.py
  • sklearn/preprocessing/tests/test_polynomial.py
  • sklearn/semi_supervised/tests/test_label_propagation.py
  • sklearn/semi_supervised/tests/test_self_training.py
  • sklearn/svm/tests/test_bounds.py
  • sklearn/svm/tests/test_sparse.py
  • sklearn/svm/tests/test_svm.py
  • sklearn/tests/test_base.py
  • sklearn/tests/test_build.py
  • sklearn/tests/test_calibration.py
  • sklearn/tests/test_check_build.py
  • sklearn/tests/test_common.py
  • sklearn/tests/test_config.py
  • sklearn/tests/test_discriminant_analysis.py
  • sklearn/tests/test_docstring_parameters.py
  • sklearn/tests/test_docstrings.py
  • sklearn/tests/test_dummy.py
  • sklearn/tests/test_init.py
  • sklearn/tests/test_isotonic.py
  • sklearn/tests/test_kernel_approximation.py
  • sklearn/tests/test_kernel_ridge.py
  • sklearn/tests/test_metaestimators.py
  • sklearn/tests/test_min_dependencies_readme.py
  • sklearn/tests/test_multiclass.py
  • sklearn/tests/test_multioutput.py
  • sklearn/tests/test_naive_bayes.py
  • sklearn/tests/test_pipeline.py
  • sklearn/tests/test_random_projection.py
  • sklearn/tree/tests/test_export.py
  • sklearn/tree/tests/test_reingold_tilford.py
  • sklearn/tree/tests/test_tree.py
  • sklearn/utils/tests/test_arpack.py
  • sklearn/utils/tests/test_arrayfuncs.py
  • sklearn/utils/tests/test_class_weight.py
  • sklearn/utils/tests/test_cython_blas.py
  • sklearn/utils/tests/test_cython_templating.py
  • sklearn/utils/tests/test_deprecation.py
  • sklearn/utils/tests/test_encode.py
  • sklearn/utils/tests/test_estimator_checks.py
  • sklearn/utils/tests/test_estimator_html_repr.py
  • sklearn/utils/tests/test_extmath.py
  • sklearn/utils/tests/test_fast_dict.py
  • sklearn/utils/tests/test_fixes.py
  • sklearn/utils/tests/test_graph.py
  • sklearn/utils/tests/test_metaestimators.py
  • sklearn/utils/tests/test_mocking.py
  • sklearn/utils/tests/test_multiclass.py
  • sklearn/utils/tests/test_murmurhash.py
  • sklearn/utils/tests/test_optimize.py
  • sklearn/utils/tests/test_parallel.py
  • sklearn/utils/tests/test_pprint.py
  • sklearn/utils/tests/test_random.py
  • sklearn/utils/tests/test_readonly_wrapper.py
  • sklearn/utils/tests/test_seq_dataset.py
  • sklearn/utils/tests/test_shortest_path.py
  • sklearn/utils/tests/test_show_versions.py
  • sklearn/utils/tests/test_sparsefuncs.py
  • sklearn/utils/tests/test_stats.py
  • sklearn/utils/tests/test_tags.py
  • sklearn/utils/tests/test_testing.py
  • sklearn/utils/tests/test_utils.py
  • sklearn/utils/tests/test_validation.py
  • sklearn/utils/tests/test_weight_vector.py

Note that some of those files might not have any test to update.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:8 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
jjerphancommented, Mar 23, 2022

@adam2392: some estimators aren’t preserving the provided input dtype.

In that case, it still makes sense to extend the tests adding TODO, as explained in the description of this issue:

Check that the dtype of fitted attributes or return values that depend on the dtype of the input datastructure actually have the expected dtype: typically when all inputs are continuous values in float32, it is often (but not always) the case that scikit-learn should carry all arithmetic operations at that precision level and return output arrays with the same precision level. There can be exceptions, in which case they could be made explicit with an inline comment in the test, possibly with a TODO marker when one thing that the current behavior should change (see the related: https://github.com/scikit-learn/scikit-learn/issues/11000 and https://github.com/scikit-learn/scikit-learn/issues/22682 for instance).

0reactions
ogriselcommented, Mar 22, 2022

I changed it to hard because it requires some knowledge to figure out which test should use the fixture and which test should not.

I extended the tests of #22806 using a combination of global_dtype and the global_random_seed fixtures and it revealed numerical problem that would not be visible with the default seed. I already fixed the most obvious but I am not 100% sure if this is good enough or not. So indeed I anticipate those PRs to be hard on average.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Simplify Your Tests With Fixtures - YouTube
Fixtures can make your tests simpler and easier to maintain by using or creating common abstractions to be shared amongst your tests.
Read more >
Pytest best practices for libraries - Towards Data Science
Python library testing strategies with pytest, including parametrizing, mocking, passing arguments, temporary path, mocker.spy, and fixture ...
Read more >
Effective Python Testing With Pytest
Fixtures are great for extracting data or objects that you use across multiple tests. However, they aren't always as good for tests that...
Read more >
Create Data-Driven Unit Tests - Visual Studio - Microsoft Learn
Learn how to use the Microsoft unit test framework for managed code to set up a unit test method to retrieve values from...
Read more >
Testing in Java & JVM projects - Gradle User Manual
A test task of type Test that runs those unit tests. The JVM language plugins use the source set to configure the task...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found