import numpy as np
import pandas as pd
from PyQt5 import QtWidgets
import gui.core.regressionMethods as rm
from gui.ui.RegressionTrain import Ui_Form
from gui.util import Qtickle
from gui.util.Modules import Modules
from libpyhat.regression import regression
from libpyhat.spectral_data import SpectralData
[docs]
class RegressionTrain(Ui_Form, Modules):
[docs]
    def setupUi(self, Form):
        self.Form = Form
        super().setupUi(Form)
        Modules.setupUi(self, Form)
        self.regressionMethods() 
[docs]
    def set_yRange(self):
        try:
            yvar = (self.comp_label, self.yVariableList.currentItem().text())
            ymax = self.data[self.chooseDataComboBox.currentText()].df[
                yvar].max()
            ymin = self.data[self.chooseDataComboBox.currentText()].df[
                yvar].min()
            self.yMaxDoubleSpinBox.setValue(ymax)
            self.yMinDoubleSpinBox.setValue(ymin)
        except:
            print(
                'Failed to update Y range. Selected data may be non-numeric!'
            ) 
[docs]
    def refreshLists(self):
        self.changeComboListVars(self.yVariableList, self.yvar_choices())
        self.changeComboListVars(self.xVariableList, self.xvar_choices()) 
[docs]
    def getGuiParams(self):
        """
        Overriding Modules' getGuiParams, because I'll need to do a list of
        lists
        in order to obtain the regressionMethods' parameters
        """
        self.qt = Qtickle.Qtickle(self)
        s = []
        s.append(self.qt.guiSave())
        for items in self.alg:
            s.append(self.alg[items].getGuiParams())
        return s 
[docs]
    def setGuiParams(self, dict):
        """
        Overriding Modules' setGuiParams as we are using a list of lists to
        :param dict:
        :return:
        """
        self.qt = Qtickle.Qtickle(self)
        self.qt.guiRestore(dict[0])
        keys = list(self.alg.keys())
        for i in range(len(dict)):
            self.alg[keys[i - 1]].setGuiParams(dict[i]) 
[docs]
    def selectiveSetGuiParams(self, dict):
        """
        Override Modules' selective Restore function
        Setup Qtickle
        selectively restore the UI, the data to do that will be in the 0th
        element of the dictionary
        We will then iterate through the rest of the dictionary
        Will now restore the parameters for the algorithms in the list,
        Each of the algs have their own selectiveSetGuiParams
        :param dict:
        :return:
        """
        self.qt = Qtickle.Qtickle(self)
        self.qt.selectiveGuiRestore(dict[0])
        keys = list(self.alg.keys())
        for i in range(len(dict)):
            self.alg[keys[i - 1]].selectiveSetGuiParams(dict[i]) 
[docs]
    def run(self):
        if 'Model Coefficients' in self.datakeys:
            pass
        else:
            Modules.data_count += 1
            self.list_amend(
                self.datakeys, Modules.data_count,
                'Model Coefficients'
            )
        if 'Model Means' in self.datakeys:
            pass
        else:
            Modules.data_count += 1
            self.list_amend(
                self.datakeys,
                Modules.data_count,
                'Model Means'
            )
        Modules.model_count += 1
        self.count = Modules.model_count
        method = self.chooseAlgorithmComboBox.currentText()
        datakey = self.chooseDataComboBox.currentText()
        self.comp_label = self.data[datakey].comp_label
        xvars = [str(x.text()) for x in self.xVariableList.selectedItems()]
        yvars = [(self.comp_label, str(y.text())) for y in
            self.yVariableList.selectedItems()]
        yrange = [self.yMinDoubleSpinBox.value(),
            self.yMaxDoubleSpinBox.value()]
        params, modelkey = self.alg[
            self.chooseAlgorithmComboBox.currentText()].run()
        modelkey = "{} - {} - ({}, {}) {}".format(
            method, yvars[0][-1],
            yrange[0], yrange[1],
            modelkey
        )
        self.list_amend(self.modelkeys, self.count, modelkey)
        self.models[modelkey] = regression.regression(
            [method],
            [params]
        )
        x = self.data[datakey].df[xvars]
        y = self.data[datakey].df[yvars]
        x = np.array(x)
        y = np.array(y)
        ymask = np.squeeze((y > yrange[0]) & (y < yrange[1]))
        y = y[ymask]
        x = x[ymask, :]
        self.models[modelkey].fit(x, y)
        self.model_xvars[modelkey] = xvars
        self.model_yvars[modelkey] = yvars
        try:
            coef = np.squeeze(self.models[modelkey].model.coef_)
            coef = pd.DataFrame(coef)
            coef.index = pd.MultiIndex.from_tuples(
                self.data[datakey].df[xvars].columns.values
            )
            coef = coef.T
            coef[('meta', 'Model')] = modelkey
            try:
                coef[('meta', 'Intercept')] = self.models[
                    modelkey].model.intercept_
            except:
                pass
            try:
                self.data['Model Coefficients'] = SpectralData(
                    pd.concat([self.data['Model Coefficients'].df, coef]),
                    name='Model Coefficients', spect_label=xvars
                )
            except:
                self.data['Model Coefficients'] = SpectralData(
                    coef,
                    name='Model '
                         'Coefficients',
                    spect_label=xvars
                )
            # track the x mean from the model too
            model_mean = np.squeeze(self.models[modelkey].model.x_mean_)
            model_mean = pd.DataFrame(model_mean)
            model_mean.index = pd.MultiIndex.from_tuples(
                self.data[datakey].df[xvars].columns.values
            )
            model_mean = model_mean.T
            model_mean[('meta', 'Model')] = modelkey
            model_mean[('meta', 'ymean')] = self.models[modelkey].model.y_mean_
            try:
                self.data['Model Means'] = SpectralData(
                    pd.concat([self.data['Model Means'].df, model_mean]),
                    name='Model Means'
                )
            except:
                self.data['Model Means'] = SpectralData(
                    model_mean,
                    name='Model Means'
                )
        except:
            pass 
[docs]
    def yvar_choices(self):
        try:
            yvarchoices = self.data[self.chooseDataComboBox.currentText()].df[
                self.comp_label].columns.values
            yvarchoices = [i for i in yvarchoices if
                'Unnamed' not in i]  # remove unnamed columns
            # from choices
        except:
            yvarchoices = ['No composition columns!']
        return yvarchoices 
[docs]
    def xvar_choices(self):
        try:
            xvarchoices = \
                
self.data[
                    self.chooseDataComboBox.currentText()].df.columns.levels[
                    0].values
            xvarchoices = [i for i in xvarchoices if
                'Unnamed' not in i]  # remove unnamed columns
            # from choices
        except:
            xvarchoices = ['No valid choices!']
        return xvarchoices 
[docs]
    def hideAll(self):
        for a in self.alg:
            self.alg[a].setHidden(True) 
[docs]
    def regressionMethods(self):
        self.alg = {
            'ARD': rm.ARD.Ui_Form(),
            'BRR': rm.BayesianRidge.Ui_Form(),
            'Elastic Net': rm.ElasticNet.Ui_Form(),
            # 'GP': rm.GP.Ui_Form(),
            # 'KRR': rm.KRR.Ui_Form(),
            'LARS': rm.LARS.Ui_Form(),
            'LASSO': rm.Lasso.Ui_Form(),
            # 'LASSO LARS': rm.LassoLARS.Ui_Form(),
            'OLS': rm.OLS.Ui_Form(),
            'OMP': rm.OMP.Ui_Form(),
            'PLS': rm.PLS.Ui_Form(),
            'Ridge': rm.Ridge.Ui_Form(),
            'SVR': rm.SVR.Ui_Form(),
            'GBR': rm.GBR.Ui_Form(),
            'RF': rm.RF.Ui_Form()
        }
        for item in self.alg:
            self.alg[item].setupUi(self.Form)
            self.algorithmLayout.addWidget(self.alg[item].get_widget())
            self.alg[item].setHidden(True) 
 
if __name__ == "__main__":
    import sys
    app = QtWidgets.QApplication(sys.argv)
    Form = QtWidgets.QWidget()
    ui = RegressionTrain()
    ui.setupUi(Form)
    Form.show()
    sys.exit(app.exec_())