package biolearn.ModuleNetwork.Learning;

import biolearn.GraphicalModel.Learning.SuffStat.NormalGammaStat;
import biolearn.GraphicalModel.Learning.SuffStat.WholeData;
import biolearn.ModuleNetwork.ModuleAssignment;
import biolearn.ModuleNetwork.Network;
import java.io.IOException;
import java.io.PrintStream;
import java.util.List;
import java.util.ListIterator;
import java.util.Random;
import java.util.Vector;
import org.apache.commons.math.stat.regression.SimpleRegression;

/* loaded from: input_file:biolearn/ModuleNetwork/Learning/RandomKMeansAssignment.class */
public class RandomKMeansAssignment implements InitialClustering {
    public int num_clusters;
    int max_iterations;
    int num_kmeans_runs;
    public int current_run = 0;
    private List<String> args_cache;
    static Random generator = new Random();
    public static boolean pearson = false;
    static SimpleRegression regression = null;

    public RandomKMeansAssignment(Vector<String> vector) {
        this.num_clusters = 500;
        this.max_iterations = 10;
        this.num_kmeans_runs = 1;
        this.args_cache = new Vector(vector);
        ListIterator<String> listIterator = vector.listIterator();
        while (listIterator.hasNext()) {
            String lowerCase = listIterator.next().toLowerCase();
            if (lowerCase.startsWith("k=")) {
                this.num_clusters = Integer.parseInt(lowerCase.substring(2));
            } else if (lowerCase.startsWith("maxiterations=")) {
                this.max_iterations = Integer.parseInt(lowerCase.substring(14));
            } else if (lowerCase.startsWith("restarts=")) {
                this.num_kmeans_runs = Integer.parseInt(lowerCase.substring(9)) + 1;
            } else if (Character.isDigit(lowerCase.charAt(0))) {
                this.num_clusters = Integer.parseInt(lowerCase);
            }
        }
    }

    @Override // biolearn.ModuleNetwork.Learning.InitialClustering
    public ModuleAssignment clusters(Network network, WholeData wholeData) {
        List<List<Number>> VarVectors = wholeData.VarVectors();
        int[] iArr = (int[]) null;
        float f = Float.NEGATIVE_INFINITY;
        this.current_run = 1;
        while (this.current_run <= this.num_kmeans_runs && !Thread.interrupted()) {
            Vector vector = new Vector();
            for (int i = 0; i < this.num_clusters; i++) {
                vector.add(new Vector());
            }
            int[] iArr2 = new int[network.Nodes().size()];
            for (int i2 = 0; i2 < network.Nodes().size(); i2++) {
                if (network.isModuleMember(i2)) {
                    iArr2[i2] = generator.nextInt(this.num_clusters);
                    ((List) vector.get(iArr2[i2])).add(VarVectors.get(i2));
                } else {
                    iArr2[i2] = -1;
                }
            }
            performKMeans(VarVectors, vector, null, iArr2, this.max_iterations);
            float scoreClusters = this.num_kmeans_runs > 1 ? scoreClusters(vector) : 0.0f;
            if (scoreClusters > f) {
                f = scoreClusters;
                iArr = iArr2;
                int[] iArr3 = new int[this.num_clusters];
                int i3 = 0;
                ListIterator listIterator = vector.listIterator();
                while (listIterator.hasNext()) {
                    if (!((List) listIterator.next()).isEmpty()) {
                        int i4 = i3;
                        i3++;
                        iArr3[listIterator.previousIndex()] = i4;
                    }
                }
                for (int i5 = 0; i5 < network.Nodes().size(); i5++) {
                    if (network.isModuleMember(i5)) {
                        iArr[i5] = iArr3[iArr[i5]];
                    }
                }
            }
            this.current_run++;
        }
        return new ModuleAssignment(network, iArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static List<List<Number>> performKMeans(List<List<Number>> list, List<List<List<Number>>> list2, List<List<Number>> list3, int[] iArr, int i) {
        for (int i2 = 0; i2 < i && !Thread.interrupted(); i2++) {
            if (i2 > 0 || list3 == null) {
                list3 = new Vector();
                ListIterator<List<List<Number>>> listIterator = list2.listIterator();
                while (listIterator.hasNext()) {
                    list3.add(compute_center(listIterator.next()));
                }
            }
            ListIterator<List<List<Number>>> listIterator2 = list2.listIterator();
            while (listIterator2.hasNext()) {
                listIterator2.next().clear();
            }
            int i3 = 0;
            for (int i4 = 0; i4 < iArr.length; i4++) {
                if (iArr[i4] >= 0) {
                    int closest_point = closest_point(list3, list.get(i4));
                    if (closest_point != iArr[i4]) {
                        i3++;
                    }
                    iArr[i4] = closest_point;
                    list2.get(iArr[i4]).add(list.get(i4));
                }
            }
            if (i3 == 0) {
                break;
            }
        }
        return list3;
    }

    private static float scoreClusters(List<List<List<Number>>> list) {
        NormalGammaStat.Stat stat = new NormalGammaStat.Stat();
        Vector vector = new Vector();
        ListIterator<List<List<Number>>> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            List<List<Number>> next = listIterator.next();
            if (!next.isEmpty()) {
                vector.add(compute_center(next));
                ListIterator<List<Number>> listIterator2 = next.listIterator();
                while (listIterator2.hasNext()) {
                    List<Number> next2 = listIterator2.next();
                    ListIterator<List<Number>> listIterator3 = next.listIterator(listIterator2.nextIndex());
                    while (listIterator3.hasNext()) {
                        stat.add((float) distance(next2, listIterator3.next()));
                    }
                }
            }
        }
        NormalGammaStat.Stat stat2 = new NormalGammaStat.Stat();
        ListIterator listIterator4 = vector.listIterator();
        while (listIterator4.hasNext()) {
            List list2 = (List) listIterator4.next();
            ListIterator listIterator5 = vector.listIterator(listIterator4.nextIndex());
            while (listIterator5.hasNext()) {
                stat2.add((float) distance(list2, (List) listIterator5.next()));
            }
        }
        return stat2.mean() / stat.mean();
    }

    static List<Number> compute_center(List<List<Number>> list) {
        if (list.isEmpty()) {
            return null;
        }
        Vector vector = new Vector();
        for (int i = 0; i < list.get(0).size(); i++) {
            NormalGammaStat.Stat stat = new NormalGammaStat.Stat();
            ListIterator<List<Number>> listIterator = list.listIterator();
            while (listIterator.hasNext()) {
                stat.add(listIterator.next().get(i).floatValue());
            }
            vector.add(new Float(stat.mean()));
        }
        return vector;
    }

    /* JADX WARN: Type inference failed for: r0v25, types: [double[], java.lang.Object[], double[][]] */
    public static double distance(List<Number> list, List<Number> list2) {
        if (list == null || list2 == null) {
            return Double.POSITIVE_INFINITY;
        }
        if (!pearson) {
            NormalGammaStat.Stat stat = new NormalGammaStat.Stat();
            ListIterator<Number> listIterator = list.listIterator();
            ListIterator<Number> listIterator2 = list2.listIterator();
            while (listIterator.hasNext()) {
                stat.add(listIterator.next().floatValue() - listIterator2.next().floatValue());
            }
            return stat.sumsq / stat.count;
        }
        Vector vector = new Vector();
        ListIterator<Number> listIterator3 = list.listIterator();
        ListIterator<Number> listIterator4 = list2.listIterator();
        while (listIterator3.hasNext()) {
            double[] dArr = {listIterator3.next().doubleValue(), listIterator4.next().doubleValue()};
            if (!Double.isNaN(dArr[0]) && !Double.isNaN(dArr[1])) {
                vector.add(dArr);
            }
        }
        ?? r0 = new double[vector.size()];
        vector.toArray((Object[]) r0);
        if (regression == null) {
            regression = new SimpleRegression();
        } else {
            regression.clear();
        }
        regression.addData((double[][]) r0);
        if (regression.getR() < 0.0d) {
            return Double.POSITIVE_INFINITY;
        }
        return 1.0d / regression.getR();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int closest_point(List<List<Number>> list, List<Number> list2) {
        double d = Double.POSITIVE_INFINITY;
        int i = -1;
        ListIterator<List<Number>> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            double distance = distance(listIterator.next(), list2);
            if (distance < d) {
                d = distance;
                i = listIterator.previousIndex();
            }
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double[] distances(List<List<Number>> list, List<Number> list2) {
        double[] dArr = new double[list.size()];
        int i = 0;
        ListIterator<List<Number>> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            int i2 = i;
            i++;
            dArr[i2] = distance(listIterator.next(), list2);
        }
        return dArr;
    }

    @Override // biolearn.BiolearnComponent
    public void WriteRecord(PrintStream printStream, boolean z) throws IOException {
        if (z) {
            printStream.print("# ");
        }
        printStream.print("ModuleInitiation RandomKMeansAssignment");
        ListIterator<String> listIterator = this.args_cache.listIterator();
        while (listIterator.hasNext()) {
            printStream.print(" " + listIterator.next());
        }
        printStream.println();
    }
}
