import copy
import numpy as np
import pandas as pd
from libpyhat.transform.interp import interp
[docs]
def shift_spect(df, shifts, spect_label="wvl", meta_label="meta"):
    """This function takes a data frame containing spectra and shifts them
    by a specified amount.
    Arguments:
    df = The data frame. Spectra should be stored in rows, with each column
    having a multi-indexed column name. The
        top level should be the specified 'spect_label' string, the second
        level should be a floating point value indicating the wavelength.
    shifts = The amounts by which the spectra should be shifted.
    spect_label = the string used to label spectral data. Defaults to 'wvl'"""
    # TODO: Make flexible to user supplying spectra that is already shifted
    # Get the x (wavelength) values from the spectra into an array
    wvls = np.array(df[spect_label].columns.values, dtype=float)
    # Get the intensity values from the spectra
    df_spect = df[spect_label]
    # Remove all intensity values from the original spectra object, these will
    # be replaced
    df = df.drop(spect_label, axis=1)
    # Build empty list which will hold the shifted data
    df_list = []
    # Check to see if the user provided a list of values to shift or a single value
    assert (
        isinstance(shifts, list)
        or isinstance(shifts, int)
        or isinstance(shifts, float)
        or isinstance(shifts, np.ndarray)
    )
    # Make the shifts variable into a list that we can loop over, if it isn't a list
    if isinstance(shifts, int) or isinstance(shifts, float):
        shifts = [shifts]
    # *assume* the user wants to retain zero shifted data; we tossed away this data,
    # so it needs to be added back to the final object
    if 0.0 not in shifts:
        shifts += [0]
    # Loop through each shift value and generate spectra
    for shift_value in shifts:
        # make a copy of the intensity of the zero-shifted spectrum
        spectra_copy = copy.deepcopy(df_spect)
        # add the shift amount to the wavelengths
        newcols = [(spect_label, i + shift_value) for i in wvls]
        # replace the original column names with the new, shifted ones
        spectra_copy.columns = pd.MultiIndex.from_tuples(newcols)
        # Reinject the original spectra into a temporary dataframe
        df_tmp = pd.concat([df, spectra_copy], axis=1)
        # If there is no shift, don't perform any interpolation, just reinject the
        # spectra back into the object.
        if shift_value != 0.0:
            # interpolate the shifted data back onto the original set of wavelengths
            df_tmp = interp(df_tmp, wvls)
        # Record the shift amount in a metadata column
        df_tmp[(meta_label, "Shift")] = shift_value
        df_list.append(df_tmp)
    # Combine all the shifted datasets together
    df = pd.concat(df_list)
    return df