/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import smile.classification.ClassifierTrainer;
import smile.classification.OnlineClassifier;
import smile.classification.SoftClassifier;
import smile.stat.distribution.Distribution;

public class NaiveBayes
implements OnlineClassifier<double[]>,
SoftClassifier<double[]> {
    private static final double EPSILON = 1.0E-20;
    private Model model;
    private int k;
    private int p;
    private double[] priori;
    private Distribution[][] prob;
    private double sigma;
    private boolean predefinedPriori;
    private int n;
    private int[] nc;
    private int[] nt;
    private int[][] ntc;
    private double[][] condprob;

    public NaiveBayes(double[] priori, Distribution[][] condprob) {
        if (priori.length != condprob.length) {
            throw new IllegalArgumentException("The number of priori probabilities and that of the classes are not same.");
        }
        double sum = 0.0;
        for (double pr : priori) {
            if (pr <= 0.0 || pr >= 1.0) {
                throw new IllegalArgumentException("Invalid priori probability: " + pr);
            }
            sum += pr;
        }
        if (Math.abs(sum - 1.0) > 1.0E-10) {
            throw new IllegalArgumentException("The sum of priori probabilities is not one: " + sum);
        }
        this.model = Model.GENERAL;
        this.k = priori.length;
        this.p = condprob[0].length;
        this.priori = priori;
        this.prob = condprob;
        this.predefinedPriori = true;
    }

    public NaiveBayes(Model model, int k, int p) {
        this(model, k, p, 1.0);
    }

    public NaiveBayes(Model model, int k, int p, double sigma) {
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of classes: " + k);
        }
        if (p <= 0) {
            throw new IllegalArgumentException("Invalid dimension: " + p);
        }
        if (sigma < 0.0) {
            throw new IllegalArgumentException("Invalid add-k smoothing parameter: " + sigma);
        }
        this.model = model;
        this.k = k;
        this.p = p;
        this.sigma = sigma;
        this.predefinedPriori = false;
        this.priori = new double[k];
        this.n = 0;
        this.nc = new int[k];
        this.nt = new int[k];
        this.ntc = new int[k][p];
        this.condprob = new double[k][p];
    }

    public NaiveBayes(Model model, double[] priori, int p) {
        this(model, priori, p, 1.0);
    }

    public NaiveBayes(Model model, double[] priori, int p, double sigma) {
        if (p <= 0) {
            throw new IllegalArgumentException("Invalid dimension: " + p);
        }
        if (sigma < 0.0) {
            throw new IllegalArgumentException("Invalid add-k smoothing parameter: " + sigma);
        }
        if (priori.length < 2) {
            throw new IllegalArgumentException("Invalid number of classes: " + priori.length);
        }
        double sum = 0.0;
        for (double pr : priori) {
            if (pr <= 0.0 || pr >= 1.0) {
                throw new IllegalArgumentException("Invalid priori probability: " + pr);
            }
            sum += pr;
        }
        if (Math.abs(sum - 1.0) > 1.0E-10) {
            throw new IllegalArgumentException("The sum of priori probabilities is not one: " + sum);
        }
        this.model = model;
        this.k = priori.length;
        this.p = p;
        this.sigma = sigma;
        this.priori = priori;
        this.predefinedPriori = true;
        sum = 0.0;
        for (int i = 0; i < this.k; ++i) {
            if (priori[i] <= 0.0 || priori[i] >= 1.0) {
                throw new IllegalArgumentException("Invalid priori probability: " + priori[i]);
            }
            sum += priori[i];
        }
        if (Math.abs(1.0 - sum) > 1.0E-5) {
            throw new IllegalArgumentException("Priori probabilities don't sum to 1.");
        }
        this.n = 0;
        this.nc = new int[this.k];
        this.nt = new int[this.k];
        this.ntc = new int[this.k][p];
        this.condprob = new double[this.k][p];
    }

    public double[] getPriori() {
        return this.priori;
    }

    @Override
    public void learn(double[] x, int y) {
        if (this.model == Model.GENERAL) {
            throw new UnsupportedOperationException("General-mode Naive Bayes classifier doesn't support online learning.");
        }
        if (x.length != this.p) {
            throw new IllegalArgumentException("Invalid input vector size: " + x.length);
        }
        if (this.model == Model.MULTINOMIAL) {
            for (int i = 0; i < this.p; ++i) {
                int[] nArray = this.ntc[y];
                int n = i;
                nArray[n] = (int)((double)nArray[n] + x[i]);
                int n2 = y;
                this.nt[n2] = (int)((double)this.nt[n2] + x[i]);
            }
        } else {
            for (int i = 0; i < this.p; ++i) {
                if (!(x[i] > 0.0)) continue;
                int[] nArray = this.ntc[y];
                int n = i;
                nArray[n] = nArray[n] + 1;
            }
        }
        ++this.n;
        int n = y;
        this.nc[n] = this.nc[n] + 1;
        this.update();
    }

    public void learn(double[][] x, int[] y) {
        if (this.model == Model.GENERAL) {
            throw new UnsupportedOperationException("General-mode Naive Bayes classifier doesn't support online learning.");
        }
        if (this.model == Model.MULTINOMIAL) {
            for (int i = 0; i < x.length; ++i) {
                if (x[i].length != this.p) {
                    throw new IllegalArgumentException("Invalid input vector size: " + x[i].length);
                }
                for (int j = 0; j < this.p; ++j) {
                    int[] nArray = this.ntc[y[i]];
                    int n = j;
                    nArray[n] = (int)((double)nArray[n] + x[i][j]);
                    int n2 = y[i];
                    this.nt[n2] = (int)((double)this.nt[n2] + x[i][j]);
                }
                ++this.n;
                int n = y[i];
                this.nc[n] = this.nc[n] + 1;
            }
        } else {
            for (int i = 0; i < x.length; ++i) {
                if (x[i].length != this.p) {
                    throw new IllegalArgumentException("Invalid input vector size: " + x[i].length);
                }
                for (int j = 0; j < this.p; ++j) {
                    if (!(x[i][j] > 0.0)) continue;
                    int[] nArray = this.ntc[y[i]];
                    int n = j;
                    nArray[n] = nArray[n] + 1;
                }
                ++this.n;
                int n = y[i];
                this.nc[n] = this.nc[n] + 1;
            }
        }
        this.update();
    }

    private void update() {
        int c;
        if (!this.predefinedPriori) {
            for (c = 0; c < this.k; ++c) {
                this.priori[c] = ((double)this.nc[c] + 1.0E-20) / ((double)this.n + (double)this.k * 1.0E-20);
            }
        }
        if (this.model == Model.MULTINOMIAL) {
            for (c = 0; c < this.k; ++c) {
                for (int t = 0; t < this.p; ++t) {
                    this.condprob[c][t] = ((double)this.ntc[c][t] + this.sigma) / ((double)this.nt[c] + this.sigma * (double)this.p);
                }
            }
        } else {
            for (c = 0; c < this.k; ++c) {
                for (int t = 0; t < this.p; ++t) {
                    this.condprob[c][t] = ((double)this.ntc[c][t] + this.sigma) / ((double)this.nc[c] + this.sigma * 2.0);
                }
            }
        }
    }

    @Override
    public int predict(double[] x) {
        return this.predict(x, (double[])null);
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d", x.length));
        }
        if (posteriori != null && posteriori.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
        }
        int label = -1;
        double max = Double.NEGATIVE_INFINITY;
        boolean any = this.model == Model.GENERAL;
        for (int i = 0; i < this.k; ++i) {
            double logprob = Math.log(this.priori[i]);
            block6: for (int j = 0; j < this.p; ++j) {
                switch (this.model) {
                    case GENERAL: {
                        logprob += this.prob[i][j].logp(x[j]);
                        continue block6;
                    }
                    case MULTINOMIAL: {
                        if (!(x[j] > 0.0)) continue block6;
                        logprob += x[j] * Math.log(this.condprob[i][j]);
                        any = true;
                        continue block6;
                    }
                    case BERNOULLI: {
                        if (x[j] > 0.0) {
                            logprob += Math.log(this.condprob[i][j]);
                            any = true;
                            continue block6;
                        }
                        logprob += Math.log(1.0 - this.condprob[i][j]);
                    }
                }
            }
            if (logprob > max && any) {
                max = logprob;
                label = i;
            }
            if (posteriori == null) continue;
            posteriori[i] = logprob;
        }
        if (posteriori != null && any) {
            int i;
            double Z = 0.0;
            for (i = 0; i < this.k; ++i) {
                posteriori[i] = Math.exp(posteriori[i] - max);
                Z += posteriori[i];
            }
            i = 0;
            while (i < this.k) {
                int n = i++;
                posteriori[n] = posteriori[n] / Z;
            }
        }
        return label;
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private Model model;
        private int k;
        private int p;
        private double[] priori;
        private double sigma = 1.0;

        public Trainer(Model model, int k, int p) {
            if (k < 2) {
                throw new IllegalArgumentException("Invalid number of classes: " + k);
            }
            if (p <= 0) {
                throw new IllegalArgumentException("Invalid dimension: " + p);
            }
            this.model = model;
            this.k = k;
            this.p = p;
        }

        public Trainer(Model model, double[] priori, int p) {
            if (p <= 0) {
                throw new IllegalArgumentException("Invalid dimension: " + p);
            }
            if (priori.length < 2) {
                throw new IllegalArgumentException("Invalid number of classes: " + priori.length);
            }
            double sum = 0.0;
            for (double prob : priori) {
                if (prob <= 0.0 || prob >= 1.0) {
                    throw new IllegalArgumentException("Invalid priori probability: " + prob);
                }
                sum += prob;
            }
            if (Math.abs(sum - 1.0) > 1.0E-10) {
                throw new IllegalArgumentException("The sum of priori probabilities is not one: " + sum);
            }
            this.model = model;
            this.priori = priori;
            this.k = priori.length;
            this.p = p;
        }

        public Trainer setPriori(double[] priori) {
            this.priori = priori;
            return this;
        }

        public Trainer setSmooth(double sigma) {
            if (sigma < 0.0) {
                throw new IllegalArgumentException("Invalid add-k smoothing parameter: " + sigma);
            }
            this.sigma = sigma;
            return this;
        }

        public NaiveBayes train(double[][] x, int[] y) {
            NaiveBayes bayes = this.priori == null ? new NaiveBayes(this.model, this.k, this.p, this.sigma) : new NaiveBayes(this.model, this.priori, this.p, this.sigma);
            bayes.learn(x, y);
            return bayes;
        }
    }

    public static enum Model {
        GENERAL,
        MULTINOMIAL,
        BERNOULLI;

    }
}

