import copy
import numpy as np
import numpy.polynomial.polynomial as poly
import scipy.stats as ss
[docs]
def regression(data, fit_wvls):
    """
    Compute a continuum using a standard linear regression.
    Parameters
    ----------
    data : PyHAT spectral data object
    fit_wvls: The wavelengths over which to perform the fit
    Returns:
    y: Array of y values at all wavelengths in the data object (not just
    those used in the fit)
    """
    yvals = data.df[data.spect_label][fit_wvls]
    y = []
    for i in range(yvals.shape[0]):
        m, b, _, _, _ = ss.linregress(fit_wvls, yvals.iloc[i, :])
        y.append(m * data.wvls + b)
    return np.array(y) 
[docs]
def linear(data, fit_wvls):
    """
    Compute a continuum using a line between two points
    Parameters
    ----------
    data : PyHAT spectral data object
    fit_wvls: [x1,x2]
    x1: lower wavelength to use as line starting point
    x2: higher wavelength to use as line end point
    Returns:
    y: Array of y values at all wavelengths in the data object (not just
    those used in the fit)
    """
    if len(fit_wvls) > 2:
        print(
            "Linear fit only takes two wavelengths! Using the first two "
            "elements of the wvls provided."
        )
    x1 = fit_wvls[0]
    x2 = fit_wvls[1]
    y1 = data.df[(data.spect_label, x1)]
    y2 = data.df[(data.spect_label, x2)]
    m = np.array((y2 - y1) / (x2 - x1))
    b = np.array(y1 - (m * x1))
    y = []
    for i in range(len(m)):  # TODO: there must be a faster way to do this
        y.append(m[i] * data.wvls + b[i])
    return np.array(y) 
[docs]
def polynomial(data, fit_wvls, order=2):
    """
    Compute a continuum using a polynomial fit.
    Parameters
    ----------
    data : PyHAT spectral data object
    fit_wvls: The wavelengths over which to perform the fit
    Returns:
    y: Array of y values at all wavelengths in the data object (not just
    those used in the fit)
    """
    yvals = data.df[data.spect_label][fit_wvls]
    y = []
    for i in range(yvals.shape[0]):
        coeffs = poly.polyfit(fit_wvls, yvals.iloc[i, :], order)
        y.append(poly.polyval(data.wvls, coeffs))
    return np.array(y) 
[docs]
def continuum_correction(
    data, fit_wvls, method="linear", poly_order=2, divide=True, verbose=True
):
    fit_wvls = data.closest_wvl(fit_wvls)
    data_copy = copy.deepcopy(data)
    if method == "linear":
        y = linear(data_copy, fit_wvls)
    if method == "regression":
        y = regression(data_copy, fit_wvls)
    if method == "polynomial":
        y = polynomial(data_copy, fit_wvls, order=poly_order)
    if divide is True:
        if verbose is True:
            print("Dividing the data by the derived continuum")
        data_copy.df[data_copy.spect_label] = data_copy.df[data_copy.spect_label] / y
    else:
        if verbose is True:
            print("Subtracting the derived continuum")
        data_copy.df[data_copy.spect_label] = data_copy.df[data_copy.spect_label] - y
    return data_copy, y