diff --git a/tests/test_RuLSIF.py b/tests/test_RuLSIF.py index de83d2e..dbe7029 100644 --- a/tests/test_RuLSIF.py +++ b/tests/test_RuLSIF.py @@ -1,7 +1,7 @@ import unittest from scipy.stats import norm, multivariate_normal -from numpy import linspace +from numpy import linspace, mgrid from .context import densratio @@ -45,7 +45,9 @@ def test_alphadensratio_2d(self): y = multivariate_normal.rvs(size=300, mean=[1, 1], cov=[[1./2, 0], [0, 2]], random_state=71) result = densratio(x, y, alpha=0.5) self.assertIsNotNone(result) - density_ratio = result.compute_density_ratio(linspace(-1, 3)) + space_range = slice(-1, 3, 50j) + space_2d = mgrid[space_range, space_range].reshape(2, -1).T + density_ratio = result.compute_density_ratio(space_2d) self.assertTrue((density_ratio >= 0).all()) def test_densratio_dimension_error(self):