package biolearn.GraphicalModel.Learning.Structure.Priors;

import biolearn.GraphicalModel.CPDs.RegressionTree;
import biolearn.GraphicalModel.Learning.Structure.Candidate;
import biolearn.GraphicalModel.Learning.Structure.Modifications.RegressionTreeModification;
import biolearn.GraphicalModel.Learning.Structure.ScorePrior;
import biolearn.GraphicalModel.Learning.SufficientStatistic;
import biolearn.ModuleNetwork.Learning.ReassignMember;
import biolearn.NotImplementedYet;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Vector;

/* loaded from: input_file:biolearn/GraphicalModel/Learning/Structure/Priors/RegulatorPenalty.class */
public class RegulatorPenalty implements ScorePrior {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:biolearn/GraphicalModel/Learning/Structure/Priors/RegulatorPenalty$counts.class */
    public static class counts extends HashSet<Integer> {
        private int total_members;

        counts() {
            this.total_members = 0;
        }

        counts(counts countsVar) {
            super(countsVar);
            this.total_members = 0;
            this.total_members = countsVar.total_members;
        }

        int numTargetMembers() {
            return this.total_members;
        }

        void addTarget(int i, int i2) {
            if (i < 0 || add(Integer.valueOf(i))) {
                this.total_members += i2;
            }
        }
    }

    static HashMap<Integer, counts> getCounts(Candidate candidate) {
        if (candidate.cache == null) {
            candidate.cache = new HashMap();
        }
        HashMap<Integer, counts> hashMap = (HashMap) candidate.cache.get("Regulators");
        if (hashMap == null) {
            hashMap = new HashMap<>();
            candidate.cache.put("Regulators", hashMap);
            for (int i = 0; i < candidate.CPDs.length; i++) {
                for (Integer num : candidate.getParents(i)) {
                    counts countsVar = hashMap.get(num);
                    if (countsVar == null) {
                        countsVar = new counts();
                        hashMap.put(num, countsVar);
                    }
                    countsVar.addTarget(i, candidate.structure.constituents(i).size());
                }
            }
        }
        return hashMap;
    }

    public static void updateCounts(Candidate candidate) {
        HashMap hashMap;
        if (candidate.cache == null || (hashMap = (HashMap) candidate.cache.get("Regulators")) == null) {
            return;
        }
        if (candidate.modification instanceof RegressionTreeModification) {
            RegressionTreeModification regressionTreeModification = (RegressionTreeModification) candidate.modification;
            counts countsVar = (counts) hashMap.get(Integer.valueOf(regressionTreeModification.from));
            if (countsVar == null) {
                countsVar = new counts();
                hashMap.put(Integer.valueOf(regressionTreeModification.from), countsVar);
            }
            int size = candidate.structure.constituents(regressionTreeModification.to).size();
            if (regressionTreeModification.type != 0) {
                throw new NotImplementedYet("Regulator or Split Point penalty with unsplitting");
            }
            countsVar.addTarget(regressionTreeModification.to, size);
            return;
        }
        if (candidate.modification instanceof ReassignMember) {
            HashSet hashSet = new HashSet();
            Iterator<RegressionTree.Node> it = ((RegressionTree) candidate.getCPD(candidate.modification.from)).InnerNodes().iterator();
            while (it.hasNext()) {
                int Index = it.next().var.Index();
                if (hashSet.add(Integer.valueOf(Index))) {
                    ((counts) hashMap.get(Integer.valueOf(Index))).addTarget(-1, -1);
                }
            }
            hashSet.clear();
            if (candidate.modification.to >= 0) {
                Iterator<RegressionTree.Node> it2 = ((RegressionTree) candidate.getCPD(candidate.modification.to)).InnerNodes().iterator();
                while (it2.hasNext()) {
                    int Index2 = it2.next().var.Index();
                    if (hashSet.add(Integer.valueOf(Index2))) {
                        counts countsVar2 = (counts) hashMap.get(Integer.valueOf(Index2));
                        if (countsVar2 == null) {
                            counts countsVar3 = new counts();
                            countsVar3.addTarget(candidate.modification.to, 1);
                            hashMap.put(Integer.valueOf(Index2), countsVar3);
                        } else {
                            countsVar2.addTarget(-1, 1);
                        }
                    }
                }
            }
        }
    }

    public static HashMap<Integer, counts> clonecache(Map map) {
        if (map == null) {
            return null;
        }
        HashMap<Integer, counts> hashMap = new HashMap<>();
        for (Map.Entry entry : map.entrySet()) {
            hashMap.put((Integer) entry.getKey(), new counts((counts) entry.getValue()));
        }
        return hashMap;
    }

    public RegulatorPenalty(Vector<String> vector) {
    }

    @Override // biolearn.GraphicalModel.Learning.Structure.ScorePrior
    public double priorValue(Candidate candidate) {
        double d = 0.0d;
        Iterator<counts> it = getCounts(candidate).values().iterator();
        while (it.hasNext()) {
            d -= Math.log(it.next().numTargetMembers() + 1);
        }
        return d;
    }

    @Override // biolearn.GraphicalModel.Learning.Structure.ScorePrior
    public double priorDelta(Candidate candidate) {
        HashMap<Integer, counts> counts2 = getCounts(candidate);
        if (candidate.modification instanceof RegressionTreeModification) {
            RegressionTreeModification regressionTreeModification = (RegressionTreeModification) candidate.modification;
            if (regressionTreeModification.type != 0) {
                throw new NotImplementedYet("Regulator penalty with non-split operator");
            }
            counts countsVar = counts2.get(Integer.valueOf(regressionTreeModification.from));
            if (countsVar != null && countsVar.contains(Integer.valueOf(regressionTreeModification.to))) {
                return 0.0d;
            }
            int numTargetMembers = countsVar == null ? 0 : countsVar.numTargetMembers();
            return Math.log(numTargetMembers + 1) - Math.log((numTargetMembers + candidate.constituents(regressionTreeModification.to).size()) + 1);
        }
        if (!(candidate.modification instanceof ReassignMember)) {
            throw new NotImplementedYet("leaf penalty for modification of class " + candidate.modification.getClass().getName());
        }
        int[] iArr = new int[candidate.model.CandidateParents().size()];
        Arrays.fill(iArr, 0);
        if (candidate.modification.from >= 0) {
            Iterator<RegressionTree.Node> it = ((RegressionTree) candidate.getCPD(candidate.modification.from)).InnerNodes().iterator();
            while (it.hasNext()) {
                iArr[it.next().var.Index()] = -1;
            }
        }
        HashSet hashSet = new HashSet();
        if (candidate.modification.to >= 0) {
            Iterator<RegressionTree.Node> it2 = ((RegressionTree) candidate.getCPD(candidate.modification.to)).InnerNodes().iterator();
            while (it2.hasNext()) {
                int Index = it2.next().var.Index();
                if (hashSet.add(Integer.valueOf(Index))) {
                    iArr[Index] = iArr[Index] + 1;
                }
            }
        }
        double d = 0.0d;
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] != 0) {
                counts countsVar2 = counts2.get(Integer.valueOf(i));
                int numTargetMembers2 = countsVar2 == null ? 0 : countsVar2.numTargetMembers();
                d += Math.log(numTargetMembers2 + 1) - Math.log((numTargetMembers2 + iArr[i]) + 1);
            }
        }
        return d;
    }

    @Override // biolearn.GraphicalModel.Learning.Structure.ScorePrior
    public void setData(SufficientStatistic sufficientStatistic) {
    }

    public String toString() {
        return "RegulatorPenalty";
    }
}
