diff --git a/doc/api.rst b/doc/api.rst index 57b36246..25fc8ed8 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -32,6 +32,7 @@ Clustering cluster.KMedoids cluster.CommonNNClustering + cluster.CLARA Robust ==================== diff --git a/doc/changelog.rst b/doc/changelog.rst index 36aa2ecc..053b9197 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -4,6 +4,9 @@ Changelog Unreleased ---------- +- Add `CLARA` (Clustering for Large Applications) which extends k-medoids to + be more scalable using a sampling approach. + [`#83 `_]. - Fix `_estimator_type` for :class:`~sklearn_extra.robust` estimators. Fix misbehavior of scikit-learn's :class:`~sklearn.model_selection.cross_val_score` and :class:`~sklearn.grid_search.GridSearchCV` for :class:`~sklearn_extra.robust.RobustWeightedClassifier` diff --git a/doc/modules/cluster.rst b/doc/modules/cluster.rst index 62390f9b..5bf9e259 100644 --- a/doc/modules/cluster.rst +++ b/doc/modules/cluster.rst @@ -1,8 +1,8 @@ .. _cluster: -===================================================== -Clustering with KMedoids and Common-nearest-neighbors -===================================================== +============================================================ +Clustering with KMedoids, CLARA and Common-nearest-neighbors +============================================================ .. _k_medoids: K-Medoids @@ -82,6 +82,38 @@ when speed is an issue. for performing face recognition. International Journal of Soft Computing, Mathematics and Control, 3(3), pp 1-12. + + +CLARA +===== + + :class:`CLARA` is related to the :class:`KMedoids` algorithm. CLARA + (Clustering for Large Applications) extends k-medoids to be more scalable, + uses a sampling approach. + + .. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_plot_clara_digits.py`: Applying K-Medoids on digits + with various distance metrics. + + + **Algorithm description:** + CLARA uses random samples of the dataset, each of size `sampling_size` + The algorith is iterative, first we select one sub-sample, then CLARA applies + KMedoids on this sub-sample to obtain `n_clusters` medoids. At the next step, + CLARA sample `sampling_size`-`n_clusters` from the dataset and the next sub-sample + is composed of the best medoids found until now (with respect to inertia in the + whole dataset, not the inertia only on the sub-sample) to which we add the new + samples just drawn. Then, K-Medoids is applied to this new sub-sample, and loop + back until `sample` sub-samples have been used. + + + .. topic:: References: + + * Kaufman, L. and Rousseeuw, P.J. (2008). Clustering Large Applications (Program CLARA). + In Finding Groups in Data (eds L. Kaufman and P.J. Rousseeuw). + doi:10.1002/9780470316801.ch2 + .. _commonnn: Common-nearest-neighbors clustering diff --git a/examples/plot_clara_digits.py b/examples/plot_clara_digits.py new file mode 100644 index 00000000..e1bb1f54 --- /dev/null +++ b/examples/plot_clara_digits.py @@ -0,0 +1,121 @@ +""" +====================================================================== +A demo of K-Medoids vs CLARA clustering on the handwritten digits data +====================================================================== +In this example we compare different computation time of K-Medoids and CLARA on +the handwritten digits data. +""" +import numpy as np +import matplotlib.pyplot as plt +import time + +from sklearn_extra.cluster import KMedoids, CLARA +from sklearn.datasets import load_digits +from sklearn.decomposition import PCA +from sklearn.preprocessing import scale + +print(__doc__) + +# Authors: Timo Erkkilä +# Antti Lehmussola +# Kornel Kiełczewski +# License: BSD 3 clause + +np.random.seed(42) + +digits = load_digits() +data = scale(digits.data) +n_digits = len(np.unique(digits.target)) + +reduced_data = PCA(n_components=2).fit_transform(data) + +# Step size of the mesh. Decrease to increase the quality of the VQ. +h = 0.02 # point in the mesh [x_min, m_max]x[y_min, y_max]. + +# Plot the decision boundary. For that, we will assign a color to each +x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1 +y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1 +xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) + +plt.figure() +plt.clf() + +plt.suptitle( + "Comparing KMedoids and CLARA", + fontsize=14, +) + + +selected_models = [ + ( + KMedoids(metric="cosine", n_clusters=n_digits), + "KMedoids (cosine)", + ), + ( + KMedoids(metric="manhattan", n_clusters=n_digits), + "KMedoids (manhattan)", + ), + ( + CLARA( + metric="cosine", + n_clusters=n_digits, + init="heuristic", + n_sampling=50, + ), + "CLARA (cosine)", + ), + ( + CLARA( + metric="manhattan", + n_clusters=n_digits, + init="heuristic", + n_sampling=50, + ), + "CLARA (manhattan)", + ), +] + +plot_rows = int(np.ceil(len(selected_models) / 2.0)) +plot_cols = 2 + +for i, (model, description) in enumerate(selected_models): + + # Obtain labels for each point in mesh. Use last trained model. + init_time = time.time() + model.fit(reduced_data) + Z = model.predict(np.c_[xx.ravel(), yy.ravel()]) + computation_time = time.time() - init_time + + # Put the result into a color plot + Z = Z.reshape(xx.shape) + plt.subplot(plot_cols, plot_rows, i + 1) + plt.imshow( + Z, + interpolation="nearest", + extent=(xx.min(), xx.max(), yy.min(), yy.max()), + cmap=plt.cm.Paired, + aspect="auto", + origin="lower", + ) + + plt.plot( + reduced_data[:, 0], reduced_data[:, 1], "k.", markersize=2, alpha=0.3 + ) + # Plot the centroids as a white X + centroids = model.cluster_centers_ + plt.scatter( + centroids[:, 0], + centroids[:, 1], + marker="x", + s=169, + linewidths=3, + color="w", + zorder=10, + ) + plt.title(description + ": %.2Fs" % (computation_time)) + plt.xlim(x_min, x_max) + plt.ylim(y_min, y_max) + plt.xticks(()) + plt.yticks(()) + +plt.show() diff --git a/sklearn_extra/cluster/__init__.py b/sklearn_extra/cluster/__init__.py index 0d4cf43c..426f8b99 100644 --- a/sklearn_extra/cluster/__init__.py +++ b/sklearn_extra/cluster/__init__.py @@ -1,4 +1,4 @@ -from ._k_medoids import KMedoids +from ._k_medoids import KMedoids, CLARA from ._commonnn import commonnn, CommonNNClustering -__all__ = ["KMedoids", "CommonNNClustering", "commonnn"] +__all__ = ["KMedoids", "CLARA", "CommonNNClustering", "commonnn"] diff --git a/sklearn_extra/cluster/_k_medoids.py b/sklearn_extra/cluster/_k_medoids.py index 05b1f198..cccd575c 100644 --- a/sklearn_extra/cluster/_k_medoids.py +++ b/sklearn_extra/cluster/_k_medoids.py @@ -20,10 +20,31 @@ from sklearn.utils.validation import check_is_fitted from sklearn.exceptions import ConvergenceWarning -# cython implementation of swap step in PAM algorithm. +# cython implementation of steps in PAM algorithm. from ._k_medoids_helper import _compute_optimal_swap, _build +def _compute_inertia(distances): + """Compute inertia of new samples. Inertia is defined as the sum of the + sample distances to closest cluster centers. + + Parameters + ---------- + distances : {array-like, sparse matrix}, shape=(n_samples, n_clusters) + Distances to cluster centers. + + Returns + ------- + Sum of sample distances to closest cluster centers. + """ + + # Define inertia as the sum of the sample-distances + # to closest cluster centers + inertia = np.sum(np.min(distances, axis=1)) + + return inertia + + class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin): """k-medoids clustering. @@ -43,7 +64,7 @@ class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin): method : {'alternate', 'pam'}, default: 'alternate' Which algorithm to use. 'alternate' is faster while 'pam' is more accurate. - init : {'random', 'heuristic', 'k-medoids++', 'build'}, optional, default: 'build' + init : {'random', 'heuristic', 'k-medoids++', 'build'}, optional, default: 'heuristic' Specify medoid initialization method. 'random' selects n_clusters elements from the dataset. 'heuristic' picks the n_clusters points with the smallest sum distance to every other point. 'k-medoids++' @@ -196,6 +217,7 @@ def fit(self, X, y=None): ) D = pairwise_distances(X, metric=self.metric) + medoid_idxs = self._initialize_medoids( D, self.n_clusters, random_state_ ) @@ -204,7 +226,16 @@ def fit(self, X, y=None): if self.method == "pam": # Compute the distance to the first and second closest points # among medoids. - Djs, Ejs = np.sort(D[medoid_idxs], axis=0)[[0, 1]] + + if self.n_clusters == 1 and self.max_iter > 0: + # PAM SWAP step can only be used for n_clusters > 1 + warnings.warn( + "n_clusters should be larger than 2 if max_iter != 0 " + "setting max_iter to 0." + ) + self.max_iter = 0 + elif self.max_iter > 0: + Djs, Ejs = np.sort(D[medoid_idxs], axis=0)[[0, 1]] # Continue the algorithm as long as # the medoids keep changing and the maximum number @@ -233,6 +264,7 @@ def fit(self, X, y=None): # update Djs and Ejs with new medoids Djs, Ejs = np.sort(D[medoid_idxs], axis=0)[[0, 1]] + else: raise ValueError( f"method={self.method} is not supported. Supported methods " @@ -259,7 +291,7 @@ def fit(self, X, y=None): # the training data to clusters self.labels_ = np.argmin(D[medoid_idxs, :], axis=0) self.medoid_indices_ = medoid_idxs - self.inertia_ = self._compute_inertia(self.transform(X)) + self.inertia_ = _compute_inertia(self.transform(X)) # Return self to enable method chaining return self @@ -301,7 +333,7 @@ def _update_medoid_idxs_in_place(self, D, labels, medoid_idxs): def _compute_cost(self, D, medoid_idxs): """Compute the cose for a given configuration of the medoids""" - return self._compute_inertia(D[:, medoid_idxs]) + return _compute_inertia(D[:, medoid_idxs]) def transform(self, X): """Transforms X to cluster-distance space. @@ -375,26 +407,6 @@ def predict(self, X): return pd_argmin - def _compute_inertia(self, distances): - """Compute inertia of new samples. Inertia is defined as the sum of the - sample distances to closest cluster centers. - - Parameters - ---------- - distances : {array-like, sparse matrix}, shape=(n_samples, n_clusters) - Distances to cluster centers. - - Returns - ------- - Sum of sample distances to closest cluster centers. - """ - - # Define inertia as the sum of the sample-distances - # to closest cluster centers - inertia = np.sum(np.min(distances, axis=1)) - - return inertia - def _initialize_medoids(self, D, n_clusters, random_state_): """Select initial mediods when beginning clustering.""" @@ -499,3 +511,239 @@ def _kpp_init(self, D, n_clusters, random_state_, n_local_trials=None): closest_dist_sq = best_dist_sq return centers + + +class CLARA(BaseEstimator, ClusterMixin, TransformerMixin): + """CLARA clustering. + + Read more in the :ref:`User Guide `. + CLARA (Clustering for Large Applications) extends k-medoids approach for a + large number of objects. This algorithm use a sampling approach. + + Parameters + ---------- + n_clusters : int, optional, default: 8 + The number of clusters to form as well as the number of medoids to + generate. + + metric : string, or callable, optional, default: 'euclidean' + What distance metric to use. See :func:metrics.pairwise_distances + + max_iter : int, optional, default : 300 + Specify the maximum number of iterations when fitting PAM. It can be zero + in which case only the initialization is computed. + + n_sampling : int or None, optional, default : None + Size of the sampled dataset at each iteration. sampling-size a trade-off + between complexity and efficiency. If None, then sampling-size is set + to min(sample_size, 40 + 2 * self.n_clusters) as suggested by the authors of the + algorithm. must be smaller than sample_size. + + n_sampling_iter : int, optional, default : 5 + Number of different samples that have to be done, or number of iterations. + + random_state : int, RandomState instance or None, optional + Specify random state for the random number generator. Used to + initialise medoids when init='random'. + + Attributes + ---------- + cluster_centers_ : array, shape = (n_clusters, n_features) + or None if metric == 'precomputed' + Cluster centers, i.e. medoids (elements from the original dataset) + + medoid_indices_ : array, shape = (n_clusters,) + The indices of the medoid rows in X + + labels_ : array, shape = (n_samples,) + Labels of each point + + inertia_ : float + Sum of distances of samples to their closest cluster center. + + Examples + -------- + >>> from sklearn_extra.cluster import CLARA + >>> import numpy as np + >>> from sklearn.datasets import make_blobs + >>> X, _ = make_blobs(centers=[[0,0],[1,1]], n_features=2,random_state=0) + >>> clara = CLARA(n_clusters=2, random_state=0).fit(X) + >>> clara.predict([[0,0], [4,4]]) + array([0, 1]) + >>> clara.inertia_ + 122.44919397611667 + + References + ---------- + Kaufman, L. and Rousseeuw, P.J. (2008). Partitioning Around Medoids (Program PAM). + In Finding Groups in Data (eds L. Kaufman and P.J. Rousseeuw). + doi:10.1002/9780470316801.ch2 + + See also + -------- + + KMedoids + CLARA is a variant of KMedoids that use sub-sampling scheme as such if the + dataset is sufficiently small, KMedoids is preferable. + + Notes + ----- + Contrary to KMedoids, CLARA is linear in N the sample size for both the spacial + and time complexity. On the other hand, it scales quadratically with n_sampling. + + """ + + def __init__( + self, + n_clusters=8, + metric="euclidean", + init="build", + max_iter=300, + n_sampling=None, + n_sampling_iter=5, + random_state=None, + ): + self.n_clusters = n_clusters + self.metric = metric + self.init = init + self.max_iter = max_iter + self.n_sampling = n_sampling + self.n_sampling_iter = n_sampling_iter + self.random_state = random_state + + def fit(self, X, y=None): + """Fit CLARA to the provided data. + + Parameters + ---------- + X : array-like, shape = (n_samples, n_features), \ + or (n_n_sampling_iter, n_n_sampling_iter) if metric == 'precomputed' + Dataset to cluster. + + y : Ignored + + Returns + ------- + self + """ + X = check_array(X, dtype=[np.float64, np.float32]) + n = len(X) + + random_state_ = check_random_state(self.random_state) + + if self.n_sampling is None: + n_sampling = max( + min(n, 40 + 2 * self.n_clusters), self.n_clusters + 1 + ) + else: + n_sampling = self.n_sampling + + # Check n_sampling. + + if n < self.n_clusters: + raise ValueError( + "sample_size should be greater than self.n_clusters" + ) + + if self.n_clusters >= n_sampling: + raise ValueError( + "sampling size must be strictly greater than self.n_clusters" + ) + + medoids_idxs = random_state_.choice( + np.arange(n), size=self.n_clusters, replace=False + ) + best_score = np.inf + for _ in range(self.n_sampling_iter): + if n_sampling >= n: + sample_idxs = np.arange(n) + else: + sample_idxs = np.hstack( + [ + medoids_idxs, + random_state_.choice( + np.delete(np.arange(n), medoids_idxs), + size=n_sampling - self.n_clusters, + replace=False, + ), + ] + ) + pam = KMedoids( + n_clusters=self.n_clusters, + metric=self.metric, + method="pam", + init=self.init, + max_iter=self.max_iter, + random_state=random_state_, + ) + pam.fit(X[sample_idxs]) + self.cluster_centers_ = pam.cluster_centers_ + self.inertia_ = _compute_inertia(self.transform(X)) + + if pam.inertia_ < best_score: + best_score = self.inertia_ + medoids_idxs = pam.medoid_indices_ + best_sample_idxs = sample_idxs + + self.medoid_indices_ = medoids_idxs + self.labels_ = np.argmin(self.transform(X), axis=1) + self.n_iter_ = self.n_sampling_iter + + return self + + def transform(self, X): + """Transforms X to cluster-distance space. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_query, n_features), \ + or (n_query, n_indexed) if metric == 'precomputed' + Data to transform. + + Returns + ------- + X_new : {array-like, sparse matrix}, shape=(n_query, n_clusters) + X transformed in the new space of distances to cluster centers. + """ + X = check_array( + X, accept_sparse=["csr", "csc"], dtype=[np.float64, np.float32] + ) + + if self.metric == "precomputed": + check_is_fitted(self, "medoid_indices_") + return X[:, self.medoid_indices_] + else: + check_is_fitted(self, "cluster_centers_") + + Y = self.cluster_centers_ + return pairwise_distances(X, Y=Y, metric=self.metric) + + def predict(self, X): + """Predict the closest cluster for each sample in X. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_query, n_features), \ + or (n_query, n_indexed) if metric == 'precomputed' + New data to predict. + + Returns + ------- + labels : array, shape = (n_query,) + Index of the cluster each sample belongs to. + """ + X = check_array( + X, accept_sparse=["csr", "csc"], dtype=[np.float64, np.float32] + ) + + if self.metric == "precomputed": + check_is_fitted(self, "medoid_indices_") + return np.argmin(X[:, self.medoid_indices_], axis=1) + else: + check_is_fitted(self, "cluster_centers_") + + # Return data points to clusters based on which cluster assignment + # yields the smallest distance + return pairwise_distances_argmin( + X, Y=self.cluster_centers_, metric=self.metric + ) diff --git a/sklearn_extra/cluster/tests/test_k_medoids.py b/sklearn_extra/cluster/tests/test_k_medoids.py index 6030241d..89742f3d 100644 --- a/sklearn_extra/cluster/tests/test_k_medoids.py +++ b/sklearn_extra/cluster/tests/test_k_medoids.py @@ -11,7 +11,7 @@ from numpy.testing import assert_allclose, assert_array_equal -from sklearn_extra.cluster import KMedoids +from sklearn_extra.cluster import KMedoids, CLARA from sklearn.cluster import KMeans from sklearn.datasets import make_blobs @@ -48,6 +48,17 @@ def test_kmedoid_results(method, init, dtype): assert dtype is np.dtype(km.transform(X_cc.astype(dtype)).dtype).type +def test_clara_results(): + expected = np.hstack([np.zeros(50), np.ones(50)]) + km = CLARA(n_clusters=2) + km.fit(X_cc) + # This test use data that are not perfectly separable so the + # accuracy is not 1. Accuracy around 0.85 + assert (np.mean(km.labels_ == expected) > 0.8) or ( + 1 - np.mean(km.labels_ == expected) > 0.8 + ) + + def test_medoids_invalid_method(): with pytest.raises(ValueError, match="invalid is not supported"): KMedoids(n_clusters=1, method="invalid").fit([[0, 1], [1, 1]]) @@ -340,7 +351,7 @@ def test_kmedoids_on_sparse_input(): # Test the build initialization. def test_build(): X, y = fetch_20newsgroups_vectorized(return_X_y=True) - # Select only the first 1000 samples + # Select only the first 500 samples X = X[:500] y = y[:500] # Precompute cosine distance matrix @@ -352,6 +363,26 @@ def test_build(): assert len(np.unique(ske.labels_)) == 20 +def test_clara_consistency_iris(): + # test that CLARA is PAM when full sample is used + + rng = np.random.RandomState(seed) + X_iris = load_iris()["data"] + + clara = CLARA( + n_clusters=3, + n_sampling_iter=1, + n_sampling=len(X_iris), + random_state=rng, + ) + + model = KMedoids(n_clusters=3, init="build", random_state=rng) + + model.fit(X_iris) + clara.fit(X_iris) + assert np.sum(model.labels_ == clara.labels_) == len(X_iris) + + def test_seuclidean(): with pytest.warns(None) as record: km = KMedoids(2, metric="seuclidean", method="pam") diff --git a/sklearn_extra/tests/test_common.py b/sklearn_extra/tests/test_common.py index 2da6cf22..3a72dc32 100644 --- a/sklearn_extra/tests/test_common.py +++ b/sklearn_extra/tests/test_common.py @@ -3,7 +3,7 @@ from sklearn_extra.kernel_approximation import Fastfood from sklearn_extra.kernel_methods import EigenProClassifier, EigenProRegressor -from sklearn_extra.cluster import KMedoids, CommonNNClustering +from sklearn_extra.cluster import KMedoids, CommonNNClustering, CLARA from sklearn_extra.robust import ( RobustWeightedClassifier, RobustWeightedRegressor, @@ -14,6 +14,7 @@ ALL_ESTIMATORS = [ Fastfood, KMedoids, + CLARA, EigenProClassifier, EigenProRegressor, CommonNNClustering,