Source code for ndcube.wcs.wrappers.compound_wcs
from functools import reduce
import numpy as np
from astropy.wcs.wcsapi.wrappers.base import BaseWCSWrapper
__all__ = ['CompoundLowLevelWCS']
def tuplesum(lists):
    return reduce(tuple.__add__, map(tuple, lists))
class Mapping:
    """
    Allows inputs to be reordered, duplicated or dropped.
    This is a very stripped down version of `astropy.modeling.models.Mapping`
    to be able to handle input of arbitrary type.
    Parameters
    ----------
    mapping : tuple
        A tuple of integers representing indices of the inputs to this model
        to return and in what order to return them. See
        :ref:`compound-model-mappings` for more details.
    """
    def __init__(self, mapping):
        self.mapping = mapping
        self.n_inputs = max(mapping) + 1
        self.n_outputs = len(mapping)
    def __call__(self, *values):
        return tuple(values[idx] for idx in self.mapping)
    @property
    def inverse(self):
        mapping = tuple(self.mapping.index(idx)
                        for idx in range(self.n_inputs))
        return type(self)(mapping)
    def __repr__(self):
        return f'<Mapping({self.mapping})>'
[docs]
class CompoundLowLevelWCS(BaseWCSWrapper):
    """
    A wrapper that takes multiple low level WCS objects and makes a compound
    WCS that combines them.
    Parameters
    ----------
    *wcs : `~astropy.wcs.wcsapi.BaseLowLevelWCS`
        The WCSes to combine
    mapping : `tuple`
        The pixel dimension mapping between the input pixel dimensions and the
        input pixel dimensions to the underlying WCSes. This should have length
        equal to the total number of pixel dimensions in all input WCSes and
        have a maximum of the number of input pixel dimensions to the resulting
        compound WCS -1 (counts from 0). For example ``(0, 1, 2, 1)`` would end
        up with the second and fourth pixel dimensions in the input WCSes being
        shared, so the compound WCS would have 3 pixel dimensions ``(2 + 1)``.
        See :ref:`compound-model-mappings` for more examples of this input
        format.
    pixel_atol : `float`
        A tolerance used to check that the resulting pixel coordinates from
        ``world_to_pixel`` are the same from all WCSes.
    """
    def __init__(self, *wcs, mapping=None, pixel_atol=1e-8):
        self._wcs = wcs
        if not mapping:
            mapping = tuple(range(self._all_pixel_n_dim))
        if not len(mapping) == self._all_pixel_n_dim:
            raise ValueError(
                "The length of the mapping must equal the total number of pixel dimensions in all input WCSes.")
        self.mapping = Mapping(mapping)
        self.atol = pixel_atol
        # Validate the pixel bounds and shape are consistent
        self.pixel_bounds
        self.pixel_shape
    @property
    def _all_pixel_n_dim(self):
        return sum([w.pixel_n_dim for w in self._wcs])
    @property
    def pixel_n_dim(self):
        return self.mapping.n_inputs
    @property
    def world_n_dim(self):
        return sum([w.world_n_dim for w in self._wcs])
    @property
    def world_axis_physical_types(self):
        return tuplesum([w.world_axis_physical_types for w in self._wcs])
    @property
    def world_axis_units(self):
        return tuplesum([w.world_axis_units for w in self._wcs])
[docs]
    def pixel_to_world_values(self, *pixel_arrays):
        pixel_arrays = self.mapping(*pixel_arrays)
        world_arrays = []
        for w in self._wcs:
            pixel_arrays_sub = pixel_arrays[:w.pixel_n_dim]
            world_arrays_sub = w.pixel_to_world_values(*pixel_arrays_sub)
            if w.world_n_dim > 1:
                world_arrays.extend(world_arrays_sub)
            else:
                world_arrays.append(world_arrays_sub)
            pixel_arrays = pixel_arrays[w.pixel_n_dim:]
        return tuple(world_arrays) 
[docs]
    def world_to_pixel_values(self, *world_arrays):
        pixel_arrays = []
        for w in self._wcs:
            world_arrays_sub = world_arrays[:w.world_n_dim]
            world_arrays = world_arrays[w.world_n_dim:]
            pixel_arrays_sub = w.world_to_pixel_values(*world_arrays_sub)
            if w.pixel_n_dim > 1:
                pixel_arrays.extend(pixel_arrays_sub)
            else:
                pixel_arrays.append(pixel_arrays_sub)
        mapped_axes = set(self.mapping.mapping)
        for mapped_axis in mapped_axes:
            idx, = np.atleast_1d(self.mapping.mapping == mapped_axis).nonzero()
            if len(idx) > 1:
                idx_0 = idx[0]
                for idx_n in idx[1:]:
                    if not np.allclose(pixel_arrays[idx_0], pixel_arrays[idx_n],
                                       atol=self.atol, equal_nan=True):
                        raise ValueError(
                            "The world inputs for shared pixel axes did not result in a pixel "
                            f"coordinate to within {self.atol} relative accuracy."
                        )
        return self.mapping.inverse(*pixel_arrays) 
    @property
    def world_axis_object_components(self):
        all_components = []
        for iw, w in enumerate(self._wcs):
            all_components += [(f'{component[0]}_{iw}',) + component[1:] for component
                               in w.world_axis_object_components]
        return all_components
    @property
    def world_axis_object_classes(self):
        # TODO: deal with name conflicts
        all_classes = {}
        for iw, w in enumerate(self._wcs):
            for key, value in w.world_axis_object_classes.items():
                all_classes[f'{key}_{iw}'] = value
        return all_classes
    @property
    def pixel_shape(self):
        if not any(w.array_shape is None for w in self._wcs):
            pixel_shape = tuplesum(w.pixel_shape for w in self._wcs)
            out_shape = self.mapping.inverse(*pixel_shape)
            for i, ix in enumerate(self.mapping.mapping):
                if out_shape[ix] != pixel_shape[i]:
                    raise ValueError(
                        "The pixel shapes of the supplied WCSes do not match "
                        "for the dimensions shared by the supplied mapping.")
            return out_shape
        return None
    @property
    def pixel_bounds(self):
        if any(w.pixel_bounds is not None for w in self._wcs):
            pixel_bounds = tuplesum(w.pixel_bounds or [() for _ in range(w.pixel_n_dim)] for w in self._wcs)
            out_bounds = self.mapping.inverse(*pixel_bounds)
            for i, ix in enumerate(self.mapping.mapping):
                if pixel_bounds[i] and (out_bounds[ix] != pixel_bounds[i]):
                    raise ValueError(
                        "The pixel bounds of the supplied WCSes do not match "
                        "for the dimensions shared by the supplied mapping.")
            iint = np.iinfo(int)
            return tuple(o or (iint.min, iint.max) for o in out_bounds)
        return None
    @property
    def pixel_axis_names(self):
        pixel_names = tuplesum(w.pixel_axis_names for w in self._wcs)
        out_names = self.mapping.inverse(*pixel_names)
        for i, ix in enumerate(self.mapping.mapping):
            if out_names[ix] != pixel_names[i]:
                out_names[ix] = ' / '.join([out_names[ix], pixel_names[i]])
        return out_names
    @property
    def world_axis_names(self):
        return tuplesum(w.world_axis_names for w in self._wcs)
    @property
    def axis_correlation_matrix(self):
        full_matrix = np.zeros((self.world_n_dim, self._all_pixel_n_dim), dtype=bool)
        iw = ip = 0
        for w in self._wcs:
            full_matrix[iw:iw + w.world_n_dim, ip:ip + w.pixel_n_dim] = w.axis_correlation_matrix
            iw += w.world_n_dim
            ip += w.pixel_n_dim
        matrix = np.zeros((self.world_n_dim, self.pixel_n_dim), dtype=bool)
        for i, ix in enumerate(self.mapping.mapping):
            matrix[:, ix] = np.logical_or(matrix[:, ix], full_matrix[:, i])
        return matrix
    @property
    def serialized_classes(self):
        return any(w.serialized_classes for w in self._wcs)