package biolearn.Applications;

import biolearn.Applications.FlowCytometry.LearnStructure;
import biolearn.Applications.FlowCytometry.choicePanel;
import biolearn.BayesianNetwork.Network;
import biolearn.GraphicalModel.CPDs.LinearGaussian;
import biolearn.GraphicalModel.CPDs.Tabular;
import biolearn.GraphicalModel.ContinuousRandomVariable;
import biolearn.GraphicalModel.DiscreteRandomVariable;
import biolearn.GraphicalModel.DiscretizedRandomVariable;
import biolearn.GraphicalModel.Learning.InputData.MultipleData;
import biolearn.GraphicalModel.Learning.Structure.Scores.BDe;
import biolearn.GraphicalModel.Learning.Structure.Scores.MeanSquareError;
import biolearn.GraphicalModel.Learning.SuffStat.JointCounts;
import biolearn.GraphicalModel.Learning.SuffStat.Util.DataPoint;
import biolearn.GraphicalModel.Learning.SuffStat.WholeData;
import biolearn.NotImplementedYet;
import biolearn.PRM.PRMInstance;
import biolearn.PRM.PRMSchema;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Vector;

/* loaded from: input_file:biolearn/Applications/AverageModels.class */
public class AverageModels extends BiolearnApplication {
    public static void main(String[] strArr) {
        try {
            parseSpecFile(strArr[0]);
            File file = new File(String.valueOf(strArr[0]) + ".confidences");
            FileReader fileReader = new FileReader(file);
            char[] cArr = new char[(int) file.length()];
            fileReader.read(cArr, 0, cArr.length);
            if (confidenceThreshold <= 0) {
                confidenceThreshold = 50;
            }
            Vector vector = new Vector();
            vector.add("ActivatedProtein");
            schema = new PRMSchema("biolearn.nolandata", vector, new Vector());
            LearnStructure.choice_panel = new choicePanel(data.VarNames(), data.PerFileConstants());
            LearnStructure.setDiscretizations(false);
            if (LearnStructure.var_class == ContinuousRandomVariable.class) {
                schema.addRelevantAttribute("ActivatedProtein.ContinuousVal");
            } else if (LearnStructure.var_class == DiscreteRandomVariable.class) {
                schema.addRelevantAttribute("ActivatedProtein.DiscreteVal");
            } else if (LearnStructure.var_class == DiscretizedRandomVariable.class) {
                schema.addRelevantAttribute("ActivatedProtein.DiscretizedVal");
            }
            prm = new PRMInstance(schema);
            Iterator<String> it = LearnStructure.choice_panel.relevantVars().iterator();
            while (it.hasNext()) {
                prm.addObject("ActivatedProtein", it.next());
            }
            Network InducedNetwork = prm.InducedNetwork();
            if (LearnStructure.var_class == DiscretizedRandomVariable.class) {
                for (int i = 0; i < LearnStructure.discretizations.length; i++) {
                    if (LearnStructure.discretizations[i] != null) {
                        ((DiscretizedRandomVariable) InducedNetwork.Nodes().get(i)).setDiscretization(LearnStructure.discretizations[i].procedure);
                    }
                }
            }
            InducedNetwork.setDescription(new String(cArr), confidenceThreshold);
            if (scoring_function.CPDType() == LinearGaussian.class || test_data != null) {
                stat = scoring_function.expectedSufficientStatistic();
                stat.setModel(InducedNetwork);
                int round = Math.round(sampleSize > 1.0f ? sampleSize : (sampleSize * data.numDataPoints()) / (data instanceof MultipleData ? ((MultipleData) data).numDatasets() : 1));
                float f = scoring_function instanceof BDe ? ((BDe) scoring_function).phantomDataSize : 0.0f;
                if (sampleSize > 0.0f) {
                    data.Sample(stat, round, with_replacement);
                } else {
                    data.GetAll(stat);
                }
                for (int i2 = 0; i2 < InducedNetwork.Nodes().size(); i2++) {
                    if (scoring_function.CPDType() == LinearGaussian.class) {
                        Integer[] numArr = {Integer.valueOf(i2)};
                        Vector vector2 = new Vector();
                        Iterator<Integer> it2 = InducedNetwork.Structure().getParents(i2).iterator();
                        while (it2.hasNext()) {
                            vector2.add(InducedNetwork.Nodes().get(it2.next().intValue()));
                        }
                        InducedNetwork.Nodes().get(i2).setCPD(MeanSquareError.linearRegression(Arrays.asList(numArr), vector2, (WholeData) stat));
                    } else if (scoring_function.CPDType() == Tabular.class) {
                        InducedNetwork.Nodes().get(i2).setCPD(new Tabular(InducedNetwork.Structure().getParents(i2), InducedNetwork.Nodes().get(i2), (JointCounts) stat, f));
                    } else if (test_data != null) {
                        throw new NotImplementedYet("Cross-validation with multiple models scored with " + scoring_function.CPDType().getSimpleName() + " CPDs");
                    }
                }
            }
            PrintStream printStream = new PrintStream(new FileOutputStream(String.valueOf(strArr[0]) + ".average"));
            WriteAllRecords(printStream, true);
            printStream.print(InducedNetwork.toString());
            printStream.close();
            if (test_data != null) {
                test_data.GetAll(stat);
                double d = 0.0d;
                for (int i3 = 0; i3 < InducedNetwork.Nodes().size(); i3++) {
                    if (InducedNetwork.Nodes().get(i3).MayAddParent(-1)) {
                        Integer[] numArr2 = {Integer.valueOf(i3)};
                        Iterator it3 = stat.Data().iterator();
                        while (it3.hasNext()) {
                            d += InducedNetwork.Nodes().get(i3).CPD().logPDF((DataPoint) it3.next(), Arrays.asList(numArr2));
                        }
                    }
                }
                System.out.println(String.valueOf(strArr[0]) + ": Test data score " + d);
            }
        } catch (Throwable th) {
            th.printStackTrace();
        }
    }
}
