package biolearn.GraphicalModel.Learning.Structure.Algorithms;

import biolearn.Applications.BiolearnApplication;
import biolearn.BayesianNetwork.Inference.BeliefPropagation;
import biolearn.BayesianNetwork.Network;
import biolearn.GraphicalModel.CPDs.Tabular;
import biolearn.GraphicalModel.Learning.LearningException;
import biolearn.GraphicalModel.Learning.Structure.Candidate;
import biolearn.GraphicalModel.Learning.Structure.Scores.BDe;
import biolearn.GraphicalModel.Learning.Structure.SearchAlgorithm;
import biolearn.GraphicalModel.Learning.SuffStat.JointCounts;
import biolearn.GraphicalModel.Learning.SuffStat.Util.ADTree;
import biolearn.GraphicalModel.Learning.SuffStat.Util.DataPoint;
import biolearn.GraphicalModel.Learning.SuffStat.Util.DiscreteDataPoint;
import biolearn.GraphicalModel.Learning.SufficientStatistic;
import biolearn.GraphicalModel.Model;
import biolearn.GraphicalModel.ModelStructure;
import biolearn.GraphicalModel.PDAG;
import biolearn.GraphicalModel.VariableCPD;
import biolearn.Inconsistency;
import biolearn.NotImplementedYet;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Collection;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Vector;

/* loaded from: input_file:biolearn/GraphicalModel/Learning/Structure/Algorithms/InferenceLoop.class */
public class InferenceLoop extends SearchAlgorithm {
    private SearchAlgorithm learning;
    private float max_missing_values;
    private int max_iterations;
    private Map<DataPoint, int[]> previous_values;
    private List<String> args_cache;

    public InferenceLoop(Vector<String> vector) throws Exception {
        this.learning = null;
        this.max_missing_values = 0.02f;
        this.max_iterations = 10;
        this.args_cache = new Vector(vector);
        Iterator<String> it = vector.iterator();
        while (it.hasNext()) {
            String next = it.next();
            int indexOf = next.indexOf(61);
            if (indexOf > 0) {
                if (next.substring(0, indexOf).equalsIgnoreCase("maxMissingValues")) {
                    this.max_missing_values = Float.parseFloat(next.substring(indexOf + 1));
                } else if (next.substring(0, indexOf).equalsIgnoreCase("maxIterations")) {
                    this.max_iterations = Integer.parseInt(next.substring(indexOf + 1));
                } else if (next.substring(0, indexOf).equalsIgnoreCase("learning")) {
                    this.learning = (SearchAlgorithm) BiolearnApplication.subclassConstructor("Structure.Algorithms", next.substring(indexOf + 1)).newInstance(vector);
                }
            }
        }
        if (this.learning == null) {
            throw new Inconsistency("Learning algorithm for inference loop not specified");
        }
        ADTree.lazy_threshold = Integer.MAX_VALUE;
        ADTree.cache_threshold = Integer.MAX_VALUE;
    }

    @Override // biolearn.GraphicalModel.Learning.Structure.SearchAlgorithm
    public Collection<Candidate> run() throws LearningException {
        Collection<Candidate> run;
        int i;
        if (!(this.model instanceof Network)) {
            throw new NotImplementedYet("Inference loop can only be used with Bayesian networks");
        }
        if (!(this.score instanceof BDe)) {
            throw new NotImplementedYet("Inference loop can only be used with BDe scoring");
        }
        if (this.data.missing_values_fraction() > this.max_missing_values) {
            throw new LearningException("Too many missing values; " + (this.data.missing_values_fraction() * 100.0f) + "% of data is missing");
        }
        this.previous_values = new HashMap();
        fill_missing_values(this.data, null);
        int i2 = 1;
        Model model = null;
        do {
            this.data.endOfData();
            this.score.resetCache();
            this.learning.initialize(this.model, this.start, this.mod, this.score, this.data, this.constraints, this.choice_test);
            run = this.learning.run();
            Candidate next = run.iterator().next();
            if (BiolearnApplication.debugModules) {
                System.err.print("Iteration " + i2 + " at " + new Date().toString() + ": ");
                if (model != null) {
                    System.err.println(compare(model.Structure(), next.structure));
                }
                System.err.println(this.model.toString(next.structure));
            }
            unfill_missing_values(this.data);
            VariableCPD[] variableCPDArr = new VariableCPD[this.model.CandidateChildren().size()];
            for (int i3 = 0; i3 < variableCPDArr.length; i3++) {
                variableCPDArr[i3] = new Tabular(next.structure.getParents(i3), this.model.CandidateChildren().get(i3), (JointCounts) this.data, ((BDe) this.score).phantomDataSize);
            }
            model = this.model.m44clone();
            model.learnedStructure(next.structure, variableCPDArr);
            if (fill_missing_values(this.data, model) <= 0) {
                break;
            }
            i = i2;
            i2++;
        } while (i < this.max_iterations);
        return run;
    }

    private long fill_missing_values(SufficientStatistic sufficientStatistic, Model model) {
        BeliefPropagation beliefPropagation = model == null ? null : new BeliefPropagation((Network) model);
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        for (DiscreteDataPoint discreteDataPoint : sufficientStatistic.all_data_points()) {
            if (!discreteDataPoint.missing_values.isEmpty() && (discreteDataPoint.noncovered_vars == null || discreteDataPoint.noncovered_vars.size() != discreteDataPoint.missing_values.size())) {
                if (BiolearnApplication.debug) {
                    System.err.println("Data point is " + discreteDataPoint);
                }
                if (beliefPropagation != null) {
                    j3++;
                    beliefPropagation.clear();
                    for (int i = 0; i < model.CandidateChildren().size(); i++) {
                        if (!discreteDataPoint.missing_values.contains(Integer.valueOf(i))) {
                            beliefPropagation.setEvidence(i, String.valueOf(((JointCounts) sufficientStatistic).minValue(i) + discreteDataPoint.values[i] + (((double) discreteDataPoint.weight[i]) < 0.5d ? 1 : 0)));
                        }
                    }
                    beliefPropagation.run();
                }
                int[] iArr = this.previous_values.get(discreteDataPoint);
                if (iArr == null) {
                    Map<DataPoint, int[]> map = this.previous_values;
                    int[] iArr2 = (int[]) discreteDataPoint.values.clone();
                    iArr = iArr2;
                    map.put(discreteDataPoint, iArr2);
                }
                for (Integer num : discreteDataPoint.missing_values) {
                    if (discreteDataPoint.isNaN[num.intValue()] && (discreteDataPoint.noncovered_vars == null || !discreteDataPoint.noncovered_vars.contains(num))) {
                        j++;
                        int numValues = beliefPropagation == null ? ((JointCounts) sufficientStatistic).numValues(num.intValue()) / 2 : Math.round(beliefPropagation.getBeliefs(num.intValue()).HighestLikelihoodValue(null));
                        discreteDataPoint.set(num.intValue(), numValues);
                        if (numValues != iArr[num.intValue()]) {
                            iArr[num.intValue()] = numValues;
                            j2++;
                        }
                    }
                }
                if (BiolearnApplication.debug) {
                    System.err.println("Data point became " + discreteDataPoint);
                }
            }
        }
        sufficientStatistic.incompleteVars().clear();
        if (BiolearnApplication.debugModules) {
            System.err.println("Finished inference at " + new Date().toString() + ": changed " + j2 + " values out of " + j + " by running " + j3 + " inferences");
        }
        return j2;
    }

    private void unfill_missing_values(SufficientStatistic sufficientStatistic) {
        for (DiscreteDataPoint discreteDataPoint : sufficientStatistic.all_data_points()) {
            for (Integer num : discreteDataPoint.missing_values) {
                discreteDataPoint.values[num.intValue()] = -1;
                discreteDataPoint.isNaN[num.intValue()] = true;
                discreteDataPoint.weight[num.intValue()] = Float.NaN;
                sufficientStatistic.incompleteVars().add(num);
            }
        }
    }

    public static String compare(ModelStructure modelStructure, ModelStructure modelStructure2) {
        int numChildren = modelStructure.numChildren();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < numChildren; i4++) {
            Iterator<Integer> it = modelStructure.getParents(i4).iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                if (!modelStructure2.hasEdge(intValue, i4) && !modelStructure2.hasEdge(i4, intValue)) {
                    i2++;
                } else if (modelStructure2.hasEdge(i4, intValue) || ((modelStructure instanceof PDAG) && (modelStructure2 instanceof PDAG) && ((((PDAG) modelStructure).hasCompelledEdge(intValue, i4) && !((PDAG) modelStructure2).hasCompelledEdge(intValue, i4)) || (!((PDAG) modelStructure).hasCompelledEdge(intValue, i4) && ((PDAG) modelStructure2).hasCompelledEdge(intValue, i4))))) {
                    i3++;
                }
            }
            Iterator<Integer> it2 = modelStructure2.getParents(i4).iterator();
            while (it2.hasNext()) {
                int intValue2 = it2.next().intValue();
                if (!modelStructure.hasEdge(intValue2, i4) && !modelStructure.hasEdge(i4, intValue2)) {
                    i++;
                }
            }
        }
        return String.valueOf(String.valueOf(i)) + " edges added, " + i2 + " removed, " + i3 + " directions changed";
    }

    @Override // biolearn.BiolearnComponent
    public void WriteRecord(PrintStream printStream, boolean z) throws IOException {
        if (z) {
            printStream.print("# ");
        }
        printStream.print(String.valueOf(BiolearnApplication.initial_structure == this ? "InitialStructure" : "Algorithm") + " InferenceLoop");
        Iterator<String> it = this.args_cache.iterator();
        while (it.hasNext()) {
            printStream.print(" " + it.next());
        }
        printStream.println();
    }
}
