From 07c8d10476d074f22eb8b4b8668f6abf70da2d1b Mon Sep 17 00:00:00 2001 From: Johannes Elferich Date: Sat, 7 Sep 2024 22:55:09 -0400 Subject: [PATCH] Refactor --- src/ttfsc/_cli.py | 152 +++++++++++++--------------------- src/ttfsc/_data_models.py | 68 +++++++++++++++ src/ttfsc/_masking.py | 127 ++++++++++++---------------- src/ttfsc/_plotting.py | 73 +++++++--------- src/ttfsc/_starfile_schema.py | 29 ------- src/ttfsc/ttfsc.py | 111 +++++++++++++++++++++++++ tests/test_ttfsc.py | 10 ++- 7 files changed, 330 insertions(+), 240 deletions(-) create mode 100644 src/ttfsc/_data_models.py delete mode 100644 src/ttfsc/_starfile_schema.py create mode 100644 src/ttfsc/ttfsc.py diff --git a/src/ttfsc/_cli.py b/src/ttfsc/_cli.py index 75bdc1f..528a313 100644 --- a/src/ttfsc/_cli.py +++ b/src/ttfsc/_cli.py @@ -1,13 +1,11 @@ from pathlib import Path from typing import Annotated, Optional -import mrcfile -import torch import typer from rich import print as rprint -from torch_fourier_shell_correlation import fsc from ._masking import Masking +from .ttfsc import ttfsc cli = typer.Typer(name="ttfsc", no_args_is_help=True, add_completion=False) @@ -56,98 +54,77 @@ def ttfsc_cli( float, typer.Option("--correct-from-fraction-of-estimated-resolution", rich_help_panel="Masking correction options") ] = 0.5, ) -> None: - with mrcfile.open(map1) as f: - map1_tensor = torch.tensor(f.data) - if pixel_spacing_angstroms is None: - pixel_spacing_angstroms = f.voxel_size.x - with mrcfile.open(map2) as f: - map2_tensor = torch.tensor(f.data) - - frequency_pixels = torch.fft.rfftfreq(map1_tensor.shape[0]) - resolution_angstroms = (1 / frequency_pixels) * pixel_spacing_angstroms - - fsc_values_unmasked = fsc(map1_tensor, map2_tensor) - - estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1]) - estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1]) - - rprint(f"Estimated resolution using {fsc_threshold} criterion in unmasked map: {estimated_resolution_angstrom:.2f} Å") - - fsc_values_masked = None - fsc_values_corrected = None - if mask != Masking.none: - from ._masking import calculate_masked_fsc - - (estimated_resolution_angstrom, estimated_resolution_frequency_pixel, fsc_values_masked, mask_tensor) = ( - calculate_masked_fsc( - map1_tensor, - map2_tensor, - pixel_spacing_angstroms=pixel_spacing_angstroms, - fsc_threshold=fsc_threshold, - mask=mask, - mask_radius_angstroms=mask_radius_angstroms, - mask_soft_edge_width_pixels=mask_soft_edge_width_pixels, - ) + result = ttfsc( + map1=map1, + map2=map2, + pixel_spacing_angstroms=pixel_spacing_angstroms, + fsc_threshold=fsc_threshold, + mask=mask, + mask_radius_angstroms=mask_radius_angstroms, + mask_soft_edge_width_pixels=mask_soft_edge_width_pixels, + correct_for_masking=correct_for_masking, + correct_from_resolution=correct_from_resolution, + correct_from_fraction_of_estimated_resolution=correct_from_fraction_of_estimated_resolution, + ) + + rprint( + f"Estimated resolution using {fsc_threshold} criterion in unmasked map: " + f"{result.estimated_resolution_angstrom_unmasked:.2f} Å" + ) + if result.estimated_resolution_angstrom_masked is not None: + rprint( + f"Estimated resolution using {fsc_threshold} " + f"criterion in masked map: {result.estimated_resolution_angstrom_masked:.2f} Å" + ) + if result.estimated_resolution_angstrom_corrected is not None: + print( + f"Estimated resolution using {fsc_threshold} " + f"criterion with correction after {result.correction_from_resolution_angstrom:.2f} Å: " + f"{result.estimated_resolution_angstrom_corrected:.2f} Å" ) - rprint(f"Estimated resolution using {fsc_threshold} criterion in masked map: {estimated_resolution_angstrom:.2f} Å") - if correct_for_masking: - from ._masking import calculate_noise_injected_fsc - ( - estimated_resolution_angstrom, - estimated_resolution_frequency_pixel, - correction_from_resolution_angstrom, - fsc_values_corrected, - fsc_values_randomized, - ) = calculate_noise_injected_fsc( - map1_tensor, - map2_tensor, - mask_tensor=mask_tensor, - fsc_values_masked=fsc_values_masked, - pixel_spacing_angstroms=pixel_spacing_angstroms, - fsc_threshold=fsc_threshold, - estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel, - correct_from_resolution=correct_from_resolution, - correct_from_fraction_of_estimated_resolution=correct_from_fraction_of_estimated_resolution, - ) - rprint( - f"Estimated resolution using {fsc_threshold} " - f"criterion with correction after {correction_from_resolution_angstrom:.2f} Å: " - f"{estimated_resolution_angstrom:.2f} Å" - ) if save_starfile: import pandas as pd import starfile from numpy import nan - from ._starfile_schema import RelionDataGeneral, RelionFSCData + from ._data_models import RelionDataGeneral, RelionFSCData data_general = RelionDataGeneral( - rlnFinalResolution=estimated_resolution_angstrom, - rlnUnfilteredMapHalf1=map1.name, - rlnUnfilteredMapHalf2=map2.name, + rlnFinalResolution=result.estimated_resolution_angstrom, + rlnUnfilteredMapHalf1=str(result.map1), + rlnUnfilteredMapHalf2=str(result.map2), ) - if mask != Masking.none: - data_general.rlnParticleBoxFractionSolventMask = mask_tensor.sum().item() / mask_tensor.numel() + if result.mask_tensor is not None: + data_general.rlnParticleBoxFractionSolventMask = result.mask_tensor.sum().item() / result.mask_tensor.numel() if correct_for_masking: - data_general.rlnRandomiseFrom = correction_from_resolution_angstrom + if result.correction_from_resolution_angstrom is None: + raise ValueError("Phase randomization cutoff has not been calculated") + data_general.rlnRandomiseFrom = result.correction_from_resolution_angstrom fsc_data = [] - for i, (f, r) in enumerate(zip(fsc_values_unmasked, resolution_angstroms)): + for i in range(result.num_shells): fsc_data.append( RelionFSCData( rlnSpectralIndex=i, - rlnResolution=r, - rlnAngstromResolution=r, - rlnFourierShellCorrelationCorrected=fsc_values_corrected[i] if fsc_values_corrected is not None else nan, - rlnFourierShellCorrelationUnmaskedMaps=f, - rlnFourierShellCorrelationMaskedMaps=fsc_values_masked[i] if fsc_values_masked is not None else nan, - rlnCorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps=fsc_values_randomized[i] - if correct_for_masking + rlnResolution=result.frequency_pixels[i], + rlnAngstromResolution=result.resolution_angstroms[i], + rlnFourierShellCorrelationCorrected=result.fsc_values_corrected[i] + if result.fsc_values_corrected is not None + else nan, + rlnFourierShellCorrelationUnmaskedMaps=result.fsc_values_unmasked[i], + rlnFourierShellCorrelationMaskedMaps=result.fsc_values_masked[i] + if result.fsc_values_masked is not None else nan, - rlnFourierShellCorrelationParticleMaskFraction=f + rlnCorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps=result.fsc_values_randomized[i] + if result.fsc_values_randomized is not None + else nan, + rlnFourierShellCorrelationParticleMaskFraction=result.fsc_values_unmasked[i] / data_general.rlnParticleBoxFractionSolventMask - / (1.0 + (1.0 / data_general.rlnParticleBoxFractionSolventMask - 1.0) * f.abs()) + / ( + 1.0 + + (1.0 / data_general.rlnParticleBoxFractionSolventMask - 1.0) * result.fsc_values_unmasked[i].abs() + ) if mask != Masking.none else nan, ) @@ -160,23 +137,6 @@ def ttfsc_cli( from ._plotting import plot_matplotlib, plot_plottile if plot_with_matplotlib: - plot_matplotlib( - fsc_values_unmasked=fsc_values_unmasked, - fsc_values_masked=fsc_values_masked, - fsc_values_corrected=fsc_values_corrected, - frequency_pixels=frequency_pixels, - pixel_spacing_angstroms=pixel_spacing_angstroms, - estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel, - fsc_threshold=fsc_threshold, - plot_matplotlib_style=plot_matplotlib_style, - ) + plot_matplotlib(result, plot_matplotlib_style=plot_matplotlib_style) else: - plot_plottile( - fsc_values=fsc_values_unmasked, - fsc_values_masked=fsc_values_masked, - fsc_values_corrected=fsc_values_corrected, - frequency_pixels=frequency_pixels, - pixel_spacing_angstroms=pixel_spacing_angstroms, - estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel, - fsc_threshold=fsc_threshold, - ) + plot_plottile(result) diff --git a/src/ttfsc/_data_models.py b/src/ttfsc/_data_models.py new file mode 100644 index 0000000..4191690 --- /dev/null +++ b/src/ttfsc/_data_models.py @@ -0,0 +1,68 @@ +from enum import Enum +from pathlib import Path +from typing import List, Optional + +from numpy import nan +from pydantic import BaseModel +from torch import Tensor + + +class Masking(str, Enum): + none = "none" + sphere = "sphere" + + +class TTFSCResult(BaseModel): + map1: Path + map1_tensor: Tensor + map2: Path + map2_tensor: Tensor + pixel_spacing_angstroms: float + fsc_threshold: float + mask: str = Masking.none + mask_filename: Optional[Path] = None + mask_tensor: Optional[Tensor] = None + mask_radius_angstroms: float = 50.0 + mask_soft_edge_width_pixels: int = 10 + num_shells: int + estimated_resolution_angstrom: float + estimated_resolution_angstrom_unmasked: float + estimated_resolution_angstrom_masked: Optional[float] = None + estimated_resolution_angstrom_corrected: Optional[float] = None + estimated_resolution_frequency_pixel: float + frequency_pixels: Tensor + resolution_angstroms: Tensor + fsc_values_unmasked: Tensor + fsc_values_masked: Optional[Tensor] = None + fsc_values_corrected: Optional[Tensor] = None + fsc_values_masked_randomized: Optional[Tensor] = None + fsc_values_randomized: Optional[Tensor] = None + correction_from_resolution_angstrom: Optional[float] = None + correct_from_fraction_of_estimated_resolution: float = 0.5 + + model_config: dict = {"arbitrary_types_allowed": True} + + +class RelionDataGeneral(BaseModel): + rlnFinalResolution: float + rlnUnfilteredMapHalf1: str + rlnUnfilteredMapHalf2: str + rlnParticleBoxFractionSolventMask: float = nan + rlnRandomiseFrom: float = nan + rlnMaskName: str = "" + + +class RelionFSCData(BaseModel): + rlnSpectralIndex: int + rlnResolution: float + rlnAngstromResolution: float + rlnFourierShellCorrelationCorrected: float + rlnFourierShellCorrelationParticleMaskFraction: float + rlnFourierShellCorrelationUnmaskedMaps: float + rlnFourierShellCorrelationMaskedMaps: float + rlnCorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps: float + + +class RelionStarfile(BaseModel): + data_general: RelionDataGeneral + fsc_data: List[RelionFSCData] diff --git a/src/ttfsc/_masking.py b/src/ttfsc/_masking.py index ce41950..3ad612f 100644 --- a/src/ttfsc/_masking.py +++ b/src/ttfsc/_masking.py @@ -1,42 +1,29 @@ -from enum import Enum -from typing import Optional - import torch from torch_fourier_shell_correlation import fsc - -class Masking(str, Enum): - none = "none" - sphere = "sphere" +from ._data_models import Masking, TTFSCResult -def calculate_noise_injected_fsc( - map1_tensor: torch.tensor, - map2_tensor: torch.tensor, - mask_tensor: torch.tensor, - fsc_values_masked: torch.tensor, - pixel_spacing_angstroms: float, - fsc_threshold: float, - estimated_resolution_frequency_pixel: float, - correct_from_resolution: Optional[float] = None, - correct_from_fraction_of_estimated_resolution: float = 0.5, -) -> tuple[float, float, float, torch.tensor, torch.tensor]: +def calculate_noise_injected_fsc(result: TTFSCResult) -> None: from torch_grid_utils import fftfreq_grid - map1_tensor_randomized = torch.fft.rfftn(map1_tensor) - map2_tensor_randomized = torch.fft.rfftn(map2_tensor) + map1_tensor_randomized = torch.fft.rfftn(result.map1_tensor) + map2_tensor_randomized = torch.fft.rfftn(result.map2_tensor) frequency_grid = fftfreq_grid( - image_shape=map1_tensor.shape, + image_shape=result.map1_tensor.shape, rfft=True, fftshift=False, norm=True, device=map1_tensor_randomized.device, ) - if correct_from_resolution is not None: - to_correct = frequency_grid > (1 / correct_from_resolution) / pixel_spacing_angstroms + if result.correction_from_resolution_angstrom is not None: + to_correct = frequency_grid > (1 / result.correction_from_resolution_angstrom) / result.pixel_spacing_angstroms else: - to_correct = frequency_grid > correct_from_fraction_of_estimated_resolution * estimated_resolution_frequency_pixel - # Rotate phases at frequencies higher than 0.25 + to_correct = ( + frequency_grid + > result.correct_from_fraction_of_estimated_resolution * result.estimated_resolution_frequency_pixel + ) + random_phases1 = torch.rand(frequency_grid[to_correct].shape) * 2 * torch.pi random_phases1 = torch.complex(torch.cos(random_phases1), torch.sin(random_phases1)) random_phases2 = torch.rand(frequency_grid[to_correct].shape) * 2 * torch.pi @@ -48,75 +35,71 @@ def calculate_noise_injected_fsc( map1_tensor_randomized = torch.fft.irfftn(map1_tensor_randomized) map2_tensor_randomized = torch.fft.irfftn(map2_tensor_randomized) - map1_tensor_randomized *= mask_tensor - map2_tensor_randomized *= mask_tensor - fsc_values_masked_randomized = fsc(map1_tensor_randomized, map2_tensor_randomized) - - frequency_pixels = torch.fft.rfftfreq(map1_tensor.shape[0]) - resolution_angstroms = (1 / frequency_pixels) * pixel_spacing_angstroms + map1_tensor_randomized *= result.mask_tensor + map2_tensor_randomized *= result.mask_tensor + result.fsc_values_masked_randomized = fsc(map1_tensor_randomized, map2_tensor_randomized) - if correct_from_resolution is None: - to_correct = frequency_pixels > correct_from_fraction_of_estimated_resolution * estimated_resolution_frequency_pixel - correct_from_resolution = pixel_spacing_angstroms / ( - correct_from_fraction_of_estimated_resolution * estimated_resolution_frequency_pixel + if result.correction_from_resolution_angstrom is None: + to_correct = ( + result.frequency_pixels + > result.correct_from_fraction_of_estimated_resolution * result.estimated_resolution_frequency_pixel + ) + result.correction_from_resolution_angstrom = result.pixel_spacing_angstroms / ( + result.correct_from_fraction_of_estimated_resolution * result.estimated_resolution_frequency_pixel ) else: - to_correct = frequency_pixels > (1 / correct_from_resolution) / pixel_spacing_angstroms - fsc_values_corrected = fsc_values_masked.clone() - fsc_values_corrected[to_correct] = (fsc_values_corrected[to_correct] - fsc_values_masked_randomized[to_correct]) / ( - 1.0 - fsc_values_masked_randomized[to_correct] + to_correct = ( + result.frequency_pixels > (1 / result.correction_from_resolution_angstrom) / result.pixel_spacing_angstroms + ) + if result.fsc_values_masked is None: + raise ValueError("Must calculate masked FSC before correcting for masking") + result.fsc_values_corrected = result.fsc_values_masked.clone() + result.fsc_values_corrected[to_correct] = ( + result.fsc_values_corrected[to_correct] - result.fsc_values_masked_randomized[to_correct] + ) / (1.0 - result.fsc_values_masked_randomized[to_correct]) + + result.estimated_resolution_frequency_pixel = float( + result.frequency_pixels[(result.fsc_values_corrected < result.fsc_threshold).nonzero()[0] - 1] ) - - estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_corrected < fsc_threshold).nonzero()[0] - 1]) - estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_corrected < fsc_threshold).nonzero()[0] - 1]) - - return ( - estimated_resolution_angstrom, - estimated_resolution_frequency_pixel, - correct_from_resolution, - fsc_values_corrected, - fsc_values_masked_randomized, + result.estimated_resolution_angstrom = float( + result.resolution_angstroms[(result.fsc_values_corrected < result.fsc_threshold).nonzero()[0] - 1] ) + result.estimated_resolution_angstrom_corrected = result.estimated_resolution_angstrom -def calculate_masked_fsc( - map1_tensor: torch.tensor, - map2_tensor: torch.tensor, - pixel_spacing_angstroms: float, - fsc_threshold: float, - mask: Masking, - mask_radius_angstroms: float = 100.0, - mask_soft_edge_width_pixels: int = 5, -) -> tuple[float, float, torch.tensor, torch.tensor]: - if mask == Masking.none: +def calculate_masked_fsc(result: TTFSCResult) -> None: + if result.mask == Masking.none: raise ValueError("Must choose a mask type") - if mask == Masking.sphere: + if result.mask == Masking.sphere: import numpy as np from ttmask.box_setup import box_setup from ttmask.soft_edge import add_soft_edge # Taken from https://github.com/teamtomo/ttmask/blob/main/src/ttmask/sphere.py # establish our coordinate system and empty mask - coordinates_centered, mask_tensor = box_setup(map1_tensor.shape[0]) + coordinates_centered, mask_tensor = box_setup(result.map1_tensor.shape[0]) # determine distances of each pixel to the center distance_to_center = np.linalg.norm(coordinates_centered, axis=-1) # set up criteria for which pixels are inside the sphere and modify values to 1. - inside_sphere = distance_to_center < (mask_radius_angstroms / pixel_spacing_angstroms) + inside_sphere = distance_to_center < (result.mask_radius_angstroms / result.pixel_spacing_angstroms) mask_tensor[inside_sphere] = 1 # if requested, a soft edge is added to the mask - mask_tensor = torch.tensor(add_soft_edge(mask_tensor, mask_soft_edge_width_pixels)) + result.mask_tensor = torch.tensor(add_soft_edge(mask_tensor, result.mask_soft_edge_width_pixels)) - map1_tensor_masked = map1_tensor * mask_tensor - map2_tensor_masked = map2_tensor * mask_tensor - fsc_values_masked = fsc(map1_tensor_masked, map2_tensor_masked) + map1_tensor_masked = result.map1_tensor * result.mask_tensor + map2_tensor_masked = result.map2_tensor * result.mask_tensor + result.fsc_values_masked = fsc(map1_tensor_masked, map2_tensor_masked) - frequency_pixels = torch.fft.rfftfreq(map1_tensor.shape[0]) - resolution_angstroms = (1 / frequency_pixels) * pixel_spacing_angstroms - - estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1]) - estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_masked < fsc_threshold).nonzero()[0] - 1]) + result.estimated_resolution_frequency_pixel = float( + result.frequency_pixels[(result.fsc_values_masked < result.fsc_threshold).nonzero()[0] - 1] + ) + result.estimated_resolution_angstrom = float( + result.resolution_angstroms[(result.fsc_values_masked < result.fsc_threshold).nonzero()[0] - 1] + ) + result.estimated_resolution_angstrom_masked = result.estimated_resolution_angstrom - return (estimated_resolution_angstrom, estimated_resolution_frequency_pixel, fsc_values_masked, mask_tensor) + return + raise NotImplementedError("Only sphere masking is implemented") diff --git a/src/ttfsc/_plotting.py b/src/ttfsc/_plotting.py index 5a67c1d..e28e98b 100644 --- a/src/ttfsc/_plotting.py +++ b/src/ttfsc/_plotting.py @@ -1,61 +1,42 @@ -from typing import Optional +from ._data_models import TTFSCResult -import torch - -def plot_matplotlib( - fsc_values_unmasked: torch.Tensor, - fsc_values_masked: Optional[torch.Tensor], - fsc_values_corrected: Optional[torch.Tensor], - frequency_pixels: torch.Tensor, - pixel_spacing_angstroms: float, - estimated_resolution_frequency_pixel: float, - fsc_threshold: float, - plot_matplotlib_style: str, -) -> None: +def plot_matplotlib(result: TTFSCResult, plot_matplotlib_style: str) -> None: from matplotlib import pyplot as plt plt.style.use(plot_matplotlib_style) - plt.hlines(0, frequency_pixels[1], frequency_pixels[-2], "black") - plt.plot(frequency_pixels[1:], fsc_values_unmasked[1:], label="FSC (unmasked)") - if fsc_values_masked is not None: - plt.plot(frequency_pixels[1:], fsc_values_masked[1:], label="FSC (masked)") - if fsc_values_corrected is not None: - plt.plot(frequency_pixels[1:], fsc_values_corrected[1:], label="FSC (corrected)") + plt.hlines(0, result.frequency_pixels[1], result.frequency_pixels[-2], "black") + plt.plot(result.frequency_pixels[1:], result.fsc_values_unmasked[1:], label="FSC (unmasked)") + if result.fsc_values_masked is not None: + plt.plot(result.frequency_pixels[1:], result.fsc_values_masked[1:], label="FSC (masked)") + if result.fsc_values_corrected is not None: + plt.plot(result.frequency_pixels[1:], result.fsc_values_corrected[1:], label="FSC (corrected)") plt.xlabel("Resolution (Å)") plt.ylabel("Correlation") - plt.gca().xaxis.set_major_formatter(lambda x, pos: f"{(1 / x) * pixel_spacing_angstroms:.2f}") - plt.xlim(frequency_pixels[1], frequency_pixels[-2]) + plt.gca().xaxis.set_major_formatter(lambda x, pos: f"{(1 / x) * result.pixel_spacing_angstroms:.2f}") + plt.xlim(result.frequency_pixels[1], result.frequency_pixels[-2]) plt.ylim(-0.05, 1.05) - plt.hlines(fsc_threshold, frequency_pixels[1], estimated_resolution_frequency_pixel, "red", "--") - plt.vlines(estimated_resolution_frequency_pixel, -0.05, fsc_threshold, "red", "--") + plt.hlines(result.fsc_threshold, result.frequency_pixels[1], result.estimated_resolution_frequency_pixel, "red", "--") + plt.vlines(result.estimated_resolution_frequency_pixel, -0.05, result.fsc_threshold, "red", "--") plt.legend() plt.tight_layout() plt.show() -def plot_plottile( - fsc_values: torch.Tensor, - fsc_values_masked: Optional[torch.Tensor], - fsc_values_corrected: Optional[torch.Tensor], - frequency_pixels: torch.Tensor, - pixel_spacing_angstroms: float, - estimated_resolution_frequency_pixel: float, - fsc_threshold: float, -) -> None: +def plot_plottile(result: TTFSCResult) -> None: import plotille fig = plotille.Figure() fig.width = 60 fig.height = 20 - fig.set_x_limits(float(frequency_pixels[1]), float(frequency_pixels[-1])) + fig.set_x_limits(float(result.frequency_pixels[1]), float(result.frequency_pixels[-1])) fig.set_y_limits(0, 1) fig.x_label = "Resolution [Å]" fig.y_label = "FSC" def resolution_callback(x: float, _: float) -> str: - return f"{(1 / x) * pixel_spacing_angstroms:.2f}" + return f"{(1 / x) * result.pixel_spacing_angstroms:.2f}" fig.x_ticks_fkt = resolution_callback @@ -63,18 +44,26 @@ def fsc_callback(x: float, _: float) -> str: return f"{x:.2f}" fig.y_ticks_fkt = fsc_callback - fig.plot(frequency_pixels[1:].numpy(), fsc_values[1:].numpy(), lc="blue", label="FSC (unmasked)") - if fsc_values_masked is not None: - fig.plot(frequency_pixels[1:].numpy(), fsc_values_masked[1:].numpy(), lc="green", label="FSC (masked)") - if fsc_values_corrected is not None: - fig.plot(frequency_pixels[1:].numpy(), fsc_values_corrected[1:].numpy(), lc="yellow", label="FSC (corrected)") + fig.plot(result.frequency_pixels[1:].numpy(), result.fsc_values_unmasked[1:].numpy(), lc="blue", label="FSC (unmasked)") + if result.fsc_values_masked is not None: + fig.plot(result.frequency_pixels[1:].numpy(), result.fsc_values_masked[1:].numpy(), lc="green", label="FSC (masked)") + if result.fsc_values_corrected is not None: + fig.plot( + result.frequency_pixels[1:].numpy(), + result.fsc_values_corrected[1:].numpy(), + lc="yellow", + label="FSC (corrected)", + ) fig.plot( - [float(frequency_pixels[1].numpy()), estimated_resolution_frequency_pixel], - [fsc_threshold, fsc_threshold], + [float(result.frequency_pixels[1].numpy()), result.estimated_resolution_frequency_pixel], + [result.fsc_threshold, result.fsc_threshold], lc="red", label=" ", ) fig.plot( - [estimated_resolution_frequency_pixel, estimated_resolution_frequency_pixel], [0, fsc_threshold], lc="red", label=" " + [result.estimated_resolution_frequency_pixel, result.estimated_resolution_frequency_pixel], + [0, result.fsc_threshold], + lc="red", + label=" ", ) print(fig.show(legend=True)) diff --git a/src/ttfsc/_starfile_schema.py b/src/ttfsc/_starfile_schema.py deleted file mode 100644 index e3b9215..0000000 --- a/src/ttfsc/_starfile_schema.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import List - -from numpy import nan -from pydantic import BaseModel - - -class RelionDataGeneral(BaseModel): - rlnFinalResolution: float - rlnUnfilteredMapHalf1: str - rlnUnfilteredMapHalf2: str - rlnParticleBoxFractionSolventMask: float = nan - rlnRandomiseFrom: float = nan - rlnMaskName: str = "" - - -class RelionFSCData(BaseModel): - rlnSpectralIndex: int - rlnResolution: float - rlnAngstromResolution: float - rlnFourierShellCorrelationCorrected: float - rlnFourierShellCorrelationParticleMaskFraction: float - rlnFourierShellCorrelationUnmaskedMaps: float - rlnFourierShellCorrelationMaskedMaps: float - rlnCorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps: float - - -class RelionStarfile(BaseModel): - data_general: RelionDataGeneral - fsc_data: List[RelionFSCData] diff --git a/src/ttfsc/ttfsc.py b/src/ttfsc/ttfsc.py new file mode 100644 index 0000000..eaa9ec8 --- /dev/null +++ b/src/ttfsc/ttfsc.py @@ -0,0 +1,111 @@ +""" +Provides functionality for Fourier Shell Correlation (FSC) analysis. + +The main function in this module is `ttfsc`, which calculates the FSC between two +3D maps and returns a `TTFSCResult` object containing the results of the analysis. + +Example: + ttfsc("map1.mrc","map2.mrc") +""" + +from pathlib import Path +from typing import Optional + +import mrcfile +import torch +from torch_fourier_shell_correlation import fsc + +from ._data_models import TTFSCResult +from ._masking import Masking + + +def ttfsc( + map1: Path, + map2: Path, + pixel_spacing_angstroms: Optional[float] = None, + fsc_threshold: float = 0.143, + mask: Masking = Masking.none, + mask_radius_angstroms: float = 100.0, + mask_soft_edge_width_pixels: int = 10, + correct_for_masking: bool = True, + correct_from_resolution: Optional[float] = None, + correct_from_fraction_of_estimated_resolution: float = 0.5, +) -> TTFSCResult: + """ + Perform Fourier Shell Correlation (FSC) analysis between two maps. + + Args: + map1 (Path): Path to the first map file. + map2 (Path): Path to the second map file. + pixel_spacing_angstroms (Optional[float]): Pixel spacing in Å/px. If not provided, it will be taken from the header. + fsc_threshold (float): FSC threshold value. Default is 0.143. + mask (Masking): Masking option to use. Default is Masking.none. + mask_radius_angstroms (float): Radius of the mask in Å. Default is 100.0. + mask_soft_edge_width_pixels (int): Width of the soft edge of the mask in pixels. Default is 10. + correct_for_masking (bool): Whether to correct for masking effects. Default is True. + correct_from_resolution (Optional[float]): Resolution from which to start correction. + Default is None. + correct_from_fraction_of_estimated_resolution (float): Fraction of the estimated resolution + from which to start correction. Default is 0.5. + + Returns + ------- + TTFSCResult: The result of the FSC analysis, including FSC curves and resolution estimates. + + Example: + result = ttfsc( + map1=Path("map1.mrc"), + map2=Path("map2.mrc"), + pixel_spacing_angstroms=1.0, + fsc_threshold=0.143, + mask=Masking.soft, + mask_radius_angstroms=150.0, + mask_soft_edge_width_pixels=5, + correct_for_masking=True, + correct_from_resolution=3.0, + correct_from_fraction_of_estimated_resolution=0.5 + ) + """ + with mrcfile.open(map1) as f: + map1_tensor = torch.tensor(f.data) + if pixel_spacing_angstroms is None: + pixel_spacing_angstroms = f.voxel_size.x + with mrcfile.open(map2) as f: + map2_tensor = torch.tensor(f.data) + + frequency_pixels = torch.fft.rfftfreq(map1_tensor.shape[0]) + resolution_angstroms = (1 / frequency_pixels) * pixel_spacing_angstroms + + fsc_values_unmasked = fsc(map1_tensor, map2_tensor) + + estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1]) + estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1]) + result = TTFSCResult( + map1=map1, + map1_tensor=map1_tensor, + map2=map2, + map2_tensor=map2_tensor, + pixel_spacing_angstroms=pixel_spacing_angstroms, + fsc_threshold=fsc_threshold, + num_shells=len(frequency_pixels), + estimated_resolution_angstrom=estimated_resolution_angstrom, + estimated_resolution_angstrom_unmasked=estimated_resolution_angstrom, + estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel, + frequency_pixels=frequency_pixels, + resolution_angstroms=resolution_angstroms, + fsc_values_unmasked=fsc_values_unmasked, + ) + if mask != Masking.none: + from ._masking import calculate_masked_fsc + + result.mask = mask + result.mask_radius_angstroms = mask_radius_angstroms + result.mask_soft_edge_width_pixels = mask_soft_edge_width_pixels + calculate_masked_fsc(result) + if correct_for_masking: + from ._masking import calculate_noise_injected_fsc + + result.correction_from_resolution_angstrom = correct_from_resolution + result.correct_from_fraction_of_estimated_resolution = correct_from_fraction_of_estimated_resolution + calculate_noise_injected_fsc(result) + return result diff --git a/tests/test_ttfsc.py b/tests/test_ttfsc.py index 94cfc9e..7c80006 100644 --- a/tests/test_ttfsc.py +++ b/tests/test_ttfsc.py @@ -26,8 +26,16 @@ def halfmap2(): runner = CliRunner() -def test_app(halfmap1, halfmap2): +def test_app_nomask(halfmap1, halfmap2): result = runner.invoke(cli, [halfmap1, halfmap2]) print(result.output) assert result.exit_code == 0 assert "Estimated resolution using 0.143 criterion in unmasked map: 3.63 Å" in result.output + + +def test_app_spherical_mask(halfmap1, halfmap2): + result = runner.invoke(cli, [halfmap1, halfmap2, "--mask", "sphere", "--mask-radius-angstroms", "50"]) + print(result.output) + assert result.exit_code == 0 + assert "Estimated resolution using 0.143 criterion in masked map: 3.26 Å" in result.output + assert "Estimated resolution using 0.143 criterion with correction after 6.53 Å: 3.26 Å" in result.output