Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jojoelfe committed Sep 8, 2024
1 parent 02a1daa commit 07c8d10
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 240 deletions.
152 changes: 56 additions & 96 deletions src/ttfsc/_cli.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from pathlib import Path
from typing import Annotated, Optional

import mrcfile
import torch
import typer
from rich import print as rprint
from torch_fourier_shell_correlation import fsc

from ._masking import Masking
from .ttfsc import ttfsc

cli = typer.Typer(name="ttfsc", no_args_is_help=True, add_completion=False)

Expand Down Expand Up @@ -56,98 +54,77 @@ def ttfsc_cli(
float, typer.Option("--correct-from-fraction-of-estimated-resolution", rich_help_panel="Masking correction options")
] = 0.5,
) -> None:
with mrcfile.open(map1) as f:
map1_tensor = torch.tensor(f.data)
if pixel_spacing_angstroms is None:
pixel_spacing_angstroms = f.voxel_size.x
with mrcfile.open(map2) as f:
map2_tensor = torch.tensor(f.data)

frequency_pixels = torch.fft.rfftfreq(map1_tensor.shape[0])
resolution_angstroms = (1 / frequency_pixels) * pixel_spacing_angstroms

fsc_values_unmasked = fsc(map1_tensor, map2_tensor)

estimated_resolution_frequency_pixel = float(frequency_pixels[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1])
estimated_resolution_angstrom = float(resolution_angstroms[(fsc_values_unmasked < fsc_threshold).nonzero()[0] - 1])

rprint(f"Estimated resolution using {fsc_threshold} criterion in unmasked map: {estimated_resolution_angstrom:.2f} Å")

fsc_values_masked = None
fsc_values_corrected = None
if mask != Masking.none:
from ._masking import calculate_masked_fsc

(estimated_resolution_angstrom, estimated_resolution_frequency_pixel, fsc_values_masked, mask_tensor) = (
calculate_masked_fsc(
map1_tensor,
map2_tensor,
pixel_spacing_angstroms=pixel_spacing_angstroms,
fsc_threshold=fsc_threshold,
mask=mask,
mask_radius_angstroms=mask_radius_angstroms,
mask_soft_edge_width_pixels=mask_soft_edge_width_pixels,
)
result = ttfsc(
map1=map1,
map2=map2,
pixel_spacing_angstroms=pixel_spacing_angstroms,
fsc_threshold=fsc_threshold,
mask=mask,
mask_radius_angstroms=mask_radius_angstroms,
mask_soft_edge_width_pixels=mask_soft_edge_width_pixels,
correct_for_masking=correct_for_masking,
correct_from_resolution=correct_from_resolution,
correct_from_fraction_of_estimated_resolution=correct_from_fraction_of_estimated_resolution,
)

rprint(
f"Estimated resolution using {fsc_threshold} criterion in unmasked map: "
f"{result.estimated_resolution_angstrom_unmasked:.2f} Å"
)
if result.estimated_resolution_angstrom_masked is not None:
rprint(
f"Estimated resolution using {fsc_threshold} "
f"criterion in masked map: {result.estimated_resolution_angstrom_masked:.2f} Å"
)
if result.estimated_resolution_angstrom_corrected is not None:
print(
f"Estimated resolution using {fsc_threshold} "
f"criterion with correction after {result.correction_from_resolution_angstrom:.2f} Å: "
f"{result.estimated_resolution_angstrom_corrected:.2f} Å"
)
rprint(f"Estimated resolution using {fsc_threshold} criterion in masked map: {estimated_resolution_angstrom:.2f} Å")
if correct_for_masking:
from ._masking import calculate_noise_injected_fsc

(
estimated_resolution_angstrom,
estimated_resolution_frequency_pixel,
correction_from_resolution_angstrom,
fsc_values_corrected,
fsc_values_randomized,
) = calculate_noise_injected_fsc(
map1_tensor,
map2_tensor,
mask_tensor=mask_tensor,
fsc_values_masked=fsc_values_masked,
pixel_spacing_angstroms=pixel_spacing_angstroms,
fsc_threshold=fsc_threshold,
estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel,
correct_from_resolution=correct_from_resolution,
correct_from_fraction_of_estimated_resolution=correct_from_fraction_of_estimated_resolution,
)
rprint(
f"Estimated resolution using {fsc_threshold} "
f"criterion with correction after {correction_from_resolution_angstrom:.2f} Å: "
f"{estimated_resolution_angstrom:.2f} Å"
)
if save_starfile:
import pandas as pd
import starfile
from numpy import nan

from ._starfile_schema import RelionDataGeneral, RelionFSCData
from ._data_models import RelionDataGeneral, RelionFSCData

data_general = RelionDataGeneral(
rlnFinalResolution=estimated_resolution_angstrom,
rlnUnfilteredMapHalf1=map1.name,
rlnUnfilteredMapHalf2=map2.name,
rlnFinalResolution=result.estimated_resolution_angstrom,
rlnUnfilteredMapHalf1=str(result.map1),
rlnUnfilteredMapHalf2=str(result.map2),
)
if mask != Masking.none:
data_general.rlnParticleBoxFractionSolventMask = mask_tensor.sum().item() / mask_tensor.numel()
if result.mask_tensor is not None:
data_general.rlnParticleBoxFractionSolventMask = result.mask_tensor.sum().item() / result.mask_tensor.numel()
if correct_for_masking:
data_general.rlnRandomiseFrom = correction_from_resolution_angstrom
if result.correction_from_resolution_angstrom is None:
raise ValueError("Phase randomization cutoff has not been calculated")
data_general.rlnRandomiseFrom = result.correction_from_resolution_angstrom

fsc_data = []
for i, (f, r) in enumerate(zip(fsc_values_unmasked, resolution_angstroms)):
for i in range(result.num_shells):
fsc_data.append(
RelionFSCData(
rlnSpectralIndex=i,
rlnResolution=r,
rlnAngstromResolution=r,
rlnFourierShellCorrelationCorrected=fsc_values_corrected[i] if fsc_values_corrected is not None else nan,
rlnFourierShellCorrelationUnmaskedMaps=f,
rlnFourierShellCorrelationMaskedMaps=fsc_values_masked[i] if fsc_values_masked is not None else nan,
rlnCorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps=fsc_values_randomized[i]
if correct_for_masking
rlnResolution=result.frequency_pixels[i],
rlnAngstromResolution=result.resolution_angstroms[i],
rlnFourierShellCorrelationCorrected=result.fsc_values_corrected[i]
if result.fsc_values_corrected is not None
else nan,
rlnFourierShellCorrelationUnmaskedMaps=result.fsc_values_unmasked[i],
rlnFourierShellCorrelationMaskedMaps=result.fsc_values_masked[i]
if result.fsc_values_masked is not None
else nan,
rlnFourierShellCorrelationParticleMaskFraction=f
rlnCorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps=result.fsc_values_randomized[i]
if result.fsc_values_randomized is not None
else nan,
rlnFourierShellCorrelationParticleMaskFraction=result.fsc_values_unmasked[i]
/ data_general.rlnParticleBoxFractionSolventMask
/ (1.0 + (1.0 / data_general.rlnParticleBoxFractionSolventMask - 1.0) * f.abs())
/ (
1.0
+ (1.0 / data_general.rlnParticleBoxFractionSolventMask - 1.0) * result.fsc_values_unmasked[i].abs()
)
if mask != Masking.none
else nan,
)
Expand All @@ -160,23 +137,6 @@ def ttfsc_cli(
from ._plotting import plot_matplotlib, plot_plottile

if plot_with_matplotlib:
plot_matplotlib(
fsc_values_unmasked=fsc_values_unmasked,
fsc_values_masked=fsc_values_masked,
fsc_values_corrected=fsc_values_corrected,
frequency_pixels=frequency_pixels,
pixel_spacing_angstroms=pixel_spacing_angstroms,
estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel,
fsc_threshold=fsc_threshold,
plot_matplotlib_style=plot_matplotlib_style,
)
plot_matplotlib(result, plot_matplotlib_style=plot_matplotlib_style)
else:
plot_plottile(
fsc_values=fsc_values_unmasked,
fsc_values_masked=fsc_values_masked,
fsc_values_corrected=fsc_values_corrected,
frequency_pixels=frequency_pixels,
pixel_spacing_angstroms=pixel_spacing_angstroms,
estimated_resolution_frequency_pixel=estimated_resolution_frequency_pixel,
fsc_threshold=fsc_threshold,
)
plot_plottile(result)
68 changes: 68 additions & 0 deletions src/ttfsc/_data_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from enum import Enum
from pathlib import Path
from typing import List, Optional

from numpy import nan
from pydantic import BaseModel
from torch import Tensor


class Masking(str, Enum):
none = "none"
sphere = "sphere"


class TTFSCResult(BaseModel):
map1: Path
map1_tensor: Tensor
map2: Path
map2_tensor: Tensor
pixel_spacing_angstroms: float
fsc_threshold: float
mask: str = Masking.none
mask_filename: Optional[Path] = None
mask_tensor: Optional[Tensor] = None
mask_radius_angstroms: float = 50.0
mask_soft_edge_width_pixels: int = 10
num_shells: int
estimated_resolution_angstrom: float
estimated_resolution_angstrom_unmasked: float
estimated_resolution_angstrom_masked: Optional[float] = None
estimated_resolution_angstrom_corrected: Optional[float] = None
estimated_resolution_frequency_pixel: float
frequency_pixels: Tensor
resolution_angstroms: Tensor
fsc_values_unmasked: Tensor
fsc_values_masked: Optional[Tensor] = None
fsc_values_corrected: Optional[Tensor] = None
fsc_values_masked_randomized: Optional[Tensor] = None
fsc_values_randomized: Optional[Tensor] = None
correction_from_resolution_angstrom: Optional[float] = None
correct_from_fraction_of_estimated_resolution: float = 0.5

model_config: dict = {"arbitrary_types_allowed": True}


class RelionDataGeneral(BaseModel):
rlnFinalResolution: float
rlnUnfilteredMapHalf1: str
rlnUnfilteredMapHalf2: str
rlnParticleBoxFractionSolventMask: float = nan
rlnRandomiseFrom: float = nan
rlnMaskName: str = ""


class RelionFSCData(BaseModel):
rlnSpectralIndex: int
rlnResolution: float
rlnAngstromResolution: float
rlnFourierShellCorrelationCorrected: float
rlnFourierShellCorrelationParticleMaskFraction: float
rlnFourierShellCorrelationUnmaskedMaps: float
rlnFourierShellCorrelationMaskedMaps: float
rlnCorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps: float


class RelionStarfile(BaseModel):
data_general: RelionDataGeneral
fsc_data: List[RelionFSCData]
Loading

0 comments on commit 07c8d10

Please sign in to comment.