Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jojoelfe committed Aug 26, 2024
1 parent e841de2 commit 411db34
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 91 deletions.
132 changes: 52 additions & 80 deletions src/ttfsc/_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from enum import Enum
from pathlib import Path
from typing import Annotated, Optional

Expand All @@ -8,12 +7,9 @@
from rich import print as rprint
from torch_fourier_shell_correlation import fsc

cli = typer.Typer(name="ttfsc", no_args_is_help=True, add_completion=False)

from ._masking import Masking

class Masking(str, Enum):
none = "none"
sphere = "sphere"
cli = typer.Typer(name="ttfsc", no_args_is_help=True, add_completion=False)


@cli.command(no_args_is_help=True)
Expand Down Expand Up @@ -44,17 +40,21 @@ def ttfsc_cli(
str, typer.Option("--plot-matplotlib-style", rich_help_panel="Plotting options")
] = "default",
mask: Annotated[Masking, typer.Option("--mask", rich_help_panel="Masking options")] = Masking.none,
mask_radius: Annotated[float, typer.Option("--mask-radius", rich_help_panel="Masking options")] = 100.0,
mask_soft_edge_width: Annotated[int, typer.Option("--mask-soft-edge-width", rich_help_panel="Masking options")] = 10,
mask_radius_angstroms: Annotated[
float, typer.Option("--mask-radius-angstroms", rich_help_panel="Masking options")
] = 100.0,
mask_soft_edge_width_pixels: Annotated[
int, typer.Option("--mask-soft-edge-width-pixels", rich_help_panel="Masking options")
] = 10,
correct_for_masking: Annotated[
bool, typer.Option("--correct-for-masking", rich_help_panel="Masking correction options")
] = True,
correct_from_resolution: Annotated[
Optional[float], typer.Option("--correct-from_resolution", rich_help_panel="Masking correction options")
] = 10.0,
] = None,
correct_from_fraction_of_estimated_resolution: Annotated[
float, typer.Option("--correct-from-fraction-of-estimated-resolution", rich_help_panel="Masking correction options")
] = 0.25,
] = 0.5,
) -> None:
with mrcfile.open(map1) as f:
map1_tensor = torch.tensor(f.data)
Expand All @@ -67,82 +67,52 @@ def ttfsc_cli(
resolution_angstroms = (1 / frequency_pixels) * pixel_spacing_angstroms

fsc_values_unmasked = fsc(map1_tensor, map2_tensor)
fsc_values_masked = None

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1])

if mask == Masking.sphere:
import numpy as np
from ttmask.box_setup import box_setup
from ttmask.soft_edge import add_soft_edge
# Taken from https://github.com/teamtomo/ttmask/blob/main/src/ttmask/sphere.py

# establish our coordinate system and empty mask
coordinates_centered, mask_tensor = box_setup(map1_tensor.shape[0])

# determine distances of each pixel to the center
distance_to_center = np.linalg.norm(coordinates_centered, axis=-1)
rprint(f"Estimated resolution using {fsc_threshold} criterion in unmasked map: {estimated_resolution_angstrom:.2f} Å")

# set up criteria for which pixels are inside the sphere and modify values to 1.
inside_sphere = distance_to_center < (mask_radius / pixel_spacing_angstroms)
mask_tensor[inside_sphere] = 1

# if requested, a soft edge is added to the mask
mask_tensor = add_soft_edge(mask_tensor, mask_soft_edge_width)

if correct_for_masking:
from torch_grid_utils import fftfreq_grid
fsc_values_masked = None
if mask != Masking.none:
from ._masking import calculate_masked_fsc

map1_tensor_randomized = torch.fft.rfftn(map1_tensor)
map2_tensor_randomized = torch.fft.rfftn(map2_tensor)
frequency_grid = fftfreq_grid(
image_shape=map1_tensor.shape,
rfft=True,
fftshift=False,
norm=True,
device=map1_tensor_randomized.device,
(estimated_resolution_angstrom, estimated_resolution_frequency_pixel, fsc_values_masked, mask_tensor) = (
calculate_masked_fsc(
map1_tensor,
map2_tensor,
pixel_spacing_angstroms=pixel_spacing_angstroms,
fsc_threshold=fsc_threshold,
mask=mask,
mask_radius_angstroms=mask_radius_angstroms,
mask_soft_edge_width_pixels=mask_soft_edge_width_pixels,
)
if correct_from_resolution is not None:
to_correct = frequency_grid > (1 / correct_from_resolution) / pixel_spacing_angstroms
else:
to_correct = (
frequency_grid > correct_from_fraction_of_estimated_resolution * estimated_resolution_frequency_pixel
)
# Rotate phases at frequencies higher than 0.25
random_phases1 = torch.rand(frequency_grid[to_correct].shape) * 2 * torch.pi
random_phases1 = torch.complex(torch.cos(random_phases1), torch.sin(random_phases1))
random_phases2 = torch.rand(frequency_grid[to_correct].shape) * 2 * torch.pi
random_phases2 = torch.complex(torch.cos(random_phases2), torch.sin(random_phases2))

map1_tensor_randomized[to_correct] *= random_phases1
map2_tensor_randomized[to_correct] *= random_phases2

map1_tensor_randomized = torch.fft.irfftn(map1_tensor_randomized)
map2_tensor_randomized = torch.fft.irfftn(map2_tensor_randomized)

map1_tensor_randomized *= mask_tensor
map2_tensor_randomized *= mask_tensor
fsc_values_masked_randomized = fsc(map1_tensor_randomized, map2_tensor_randomized)
map1_tensor = map1_tensor * mask_tensor
map2_tensor = map2_tensor * mask_tensor
fsc_values_masked = fsc(map1_tensor, map2_tensor)
)
rprint(f"Estimated resolution using {fsc_threshold} criterion in masked map: {estimated_resolution_angstrom:.2f} Å")
if correct_for_masking:
if correct_from_resolution is None:
to_correct = (
frequency_pixels > correct_from_fraction_of_estimated_resolution * estimated_resolution_frequency_pixel
)
else:
to_correct = frequency_pixels > (1 / correct_from_resolution) / pixel_spacing_angstroms
fsc_values_corrected = fsc_values_masked.clone()
fsc_values_corrected[to_correct] = (
fsc_values_corrected[to_correct] - fsc_values_masked_randomized[to_correct]
) / (1.0 - fsc_values_masked_randomized[to_correct])

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1])

rprint(f"Estimated resolution using {fsc_threshold} criterion: {estimated_resolution_angstrom:.2f} Å")
from ._masking import calculate_noise_injected_fsc

(
estimated_resolution_angstrom,
estimated_resolution_frequency_pixel,
correction_from_resolution_angstrom,
fsc_values_corrected,
) = calculate_noise_injected_fsc(
map1_tensor,
map2_tensor,
mask_tensor=mask_tensor,
fsc_values_masked=fsc_values_masked,
pixel_spacing_angstroms=pixel_spacing_angstroms,
fsc_threshold=fsc_threshold,
estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel,
correct_from_resolution=correct_from_resolution,
correct_from_fraction_of_estimated_resolution=correct_from_fraction_of_estimated_resolution,
)
rprint(
f"Estimated resolution using {fsc_threshold} "
f"criterion with correction after {correction_from_resolution_angstrom:.2f} Å: "
f"{estimated_resolution_angstrom:.2f} Å"
)

if plot:
from ._plotting import plot_matplotlib, plot_plottile
Expand All @@ -152,15 +122,17 @@ def ttfsc_cli(
fsc_values_unmasked=fsc_values_unmasked,
fsc_values_masked=fsc_values_masked,
fsc_values_corrected=fsc_values_corrected,
resolution_angstroms=resolution_angstroms,
estimated_resolution_angstrom=estimated_resolution_angstrom,
frequency_pixels=frequency_pixels,
pixel_spacing_angstroms=pixel_spacing_angstroms,
estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel,
fsc_threshold=fsc_threshold,
plot_matplotlib_style=plot_matplotlib_style,
)
else:
plot_plottile(
fsc_values=fsc_values_unmasked,
fsc_values_masked=fsc_values_masked,
fsc_values_corrected=fsc_values_corrected,
frequency_pixels=frequency_pixels,
pixel_spacing_angstroms=pixel_spacing_angstroms,
estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel,
Expand Down
121 changes: 121 additions & 0 deletions src/ttfsc/_masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from enum import Enum
from typing import Optional

import torch
from torch_fourier_shell_correlation import fsc


class Masking(str, Enum):
none = "none"
sphere = "sphere"


def calculate_noise_injected_fsc(
map1_tensor: torch.tensor,
map2_tensor: torch.tensor,
mask_tensor: torch.tensor,
fsc_values_masked: torch.tensor,
pixel_spacing_angstroms: float,
fsc_threshold: float,
estimated_resolution_frequency_pixel: float,
correct_from_resolution: Optional[float] = None,
correct_from_fraction_of_estimated_resolution: float = 0.5,
):
from torch_grid_utils import fftfreq_grid

map1_tensor_randomized = torch.fft.rfftn(map1_tensor)
map2_tensor_randomized = torch.fft.rfftn(map2_tensor)
frequency_grid = fftfreq_grid(
image_shape=map1_tensor.shape,
rfft=True,
fftshift=False,
norm=True,
device=map1_tensor_randomized.device,
)
if correct_from_resolution is not None:
to_correct = frequency_grid > (1 / correct_from_resolution) / pixel_spacing_angstroms
else:
to_correct = frequency_grid > correct_from_fraction_of_estimated_resolution * estimated_resolution_frequency_pixel
# Rotate phases at frequencies higher than 0.25
random_phases1 = torch.rand(frequency_grid[to_correct].shape) * 2 * torch.pi
random_phases1 = torch.complex(torch.cos(random_phases1), torch.sin(random_phases1))
random_phases2 = torch.rand(frequency_grid[to_correct].shape) * 2 * torch.pi
random_phases2 = torch.complex(torch.cos(random_phases2), torch.sin(random_phases2))

map1_tensor_randomized[to_correct] *= random_phases1
map2_tensor_randomized[to_correct] *= random_phases2

map1_tensor_randomized = torch.fft.irfftn(map1_tensor_randomized)
map2_tensor_randomized = torch.fft.irfftn(map2_tensor_randomized)

map1_tensor_randomized *= mask_tensor
map2_tensor_randomized *= mask_tensor
fsc_values_masked_randomized = fsc(map1_tensor_randomized, map2_tensor_randomized)

frequency_pixels = torch.fft.rfftfreq(map1_tensor.shape[0])
resolution_angstroms = (1 / frequency_pixels) * pixel_spacing_angstroms

if correct_from_resolution is None:
to_correct = frequency_pixels > correct_from_fraction_of_estimated_resolution * estimated_resolution_frequency_pixel
correct_from_resolution = pixel_spacing_angstroms / (
correct_from_fraction_of_estimated_resolution * estimated_resolution_frequency_pixel
)
else:
to_correct = frequency_pixels > (1 / correct_from_resolution) / pixel_spacing_angstroms
fsc_values_corrected = fsc_values_masked.clone()
fsc_values_corrected[to_correct] = (fsc_values_corrected[to_correct] - fsc_values_masked_randomized[to_correct]) / (
1.0 - fsc_values_masked_randomized[to_correct]
)

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_corrected < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_corrected < fsc_threshold).nonzero()[0] - 1])

return (
estimated_resolution_angstrom,
estimated_resolution_frequency_pixel,
correct_from_resolution,
fsc_values_corrected,
)


def calculate_masked_fsc(
map1_tensor: torch.tensor,
map2_tensor: torch.tensor,
pixel_spacing_angstroms: float,
fsc_threshold: float,
mask: Masking,
mask_radius_angstroms: float = 100.0,
mask_soft_edge_width_pixels: int = 5,
) -> tuple[float, float, torch.tensor, torch.tensor]:
if mask == Masking.none:
raise ValueError("Must choose a mask type")
if mask == Masking.sphere:
import numpy as np
from ttmask.box_setup import box_setup
from ttmask.soft_edge import add_soft_edge
# Taken from https://github.com/teamtomo/ttmask/blob/main/src/ttmask/sphere.py

# establish our coordinate system and empty mask
coordinates_centered, mask_tensor = box_setup(map1_tensor.shape[0])

# determine distances of each pixel to the center
distance_to_center = np.linalg.norm(coordinates_centered, axis=-1)

# set up criteria for which pixels are inside the sphere and modify values to 1.
inside_sphere = distance_to_center < (mask_radius_angstroms / pixel_spacing_angstroms)
mask_tensor[inside_sphere] = 1

# if requested, a soft edge is added to the mask
mask_tensor = add_soft_edge(mask_tensor, mask_soft_edge_width_pixels)

map1_tensor_masked = map1_tensor * mask_tensor
map2_tensor_masked = map2_tensor * mask_tensor
fsc_values_masked = fsc(map1_tensor_masked, map2_tensor_masked)

frequency_pixels = torch.fft.rfftfreq(map1_tensor.shape[0])
resolution_angstroms = (1 / frequency_pixels) * pixel_spacing_angstroms

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1])

return (estimated_resolution_angstrom, estimated_resolution_frequency_pixel, fsc_values_masked, mask_tensor)
32 changes: 21 additions & 11 deletions src/ttfsc/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,29 @@ def plot_matplotlib(
fsc_values_unmasked: torch.Tensor,
fsc_values_masked: Optional[torch.Tensor],
fsc_values_corrected: Optional[torch.Tensor],
resolution_angstroms: torch.Tensor,
estimated_resolution_angstrom: float,
frequency_pixels: torch.Tensor,
pixel_spacing_angstroms: float,
estimated_resolution_frequency_pixel: float,
fsc_threshold: float,
plot_matplotlib_style: str,
) -> None:
from matplotlib import pyplot as plt

plt.style.use(plot_matplotlib_style)
plt.hlines(0, resolution_angstroms[1], resolution_angstroms[-2], "black")
plt.plot(resolution_angstroms, fsc_values_unmasked, label="FSC (unmasked)")
plt.hlines(0, frequency_pixels[1], frequency_pixels[-2], "black")
plt.plot(frequency_pixels[1:], fsc_values_unmasked[1:], label="FSC (unmasked)")
if fsc_values_masked is not None:
plt.plot(resolution_angstroms, fsc_values_masked, label="FSC (masked)")
plt.plot(frequency_pixels[1:], fsc_values_masked[1:], label="FSC (masked)")
if fsc_values_corrected is not None:
plt.plot(resolution_angstroms, fsc_values_corrected, label="FSC (corrected)")
plt.plot(frequency_pixels[1:], fsc_values_corrected[1:], label="FSC (corrected)")

plt.xlabel("Resolution (Å)")
plt.ylabel("Correlation")
plt.xscale("log")
plt.xlim(resolution_angstroms[1], resolution_angstroms[-2])
plt.gca().xaxis.set_major_formatter(lambda x, pos: f"{(1 / x) * pixel_spacing_angstroms:.2f}")
plt.xlim(frequency_pixels[1], frequency_pixels[-2])
plt.ylim(-0.05, 1.05)
plt.hlines(fsc_threshold, resolution_angstroms[1], estimated_resolution_angstrom, "red", "--")
plt.vlines(estimated_resolution_angstrom, -0.05, fsc_threshold, "red", "--")
plt.hlines(fsc_threshold, frequency_pixels[1], estimated_resolution_frequency_pixel, "red", "--")
plt.vlines(estimated_resolution_frequency_pixel, -0.05, fsc_threshold, "red", "--")
plt.legend()
plt.tight_layout()
plt.show()
Expand All @@ -37,6 +38,7 @@ def plot_matplotlib(
def plot_plottile(
fsc_values: torch.Tensor,
fsc_values_masked: Optional[torch.Tensor],
fsc_values_corrected: Optional[torch.Tensor],
frequency_pixels: torch.Tensor,
pixel_spacing_angstroms: float,
estimated_resolution_frequency_pixel: float,
Expand All @@ -49,15 +51,23 @@ def plot_plottile(
fig.height = 20
fig.set_x_limits(float(frequency_pixels[1]), float(frequency_pixels[-1]))
fig.set_y_limits(0, 1)
fig.x_label = "Resolution [Å]"
fig.y_label = "FSC"

def resolution_callback(x: float, _: float) -> str:
return f"{(1 / x) * pixel_spacing_angstroms:.2f}"

fig.x_ticks_fkt = resolution_callback

def fsc_callback(x: float, _: float) -> str:
return f"{x:.2f}"

fig.y_ticks_fkt = fsc_callback
fig.plot(frequency_pixels[1:].numpy(), fsc_values[1:].numpy(), lc="blue", label="FSC (unmasked)")
if fsc_values_masked is not None:
fig.plot(frequency_pixels[1:].numpy(), fsc_values_masked[1:].numpy(), lc="green", label="FSC (masked)")

if fsc_values_corrected is not None:
fig.plot(frequency_pixels[1:].numpy(), fsc_values_corrected[1:].numpy(), lc="yellow", label="FSC (corrected)")
fig.plot(
[float(frequency_pixels[1].numpy()), estimated_resolution_frequency_pixel],
[fsc_threshold, fsc_threshold],
Expand Down

0 comments on commit 411db34

Please sign in to comment.