Risk of MemoryError when using Birch clustering algorithm with large datasets and a simple solution
See original GitHub issueDescription
It always raise MemoryError when using sklearn.cluster.Birch to implement clustering of large dataset. For example, the MemoryError would be like that:
File "D:\Program Files\Python3\lib\site-packages\sklearn\cluster\birch.py", line 536, in partial_fit
return self._fit(X)
File "D:\Program Files\Python3\lib\site-packages\sklearn\cluster\birch.py", line 497, in _fit
self._global_clustering(X)
File "D:\Program Files\Python3\lib\site-packages\sklearn\cluster\birch.py", line 636, in _global_clustering
self.labels_ = self.predict(X)
File "D:\Program Files\Python3\lib\site-packages\sklearn\cluster\birch.py", line 571, in predict
reduced_distance = safe_sparse_dot(X, self.subcluster_centers_.T)
File "D:\Program Files\Python3\lib\site-packages\sklearn\utils\extmath.py", line 142, in safe_sparse_dot
return np.dot(a, b)
File "<__array_function__ internals>", line 6, in dot
numpy.core._exceptions.MemoryError: Unable to allocate array with shape (1000000, 30777) and data type float64
In one of my cases, the method predict(X) requires a large amount of memory to create a np.array (around 1000000 * 30777 * 8/1024/1024/1024/8 = 29GB) when handling a 30M-size 2D dataset (10M each partial_fit(X) here). It is unreasonable that the method predict(X) do the dot product of X and self.subcluster_centers_.T directly.
I think a simple change can mitigate the high consumption of memory easily.
Steps/Code to Reproduce
Just divide the large matrix reduced_distance (n_samples, n_subcluster_centers) into a series of smaller matrixes (n_samples, i_n_subcluster_centers).
The test code is modified from Compare BIRCH and MiniBatchKMeans.
The OldBirch class is the same as sklearn.cluster.Birch.
The NewBirch class has a modified method predict(X).
from itertools import cycle
from time import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from sklearn.cluster import Birch
from sklearn.datasets import make_blobs
from sklearn.utils import check_array
from sklearn.utils.extmath import safe_sparse_dot
from memory_profiler import profile
n_decompositon = 1000 # divide the array 'reduced_distance' into 1000 parts along the axis=0
class OldBirch(Birch):
@profile
def predict(self, X):
# the original code
X = check_array(X, accept_sparse='csr')
self._check_fit(X)
reduced_distance = safe_sparse_dot(X, self.subcluster_centers_.T)
reduced_distance *= -2
reduced_distance += self._subcluster_norms
return self.subcluster_labels_[np.argmin(reduced_distance, axis=1)]
class NewBirch(Birch):
@profile
def predict(self, X):
X = check_array(X, accept_sparse='csr')
self._check_fit(X)
'''
try:
reduced_distance = safe_sparse_dot(X, self.subcluster_centers_.T) # the original code
reduced_distance *= -2
reduced_distance += self._subcluster_norms
return self.subcluster_labels_[np.argmin(reduced_distance, axis=1)]
except MemoryError:
'''
# assume that the matrix is dense
argmin_list = np.array([], dtype=np.int)
interval = int(np.ceil(X.shape[0] / n_decompositon))
for index in range(0, n_decompositon - 1):
lb = index * interval
ub = (index + 1) * interval
reduced_distance = safe_sparse_dot(X[lb:ub, :], self.subcluster_centers_.T)
reduced_distance *= -2
reduced_distance += self._subcluster_norms
argmin_list = np.append(argmin_list, np.argmin(reduced_distance, axis=1))
lb = (n_decompositon - 1) * interval
reduced_distance = safe_sparse_dot(X[lb:X.shape[0], :], self.subcluster_centers_.T)
reduced_distance *= -2
reduced_distance += self._subcluster_norms
argmin_list = np.append(argmin_list, np.argmin(reduced_distance, axis=1))
return self.subcluster_labels_[argmin_list]
# Generate centers for the blobs so that it forms a 10 X 10 grid.
xx = np.linspace(-22, 22, 10)
yy = np.linspace(-22, 22, 10)
xx, yy = np.meshgrid(xx, yy)
n_centres = np.hstack((np.ravel(xx)[:, np.newaxis],
np.ravel(yy)[:, np.newaxis]))
# Generate blobs to do a comparison between MiniBatchKMeans and Birch.
X, y = make_blobs(n_samples=200000, centers=n_centres, random_state=0)
# Use all colors that matplotlib provides by default.
colors_ = cycle(colors.cnames.keys())
fig = plt.figure(figsize=(8, 4))
fig.subplots_adjust(left=0.04, right=0.98, bottom=0.1, top=0.9)
# Compute clustering with Birch with and without the final clustering step
# and plot.
birch_models = [OldBirch(threshold=1.7, n_clusters=None),
NewBirch(threshold=1.7, n_clusters=None)]
final_step = ['with old predict(X) method', 'with new predict(X) method']
for ind, (birch_model, info) in enumerate(zip(birch_models, final_step)):
t = time()
birch_model.fit(X)
time_ = time() - t
print("Birch %s as the final step took %0.2f seconds" % (
info, (time() - t)))
# Plot result
labels = birch_model.labels_
centroids = birch_model.subcluster_centers_
n_clusters = np.unique(labels).size
print("n_clusters : %d" % n_clusters)
ax = fig.add_subplot(1, 2, ind + 1)
for this_centroid, k, col in zip(centroids, range(n_clusters), colors_):
mask = labels == k
ax.scatter(X[mask, 0], X[mask, 1],
c='w', edgecolor=col, marker='.', alpha=0.5)
ax.scatter(this_centroid[0], this_centroid[1], marker='+',
c='k', s=25)
ax.set_ylim([-25, 25])
ax.set_xlim([-25, 25])
ax.set_autoscaley_on(False)
ax.set_title('Birch %s' % info)
plt.show()
Here, just consider the dense matrix but not sparse matrix.
Actual Results
The clustering results are the same:

And the consumption of memory is reduced remarkably though the time cost increases.
The level of the reduction depends on the number of divisions (smaller matrixes) n_decompositon, I think.
Line # Mem usage Increment Line Contents
================================================
18 133.2 MiB 133.2 MiB @profile
19 def predict(self, X):
20 # the original code
21 133.2 MiB 0.0 MiB X = check_array(X, accept_sparse='csr')
22 133.2 MiB 0.0 MiB self._check_fit(X)
23 391.0 MiB 257.8 MiB reduced_distance = safe_sparse_dot(X, self.subcluster_centers_.T)
24 391.0 MiB 0.0 MiB reduced_distance *= -2
25 391.0 MiB 0.0 MiB reduced_distance += self._subcluster_norms
26 391.0 MiB 0.0 MiB return self.subcluster_labels_[np.argmin(reduced_distance, axis=1)]
Birch with old predict(X) method as the final step took 5.61 seconds
n_clusters : 166
Line # Mem usage Increment Line Contents
================================================
30 150.0 MiB 150.0 MiB @profile
31 def predict(self, X):
32 150.0 MiB 0.0 MiB X = check_array(X, accept_sparse='csr')
33 150.0 MiB 0.0 MiB self._check_fit(X)
34 '''
35 try:
36 reduced_distance = safe_sparse_dot(X, self.subcluster_centers_.T) # the original code
37 reduced_distance *= -2
38 reduced_distance += self._subcluster_norms
39 return self.subcluster_labels_[np.argmin(reduced_distance, axis=1)]
40 except MemoryError:
41 '''
42 # assume that the matrix is dense
43 150.0 MiB 0.0 MiB argmin_list = np.array([], dtype=np.int)
44 150.0 MiB 0.0 MiB interval = int(np.ceil(X.shape[0] / n_decompositon))
45 153.1 MiB 0.0 MiB for index in range(0, n_decompositon - 1):
46 153.1 MiB 0.0 MiB lb = index * interval
47 153.1 MiB 0.0 MiB ub = (index + 1) * interval
48 153.1 MiB 0.2 MiB reduced_distance = safe_sparse_dot(X[lb:ub, :], self.subcluster_centers_.T)
49 153.1 MiB 0.0 MiB reduced_distance *= -2
50 153.1 MiB 0.1 MiB reduced_distance += self._subcluster_norms
51 153.1 MiB 0.2 MiB argmin_list = np.append(argmin_list, np.argmin(reduced_distance, axis=1))
52
53 153.1 MiB 0.0 MiB lb = (n_decompositon - 1) * interval
54 153.1 MiB 0.0 MiB reduced_distance = safe_sparse_dot(X[lb:X.shape[0], :], self.subcluster_centers_.T)
55 153.1 MiB 0.0 MiB reduced_distance *= -2
56 153.1 MiB 0.0 MiB reduced_distance += self._subcluster_norms
57 153.1 MiB 0.0 MiB argmin_list = np.append(argmin_list, np.argmin(reduced_distance, axis=1))
58
59 153.1 MiB 0.0 MiB return self.subcluster_labels_[argmin_list]
Birch with new predict(X) method as the final step took 5.78 seconds
n_clusters : 166
I think nobody has been bothered with that possibly because the Birch is rarely used by people, but it can achieve uneven cluster size, unlike KMeans.
Versions
Python 3.7.0 (v3.7.0:1bf9cc5093, Jun 27 2018, 04:59:51) [MSC v.1914 64 bit (AMD64)] on win32 OS: Win10 Intel®_Core™i7-9750H_CPU@_2.60GHz RAM: 16GB
Python dependencies: numpy: 1.17.0 scikit-learn: 0.21.3 scipy : 1.3.2 memory-profiler: 0.55.0 psutil: 5.6.7 matplotlib: 3.1.1
Thank you.
Issue Analytics
- State:
- Created 4 years ago
- Comments:12 (10 by maintainers)

Top Related StackOverflow Question
Hi @ImLaoBJie. I intend to fix this issue based on your idea. For the chunking of the samples matrix I will use an existing scikit-learn algorithm - pairwise_distances_argmin() which builds the chunks using get_chunk_n_rows()and gen_batches() based on working_memory accessible from sklearn.sk.get_config()
I added a new step to your demo that uses the chunked aproach - BirchChunked
This is the performance report. BirchChunked has the smallest memory increase:
And here is the pyplot display which shows that all three examples create the same clusters:
I’ve posted here a long explanantion about garbage collection issues I was observing while working with pairwise_distances_argmin. It turned out I was mistaken. It happened because I ran in debug mode in Pycharm. Running in release mode has no garbage collection problems and the memory stays low. So I removed that comment. In the PR I implemented a solution with 16MB working_memory. Attached the performance benchmark script and the memory_profiler log To run the benchmark script
python performance_16027.py SCIKIT-LEARN_HOMEperf.log performance_16027.txt@jnothman, @glemaitre @jeremiedbb PR is ready + performance benchmark attached here which validates efficiency of chunked memory algorithm in keeping memory footprint low