diff --git a/src/skmatter/preprocessing/_data.py b/src/skmatter/preprocessing/_data.py index 07160dea4..a805e7460 100644 --- a/src/skmatter/preprocessing/_data.py +++ b/src/skmatter/preprocessing/_data.py @@ -524,9 +524,15 @@ def fit(self, Knm, Kmm, y=None, sample_weight=None): if self.with_trace: Knm_centered = Knm - self.K_fit_rows_ - Khat = Knm_centered @ np.linalg.pinv(Kmm, self.rcond) @ Knm_centered.T + # The following is more correctly written as Knm @ Kmm^{-1} @ Knm.T + # but has been changed to Knm.T @ Knm @ Kmm^{-1} to avoid the memory + # overload often caused by storing n x n matrices. This is fine + # for the following trace, but should not be used for other operations. + Khat_trace = np.trace( + Knm_centered.T @ Knm_centered @ np.linalg.pinv(Kmm, self.rcond) + ) - self.scale_ = np.sqrt(np.trace(Khat) / Knm.shape[0]) + self.scale_ = Khat_trace / Knm.shape[0] else: self.scale_ = 1.0