import math
from numpy import *
from scipy.optimize import nnls
# TODO: This is not fully translated from Matlab.
"""
Corresponds with ELMM_ADMM.m from the toolbox at following link:
https://openremotesensing.net/knowledgebase/spectral-variability-and
-extended-linear-mixing-model/
https://github.com/ricardoborsoi/Unmixing_with_Deep_Generative_Models/blob
/master/other_methods/ELMM_ADMM.m
   The algorithm is presented in detail in:
   L. Drumetz, M. A. Veganzones, S. Henrot, R. Phlypo, J. Chanussot and 
   C. Jutten, "Blind Hyperspectral Unmixing Using an Extended Linear
   Mixing Model to Address Spectral Variability," in IEEE Transactions on 
   Image Processing, vol. 25, no. 8, pp. 3890-3905, Aug. 2016.
"""
[docs]
def elmm_admm(data, A_init, psis_init, S0, lambda_s, lambda_a, lambda_psi, **kwargs):
    """
        Unmix hyperspectral data using the Extended Linear Mixing Model
     We find a stationary point of the following functional:
    # %   J(S,A,PSI) = 1/2 * sum_{k=1}^{N} (||x_k - S_k*a_k||_{2}^{2} +
    # %   ||S_k - S0*psi_{k}||_{F}^{2}) + lambda_A R(A) + lambda_PSI R(PSI)
    # %
    # %   with S a collection of endmember matrices for each pixel, A the
    # %   abundances in each pixel and for each endmember, and PSI the scaling
    # %   factors advocated by the ELMM.
    # %
    # %   The abundances are subject to the usual nonnegativity and sum to one
    # %   constraints. The scaling factors and endmember spectral are nonnegative
    # %   as well.
    # %
    # %   R(A) is a spatial regularization term on the abundances. It can be
    # %   either an anisotropic total variation term TV(A) applied on each
    # %   material or a Tikhonov like regularization on the spatial gradients of
    # %   the abundance maps. R(PSI) is a differentiable Tikhonov regularization
    # %   on the spatial gradients of the scaling factor maps.
         Mandatory inputs:
        -data: m*n*L image cube, where m is the number of rows, n the number of
        columns, and L the number of spectral bands.
        -A_init: P*N initial abundance matrix, with P the number of endmembers
        to consider, and N the number of pixels (N=m*n)
        -psis_init: P*N initial scaling factor matrix
        -S0: L*P reference endmember matrix
        -lambda_s: regularization parameter on the ELMM tightness
        -lambda_a: regularization parameter for the spatial regularization on
        the abundances.
        -lambda_psi: regularization parameter for the spatial regularization on
        the scaling factors
        The spatial regularization parameters can be scalars, in which case
        they will apply in the same way for all the terms of the concerned
        regularizations. If they are vectors, then each term of the sum
        (corresponding to each material) will be differently weighted by the
        different entries of the vector.
       Optional inputs (arguments are to be provided in the same order as in
        the following list):
        -norm_sr: choose norm to use for the spatial regularization on the
        abundances. Can be '2,1' (Tikhonov like penalty on the gradient) or
        '1,1' (Total Variation) (default: '1,1')
        -verbose: flag for display in console. Display if true, no display
        otherwise (default: true)
        -maxiter_anls: maximum number of iterations for the ANLS loop (default:
        100)
        -maxiter_admm: maximum number of iterations for the ADMM loop (default:
        100)
        -epsilon_s: tolerance on the relative variation of S between two
        consecutive iterations (default: 10^(-3))
        -epsilon_a: tolerance on the relative variation of A between two
        consecutive iterations (default: 10^(-3))
        -epsilon_psi: tolerance on the relative variation of psi between two
        consecutive iterations (default: 10^(-3))
        -epsilon_admm_abs: tolerance on the absolute part of the primal and
        dual residuals (default: 10^(-2))
        -epsilon_admm_rel: tolerance on the relative part of the primal and
        dual residuals (default: 10^(-2))
       Outputs:
        -A: P*N abundance matrix
        -psi_maps: P*N scaling factor matrix
        -S: L*P*N tensor constaining all the endmember matrices for each pixel
        -optim_struct: structure containing the values of the objective
        function and its different terms at each iteration
    """
    # set default values for optional parameters
    norm_sr = kwargs.get("norm_sr", "1,1")
    verbose = kwargs.get("verbose", True)
    maxiter_anls = kwargs.get("maxiter_anls", 100)
    maxiter_admm = kwargs.get("maxiter_admm", 100)
    epsilon_s = kwargs.get("epsilon_s", 10 ** (-3))
    epsilon_a = kwargs.get("epsilon_a", 10 ** (-3))
    epsilon_psi = kwargs.get("epsilon_psi", 10 ** (-3))
    epsilon_admm_abs = kwargs.get("epsilon_admm_abs", 10 ** (-2))
    epsilon_admm_rel = kwargs.get("epsilon_admm_rel", 10 ** (-2))
    P = A_init.shape[0]  # number of endmembers
    scalar_lambda_a = False
    scalar_lambda_psi = False
    if lambda_a.shape[1] == 1:
        scalar_lambda_a = True
    elif lambda_a.shape[1] == P:
        if lambda_a.shape[0] == 1:
            lambda_a = lamdba_a.transpose()
    else:
        raise ValueError("lambda_a must be a scalar or a P-dimensional vector")
    if lambda_psi.shape[1] == 1:
        scalar_lambda_psi = True
    elif lambda_psi.shape[1] == P:
        if lambda_psi.shape[0] == 1:
            lambda_psi = lamdba_psi.transpose()
    else:
        raise ValueError("lambda_psi must be a scalar or a P-dimensional vector")
    m, n, L = data.shape
    N = m * n
    # data_r = data.reshape((N, L)).transpose()
    data_r = data.copy().reshape((N, L), order="F").conj().T
    rs = zeros((maxiter_anls, 1))
    ra = zeros((maxiter_anls, 1))
    rpsi = zeros((maxiter_anls, 1))
    A = A_init
    # MATLAB: S = repmat(S0,[1,1,N]);
    S = array([tile(S0, (1, 1)) for i in range(N)]).T
    psi_maps = psis_init
    S0ptS0 = diag(S0.T @ S0)
    S0ptS0 = S0ptS0[None,].T
    objective = zeros((maxiter_anls, 1))
    norm_fitting = zeros((maxiter_anls, 1))
    source_model = zeros((maxiter_anls, 1))
    if scalar_lambda_a:
        TV_a = zeros((maxiter_anls, 1))
    else:
        TV_a = zeros((maxiter_anls, P))
    if scalar_lambda_psi:
        smooth_psi = zeros((maxiter_anls, 1))
    else:
        smooth_psi = zeros((maxiter_anls, P))
    # forward first order horizontal difference operator
    FDh = zeros((m, n))
    FDh[0, n - 1] = -1
    FDh[m - 1, n - 1] = 1
    FDh = fft.fft2(FDh)
    FDhC = conj(FDh)
    # forward first order vertical  difference operator
    FDv = zeros((m, n))
    FDv[0, n - 1] = -1
    FDv[m - 1, n - 1] = 1
    FDv = fft.fft2(FDh)
    FDvC = conj(FDh)
    # barrier parameter of ADMM and related
    rho = zeros((maxiter_admm, 1))
    rho[0] = 10
    tau_incr = 2
    tau_decr = 2
    nu = 10
    # EXPECTED BUG: matrix vs element multiplication
    for i in range(maxiter_anls):
        S_old = S.copy()
        psi_maps_old = psi_maps.copy()
        A_old_anls = A.copy()
        # S_update
        for k in range(N):
            first_op = data_r[:, k] @ A[:, k].conj().T + (lambda_s * S0) @ diag(
                psi_maps[:, k]
            )
            second_op = A[:, k] @ A[:, k].conj().T + lambda_s * eye(P)
            S[:, :, k] = dot(first_op, linalg.pinv(second_op))  # first_op / second_op
            S[:, :, k] = maximum(pow(10, -6), S[:, :, k])
        # A_update
        if any(lambda_a):
            # initialize split variables
            v1 = A
            v1_im = conv2im(v1, m, n, P)
            v2 = ConvC(A, FDh, m, n, P)
            v3 = ConvC(A, FDv, m, n, P)
            v4 = A
            # initialize Lagrange multipliers
            d1 = zeros((P, N))
            d2 = zeros((v2.shape))
            d3 = zeros((v3.shape))
            d4 = zeros((psi_maps.shape))
            mu = zeros((1, N))
            # initialize primal and dual variables
            primal = zeros((maxiter_admm, 1))
            dual = zeros((maxiter_admm, 1))
            # precomputing
            Hvv1 = ConvC(v1, FDv, m, n, P)
            Hhv1 = ConvC(v1, FDh, m, n, P)
            for j in range(maxiter_admm):
                A_old = A
                v1_old = v1
                p_res2_old = v2 - Hhv1
                p_res3_old = v3 - Hvv1
                v4_old = v4
                d1_old = d1
                d4_old = d4
                for k in range(N):
                    ALPHA = S[:, :, k].T @ S[:, :, k] + 2 * rho[j] * eye(P)
                    ALPHA_INVERTED = linalg.inv(ALPHA)
                    BETA = ones((P, 1))
                    s = ALPHA_INVERTED.sum(axis=0)
                    SEC_MEMBER = concatenate(
                        (
                            S[:, :, k].T @ data_r[:, k]
                            + rho[j] * (v1[:, k] + d1[:, k] + v4[:, k] + d4[:, k]),
                            array([1]),
                        ),
                        axis=0,
                    )
                    OMEGA_a = concatenate(
                        (
                            ALPHA_INVERTED
[docs]
                            @ (eye(P) - 1 / s * ones((P, P)) @ ALPHA_INVERTED),
                            1 / s * ALPHA_INVERTED * BETA,
                        ),
                        axis=None,
                    )
                    OMEGA_b = concatenate(
                        (1 / s * BETA.T * ALPHA_INVERTED, -1 / s), axis=None
                    )
                    print(OMEGA_a.shape)
                    print(OMEGA_b.shape)
                    OMEGA_INV = concatenate((OMEGA_a, OMEGA_b), axis=0)
                    print(OMEGA_INV.shape)
                    print(SEC_MEMBER.shape)
                    X = OMEGA_INV @ SEC_MEMBER
                    A[:, k] = X[0:-2]
                    mu[k] = X[-1]
                A_im = conv2im(A, m, n, P)
                d1_im = conv2im(d1, m, n, P)
                d2_im = conv2im(d2, m, n, P)
                d3_im = conv2im(d3, m, n, P)
                v2_im = conv2im(v2, m, n, P)
                v3_im = conv2im(v3, m, n, P)
                # update in the Fourier domain
                for p in range(P):
                    sec_spectral_term = (
                        fft.fft2(squeeze(A_im[:, :, p]) - squeeze(d1_im[:, :, p]))
                        + fft.fft2(
                            squeeze((v2_im[:, :, p] + squeeze(d2_im[:, :, p]))) * FDhC
                            + fft.fft2(
                                squeeze(v3_im[:, :, p] + squeeze(d3_im[:, :, p]))
                            )
                        )
                        * FDvC
                    )
                    v1_im[:, :, p] = dot(
                        real(
                            ftt.ifft2(
                                (sec_spectral_term),
                                linalg.pinv(
                                    ones((m, n)) + abs(FDh) ** 2 + abs(FDv) ** 2
                                ),
                            )
                        )
                    )
                # convert back necessary variables into matrices
                v1 = conv2mat(v1_im)
                Hvv1 = ConvC(v1, FDv)
                Hhv1 = ConvC(v1, FDh)
                # min w.r.t. v2 and v3
                if scalar_lambda_a:
                    if norm_sr == "2,1":
                        v2 = vector_soft_col(-(d2 - Hhv1), lambda_a / rho[j])
                        v3 = vector_soft_col(-(d3 - Hvv1), lambda_a / rho[j])
                    elif norm_sr == "1,1":
                        v2 = soft(-(d2 - Hhv1), lambda_a / rho[j])
                        v3 = soft(-(d3 - Hvv1), lambda_a / rho[j])
                else:
                    if norm_sr == "2,1":
                        for p in range(P):
                            v2[p, :] = vector_soft_col(
                                -(d2[p, :] - Hhv1[p, :]), lambda_a[p] / rho[j]
                            )
                            v3[p, :] = vector_soft_col(
                                -(d3[p, :] - Hvv1[p, :]), lambda_a[p] / rho[j]
                            )
                    elif norm_sr == "1,1":
                        v2[p, :] = soft(-(d2[p, :] - Hhv1[p, :]), lambda_a[p] / rho[j])
                        v3[p, :] = soft(-(d3[p, :] - Hvv1[p, :]), lambda_a[p] / rho[j])
                # min w.r.t. v4
                v4 = max(A - d4, zeros(A.shape))
                # dual update
                # compute necessary variables for the residuals and update
                # lagrange multipliers
                p_res1 = v1 - A
                p_res2 = v2 - Hhv1
                p_res3 = v3 - Hvv1
                p_res4 = v4 - A
                d1 = d1 + p_res1
                d2 = d2 + p_res2
                d3 = d3 + p_res3
                d4 = d4 + p_res4
                # primal and dual residuals
                primal[j] = math.sqrt(
                    linalg.norm(p_res1, "fro") ** 2
                    + linalg.norm(p_res2, "fro") ** 2
                    + linalg.norm(p_res3, "fro") ** 2
                    + linalg.norm(p_res4, "fro") ** 2
                )
                dual[j] = rho[j] * math.sqrt(
                    linalg.norm(v1_old - v1, "fro") ** 2
                    + linalg.norm(v4_old - v4, "fro") ** 2
                )
                # compute termination values
                epsilon_primal = math.sqrt(
                    4 * P * N
                ) * epsilon_admm_abs + epsilon_admm_rel * max(
                    math.sqrt(2 * linalg.norm(A, "fro") ** 2),
                    math.sqrt(
                        linalg.norm(v1_old, "fro") ** 2
                        + linalg.norm(p_res2_old, "fro") ** 2
                        + linalg.norm(p_res3_old, "fro") ** 2
                        + linalg.norm(v4_old, "fro") ** 2
                    ),
                )
                epsilon_dual = math.sqrt(P * N) * epsilon_admm_abs + rho[
                    j
                ] * epsilon_admm_rel * math.sqrt(
                    linalg.norm(d1_old, "fro") + linalg.norm(d4_old, "fro") ** 2
                )
                rel_A = dot(
                    abs(linalg.norm(A, "fro") - linalg.norm(A_old, "fro")),
                    linalg.pinv(linalg.norm(A_old, "fro")),
                )
                # display of admm results
                if verbose:
                    print(
                        f"iter {j}, rel_A = {rel_A}, primal = {primal[j]}, "
                        f"eps_p = {epsilon_primal}, dual = {dual[j]}, "
                        f"eps_d = {epsilon_dual}, rho = {rho[j]}"
                    )
                if j > 1 and ((primal[j] < epsilon_primal and dual[j] < epsilon_dual)):
                    break
                # rho update
                if j < maxiter_admm:
                    if norm(primal[j]) > nu * norm(dual[j]):
                        rho[j + 1] = tau_incr * rho[j]
                        A = A / tau_incr
                    elif norm(dual[j]) < nu * norm(primal[j]):
                        rho[j + 1] = rho[j] / tau_decr
                        A = tau_decr * A
                    else:
                        rho[j + 1] = rho[j]
            # end for loop
        else:
            # without spatial regularization
            for k in range(N):
                A[:, k] = FCLSU(data_r[:, k], S[:, :, k])
            if verbose:
                print("Done")
                print("updating psi..")
            # psi_update
            if any(lambda_psi):
                # with spatial regularization
                if scalar_lambda_psi:
                    for p in range(P):
                        numerator = 0  # TODO
                        psi_maps_im = real(
                            fft.ifft2(
                                fft.fft2(numerator)
                                / (
                                    (
                                        lambda_psi * (abs(FDh) ** 2 + abs(FDv) ** 2)
                                        + lambda_s * S0ptS0[p]
                                    )
                                )
                            )
                        )
                        psi_maps[p, :] = psi_maps_im[:]
                else:
                    for p in range(P):
                        numerator = 0  # TODO (translate from matlab)
                        psi_maps_im = real(
                            fft.ifft2(
                                fft.fft2(numerator)
                                / (
                                    (
                                        lambda_psi[p] * (abs(FDh) ** 2 + abs(FDv) ** 2)
                                        + lambda_s * S0ptS0[p]
                                    )
                                )
                            )
                        )
                        psi_maps[p, :] = psi_maps_im[:]
            else:
                for p in range(P):
                    psi_maps_temp = zeros((N, 1))
                    for k in range(N):
                        psi_maps_temp[k] = (S0[:, p].T @ S[:, p, k]) / S0ptS0[p]
                    psi_maps[p, :] = psi_maps_temp.flatten()
            if verbose:
                print("Done")
            # residuals of the ANLS loops
            rs_vect = zeros((N, 1))
            for k in range(N):
                rs_vect[k] = linalg.norm(
                    squeeze(S[:, :, k]) - squeeze(S_old[:, :, k]), "fro"
                ) / linalg.norm(squeeze(S_old[:, :, k]), "fro")
            rs[i] = rs_vect.mean(axis=0)
            ra[i] = linalg.norm(A[:] - A_old_anls[:], 2) / linalg.norm(A_old_anls[:], 2)
            rpsi[i] = linalg.norm(psi_maps - psi_maps_old, "fro") / (
                linalg.norm(psi_maps_old, "fro")
            )
            # compute objective function value
            SkAk = zeros((L, N))
            S0_psi = ndarray((L, P, N))  # S0_psi initializes automatically in
            # matlab in forloop, manually here
            for k in range(N):
                SkAk[:, k] = squeeze(S[:, :, k] @ A[:, k])
                S0_psi[:, :, k] = S0 * diag(psi_maps[:, k])
            norm_fitting[i] = 1 / 2 * linalg.norm(data_r[:] - SkAk[:]) ** 2
            source_model[i] = 1 / 2 * linalg.norm(S[:] - S0_psi[:]) ** 2
            if any(lambda_psi) and any(
                lambda_a
            ):  # different objective functions depending on
                # the chosen regularizations
                if scalar_lambda_psi:
                    smooth_psi[i] = (
                        1
                        / 2
                        * (
                            sum(sum((ConvC(psi_maps, FDh, m, n, P) ** 2)))
                            + sum(sum((ConvC(psi_maps, FDv, m, n, P) ** 2)))
                        )
                    )
                else:
                    CvCpsih = ConvC(psi_maps, FDh, m, n, P)
                    CvCpsiv = ConvC(psi_maps, FDv, m, n, P)
                    for p in range(P):
                        smooth_psi[i, p] = (
                            1
                            / 2
                            * (
                                sum(sum((CvCpsih[p, :h] ** 2)))
                                + sum(sum((CVCpsiv[p, :] ** 2)))
                            )
                        )
                if scalar_lambda_a:
                    if norm_sr == "2,1":
                        TV_a[i] = sum(
                            sum(
                                math.sqrt(
                                    ConvC(A, FDh, m, n, P) ** 2
                                    + ConvC(A, FDv, m, n, P) ** 2
                                )
                            )
                        )
                    elif norm_sr == "1,1":
                        TV_a[i] = sum(
                            sum(
                                abs(ConvC(A, FDh, m, n, P))
                                + abs(ConvC(A, FDv, m, n, P))
                            )
                        )
                else:
                    CvCAh = ConvC(A, FDh, m, n, P)
                    CvCAv = ConvC(A, FDv, m, n, P)
                    if norm_sr == "2,1":
                        for p in range(P):
                            TV_a[i, p] = sum(
                                sum(math.sqrt(CvCAh[p, :] ** 2 + CvCAv[p, :] ** 2))
                            )
                    elif norm_sr == "1,1":
                        for p in range(P):
                            TV_a[i, p] = sum(sum(abs(CvCAh[p, :]) + abs(CvCAv[p, h])))
                objective[i] = (
                    norm_fitting[i]
                    + lambda_s * source_model[i]
                    + lambda_a.transpose() * TV_a[i, :].transpose()
                    + lambda_psi.transpose * smooth_psi[i, :].transpose()
                )
            elif not (any(lambda_psi)) and any(lambda_a):
                if scalar_lambda_a:
                    if norm_sr == "2,1":
                        TV_a[i] = sum(
                            sum(
                                math.sqrt(
                                    ConvC(A, FDh, m, n, P) ** 2
                                    + ConvC(A, FDv, m, n, P) ** 2
                                )
                            )
                        )
                    elif norm_sr == "1,1":
                        TV_a[i] = sum(
                            sum(
                                abs(ConvC(A, FDh, m, n, P))
                                + abs(ConvC(A, FDv, m, n, P))
                            )
                        )
                else:
                    CvCAh = ConvC(A, FDh, m, n, P)
                    CvCAv = ConvC(A, FDv, m, n, P)
                    if norm_sr == "2,1":
                        for p in range(P):
                            TV_a[i, p] = sum(
                                sum(math.sqrt(CvCAh[p, :] ** 2 + CvCAv[p, :] ** 2))
                            )
                    elif norm_sr == "1,1":
                        for p in range(P):
                            TV_a[i, p] = sum(sum(abs(CvCAh[p, :]) + abs(CvCAv[p, h])))
                objective[i] = (
                    norm_fitting[i]
                    + lambda_s * source_model[i]
                    + lambda_a.transpose() * TV_a[i, :].transpose()
                )
            elif any(lambda_psi) and not (any(lambda_a)):
                if scalar_lambda_psi:
                    smooth_psi[i] = (
                        1
                        / 2
                        * (
                            sum(sum((ConvC(psi_maps, FDh, m, n, P) ** 2)))
                            + sum(sum((ConvC(psi_maps, FDv, m, n, P) ** 2)))
                        )
                    )
                else:
                    CvCpsih = ConvC(psi_maps, FDh, m, n, P)
                    CvCpsiv = ConvC(psi_maps, FDv, m, n, P)
                    for p in range(P):
                        smooth_psi[i, p] = (
                            1
                            / 2
                            * (
                                sum(sum((CvCpsih[p, :h] ** 2)))
                                + sum(sum((CVCpsiv[p, :] ** 2)))
                            )
                        )
                objective[i] = (
                    norm_fitting[i]
                    + lambda_s * source_model[i]
                    + lambda_psi.transpose() * smooth_psi[i, :].transpose()
                )
            else:
                objective[i] = norm_fitting[i] + lambda_s * source_model[i]
            # termination test
            print(f"iteration: {i}")
            if (rs[i] < epsilon_s) and (ra[i] < espilon_a) and (rpsi[i] < epsilon_psi):
                break
    # gather processed output
    """
    Outputs:
    -A: P*N abundance matrix
    -psi_maps: P*N scaling factor matrix
    -S: L*P*N tensor constaining all the endmember matrices for each pixel
    -optim_struct: structure containing the values of the objective
    function and its different terms at each iteration
    """
    outputs = []
    outputs.append(A)
    outputs.append(psi_maps)
    outputs.append(S)
    return outputs 
# # define some auxiliary functions:
# Fully Constrained Linear Spectral Unmixing
# has it's own .m file in source code
# may be useful to implement in it's own file
def FCLSU(HIM, M):
    # depends on scipy nnls import
    # this was a hacky implementation
    # should be tested
    if len(HIM.shape) == 1:
        HIM = HIM[:, None]
    ns = HIM.shape[1]
    l = M.shape[0]
    p = M.shape[1]
    Delta = 1 / 1000
    N = zeros((l + 1, p))
    N[0:l, 0:p] = Delta * M
    N[l, :] = ones((1, p))
    s = zeros((l + 1, 1))
    out = zeros((ns, p))
    for i in range(ns):
        s[0:l] = Delta * HIM[:, i, None]
        s[l] = 1
        Abundances = nnls(N, s.flatten())[0].T
        out[i, :] = Abundances
    return out 
# circular convolution
[docs]
def ConvC(X, FK, m, n, P):
    # matlab:
    # reshape(real(ifft2(fft2(reshape(X', m,n,P)).*repmat(FK,[1,1,P])) ),
    # m*n,P)';
    # AN ESPECIALLY WEIRD m->py DISCREPANCY
    # not exactly sure how this works, but it's what I've been using.
    # MATLAB: S = repmat(S0,[1,1,N]);
    # python: S = array([tile(S0,(1,1)) for i in range(N)]).T
    first_op = fft.fft2(X.T.reshape(m, n, P, order="F"))
    second_op = real(
        fft.ifft2(first_op * array([tile(S0, (1, 1)) for i in range(P)]).T)
    )
    third_op = second_op.reshape(m * n, P).T
    return third_op 
# convert matrix to image
[docs]
def conv2im(A, m, n, P):
    return A.T.reshape((m, n, P)) 
# convert image to matrix
[docs]
def conv2mat(A, m, n, P):
    return A.reshape((m * n, P)).T 
# # soft(x,T) and vector_soft_col(X, tau) are in a separate .m file in
# source code
# soft-thresholding function
[docs]
def soft(x, T):
    if sum(abs(T.flatten(1))) == 0:
        y = x
    else:
        y = max(abs(x) - T, 0)
        y = y / (y + T) * x
    return y 
# computes the vector soft columnwise
[docs]
def vector_soft_col(X, tau):
    NU = math.sqrt(sum(X**2))
    A = max(0, NU - tau)
    Y = kron(ones((size(X, axis=1), 1)), (A / (A + tau))) * X
    return Y 
"""
# code block used for testing by running this .py file
# this test case is NOT sufficient
# only used for getting the basic matrix shapes correct
# It's recommended that a real test case is set up before much further work 
on the algorithm itself
if __name__ == '__main__':
    # print statements are littered around the algorithm
    # this is how it was done in MATLAB
    # should be removed before final PyHAT implementation
    m = 5
    n = 5
    L = 5
    P = 5
    N = m * n
    arb = 5  # arbitrary integer
    
    data = ones((n,n,L))
    A_init = ones((P, N))
    psis_init = ones((P, N))
    S0 = ones((L, P))
    lambda_s = arb
    lambda_a = ones((arb,arb))
    lambda_psi = ones((arb,arb))
    
    output = elmm_admm( data, A_init, psis_init, S0, lambda_s, lambda_a, 
    lambda_psi )
    print( output )
"""