/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sequences;

import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.sequences.BestSequenceFinder;
import edu.stanford.nlp.sequences.CoolingSchedule;
import edu.stanford.nlp.sequences.SequenceListener;
import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

public class SequenceGibbsSampler
implements BestSequenceFinder {
    private static Random random = new Random(Integer.MAX_VALUE);
    public static int verbose = 0;
    private List document;
    private int numSamples;
    private int sampleInterval;
    private SequenceListener listener;
    private static final int RANDOM_SAMPLING = 0;
    private static final int SEQUENTIAL_SAMPLING = 1;
    private static final int CHROMATIC_SAMPLING = 2;
    public boolean returnLastFoundSequence = false;
    private int samplingStyle;
    private int chromaticSize;
    private List<List<Integer>> partition;

    public static int[] copy(int[] a) {
        int[] result = new int[a.length];
        System.arraycopy(a, 0, result, 0, a.length);
        return result;
    }

    public static int[] getRandomSequence(SequenceModel model) {
        int[] result = new int[model.length()];
        for (int i = 0; i < result.length; ++i) {
            int[] classes = model.getPossibleValues(i);
            result[i] = classes[random.nextInt(classes.length)];
        }
        return result;
    }

    @Override
    public int[] bestSequence(SequenceModel model) {
        int[] initialSequence = SequenceGibbsSampler.getRandomSequence(model);
        return this.findBestUsingSampling(model, this.numSamples, this.sampleInterval, initialSequence);
    }

    public int[] findBestUsingSampling(SequenceModel model, int numSamples, int sampleInterval, int[] initialSequence) {
        List<int[]> samples = this.collectSamples(model, numSamples, sampleInterval, initialSequence);
        int[] best = null;
        double bestScore = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < samples.size(); ++i) {
            int[] sequence = samples.get(i);
            double score = model.scoreOf(sequence);
            if (!(score > bestScore)) continue;
            best = sequence;
            bestScore = score;
            System.err.println("found new best (" + bestScore + ")");
            System.err.println(ArrayMath.toString(best));
        }
        return best;
    }

    public int[] findBestUsingAnnealing(SequenceModel model, CoolingSchedule schedule) {
        int[] initialSequence = SequenceGibbsSampler.getRandomSequence(model);
        return this.findBestUsingAnnealing(model, schedule, initialSequence);
    }

    public int[] findBestUsingAnnealing(SequenceModel model, CoolingSchedule schedule, int[] initialSequence) {
        if (verbose > 0) {
            System.err.println("Doing annealing");
        }
        this.listener.setInitialSequence(initialSequence);
        ArrayList<int[]> result = new ArrayList<int[]>();
        int[] sequence = initialSequence;
        int[] best = null;
        double bestScore = Double.NEGATIVE_INFINITY;
        double score = Double.NEGATIVE_INFINITY;
        if (!this.returnLastFoundSequence) {
            score = model.scoreOf(sequence);
        }
        for (int i = 0; i < schedule.numIterations(); ++i) {
            sequence = SequenceGibbsSampler.copy(sequence);
            double temperature = schedule.getTemperature(i);
            this.sampleSequenceForward(model, sequence, temperature);
            result.add(sequence);
            if (this.returnLastFoundSequence) {
                best = sequence;
            } else {
                score = model.scoreOf(sequence);
                if (score > bestScore) {
                    best = sequence;
                    bestScore = score;
                }
            }
            if (verbose <= 0) continue;
            System.err.print(".");
        }
        if (verbose > 1) {
            System.err.println();
            this.printSamples(result, System.err);
        }
        if (verbose > 0) {
            System.err.println("done.");
        }
        return best;
    }

    public List<int[]> collectSamples(SequenceModel model, int numSamples, int sampleInterval) {
        int[] initialSequence = SequenceGibbsSampler.getRandomSequence(model);
        return this.collectSamples(model, numSamples, sampleInterval, initialSequence);
    }

    public List<int[]> collectSamples(SequenceModel model, int numSamples, int sampleInterval, int[] initialSequence) {
        if (verbose > 0) {
            System.err.print("Collecting samples");
        }
        this.listener.setInitialSequence(initialSequence);
        ArrayList<int[]> result = new ArrayList<int[]>();
        int[] sequence = initialSequence;
        for (int i = 0; i < numSamples; ++i) {
            sequence = SequenceGibbsSampler.copy(sequence);
            this.sampleSequenceRepeatedly(model, sequence, sampleInterval);
            result.add(sequence);
            if (verbose > 0) {
                System.err.print(".");
            }
            System.err.flush();
        }
        if (verbose > 1) {
            System.err.println();
            this.printSamples(result, System.err);
        }
        if (verbose > 0) {
            System.err.println("done.");
        }
        return result;
    }

    public void sampleSequenceRepeatedly(SequenceModel model, int[] sequence, int numSamples) {
        sequence = SequenceGibbsSampler.copy(sequence);
        this.listener.setInitialSequence(sequence);
        for (int iter = 0; iter < numSamples; ++iter) {
            this.sampleSequenceForward(model, sequence);
        }
    }

    public void sampleSequenceRepeatedly(SequenceModel model, int numSamples) {
        int[] sequence = SequenceGibbsSampler.getRandomSequence(model);
        this.sampleSequenceRepeatedly(model, sequence, numSamples);
    }

    public void sampleSequenceForward(SequenceModel model, int[] sequence) {
        this.sampleSequenceForward(model, sequence, 1.0);
    }

    public void sampleSequenceForward(final SequenceModel model, final int[] sequence, final double temperature) {
        block10: {
            block11: {
                block9: {
                    if (this.samplingStyle != 1) break block9;
                    for (int pos = 0; pos < sequence.length; ++pos) {
                        this.samplePosition(model, sequence, pos, temperature);
                    }
                    break block10;
                }
                if (this.samplingStyle != 0) break block11;
                for (int itr = 0; itr < sequence.length; ++itr) {
                    int pos = random.nextInt(sequence.length);
                    this.samplePosition(model, sequence, pos, temperature);
                }
                break block10;
            }
            if (this.samplingStyle != 2) break block10;
            ArrayList results = new ArrayList();
            for (List<Integer> indieList : this.partition) {
                if (indieList.size() <= this.chromaticSize) {
                    for (int pos : indieList) {
                        Pair<Integer, Double> newPosProb = this.samplePositionHelper(model, sequence, pos, temperature);
                        sequence[pos] = newPosProb.first();
                    }
                    continue;
                }
                MulticoreWrapper<List<Integer>, List<Pair<Integer, Integer>>> wrapper = new MulticoreWrapper<List<Integer>, List<Pair<Integer, Integer>>>(this.chromaticSize, new ThreadsafeProcessor<List<Integer>, List<Pair<Integer, Integer>>>(){

                    @Override
                    public List<Pair<Integer, Integer>> process(List<Integer> posList) {
                        ArrayList<Pair<Integer, Integer>> allPos = new ArrayList<Pair<Integer, Integer>>(posList.size());
                        Pair newPosProb = null;
                        for (int pos : posList) {
                            newPosProb = SequenceGibbsSampler.this.samplePositionHelper(model, sequence, pos, temperature);
                            allPos.add(new Pair(pos, newPosProb.first()));
                        }
                        return allPos;
                    }

                    @Override
                    public ThreadsafeProcessor<List<Integer>, List<Pair<Integer, Integer>>> newInstance() {
                        return this;
                    }
                });
                results.clear();
                int interval = Math.max(1, indieList.size() / this.chromaticSize);
                int begin = 0;
                int end = 0;
                int indieListSize = indieList.size();
                while (end < indieListSize) {
                    end = Math.min(begin + interval, indieListSize);
                    wrapper.submit(indieList.subList(begin, end));
                    while (wrapper.hasNext()) {
                        results.addAll(wrapper.next());
                    }
                    begin += interval;
                }
                wrapper.join();
                while (wrapper.hasNext()) {
                    results.addAll(wrapper.next());
                }
                for (Pair posVal : results) {
                    sequence[((Integer)posVal.first()).intValue()] = (Integer)posVal.second();
                }
            }
        }
    }

    public void sampleSequenceBackward(SequenceModel model, int[] sequence) {
        this.sampleSequenceBackward(model, sequence, 1.0);
    }

    public void sampleSequenceBackward(SequenceModel model, int[] sequence, double temperature) {
        for (int pos = sequence.length - 1; pos >= 0; --pos) {
            this.samplePosition(model, sequence, pos, temperature);
        }
    }

    public double samplePosition(SequenceModel model, int[] sequence, int pos) {
        return this.samplePosition(model, sequence, pos, 1.0);
    }

    private Pair<Integer, Double> samplePositionHelper(SequenceModel model, int[] sequence, int pos, double temperature) {
        double[] distribution = model.scoresOf(sequence, pos);
        if (temperature != 1.0) {
            if (temperature == 0.0) {
                int argmax = ArrayMath.argmax(distribution);
                Arrays.fill(distribution, Double.NEGATIVE_INFINITY);
                distribution[argmax] = 0.0;
            } else {
                ArrayMath.multiplyInPlace(distribution, 1.0 / temperature);
            }
        }
        ArrayMath.logNormalize(distribution);
        ArrayMath.expInPlace(distribution);
        int newTag = ArrayMath.sampleFromDistribution(distribution, random);
        double newProb = distribution[newTag];
        return new Pair<Integer, Double>(newTag, newProb);
    }

    public double samplePosition(SequenceModel model, int[] sequence, int pos, double temperature) {
        Pair<Integer, Double> newPosProb = this.samplePositionHelper(model, sequence, pos, temperature);
        int newTag = newPosProb.first();
        int oldTag = sequence[pos];
        sequence[pos] = newTag;
        this.listener.updateSequenceElement(sequence, pos, oldTag);
        return newPosProb.second();
    }

    public void printSamples(List samples, PrintStream out2) {
        for (int i = 0; i < this.document.size(); ++i) {
            HasWord word = (HasWord)this.document.get(i);
            String s = "null";
            if (word != null) {
                s = word.word();
            }
            out2.print(StringUtils.padOrTrim(s, 10));
            for (int j = 0; j < samples.size(); ++j) {
                int[] sequence = (int[])samples.get(j);
                out2.print(" " + StringUtils.padLeft(sequence[i], 2));
            }
            out2.println();
        }
    }

    public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener, List document, boolean returnLastFoundSequence, int samplingStyle, int chromaticSize, List<List<Integer>> partition) {
        this.numSamples = numSamples;
        this.sampleInterval = sampleInterval;
        this.listener = listener;
        this.document = document;
        this.returnLastFoundSequence = returnLastFoundSequence;
        this.samplingStyle = samplingStyle;
        if (verbose > 0) {
            if (samplingStyle == 0) {
                System.err.println("Using random sampling");
            } else if (samplingStyle == 2) {
                System.err.println("Using chromatic sampling with " + chromaticSize + " threads");
            } else if (samplingStyle == 1) {
                System.err.println("Using sequential sampling");
            }
        }
        this.chromaticSize = chromaticSize;
        this.partition = partition;
    }

    public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener, List document) {
        this(numSamples, sampleInterval, listener, document, false, 1, 0, null);
    }

    public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener) {
        this(numSamples, sampleInterval, listener, null);
    }

    public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener, int samplingStyle, int chromaticSize, List<List<Integer>> partition) {
        this(numSamples, sampleInterval, listener, null, false, samplingStyle, chromaticSize, partition);
    }
}

