diff --git a/doc/over_sampling.rst b/doc/over_sampling.rst index 6159e925b..bc6bf66d0 100644 --- a/doc/over_sampling.rst +++ b/doc/over_sampling.rst @@ -152,8 +152,9 @@ nearest neighbors class. Those variants are presented in the figure below. :align: center -The :class:`BorderlineSMOTE` [HWB2005]_, :class:`SVMSMOTE` [NCK2009]_, and -:class:`KMeansSMOTE` [LDB2017]_ offer some variant of the SMOTE algorithm:: +The :class:`BorderlineSMOTE` [HWB2005]_, :class:`SVMSMOTE` [NCK2009]_, +:class:`KMeansSMOTE` [LDB2017]_ and :class:`SafeLevelSMOTE` [BSL2009]_ +offer some variant of the SMOTE algorithm:: >>> from imblearn.over_sampling import BorderlineSMOTE >>> X_resampled, y_resampled = BorderlineSMOTE().fit_resample(X, y) @@ -213,6 +214,14 @@ other extra interpolation. Imbalanced Learning Based on K-Means and SMOTE" https://arxiv.org/abs/1711.00837 + [BSL2009] C. Bunkhumpornpat, K. Sinapiromsaran, C. Lursinsap, + "Safe-level-SMOTE: Safe-level-synthetic minority over-sampling + technique for handling the class imbalanced problem," In: + Theeramunkong T., Kijsirikul B., Cercone N., Ho TB. (eds) + Advances in Knowledge Discovery and Data Mining. PAKDD 2009. + Lecture Notes in Computer Science, vol 5476. Springer, Berlin, + Heidelberg, 475-482, 2009. + Mathematical formulation ======================== @@ -274,6 +283,11 @@ parameter ``m_neighbors`` to decide if a sample is in danger, safe, or noise. method before to apply SMOTE. The clustering will group samples together and generate new samples depending of the cluster density. +**SafeLevel** SMOTE --- cf. to :class:`SafeLevelSMOTE` --- uses the safe level +(the number of positive instances in nearest neighbors) to generate a synthetic +instance. Compared to regular SMOTE, the new instance is positioned closer to +the positive instance with larger safe level. + ADASYN works similarly to the regular SMOTE. However, the number of samples generated for each :math:`x_i` is proportional to the number of samples which are not from the same class than :math:`x_i` in a given diff --git a/imblearn/over_sampling/__init__.py b/imblearn/over_sampling/__init__.py index bd20b76ea..8027b18a2 100644 --- a/imblearn/over_sampling/__init__.py +++ b/imblearn/over_sampling/__init__.py @@ -10,6 +10,7 @@ from ._smote import KMeansSMOTE from ._smote import SVMSMOTE from ._smote import SMOTENC +from ._smote import SafeLevelSMOTE __all__ = [ "ADASYN", @@ -19,4 +20,5 @@ "BorderlineSMOTE", "SVMSMOTE", "SMOTENC", + "SafeLevelSMOTE", ] diff --git a/imblearn/over_sampling/_smote.py b/imblearn/over_sampling/_smote.py index 483b0c720..09d695885 100644 --- a/imblearn/over_sampling/_smote.py +++ b/imblearn/over_sampling/_smote.py @@ -284,6 +284,11 @@ class BorderlineSMOTE(BaseSMOTE): SVMSMOTE : Over-sample using SVM-SMOTE variant. + KMeansSMOTE: Over-sample using KMeans-SMOTE variant. + + SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant. + + ADASYN : Over-sample using ADASYN. References @@ -484,6 +489,10 @@ class SVMSMOTE(BaseSMOTE): BorderlineSMOTE : Over-sample using Borderline-SMOTE. + KMeansSMOTE: Over-sample using KMeans-SMOTE variant. + + SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant. + ADASYN : Over-sample using ADASYN. References @@ -586,12 +595,14 @@ def _fit_resample(self, X, y): n_generated_samples = int(fractions * (n_samples + 1)) if np.count_nonzero(danger_bool) > 0: nns = self.nn_k_.kneighbors( - _safe_indexing(support_vector, np.flatnonzero(danger_bool)), + _safe_indexing( + support_vector, np.flatnonzero(danger_bool)), return_distance=False, )[:, 1:] X_new_1, y_new_1 = self._make_samples( - _safe_indexing(support_vector, np.flatnonzero(danger_bool)), + _safe_indexing( + support_vector, np.flatnonzero(danger_bool)), y.dtype, class_sample, X_class, @@ -602,12 +613,14 @@ def _fit_resample(self, X, y): if np.count_nonzero(safety_bool) > 0: nns = self.nn_k_.kneighbors( - _safe_indexing(support_vector, np.flatnonzero(safety_bool)), + _safe_indexing( + support_vector, np.flatnonzero(safety_bool)), return_distance=False, )[:, 1:] X_new_2, y_new_2 = self._make_samples( - _safe_indexing(support_vector, np.flatnonzero(safety_bool)), + _safe_indexing( + support_vector, np.flatnonzero(safety_bool)), y.dtype, class_sample, X_class, @@ -691,6 +704,10 @@ class SMOTE(BaseSMOTE): SVMSMOTE : Over-sample using the SVM-SMOTE variant. + KMeansSMOTE: Over-sample using KMeans-SMOTE variant. + + SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant. + ADASYN : Over-sample using ADASYN. References @@ -860,6 +877,10 @@ class SMOTENC(SMOTE): BorderlineSMOTE : Over-sample using Borderline-SMOTE variant. + KMeansSMOTE: Over-sample using KMeans-SMOTE variant. + + SafeLevelSMOTE: Over-sample using SafeLevel-SMOTE variant. + ADASYN : Over-sample using ADASYN. References @@ -1308,3 +1329,321 @@ def _fit_resample(self, X, y): y_resampled = np.hstack((y_resampled, y_new)) return X_resampled, y_resampled + + +@Substitution( + sampling_strategy=BaseOverSampler._sampling_strategy_docstring, + random_state=_random_state_docstring, +) +class SafeLevelSMOTE(BaseSMOTE): + """Class to perform over-sampling using safe-level SMOTE. + This is an implementation of the Safe-level-SMOTE described in [2]_. + + Parameters + ----------- + {sampling_strategy} + + {random_state} + + k_neighbors : int or object, optional (default=5) + If ``int``, number of nearest neighbours to used to construct synthetic + samples. If object, an estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the k_neighbors. + + m_neighbors : int or object, optional (default=10) + If ``int``, number of nearest neighbours used to determine the safe + level of an instance. If object, an estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used + to find the m_neighbors. + + n_jobs : int or None, optional (default=None) + Number of CPU cores used during the cross-validation loop. + ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. + ``-1`` means using all processors. See + `Glossary `_ + for more details. + + + Notes + ----- + See the original papers: [2]_ for more details. + + Supports multi-class resampling. A one-vs.-rest scheme is used as + originally proposed in [1]_. + + See also + -------- + SMOTE : Over-sample using SMOTE. + + SMOTENC : Over-sample using SMOTE for continuous and categorical features. + + SVMSMOTE : Over-sample using SVM-SMOTE variant. + + BorderlineSMOTE : Over-sample using Borderline-SMOTE. + + ADASYN : Over-sample using ADASYN. + + KMeansSMOTE: Over-sample using KMeans-SMOTE variant. + + References + ---------- + .. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE: + synthetic minority over-sampling technique," Journal of artificial + intelligence research, 321-357, 2002. + + .. [2] C. Bunkhumpornpat, K. Sinapiromsaran, C. Lursinsap, "Safe-level- + SMOTE: Safe-level-synthetic minority over-sampling technique for + handling the class imbalanced problem," In: Theeramunkong T., + Kijsirikul B., Cercone N., Ho TB. (eds) Advances in Knowledge Discovery + and Data Mining. PAKDD 2009. Lecture Notes in Computer Science, + vol 5476. Springer, Berlin, Heidelberg, 475-482, 2009. + + + Examples + -------- + + >>> from collections import Counter + >>> from sklearn.datasets import make_classification + >>> from imblearn.over_sampling import \ +SafeLevelSMOTE # doctest: +NORMALIZE_WHITESPACE + >>> X, y = make_classification(n_classes=2, class_sep=2, + ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) + >>> print('Original dataset shape %s' % Counter(y)) + Original dataset shape Counter({{1: 900, 0: 100}}) + >>> sm = SafeLevelSMOTE(random_state=42) + >>> X_res, y_res = sm.fit_resample(X, y) + >>> print('Resampled dataset shape %s' % Counter(y_res)) + Resampled dataset shape Counter({{0: 900, 1: 900}}) + + """ + + def __init__(self, + sampling_strategy='auto', + random_state=None, + k_neighbors=5, + m_neighbors=10, + n_jobs=None): + + super().__init__(sampling_strategy=sampling_strategy, + random_state=random_state, k_neighbors=k_neighbors, + n_jobs=n_jobs) + + self.m_neighbors = m_neighbors + + def _assign_safe_levels(self, nn_estimator, samples, target_class, y): + ''' + Assign the safe levels to the instances in the target class. + + Parameters + ---------- + nn_estimator : estimator + An estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin`. It gets the + nearest neighbors that are used to determine the safe levels. + + samples : {array-like, sparse matrix}, shape (n_samples, n_features) + The samples to which the safe levels are assigned. + + target_class : int or str + The target corresponding class being over-sampled. + + y : array-like, shape (n_samples,) + The true label in order to calculate the safe levels. + + Returns + ------- + output : ndarray, shape (n_samples,) + A ndarray where the values refer to the safe level of the + instances in the target class. + ''' + + x = nn_estimator.kneighbors(samples, return_distance=False)[:, 1:] + nn_label = (y[x] == target_class).astype(int) + safe_levels = np.sum(nn_label, axis=1) + return safe_levels + + def _validate_estimator(self): + super()._validate_estimator() + self.nn_m_ = check_neighbors_object('m_neighbors', self.m_neighbors, + additional_neighbor=1) + self.nn_m_.set_params(**{"n_jobs": self.n_jobs}) + + def _fit_resample(self, X, y): + self._validate_estimator() + + X_resampled = X.copy() + y_resampled = y.copy() + + for class_sample, n_samples in self.sampling_strategy_.items(): + if n_samples == 0: + continue + target_class_indices = np.flatnonzero(y == class_sample) + X_class = _safe_indexing(X, target_class_indices) + + self.nn_m_.fit(X) + safe_levels = self._assign_safe_levels( + self.nn_m_, X_class, class_sample, y) + + # filter the points in X_class that have safe level >0 + # If safe level = 0, the point is not used to + # generate synthetic instances + X_safe_indices = np.flatnonzero(safe_levels != 0) + X_safe_class = _safe_indexing(X_class, X_safe_indices) + + self.nn_k_.fit(X_class) + nns = self.nn_k_.kneighbors(X_safe_class, + return_distance=False)[:, 1:] + + sl_safe_class = safe_levels[X_safe_indices] + sl_nns = safe_levels[nns] + sl_safe_t = np.array([sl_safe_class]).transpose() + with np.errstate(divide='ignore'): + safe_level_ratio = np.divide(sl_safe_t, sl_nns) + + X_new, y_new = self._make_samples_safelevel(X_safe_class, y.dtype, + class_sample, X_class, + nns, n_samples, + safe_level_ratio, + 1.0) + + if sparse.issparse(X_new): + X_resampled = sparse.vstack([X_resampled, X_new]) + else: + X_resampled = np.vstack((X_resampled, X_new)) + y_resampled = np.hstack((y_resampled, y_new)) + + return X_resampled, y_resampled + + def _make_samples_safelevel(self, X, y_dtype, y_type, nn_data, nn_num, + n_samples, safe_level_ratio, step_size=1.): + """A support function that returns artificial samples using + safe-level SMOTE. It is similar to _make_samples method for SMOTE. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples_safe, n_features) + Points from which the points will be created. + + y_dtype : dtype + The data type of the targets. + + y_type : str or int + The minority target value, just so the function can return the + target values for the synthetic variables with correct length in + a clear format. + + nn_data : ndarray, shape (n_samples_all, n_features) + Data set carrying all the neighbours to be used + + nn_num : ndarray, shape (n_samples_safe, k_nearest_neighbours) + The nearest neighbours of each sample in `nn_data`. + + n_samples : int + The number of samples to generate. + + safe_level_ratio: ndarray, shape (n_samples_safe, k_nearest_neighbours) + + step_size : float, optional (default=1.) + The step size to create samples. + + + Returns + ------- + X_new : {ndarray, sparse matrix}, shape (n_samples_new, n_features) + Synthetically generated samples using the safe-level method. + + y_new : ndarray, shape (n_samples_new,) + Target values for synthetic samples. + + """ + + random_state = check_random_state(self.random_state) + samples_indices = random_state.randint(low=0, + high=len(nn_num.flatten()), + size=n_samples) + rows = np.floor_divide(samples_indices, nn_num.shape[1]) + cols = np.mod(samples_indices, nn_num.shape[1]) + gap_array = step_size * self._vgenerate_gap(safe_level_ratio) + gaps = gap_array.flatten()[samples_indices] + + y_new = np.array([y_type] * n_samples, dtype=y_dtype) + + if sparse.issparse(X): + row_indices, col_indices, samples = [], [], [] + for i, (row, col, gap) in enumerate(zip(rows, cols, gaps)): + if X[row].nnz: + sample = self._generate_sample( + X, nn_data, nn_num, row, col, gap) + row_indices += [i] * len(sample.indices) + col_indices += sample.indices.tolist() + samples += sample.data.tolist() + return ( + sparse.csr_matrix( + (samples, (row_indices, col_indices)), + [len(samples_indices), X.shape[1]], + dtype=X.dtype, + ), + y_new, + ) + + else: + X_new = np.zeros((n_samples, X.shape[1]), dtype=X.dtype) + for i, (row, col, gap) in enumerate(zip(rows, cols, gaps)): + X_new[i] = self._generate_sample(X, nn_data, nn_num, + row, col, gap) + + return X_new, y_new + + def _generate_gap(self, a_ratio, rand_state=None): + """ generate gap according to safe_level_ratio, non-vectorized version. + + Parameters + ---------- + a_ratio: float + safe_level_ratio of a single data point + + rand_state: random state object or int + + + Returns + ------------ + gap: float + a number between 0 and 1 + + """ + + random_state = check_random_state(rand_state) + if np.isinf(a_ratio): + gap = 0 + elif a_ratio >= 1: + gap = random_state.uniform(0, 1/a_ratio) + else: + gap = random_state.uniform(1-a_ratio, 1) + return gap + + def _vgenerate_gap(self, safe_level_ratio): + """ + generate gap according to safe_level_ratio, vectorized version + of _generate_gap + + Parameters + ----------- + safe_level_ratio: ndarray shape (n_samples_safe, k_nearest_neighbours) + safe_level_ratio of all instances with safe_level>0 in the + specified class + + Returns + ------------ + gap_array: ndarray shape (n_samples_safe, k_nearest_neighbours) + the gap for all instances with safe_level>0 in the specified + class + + """ + prng = check_random_state(self.random_state) + rand_state = prng.randint( + safe_level_ratio.size+1, size=safe_level_ratio.shape) + vgap = np.vectorize(self._generate_gap) + gap_array = vgap(safe_level_ratio, rand_state) + return gap_array diff --git a/imblearn/over_sampling/tests/test_safelevel_smote.py b/imblearn/over_sampling/tests/test_safelevel_smote.py new file mode 100644 index 000000000..ad4a3c8aa --- /dev/null +++ b/imblearn/over_sampling/tests/test_safelevel_smote.py @@ -0,0 +1,80 @@ +import pytest +import numpy as np +from collections import Counter + +from sklearn.neighbors import NearestNeighbors +from scipy import sparse + +from sklearn.utils._testing import assert_allclose +from sklearn.utils._testing import assert_array_equal + +from imblearn.over_sampling import SafeLevelSMOTE + + +def data_np(): + rng = np.random.RandomState(42) + X = rng.randn(20, 2) + y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) + return X, y + + +def data_sparse(format): + X = sparse.random(20, 2, density=0.3, format=format, random_state=42) + y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0]) + return X, y + + +@pytest.mark.parametrize( + "data", + [data_np(), data_sparse('csr'), data_sparse('csc')] +) +def test_safelevel_smote(data): + y_gt = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, + 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0]) + X, y = data + safelevel_smote = SafeLevelSMOTE(random_state=42) + X_res, y_res = safelevel_smote.fit_resample(X, y) + + assert X_res.shape == (24, 2) + assert_array_equal(y_res, y_gt) + + +def test_sl_smote_nn(): + X, y = data_np() + safelevel_smote = SafeLevelSMOTE(random_state=42) + safelevel_smote_nn = SafeLevelSMOTE( + random_state=42, + k_neighbors=NearestNeighbors(n_neighbors=6), + m_neighbors=NearestNeighbors(n_neighbors=11), + ) + + X_res_1, y_res_1 = safelevel_smote.fit_resample(X, y) + X_res_2, y_res_2 = safelevel_smote_nn.fit_resample(X, y) + + assert_allclose(X_res_1, X_res_2) + assert_array_equal(y_res_1, y_res_2) + + +def test_sl_smote_pd(): + pd = pytest.importorskip("pandas") + X, y = data_np() + X_pd = pd.DataFrame(X) + safelevel_smote = SafeLevelSMOTE(random_state=42) + X_res, y_res = safelevel_smote.fit_resample(X, y) + X_res_pd, y_res_pd = safelevel_smote.fit_resample(X_pd, y) + + assert X_res_pd.tolist() == X_res.tolist() + assert_allclose(y_res_pd, y_res) + + +def test_sl_smote_multiclass(): + rng = np.random.RandomState(42) + X = rng.randn(50, 2) + y = np.array([0] * 10 + [1] * 15 + [2] * 25) + safelevel_smote = SafeLevelSMOTE(random_state=42) + X_res, y_res = safelevel_smote.fit_resample(X, y) + + count_y_res = Counter(y_res) + assert count_y_res[0] == 25 + assert count_y_res[1] == 25 + assert count_y_res[2] == 25