Minimize memory usage in Tree based models by allowing for change of internal dtype
See original GitHub issueInternally in the Tree
class, a float64/double type is used for the value
array (https://github.com/scikit-learn/scikit-learn/blob/0.23.2/sklearn/tree/_tree.pyx#L547). This seems excessive in many cases. If it was possible to use other dtypes for such cases, a significant reduction in required memory would be obtainable.
A motivating example:
X, y = make_classification(n_samples=int(1e5), n_informative=10, n_features=150, n_classes=450, random_state=8811)
simple_model = RandomForestClassifier(n_estimators=1)
simple_model = simple_model.fit(X, y)
print(simple_model.estimators_[0].tree_.value.nbytes / 1024**2) # This is 98% of the memory required for the entire model
366.8231964111328
print(np.unique(simple_model.estimators_[0].tree_.value)) # Does not really need to be float64. Could be e.g. uint16.
array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.,
11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21.,
22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43.,
44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54.,
55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65.,
66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76.,
77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87.,
88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98.,
99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109.,
110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120.,
121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131.,
132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142.,
143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153.,
154., 155., 156., 157., 158., 159., 160., 161., 162., 163., 164.,
165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175.,
176., 177., 178., 179., 180., 182., 183., 184., 185., 187., 188.,
189., 190., 191., 192., 193., 194., 195., 196., 197., 198., 199.,
200., 201., 202., 203., 204., 205., 206., 207., 208., 209., 210.,
211., 212., 213., 214., 215., 216., 217., 218., 219., 220., 221.,
222., 223., 224., 225., 226., 227., 228., 229., 230., 231., 232.,
233., 234., 235., 236., 237., 238., 239., 240., 241., 242., 243.,
244., 245., 247., 248., 249., 250., 251., 252., 253., 254., 255.,
256., 257., 267., 268., 270.])
In the above example, the memory requirements may be lowered by a factor of ~4 by switching to uint16. If there were fewer classes in the model, one could even have settled for an uint8. For a regression model this probably needs to be a float32. However, even for the classification case, using a float32 would allow for a memory reduction by a factor of ~2. As a side note, since the input data X
is cast to float32 anyway, it does not seem unreasonable to have float32 be the default.
In my real world case, I have a RandomForestClassifier which takes up about ~10 GiB RAM. If I could obtain a memory reduction by a factor of ~4, I would only need a ~4 GiB VM when hosting and/or training this model in the cloud - which would be a significant saving for me.
Thus, ideally, I would like to have classes like RandomForestClassifier
or RandomForestRegressor
intelligently pick the dtype of the Tree
value
array - or at least let me specify it myself.
Any feedback on the feasibility of this suggestion is highly appreciated. My Cython skills are a little inadequate to fully understand the internals of the Tree
class.
This may be related to https://github.com/scikit-learn/scikit-learn/issues/14747 and https://github.com/scikit-learn/scikit-learn/issues/11165.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:1
- Comments:17 (12 by maintainers)
Changing
value
tofloat
would also require to changesample_weight
tofloat
(because ofsum_total
).I would be happy to open a PR with this solution, but I would like to know your opinion.
WDYT @NicolasHug @rth @thomasjpfan?
Hey @Chroxvi,
IIUC, @NicolasHug refers to the following. The
sum_total
attribute is computed in these lines (forClassificationCriterion
):https://github.com/scikit-learn/scikit-learn/blob/0fb307bf39bbdacd6ed713c00724f8f871d60370/sklearn/tree/_criterion.pyx#L327-L340
As noticed, the
sum_total
attribute depends on thesample_weight
attribute (since we are looking for computing the weighted count of each label).Then,
node_value
copysum_total
to thevalue
attribute of tree (dest
) here:https://github.com/scikit-learn/scikit-learn/blob/0fb307bf39bbdacd6ed713c00724f8f871d60370/sklearn/tree/_criterion.pyx#L495-L502
That is how
sample_weight
relates tonode_value
and why we cannot useuint8
forsum_total
and so forvalue
.In fact, running the following snippet:
will return
double
values even for classification tasks.Nevertheless, I think it is worth it to store
float
instead ofdouble
and so reduce the memory.WDYT?