package biolearn.GraphicalModel.Learning.SuffStat;

import biolearn.GraphicalModel.CPDs.RegressionTree;
import biolearn.GraphicalModel.Learning.SuffStat.Util.RTDP;
import biolearn.GraphicalModel.Learning.SuffStat.Util.RTDPSet;
import biolearn.GraphicalModel.RandomVariable;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Vector;

/* loaded from: input_file:biolearn/GraphicalModel/Learning/SuffStat/NormalGammaStat.class */
public class NormalGammaStat extends WholeData {

    /* loaded from: input_file:biolearn/GraphicalModel/Learning/SuffStat/NormalGammaStat$Stat.class */
    public static class Stat {
        public float threshold;
        public int count;
        public int dpcount;
        public float sum;
        public double sumsq;
        public float multiplier;

        public Stat() {
            this.threshold = Float.NaN;
            this.count = 0;
            this.dpcount = 0;
            this.sum = 0.0f;
            this.sumsq = 0.0d;
            this.multiplier = 1.0f;
        }

        public Stat(Stat stat) {
            this.threshold = stat.threshold;
            this.count = stat.count;
            this.dpcount = stat.dpcount;
            this.sum = stat.sum;
            this.sumsq = stat.sumsq;
            this.multiplier = stat.multiplier;
        }

        public Stat(float[] fArr) {
            this.threshold = Float.NaN;
            this.count = 0;
            this.sum = 0.0f;
            this.sumsq = 0.0d;
            this.multiplier = 1.0f;
            for (float f : fArr) {
                add(f);
            }
        }

        public void add(Stat stat) {
            this.count += stat.count;
            this.dpcount += stat.dpcount;
            this.sum += stat.sum;
            this.sumsq += stat.sumsq;
        }

        public void substract(Stat stat) {
            this.count -= stat.count;
            this.dpcount -= stat.dpcount;
            this.sum -= stat.sum;
            this.sumsq -= stat.sumsq;
        }

        public void add(float f) {
            if (Float.isNaN(f)) {
                return;
            }
            this.count++;
            this.sum += f;
            this.sumsq += f * f;
        }

        public void add(RTDP rtdp, int i) {
            if (rtdp.isNaN[i]) {
                return;
            }
            this.count++;
            this.sum += rtdp.val[i];
            this.sumsq += rtdp.sq[i];
        }

        public void multiply(float f) {
            this.multiplier *= f;
        }

        public String toString() {
            return "Regression Tree Stat " + this.count + '*' + this.multiplier + ' ' + this.sum + ' ' + this.sumsq;
        }

        public double variance() {
            if (this.count == 0) {
                return 0.0d;
            }
            return (this.sumsq - ((this.sum * this.sum) / this.count)) / this.count;
        }

        public double std() {
            return Math.sqrt(variance());
        }

        public float mean() {
            return this.sum / this.count;
        }
    }

    /* loaded from: input_file:biolearn/GraphicalModel/Learning/SuffStat/NormalGammaStat$orderByVar.class */
    private class orderByVar implements Comparator<RTDP> {
        int var;

        public orderByVar(int i) {
            this.var = i;
        }

        @Override // java.util.Comparator
        public int compare(RTDP rtdp, RTDP rtdp2) {
            return rtdp.val[this.var] < rtdp2.val[this.var] ? -1 : 1;
        }
    }

    @Override // biolearn.GraphicalModel.Learning.SuffStat.WholeData, biolearn.GraphicalModel.Learning.SufficientStatistic
    public boolean compatibleVar(RandomVariable randomVariable) {
        return !randomVariable.MayAddParent(-1) || randomVariable.CPD() == null || (randomVariable.CPD() instanceof RegressionTree);
    }

    public RTDPSet getDataFromNode(int[] iArr, RegressionTree.Node[] nodeArr, int i) {
        if (i == 1) {
            return (RTDPSet) this.data[this.same_var_conditions_index[iArr[0]]];
        }
        int i2 = -1;
        do {
            i2++;
        } while (nodeArr[i2].index != i / 2);
        return i % 2 == 0 ? nodeArr[i2].left_data : nodeArr[i2].right_data;
    }

    public Stat getStats(int[] iArr) {
        return ((RTDPSet) this.data[this.same_var_conditions_index[iArr[0]]]).getStats(iArr);
    }

    public List<Stat> getStats(int[] iArr, RegressionTree.Node[] nodeArr, RTDPSet rTDPSet) {
        Vector vector = new Vector();
        RTDPSet dataFromNode = getDataFromNode(iArr, nodeArr, nodeArr[0].index);
        if (dataFromNode != null && !dataFromNode.isEmpty()) {
            rTDPSet.addAll(dataFromNode);
            int Index = nodeArr[0].var.Index();
            Collections.sort(rTDPSet, new orderByVar(Index));
            Stat stat = new Stat();
            ListIterator<RTDP> listIterator = rTDPSet.listIterator();
            RTDP next = listIterator.next();
            float f = next.val[Index];
            while (true) {
                if (next.val[Index] > f) {
                    float f2 = next.val[Index];
                    stat.threshold = f2;
                    f = f2;
                    vector.add(new Stat(stat));
                }
                stat.dpcount++;
                if (next.fixedStat != null) {
                    stat.add(next.fixedStat);
                } else {
                    for (int i : iArr) {
                        stat.add(next, i);
                    }
                }
                if (!listIterator.hasNext()) {
                    break;
                }
                next = listIterator.next();
            }
            Stat stat2 = new Stat();
            float f3 = next.val[Index];
            while (listIterator.hasPrevious()) {
                RTDP previous = listIterator.previous();
                if (previous.val[Index] < f3) {
                    f3 = previous.val[Index];
                    vector.add(new Stat(stat2));
                }
                stat2.dpcount++;
                if (previous.fixedStat != null) {
                    stat2.add(previous.fixedStat);
                } else {
                    for (int i2 : iArr) {
                        stat2.add(previous, i2);
                    }
                }
            }
        }
        return vector;
    }

    public void fixVars(Collection<Integer> collection) {
        for (int i = 0; i < this.data.length; i++) {
            for (RTDP rtdp : this.data[i]) {
                rtdp.fixedStat = new Stat();
                Iterator<Integer> it = collection.iterator();
                while (it.hasNext()) {
                    rtdp.fixedStat.add(rtdp, it.next().intValue());
                }
            }
        }
    }

    public void unfixVars() {
        for (int i = 0; i < this.data.length; i++) {
            Iterator it = this.data[i].iterator();
            while (it.hasNext()) {
                ((RTDP) it.next()).fixedStat = null;
            }
        }
    }
}
