-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CLARA Clustering algorithm (#83)
* add CLARA * add example * fix typo * add doc * fix docstring * add CLARA to test_common * add size check to pass tests * fix tests * update doc * add test consistency clara kmedoids * black * handle types KMedoids * Apply suggestions from code review Co-authored-by: Roman Yurchak <[email protected]> * correct 32 bit * change name variables * create private function inertia and changelog Co-authored-by: Roman Yurchak <[email protected]>
- Loading branch information
1 parent
445aaf8
commit 5c47ba2
Showing
8 changed files
with
470 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ Clustering | |
|
||
cluster.KMedoids | ||
cluster.CommonNNClustering | ||
cluster.CLARA | ||
|
||
Robust | ||
==================== | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
""" | ||
====================================================================== | ||
A demo of K-Medoids vs CLARA clustering on the handwritten digits data | ||
====================================================================== | ||
In this example we compare different computation time of K-Medoids and CLARA on | ||
the handwritten digits data. | ||
""" | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import time | ||
|
||
from sklearn_extra.cluster import KMedoids, CLARA | ||
from sklearn.datasets import load_digits | ||
from sklearn.decomposition import PCA | ||
from sklearn.preprocessing import scale | ||
|
||
print(__doc__) | ||
|
||
# Authors: Timo Erkkilä <[email protected]> | ||
# Antti Lehmussola <[email protected]> | ||
# Kornel Kiełczewski <[email protected]> | ||
# License: BSD 3 clause | ||
|
||
np.random.seed(42) | ||
|
||
digits = load_digits() | ||
data = scale(digits.data) | ||
n_digits = len(np.unique(digits.target)) | ||
|
||
reduced_data = PCA(n_components=2).fit_transform(data) | ||
|
||
# Step size of the mesh. Decrease to increase the quality of the VQ. | ||
h = 0.02 # point in the mesh [x_min, m_max]x[y_min, y_max]. | ||
|
||
# Plot the decision boundary. For that, we will assign a color to each | ||
x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1 | ||
y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1 | ||
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) | ||
|
||
plt.figure() | ||
plt.clf() | ||
|
||
plt.suptitle( | ||
"Comparing KMedoids and CLARA", | ||
fontsize=14, | ||
) | ||
|
||
|
||
selected_models = [ | ||
( | ||
KMedoids(metric="cosine", n_clusters=n_digits), | ||
"KMedoids (cosine)", | ||
), | ||
( | ||
KMedoids(metric="manhattan", n_clusters=n_digits), | ||
"KMedoids (manhattan)", | ||
), | ||
( | ||
CLARA( | ||
metric="cosine", | ||
n_clusters=n_digits, | ||
init="heuristic", | ||
n_sampling=50, | ||
), | ||
"CLARA (cosine)", | ||
), | ||
( | ||
CLARA( | ||
metric="manhattan", | ||
n_clusters=n_digits, | ||
init="heuristic", | ||
n_sampling=50, | ||
), | ||
"CLARA (manhattan)", | ||
), | ||
] | ||
|
||
plot_rows = int(np.ceil(len(selected_models) / 2.0)) | ||
plot_cols = 2 | ||
|
||
for i, (model, description) in enumerate(selected_models): | ||
|
||
# Obtain labels for each point in mesh. Use last trained model. | ||
init_time = time.time() | ||
model.fit(reduced_data) | ||
Z = model.predict(np.c_[xx.ravel(), yy.ravel()]) | ||
computation_time = time.time() - init_time | ||
|
||
# Put the result into a color plot | ||
Z = Z.reshape(xx.shape) | ||
plt.subplot(plot_cols, plot_rows, i + 1) | ||
plt.imshow( | ||
Z, | ||
interpolation="nearest", | ||
extent=(xx.min(), xx.max(), yy.min(), yy.max()), | ||
cmap=plt.cm.Paired, | ||
aspect="auto", | ||
origin="lower", | ||
) | ||
|
||
plt.plot( | ||
reduced_data[:, 0], reduced_data[:, 1], "k.", markersize=2, alpha=0.3 | ||
) | ||
# Plot the centroids as a white X | ||
centroids = model.cluster_centers_ | ||
plt.scatter( | ||
centroids[:, 0], | ||
centroids[:, 1], | ||
marker="x", | ||
s=169, | ||
linewidths=3, | ||
color="w", | ||
zorder=10, | ||
) | ||
plt.title(description + ": %.2Fs" % (computation_time)) | ||
plt.xlim(x_min, x_max) | ||
plt.ylim(y_min, y_max) | ||
plt.xticks(()) | ||
plt.yticks(()) | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from ._k_medoids import KMedoids | ||
from ._k_medoids import KMedoids, CLARA | ||
from ._commonnn import commonnn, CommonNNClustering | ||
|
||
__all__ = ["KMedoids", "CommonNNClustering", "commonnn"] | ||
__all__ = ["KMedoids", "CLARA", "CommonNNClustering", "commonnn"] |
Oops, something went wrong.