import numpy as np
import pandas as pd
from libpyhat import spectral_data
[docs]
def reflectance(data, central_wvl, kernel=5):
    result = median_kernel(data, central_wvl, kernel)
    return result 
[docs]
def compute_b_a(wavelengths):
    """
    Given a set of three wavelengths compute their b and a values as per
    the Viviano Beck CRISM Derived Products paper
    (Revised CRISM spectral parameters and summary
    products based on the currently detected
    mineral diversity on Mars)
    Parameters
    ----------
    wavelengths : iterable
        A list of three wavelength values
    Returns
    -------
    b : float
        b value from the paper
    a : float
        a value from the paper
    """
    wavelengths.sort()
    lambda_s, lambda_c, lambda_l = wavelengths
    b = (lambda_c - lambda_s) / (lambda_l - lambda_s)
    a = 1.0 - b
    return b, a 
[docs]
def compute_slope(x1, x2, y1, y2):
    """
    Computes slope given two points on a line
    Parameters
    ----------
    x1 : float
        First points x value
    x2 : float
        Second points x value
    y1 : float
        First points y value
    y2 : float
        Second points y value
    Returns
    -------
    : float
        Slope between the two points
    """
    return (y2 - y1) / (x2 - x1) 
[docs]
def line_fit(slope, x, b):
    """
    Finds the y value for a given x using a given slope and y intercept
    Parameters
    ----------
    slope : float
        Slope of a line
    x : float
        Point along the x axis of a line
    b : float
        Y intercept of a line
    Returns
    -------
    : float
        Y coordinate corresponding to the given x
    """
    return (slope * x) + b 
[docs]
def continuum_reflectance(data, central_wvl, left_wvl, right_wvl):
    left_refl = reflectance(data, left_wvl, kernel=0)
    right_refl = reflectance(data, right_wvl, kernel=0)
    slope = compute_slope(left_wvl, right_wvl, left_refl, right_refl)
    return line_fit(slope, (central_wvl - left_wvl), left_refl) 
[docs]
def clip_multiband(multiband_image, n):
    """
    Clips the values of each band of a multi-band image to +/- n stdevs
    Parameters
    ==========
    multiband_image: An array representing a multiband image with dimensions
    [x,y,bands]
    n: the number of stdevs to use when scaling
    Returns
    =======
    multiband_image: the image with each band mean centered and values
    clipped so that the range of values goes
     from -n*stdev to +n*stdev
    """
    for i in range(multiband_image.shape[2]):
        band = multiband_image[:, :, i]
        # calculate the mean and stdev of each band
        band_std = np.nanstd(band)
        band_mean = np.nanmean(band)
        # scale the band to +/- 3 stdevs
        band_min = band_mean - n * band_std
        band_max = band_mean + n * band_std
        band[band < band_min] = band_min
        band[band > band_max] = band_max
        band = (band - band_min) / (band_max - band_min)
        multiband_image[:, :, i] = band
    return multiband_image 
[docs]
def get_roi(data, roi_x, roi_y, meta_label="meta"):
    """
    Extract an ROI from a Spectra Data object
    Parameters
    ==========
    data: PyHAT Spectral Data object
    roi_x: min and max x values for the ROI [x_min,x_max]
    roi_y: min and max y values for the ROI [y_min,y_max]
    Returns
    =======
    data_roi : Spectral data object for the spectra within an ROI.
    (Note: Extracting an ROI strips any geodata that was associated with the
    original Spectral Data object, and
    x and y values are reset to begin at 0,0 within the ROI)
    """
    data_roi = data.df[
        (data.df[(meta_label, "x")] >= roi_x[0])
        * (data.df[(meta_label, "x")] <= roi_x[1])
        * (data.df[(meta_label, "y")] >= roi_y[0])
        * (data.df[(meta_label, "y")] <= roi_y[1])
    ]
    data_roi[(meta_label, "x")] = data_roi[(meta_label, "x")] - np.min(
        data_roi[(meta_label, "x")]
    )
    data_roi[(meta_label, "y")] = data_roi[(meta_label, "y")] - np.min(
        data_roi[(meta_label, "y")]
    )
    data_roi = spectral_data.SpectralData(
        data_roi, meta_label=data.meta_label, spect_label=data.spect_label, geodata=None
    )
    return data_roi 
[docs]
def get_roi_avg(data, roi_x, roi_y):
    """
    Calculate the average spectrum within a specified ROI
    Parameters
    ==========
    data: PyHAT Spectral Data object
    roi_x: min and max x values for the ROI [x_min,x_max]
    roi_y: min and max y values for the ROI [y_min,y_max]
    Returns
    =======
    roi_avg : The average spectrum of the ROI
    """
    data_roi = get_roi(data, roi_x, roi_y)
    roi_avg = data_roi.df[data_roi.spect_label].mean(axis=0)
    return roi_avg 
[docs]
def get_sub_spectrum(spectrum, low_wvl=None, high_wvl=None):
    """
    Grab a subdomain of spectrum according
    to starting and ending wavelength values. Endpoints will be included.
    Parameters
    ==========
    spectrum : single spectrum represented as a Pandas series, with indices
    corresponding to wavelength values.
    low_wvl (optional) : float
                    The low wavelength (in units of the actual wavelength
                    not the index value)
    high_wvl (optional) : float
                    The high wavelength (in units of the actual wavelength
                    not the index value)
    Returns
    =======
    sub_spectrum : libpyhat Spectrum object
                    This contains spectral information and corresponding
                    wavelength values.
                    See the Spectrum object documentation.
    """
    wvls = spectrum.index.values  # spectrum should be a series with wvls as
    # indices
    # If start and/or end values of the subdomain weren't provided, then just
    # use the start and/or end of the spectrum
    if low_wvl is None:
        low_wvl = np.min(wvls)
    if high_wvl is None:
        high_wvl = np.max(wvls)
    # Check to make sure the user didn't provide a nonsensical choice in the
    # min and max wavelength values. Return the provided spectrum if so.
    try:
        assert low_wvl < high_wvl
    except AssertionError as msg:
        print(msg)
        print("\nReturning the original spectrum.")
        return spectrum
    # Cut out the subdomain in the spectrum
    sub_spectrum = spectrum[(wvls >= low_wvl) * (wvls <= high_wvl)]
    return sub_spectrum 
[docs]
def band_minimum(spectrum, low_wvl=None, high_wvl=None):
    """
    Given a 1-D spectrum, find the minimum intensity value and
    the location of the minimum in wavelength space. A subset of
    the spectra can also be selected with optional kwargs.
    Parameters
    ==========
    spectrum : single spectrum represented as a Pandas series, with indices
    corresponding to wavelength values.
    low_wvl (optional) : float
                    The low wavelength (in units of the actual wavelength
                    not the index value)
    high_wvl (optional) : float
                    The high wavelength (in units of the actual wavelength
                    not the index value)
    Returns
    =======
    min_wvl : int
                    The wavelength of the minimum value
    min_value : float
                    The observed minimal value
    """
    # Get the portion of the spectrum between low_wvl and high_wvl
    sub_spectrum = get_sub_spectrum(spectrum, low_wvl=low_wvl, high_wvl=high_wvl)
    # Determine the lowest intensity value in the spectrum
    min_value = np.amin(sub_spectrum)
    # Determine the wavelength where the minimum in the spectrum exists
    min_wvl = sub_spectrum.index[sub_spectrum == min_value][0]
    return min_wvl, min_value 
[docs]
def band_center(spectrum, low_wvl=None, high_wvl=None, degree=3):
    """
    Given a spectrum, fit a polynomial to it. User can specify
    a subdomain of the spectrum to fit and the polynomial
    degree through optional kwargs. Be aware that high-degree polynomials
    may complicate the computation of the band center, and it is always
    important to plot the results to visualize the fit.
    Parameters
    ==========
    spectrum : single spectrum represented as a Pandas series, with indices
    corresponding to wavelength values.
    low_wvl (optional) : float
                    The low wavelength (in units of the actual wavelength
                    not the index value)
    high_wvl (optional) : float
                    The high wavelength (in units of the actual wavelength
                    not the index value)
    degree (optional) : int
                    The degree of the polynomial
    Returns
    =======
    centerwvl : float
                    The wavelength where the center of the polynomial lies
    centerval : float
                    The spectrum value where the center of the polynomial lies
    center_fit : Pandas Series object
                    Contains the fit values and wavelength information
    """
    # Grab a subdomain of the spectrum
    sub_spectrum = get_sub_spectrum(spectrum, low_wvl=low_wvl, high_wvl=high_wvl)
    sub_spect_wvls = np.array(sub_spectrum.index.values, dtype=float)
    # Fit a polynomial to the spectrum
    polynomial_fitted = np.polynomial.Polynomial.fit(
        sub_spect_wvls, np.array(sub_spectrum, dtype=float), degree
    )
    center_fit = polynomial_fitted(sub_spect_wvls)
    center_fit = pd.Series(center_fit, index=sub_spect_wvls)
    centerwvl, centerval = band_minimum(center_fit)
    return centerwvl, centerval, center_fit 
[docs]
def band_area(spectrum, low_wvl=None, high_wvl=None, mask_threshold=1.0):
    """
    Compute the area under the spectrum. The user can specify a subsection
    or 'band'
    of the spectra on which to compute the integral.
    Parameters
    ==========
    spectrum : single spectrum represented as a Pandas series, with indices
    corresponding to wavelength values.
    low_wvl (optional) : float
                    The low wavelength (in units of the actual wavelength
                    not the index value)
    high_wvl (optional) : float
                    The high wavelength (in units of the actual wavelength
                    not the index value)
    mask (optional): float
    Spectrum values above this threshold will be set to 0 before
    integration. Useful for excluding bands with bad data.
    Returns
    =======
    (no var) : float
                    The area under the curve
    """
    sub_spectrum = get_sub_spectrum(spectrum, low_wvl, high_wvl)
    sub_spectrum[sub_spectrum > mask_threshold] = 0
    return np.trapz(
        np.array(sub_spectrum, dtype=float),
        x=np.array(sub_spectrum.index.values, dtype=float),
    ) 
[docs]
def band_asymmetry(spectrum, low_wvl=None, high_wvl=None, degree=3):
    """
    Compute the asymmetry of an absorption feature as
    (left_area - right_area) / total_area
    Parameters
    ----------
    spectrum : single spectrum represented as a Pandas series, with indices
    corresponding to wavelength values.
    low_wvl (optional) : float
                    The low wavelength (in units of the actual wavelength
                    not the index value)
    high_wvl (optional) : float
                    The high wavelength (in units of the actual wavelength
                    not the index value)
    degree (optional) : int
                    The degree of the polynomial
    Returns
    -------
    asymmetry : float
        value indicating how asymmetrical the two halves of the band are,
        where 1.0 is completely asymmetrical and 0.0 is completely symmetrical
        asymmetry = abs((area_left - area_right) / (area_left + area_right))
    """
    sub_spectrum = get_sub_spectrum(spectrum, low_wvl, high_wvl)
    center, centerval, center_fit = band_center(
        sub_spectrum, low_wvl=None, high_wvl=None, degree=degree
    )
    area_left = band_area(sub_spectrum, high_wvl=center)
    area_right = band_area(sub_spectrum, low_wvl=center)
    asymmetry = abs((area_left - area_right) / (area_left + area_right))
    return asymmetry