Risk of MemoryError when using Birch clustering algorithm with large datasets and a simple solution
See original GitHub issueDescription
It always raise MemoryError when using sklearn.cluster.Birch
to implement clustering of large dataset. For example, the MemoryError would be like that:
File "D:\Program Files\Python3\lib\site-packages\sklearn\cluster\birch.py", line 536, in partial_fit
return self._fit(X)
File "D:\Program Files\Python3\lib\site-packages\sklearn\cluster\birch.py", line 497, in _fit
self._global_clustering(X)
File "D:\Program Files\Python3\lib\site-packages\sklearn\cluster\birch.py", line 636, in _global_clustering
self.labels_ = self.predict(X)
File "D:\Program Files\Python3\lib\site-packages\sklearn\cluster\birch.py", line 571, in predict
reduced_distance = safe_sparse_dot(X, self.subcluster_centers_.T)
File "D:\Program Files\Python3\lib\site-packages\sklearn\utils\extmath.py", line 142, in safe_sparse_dot
return np.dot(a, b)
File "<__array_function__ internals>", line 6, in dot
numpy.core._exceptions.MemoryError: Unable to allocate array with shape (1000000, 30777) and data type float64
In one of my cases, the method predict(X)
requires a large amount of memory to create a np.array
(around 1000000 * 30777 * 8/1024/1024/1024/8 = 29GB) when handling a 30M-size 2D dataset (10M each partial_fit(X)
here). It is unreasonable that the method predict(X)
do the dot product of X
and self.subcluster_centers_.T
directly.
I think a simple change can mitigate the high consumption of memory easily.
Steps/Code to Reproduce
Just divide the large matrix reduced_distance
(n_samples, n_subcluster_centers) into a series of smaller matrixes (n_samples, i_n_subcluster_centers).
The test code is modified from Compare BIRCH and MiniBatchKMeans.
The OldBirch
class is the same as sklearn.cluster.Birch
.
The NewBirch
class has a modified method predict(X)
.
from itertools import cycle
from time import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from sklearn.cluster import Birch
from sklearn.datasets import make_blobs
from sklearn.utils import check_array
from sklearn.utils.extmath import safe_sparse_dot
from memory_profiler import profile
n_decompositon = 1000 # divide the array 'reduced_distance' into 1000 parts along the axis=0
class OldBirch(Birch):
@profile
def predict(self, X):
# the original code
X = check_array(X, accept_sparse='csr')
self._check_fit(X)
reduced_distance = safe_sparse_dot(X, self.subcluster_centers_.T)
reduced_distance *= -2
reduced_distance += self._subcluster_norms
return self.subcluster_labels_[np.argmin(reduced_distance, axis=1)]
class NewBirch(Birch):
@profile
def predict(self, X):
X = check_array(X, accept_sparse='csr')
self._check_fit(X)
'''
try:
reduced_distance = safe_sparse_dot(X, self.subcluster_centers_.T) # the original code
reduced_distance *= -2
reduced_distance += self._subcluster_norms
return self.subcluster_labels_[np.argmin(reduced_distance, axis=1)]
except MemoryError:
'''
# assume that the matrix is dense
argmin_list = np.array([], dtype=np.int)
interval = int(np.ceil(X.shape[0] / n_decompositon))
for index in range(0, n_decompositon - 1):
lb = index * interval
ub = (index + 1) * interval
reduced_distance = safe_sparse_dot(X[lb:ub, :], self.subcluster_centers_.T)
reduced_distance *= -2
reduced_distance += self._subcluster_norms
argmin_list = np.append(argmin_list, np.argmin(reduced_distance, axis=1))
lb = (n_decompositon - 1) * interval
reduced_distance = safe_sparse_dot(X[lb:X.shape[0], :], self.subcluster_centers_.T)
reduced_distance *= -2
reduced_distance += self._subcluster_norms
argmin_list = np.append(argmin_list, np.argmin(reduced_distance, axis=1))
return self.subcluster_labels_[argmin_list]
# Generate centers for the blobs so that it forms a 10 X 10 grid.
xx = np.linspace(-22, 22, 10)
yy = np.linspace(-22, 22, 10)
xx, yy = np.meshgrid(xx, yy)
n_centres = np.hstack((np.ravel(xx)[:, np.newaxis],
np.ravel(yy)[:, np.newaxis]))
# Generate blobs to do a comparison between MiniBatchKMeans and Birch.
X, y = make_blobs(n_samples=200000, centers=n_centres, random_state=0)
# Use all colors that matplotlib provides by default.
colors_ = cycle(colors.cnames.keys())
fig = plt.figure(figsize=(8, 4))
fig.subplots_adjust(left=0.04, right=0.98, bottom=0.1, top=0.9)
# Compute clustering with Birch with and without the final clustering step
# and plot.
birch_models = [OldBirch(threshold=1.7, n_clusters=None),
NewBirch(threshold=1.7, n_clusters=None)]
final_step = ['with old predict(X) method', 'with new predict(X) method']
for ind, (birch_model, info) in enumerate(zip(birch_models, final_step)):
t = time()
birch_model.fit(X)
time_ = time() - t
print("Birch %s as the final step took %0.2f seconds" % (
info, (time() - t)))
# Plot result
labels = birch_model.labels_
centroids = birch_model.subcluster_centers_
n_clusters = np.unique(labels).size
print("n_clusters : %d" % n_clusters)
ax = fig.add_subplot(1, 2, ind + 1)
for this_centroid, k, col in zip(centroids, range(n_clusters), colors_):
mask = labels == k
ax.scatter(X[mask, 0], X[mask, 1],
c='w', edgecolor=col, marker='.', alpha=0.5)
ax.scatter(this_centroid[0], this_centroid[1], marker='+',
c='k', s=25)
ax.set_ylim([-25, 25])
ax.set_xlim([-25, 25])
ax.set_autoscaley_on(False)
ax.set_title('Birch %s' % info)
plt.show()
Here, just consider the dense matrix but not sparse matrix.
Actual Results
The clustering results are the same:
And the consumption of memory is reduced remarkably though the time cost increases.
The level of the reduction depends on the number of divisions (smaller matrixes) n_decompositon
, I think.
Line # Mem usage Increment Line Contents
================================================
18 133.2 MiB 133.2 MiB @profile
19 def predict(self, X):
20 # the original code
21 133.2 MiB 0.0 MiB X = check_array(X, accept_sparse='csr')
22 133.2 MiB 0.0 MiB self._check_fit(X)
23 391.0 MiB 257.8 MiB reduced_distance = safe_sparse_dot(X, self.subcluster_centers_.T)
24 391.0 MiB 0.0 MiB reduced_distance *= -2
25 391.0 MiB 0.0 MiB reduced_distance += self._subcluster_norms
26 391.0 MiB 0.0 MiB return self.subcluster_labels_[np.argmin(reduced_distance, axis=1)]
Birch with old predict(X) method as the final step took 5.61 seconds
n_clusters : 166
Line # Mem usage Increment Line Contents
================================================
30 150.0 MiB 150.0 MiB @profile
31 def predict(self, X):
32 150.0 MiB 0.0 MiB X = check_array(X, accept_sparse='csr')
33 150.0 MiB 0.0 MiB self._check_fit(X)
34 '''
35 try:
36 reduced_distance = safe_sparse_dot(X, self.subcluster_centers_.T) # the original code
37 reduced_distance *= -2
38 reduced_distance += self._subcluster_norms
39 return self.subcluster_labels_[np.argmin(reduced_distance, axis=1)]
40 except MemoryError:
41 '''
42 # assume that the matrix is dense
43 150.0 MiB 0.0 MiB argmin_list = np.array([], dtype=np.int)
44 150.0 MiB 0.0 MiB interval = int(np.ceil(X.shape[0] / n_decompositon))
45 153.1 MiB 0.0 MiB for index in range(0, n_decompositon - 1):
46 153.1 MiB 0.0 MiB lb = index * interval
47 153.1 MiB 0.0 MiB ub = (index + 1) * interval
48 153.1 MiB 0.2 MiB reduced_distance = safe_sparse_dot(X[lb:ub, :], self.subcluster_centers_.T)
49 153.1 MiB 0.0 MiB reduced_distance *= -2
50 153.1 MiB 0.1 MiB reduced_distance += self._subcluster_norms
51 153.1 MiB 0.2 MiB argmin_list = np.append(argmin_list, np.argmin(reduced_distance, axis=1))
52
53 153.1 MiB 0.0 MiB lb = (n_decompositon - 1) * interval
54 153.1 MiB 0.0 MiB reduced_distance = safe_sparse_dot(X[lb:X.shape[0], :], self.subcluster_centers_.T)
55 153.1 MiB 0.0 MiB reduced_distance *= -2
56 153.1 MiB 0.0 MiB reduced_distance += self._subcluster_norms
57 153.1 MiB 0.0 MiB argmin_list = np.append(argmin_list, np.argmin(reduced_distance, axis=1))
58
59 153.1 MiB 0.0 MiB return self.subcluster_labels_[argmin_list]
Birch with new predict(X) method as the final step took 5.78 seconds
n_clusters : 166
I think nobody has been bothered with that possibly because the Birch is rarely used by people, but it can achieve uneven cluster size, unlike KMeans.
Versions
Python 3.7.0 (v3.7.0:1bf9cc5093, Jun 27 2018, 04:59:51) [MSC v.1914 64 bit (AMD64)] on win32 OS: Win10 Intel®_Core™i7-9750H_CPU@_2.60GHz RAM: 16GB
Python dependencies: numpy: 1.17.0 scikit-learn: 0.21.3 scipy : 1.3.2 memory-profiler: 0.55.0 psutil: 5.6.7 matplotlib: 3.1.1
Thank you.
Issue Analytics
- State:
- Created 4 years ago
- Comments:12 (10 by maintainers)
Hi @ImLaoBJie. I intend to fix this issue based on your idea. For the chunking of the samples matrix I will use an existing scikit-learn algorithm - pairwise_distances_argmin() which builds the chunks using get_chunk_n_rows()and gen_batches() based on working_memory accessible from sklearn.sk.get_config()
I added a new step to your demo that uses the chunked aproach - BirchChunked
This is the performance report. BirchChunked has the smallest memory increase:
And here is the pyplot display which shows that all three examples create the same clusters:
I’ve posted here a long explanantion about garbage collection issues I was observing while working with pairwise_distances_argmin. It turned out I was mistaken. It happened because I ran in debug mode in Pycharm. Running in release mode has no garbage collection problems and the memory stays low. So I removed that comment. In the PR I implemented a solution with 16MB working_memory. Attached the performance benchmark script and the memory_profiler log To run the benchmark script
python performance_16027.py SCIKIT-LEARN_HOME
perf.log performance_16027.txt@jnothman, @glemaitre @jeremiedbb PR is ready + performance benchmark attached here which validates efficiency of chunked memory algorithm in keeping memory footprint low