[MAINT] Modularize Tree code and Splitter utility functions
See original GitHub issueFrom #20819 , developers expressed issues with the current tree code.
Part of that is the modularity and as a result, maintainability/upgradability of such code. I propose the following super-short refactors to the _tree.pyx/pxd
and _splitter.pyx/pxd
files. This would be the first in a series of PRs to demonstrate that #20819 is fairly straightforward.
Tree class
The Tree class assumes axis-aligned splits. However, by modularizing the parts where the node values are set, and the feature values are computed for any given dataset, then any subclass of Tree can easily redefine only these two functions and a new Splitter
to enable a “new” type of Tree.
I propose adding the following two functions to the Tree
class and altering _add_node()
, _apply_dense
to accompany these changes:
cdef int _set_node_values(self, SplitRecord split_node,
Node *node) nogil except -1:
"""Set node data.
"""
node.feature = split_node.feature
node.threshold = split_node.threshold
return 1
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray,
Node *node) nogil:
"""Compute feature from a given data matrix, X.
In axis-aligned trees, this is simply the value in the column of X
for this specific feature.
"""
# the feature index
cdef DTYPE_t feature = X_ndarray[node.feature]
return feature
Splitter
Splitter uses functions only defined in the .pyx
files. As a result, they are not available via cimport
. This poses an issue for #20819 and also for downstream packages that might want to define a new splitter that subclasses Splitter
.
Here I propose adding the following functions into the _splitter.pxd
file:
cdef inline void sort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil
cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, SIZE_t i, SIZE_t j) nogil
# and the other splitter utility functions.
...
`
### Misc Notes
This specifically addresses only issues with dense arrays. A follow-on issue and PR would be necessary for sparse arrays.
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (6 by maintainers)
Top GitHub Comments
I’m working on replacing sort all together. It requires minor refactoring of the splitter internals and benchmarks to make sure there are no performance regressions.
FYI, I just have opened https://github.com/scikit-learn/scikit-learn/pull/22760.