Skip to content

Commit

Permalink
Modify Stamp storage class for multi-arch elements
Browse files Browse the repository at this point in the history
  • Loading branch information
leeskelvin committed Mar 19, 2024
1 parent ab1f63e commit 2d8382f
Showing 1 changed file with 75 additions and 20 deletions.
95 changes: 75 additions & 20 deletions python/lsst/meas/algorithms/stamps.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from lsst.utils import doImport
from lsst.utils.introspection import get_full_type_name

DEFAULT_ARCHIVE_ELEMENT_NAME = "ELEMENT"


def writeFits(filename, stamps, metadata, type_name, write_mask, write_variance, write_archive=False):
"""Write a single FITS file containing all stamps.
Expand Down Expand Up @@ -65,25 +67,34 @@ def writeFits(filename, stamps, metadata, type_name, write_mask, write_variance,
metadata["N_STAMPS"] = len(stamps)
metadata["STAMPCLS"] = type_name
# Record version number in case of future code changes
metadata["VERSION"] = 1
metadata["VERSION"] = 2
# create primary HDU with global metadata
fitsFile = Fits(filename, "w")
fitsFile.createEmpty()
# Store Persistables in an OutputArchive and write it
if write_archive:
archive_ids = []
oa = OutputArchive()
archive_ids = [oa.put(stamp.archive_element) for stamp in stamps]
metadata["ARCHIVE_IDS"] = archive_ids
names = set()
for stamp in stamps:
archive_elements = stamp._getArchiveElements()
archive_ids.append({name: oa.put(persistable) for name, persistable in archive_elements.items()})
names.update(archive_elements.keys())
fitsFile.writeMetadata(metadata)
oa.writeFits(fitsFile)
else:
archive_ids = [None] * len(stamps)
fitsFile.writeMetadata(metadata)
fitsFile.closeFile()
# add all pixel data optionally writing mask and variance information
for i, stamp in enumerate(stamps):
for i, (stamp, stamp_archive_ids) in enumerate(zip(stamps, archive_ids)):
metadata = PropertyList()
# EXTVER should be 1-based, the index from enumerate is 0-based
metadata.update({"EXTVER": i + 1, "EXTNAME": "IMAGE"})
if stamp_archive_ids:
metadata.update(stamp_archive_ids)
for name in sorted(names):
metadata.add("ARCHIVE_ELEMENT", name)
stamp.stamp_im.getImage().writeFits(filename, metadata=metadata, mode="a")
if write_mask:
metadata = PropertyList()
Expand Down Expand Up @@ -131,8 +142,13 @@ def readFitsWithOptions(filename, stamp_factory, options):
metadata = readMetadata(filename, hdu=0)
nStamps = metadata["N_STAMPS"]
has_archive = metadata["HAS_ARCHIVE"]
archive_names = None
archive_ids_v1 = None
if has_archive:
archive_ids = metadata.getArray("ARCHIVE_IDS")
if metadata["VERSION"] < 2:
archive_ids_v1 = metadata.getArray("ARCHIVE_IDS")
else:
archive_names = metadata.getArray("ARCHIVE_ELEMENT")
with Fits(filename, "r") as f:
nExtensions = f.countHdus()
# check if a bbox was provided
Expand Down Expand Up @@ -165,6 +181,7 @@ def readFitsWithOptions(filename, stamp_factory, options):
variance_dtype = np.dtype(np.float32) # Variance is always the same type.

# We need to be careful because nExtensions includes the primary HDU.
archive_ids = {}
for idx in range(nExtensions - 1):
dtype = None
md = readMetadata(filename, hdu=idx + 1)
Expand All @@ -174,6 +191,8 @@ def readFitsWithOptions(filename, stamp_factory, options):
dtype = variance_dtype
else:
dtype = default_dtype
if archive_names is not None:
archive_ids[idx] = {name: md[name] for name in archive_names if name in md.keys()}
elif md["EXTNAME"] == "MASK":
reader = MaskFitsReader(filename, hdu=idx + 1)
elif md["EXTNAME"] == "ARCHIVE_INDEX":
Expand All @@ -184,8 +203,9 @@ def readFitsWithOptions(filename, stamp_factory, options):
continue
else:
raise ValueError(f"Unknown extension type: {md['EXTNAME']}")
stamp_parts.setdefault(md["EXTVER"], {})[md["EXTNAME"].lower()] = reader.read(dtype=dtype,
**kwargs)
stamp_parts.setdefault(md["EXTVER"], {})[md["EXTNAME"].lower()] = reader.read(
dtype=dtype, **kwargs
)
if len(stamp_parts) != nStamps:
raise ValueError(
f"Number of stamps read ({len(stamp_parts)}) does not agree with the "
Expand All @@ -196,8 +216,12 @@ def readFitsWithOptions(filename, stamp_factory, options):
for k in range(nStamps):
# Need to increment by one since EXTVER starts at 1
maskedImage = masked_image_cls(**stamp_parts[k + 1])
archive_element = archive.get(archive_ids[k]) if has_archive else None
stamps.append(stamp_factory(maskedImage, metadata, k, archive_element))
if archive_ids_v1 is not None:
archive_elements = {DEFAULT_ARCHIVE_ELEMENT_NAME: archive.get(archive_ids_v1[k])}
elif archive_names is not None:
stamp_archive_ids = archive_ids.get(k, {})
archive_elements = {name: archive.get(id) for name, id in stamp_archive_ids.items()}
stamps.append(stamp_factory(maskedImage, metadata, k, archive_elements))

return stamps, metadata

Expand All @@ -213,7 +237,7 @@ class AbstractStamp(abc.ABC):

@classmethod
@abc.abstractmethod
def factory(cls, stamp_im, metadata, index, archive_element=None):
def factory(cls, stamp_im, metadata, index, archive_elements=None):
"""This method is needed to service the FITS reader. We need a standard
interface to construct objects like this. Parameters needed to
construct this object are passed in via a metadata dictionary and then
Expand All @@ -228,15 +252,30 @@ def factory(cls, stamp_im, metadata, index, archive_element=None):
needed by the constructor.
idx : `int`
Index into the lists in ``metadata``
archive_element : `~lsst.afw.table.io.Persistable`, optional
Archive element (e.g. Transform or WCS) associated with this stamp.
archive_elements : `~collections.abc.Mapping`[ `str` , \
`~lsst.afw.table.io.Persistable`], optional
Archive elements (e.g. Transform / WCS) associated with this stamp.
Returns
-------
stamp : `AbstractStamp`
An instance of this class
"""
raise NotImplementedError
raise NotImplementedError()

@abc.abstractmethod
def _getMaskedImage(self):
"""Return the image data."""
raise NotImplementedError()

@abc.abstractmethod
def _getArchiveElements(self):
"""Return the archive elements.
Keys should be upper case names that will be used directly as FITS
header keys.
"""
raise NotImplementedError()


def _default_position():
Expand Down Expand Up @@ -265,7 +304,7 @@ class Stamp(AbstractStamp):
position: SpherePoint | None = field(default_factory=_default_position)

@classmethod
def factory(cls, stamp_im, metadata, index, archive_element=None):
def factory(cls, stamp_im, metadata, index, archive_elements=None):
"""This method is needed to service the FITS reader. We need a standard
interface to construct objects like this. Parameters needed to
construct this object are passed in via a metadata dictionary and then
Expand All @@ -283,14 +322,23 @@ def factory(cls, stamp_im, metadata, index, archive_element=None):
needed by the constructor.
idx : `int`
Index into the lists in ``metadata``
archive_element : `~lsst.afw.table.io.Persistable`, optional
Archive element (e.g. Transform or WCS) associated with this stamp.
archive_elements : `~collections.abc.Mapping`[ `str` , \
`~lsst.afw.table.io.Persistable`], optional
Archive elements (e.g. Transform / WCS) associated with this stamp.
Returns
-------
stamp : `Stamp`
An instance of this class
"""
if archive_elements:
try:
(archive_element,) = archive_elements.values()
except TypeError:
raise RuntimeError("Expected exactly one archive element.")
else:
archive_element = None

if "RA_DEG" in metadata and "DEC_DEG" in metadata:
return cls(
stamp_im=stamp_im,
Expand All @@ -307,6 +355,12 @@ def factory(cls, stamp_im, metadata, index, archive_element=None):
position=SpherePoint(Angle(np.nan), Angle(np.nan)),
)

def _getMaskedImage(self):
return self.stamp_im

def _getArchiveElements(self):
return {DEFAULT_ARCHIVE_ELEMENT_NAME: self.archive_element}


class StampsBase(abc.ABC, Sequence):
"""Collection of stamps and associated metadata.
Expand Down Expand Up @@ -437,17 +491,18 @@ def getMaskedImages(self):
maskedImages :
`list` [`~lsst.afw.image.MaskedImageF`]
"""
return [stamp.stamp_im for stamp in self._stamps]
return [stamp._getMaskedImage() for stamp in self._stamps]

def getArchiveElements(self):
"""Retrieve archive elements associated with each stamp.
Returns
-------
archiveElements :
`list` [`~lsst.afw.table.io.Persistable`]
archiveElements : `list` [`dict`[ `str`, \
`~lsst.afw.table.io.Persistable` ]]
A list of archive elements associated with each stamp.
"""
return [stamp.archive_element for stamp in self._stamps]
return [stamp._getArchiveElements() for stamp in self._stamps]

@property
def metadata(self):
Expand Down

0 comments on commit 2d8382f

Please sign in to comment.