package biolearn.ModuleNetwork.Learning;

import biolearn.Applications.BiolearnApplication;
import biolearn.GraphicalModel.Learning.Structure.Scores.NormalGamma;
import biolearn.GraphicalModel.Learning.SuffStat.NormalGammaStat;
import biolearn.GraphicalModel.Learning.SuffStat.Util.RTDP;
import biolearn.GraphicalModel.Learning.SuffStat.Util.RTDPSet;
import biolearn.GraphicalModel.Learning.SuffStat.WholeData;
import biolearn.ModuleNetwork.ModuleAssignment;
import biolearn.ModuleNetwork.Network;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.TreeSet;
import java.util.Vector;

/* loaded from: input_file:biolearn/ModuleNetwork/Learning/DoublePCluster.class */
public class DoublePCluster implements InitialClustering {
    private static final int cache_size = 3000;
    private int K;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:biolearn/ModuleNetwork/Learning/DoublePCluster$Cluster.class */
    public static class Cluster implements Comparable<Cluster> {
        TreeSet<Integer> members = new TreeSet<>();
        NormalGammaStat.Stat[] stats;
        double score;

        Cluster(int i, RTDPSet rTDPSet, Clustering clustering) {
            this.members.add(Integer.valueOf(i));
            this.stats = new NormalGammaStat.Stat[clustering == null ? rTDPSet.get(0).val.length : clustering.size()];
            this.score = 0.0d;
            for (int i2 = 0; i2 < this.stats.length; i2++) {
                this.stats[i2] = new NormalGammaStat.Stat();
                RTDP rtdp = rTDPSet.get(i);
                if (clustering == null) {
                    this.stats[i2].add(rtdp, i2);
                } else {
                    Iterator<Integer> it = clustering.get(i2).members.iterator();
                    while (it.hasNext()) {
                        this.stats[i2].add(rtdp, it.next().intValue());
                    }
                }
                this.score += ((NormalGamma) BiolearnApplication.scoring_function).score(this.stats[i2]);
            }
        }

        Cluster(Cluster cluster, Cluster cluster2) {
            this.members.addAll(cluster.members);
            this.members.addAll(cluster2.members);
            this.stats = new NormalGammaStat.Stat[cluster.stats.length];
            this.score = 0.0d;
            for (int i = 0; i < this.stats.length; i++) {
                this.stats[i] = new NormalGammaStat.Stat(cluster.stats[i]);
                this.stats[i].add(cluster2.stats[i]);
                this.score += ((NormalGamma) BiolearnApplication.scoring_function).score(this.stats[i]);
            }
        }

        @Override // java.lang.Comparable
        public int compareTo(Cluster cluster) {
            int intValue = this.members.first().intValue() - cluster.members.first().intValue();
            return intValue == 0 ? this.members.size() - cluster.members.size() : intValue;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:biolearn/ModuleNetwork/Learning/DoublePCluster$ClusterJoin.class */
    public static class ClusterJoin implements Comparable<ClusterJoin> {
        Cluster cl1;
        Cluster cl2;
        Cluster joined;
        double deltascore;

        ClusterJoin(Cluster cluster, Cluster cluster2) {
            boolean z = cluster.compareTo(cluster2) < 0;
            this.cl1 = z ? cluster : cluster2;
            this.cl2 = z ? cluster2 : cluster;
            this.joined = new Cluster(this.cl1, this.cl2);
            this.deltascore = (this.joined.score - cluster.score) - cluster2.score;
        }

        @Override // java.lang.Comparable
        public int compareTo(ClusterJoin clusterJoin) {
            int signum = (int) Math.signum(clusterJoin.deltascore - this.deltascore);
            return signum != 0 ? signum : this.cl1 != clusterJoin.cl1 ? this.cl1.compareTo(clusterJoin.cl1) : this.cl2.compareTo(clusterJoin.cl2);
        }
    }

    /* loaded from: input_file:biolearn/ModuleNetwork/Learning/DoublePCluster$Clustering.class */
    private static class Clustering extends Vector<Cluster> {
        int numMembers;
        double score = 0.0d;

        Clustering(RTDPSet rTDPSet, Clustering clustering) {
            this.numMembers = rTDPSet.size();
            for (int i = 0; i < this.numMembers; i++) {
                Cluster cluster = new Cluster(i, rTDPSet, clustering);
                add(cluster);
                this.score += cluster.score;
            }
        }

        void join(ClusterJoin clusterJoin) {
            int binarySearch = Collections.binarySearch(this, clusterJoin.cl1);
            int binarySearch2 = Collections.binarySearch(this, clusterJoin.cl2);
            setElementAt(clusterJoin.joined, binarySearch);
            remove(binarySearch2);
            this.score += clusterJoin.deltascore;
        }

        TreeSet<ClusterJoin> JoinCache() {
            TreeSet<ClusterJoin> treeSet = new TreeSet<>();
            ListIterator<Cluster> listIterator = listIterator();
            while (listIterator.hasNext()) {
                Cluster next = listIterator.next();
                ListIterator<Cluster> listIterator2 = listIterator(listIterator.nextIndex());
                while (listIterator2.hasNext()) {
                    ClusterJoin clusterJoin = new ClusterJoin(next, listIterator2.next());
                    if (clusterJoin.deltascore > 0.0d) {
                        treeSet.add(clusterJoin);
                        if (treeSet.size() > DoublePCluster.cache_size) {
                            treeSet.pollLast();
                        }
                    }
                }
            }
            return treeSet;
        }
    }

    public DoublePCluster(Vector<String> vector) {
        this.K = 0;
        Iterator<String> it = vector.iterator();
        while (it.hasNext()) {
            String next = it.next();
            if (next.toLowerCase().startsWith("k=")) {
                this.K = Integer.parseInt(next.substring(2));
            }
        }
    }

    @Override // biolearn.ModuleNetwork.Learning.InitialClustering
    public ModuleAssignment clusters(Network network, WholeData wholeData) {
        List<List<Number>> VarVectors = wholeData.VarVectors();
        Vector vector = new Vector();
        for (int i = 0; i < network.Nodes().size(); i++) {
            vector.add(Integer.valueOf(i));
        }
        for (int size = network.CandidateParents().size() - 1; size >= 0; size--) {
            if (!network.isModuleMember(size)) {
                VarVectors.remove(size);
                vector.remove(size);
            }
        }
        RTDPSet[] rTDPSetArr = {new RTDPSet(VarVectors, false), new RTDPSet(VarVectors, true)};
        if (BiolearnApplication.debugModules) {
            System.err.println(String.valueOf(String.valueOf(rTDPSetArr[0].size())) + " vars, " + rTDPSetArr[1].size() + " experiments");
        }
        Clustering[] clusteringArr = new Clustering[2];
        clusteringArr[1] = null;
        clusteringArr[0] = null;
        boolean z = false;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (z) {
                break;
            }
            Clustering clustering = new Clustering(rTDPSetArr[i3], clusteringArr[1 - i3]);
            TreeSet<ClusterJoin> treeSet = new TreeSet<>();
            double d = 0.0d;
            while (true) {
                if (treeSet.size() == cache_size && treeSet.last().deltascore > d) {
                    d = treeSet.last().deltascore;
                }
                if (treeSet.isEmpty() || treeSet.first().deltascore <= d) {
                    treeSet = clustering.JoinCache();
                    if (treeSet.isEmpty()) {
                        break;
                    }
                    d = treeSet.last().deltascore;
                    if (BiolearnApplication.debugModules) {
                        System.err.println("Cache recalculated " + treeSet.size() + " joins" + clustering.size() + " clusters, range " + treeSet.first().deltascore + " - " + d);
                    }
                }
                ClusterJoin pollFirst = treeSet.pollFirst();
                Iterator<ClusterJoin> it = treeSet.iterator();
                while (it.hasNext()) {
                    ClusterJoin next = it.next();
                    if (next.cl1 == pollFirst.cl1 || next.cl2 == pollFirst.cl2 || next.cl1 == pollFirst.cl2 || next.cl2 == pollFirst.cl1) {
                        it.remove();
                    }
                }
                clustering.join(pollFirst);
                Iterator<Cluster> it2 = clustering.iterator();
                while (it2.hasNext()) {
                    Cluster next2 = it2.next();
                    if (next2 != pollFirst.joined) {
                        ClusterJoin clusterJoin = new ClusterJoin(next2, pollFirst.joined);
                        if (clusterJoin.deltascore > 0.0d) {
                            treeSet.add(clusterJoin);
                            if (treeSet.size() > cache_size) {
                                treeSet.pollLast();
                            }
                        }
                    }
                }
                if (i3 == 0 && clustering.size() <= this.K) {
                    break;
                }
            }
            z = clusteringArr[i3] != null && (i3 != 0 ? clusteringArr[i3].equals(clustering) : clusteringArr[i3].score >= clustering.score);
            if (!z) {
                clusteringArr[i3] = clustering;
            }
            if (BiolearnApplication.debugModules) {
                System.err.println("clustering on " + i3 + ", " + clustering.size() + " clusters, score " + clustering.score + " at " + new Date().toString());
            }
            i2 = 1 - i3;
        }
        int[] iArr = new int[network.Nodes().size()];
        Arrays.fill(iArr, -1);
        ListIterator<Cluster> listIterator = clusteringArr[0].listIterator();
        while (listIterator.hasNext()) {
            Iterator<Integer> it3 = listIterator.next().members.iterator();
            while (it3.hasNext()) {
                iArr[((Integer) vector.get(it3.next().intValue())).intValue()] = listIterator.previousIndex();
            }
        }
        return new ModuleAssignment(network, iArr);
    }

    @Override // biolearn.BiolearnComponent
    public void WriteRecord(PrintStream printStream, boolean z) throws IOException {
        if (z) {
            printStream.print("# ");
        }
        printStream.println("ModuleInitiation DoublePCluster K=" + this.K);
    }
}
