import functools
import numpy as np
try:
    from np.linalg import multi_dot
except ImportError:
[docs]
    def multi_dot(arrays):
        return functools.reduce(np.dot, arrays) 
[docs]
def svt_thresh(X, thresh):
    """Solves argmin_X 1/2 ||X-Y||_F^2 + thresh ||X||_*
    proximal operator for spectral norm (rank reducer)
    See: http://www-stat.stanford.edu/~candes/papers/SVT.pdf
    """
    U, s, V = np.linalg.svd(X, full_matrices=False)
    s = np.maximum(0, s - thresh)
    return (U * s).dot(V) 
[docs]
def soft_thresh(X, thresh):
    """Solves argmin_X 1/2 ||X-Y||_F^2 + thresh ||X||_1
    proximal operator for l1-norm (sparsifier)
    See: http://www.simonlucey.com/soft-thresholding/
    """
    return np.sign(X) * np.maximum(0, np.abs(X) - thresh) 
[docs]
def prepare_data(
    A,
    B,
    metaColNameA="Target",
    metaColNameB="Target",
    averageRepeats=True,
    colvar="wvl",
    meta_label="meta",
):
    # Using a column specified by the user, this will identify any rows
    # that do not match in that column and removes them from both datasets.
    # Usually, the 'Target' metadata is used
    A = A.loc[A[(meta_label, metaColNameA)].isin(B[(meta_label, metaColNameB)])]
    B = B.loc[B[(meta_label, metaColNameB)].isin(A[(meta_label, metaColNameA)])]
    # This will alphabetically sort the data according to the column
    # specified by the user
    A = A.sort_values((meta_label, metaColNameA))
    B = B.sort_values((meta_label, metaColNameB))
    # Check to make sure the spectral channels for each dataset are
    # identical, otherwise you are performing calibration transfer
    # on two funamentally different datasets.
    # TODO: Swap these assertions with exception handling.
    assert len(A[colvar].columns) == len(
        B[colvar].columns
    ), "Data sets A and B have different numbers of spectral channels!"
    assert A[colvar].columns.values[0] == B[colvar].columns.values[0], (
        "Data set A and B wavelengths are not identical. Check rounding "
        "and/or resample one data set onto the other's wavelengths"
    )
    # The user may choose to average repeated data and conslidate it
    # into a single spectra. This will not propagate or track uncertainties.
    # To-do: Build in error propagation and tracking
    if averageRepeats:
        # Determine the unique measurements according to the metadata column
        # specified by the user
        A_uniques = np.unique(A[(meta_label, metaColNameA)])
        # If there are no unique measurements, then just use the original data
        A_mean = A
        # Otherwise, lets take averages and drop superfluous data
        if not len(A_uniques) == len(A[(meta_label, metaColNameA)]):
            # Loop through the unique metadata names
            for value in A_uniques:
                # Determine the rows that match the unique value
                rows = A_mean[(meta_label, metaColNameA)] == value
                # Generate a mean spectra from the matching spectra
                avg = np.mean(A_mean.iloc[rows.index[rows]][colvar], axis=0)
                # Build a numpy array of objects to represent row data that
                # we'll inject back into the dataset
                avg_row = np.concatenate(([value], avg.values))
                # Inject the mean spectra into the first row of the matching
                # spectra
                A_mean.loc[rows.index[rows][0]] = avg_row
                # Drop all other spectra after the first row, and just keep the
                # first row of the matching spectra which now houses the
                # mean spectra
                A_mean = A_mean.drop(rows.index[rows][1:])
                # Reset the indices
                A_mean.index = np.arange(len(A_mean[meta_label]))
        # Repeat the process as above for the second dataset
        B_uniques = np.unique(B[(meta_label, metaColNameB)])
        B_mean = B
        if not len(B_uniques) == len(B[(meta_label, metaColNameB)]):
            for value in B_uniques:
                rows = B_mean[(meta_label, metaColNameB)] == value
                avg = np.mean(B_mean.iloc[rows.index[rows]][colvar], axis=0)
                avg_row = np.concatenate(([value], avg.values))
                B_mean.loc[rows.index[rows][0]] = avg_row
                B_mean = B_mean.drop(rows.index[rows][1:])
                B_mean.index = np.arange(len(B_mean[meta_label]))
    # make sure we're still working with floats
    A_mean[colvar] = A_mean[colvar].astype(float)
    B_mean[colvar] = B_mean[colvar].astype(float)
    return A_mean, B_mean 
[docs]
def check_data(data1, data2, label1, label2, spect_label="wvl"):
    # TODO: Swap this assertion with exception handling.
    assert len(data1[spect_label].columns) == len(data2[spect_label].columns), (
        "Data sets "
        + label1
        + " and "
        + label2
        + (" have different numbers " "of spectral channels!")
    )
    assert (
        data1[spect_label].columns.values[-1] == data2[spect_label].columns.values[-1]
    ), (
        "Data set "
        + label1
        + " and "
        + label2
        + (
            " wavelengths are not "
            "identical. Check "
            "rounding and/or resample "
            "one data set onto the "
            "other's wavelengths"
        )
    )