package biolearn.ModuleNetwork;

import biolearn.Applications.BiolearnApplication;
import biolearn.GraphicalModel.CPDs.RegressionTree;
import biolearn.GraphicalModel.Learning.Structure.Candidate;
import biolearn.GraphicalModel.Learning.Structure.ScoringFunction;
import biolearn.GraphicalModel.Learning.SuffStat.NormalGammaStat;
import biolearn.GraphicalModel.Learning.SuffStat.Util.RTDP;
import biolearn.GraphicalModel.Learning.SuffStat.WholeData;
import biolearn.GraphicalModel.Learning.SufficientStatistic;
import biolearn.GraphicalModel.ModelNode;
import biolearn.GraphicalModel.PDAG;
import biolearn.GraphicalModel.VariableCPD;
import biolearn.Inconsistency;
import biolearn.ModuleNetwork.Learning.ReassignMember;
import biolearn.NotImplementedYet;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.ListIterator;
import java.util.Set;
import java.util.Vector;

/* loaded from: input_file:biolearn/ModuleNetwork/ModuleAssignment.class */
public class ModuleAssignment extends PDAG {
    private Vector<ModelNode> modules = new Vector<>();
    private int[] assignments;
    private Network containing_network;

    public ModuleAssignment(Network network, int[] iArr) {
        this.containing_network = network;
        this.assignments = iArr;
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= 0) {
                while (iArr[i] >= this.modules.size()) {
                    Module module = new Module();
                    module.setModel(network, this.modules.size());
                    this.modules.add(module);
                }
                ((Module) this.modules.get(iArr[i])).add(network.Nodes().get(i));
            }
        }
        initialize(network.CandidateParents().size(), this.modules.size());
    }

    public void addModule(Module module) {
        module.setModel(this.containing_network, this.modules.size());
        this.modules.add(module);
    }

    public void clone_assignments(Network network, int[] iArr) {
        this.assignments = iArr;
        this.containing_network = network;
        this.modules = new Vector<>();
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] >= 0) {
                while (this.modules.size() <= iArr[i]) {
                    addModule(new Module());
                    VariableCPD CPD = network.CandidateChildren().get(this.modules.size() - 1).CPD();
                    if (CPD != null) {
                        this.modules.get(this.modules.size() - 1).setCPD(CPD.m49clone());
                    }
                }
                ((Module) this.modules.get(iArr[i])).add(this.containing_network.Nodes().get(i));
            }
        }
    }

    @Override // biolearn.GraphicalModel.PDAG, biolearn.GraphicalModel.ModelStructure
    public Set<Integer> getParents(int i) {
        if (i < super.numChildren()) {
            return super.getParents(i);
        }
        VariableCPD CPD = this.modules.get(i).CPD();
        if (!(CPD instanceof RegressionTree)) {
            throw new NotImplementedYet("Reading set of regulators out of CPD of class " + CPD.getClass().getSimpleName());
        }
        HashSet hashSet = new HashSet();
        ListIterator<RegressionTree.Node> listIterator = ((RegressionTree) CPD).InnerNodes().listIterator();
        while (listIterator.hasNext()) {
            hashSet.add(Integer.valueOf(listIterator.next().var.Index()));
        }
        return hashSet;
    }

    @Override // biolearn.GraphicalModel.PDAG, biolearn.GraphicalModel.ModelStructure
    public int numChildren() {
        return this.modules.size();
    }

    @Override // biolearn.GraphicalModel.PDAG, biolearn.GraphicalModel.ModelStructure
    public void addEdge(int i, int i2) {
        if (i2 < super.numChildren()) {
            super.addEdge(i, i2);
        }
    }

    public void setCPD(int i, VariableCPD variableCPD) {
        this.modules.get(i).setCPD(variableCPD);
    }

    public void synchronize() {
        Vector vector = new Vector();
        int i = 0;
        int numChildren = super.numChildren();
        int[] iArr = new int[this.modules.size()];
        ListIterator<ModelNode> listIterator = this.modules.listIterator();
        while (listIterator.hasNext()) {
            Module module = (Module) listIterator.next();
            if (module.Members().isEmpty()) {
                if (module.Index() < numChildren) {
                    vector.add(Integer.valueOf(module.Index()));
                }
                listIterator.remove();
            } else {
                iArr[module.Index()] = i;
                int i2 = i;
                i++;
                module.setModel(this.containing_network, i2);
            }
        }
        for (int i3 = 0; i3 < this.assignments.length; i3++) {
            if (this.assignments[i3] >= 0) {
                this.assignments[i3] = iArr[this.assignments[i3]];
            }
        }
        childrenChange(this.modules.size(), vector);
        ListIterator<ModelNode> listIterator2 = this.modules.listIterator(numChildren - vector.size());
        while (listIterator2.hasNext()) {
            VariableCPD CPD = listIterator2.next().CPD();
            if (!(CPD instanceof RegressionTree)) {
                throw new NotImplementedYet("Reading set of regulators out of CPD of class " + CPD.getClass().getSimpleName());
            }
            ListIterator<RegressionTree.Node> listIterator3 = ((RegressionTree) CPD).InnerNodes().listIterator();
            while (listIterator3.hasNext()) {
                addEdge(listIterator3.next().var.Index(), listIterator2.previousIndex());
            }
        }
    }

    @Override // biolearn.GraphicalModel.PDAG, biolearn.GraphicalModel.ModelStructure
    public int numPotentialReverses() {
        return 0;
    }

    @Override // biolearn.GraphicalModel.PDAG, biolearn.GraphicalModel.ModelStructure
    public Collection<Integer> constituents(int i) {
        return ((Module) this.modules.get(i)).MemberIndices();
    }

    @Override // biolearn.GraphicalModel.PDAG, biolearn.GraphicalModel.ModelStructure
    public int containing_node(int i) {
        return this.assignments[i];
    }

    public Vector<ModelNode> Modules() {
        return this.modules;
    }

    public int[] Assignments() {
        return this.assignments;
    }

    public void reassign(int i, int i2) {
        while (i2 >= this.modules.size()) {
            Module module = new Module();
            module.setModel(this.containing_network, this.modules.size());
            this.modules.add(module);
        }
        if (this.assignments[i] >= 0) {
            ((Module) this.modules.get(this.assignments[i])).remove(this.containing_network.Nodes().get(i));
        }
        if (i2 >= 0) {
            ((Module) this.modules.get(i2)).add(this.containing_network.Nodes().get(i));
        }
        if (BiolearnApplication.debug && this.assignments[i] >= 0) {
            System.err.println("reassigning " + i + " to " + i2 + ", old assignment " + this.assignments[i] + ' ' + this.modules.get(this.assignments[i]).Name() + " now has " + ((Module) this.modules.get(this.assignments[i])).Members().size() + " members");
        }
        this.assignments[i] = i2;
    }

    public Module assignment(int i) {
        if (this.assignments[i] >= 0) {
            return (Module) this.modules.get(this.assignments[i]);
        }
        return null;
    }

    public void clear_regulation_programs() {
        ListIterator<ModelNode> listIterator = this.modules.listIterator();
        while (listIterator.hasNext()) {
            ModelNode next = listIterator.next();
            clearParents(next.Index());
            try {
                next.setCPD((VariableCPD) next.CPD().getClass().newInstance());
            } catch (Exception e) {
                throw new Inconsistency(e.toString());
            }
        }
    }

    public void filter(ScoringFunction scoringFunction, WholeData wholeData, SufficientStatistic sufficientStatistic, float f, float f2) {
        float[] fArr = new float[this.assignments.length];
        float[] fArr2 = new float[this.assignments.length];
        Arrays.fill(fArr, Float.POSITIVE_INFINITY);
        Arrays.fill(fArr2, Float.POSITIVE_INFINITY);
        scoringFunction.resetCache();
        scoringFunction.setCaching(false);
        NormalGammaStat.Stat stat = new NormalGammaStat.Stat();
        NormalGammaStat.Stat[] statArr = new NormalGammaStat.Stat[this.modules.size()];
        for (int i = 0; i < statArr.length; i++) {
            statArr[i] = new NormalGammaStat.Stat();
        }
        boolean[] zArr = new boolean[this.modules.size()];
        Arrays.fill(zArr, false);
        Candidate candidate = null;
        for (int i2 = 0; i2 < this.assignments.length; i2++) {
            if (this.assignments[i2] >= 0) {
                if (f >= 0.0f) {
                    if (candidate == null) {
                        candidate = new Candidate(this.containing_network);
                        scoringFunction.score(this.containing_network, candidate, sufficientStatistic);
                    }
                    Candidate candidate2 = new Candidate(candidate, new ReassignMember(this.assignments[i2], -1, i2));
                    scoringFunction.score(this.containing_network, candidate2, sufficientStatistic);
                    fArr[i2] = (float) (candidate2.score - candidate.score);
                    statArr[this.assignments[i2]].add(fArr[i2]);
                }
                if (f2 >= 0.0f) {
                    RegressionTree regressionTree = (RegressionTree) this.modules.get(this.assignments[i2]).CPD();
                    RegressionTree regressionTree2 = new RegressionTree();
                    Integer[] numArr = {Integer.valueOf(i2)};
                    regressionTree.setLeaves(wholeData.Data(), Arrays.asList(numArr));
                    regressionTree2.setLeaves(wholeData.Data(), Arrays.asList(numArr));
                    fArr2[i2] = 0.0f;
                    Iterator<RTDP> it = wholeData.Data().iterator();
                    while (it.hasNext()) {
                        RTDP next = it.next();
                        int i3 = i2;
                        fArr2[i3] = fArr2[i3] + ((float) (regressionTree.logPDF(next, Arrays.asList(numArr)) - regressionTree2.logPDF(next, Arrays.asList(numArr))));
                    }
                    stat.add(fArr2[i2]);
                }
            }
        }
        float f3 = Float.POSITIVE_INFINITY;
        if (f2 > 1.0f) {
            f3 = stat.mean() - (f2 * ((float) stat.std()));
        } else if (f2 > 0.0f) {
            float[] fArr3 = (float[]) fArr2.clone();
            Arrays.sort(fArr3);
            f3 = fArr3[Math.round(stat.count * f2)];
        }
        for (int i4 = 0; i4 < this.assignments.length; i4++) {
            if (this.assignments[i4] >= 0) {
                boolean z = fArr[i4] >= (f > 0.0f ? statArr[this.assignments[i4]].mean() + (f * ((float) statArr[this.assignments[i4]].std())) : 0.0f);
                boolean z2 = fArr2[i4] <= f3;
                if (z && z2) {
                    reassign(i4, -1);
                } else if (!z2) {
                    zArr[this.assignments[i4]] = true;
                }
            }
        }
        for (int i5 = 0; i5 < this.modules.size(); i5++) {
            if (!zArr[i5]) {
                Iterator it2 = new HashSet(constituents(i5)).iterator();
                while (it2.hasNext()) {
                    reassign(((Integer) it2.next()).intValue(), -1);
                }
            }
        }
        synchronize();
        scoringFunction.setCaching(true);
    }
}
