import numpy as np
import pandas as pd
from libpyhat.derived.m3.m3_funcs import bd_func, oneum_continuum, twoum_continuum
from libpyhat.derived.m3.m3_funcs import (
    warn_m3,
    warn_m3_noisy,
    warn_m3_slow,
)
from libpyhat.derived.utils import (
    reflectance,
    get_roi,
    compute_slope,
)
from libpyhat.transform.continuum import continuum_correction
# REFLECTANCE
[docs]
def r540(data):
    """
    Name: R540
    Parameter: 0.55 um reflectance
    Formulation:
    R750 = R539
    Rationale: Reference I/F
    Bands: R539
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    data.df[("parameter", "R540")] = reflectance(data, 539, kernel=0)
    return data 
[docs]
def r750(data):
    """
    Name: R750
    Parameter: 0.75 um reflectance
    Formulation:
    R750 = R749
    Rationale: Reference I/F
    Bands: R749
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    data.df[("parameter", "R750")] = reflectance(data, 749, kernel=0)
    return data 
[docs]
def r1580(data):
    """
    Name: R1580
    Parameter: 1.6 um reflectance
    Formulation:
    R1580 = R1579
    Rationale: IR Albedo
    Bands: R1579
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    data.df[("parameter", "R1580")] = reflectance(data, 1579, kernel=0)
    return data 
[docs]
def r2780(data):
    """
    Name: R2780
    Parameter: 2.8 um reflectance
    Formulation:
    R750 = R2778
    Rationale: Reference I/F
    Bands: R2778
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    data.df[("parameter", "R2780")] = reflectance(data, 2778, kernel=0)
    return data 
# RATIOS
[docs]
def visnir(data):
    """
    Name: VISNIR
    Parameter: Visible-nearIR Ratio
    Formulation:
    VISUV = R699/R1579
    Rationale: Optical Maturity and mare-highland
    Bands: R699, R1579
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    r699 = reflectance(data, 699, kernel=0)
    r1579 = reflectance(data, 1579, kernel=0)
    data.df[("parameter", "VISNIR")] = r699 / r1579
    return data 
[docs]
def r950_750(data):
    """
    Name: R950_750
    Parameter: Ratio of 950nm to 750nm, mafic absorption
    Formulation:
    VISUV = R949/R749
    Rationale: Quick look at mafic absorption
    Bands: R749, R949
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    r749 = reflectance(data, 749, kernel=0)
    r949 = reflectance(data, 949, kernel=0)
    data.df[("parameter", "R950_750")] = r949 / r749
    return data 
[docs]
def twoum_ratio(data):
    """
    Name: 2um_Ratio
    Parameter: 2 um ratio
    Formulation:
    2um_Ratio = R1578/R2538
    Bands: R1578, R2538
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    r1578 = reflectance(data, 1578, kernel=0)
    r2538 = reflectance(data, 2538, kernel=0)
    data.df[("parameter", "2um_ratio")] = r1578 / r2538
    return data 
[docs]
def thermal_ratio(data):
    """
    Name: Thermal_Ratio
    Formulation:
    Thermal_Ratio = R2538/2978
    Bands: R2538, R2978
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    r2538 = reflectance(data, 2538, kernel=0)
    r2978 = reflectance(data, 2978, kernel=0)
    data.df[("parameter", "Thermal_Ratio")] = r2538 / r2978
    return data 
# SLOPES
[docs]
def visslope(data):
    """
    Name: Vis_Slope
    Parameter: UV-visible continuum slope
    Formulation:
    Vis_Slope = (R749 - R419) / (749 - 419)
    Rationale: UV-Vis Slope (%/nm)
    Bands: R419, R749
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    band_wvls = data.closest_wvl([419, 749])
    low_r = reflectance(data, band_wvls[0], kernel=0)
    high_r = reflectance(data, band_wvls[1], kernel=0)
    data.df[("parameter", "vis_slope")] = compute_slope(
        band_wvls[0], band_wvls[1], low_r, high_r
    )
    return data 
[docs]
def oneum_slope(data):
    """
    Name: 1um_Slope
    Parameter: continuum slope between 0.70 and 1.6 um
    Formulation:
    1um_Slope = (R1579 - R699) / (1579 - 699)
    Rationale: Vis-NIR Slope (%/nm)
    Bands: R699, R1579
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    band_wvls = data.closest_wvl([699, 1579])
    low_r = reflectance(data, band_wvls[0], kernel=0)
    high_r = reflectance(data, band_wvls[1], kernel=0)
    data.df[("parameter", "1um_slope")] = compute_slope(
        band_wvls[0], band_wvls[1], low_r, high_r
    )
    return data 
[docs]
def twoum_slope(data):
    """
    Name: 2um_Slope
    Parameter: continuum slope between 1.6 and 2.5 um
    Formulation:
    2um_Slope = (R2538 - R1578) / (2538 - 1578)
    Rationale: NIR Slope
    Bands: R1578, R2538
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    band_wvls = data.closest_wvl([699, 1579])
    low_r = reflectance(data, band_wvls[0], kernel=0)
    high_r = reflectance(data, band_wvls[1], kernel=0)
    data.df[("parameter", "2um_slope")] = compute_slope(
        band_wvls[0], band_wvls[1], low_r, high_r
    )
    return data 
# BAND DEPTHS
[docs]
def bd620(data):
    """
    Name: BD620
    Parameter: Band Depth at 620 nm
    Formulation:
    Numerator = R619
    Denominator = ((R749 - R419) / (749 - 419)) * (619 - 419) + R419
    BD620 = 1 - [Numerator/Denominator]
    Rationale: Possible Ti or Impact Melt
    Bands: R419, R619, R749
    Parameters
    ----------
    data : ndarray
           (n,m,p) array
    Returns
    -------
     : ndarray
       the processed ndarray
    """
    # wavelengths = [419, 619, 749]
    # return utils.generic_func(data, wavelengths, func = m3_funcs.bd_func,
    # pass_wvs=True, **kwargs)
    warn_m3("BD620")
    return 
[docs]
def bd950(data):
    """
    Name: BD950
    Parameter: Band Depth at 950 nm
    Formulation:
    Numerator = R949
    Denominator = ((R1579 - R749) / (1579 - 749)) * (949 - 749) + R749
    BD620 = 1 - [Numerator/Denominator]
    Rationale: OPX Comparison with Kaguya
    Bands: R749, R949, R1579
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    warn_m3_noisy("BD950")
    band_wvls = data.closest_wvl([749, 949, 1579])
    data.df[("parameter", "BD950")] = bd_func(data, band_wvls)
    return data 
[docs]
def bd1050(data):
    """
    Name: BD1050
    Parameter: Band Depth at 1050 nm
    Formulation:
    Numerator = R1049
    Denominator = ((R1579 - R749) / (1579 - 749)) * (1049 - 749) + R749
    BD620 = 1 - [Numerator/Denominator]
    Rationale: OLV Comparison with Kaguya
    Bands: R749, R1049, R1579
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    warn_m3_noisy("BD1050")
    band_wvls = data.closest_wvl([749, 1049, 1579])
    data.df[("parameter", "BD1050")] = bd_func(data, band_wvls)
    return data 
[docs]
def bd1250(data):
    """
     Name: BD1250
     Parameter: Band Depth at 1250 nm
     Formulation:
     Numerator = R1249
     Denominator = ((R1579 - R749) / (1579 - 749)) * (1249 - 749) + R749
     BD620 = 1 - [Numerator/Denominator]
     Rationale: PLAG Comparison with Kaguya
     Bands: R749, R1249, R1579
    Parameters
     ----------
     data : PyHAT SpectralData object
     Returns
     -------
      data: PyHAT SpectralData object with a new column added for the derived
      parameter
    """
    warn_m3_noisy("BD1250")
    band_wvls = data.closest_wvl([749, 1249, 1579])
    data.df[("parameter", "BD1250")] = bd_func(data, band_wvls)
    return data 
[docs]
def bd3000_old(data):
    """
    Name: BD3000
    Parameter: 3 um band depth using 2um continuum
    Formulation:
    Numerator = R2978
    Denominator = ((R2538 - R1578) / (2538 - 1578)) * (2978 - 1578) + R1578
    BD620 = 1 - [Numerator/Denominator]
    Rationale: H2O
    Bands: R1578, R2538, R2978
    Parameters
    ----------
    data : ndarray
           (n,m,p) array
    Returns
    -------
     : ndarray
       the processed ndarray
    """
    warn_m3("BD3000_old") 
    # wavelengths = [1578, 2538, 2978]
    # return utils.generic_func(data, wavelengths, func =
    # m3_funcs.bd3000_func, **kwargs)
[docs]
def bd3000(data):
    """
        Name: BD3000
        Parameter: 3 um band depth
        Formulation:
        HBD=[1-(BB/RC)]
        BB= (R2896+R2936)/2
        RC= (R2657+R2697)/2
        Rationale: Estimate relative OH
        Bands: R2657, R2697, R2896, R2936
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    band_wvls = data.closest_wvl([2657, 2697, 2896, 2936])
    r2657 = reflectance(data, band_wvls[0], kernel=0)
    r2697 = reflectance(data, band_wvls[1], kernel=0)
    r2896 = reflectance(data, band_wvls[2], kernel=0)
    r2936 = reflectance(data, band_wvls[3], kernel=0)
    BB = (r2896 + r2936) / 2
    RC = (r2657 + r2697) / 2
    data.df[("parameter", "BD3000")] = 1 - (BB / RC)
    return data 
[docs]
def bd1900(data):
    """
     Name: BD1900
     Parameter: Band Depth at 1900 nm: low Ca pyroxene index
     Formulation:
     Numerator = R1898
     Denominator = ((R2498 - R1408) / (2498 - 1408)) * (1898 - 1408) + R1408
     BD620 = 1 - [Numerator/Denominator]
     Rationale: pyroxene will be positive; favors LCP
     Bands: R1408, R1898, R2498
    Parameters
     ----------
     data : PyHAT SpectralData object
     Returns
     -------
      data: PyHAT SpectralData object with a new column added for the derived
      parameter
    """
    band_wvls = data.closest_wvl([1408, 1898, 2498])
    data.df[("parameter", "BD1900")] = bd_func(data, band_wvls)
    return data 
[docs]
def bd2300(data):
    """
    Name: BD2300
    Parameter: Band Depth at 2300 nm: low Ca pyroxene index
    Formulation:
    Numerator = R2298
    Denominator = ((R2578 - R1578) / (2578 - 1578)) * (2298 - 1578) + R1578
    BD620 = 1 - [Numerator/Denominator]
    Rationale: pyroxene will be positive; favors LCP
    Bands: R1578, R2298, R2578
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    band_wvls = data.closest_wvl([1578, 2298, 2578])
    data.df[("parameter", "BD2300")] = bd_func(data, band_wvls)
    return data 
[docs]
def bdi1000(data):
    """
    Name: BDI1000
    Parameter: 1 um integrated band depth
    Formulation:
    BDI1000 = Sum with n values 0-26: (1 - [R(789 + 20n) / Rc(789 + 20n)])
    Rationale: Fe Mineralogy
    Bands: R789 - R1308 (in steps of 20)
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    warn_m3_slow("BDI1000")
    bds = []
    n = 0
    while n <= 26:
        x = 789 + 20 * n
        r = reflectance(data, x, kernel=0)
        rc = oneum_continuum(data, x)
        bds.append(1 - r / rc)
    data.df[("parameter", "BDI1000")] = np.sum(bds)
    return data 
[docs]
def bdi2000(data):
    """
    Name: BDI2000
    Parameter: 2 um integrated band depth
    Formulation:
    BDI1000 = Sum with n values 0-21: (1 - [R(1658 + 40n) / Rc2(1658 + 40n)])
    Rationale: Fe Mineralogy
    Bands: R1658 - R2498 (in steps of 40)
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    warn_m3_slow("BDI2000")
    bds = []
    n = 0
    while n <= 21:
        x = 1658 + 40 * n
        r = reflectance(data, x, kernel=0)
        rc = twoum_continuum(data, x)
        bds.append(1 - r / rc)
    data.df[("parameter", "BDI2000")] = np.sum(bds)
    return data 
[docs]
def olindex(data):
    """
    Name: OLINDEX
    Parameter: Olivine Index
    Formulation:
    slope = (R1750 - R650) / (1750 - 650)
    a = 0.1 * [(slope * (860-650) + R650) / R860]
    b = 0.5 * [(slope * (1047-650) + R650) / R1047]
    c = 0.25 * [(slope * (1230-650) + R650) / R1230]
    OLINDEX = a + b + c
    Rationale: Olivine will be strongly positive
    Bands: R650, R860, R1047, R1230, R1750
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    band_wvls = data.closest_wvl([650, 860, 1047, 1230, 1750])
    a = (
        0.1
        * bd_func(data, [band_wvls[0], band_wvls[1], band_wvls[4]])
        / reflectance(data, band_wvls[1], kernel=0)
    )
    b = (
        0.5
        * bd_func(data, [band_wvls[0], band_wvls[2], band_wvls[4]])
        / reflectance(data, band_wvls[2], kernel=0)
    )
    c = (
        0.25
        * bd_func(data, [band_wvls[0], band_wvls[3], band_wvls[4]])
        / reflectance(data, band_wvls[3], kernel=0)
    )
    data.df[("parameter", "OLINDEX")] = a + b + c
    return data 
[docs]
def oneum_min(data):
    """
        Name: 1um_min
        Parameter: 1 um band center
        Formulation:
        Wavelength between 890-1349 at which 1-R/Rc is maximized
        Rationale: Fe mineralogy
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the
     derived parameter
    """
    wvls_range = data.wvls[(data.wvls >= 890) * (data.wvls <= 1349)]
    r = data.df[data.spect_label][wvls_range]
    rc = pd.DataFrame()
    for x in wvls_range:
        rc[x] = oneum_continuum(data, x)
    values = 1 - r / rc
    max_wvl = values.idxmax(axis=1)
    data.df[("parameter", "1um_min")] = max_wvl
    return data 
[docs]
def oneum_fwhm(data, threshold=0.02):
    """
        Name: 1um_FWHM
        Parameter: 1 um full width at half max
        Formulation:
        locate the two points where continuum-removed
        reflectance intersects 0.5*(1-R(1um_min)/Rc(1um_min))
        Rationale: Fe mineralogy
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the
     derived parameter
    """
    # TODO: Make the implementation of this less terrible
    warn_m3_slow("1um_FWHM")
    if ("parameter", "1um_min") not in data.df.columns:
        data = oneum_min(data)
    width = []
    wvl_min = []
    wvl_max = []
    # import matplotlib.pyplot as plot
    for i in range(data.df.shape[0]):
        # get a SpectralData object for the current spectrum
        x = data.df[("meta", "x")].iloc[i]
        y = data.df[("meta", "y")].iloc[i]
        spectrum = get_roi(data, [x, x], [y, y])
        # plot.show()
        # get the wvl of the 1um minimum for the current spectrum and
        # calculate actual and continuum reflectance
        wvl_1um_min = data.df[("parameter", "1um_min")].iloc[i]
        ri = reflectance(spectrum, wvl_1um_min)
        rci = oneum_continuum(spectrum, wvl_1um_min)
        # get the half band depth and find where the spectrum is closest to
        # that value
        half_bd = 0.5 * (1 - (ri / rci))
        half_bd = half_bd.values[0]
        spectrum_cr, cont = continuum_correction(
            spectrum, [699, 1579], method="linear", divide=True, verbose=False
        )
        # from matplotlib import pyplot as plot
        # plot.close()
        # plot.plot(spectrum.wvls,np.squeeze(spectrum_cr.df['wvl'].values))
        # plot.hlines(1-half_bd,0,3000)
        # plot.savefig('debug.png')
        # get the absolute value of the continuum removed spectrum minus (
        # 1-half_bd)
        intersects = np.abs(spectrum_cr.df[spectrum_cr.spect_label] - (1 - half_bd))
        # get the subset where the wvl is less than/greater than the band
        # minimum
        intersects_lt = intersects.iloc[:, spectrum_cr.wvls < wvl_1um_min]
        intersects_gt = intersects.iloc[:, spectrum_cr.wvls > wvl_1um_min]
        # Identify where the spectrum is close to the half band depth value
        where_close_lt = intersects_lt < threshold
        where_close_gt = intersects_gt < threshold
        # find the corresponding wvls
        close_wvls_lt = where_close_lt.columns[np.squeeze(where_close_lt.values)]
        close_wvls_gt = where_close_gt.columns[np.squeeze(where_close_gt.values)]
        if close_wvls_lt.empty or close_wvls_gt.empty:
            print(
                "No intersects found! Consider increasing the threshold "
                "value and visually inspect your spectrum."
            )
            print("Current threshold = " + str(threshold))
            wvl_min_tmp = np.nan
            wvl_max_tmp = np.nan
            width_tmp = np.nan
        else:
            # choose the one closest to the band minimum
            wvl_diff_lt = np.abs(close_wvls_lt - wvl_1um_min)
            wvl_min_tmp = close_wvls_lt[wvl_diff_lt == np.min(wvl_diff_lt)][0]
            wvl_diff_gt = np.abs(close_wvls_gt - wvl_1um_min)
            wvl_max_tmp = close_wvls_gt[wvl_diff_gt == np.min(wvl_diff_gt)][0]
            width_tmp = wvl_max_tmp - wvl_min_tmp
        wvl_min.append(wvl_min_tmp)
        wvl_max.append(wvl_max_tmp)
        width.append(width_tmp)
    data.df[("parameter", "1um_FWHM_min")] = wvl_min
    data.df[("parameter", "1um_FWHM_max")] = wvl_max
    data.df[("parameter", "1um_FWHM")] = width
    return data 
[docs]
def oneum_sym(data, threshold=0.02):
    """
    Name: 1um_symmetry
    Parameter: 1 um band symmetry
    Formulation: a = 1um_min - short wavelength point found in 1um_FWHM
    b = long wavelength point found in 1um_FWHM – 1um_min
    1um_symmetry = b/a
    Rationale: Numbers greater than 1 may be enriched in olivine
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    if ("parameter", "1um_min") not in data.df.columns:
        data = oneum_min(data)
    if ("parameter", "1um_FWHM") not in data.df.columns:
        data = oneum_fwhm(data, threshold=threshold)
    b = data.df[("parameter", "1um_FWHM_max")] - data.df[("parameter", "1um_min")]
    a = data.df[("parameter", "1um_min")] - data.df[("parameter", "1um_FWHM_min")]
    data.df[("parameter", "1um_symmetry")] = b / a
    return data 
[docs]
def bd1um_ratio(data):
    """
    Name: bd1um_ratio
    Parameter: BD930/BD990
    Formulation:
    BD930 = 1-R929/((R1579-R699)/(1579-699)*(929-699)+R699)
    BD990 = 1-R989/((R1579-R699)/(1579-699)*(989-699)+R699)
    bd1um_ratio = BD930/BD990
    Rationale: Enhancement in low Ca pyroxene relative to high Ca pyroxene
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    band_wvls = data.closest_wvl([699, 929, 1579])
    bd930 = bd_func(data, band_wvls)
    band_wvls = data.closest_wvl([699, 989, 1579])
    bd990 = bd_func(data, band_wvls)
    data.df[("parameter", "BD1um_ratio")] = bd930 / bd990
    return data 
[docs]
def bd2um_ratio(data):
    """
    Name: bd2um_ratio
    Parameter: 2um band depth ratio
    Formulation:
    a = 1-R1898/((R2578-R1578)/(2578-1578)*(1898-1578)+R1578)
    b = 1-R2298/((R2578-R1578)/(2578-1578)*(2298-1578)+R1578)
    bd2um_ratio = a/b
    Rationale: Enhancement in low Ca pyroxene relative to high Ca pyroxene
    Parameters
    ----------
    data : PyHAT SpectralData object
    Returns
    -------
     data: PyHAT SpectralData object with a new column added for the derived
     parameter
    """
    band_wvls = data.closest_wvl([1578, 1898, 2578])
    a = bd_func(data, band_wvls)
    band_wvls = data.closest_wvl([1578, 2298, 2578])
    b = bd_func(data, band_wvls)
    data.df[("parameter", "BD2um_ratio")] = a / b
    return data