package biolearn.ModuleNetwork.Learning;

import biolearn.Applications.ExpandModules;
import biolearn.GraphicalModel.CPDs.RegressionTree;
import biolearn.GraphicalModel.Learning.LearningException;
import biolearn.GraphicalModel.Learning.Structure.Candidate;
import biolearn.GraphicalModel.Learning.Structure.Constraint;
import biolearn.GraphicalModel.Learning.Structure.Scores.NormalGamma;
import biolearn.GraphicalModel.Learning.Structure.ScoringFunction;
import biolearn.GraphicalModel.Learning.SuffStat.WholeData;
import biolearn.GraphicalModel.Learning.SufficientStatistic;
import biolearn.GraphicalModel.VariableCPD;
import biolearn.Inconsistency;
import biolearn.ModuleNetwork.ModuleAssignment;
import biolearn.ModuleNetwork.Network;
import biolearn.NotImplementedYet;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.Vector;

/* loaded from: input_file:biolearn/ModuleNetwork/Learning/SingleModuleRefinement.class */
public class SingleModuleRefinement implements AssignmentAlgorithm {
    float score_worsening_cutoff;
    float likelihood_improvement_cutoff;
    float expansion_threshold;
    int relevant_levels;
    private List<String> args_cache;

    public SingleModuleRefinement(Vector<String> vector) {
        this.score_worsening_cutoff = -1.0f;
        this.likelihood_improvement_cutoff = -1.0f;
        this.expansion_threshold = -1.0f;
        this.relevant_levels = 3;
        this.args_cache = new Vector(vector);
        Search.stop_threshold = -1.0f;
        Iterator<String> it = vector.iterator();
        while (it.hasNext()) {
            String lowerCase = it.next().toLowerCase();
            if (lowerCase.startsWith("maxiterations=")) {
                Search.max_iterations = Integer.parseInt(lowerCase.substring(14));
            } else if (lowerCase.startsWith("scorecutoff=")) {
                this.score_worsening_cutoff = Float.parseFloat(lowerCase.substring(12));
            } else if (lowerCase.startsWith("likelihoodcutoff=")) {
                this.likelihood_improvement_cutoff = Float.parseFloat(lowerCase.substring(17));
            } else if (lowerCase.startsWith("pruningleaveoutfraction=")) {
                Search.leave_out_fraction = Float.parseFloat(lowerCase.substring(24));
                Search.final_tree_pruning = true;
            } else if (lowerCase.startsWith("improvementdropthreshold=")) {
                Search.improvement_drop_threshold = Float.parseFloat(lowerCase.substring(25));
                Search.final_tree_pruning = true;
            } else if (lowerCase.startsWith("sloperatiothreshold=")) {
                Search.slope_ratio_threshold = Float.parseFloat(lowerCase.substring(20));
                if (Search.slope_ratio_threshold < 1.0f) {
                    Search.slope_ratio_threshold = 1.0f / Search.slope_ratio_threshold;
                }
                Search.final_tree_pruning = true;
            } else if (lowerCase.startsWith("correlationthreshold=")) {
                Search.split_side_correlation_threshold = Float.parseFloat(lowerCase.substring(21));
                Search.final_tree_pruning = true;
            } else if (lowerCase.startsWith("expansionthreshold=")) {
                this.expansion_threshold = 1.0f - Float.parseFloat(lowerCase.substring(19));
            } else if (lowerCase.startsWith("relevanttreelevels=")) {
                this.relevant_levels = Integer.parseInt(lowerCase.substring(19));
            }
        }
        if (this.score_worsening_cutoff <= 0.0f && this.likelihood_improvement_cutoff <= 0.0f && this.expansion_threshold <= 0.0f) {
            throw new Inconsistency("SingleModuleRefinement requires specifying parameters for module narrowing and/or expansion");
        }
    }

    @Override // biolearn.ModuleNetwork.Learning.AssignmentAlgorithm
    public int newAssignment(Network network, Candidate[] candidateArr, SufficientStatistic sufficientStatistic, ScoringFunction scoringFunction, Constraint[] constraintArr) throws LearningException {
        if (!(scoringFunction instanceof NormalGamma)) {
            throw new NotImplementedYet("Single-module refinement for " + scoringFunction.getClass().getSimpleName() + " score");
        }
        if (candidateArr[0].structure.numChildren() != 1) {
            throw new NotImplementedYet("Single-module refinement for more than one module");
        }
        int size = candidateArr[0].structure.constituents(0).size();
        if (this.score_worsening_cutoff > 0.0f || this.likelihood_improvement_cutoff > 0.0f) {
            ((ModuleAssignment) candidateArr[0].structure).filter(scoringFunction, (WholeData) sufficientStatistic, sufficientStatistic, this.score_worsening_cutoff, this.likelihood_improvement_cutoff);
        }
        int size2 = size - candidateArr[0].structure.constituents(0).size();
        if (this.expansion_threshold > 0.0f) {
            List<Set<Integer>> compute_new_members = ExpandModules.compute_new_members(network, candidateArr[0], candidateArr[0].CPDs, scoringFunction, sufficientStatistic, this.expansion_threshold, 1);
            size2 += compute_new_members.get(0).size();
            Iterator<Integer> it = compute_new_members.get(0).iterator();
            while (it.hasNext()) {
                ((ModuleAssignment) candidateArr[0].structure).reassign(it.next().intValue(), 0);
            }
        }
        return size2;
    }

    @Override // biolearn.ModuleNetwork.Learning.AssignmentAlgorithm
    public boolean regulatory_program_stop_condition(VariableCPD[] variableCPDArr, VariableCPD[] variableCPDArr2) {
        return ((RegressionTree) variableCPDArr[0]).top_regulators(this.relevant_levels).equals(((RegressionTree) variableCPDArr2[0]).top_regulators(this.relevant_levels));
    }

    @Override // biolearn.BiolearnComponent
    public void WriteRecord(PrintStream printStream, boolean z) throws IOException {
        if (z) {
            printStream.print("# ");
        }
        printStream.print("AssignmentAlgorithm SingleModuleRefinement");
        Iterator<String> it = this.args_cache.iterator();
        while (it.hasNext()) {
            printStream.print(" " + it.next());
        }
        printStream.println();
    }
}
