/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.ClusteringDistance;
import smile.clustering.PartitionClustering;
import smile.data.SparseDataset;
import smile.math.Math;
import smile.math.SparseArray;
import smile.util.MulticoreExecutor;

public class SIB
extends PartitionClustering<double[]> {
    private static final Logger logger = LoggerFactory.getLogger(SIB.class);
    private double distortion;
    private double[][] centroids;

    public SIB(double[][] data, int k) {
        this(data, k, 100);
    }

    public SIB(double[][] data, int k, int maxIter) {
        int j;
        int i;
        if (k < 2) {
            throw new IllegalArgumentException("Invalid parameter k = " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = data.length;
        int d = data[0].length;
        this.k = k;
        this.size = new int[k];
        this.centroids = new double[k][d];
        this.y = SIB.seed(data, k, ClusteringDistance.JENSEN_SHANNON_DIVERGENCE);
        for (i = 0; i < n; ++i) {
            int n2 = this.y[i];
            this.size[n2] = this.size[n2] + 1;
            for (j = 0; j < d; ++j) {
                double[] dArray = this.centroids[this.y[i]];
                int n3 = j;
                dArray[n3] = dArray[n3] + data[i][j];
            }
        }
        for (i = 0; i < k; ++i) {
            j = 0;
            while (j < d) {
                double[] dArray = this.centroids[i];
                int n4 = j++;
                dArray[n4] = dArray[n4] / (double)this.size[i];
            }
        }
        int reassignment = n;
        for (int iter = 1; iter <= maxIter && reassignment > 0; ++iter) {
            reassignment = 0;
            for (int i2 = 0; i2 < n; ++i2) {
                int j2;
                double nearest = Double.MAX_VALUE;
                int c = -1;
                for (int j3 = 0; j3 < k; ++j3) {
                    double dist = Math.JensenShannonDivergence(data[i2], this.centroids[j3]);
                    if (!(nearest > dist)) continue;
                    nearest = dist;
                    c = j3;
                }
                if (c == this.y[i2]) continue;
                int o = this.y[i2];
                if (this.size[o] > 1) {
                    int m = this.size[o] - 1;
                    for (j2 = 0; j2 < d; ++j2) {
                        this.centroids[o][j2] = (this.centroids[o][j2] * (double)this.size[o] - data[i2][j2]) / (double)m;
                        if (!(this.centroids[o][j2] < 0.0)) continue;
                        this.centroids[o][j2] = 0.0;
                    }
                } else {
                    Arrays.fill(this.centroids[o], 0.0);
                }
                int m = this.size[c] + 1;
                for (j2 = 0; j2 < d; ++j2) {
                    this.centroids[c][j2] = (this.centroids[c][j2] * (double)this.size[c] + data[i2][j2]) / (double)m;
                }
                int n5 = o;
                this.size[n5] = this.size[n5] - 1;
                int n6 = c;
                this.size[n6] = this.size[n6] + 1;
                this.y[i2] = c;
                ++reassignment;
            }
        }
        this.distortion = 0.0;
        for (i = 0; i < n; ++i) {
            this.distortion += Math.JensenShannonDivergence(data[i], this.centroids[this.y[i]]);
        }
    }

    public SIB(double[][] data, int k, int maxIter, int runs) {
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        if (runs <= 0) {
            throw new IllegalArgumentException("Invalid number of runs: " + runs);
        }
        ArrayList<SIBThread> tasks = new ArrayList<SIBThread>();
        for (int i = 0; i < runs; ++i) {
            tasks.add(new SIBThread(data, k, maxIter));
        }
        SIB best = null;
        try {
            List clusters = MulticoreExecutor.run(tasks);
            best = (SIB)clusters.get(0);
            for (int i = 1; i < runs; ++i) {
                SIB sib = (SIB)clusters.get(i);
                if (!(sib.distortion < best.distortion)) continue;
                best = sib;
            }
        }
        catch (Exception ex) {
            logger.error("Failed to run Sequential Information Bottleneck on multi-core", (Throwable)ex);
            best = new SIB(data, k, maxIter);
            for (int i = 1; i < runs; ++i) {
                SIB sib = new SIB(data, k, maxIter);
                if (!(sib.distortion < best.distortion)) continue;
                best = sib;
            }
        }
        this.k = best.k;
        this.distortion = best.distortion;
        this.centroids = best.centroids;
        this.y = best.y;
        this.size = best.size;
    }

    private static int[] seed(SparseDataset data, int k) {
        int i;
        int n = data.size();
        int[] y = new int[n];
        SparseArray centroid = (SparseArray)data.get((int)Math.randomInt((int)n)).x;
        double[] D = new double[n];
        for (i = 0; i < n; ++i) {
            D[i] = Double.MAX_VALUE;
        }
        for (i = 1; i < k; ++i) {
            int index;
            for (int j = 0; j < n; ++j) {
                double dist = Math.JensenShannonDivergence((SparseArray)data.get((int)j).x, centroid);
                if (!(dist < D[j])) continue;
                D[j] = dist;
                y[j] = i - 1;
            }
            double cutoff = Math.random() * Math.sum(D);
            double cost = 0.0;
            for (index = 0; index < n && !((cost += D[index]) >= cutoff); ++index) {
            }
            centroid = (SparseArray)data.get((int)index).x;
        }
        for (int j = 0; j < n; ++j) {
            double dist = Math.JensenShannonDivergence((SparseArray)data.get((int)j).x, centroid);
            if (!(dist < D[j])) continue;
            D[j] = dist;
            y[j] = k - 1;
        }
        return y;
    }

    public SIB(SparseDataset data, int k) {
        this(data, k, 100);
    }

    public SIB(SparseDataset data, int k, int maxIter) {
        int i;
        if (k < 2) {
            throw new IllegalArgumentException("Invalid parameter k = " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = data.size();
        int d = data.ncols();
        this.k = k;
        this.distortion = Double.MAX_VALUE;
        this.size = new int[k];
        this.centroids = new double[k][d];
        this.y = SIB.seed(data, k);
        for (i = 0; i < n; ++i) {
            int n2 = this.y[i];
            this.size[n2] = this.size[n2] + 1;
            for (SparseArray.Entry e : (SparseArray)data.get((int)i).x) {
                double[] dArray = this.centroids[this.y[i]];
                int n3 = e.i;
                dArray[n3] = dArray[n3] + e.x;
            }
        }
        for (i = 0; i < k; ++i) {
            int j = 0;
            while (j < d) {
                double[] dArray = this.centroids[i];
                int n4 = j++;
                dArray[n4] = dArray[n4] / (double)this.size[i];
            }
        }
        int reassignment = n;
        for (int iter = 1; iter <= maxIter && reassignment > 0; ++iter) {
            reassignment = 0;
            for (int i2 = 0; i2 < n; ++i2) {
                double nearest = Double.MAX_VALUE;
                int c = -1;
                for (int j = 0; j < k; ++j) {
                    double dist = Math.JensenShannonDivergence((SparseArray)data.get((int)i2).x, this.centroids[j]);
                    if (!(nearest > dist)) continue;
                    nearest = dist;
                    c = j;
                }
                if (c == this.y[i2]) continue;
                int o = this.y[i2];
                int j = 0;
                while (j < d) {
                    double[] dArray = this.centroids[c];
                    int n5 = j;
                    dArray[n5] = dArray[n5] * (double)this.size[c];
                    double[] dArray2 = this.centroids[o];
                    int n6 = j++;
                    dArray2[n6] = dArray2[n6] * (double)this.size[o];
                }
                for (SparseArray.Entry e : (SparseArray)data.get((int)i2).x) {
                    int j2 = e.i;
                    double p = e.x;
                    double[] dArray = this.centroids[c];
                    int n7 = j2;
                    dArray[n7] = dArray[n7] + p;
                    double[] dArray3 = this.centroids[o];
                    int n8 = j2;
                    dArray3[n8] = dArray3[n8] - p;
                    if (!(this.centroids[o][j2] < 0.0)) continue;
                    this.centroids[o][j2] = 0.0;
                }
                int n9 = o;
                this.size[n9] = this.size[n9] - 1;
                int n10 = c;
                this.size[n10] = this.size[n10] + 1;
                j = 0;
                while (j < d) {
                    double[] dArray = this.centroids[c];
                    int n11 = j++;
                    dArray[n11] = dArray[n11] / (double)this.size[c];
                }
                if (this.size[o] > 0) {
                    j = 0;
                    while (j < d) {
                        double[] dArray = this.centroids[o];
                        int n12 = j++;
                        dArray[n12] = dArray[n12] / (double)this.size[o];
                    }
                }
                this.y[i2] = c;
                ++reassignment;
            }
        }
        this.distortion = 0.0;
        for (i = 0; i < n; ++i) {
            this.distortion += Math.JensenShannonDivergence((SparseArray)data.get((int)i).x, this.centroids[this.y[i]]);
        }
    }

    public SIB(SparseDataset data, int k, int maxIter, int runs) {
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        if (runs <= 0) {
            throw new IllegalArgumentException("Invalid number of runs: " + runs);
        }
        ArrayList<SIBThread> tasks = new ArrayList<SIBThread>();
        for (int i = 0; i < runs; ++i) {
            tasks.add(new SIBThread(data, k, maxIter));
        }
        SIB best = null;
        try {
            List clusters = MulticoreExecutor.run(tasks);
            best = (SIB)clusters.get(0);
            for (int i = 1; i < runs; ++i) {
                SIB sib = (SIB)clusters.get(i);
                if (!(sib.distortion < best.distortion)) continue;
                best = sib;
            }
        }
        catch (Exception ex) {
            logger.error("Failed to run Sequential Information Bottleneck on multi-core", (Throwable)ex);
            best = new SIB(data, k, maxIter);
            for (int i = 1; i < runs; ++i) {
                SIB sib = new SIB(data, k, maxIter);
                if (!(sib.distortion < best.distortion)) continue;
                best = sib;
            }
        }
        this.k = best.k;
        this.distortion = best.distortion;
        this.centroids = best.centroids;
        this.y = best.y;
        this.size = best.size;
    }

    @Override
    public int predict(double[] x) {
        double minDist = Double.MAX_VALUE;
        int bestCluster = 0;
        for (int i = 0; i < this.k; ++i) {
            double dist = Math.JensenShannonDivergence(x, this.centroids[i]);
            if (!(dist < minDist)) continue;
            minDist = dist;
            bestCluster = i;
        }
        return bestCluster;
    }

    @Override
    public int predict(SparseArray x) {
        double minDist = Double.MAX_VALUE;
        int bestCluster = 0;
        for (int i = 0; i < this.k; ++i) {
            double dist = Math.JensenShannonDivergence(x, this.centroids[i]);
            if (!(dist < minDist)) continue;
            minDist = dist;
            bestCluster = i;
        }
        return bestCluster;
    }

    public double distortion() {
        return this.distortion;
    }

    public double[][] centroids() {
        return this.centroids;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("Sequential Information Bottleneck distortion: %.5f\n", this.distortion));
        sb.append(String.format("Clusters of %d data points of dimension %d:\n", this.y.length, this.centroids[0].length));
        for (int i = 0; i < this.k; ++i) {
            int r = (int)Math.round(1000.0 * (double)this.size[i] / (double)this.y.length);
            sb.append(String.format("%3d\t%5d (%2d.%1d%%)\n", i, this.size[i], r / 10, r % 10));
        }
        return sb.toString();
    }

    static class SIBThread
    implements Callable<SIB> {
        double[][] data = null;
        SparseDataset sparse = null;
        final int k;
        final int maxIter;

        SIBThread(double[][] data, int k, int maxIter) {
            this.data = data;
            this.k = k;
            this.maxIter = maxIter;
        }

        SIBThread(SparseDataset sparse, int k, int maxIter) {
            this.sparse = sparse;
            this.k = k;
            this.maxIter = maxIter;
        }

        @Override
        public SIB call() {
            if (this.data != null) {
                return new SIB(this.data, this.k, this.maxIter);
            }
            return new SIB(this.sparse, this.k, this.maxIter);
        }
    }
}

