package biolearn.ModuleNetwork.Learning;

import biolearn.Applications.BiolearnApplication;
import biolearn.GraphicalModel.CPDs.RegressionTree;
import biolearn.GraphicalModel.Learning.LearningException;
import biolearn.GraphicalModel.Learning.Structure.Candidate;
import biolearn.GraphicalModel.Learning.Structure.Constraint;
import biolearn.GraphicalModel.Learning.Structure.ScoringFunction;
import biolearn.GraphicalModel.Learning.SuffStat.Util.RTDP;
import biolearn.GraphicalModel.Learning.SuffStat.Util.RTDPSet;
import biolearn.GraphicalModel.Learning.SuffStat.WholeData;
import biolearn.GraphicalModel.Learning.SufficientStatistic;
import biolearn.GraphicalModel.ModelNode;
import biolearn.GraphicalModel.VariableCPD;
import biolearn.ModuleNetwork.Module;
import biolearn.ModuleNetwork.ModuleAssignment;
import biolearn.ModuleNetwork.Network;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Random;
import java.util.Set;
import java.util.Vector;

/* loaded from: input_file:biolearn/ModuleNetwork/Learning/MaxScoreImprovementAssignment.class */
public class MaxScoreImprovementAssignment implements AssignmentAlgorithm {
    int max_iterations;
    float stop_threshold;
    float relearn_reg_fraction;
    static Random generator = new Random();
    public AssignmentClusteringSpeedUp clusteringSpeedup;
    boolean do_singletons = false;
    private Network network = null;
    private Vector<Integer> memberlist = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:biolearn/ModuleNetwork/Learning/MaxScoreImprovementAssignment$AssignmentClusteringSpeedUp.class */
    public static class AssignmentClusteringSpeedUp {
        public HashSet<Candidate> rejectedCandidates;
        public int modulesPerCluster;
        public int numKmeansIter;
        public int numKmeansRuns;
        public float fractionClustersTest;
        public int movesbeforerecluster;
        boolean test;
        public int missed = 0;
        public int[] modclusass = null;
        public List<List<Number>> cluscenters = null;
        public List<List<Number>> moduleMeans = null;
        List<List<List<Number>>> clusters = null;
        public RandomKMeansAssignment rka = null;

        public AssignmentClusteringSpeedUp(Vector<String> vector) {
            this.rejectedCandidates = null;
            this.modulesPerCluster = 5;
            this.numKmeansIter = 20;
            this.numKmeansRuns = 3;
            this.fractionClustersTest = 0.25f;
            this.movesbeforerecluster = 100;
            this.test = false;
            ListIterator<String> listIterator = vector.listIterator();
            while (listIterator.hasNext()) {
                String lowerCase = listIterator.next().toLowerCase();
                if (lowerCase.startsWith("modulespercluster=")) {
                    this.modulesPerCluster = Integer.parseInt(lowerCase.substring(18));
                } else if (lowerCase.startsWith("numkmeansiter=")) {
                    this.numKmeansIter = Integer.parseInt(lowerCase.substring(14));
                } else if (lowerCase.startsWith("numkmeansruns=")) {
                    this.numKmeansRuns = Integer.parseInt(lowerCase.substring(14));
                } else if (lowerCase.startsWith("testspeedup")) {
                    this.test = true;
                    this.rejectedCandidates = new HashSet<>();
                } else if (lowerCase.startsWith("fractionclusterstest=")) {
                    this.fractionClustersTest = Float.parseFloat(lowerCase.substring(21));
                } else if (lowerCase.startsWith("movesbeforerecluster=")) {
                    this.movesbeforerecluster = Integer.parseInt(lowerCase.substring(21));
                }
            }
        }

        public void initialClustering(Network network, WholeData wholeData, ModuleAssignment moduleAssignment) {
            this.missed = 0;
            Vector<ModelNode> Modules = moduleAssignment.Modules();
            this.moduleMeans = new ArrayList();
            int size = Modules.size();
            Iterator<ModelNode> it = Modules.iterator();
            while (it.hasNext()) {
                this.moduleMeans.add(calculateModuleMean((Module) it.next(), wholeData));
            }
            this.rka = new RandomKMeansAssignment(new Vector());
            this.rka.num_kmeans_runs = this.numKmeansRuns;
            this.rka.num_clusters = size / this.modulesPerCluster;
            this.modclusass = new int[size];
            this.clusters = new Vector();
            for (int i = 0; i < this.rka.num_clusters; i++) {
                this.clusters.add(new Vector());
            }
            int[] iArr = new int[size];
            for (int i2 = 0; i2 < size; i2++) {
                if (network.isModuleMember(i2)) {
                    iArr[i2] = MaxScoreImprovementAssignment.generator.nextInt(this.rka.num_clusters);
                    this.clusters.get(iArr[i2]).add(this.moduleMeans.get(i2));
                } else {
                    iArr[i2] = -1;
                }
            }
            this.cluscenters = RandomKMeansAssignment.performKMeans(this.moduleMeans, this.clusters, null, this.modclusass, this.numKmeansIter);
        }

        public ArrayList<Number> calculateModuleMean(Module module, SufficientStatistic sufficientStatistic) {
            RTDPSet data;
            Float[] fArr = new Float[((WholeData) sufficientStatistic).Data().size()];
            RegressionTree regressionTree = (RegressionTree) module.CPD();
            regressionTree.setLeaves(((WholeData) sufficientStatistic).Data(), module.MemberIndices());
            Set<Integer> MemberIndices = module.MemberIndices();
            int[] iArr = new int[MemberIndices.size()];
            int i = 0;
            Iterator<Integer> it = MemberIndices.iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                iArr[i2] = it.next().intValue();
            }
            for (Integer num : regressionTree.Leaf_indices()) {
                if (num.intValue() == 1) {
                    data = ((WholeData) sufficientStatistic).Data();
                } else {
                    regressionTree.LeafAt(num.intValue());
                    data = regressionTree.getData(num.intValue());
                }
                float mean = iArr.length > 0 ? data.getStats(iArr).mean() : 0.0f;
                Iterator<RTDP> it2 = data.iterator();
                while (it2.hasNext()) {
                    fArr[it2.next().index] = Float.valueOf(mean);
                }
            }
            return new ArrayList<>(Arrays.asList(fArr));
        }

        public void updateModuleMeans(ModuleAssignment moduleAssignment, WholeData wholeData, int i) {
            this.moduleMeans.set(i, calculateModuleMean((Module) moduleAssignment.Modules().get(i), wholeData));
        }

        public void updateModuleClusterAssignmment(int i) {
            this.modclusass[i] = RandomKMeansAssignment.closest_point(this.cluscenters, this.moduleMeans.get(i));
        }

        public void recalculateClusters() {
            this.cluscenters = RandomKMeansAssignment.performKMeans(this.moduleMeans, this.clusters, this.cluscenters, this.modclusass, this.numKmeansIter);
        }
    }

    private static String num_members(Module module) {
        if (module == null) {
            return "1";
        }
        boolean z = (module.CPD() instanceof RegressionTree) && ((RegressionTree) module.CPD()).InnerNodes().isEmpty();
        String str = String.valueOf(z ? "[" : "") + String.valueOf(module.Members().size());
        if (z) {
            str = String.valueOf(str) + ']';
        }
        return str;
    }

    public MaxScoreImprovementAssignment(Vector<String> vector) {
        this.max_iterations = 1;
        this.stop_threshold = 0.05f;
        this.relearn_reg_fraction = 1.0f;
        this.clusteringSpeedup = null;
        ListIterator<String> listIterator = vector.listIterator();
        while (listIterator.hasNext()) {
            String lowerCase = listIterator.next().toLowerCase();
            if (lowerCase.startsWith("relearnfraction=")) {
                this.relearn_reg_fraction = Float.parseFloat(lowerCase.substring(16));
            } else if (lowerCase.startsWith("clusteringspeedup")) {
                this.clusteringSpeedup = new AssignmentClusteringSpeedUp(vector);
            } else if (lowerCase.startsWith("maxiterations=")) {
                this.max_iterations = Integer.parseInt(lowerCase.substring(14));
            } else if (lowerCase.startsWith("stopthreshold=")) {
                this.stop_threshold = Float.parseFloat(lowerCase.substring(14));
            }
        }
    }

    @Override // biolearn.ModuleNetwork.Learning.AssignmentAlgorithm
    public int newAssignment(Network network, Candidate[] candidateArr, SufficientStatistic sufficientStatistic, ScoringFunction scoringFunction, Constraint[] constraintArr) throws LearningException {
        if (this.network != network) {
            this.network = network;
            this.memberlist = new Vector<>();
            for (int i = 0; i < network.Nodes().size(); i++) {
                if (network.isModuleMember(i)) {
                    this.memberlist.add(Integer.valueOf(i));
                }
            }
        }
        int i2 = 0;
        ReassignMember reassignMember = new ReassignMember(this.do_singletons);
        ModuleAssignment moduleAssignment = (ModuleAssignment) candidateArr[0].structure;
        if (BiolearnApplication.debugModules) {
            String str = "beginning members";
            ListIterator<ModelNode> listIterator = moduleAssignment.Modules().listIterator();
            while (listIterator.hasNext()) {
                str = String.valueOf(str) + " " + num_members((Module) listIterator.next());
            }
            System.err.println(str);
        }
        if (this.clusteringSpeedup != null) {
            this.clusteringSpeedup.initialClustering(network, (WholeData) sufficientStatistic, moduleAssignment);
        }
        for (int i3 = 0; i3 < this.max_iterations; i3++) {
            Collections.shuffle(this.memberlist);
            int i4 = 0;
            int i5 = 0;
            ListIterator<Integer> listIterator2 = this.memberlist.listIterator();
            while (listIterator2.hasNext()) {
                int intValue = listIterator2.next().intValue();
                if (moduleAssignment.assignment(intValue) != null) {
                    List<Candidate> allNeighbors = reassignMember.allNeighbors(network, intValue, candidateArr[0], constraintArr, this.clusteringSpeedup, (WholeData) sufficientStatistic);
                    if (!allNeighbors.isEmpty()) {
                        scoringFunction.resetCache();
                        ListIterator<Candidate> listIterator3 = allNeighbors.listIterator();
                        while (listIterator3.hasNext()) {
                            scoringFunction.score(network, listIterator3.next(), sufficientStatistic);
                        }
                        Collections.sort(allNeighbors);
                        if (ScoringFunction.compare(allNeighbors.listIterator().next().score, candidateArr[0].score) > 0.0f) {
                            i2++;
                            i5++;
                            Vector vector = new Vector();
                            ListIterator<Candidate> listIterator4 = allNeighbors.listIterator();
                            Candidate next = listIterator4.next();
                            vector.add(next);
                            while (listIterator4.hasNext()) {
                                Candidate next2 = listIterator4.next();
                                if (ScoringFunction.compare(next2.score, next.score) != 0.0f) {
                                    break;
                                }
                                vector.add(next2);
                            }
                            Candidate candidate = (Candidate) vector.get(generator.nextInt(vector.size()));
                            if (BiolearnApplication.debug) {
                                System.err.println("Trying to reassign " + network.Nodes().get(intValue).Name() + ", current assignment " + moduleAssignment.Assignments()[intValue]);
                                System.err.println("Best score gives " + vector.size() + " options, chose assignment to " + candidate.modification.to + "\nfrom score " + candidate.local_scores[candidate.modification.from].doubleValue() + "->" + candidate.updated_from_score.doubleValue() + "\nto score " + candidate.local_scores[candidate.modification.to].doubleValue() + "->" + candidate.updated_to_score.doubleValue() + "\ntotal score + " + candidateArr[0].score + "->" + candidate.score);
                            }
                            int i6 = candidate.modification.from;
                            int i7 = candidate.modification.to;
                            candidate.applyModification();
                            candidateArr[0] = candidate;
                            moduleAssignment = (ModuleAssignment) candidateArr[0].structure;
                            if (!this.do_singletons && this.clusteringSpeedup != null) {
                                this.clusteringSpeedup.updateModuleMeans(moduleAssignment, (WholeData) sufficientStatistic, i6);
                                this.clusteringSpeedup.updateModuleMeans(moduleAssignment, (WholeData) sufficientStatistic, i7);
                                if (i2 % this.clusteringSpeedup.movesbeforerecluster == 0) {
                                    this.clusteringSpeedup.recalculateClusters();
                                } else {
                                    this.clusteringSpeedup.updateModuleClusterAssignmment(i6);
                                    this.clusteringSpeedup.updateModuleClusterAssignmment(i7);
                                }
                                if (this.clusteringSpeedup.test) {
                                    if (this.clusteringSpeedup.rejectedCandidates.contains(candidate)) {
                                        this.clusteringSpeedup.missed++;
                                        System.err.println("Missed " + this.clusteringSpeedup.missed + " out of " + i2 + " elements moved");
                                    } else {
                                        System.err.println("Missed " + this.clusteringSpeedup.missed + " out of " + i2 + " elements moved");
                                    }
                                }
                            }
                            if (listIterator2.nextIndex() - i4 > this.memberlist.size() * this.relearn_reg_fraction) {
                                Search.learn_regulation_program.light_run(candidateArr[0]);
                                i4 = listIterator2.nextIndex();
                                if (BiolearnApplication.debugModules) {
                                    System.err.println("After mid-assignment adjustment of regulatory programs:\n" + candidateArr[0].toString());
                                }
                            }
                        }
                    }
                }
            }
            if (i4 < this.memberlist.size()) {
                Search.learn_regulation_program.light_run(candidateArr[0]);
            }
            if (BiolearnApplication.debugModules) {
                System.err.println("assignment " + (this.do_singletons ? "singletons " : "normal ") + i2 + " moved at " + new Date().toString());
            }
            if (i5 < this.stop_threshold * network.NumModuleNodes()) {
                break;
            }
        }
        if (BiolearnApplication.debugModules) {
            String str2 = "module members";
            ListIterator<ModelNode> listIterator5 = moduleAssignment.Modules().listIterator();
            while (listIterator5.hasNext()) {
                str2 = String.valueOf(str2) + " " + num_members((Module) listIterator5.next());
            }
            System.err.println(str2);
        }
        scoringFunction.resetCache();
        return i2;
    }

    @Override // biolearn.ModuleNetwork.Learning.AssignmentAlgorithm
    public boolean regulatory_program_stop_condition(VariableCPD[] variableCPDArr, VariableCPD[] variableCPDArr2) {
        return false;
    }

    @Override // biolearn.BiolearnComponent
    public void WriteRecord(PrintStream printStream, boolean z) throws IOException {
        if (z) {
            printStream.print("# ");
        }
        printStream.print("AssignmentAlgorithm MaxScoreImprovementAssignment relearnfraction=" + this.relearn_reg_fraction);
        printStream.println();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void DoSingletons(boolean z) {
        this.do_singletons = z;
    }
}
