package biolearn.GraphicalModel.Learning.SuffStat;

import biolearn.GraphicalModel.Learning.LearningException;
import biolearn.GraphicalModel.Learning.ObservationCondition;
import biolearn.GraphicalModel.Learning.SuffStat.NormalGammaStat;
import biolearn.GraphicalModel.Learning.SuffStat.Util.DataPoint;
import biolearn.GraphicalModel.Learning.SuffStat.Util.RTDP;
import biolearn.GraphicalModel.Learning.SuffStat.Util.RTDPSet;
import biolearn.GraphicalModel.Learning.SufficientStatistic;
import biolearn.GraphicalModel.RandomVariable;
import biolearn.Inconsistency;
import biolearn.NotImplementedYet;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Random;
import java.util.Set;
import java.util.Vector;

/* loaded from: input_file:biolearn/GraphicalModel/Learning/SuffStat/WholeData.class */
public class WholeData extends SufficientStatistic {
    public static final int NONE = 0;
    public static final int VARIANCE = 1;
    public static final int LENGTH = 2;
    public List<Integer> point_indices;
    private int normalize = 0;
    private float[] means = null;
    private double[] stds = null;
    public Set<Integer> fixed_vars = null;
    private int currently_permuted_var = -1;
    private Random generator = new Random();

    @Override // biolearn.GraphicalModel.Learning.SufficientStatistic
    public boolean compatibleVar(RandomVariable randomVariable) {
        return true;
    }

    @Override // biolearn.GraphicalModel.Learning.SufficientStatistic
    public void initialize(ObservationCondition[] observationConditionArr) throws LearningException {
        super.initialize(observationConditionArr);
        if (this.possible_conditions[0] != null) {
            throw new NotImplementedYet("normal gamma stat not including the null condition");
        }
        this.data = new List[this.same_var_conditions.size()];
        for (int i = 0; i < this.same_var_conditions.size(); i++) {
            this.data[i] = new RTDPSet();
        }
        this.normalize = 0;
        this.point_indices = new Vector();
    }

    @Override // biolearn.GraphicalModel.Learning.SufficientStatistic
    public void addDataPoint(DataPoint dataPoint, int i) {
        super.addDataPoint(dataPoint, i);
        ((RTDP) dataPoint).index = this.data[0].size() - 1;
        if (dataPoint.inhibition_originals != null) {
            for (int i2 = 0; i2 < dataPoint.inhibition_originals.length; i2++) {
                if (dataPoint.inhibition_originals[i2] != null) {
                    ((RTDP) dataPoint.inhibition_originals[i2]).index = ((RTDP) dataPoint).index;
                }
            }
        }
        this.point_indices.add(Integer.valueOf(this.point_indices.size()));
    }

    public void SetDataPointIndex(int i) {
        this.point_indices.set(this.point_indices.size() - 1, Integer.valueOf(i));
    }

    public int DataPointIndex(int i) {
        return this.point_indices.get(i).intValue();
    }

    public void permute(int i, int i2) {
        Vector vector = new Vector((RTDPSet) this.data[0]);
        if (i2 < 0 || i2 == i) {
            Collections.shuffle(vector);
        } else {
            Vector vector2 = new Vector();
            Vector vector3 = new Vector();
            int intValue = this.vars[i2].minValue().intValue();
            for (int i3 = 0; i3 < this.vars[i2].numValues(); i3++) {
                vector2.add(new Vector());
            }
            ListIterator listIterator = vector.listIterator();
            while (listIterator.hasNext()) {
                RTDP rtdp = (RTDP) listIterator.next();
                ((List) vector2.get(toInt(rtdp, i2) - intValue)).add(rtdp);
            }
            ListIterator listIterator2 = vector2.listIterator();
            while (listIterator2.hasNext()) {
                List list = (List) listIterator2.next();
                Collections.shuffle(list);
                vector3.add(list.listIterator());
            }
            vector.clear();
            ListIterator<RTDP> listIterator3 = ((RTDPSet) this.data[0]).listIterator();
            while (listIterator3.hasNext()) {
                vector.add((RTDP) ((ListIterator) vector3.get(toInt(listIterator3.next(), i2) - intValue)).next());
            }
        }
        boolean z = this.currently_permuted_var >= 0 && this.currently_permuted_var != i;
        ListIterator<RTDP> listIterator4 = ((RTDPSet) this.data[0]).listIterator();
        ListIterator listIterator5 = vector.listIterator();
        while (listIterator4.hasNext()) {
            RTDP next = listIterator4.next();
            RTDP rtdp2 = (RTDP) listIterator5.next();
            if (z) {
                next.undoPermutation(this.currently_permuted_var);
            }
            next.set(i, rtdp2.orig_val[i], true);
        }
        this.currently_permuted_var = i;
    }

    private int toInt(RTDP rtdp, int i) {
        float f = rtdp.val[i];
        if (this.normalize == 2) {
            f = (float) (f * Math.sqrt(this.data[0].size()));
        }
        if (this.normalize != 0) {
            f = (f * ((float) this.stds[i])) + this.means[i];
        }
        return Math.round(f);
    }

    public void undo_permutation() {
        if (this.currently_permuted_var >= 0) {
            ListIterator<RTDP> listIterator = ((RTDPSet) this.data[0]).listIterator();
            while (listIterator.hasNext()) {
                listIterator.next().undoPermutation(this.currently_permuted_var);
            }
        }
        this.currently_permuted_var = -1;
    }

    public void GetAll(int i, SufficientStatistic sufficientStatistic) {
        try {
            sufficientStatistic.initialize(null);
            ListIterator<RTDP> listIterator = ((RTDPSet) this.data[i < 0 ? 0 : this.same_var_conditions_index[i]]).listIterator();
            while (listIterator.hasNext()) {
                sufficientStatistic.addDataPoint(listIterator.next(), 0);
            }
            sufficientStatistic.endOfData();
        } catch (LearningException e) {
            throw new Inconsistency(e.toString());
        }
    }

    public void GetAll(SufficientStatistic sufficientStatistic) {
        GetAll(-1, sufficientStatistic);
    }

    public void SampleWithReplacement(int i, SufficientStatistic sufficientStatistic, int i2) {
        try {
            sufficientStatistic.initialize(null);
            for (int i3 = 0; i3 < i2; i3++) {
                sufficientStatistic.addDataPoint(((RTDPSet) this.data[i < 0 ? 0 : this.same_var_conditions_index[i]]).get(this.generator.nextInt(this.data[0].size())), 0);
            }
            sufficientStatistic.endOfData();
        } catch (LearningException e) {
            throw new Inconsistency(e.toString());
        }
    }

    @Override // biolearn.GraphicalModel.Learning.SufficientStatistic
    public RTDPSet Data() {
        return (RTDPSet) this.data[0];
    }

    @Override // biolearn.GraphicalModel.Learning.SufficientStatistic
    public RTDPSet Data(int i) {
        return (RTDPSet) this.data[this.same_var_conditions_index[i]];
    }

    public void GetValidation(SufficientStatistic sufficientStatistic, int i, int i2) {
        if (i == 1) {
            GetAll(sufficientStatistic);
            return;
        }
        if (this.same_var_conditions.size() > 1) {
            throw new NotImplementedYet("Cross-validation with observation conditions");
        }
        try {
            sufficientStatistic.initialize(null);
            ListIterator<RTDP> listIterator = ((RTDPSet) this.data[0]).listIterator();
            while (listIterator.nextIndex() < (this.data[0].size() * i2) / i) {
                sufficientStatistic.addDataPoint(listIterator.next(), 0);
            }
            ListIterator<RTDP> listIterator2 = ((RTDPSet) this.data[0]).listIterator((this.data[0].size() * (i2 + 1)) / i);
            while (listIterator2.hasNext()) {
                sufficientStatistic.addDataPoint(listIterator2.next(), 0);
            }
            sufficientStatistic.endOfData();
        } catch (LearningException e) {
            throw new Inconsistency(e.toString());
        }
    }

    public void GetTest(SufficientStatistic sufficientStatistic, int i, int i2) {
        try {
            sufficientStatistic.initialize(null);
            ListIterator<RTDP> listIterator = ((RTDPSet) this.data[0]).listIterator((this.data[0].size() * i2) / i);
            while (listIterator.nextIndex() < (this.data[0].size() * (i2 + 1)) / i) {
                sufficientStatistic.addDataPoint(listIterator.next(), 0);
            }
            sufficientStatistic.endOfData();
        } catch (LearningException e) {
            throw new Inconsistency(e.toString());
        }
    }

    public void fill_in_missing(int[] iArr) {
        if (this.possible_conditions.length > 1) {
            throw new NotImplementedYet("Clustering with observation conditions");
        }
        calculate_mean_and_std();
        this.incomplete_vars.clear();
        ListIterator<RTDP> listIterator = Data().listIterator();
        while (listIterator.hasNext()) {
            RTDP next = listIterator.next();
            HashMap hashMap = new HashMap();
            Iterator<Integer> it = next.missing_values.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                int i = iArr[intValue];
                if (i == -1) {
                    this.incomplete_vars.add(Integer.valueOf(intValue));
                } else {
                    Float f = (Float) hashMap.get(Integer.valueOf(i));
                    if (f == null) {
                        NormalGammaStat.Stat stat = new NormalGammaStat.Stat();
                        for (int i2 = 0; i2 < next.val.length; i2++) {
                            if (iArr[i2] == i && !next.missing_values.contains(Integer.valueOf(i2))) {
                                stat.add((next.val[i2] - this.means[i2]) / ((float) this.stds[i2]));
                            }
                        }
                        f = new Float(stat.count == 0 ? 0.0f : stat.mean());
                        hashMap.put(Integer.valueOf(i), f);
                    }
                    next.set(intValue, this.means[intValue] + (f.floatValue() * ((float) this.stds[intValue])));
                }
            }
        }
    }

    public void Normalize(int i) {
        if (this.normalize == i) {
            return;
        }
        if (this.normalize != 0) {
            throw new NotImplementedYet("undoing normalization");
        }
        this.normalize = i;
        calculate_mean_and_std();
        this.incomplete_vars.clear();
        this.point_indices.clear();
        RTDPSet rTDPSet = (RTDPSet) this.data[0];
        for (int i2 = 0; i2 < this.data.length; i2++) {
            this.data[i2] = new RTDPSet();
        }
        for (int i3 = 0; i3 < rTDPSet.size(); i3++) {
            RTDP rtdp = rTDPSet.get(i3);
            float[] fArr = new float[rtdp.val.length];
            for (int i4 = 0; i4 < rtdp.val.length; i4++) {
                fArr[i4] = this.stds[i4] == 0.0d ? 0.0f : (rtdp.childValue(i4) - this.means[i4]) / ((float) this.stds[i4]);
                if (this.normalize == 2) {
                    fArr[i4] = (float) (fArr[r1] / Math.sqrt(rTDPSet.size()));
                }
            }
            addDataPoint(fArr, rtdp.condition);
        }
        this.fixed_vars = new HashSet();
        for (int i5 = 0; i5 < this.stds.length; i5++) {
            if (this.stds[i5] == 0.0d) {
                this.fixed_vars.add(Integer.valueOf(i5));
            }
        }
    }

    private void calculate_mean_and_std() {
        if (this.means != null) {
            return;
        }
        this.means = new float[((RTDP) this.data[0].get(0)).val.length];
        this.stds = new double[((RTDP) this.data[0].get(0)).val.length];
        for (int i = 0; i < this.means.length; i++) {
            NormalGammaStat.Stat stat = new NormalGammaStat.Stat();
            Iterator<RTDP> it = ((RTDPSet) this.data[0]).iterator();
            while (it.hasNext()) {
                RTDP next = it.next();
                if (!next.missing_values.contains(Integer.valueOf(i))) {
                    stat.add(next.val[i]);
                }
            }
            this.means[i] = stat.mean();
            this.stds[i] = stat.std();
        }
    }

    public Set<Integer> fixedVars() {
        return this.fixed_vars;
    }

    public float getMean(int i) {
        if (this.normalize == 0) {
            return 0.0f;
        }
        return this.means[i];
    }

    public double getMultiplier(int i) {
        if (this.normalize == 0) {
            return 1.0d;
        }
        double d = this.stds[i];
        if (this.normalize == 2) {
            d *= Math.sqrt(this.data[0].size());
        }
        return d;
    }

    public float getMean(Collection<Integer> collection) {
        if (collection.isEmpty()) {
            return 0.0f;
        }
        float f = 0.0f;
        Iterator<Integer> it = collection.iterator();
        while (it.hasNext()) {
            f += getMean(it.next().intValue());
        }
        return f / collection.size();
    }

    public double getMultiplier(Collection<Integer> collection) {
        if (collection.isEmpty()) {
            return 1.0d;
        }
        double d = 0.0d;
        Iterator<Integer> it = collection.iterator();
        while (it.hasNext()) {
            d += getMultiplier(it.next().intValue());
        }
        return d / collection.size();
    }

    public float getOverallMean() {
        float f = 0.0f;
        for (int i = 0; i < this.vars.length; i++) {
            f += getMean(i);
        }
        return f / this.vars.length;
    }

    public double getOverallMultiplier() {
        double d = 0.0d;
        for (int i = 0; i < this.vars.length; i++) {
            d += getMultiplier(i);
        }
        return d / this.vars.length;
    }
}
