Skip to content

Commit

Permalink
First draft for correction by noise injection
Browse files Browse the repository at this point in the history
  • Loading branch information
jojoelfe committed Aug 21, 2024
1 parent 26293cc commit e974447
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 9 deletions.
64 changes: 55 additions & 9 deletions src/ttfsc/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,33 @@ class Masking(str, Enum):

@cli.command(no_args_is_help=True)
def ttfsc_cli(
map1: Annotated[Path, typer.Argument(show_default=False)],
map2: Annotated[Path, typer.Argument(show_default=False)],
map1: Annotated[Path, typer.Argument()],
map2: Annotated[Path, typer.Argument()],
pixel_spacing_angstroms: Annotated[
Optional[float],
typer.Option(
"--pixel-spacing-angstroms", show_default=False, help="Pixel spacing in Å/px, taken from header if not set"
"--pixel-spacing-angstroms",
show_default=False,
help="Pixel spacing in Å/px, taken from header if not set",
rich_help_panel="Input options",
),
] = None,
plot: Annotated[bool, typer.Option("--plot")] = True,
plot_with_matplotlib: Annotated[bool, typer.Option("--plot-with-matplotlib")] = False,
fsc_threshold: Annotated[float, typer.Option("--fsc-threshold", help="FSC threshold")] = 0.143,
mask: Annotated[Masking, typer.Option("--mask")] = Masking.none,
mask_radius: Annotated[float, typer.Option("--mask-radius")] = 100.0,
mask_soft_edge_width: Annotated[int, typer.Option("--mask-soft-edge-width")] = 10,
fsc_threshold: Annotated[
float, typer.Option("--fsc-threshold", help="FSC threshold", rich_help_panel="Input options")
] = 0.143,
plot: Annotated[bool, typer.Option("--plot", rich_help_panel="Plotting options")] = True,
plot_with_matplotlib: Annotated[
bool, typer.Option("--plot-with-matplotlib", rich_help_panel="Plotting options")
] = False,
mask: Annotated[Masking, typer.Option("--mask", rich_help_panel="Masking options")] = Masking.none,
mask_radius: Annotated[float, typer.Option("--mask-radius", rich_help_panel="Masking options")] = 100.0,
mask_soft_edge_width: Annotated[int, typer.Option("--mask-soft-edge-width", rich_help_panel="Masking options")] = 10,
correct_for_masking: Annotated[
bool, typer.Option("--correct-for-masking", rich_help_panel="Masking correction options")
] = True,
correct_from_resolution: Annotated[
float, typer.Option("--correct-from_resolution", rich_help_panel="Masking correction options")
] = True,
) -> None:
with mrcfile.open(map1) as f:
map1_tensor = torch.tensor(f.data)
Expand Down Expand Up @@ -68,9 +81,41 @@ def ttfsc_cli(
# if requested, a soft edge is added to the mask
mask_tensor = add_soft_edge(mask_tensor, mask_soft_edge_width)

if correct_for_masking:
from torch_grid_utils import fftfreq_grid

map1_tensor_randomized = torch.fft.rfftn(map1_tensor)
map2_tensor_randomized = torch.fft.rfftn(map2_tensor)
frequency_grid = fftfreq_grid(
image_shape=map1_tensor.shape,
rfft=True,
fftshift=False,
norm=True,
device=map1_tensor_randomized.device,
)
# Rotate phases at frequencies higher than 0.25
random_phases1 = torch.rand(frequency_grid[frequency_grid > 0.25].shape) * 2 * torch.pi
random_phases1 = torch.complex(torch.cos(random_phases1), torch.sin(random_phases1))
random_phases2 = torch.rand(frequency_grid[frequency_grid > 0.25].shape) * 2 * torch.pi
random_phases2 = torch.complex(torch.cos(random_phases2), torch.sin(random_phases2))

map1_tensor_randomized[frequency_grid > 0.25] *= random_phases1
map2_tensor_randomized[frequency_grid > 0.25] *= random_phases2

map1_tensor_randomized = torch.fft.irfftn(map1_tensor_randomized)
map2_tensor_randomized = torch.fft.irfftn(map2_tensor_randomized)

map1_tensor_randomized *= mask_tensor
map2_tensor_randomized *= mask_tensor
fsc_values_masked_randomized = fsc(map1_tensor_randomized, map2_tensor_randomized)
map1_tensor = map1_tensor * mask_tensor
map2_tensor = map2_tensor * mask_tensor
fsc_values_masked = fsc(map1_tensor, map2_tensor)
if correct_for_masking:
fsc_values_corrected = fsc_values_masked.clone()
fsc_values_corrected[frequency_pixels > 0.25] = (
fsc_values_corrected[frequency_pixels > 0.25] - fsc_values_masked_randomized[frequency_pixels > 0.25]
) / (1.0 - fsc_values_masked_randomized[frequency_pixels > 0.25])

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1])
Expand All @@ -84,6 +129,7 @@ def ttfsc_cli(
plot_matplotlib(
fsc_values_unmasked=fsc_values_unmasked,
fsc_values_masked=fsc_values_masked,
fsc_values_corrected=fsc_values_corrected,
resolution_angstroms=resolution_angstroms,
estimated_resolution_angstrom=estimated_resolution_angstrom,
fsc_threshold=fsc_threshold,
Expand Down
3 changes: 3 additions & 0 deletions src/ttfsc/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
def plot_matplotlib(
fsc_values_unmasked: torch.Tensor,
fsc_values_masked: Optional[torch.Tensor],
fsc_values_corrected: Optional[torch.Tensor],
resolution_angstroms: torch.Tensor,
estimated_resolution_angstrom: float,
fsc_threshold: float,
Expand All @@ -16,6 +17,8 @@ def plot_matplotlib(
plt.plot(resolution_angstroms, fsc_values_unmasked, label="FSC (unmasked)")
if fsc_values_masked is not None:
plt.plot(resolution_angstroms, fsc_values_masked, label="FSC (masked)")
if fsc_values_corrected is not None:
plt.plot(resolution_angstroms, fsc_values_corrected, label="FSC (corrected)")

plt.xlabel("Resolution (Å)")
plt.ylabel("Correlation")
Expand Down

0 comments on commit e974447

Please sign in to comment.