/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.NGramEstimator;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Text;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Description(name="context_ngrams", value="_FUNC_(expr, array<string1, string2, ...>, k, pf) estimates the top-k most frequent n-grams that fit into the specified context. The second parameter specifies a string of words that specify the positions of the n-gram elements, with a null value standing in for a 'blank' that must be filled by an n-gram element.", extended="The primary expression must be an array of strings, or an array of arrays of strings, such as the return type of the sentences() UDF. The second parameter specifies the context -- for example, array(\"i\", \"love\", null) -- which would estimate the top 'k' words that follow the phrase \"i love\" in the primary expression. The optional fourth parameter 'pf' controls the memory used by the heuristic. Larger values will yield better accuracy, but use more memory. Example usage:\n  SELECT context_ngrams(sentences(lower(review)), array(\"i\", \"love\", null, null), 10) FROM movies\nwould attempt to determine the 10 most common two-word phrases that follow \"i love\" in a database of free-form natural language movie reviews.")
public class GenericUDAFContextNGrams
implements GenericUDAFResolver {
    static final Logger LOG = LoggerFactory.getLogger(GenericUDAFContextNGrams.class.getName());

    @Override
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
        PrimitiveTypeInfo pti;
        if (parameters.length != 3 && parameters.length != 4) {
            throw new UDFArgumentTypeException(parameters.length - 1, "Please specify either three or four arguments.");
        }
        if (parameters[0].getCategory() != ObjectInspector.Category.LIST) {
            throw new UDFArgumentTypeException(0, "Only list type arguments are accepted but " + parameters[0].getTypeName() + " was passed as parameter 1.");
        }
        switch (((ListTypeInfo)parameters[0]).getListElementTypeInfo().getCategory()) {
            case PRIMITIVE: {
                pti = (PrimitiveTypeInfo)((ListTypeInfo)parameters[0]).getListElementTypeInfo();
                break;
            }
            case LIST: {
                ListTypeInfo lti = (ListTypeInfo)((ListTypeInfo)parameters[0]).getListElementTypeInfo();
                pti = (PrimitiveTypeInfo)lti.getListElementTypeInfo();
                break;
            }
            default: {
                throw new UDFArgumentTypeException(0, "Only arrays of strings or arrays of arrays of strings are accepted but " + parameters[0].getTypeName() + " was passed as parameter 1.");
            }
        }
        if (pti.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
            throw new UDFArgumentTypeException(0, "Only array<string> or array<array<string>> is allowed, but " + parameters[0].getTypeName() + " was passed as parameter 1.");
        }
        if (parameters[1].getCategory() != ObjectInspector.Category.LIST || ((ListTypeInfo)parameters[1]).getListElementTypeInfo().getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(1, "Only arrays of strings are accepted but " + parameters[1].getTypeName() + " was passed as parameter 2.");
        }
        if (((PrimitiveTypeInfo)((ListTypeInfo)parameters[1]).getListElementTypeInfo()).getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
            throw new UDFArgumentTypeException(1, "Only arrays of strings are accepted but " + parameters[1].getTypeName() + " was passed as parameter 2.");
        }
        if (parameters[2].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(2, "Only integers are accepted but " + parameters[2].getTypeName() + " was passed as parameter 3.");
        }
        switch (((PrimitiveTypeInfo)parameters[2]).getPrimitiveCategory()) {
            case BYTE: 
            case SHORT: 
            case INT: 
            case LONG: 
            case TIMESTAMP: {
                break;
            }
            default: {
                throw new UDFArgumentTypeException(2, "Only integers are accepted but " + parameters[2].getTypeName() + " was passed as parameter 3.");
            }
        }
        if (parameters.length == 4) {
            if (parameters[3].getCategory() != ObjectInspector.Category.PRIMITIVE) {
                throw new UDFArgumentTypeException(3, "Only integers are accepted but " + parameters[3].getTypeName() + " was passed as parameter 4.");
            }
            switch (((PrimitiveTypeInfo)parameters[3]).getPrimitiveCategory()) {
                case BYTE: 
                case SHORT: 
                case INT: 
                case LONG: 
                case TIMESTAMP: {
                    break;
                }
                default: {
                    throw new UDFArgumentTypeException(3, "Only integers are accepted but " + parameters[3].getTypeName() + " was passed as parameter 4.");
                }
            }
        }
        return new GenericUDAFContextNGramEvaluator();
    }

    public static class GenericUDAFContextNGramEvaluator
    extends GenericUDAFEvaluator {
        private transient ListObjectInspector outerInputOI;
        private transient StandardListObjectInspector innerInputOI;
        private transient ListObjectInspector contextListOI;
        private PrimitiveObjectInspector contextOI;
        private PrimitiveObjectInspector inputOI;
        private transient PrimitiveObjectInspector kOI;
        private transient PrimitiveObjectInspector pOI;
        private transient ListObjectInspector loi;

        @Override
        public ObjectInspector init(GenericUDAFEvaluator.Mode m, ObjectInspector[] parameters) throws HiveException {
            super.init(m, parameters);
            if (m == GenericUDAFEvaluator.Mode.PARTIAL1 || m == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.outerInputOI = (ListObjectInspector)parameters[0];
                if (this.outerInputOI.getListElementObjectInspector().getCategory() == ObjectInspector.Category.LIST) {
                    this.innerInputOI = (StandardListObjectInspector)this.outerInputOI.getListElementObjectInspector();
                    this.inputOI = (PrimitiveObjectInspector)this.innerInputOI.getListElementObjectInspector();
                } else {
                    this.inputOI = (PrimitiveObjectInspector)this.outerInputOI.getListElementObjectInspector();
                    this.innerInputOI = null;
                }
                this.contextListOI = (ListObjectInspector)parameters[1];
                this.contextOI = (PrimitiveObjectInspector)this.contextListOI.getListElementObjectInspector();
                this.kOI = (PrimitiveObjectInspector)parameters[2];
                this.pOI = parameters.length == 4 ? (PrimitiveObjectInspector)parameters[3] : null;
            } else {
                this.loi = (ListObjectInspector)parameters[0];
            }
            if (m == GenericUDAFEvaluator.Mode.PARTIAL1 || m == GenericUDAFEvaluator.Mode.PARTIAL2) {
                return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
            }
            ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
            foi.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector));
            foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            ArrayList<String> fname = new ArrayList<String>();
            fname.add("ngram");
            fname.add("estfrequency");
            return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi));
        }

        @Override
        public void merge(GenericUDAFEvaluator.AggregationBuffer agg, Object obj) throws HiveException {
            if (obj == null) {
                return;
            }
            NGramAggBuf myagg = (NGramAggBuf)agg;
            List<?> partial = this.loi.getList(obj);
            int contextSize = Integer.parseInt(partial.get(partial.size() - 1).toString());
            partial.remove(partial.size() - 1);
            if (myagg.context.size() > 0) {
                if (contextSize != myagg.context.size()) {
                    throw new HiveException(this.getClass().getSimpleName() + ": found a mismatch in the context string lengths. This is usually caused by passing a non-constant expression for the context.");
                }
            } else {
                for (int i = partial.size() - contextSize; i < partial.size(); ++i) {
                    String word = partial.get(i).toString();
                    if (word.equals("")) {
                        myagg.context.add(null);
                        continue;
                    }
                    myagg.context.add(word);
                }
                partial.subList(partial.size() - contextSize, partial.size()).clear();
                myagg.nge.merge(partial);
            }
        }

        @Override
        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            NGramAggBuf myagg = (NGramAggBuf)agg;
            ArrayList<Text> result = myagg.nge.serialize();
            for (int i = 0; i < myagg.context.size(); ++i) {
                if (myagg.context.get(i) == null) {
                    result.add(new Text(""));
                    continue;
                }
                result.add(new Text(myagg.context.get(i)));
            }
            result.add(new Text(Integer.toString(myagg.context.size())));
            return result;
        }

        private void processNgrams(NGramAggBuf agg, ArrayList<String> seq) throws HiveException {
            assert (agg.context.size() > 0);
            ArrayList<String> ng = new ArrayList<String>();
            for (int i = seq.size() - agg.context.size(); i >= 0; --i) {
                boolean contextMatches = true;
                ng.clear();
                for (int j = 0; j < agg.context.size(); ++j) {
                    String contextWord = agg.context.get(j);
                    if (contextWord == null) {
                        ng.add(seq.get(i + j));
                        continue;
                    }
                    if (contextWord.equals(seq.get(i + j))) continue;
                    contextMatches = false;
                    break;
                }
                if (!contextMatches) continue;
                agg.nge.add(ng);
                ng = new ArrayList();
            }
        }

        @Override
        public void iterate(GenericUDAFEvaluator.AggregationBuffer agg, Object[] parameters) throws HiveException {
            String word;
            assert (parameters.length == 3 || parameters.length == 4);
            if (parameters[0] == null || parameters[1] == null || parameters[2] == null) {
                return;
            }
            NGramAggBuf myagg = (NGramAggBuf)agg;
            if (!myagg.nge.isInitialized()) {
                int k = PrimitiveObjectInspectorUtils.getInt(parameters[2], this.kOI);
                int pf = 0;
                if (k < 1) {
                    throw new HiveException(this.getClass().getSimpleName() + " needs 'k' to be at least 1, but you supplied " + k);
                }
                if (parameters.length == 4) {
                    pf = PrimitiveObjectInspectorUtils.getInt(parameters[3], this.pOI);
                    if (pf < 1) {
                        throw new HiveException(this.getClass().getSimpleName() + " needs 'pf' to be at least 1, but you supplied " + pf);
                    }
                } else {
                    pf = 1;
                }
                myagg.context.clear();
                List<?> context = this.contextListOI.getList(parameters[1]);
                int contextNulls = 0;
                for (int i = 0; i < context.size(); ++i) {
                    word = PrimitiveObjectInspectorUtils.getString(context.get(i), this.contextOI);
                    if (word == null) {
                        ++contextNulls;
                    }
                    myagg.context.add(word);
                }
                if (context.size() == 0) {
                    throw new HiveException(this.getClass().getSimpleName() + " needs a context array with at least one element.");
                }
                if (contextNulls == 0) {
                    throw new HiveException(this.getClass().getSimpleName() + " the context array needs to contain at least one 'null' value to indicate what should be counted.");
                }
                myagg.nge.initialize(k, pf, contextNulls);
            }
            List<?> outer = this.outerInputOI.getList(parameters[0]);
            if (this.innerInputOI != null) {
                for (int i = 0; i < outer.size(); ++i) {
                    List<?> inner = this.innerInputOI.getList(outer.get(i));
                    ArrayList<String> words = new ArrayList<String>();
                    for (int j = 0; j < inner.size(); ++j) {
                        word = PrimitiveObjectInspectorUtils.getString(inner.get(j), this.inputOI);
                        words.add(word);
                    }
                    this.processNgrams(myagg, words);
                }
            } else {
                ArrayList<String> words = new ArrayList<String>();
                for (int i = 0; i < outer.size(); ++i) {
                    String word2 = PrimitiveObjectInspectorUtils.getString(outer.get(i), this.inputOI);
                    words.add(word2);
                }
                this.processNgrams(myagg, words);
            }
        }

        @Override
        public Object terminate(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            NGramAggBuf myagg = (NGramAggBuf)agg;
            return myagg.nge.getNGrams();
        }

        @Override
        public GenericUDAFEvaluator.AggregationBuffer getNewAggregationBuffer() throws HiveException {
            NGramAggBuf result = new NGramAggBuf();
            result.nge = new NGramEstimator();
            result.context = new ArrayList();
            this.reset(result);
            return result;
        }

        @Override
        public void reset(GenericUDAFEvaluator.AggregationBuffer agg) throws HiveException {
            NGramAggBuf result = (NGramAggBuf)agg;
            result.context.clear();
            result.nge.reset();
        }

        static class NGramAggBuf
        extends GenericUDAFEvaluator.AbstractAggregationBuffer {
            ArrayList<String> context;
            NGramEstimator nge;

            NGramAggBuf() {
            }
        }
    }
}

