From 5a42b168fd223b9f195b3519addbb6b050f4efc3 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Fri, 31 May 2024 13:18:19 -0600 Subject: [PATCH 1/2] use Matt's sparse regridding function --- ndpyramid/regrid.py | 162 +++++++++++++++++++++++++++++++++++++++----- ndpyramid/utils.py | 1 + 2 files changed, 145 insertions(+), 18 deletions(-) diff --git a/ndpyramid/regrid.py b/ndpyramid/regrid.py index b43e44a..146500a 100644 --- a/ndpyramid/regrid.py +++ b/ndpyramid/regrid.py @@ -177,10 +177,9 @@ def generate_weights_pyramid( plevels['/'] = root return dt.DataTree.from_dict(plevels) - -def pyramid_regrid( +def pyramid_regrid_sparse( ds: xr.Dataset, - projection:typing.Literal['web-mercator', 'equidistant-cylindrical'] = 'web-mercator', + projection: typing.Literal['web-mercator', 'equidistant-cylindrical'] = 'web-mercator', target_pyramid: dt.DataTree = None, levels: int = None, weights_pyramid: dt.DataTree = None, @@ -189,7 +188,6 @@ def pyramid_regrid( regridder_apply_kws: dict = None, other_chunks: dict = None, pixels_per_tile: int = 128, - ) -> dt.DataTree: """Make a pyramid using xesmf's regridders @@ -222,7 +220,7 @@ def pyramid_regrid( pyramid : dt.DataTree Multiscale data pyramid """ - import xesmf as xe + #import xesmf as xe if target_pyramid is None: if levels is not None: @@ -236,21 +234,30 @@ def pyramid_regrid( regridder_kws = {'periodic': True, **regridder_kws} # multiscales spec - save_kwargs = locals() - del save_kwargs['ds'] - del save_kwargs['target_pyramid'] - del save_kwargs['xe'] - del save_kwargs['weights_pyramid'] + projection_model = Projection(name=projection) + save_kwargs = { + 'levels': levels, + 'pixels_per_tile': pixels_per_tile, + 'projection': projection, + 'other_chunks': other_chunks, + 'method': method, + 'regridder_kws': regridder_kws, + 'regridder_apply_kws': regridder_apply_kws, + } attrs = { 'multiscales': multiscales_template( - datasets=[{'path': str(i)} for i in range(levels)], + datasets=[ + {'path': str(i), 'level': i, 'crs': projection_model._crs} for i in range(levels) + ], type='reduce', method='pyramid_regrid', version=get_version(), kwargs=save_kwargs, ) } + save_kwargs.pop('levels') + save_kwargs.pop('other_chunks') # set up pyramid @@ -261,7 +268,8 @@ def pyramid_regrid( grid = target_pyramid[str(level)].ds.load() # get the regridder object if weights_pyramid is None: - regridder = xe.Regridder(ds, grid, method, **regridder_kws) + #regridder = xe.Regridder(ds, grid, method, **regridder_kws) + raise NotImplementedError("This requires xESMF, which we're trying to avoid") else: # Reconstruct weights into format that xESMF understands # this is a hack that assumes the weights were generated by @@ -269,21 +277,139 @@ def pyramid_regrid( ds_w = weights_pyramid[str(level)].ds weights = _reconstruct_xesmf_weights(ds_w) - regridder = xe.Regridder( - ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws - ) + #regridder = xe.Regridder( + # ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws + #) + # regrid if regridder_apply_kws is None: regridder_apply_kws = {} - regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws} - plevels[str(level)] = regridder(ds, **regridder_apply_kws) + #regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws} + + plevels[str(level)] = xr_regridder(ds, grid, weights, out_grid_shape=(grid.sizes['x'], grid.sizes['y'])) + + level_attrs = { + 'multiscales': multiscales_template( + datasets=[{'path': '.', 'level': level, 'crs': projection_model._crs}], + type='reduce', + method='pyramid_regrid', + version=get_version(), + kwargs=save_kwargs, + ) + } + plevels[str(level)].attrs['multiscales'] = level_attrs['multiscales'] root = xr.Dataset(attrs=attrs) plevels['/'] = root pyramid = dt.DataTree.from_dict(plevels) pyramid = add_metadata_and_zarr_encoding( - pyramid, levels=levels, other_chunks=other_chunks, pixels_per_tile=pixels_per_tile, projection=Projection(name=projection) + pyramid, + levels=levels, + other_chunks=other_chunks, + pixels_per_tile=pixels_per_tile, + projection=Projection(name=projection), ) return pyramid + + +def xr_regridder( + ds: xr.Dataset, + grid: xr.Dataset, + weights: xr.DataArray, + out_grid_shape: tuple[int, int], +) -> xr.Dataset: + """ + Xarray-aware regridding function that uses weights from xESMF but performs the regridding using sparse matrix multiplication. + + Parameters + ---------- + ds + weights + out_grid_shape + + Returns + ------- + regridded_ds + """ + + latlon_dims = ['nlat', 'nlon'] + + shape_in = (ds.sizes['nlat'], ds.sizes['nlon']) + shape_out = out_grid_shape + + # make sure coords along non-core dims are propagated + # (this is probably superfluous now we're using xr.apply_ufunc) + non_lateral_dims = [d for d in ds.dims if d not in latlon_dims] + coords_to_copy = {d: ds.coords[d] for d in non_lateral_dims if d in ds.coords} + + regridded_ds = xr.apply_ufunc( + esmf_apply_weights, + weights, + ds, + input_core_dims=[['out_dim', 'in_dim'], latlon_dims], + output_core_dims=[latlon_dims], + exclude_dims=set(latlon_dims), + kwargs={'shape_in': shape_in, 'shape_out': shape_out}, + keep_attrs=True, + ).rename_dims(nlon='x', nlat='y') + + # add coordinates for new grid (i.e. along core dims) + regridded_ds_with_new_grid_coords = xr.merge([regridded_ds, grid]) + + return regridded_ds_with_new_grid_coords.assign_coords(coords_to_copy) + + +def esmf_apply_weights(weights, indata, shape_in, shape_out): + """ + Apply regridding weights to data. + + Parameters + ---------- + A : scipy sparse COO matrix + indata : numpy array of shape ``(..., n_lat, n_lon)`` or ``(..., n_y, n_x)``. + Should be C-ordered. Will be then tranposed to F-ordered. + shape_in, shape_out : tuple of two integers + Input/output data shape for unflatten operation. + For rectilinear grid, it is just ``(n_lat, n_lon)``. + Returns + ------- + outdata : numpy array of shape ``(..., shape_out[0], shape_out[1])``. + Extra dimensions are the same as `indata`. + If input data is C-ordered, output will also be C-ordered. + """ + + # COO matrix is fast with F-ordered array but slow with C-array, so we + # take in a C-ordered and then transpose) + # (CSR or CRS matrix is fast with C-ordered array but slow with F-array) + if not indata.flags["C_CONTIGUOUS"]: + warnings.warn("Input array is not C_CONTIGUOUS. " + "Will affect performance.") + + # get input shape information + shape_horiz = indata.shape[-2:] + extra_shape = indata.shape[0:-2] + + assert shape_horiz == shape_in, ( + "The horizontal shape of input data is {}, different from that of" + "the regridder {}!".format(shape_horiz, shape_in) + ) + + assert shape_in[0] * shape_in[1] == weights.shape[1], ( + "ny_in * nx_in should equal to weights.shape[1]" + ) + + assert shape_out[0] * shape_out[1] == weights.shape[0], ( + "ny_out * nx_out should equal to weights.shape[0]" + ) + + # use flattened array for dot operation + indata_flat = indata.reshape(-1, shape_in[0]*shape_in[1]) + outdata_flat = weights.dot(indata_flat.T).T + + # unflattened output array + outdata = outdata_flat.reshape( + [*extra_shape, shape_out[0], shape_out[1]]) + + return outdata diff --git a/ndpyramid/utils.py b/ndpyramid/utils.py index 54cc1b1..7b41bf4 100644 --- a/ndpyramid/utils.py +++ b/ndpyramid/utils.py @@ -147,6 +147,7 @@ def add_metadata_and_zarr_encoding( pyramid.ds.attrs['multiscales'][0]['datasets'][level]['pixels_per_tile'] = pixels_per_tile if projection: pyramid.ds.attrs['multiscales'][0]['datasets'][level]['crs'] = projection._crs + # set dataset chunks pyramid[slevel].ds = pyramid[slevel].ds.chunk(chunks) From 56aeea3382f6c34688c7d0b369bc8d14db7f5c17 Mon Sep 17 00:00:00 2001 From: Tom Nicholas Date: Mon, 3 Jun 2024 11:20:19 -0600 Subject: [PATCH 2/2] dask parallelization --- ndpyramid/regrid.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/ndpyramid/regrid.py b/ndpyramid/regrid.py index 146500a..bc6c15e 100644 --- a/ndpyramid/regrid.py +++ b/ndpyramid/regrid.py @@ -2,6 +2,7 @@ import itertools import typing +import warnings import datatree as dt import numpy as np @@ -177,7 +178,7 @@ def generate_weights_pyramid( plevels['/'] = root return dt.DataTree.from_dict(plevels) -def pyramid_regrid_sparse( +def pyramid_regrid( ds: xr.Dataset, projection: typing.Literal['web-mercator', 'equidistant-cylindrical'] = 'web-mercator', target_pyramid: dt.DataTree = None, @@ -277,14 +278,10 @@ def pyramid_regrid_sparse( ds_w = weights_pyramid[str(level)].ds weights = _reconstruct_xesmf_weights(ds_w) - #regridder = xe.Regridder( - # ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws - #) # regrid if regridder_apply_kws is None: regridder_apply_kws = {} - #regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws} plevels[str(level)] = xr_regridder(ds, grid, weights, out_grid_shape=(grid.sizes['x'], grid.sizes['y'])) @@ -339,6 +336,8 @@ def xr_regridder( shape_in = (ds.sizes['nlat'], ds.sizes['nlon']) shape_out = out_grid_shape + output_sizes = {'nlat': out_grid_shape[0], 'nlon': out_grid_shape[1]} + # make sure coords along non-core dims are propagated # (this is probably superfluous now we're using xr.apply_ufunc) non_lateral_dims = [d for d in ds.dims if d not in latlon_dims] @@ -352,6 +351,9 @@ def xr_regridder( output_core_dims=[latlon_dims], exclude_dims=set(latlon_dims), kwargs={'shape_in': shape_in, 'shape_out': shape_out}, + dask='parallelized', + dask_gufunc_kwargs={'output_sizes': output_sizes}, + output_dtypes=[np.float32], # bug in xarray here where you can't pass output_dtypes via dask_gufunc_kwargs keep_attrs=True, ).rename_dims(nlon='x', nlat='y') @@ -391,18 +393,23 @@ def esmf_apply_weights(weights, indata, shape_in, shape_out): shape_horiz = indata.shape[-2:] extra_shape = indata.shape[0:-2] - assert shape_horiz == shape_in, ( - "The horizontal shape of input data is {}, different from that of" - "the regridder {}!".format(shape_horiz, shape_in) + if shape_horiz != shape_in: + raise ValueError( + f"The horizontal shape of input data is {shape_horiz}, different from that of" + f"the regridder {shape_in}!" ) - assert shape_in[0] * shape_in[1] == weights.shape[1], ( - "ny_in * nx_in should equal to weights.shape[1]" - ) + n_points_in = shape_in[0] * shape_in[1] + if n_points_in != weights.shape[1]: + raise ValueError( + f"ny_in * nx_in should equal to weights.shape[1], but found {n_points_in} vs {weights.shape[1]}" + ) - assert shape_out[0] * shape_out[1] == weights.shape[0], ( - "ny_out * nx_out should equal to weights.shape[0]" - ) + n_points_out = shape_out[0] * shape_out[1] + if n_points_out != weights.shape[0]: + raise ValueError( + f"ny_out * nx_out should equal to weights.shape[0], but found {n_points_out} vs {weights.shape[0]}" + ) # use flattened array for dot operation indata_flat = indata.reshape(-1, shape_in[0]*shape_in[1])