Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Erase features in 3D from 3D mask #35

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/fidder/erase/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typer import Option

from .erase import erase_masked_region as _erase_masked_region
from .erase import erase_masked_region_3d as _erase_masked_region_3d
from ..utils import get_pixel_spacing_from_header
from .._cli import cli, OPTION_PROMPT_KWARGS as PKWARGS

Expand Down Expand Up @@ -54,3 +55,43 @@ def erase_masked_region(
voxel_size=pixel_spacing,
overwrite=True,
)


@cli.command(name="erase_3d", no_args_is_help=True)
def erase_masked_region_3d(
input_image: Path = Option(
default=...,
help="Image file in MRC format.",
**PKWARGS
),
input_mask: Path = Option(
default=...,
help="Mask file in MRC format.",
**PKWARGS
),
output_image: Path = Option(
default=...,
help="Output file in MRC format.",
**PKWARGS
),
):
"""Erase a masked region in a cryo-EM image."""
volume = torch.as_tensor(mrcfile.read(input_image)).squeeze().float()
mask = torch.as_tensor(mrcfile.read(input_mask), dtype=torch.bool).squeeze()
if volume.shape != mask.shape:
raise ValueError('Shape mismatch between data in volume and mask files.')

erased_volume = _erase_masked_region_3d(
volume=volume,
mask=mask,
background_intensity_model_resolution=(8, 8, 8),
background_intensity_model_samples=25000,
)

pixel_spacing = get_pixel_spacing_from_header(input_image)
mrcfile.write(
name=output_image,
data=np.array(erased_volume, dtype=np.float32),
voxel_size=pixel_spacing,
overwrite=True,
)
61 changes: 59 additions & 2 deletions src/fidder/erase/erase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
from einops import einops

from ..utils import estimate_background_std
from .sparse_local_mean import estimate_local_mean
from ..utils import estimate_background_std, estimate_background_std_3d
from .sparse_local_mean import estimate_local_mean, estimate_local_mean_3d


def erase_masked_region(
Expand Down Expand Up @@ -108,3 +108,60 @@ def _erase_single_image(
size=n_pixels_to_inpaint)
inpainted_image[idx_foreground] += torch.as_tensor(noise)
return inpainted_image


def erase_masked_region_3d(
volume: torch.Tensor,
mask: torch.Tensor,
background_intensity_model_resolution: Tuple[int, int, int] = (5, 5, 5),
background_intensity_model_samples: int = 20000,
) -> torch.Tensor:
"""Inpaint image(s) with gaussian noise.


Parameters
----------
image: torch.Tensor
`(b, h, w)` or `(h, w)` array containing image data for erase.
mask: torch.Tensor
`(b, h, w)` or `(h, w)` binary mask separating foreground from background pixels.
Foreground pixels (1) will be inpainted.
background_intensity_model_resolution: Tuple[int, int]
Number of points in each image dimension for the background mean model.
Minimum of two points in each dimension.
background_intensity_model_samples: int
Number of sample points used to determine the model of the background mean.

Returns
-------
inpainted_image: torch.Tensor
`(b, h, w)` or `(h, w)` array containing image data inpainted in the foreground pixels of the mask
with gaussian noise matching the local mean and global standard deviation of the image
for background pixels.
"""
volume = torch.as_tensor(volume)
mask = torch.as_tensor(mask, dtype=torch.bool)
if volume.shape != mask.shape:
raise ValueError("image shape must match mask shape.")

inpainted = torch.clone(volume)
local_mean = estimate_local_mean_3d(
volume=volume,
mask=torch.logical_not(mask),
resolution=background_intensity_model_resolution,
n_samples_for_fit=background_intensity_model_samples,
)

# fill foreground pixels with local mean
idx_foreground = torch.argwhere(mask.bool() == True)
idx_foreground = (idx_foreground[:, 0], idx_foreground[:, 1], idx_foreground[:, 2])

inpainted[idx_foreground] = local_mean[idx_foreground]

# add noise with mean=0 std=background std estimate
background_std = estimate_background_std_3d(volume, mask)
n_pixels_to_inpaint = idx_foreground[0].shape[0]
noise = np.random.normal(loc=0, scale=background_std, size=(n_pixels_to_inpaint, 3))
inpainted[idx_foreground] += torch.as_tensor(np.mean(noise, axis=1))

return torch.as_tensor(inpainted, dtype=torch.float32)
70 changes: 70 additions & 0 deletions src/fidder/erase/sparse_local_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch
from scipy.interpolate import LSQBivariateSpline
from torch_cubic_spline_grids.b_spline_grids import CubicBSplineGrid3d


def estimate_local_mean(
Expand Down Expand Up @@ -59,3 +60,72 @@ def estimate_local_mean(
x = np.arange(image.shape[-1])
local_mean = background_model(y, x, grid=True)
return torch.tensor(local_mean, dtype=input_dtype)


def estimate_local_mean_3d(
volume: torch.Tensor,
mask: Optional[torch.Tensor] = None,
resolution: Tuple[int, int, int] = (5, 5, 5),
n_samples_for_fit: int = 20000,
):
"""Estimate local mean of an image with a bivariate cubic spline.

A mask can be provided to

Parameters
----------
image: torch.Tensor
`(h, w)` array containing image data.
mask: Optional[torch.Tensor]
`(h, w)` array containing a binary mask specifying foreground
and background pixels for the estimation.
resolution: Tuple[int, int]
Resolution of the local mean estimate in each dimension.
n_samples_for_fit: int
Number of samples taken from foreground pixels for background mean estimation.
The number of background pixels will be used if this number is greater than the
number of background pixels.

Returns
-------
local_mean: torch.Tensor
`(h, w)` array containing a local estimate of the local mean.
"""
input_dtype = volume.dtype
volume = volume.numpy()
mask = np.ones_like(volume) if mask is None else mask.numpy()

# get a random set of foreground pixels for the background fit
foreground_sample_idx = np.argwhere(mask == 1)

n_samples_for_fit = min(n_samples_for_fit, len(foreground_sample_idx))
selection = np.random.choice(
foreground_sample_idx.shape[0], size=n_samples_for_fit, replace=False
)
foreground_sample_idx = foreground_sample_idx[selection]
z, y, x = foreground_sample_idx[:, 0], foreground_sample_idx[:, 1], foreground_sample_idx[:, 2]

w = torch.as_tensor(volume[(z, y, x)])

grid = CubicBSplineGrid3d(resolution=resolution)
optimiser = torch.optim.Adam(grid.parameters(), lr=0.01)

foreground_sample_idx_rescaled = foreground_sample_idx / volume.shape
for i in range(500):
# what does the model predict for our observations?
prediction = grid(foreground_sample_idx_rescaled).squeeze()

# zero gradients and calculate loss between observations and model prediction
optimiser.zero_grad()
loss = torch.sum((prediction - w)**2)**0.5

# backpropagate loss and update values at points on grid
loss.backward()
optimiser.step()

tz = torch.tensor(np.linspace(0, 1, volume.shape[0]))
ty = torch.tensor(np.linspace(0, 1, volume.shape[1]))
tx = torch.tensor(np.linspace(0, 1, volume.shape[2]))
zz, yy, xx = torch.meshgrid(tz, ty, tx, indexing='xy')
w = grid(torch.stack((zz, yy, xx), dim=-1)).detach().numpy().reshape(volume.shape)
return torch.tensor(w, dtype=input_dtype)
47 changes: 47 additions & 0 deletions src/fidder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ def central_crop_2d(image: torch.Tensor, percentage: float = 25) -> torch.Tensor
return image[..., hf:hc, wf:wc]


def central_crop_3d(image: torch.Tensor, percentage: float = 25) -> torch.Tensor:
"""Get a central crop of (a batch of) 2D image(s).

Parameters
----------
image: torch.Tensor
`(b, h, w)` or `(h, w)` array of 2D images.
percentage: float
percentage of image height and width for cropped region.
Returns
-------
cropped_image: torch.Tensor
`(b, h, w)` or `(h, w)` array of cropped 2D images.
"""
h, w, d = image.shape[-3], image.shape[-2], image.shape[-1]
mh, mw, md = h // 2, w // 2, d // 2
dh, dw, dd = int(h * (percentage / 100 / 2)), int(w * (percentage / 100 / 2)), int(d * (percentage / 100 / 2))
hf, wf, df = mh - dh, mw - dw, md - dd
hc, wc, dc = mh + dh, mw + dw, md + dd
return image[..., hf:hc, wf:wc, df:dc]


def estimate_background_std(image: torch.Tensor, mask: torch.Tensor):
"""Estimate the standard deviation of the background from a central crop.

Expand All @@ -120,6 +142,31 @@ def estimate_background_std(image: torch.Tensor, mask: torch.Tensor):
return torch.std(image[mask == 0])


def estimate_background_std_3d(image: torch.Tensor, mask: torch.Tensor):
"""Estimate the standard deviation of the background from a central crop.

Parameters
----------
image: torch.Tensor
`(h, w)` array containing data for which background standard deviation will be estimated.
mask: torch.Tensor of 0 or 1
Binary mask separating foreground and background.
Returns
-------
standard_deviation: float
estimated standard deviation for the background.
"""
image = central_crop_3d(image, percentage=25).float()
mask = central_crop_3d(mask, percentage=25)
image_masked = image.clone()
image_masked[mask == 1] = np.nan
return (
np.nanmean(np.nanstd(image_masked, axis=0)),
np.nanmean(np.nanstd(image_masked, axis=1)),
np.nanmean(np.nanstd(image_masked, axis=2)),
)


def get_pixel_spacing_from_header(image: Path) -> float:
with mrcfile.open(image, header_only=True, permissive=True) as mrc:
return float(mrc.voxel_size.x)
Expand Down
Loading