package biolearn.GraphicalModel.Learning;

import biolearn.Applications.BiolearnApplication;
import biolearn.BayesianNetwork.Network;
import biolearn.GraphicalModel.Learning.Structure.Algorithms.MinReg;
import biolearn.GraphicalModel.Learning.SuffStat.Util.DataPoint;
import biolearn.GraphicalModel.Learning.SuffStat.Util.RTDP;
import biolearn.GraphicalModel.Learning.SuffStat.Util.RTDPSet;
import biolearn.GraphicalModel.Model;
import biolearn.GraphicalModel.RandomVariable;
import biolearn.Inconsistency;
import biolearn.NotImplementedYet;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;
import java.util.Vector;

/* loaded from: input_file:biolearn/GraphicalModel/Learning/SufficientStatistic.class */
public abstract class SufficientStatistic {
    public static final int NOPARTITIONS = 0;
    public static final int DISTANCEPARTITIONS = 1;
    public static final int SIZEPARTITIONS = 2;
    protected List<Collection<ObservationCondition>> same_var_conditions;
    public Model model = null;
    protected RandomVariable[] vars = null;
    protected Collection<ObservationCondition> encountered_conditions = null;
    protected int default_condition = 0;
    protected Set<Integer> current_noncovered_vars = null;
    protected int[] same_var_conditions_index = null;
    protected float missing_val_fraction = Float.NaN;
    protected Set<Integer> incomplete_vars = null;
    protected List[] data = null;
    protected ObservationCondition[] possible_conditions = new ObservationCondition[1];

    public RandomVariable[] Vars() {
        return this.vars;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SufficientStatistic() {
        this.same_var_conditions = null;
        this.possible_conditions[0] = null;
        this.same_var_conditions = new Vector();
        this.same_var_conditions.add(new Vector());
    }

    public void fromSample(Model model, RTDPSet rTDPSet) throws LearningException {
        setModel(model);
        initialize(null);
        for (int i = 0; i < rTDPSet.size(); i++) {
            addDataPoint(rTDPSet.get(i), 0);
        }
        endOfData();
    }

    public void setModel(Model model) {
        this.model = model;
        this.vars = new RandomVariable[model.Nodes().size()];
        for (int i = 0; i < this.vars.length; i++) {
            this.vars[i] = (RandomVariable) model.Nodes().get(i);
            if (!compatibleVar(this.vars[i])) {
                throw new Inconsistency(String.valueOf(this.vars[i].Name()) + " can't be handled by " + getClass().getSimpleName());
            }
        }
        if (this.same_var_conditions_index == null) {
            this.same_var_conditions_index = new int[this.vars.length];
            Arrays.fill(this.same_var_conditions_index, 0);
        }
    }

    public void addDataPoint(float[] fArr, int i) {
        addDataPoint(new RTDP(fArr), i);
    }

    public void addDataPoint(DataPoint dataPoint, int i) {
        if (i == 0) {
            i = this.default_condition;
        }
        dataPoint.condition = i;
        if (this.possible_conditions[i] != null) {
            if (this.encountered_conditions != null) {
                this.encountered_conditions.add(this.possible_conditions[i]);
            }
            if (i != 0 && BiolearnApplication.debug) {
                System.err.println("Observation condition " + this.possible_conditions[i]);
            }
            Iterator<Integer> it = this.possible_conditions[i].Compelled().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                if (BiolearnApplication.debug) {
                    System.err.println("Applying condition on " + intValue + ' ' + this.model.Nodes().get(intValue) + " inhibitor " + this.possible_conditions[i].isActivityInhibition(intValue) + " zeroing " + this.possible_conditions[i].isZeroIntervention(intValue));
                }
                if (this.possible_conditions[i].isActivityInhibition(intValue)) {
                    this.data[this.same_var_conditions_index[intValue]].add(dataPoint.setActivityInhibition(intValue, this.vars[intValue].minValue().floatValue()));
                } else if (this.possible_conditions[i].isZeroIntervention(intValue)) {
                    dataPoint.set(intValue, this.vars[intValue].minValue().floatValue());
                }
            }
        }
        this.incomplete_vars.addAll(dataPoint.missing_values);
        for (int i2 = 0; i2 < this.same_var_conditions.size(); i2++) {
            if (!this.same_var_conditions.get(i2).contains(this.possible_conditions[i])) {
                this.data[i2].add(dataPoint);
            }
        }
    }

    public void setDefaultCondition(int i) {
        this.default_condition = i;
    }

    public void setNoncoveredVars(Set<Integer> set) {
        this.current_noncovered_vars = set;
    }

    public List<List<Number>> DataPointVectors() {
        Vector vector = new Vector();
        ListIterator listIterator = all_data_points().listIterator();
        while (listIterator.hasNext()) {
            DataPoint dataPoint = (DataPoint) listIterator.next();
            Vector vector2 = new Vector();
            for (int i = 0; i < this.vars.length; i++) {
                float childValue = dataPoint.childValue(i);
                if (Math.abs(childValue - Math.round(childValue)) < 1.0E-6d) {
                    vector2.add(new Integer(Math.round(childValue)));
                } else {
                    vector2.add(new Float(childValue));
                }
            }
            vector.add(vector2);
        }
        return vector;
    }

    public List<List<Number>> VarVectors() {
        Vector vector = new Vector();
        for (int i = 0; i < this.vars.length; i++) {
            Vector vector2 = new Vector();
            ListIterator listIterator = all_data_points().listIterator();
            while (listIterator.hasNext()) {
                float childValue = ((DataPoint) listIterator.next()).childValue(i);
                if (Math.abs(childValue - Math.round(childValue)) < 1.0E-6d) {
                    vector2.add(new Integer(Math.round(childValue)));
                } else {
                    vector2.add(new Float(childValue));
                }
            }
            vector.add(vector2);
        }
        return vector;
    }

    public Set<Integer> incompleteVars() {
        return this.incomplete_vars;
    }

    public float missing_values_fraction() {
        if (Float.isNaN(this.missing_val_fraction)) {
            long length = this.vars.length * numDataPoints();
            long j = 0;
            for (DataPoint dataPoint : all_data_points()) {
                length -= dataPoint.noncovered_vars == null ? 0 : dataPoint.noncovered_vars.size();
                j += dataPoint.missing_values.size() - r12;
            }
            this.missing_val_fraction = ((float) j) / ((float) length);
        }
        return this.missing_val_fraction;
    }

    public List all_data_points() {
        return this.data[0];
    }

    public void initialize(ObservationCondition[] observationConditionArr) throws LearningException {
        this.incomplete_vars = new HashSet();
        if (observationConditionArr == null) {
            return;
        }
        this.possible_conditions = observationConditionArr;
        for (int i = 0; i < observationConditionArr.length; i++) {
            if (observationConditionArr[i] != null) {
                observationConditionArr[i].setVars(this.vars);
            }
        }
        this.encountered_conditions = new HashSet();
        if (this.same_var_conditions_index == null) {
            return;
        }
        Vector vector = new Vector();
        for (int i2 = 0; i2 < this.vars.length; i2++) {
            vector.add(new Vector());
        }
        for (int i3 = 1; i3 < observationConditionArr.length; i3++) {
            if (observationConditionArr[i3] != null) {
                Iterator<Integer> it = observationConditionArr[i3].Compelled().iterator();
                while (it.hasNext()) {
                    ((Collection) vector.get(it.next().intValue())).add(observationConditionArr[i3]);
                }
            }
        }
        Arrays.fill(this.same_var_conditions_index, 0);
        this.same_var_conditions = new Vector();
        this.same_var_conditions.add(new Vector());
        for (int i4 = 0; i4 < this.vars.length; i4++) {
            this.same_var_conditions_index[i4] = this.same_var_conditions.indexOf(vector.get(i4));
            if (this.same_var_conditions_index[i4] < 0) {
                this.same_var_conditions_index[i4] = this.same_var_conditions.size();
                this.same_var_conditions.add((Collection) vector.get(i4));
            }
        }
        if (this.same_var_conditions.size() > 1 && !(this.model instanceof Network)) {
            throw new NotImplementedYet("Observation conditions with " + this.model.getClass().getName());
        }
        if (this.same_var_conditions.size() > 1 && (BiolearnApplication.algorithm instanceof MinReg)) {
            throw new NotImplementedYet("Observation conditions with MinReg");
        }
    }

    public void endOfData() {
    }

    public abstract boolean compatibleVar(RandomVariable randomVariable);

    public int numDataPoints() {
        return this.data[0].size();
    }

    public int numDataPoints(int i) {
        return this.data[this.same_var_conditions_index[i]].size();
    }

    public List Data() {
        return this.data[0];
    }

    public List Data(int i) {
        return this.data[this.same_var_conditions_index[i]];
    }

    public Collection<Integer> Compelled() {
        HashSet hashSet = new HashSet();
        if (this.encountered_conditions != null) {
            Iterator<ObservationCondition> it = this.encountered_conditions.iterator();
            while (it.hasNext()) {
                hashSet.addAll(it.next().Compelled());
            }
        }
        return hashSet;
    }

    public void dump(PrintStream printStream) throws IOException {
        for (int i = 0; i < this.data.length; i++) {
            printStream.println("For condition " + this.same_var_conditions.get(i));
            Iterator it = this.data[i].iterator();
            while (it.hasNext()) {
                printStream.println(it.next());
            }
            printStream.println();
        }
    }
}
