Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse regridding without using ESMF #130

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 150 additions & 17 deletions ndpyramid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools
import typing
import warnings

import datatree as dt
import numpy as np
Expand Down Expand Up @@ -177,10 +178,9 @@ def generate_weights_pyramid(
plevels['/'] = root
return dt.DataTree.from_dict(plevels)


def pyramid_regrid(
ds: xr.Dataset,
projection:typing.Literal['web-mercator', 'equidistant-cylindrical'] = 'web-mercator',
projection: typing.Literal['web-mercator', 'equidistant-cylindrical'] = 'web-mercator',
target_pyramid: dt.DataTree = None,
levels: int = None,
weights_pyramid: dt.DataTree = None,
Expand All @@ -189,7 +189,6 @@ def pyramid_regrid(
regridder_apply_kws: dict = None,
other_chunks: dict = None,
pixels_per_tile: int = 128,

) -> dt.DataTree:
"""Make a pyramid using xesmf's regridders

Expand Down Expand Up @@ -222,7 +221,7 @@ def pyramid_regrid(
pyramid : dt.DataTree
Multiscale data pyramid
"""
import xesmf as xe
#import xesmf as xe

if target_pyramid is None:
if levels is not None:
Expand All @@ -236,21 +235,30 @@ def pyramid_regrid(
regridder_kws = {'periodic': True, **regridder_kws}

# multiscales spec
save_kwargs = locals()
del save_kwargs['ds']
del save_kwargs['target_pyramid']
del save_kwargs['xe']
del save_kwargs['weights_pyramid']
projection_model = Projection(name=projection)
save_kwargs = {
'levels': levels,
'pixels_per_tile': pixels_per_tile,
'projection': projection,
'other_chunks': other_chunks,
'method': method,
'regridder_kws': regridder_kws,
'regridder_apply_kws': regridder_apply_kws,
}

attrs = {
'multiscales': multiscales_template(
datasets=[{'path': str(i)} for i in range(levels)],
datasets=[
{'path': str(i), 'level': i, 'crs': projection_model._crs} for i in range(levels)
],
type='reduce',
method='pyramid_regrid',
version=get_version(),
kwargs=save_kwargs,
)
}
save_kwargs.pop('levels')
save_kwargs.pop('other_chunks')

# set up pyramid

Expand All @@ -261,29 +269,154 @@ def pyramid_regrid(
grid = target_pyramid[str(level)].ds.load()
# get the regridder object
if weights_pyramid is None:
regridder = xe.Regridder(ds, grid, method, **regridder_kws)
#regridder = xe.Regridder(ds, grid, method, **regridder_kws)
raise NotImplementedError("This requires xESMF, which we're trying to avoid")
else:
# Reconstruct weights into format that xESMF understands
# this is a hack that assumes the weights were generated by
# the `generate_weights_pyramid` function

ds_w = weights_pyramid[str(level)].ds
weights = _reconstruct_xesmf_weights(ds_w)
regridder = xe.Regridder(
ds, grid, method, reuse_weights=True, weights=weights, **regridder_kws
)

# regrid
if regridder_apply_kws is None:
regridder_apply_kws = {}
regridder_apply_kws = {**{'keep_attrs': True}, **regridder_apply_kws}
plevels[str(level)] = regridder(ds, **regridder_apply_kws)

plevels[str(level)] = xr_regridder(ds, grid, weights, out_grid_shape=(grid.sizes['x'], grid.sizes['y']))

level_attrs = {
'multiscales': multiscales_template(
datasets=[{'path': '.', 'level': level, 'crs': projection_model._crs}],
type='reduce',
method='pyramid_regrid',
version=get_version(),
kwargs=save_kwargs,
)
}
plevels[str(level)].attrs['multiscales'] = level_attrs['multiscales']

root = xr.Dataset(attrs=attrs)
plevels['/'] = root
pyramid = dt.DataTree.from_dict(plevels)

pyramid = add_metadata_and_zarr_encoding(
pyramid, levels=levels, other_chunks=other_chunks, pixels_per_tile=pixels_per_tile, projection=Projection(name=projection)
pyramid,
levels=levels,
other_chunks=other_chunks,
pixels_per_tile=pixels_per_tile,
projection=Projection(name=projection),
)

return pyramid


def xr_regridder(
ds: xr.Dataset,
grid: xr.Dataset,
weights: xr.DataArray,
out_grid_shape: tuple[int, int],
) -> xr.Dataset:
"""
Comment on lines +314 to +320
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very neat, @TomNicholas! thank you for this addition.

we should definitely figure out how to generalize this...

Cc @norlandrhagen / @maxrjones

Xarray-aware regridding function that uses weights from xESMF but performs the regridding using sparse matrix multiplication.

Parameters
----------
ds
weights
out_grid_shape

Returns
-------
regridded_ds
"""

latlon_dims = ['nlat', 'nlon']

shape_in = (ds.sizes['nlat'], ds.sizes['nlon'])
shape_out = out_grid_shape
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This out_grid_shape argument is probably superfluous - you could just read it from the grid object.


output_sizes = {'nlat': out_grid_shape[0], 'nlon': out_grid_shape[1]}

# make sure coords along non-core dims are propagated
# (this is probably superfluous now we're using xr.apply_ufunc)
non_lateral_dims = [d for d in ds.dims if d not in latlon_dims]
coords_to_copy = {d: ds.coords[d] for d in non_lateral_dims if d in ds.coords}

regridded_ds = xr.apply_ufunc(
esmf_apply_weights,
weights,
ds,
input_core_dims=[['out_dim', 'in_dim'], latlon_dims],
output_core_dims=[latlon_dims],
exclude_dims=set(latlon_dims),
kwargs={'shape_in': shape_in, 'shape_out': shape_out},
dask='parallelized',
dask_gufunc_kwargs={'output_sizes': output_sizes},
output_dtypes=[np.float32], # bug in xarray here where you can't pass output_dtypes via dask_gufunc_kwargs
keep_attrs=True,
).rename_dims(nlon='x', nlat='y')

# add coordinates for new grid (i.e. along core dims)
regridded_ds_with_new_grid_coords = xr.merge([regridded_ds, grid])

return regridded_ds_with_new_grid_coords.assign_coords(coords_to_copy)


def esmf_apply_weights(weights, indata, shape_in, shape_out):
"""
Apply regridding weights to data.

Parameters
----------
A : scipy sparse COO matrix
indata : numpy array of shape ``(..., n_lat, n_lon)`` or ``(..., n_y, n_x)``.
Should be C-ordered. Will be then tranposed to F-ordered.
shape_in, shape_out : tuple of two integers
Input/output data shape for unflatten operation.
For rectilinear grid, it is just ``(n_lat, n_lon)``.
Returns
-------
outdata : numpy array of shape ``(..., shape_out[0], shape_out[1])``.
Extra dimensions are the same as `indata`.
If input data is C-ordered, output will also be C-ordered.
"""

# COO matrix is fast with F-ordered array but slow with C-array, so we
# take in a C-ordered and then transpose)
# (CSR or CRS matrix is fast with C-ordered array but slow with F-array)
if not indata.flags["C_CONTIGUOUS"]:
warnings.warn("Input array is not C_CONTIGUOUS. "
"Will affect performance.")

# get input shape information
shape_horiz = indata.shape[-2:]
extra_shape = indata.shape[0:-2]

if shape_horiz != shape_in:
raise ValueError(
f"The horizontal shape of input data is {shape_horiz}, different from that of"
f"the regridder {shape_in}!"
)

n_points_in = shape_in[0] * shape_in[1]
if n_points_in != weights.shape[1]:
raise ValueError(
f"ny_in * nx_in should equal to weights.shape[1], but found {n_points_in} vs {weights.shape[1]}"
)

n_points_out = shape_out[0] * shape_out[1]
if n_points_out != weights.shape[0]:
raise ValueError(
f"ny_out * nx_out should equal to weights.shape[0], but found {n_points_out} vs {weights.shape[0]}"
)

# use flattened array for dot operation
indata_flat = indata.reshape(-1, shape_in[0]*shape_in[1])
outdata_flat = weights.dot(indata_flat.T).T

# unflattened output array
outdata = outdata_flat.reshape(
[*extra_shape, shape_out[0], shape_out[1]])

return outdata
1 change: 1 addition & 0 deletions ndpyramid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def add_metadata_and_zarr_encoding(
pyramid.ds.attrs['multiscales'][0]['datasets'][level]['pixels_per_tile'] = pixels_per_tile
if projection:
pyramid.ds.attrs['multiscales'][0]['datasets'][level]['crs'] = projection._crs

# set dataset chunks
pyramid[slevel].ds = pyramid[slevel].ds.chunk(chunks)

Expand Down