question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Minimize memory usage in Tree based models by allowing for change of internal dtype

See original GitHub issue

Internally in the Tree class, a float64/double type is used for the value array (https://github.com/scikit-learn/scikit-learn/blob/0.23.2/sklearn/tree/_tree.pyx#L547). This seems excessive in many cases. If it was possible to use other dtypes for such cases, a significant reduction in required memory would be obtainable.

A motivating example:

X, y = make_classification(n_samples=int(1e5), n_informative=10, n_features=150, n_classes=450, random_state=8811)
simple_model = RandomForestClassifier(n_estimators=1)
simple_model = simple_model.fit(X, y)

print(simple_model.estimators_[0].tree_.value.nbytes / 1024**2)   # This is 98% of the memory required for the entire model
366.8231964111328

print(np.unique(simple_model.estimators_[0].tree_.value))  # Does not really need to be float64. Could be e.g. uint16.
array([  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
        11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,
        22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,
        33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,
        44.,  45.,  46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.,
        55.,  56.,  57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,
        66.,  67.,  68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,
        77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,
        88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,
        99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109.,
       110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120.,
       121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131.,
       132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142.,
       143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153.,
       154., 155., 156., 157., 158., 159., 160., 161., 162., 163., 164.,
       165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175.,
       176., 177., 178., 179., 180., 182., 183., 184., 185., 187., 188.,
       189., 190., 191., 192., 193., 194., 195., 196., 197., 198., 199.,
       200., 201., 202., 203., 204., 205., 206., 207., 208., 209., 210.,
       211., 212., 213., 214., 215., 216., 217., 218., 219., 220., 221.,
       222., 223., 224., 225., 226., 227., 228., 229., 230., 231., 232.,
       233., 234., 235., 236., 237., 238., 239., 240., 241., 242., 243.,
       244., 245., 247., 248., 249., 250., 251., 252., 253., 254., 255.,
       256., 257., 267., 268., 270.])

In the above example, the memory requirements may be lowered by a factor of ~4 by switching to uint16. If there were fewer classes in the model, one could even have settled for an uint8. For a regression model this probably needs to be a float32. However, even for the classification case, using a float32 would allow for a memory reduction by a factor of ~2. As a side note, since the input data X is cast to float32 anyway, it does not seem unreasonable to have float32 be the default.

In my real world case, I have a RandomForestClassifier which takes up about ~10 GiB RAM. If I could obtain a memory reduction by a factor of ~4, I would only need a ~4 GiB VM when hosting and/or training this model in the cloud - which would be a significant saving for me.

Thus, ideally, I would like to have classes like RandomForestClassifier or RandomForestRegressor intelligently pick the dtype of the Tree value array - or at least let me specify it myself.

Any feedback on the feasibility of this suggestion is highly appreciated. My Cython skills are a little inadequate to fully understand the internals of the Tree class.

This may be related to https://github.com/scikit-learn/scikit-learn/issues/14747 and https://github.com/scikit-learn/scikit-learn/issues/11165.

Issue Analytics

  • State:open
  • Created 3 years ago
  • Reactions:1
  • Comments:17 (12 by maintainers)

github_iconTop GitHub Comments

1reaction
alfaro96commented, Sep 28, 2020

Changing value to float would also require to change sample_weight to float (because of sum_total).

I would be happy to open a PR with this solution, but I would like to know your opinion.

WDYT @NicolasHug @rth @thomasjpfan?

1reaction
alfaro96commented, Sep 25, 2020

Hey @Chroxvi,

IIUC, @NicolasHug refers to the following. The sum_total attribute is computed in these lines (for ClassificationCriterion):

https://github.com/scikit-learn/scikit-learn/blob/0fb307bf39bbdacd6ed713c00724f8f871d60370/sklearn/tree/_criterion.pyx#L327-L340

As noticed, the sum_total attribute depends on the sample_weight attribute (since we are looking for computing the weighted count of each label).

Then, node_value copy sum_total to the value attribute of tree (dest) here:

https://github.com/scikit-learn/scikit-learn/blob/0fb307bf39bbdacd6ed713c00724f8f871d60370/sklearn/tree/_criterion.pyx#L495-L502

That is how sample_weight relates to node_value and why we cannot use uint8 for sum_total and so for value.

In fact, running the following snippet:

from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils.validation import check_random_state

seed = 8811
rng = check_random_state(seed)

X, y = make_classification(n_samples=10000, n_informative=10, n_features=150, n_classes=450, random_state=seed)
sample_weight = rng.randn(10000)
simple_model = RandomForestClassifier(n_estimators=1)
simple_model = simple_model.fit(X, y, sample_weight)

print(simple_model[0].tree_.value)
[[[10.82412723 -2.16428625 -5.63254548 ... -3.478739   -8.62627984
    4.78608331]]

 [[-0.58115954 -3.28410311 -3.27667071 ... -2.01077895 -1.13608312
    2.06086726]]

 [[ 0.43173988 -3.10751872 -3.45221219 ...  0.          0.60339334
    1.35368681]]

 ...

 [[-0.49758928  0.          0.         ...  0.          0.27629371
    0.        ]]

 [[ 0.75476597  0.          0.         ...  0.55597707  0.
    0.7089391 ]]

 [[11.40528677  1.11981687 -2.35587477 ... -1.46796005 -7.49019672
    2.72521605]]]

will return double values even for classification tasks.

Nevertheless, I think it is worth it to store float instead of double and so reduce the memory.

WDYT?

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to reduce memory used by Random Forest from Scikit ...
The memory usage of the Random Forest depends on the size of a single tree and number of trees. The most straight forward...
Read more >
sklearn.tree.DecisionTreeClassifier
The minimum number of samples required to split an internal node: ... To reduce memory consumption, the complexity and size of the trees...
Read more >
Efficient Training on a Single GPU - Hugging Face
Each method can improve speed or memory usage which is summarized in the table ... start training the model and see how the...
Read more >
Reducing memory usage in pandas with smaller datatypes
In this article, we saw how we could optimize the memory being used by the dataset. This is especially useful if we have...
Read more >
Fantastic Trees and How to tune them - Kaggle
The choice of focusing on tree-based algorithms is driven by how easy is to understand the basics. Moreover, the hope is that most...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found