package Catalano.MachineLearning.Regression.RegressionTrees.Learning;

import Catalano.Core.ArraysUtil;
import Catalano.MachineLearning.Dataset.DatasetRegression;
import Catalano.MachineLearning.Dataset.DecisionVariable;
import Catalano.MachineLearning.Regression.IRegression;
import Catalano.MachineLearning.Regression.RegressionTrees.RegressionTree;
import Catalano.Statistics.DescriptiveStatistics;
import Catalano.Statistics.Tools;
import java.io.Serializable;
import java.util.Arrays;

/* loaded from: classes.dex */
public class GradientBoostingTree implements IRegression, Serializable {
    private int J;
    private int T;
    private DecisionVariable[] attributes;
    private double b;
    private double f;
    private double[] importance;
    private Loss loss;
    private double shrinkage;
    private RegressionTree[] trees;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: classes.dex */
    public class HuberNodeOutput implements RegressionTree.NodeOutput {
        double alpha;
        double delta;
        double[] residual;
        double[] response;

        public HuberNodeOutput(double[] dArr, double[] dArr2, double d) {
            this.residual = dArr;
            this.response = dArr2;
            this.alpha = d;
            int length = dArr.length;
            for (int i = 0; i < length; i++) {
                dArr2[i] = Math.abs(dArr[i]);
            }
            this.delta = dArr2[ArraysUtil.Argsort(dArr2, true)[(int) (length * d)]];
            for (int i2 = 0; i2 < length; i2++) {
                double abs = Math.abs(dArr[i2]);
                double d2 = this.delta;
                if (abs <= d2) {
                    dArr2[i2] = dArr[i2];
                } else {
                    dArr2[i2] = d2 * Math.signum(dArr[i2]);
                }
            }
        }

        @Override // Catalano.MachineLearning.Regression.RegressionTrees.RegressionTree.NodeOutput
        public double calculate(int[] iArr) {
            int i = 0;
            for (int i2 : iArr) {
                if (i2 > 0) {
                    i++;
                }
            }
            double[] dArr = new double[i];
            int i3 = 0;
            for (int i4 = 0; i4 < iArr.length; i4++) {
                if (iArr[i4] > 0) {
                    dArr[i3] = this.residual[i4];
                    i3++;
                }
            }
            double Median = DescriptiveStatistics.Median(dArr);
            double d = 0.0d;
            for (int i5 = 0; i5 < iArr.length; i5++) {
                if (iArr[i5] > 0) {
                    double d2 = this.residual[i5] - Median;
                    d += Math.signum(d2) * Math.min(this.delta, Math.abs(d2));
                }
            }
            return Median + (d / i);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: classes.dex */
    public class LADNodeOutput implements RegressionTree.NodeOutput {
        double[] residual;

        public LADNodeOutput(double[] dArr) {
            this.residual = dArr;
        }

        @Override // Catalano.MachineLearning.Regression.RegressionTrees.RegressionTree.NodeOutput
        public double calculate(int[] iArr) {
            int i = 0;
            for (int i2 : iArr) {
                if (i2 > 0) {
                    i++;
                }
            }
            double[] dArr = new double[i];
            int i3 = 0;
            for (int i4 = 0; i4 < iArr.length; i4++) {
                if (iArr[i4] > 0) {
                    dArr[i3] = this.residual[i4];
                    i3++;
                }
            }
            return DescriptiveStatistics.Median(dArr);
        }
    }

    /* loaded from: classes.dex */
    public enum Loss {
        LeastSquares,
        LeastAbsoluteDeviation,
        Huber
    }

    public GradientBoostingTree() {
        this(500);
    }

    public GradientBoostingTree(int i) {
        this(i, 6);
    }

    public GradientBoostingTree(int i, int i2) {
        this(i, i2, Loss.LeastSquares);
    }

    public GradientBoostingTree(int i, int i2, Loss loss) {
        this(i, i2, loss, 0.005d);
    }

    public GradientBoostingTree(int i, int i2, Loss loss, double d) {
        this(i, i2, loss, d, 0.7d);
    }

    public GradientBoostingTree(int i, int i2, Loss loss, double d, double d2) {
        this(null, i, i2, loss, d, d2);
    }

    public GradientBoostingTree(DecisionVariable[] decisionVariableArr) {
        this(decisionVariableArr, 500);
    }

    public GradientBoostingTree(DecisionVariable[] decisionVariableArr, int i) {
        this(decisionVariableArr, i, 6);
    }

    public GradientBoostingTree(DecisionVariable[] decisionVariableArr, int i, int i2) {
        this(decisionVariableArr, i, i2, Loss.LeastSquares);
    }

    public GradientBoostingTree(DecisionVariable[] decisionVariableArr, int i, int i2, Loss loss) {
        this(decisionVariableArr, i, i2, loss, 0.005d);
    }

    public GradientBoostingTree(DecisionVariable[] decisionVariableArr, int i, int i2, Loss loss, double d) {
        this(decisionVariableArr, i, i2, loss, d, 0.7d);
    }

    public GradientBoostingTree(DecisionVariable[] decisionVariableArr, int i, int i2, Loss loss, double d, double d2) {
        this.attributes = null;
        this.b = 0.0d;
        this.loss = Loss.LeastAbsoluteDeviation;
        this.shrinkage = 0.005d;
        this.J = 6;
        this.T = 500;
        this.f = 0.7d;
        this.attributes = decisionVariableArr;
        this.loss = loss;
        this.T = i;
        this.J = i2;
        this.shrinkage = d;
        this.f = d2;
    }

    private int[][] sort(DecisionVariable[] decisionVariableArr, double[][] dArr) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        double[] dArr2 = new double[length];
        int[][] iArr = new int[length2];
        for (int i = 0; i < length2; i++) {
            if (decisionVariableArr[i].type == DecisionVariable.Type.Continuous) {
                for (int i2 = 0; i2 < length; i2++) {
                    dArr2[i2] = dArr[i2][i];
                }
                iArr[i] = ArraysUtil.Argsort(dArr2, true);
            }
        }
        return iArr;
    }

    @Override // Catalano.MachineLearning.Regression.IRegression
    public void Learn(DatasetRegression datasetRegression) {
        Learn(datasetRegression.getInput(), datasetRegression.getOutput());
    }

    @Override // Catalano.MachineLearning.Regression.IRegression
    public void Learn(double[][] dArr, double[] dArr2) {
        double[] dArr3;
        double[] dArr4;
        int i;
        RegressionTree.NodeOutput nodeOutput;
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(dArr2.length)));
        }
        double d = this.shrinkage;
        if (d <= 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Invalid shrinkage: " + this.shrinkage);
        }
        double d2 = this.f;
        if (d2 <= 0.0d || d2 > 1.0d) {
            throw new IllegalArgumentException("Invalid sampling fraction: " + this.f);
        }
        if (this.attributes == null) {
            int length = dArr[0].length;
            this.attributes = new DecisionVariable[length];
            for (int i2 = 0; i2 < length; i2++) {
                this.attributes[i2] = new DecisionVariable("F" + i2);
            }
        }
        int length2 = dArr.length;
        int round = (int) Math.round(length2 * this.f);
        int[] iArr = new int[length2];
        int[] iArr2 = new int[length2];
        for (int i3 = 0; i3 < length2; i3++) {
            iArr[i3] = i3;
        }
        double[] dArr5 = new double[length2];
        RegressionTree.NodeOutput nodeOutput2 = null;
        if (this.loss == Loss.LeastSquares) {
            this.b = Tools.Mean(dArr2);
            for (int i4 = 0; i4 < length2; i4++) {
                dArr5[i4] = dArr2[i4] - this.b;
            }
            dArr3 = dArr5;
        } else {
            if (this.loss == Loss.LeastAbsoluteDeviation) {
                nodeOutput2 = new LADNodeOutput(dArr5);
                System.arraycopy(dArr2, 0, dArr5, 0, length2);
                this.b = DescriptiveStatistics.Median(dArr5);
                dArr4 = new double[length2];
                for (int i5 = 0; i5 < length2; i5++) {
                    dArr5[i5] = dArr2[i5] - this.b;
                    dArr4[i5] = Math.signum(dArr5[i5]);
                }
            } else if (this.loss == Loss.Huber) {
                dArr4 = new double[length2];
                System.arraycopy(dArr2, 0, dArr5, 0, length2);
                this.b = DescriptiveStatistics.Median(dArr5);
                for (int i6 = 0; i6 < length2; i6++) {
                    dArr5[i6] = dArr2[i6] - this.b;
                }
            } else {
                dArr3 = null;
            }
            dArr3 = dArr4;
        }
        int[][] sort = sort(this.attributes, dArr);
        this.trees = new RegressionTree[this.T];
        int i7 = 0;
        while (i7 < this.T) {
            Arrays.fill(iArr2, 0);
            Catalano.Math.Tools.Permutate(iArr);
            for (int i8 = 0; i8 < round; i8++) {
                iArr2[iArr[i8]] = 1;
            }
            if (this.loss == Loss.Huber) {
                i = i7;
                nodeOutput = new HuberNodeOutput(dArr5, dArr3, 0.9d);
            } else {
                i = i7;
                nodeOutput = nodeOutput2;
            }
            this.trees[i] = new RegressionTree(this.attributes, this.J, sort, iArr2, nodeOutput);
            this.trees[i].Learn(dArr, dArr3);
            for (int i9 = 0; i9 < length2; i9++) {
                dArr5[i9] = dArr5[i9] - (this.shrinkage * this.trees[i].Predict(dArr[i9]));
                if (this.loss == Loss.LeastAbsoluteDeviation) {
                    dArr3[i9] = Math.signum(dArr5[i9]);
                }
            }
            i7 = i + 1;
            nodeOutput2 = nodeOutput;
        }
        this.importance = new double[this.attributes.length];
        for (RegressionTree regressionTree : this.trees) {
            double[] importance = regressionTree.getImportance();
            for (int i10 = 0; i10 < importance.length; i10++) {
                double[] dArr6 = this.importance;
                dArr6[i10] = dArr6[i10] + importance[i10];
            }
        }
    }

    @Override // Catalano.MachineLearning.Regression.IRegression
    public double Predict(double[] dArr) {
        double d = this.b;
        for (int i = 0; i < this.T; i++) {
            d += this.shrinkage * this.trees[i].Predict(dArr);
        }
        return d;
    }

    @Override // Catalano.MachineLearning.Regression.IRegression
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public IRegression m12clone() {
        try {
            return (IRegression) super.clone();
        } catch (CloneNotSupportedException e) {
            throw new IllegalArgumentException("Clone not supported: " + e.getMessage());
        }
    }

    public Loss getLossFunction() {
        return this.loss;
    }

    public int getNumLeaves() {
        return this.J;
    }

    public double getSamplingRate() {
        return this.f;
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.length;
    }

    public void trim(int i) {
        RegressionTree[] regressionTreeArr = this.trees;
        if (i > regressionTreeArr.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + i);
        }
        if (i < regressionTreeArr.length) {
            this.trees = (RegressionTree[]) Arrays.copyOf(regressionTreeArr, i);
        }
    }
}
