/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.lsh;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.JoinFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.common.typeinfo.PriorityQueueTypeInfo;
import org.apache.flink.ml.feature.lsh.LSHModelData;
import org.apache.flink.ml.feature.lsh.LSHModelParams;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.windowing.assigners.WindowAssigner;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

abstract class LSHModel<T extends LSHModel<T>>
implements Model<T>,
LSHModelParams<T> {
    private static final String MODEL_DATA_BC_KEY = "modelData";
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private final Class<? extends LSHModelData> modelDataClass;
    protected Table modelDataTable;

    public LSHModel(Class<? extends LSHModelData> modelDataClass) {
        this.modelDataClass = modelDataClass;
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override
    public T setModelData(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        this.modelDataTable = inputs[0];
        return (T)this;
    }

    @Override
    public Table[] getModelData() {
        return new Table[]{this.modelDataTable};
    }

    @Override
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    @Override
    public Table[] transform(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        DataStream modelData = tEnv.toDataStream(this.modelDataTable, this.modelDataClass);
        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
        TypeInformation outputType = TypeInformation.of(DenseVector[].class);
        RowTypeInfo outputTypeInfo = new RowTypeInfo((TypeInformation[])ArrayUtils.addAll((Object[])inputTypeInfo.getFieldTypes(), (Object[])new TypeInformation[]{outputType}), (String[])ArrayUtils.addAll((Object[])inputTypeInfo.getFieldNames(), (Object[])new String[]{this.getOutputCol()}));
        DataStream output = BroadcastUtils.withBroadcastStream(Collections.singletonList(tEnv.toDataStream(inputs[0])), Collections.singletonMap(MODEL_DATA_BC_KEY, modelData), inputList -> {
            DataStream data = (DataStream)inputList.get(0);
            return data.map((MapFunction)new PredictFunction(this.getInputCol()), (TypeInformation)outputTypeInfo);
        });
        return new Table[]{tEnv.fromDataStream(output)};
    }

    public Table approxNearestNeighbors(Table dataset, Vector key, int k, String distCol) {
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)dataset).getTableEnvironment();
        Table transformedTable = dataset.getResolvedSchema().getColumnNames().contains(this.getOutputCol()) ? dataset : this.transform(dataset)[0];
        DataStream modelData = tEnv.toDataStream(this.modelDataTable, this.modelDataClass);
        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(transformedTable.getResolvedSchema());
        RowTypeInfo outputTypeInfo = new RowTypeInfo((TypeInformation[])ArrayUtils.addAll((Object[])inputTypeInfo.getFieldTypes(), (Object[])new TypeInformation[]{Types.DOUBLE}), (String[])ArrayUtils.addAll((Object[])inputTypeInfo.getFieldNames(), (Object[])new String[]{distCol}));
        DataStream filteredData = BroadcastUtils.withBroadcastStream(Collections.singletonList(tEnv.toDataStream(transformedTable)), Collections.singletonMap(MODEL_DATA_BC_KEY, modelData), inputList -> {
            DataStream data = (DataStream)inputList.get(0);
            return data.flatMap((FlatMapFunction)new FilterByBucketFunction(this.getInputCol(), this.getOutputCol(), key), (TypeInformation)outputTypeInfo);
        });
        TopKFunction topKFunction = new TopKFunction(distCol, k);
        DataStream<List<Row>> topKList = DataStreamUtils.aggregate(filteredData, topKFunction, new PriorityQueueTypeInfo(topKFunction.getComparator(), outputTypeInfo), Types.LIST((TypeInformation)outputTypeInfo));
        SingleOutputStreamOperator topKData = topKList.flatMap((FlatMapFunction & Serializable)(value, out) -> {
            for (Row row : value) {
                out.collect((Object)row);
            }
        });
        topKData.getTransformation().setOutputType((TypeInformation)outputTypeInfo);
        return tEnv.fromDataStream((DataStream)topKData);
    }

    public Table approxNearestNeighbors(Table dataset, Vector key, int k) {
        return this.approxNearestNeighbors(dataset, key, k, "distCol");
    }

    public Table approxSimilarityJoin(Table datasetA, Table datasetB, double threshold, String idCol, String distCol) {
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)datasetA).getTableEnvironment();
        DataStream<Row> explodedA = this.preprocessData(datasetA, idCol);
        DataStream<Row> explodedB = this.preprocessData(datasetB, idCol);
        RowTypeInfo inputTypeInfo = this.getOutputType(datasetA, idCol);
        RowTypeInfo outputTypeInfo = new RowTypeInfo(new TypeInformation[]{inputTypeInfo.getTypeAt(0), inputTypeInfo.getTypeAt(0), inputTypeInfo.getTypeAt(1), inputTypeInfo.getTypeAt(1)});
        DataStream modelData = tEnv.toDataStream(this.modelDataTable, this.modelDataClass);
        DataStream sameBucketPairs = explodedA.join(explodedB).where((KeySelector)new IndexHashValueKeySelector()).equalTo((KeySelector)new IndexHashValueKeySelector()).window((WindowAssigner)EndOfStreamWindows.get()).apply((JoinFunction & Serializable)(r0, r1) -> Row.of((Object[])new Object[]{r0.getField(0), r1.getField(0), r0.getField(1), r1.getField(1)}), (TypeInformation)outputTypeInfo);
        DataStream distinctSameBucketPairs = DataStreamUtils.reduce(sameBucketPairs.keyBy((KeySelector)new KeySelector<Row, Tuple2<Integer, Integer>>(){

            public Tuple2<Integer, Integer> getKey(Row r) {
                return Tuple2.of((Object)((Integer)r.getFieldAs(0)), (Object)((Integer)r.getFieldAs(1)));
            }
        }), (ReduceFunction & Serializable)(r0, r1) -> r0, outputTypeInfo);
        TypeInformation idColType = TableUtils.getRowTypeInfo(datasetA.getResolvedSchema()).getTypeAt(idCol);
        DataStream pairsWithDists = BroadcastUtils.withBroadcastStream(Collections.singletonList(distinctSameBucketPairs), Collections.singletonMap(MODEL_DATA_BC_KEY, modelData), inputList -> {
            DataStream data = (DataStream)inputList.get(0);
            return data.flatMap((FlatMapFunction)new FilterByDistanceFunction(threshold), (TypeInformation)new RowTypeInfo(new TypeInformation[]{idColType, idColType, Types.DOUBLE}, new String[]{"datasetA.id", "datasetB.id", distCol}));
        });
        return tEnv.fromDataStream(pairsWithDists);
    }

    public Table approxSimilarityJoin(Table datasetA, Table datasetB, double threshold, String idCol) {
        return this.approxSimilarityJoin(datasetA, datasetB, threshold, idCol, "distCol");
    }

    private DataStream<Row> preprocessData(Table dataTable, String idCol) {
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)dataTable).getTableEnvironment();
        dataTable = dataTable.getResolvedSchema().getColumnNames().contains(this.getOutputCol()) ? dataTable : this.transform(dataTable)[0];
        RowTypeInfo outputTypeInfo = this.getOutputType(dataTable, idCol);
        return tEnv.toDataStream(dataTable).flatMap((FlatMapFunction)new ExplodeHashValuesFunction(idCol, this.getInputCol(), this.getOutputCol()), (TypeInformation)outputTypeInfo);
    }

    private RowTypeInfo getOutputType(Table dataTable, String idCol) {
        String indexCol = "index";
        String hashValueCol = "hashValue";
        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(dataTable.getResolvedSchema());
        TypeInformation idColType = inputTypeInfo.getTypeAt(idCol);
        RowTypeInfo outputTypeInfo = new RowTypeInfo(new TypeInformation[]{idColType, VectorTypeInfo.INSTANCE, Types.INT, DenseVectorTypeInfo.INSTANCE}, new String[]{idCol, this.getInputCol(), "index", "hashValue"});
        return outputTypeInfo;
    }

    private static class FilterByDistanceFunction
    extends RichFlatMapFunction<Row, Row> {
        private final double threshold;
        private LSHModelData modelData;

        public FilterByDistanceFunction(double threshold) {
            this.threshold = threshold;
        }

        public void flatMap(Row value, Collector<Row> out) throws Exception {
            double dist;
            if (null == this.modelData) {
                this.modelData = (LSHModelData)this.getRuntimeContext().getBroadcastVariable(LSHModel.MODEL_DATA_BC_KEY).get(0);
            }
            if ((dist = this.modelData.keyDistance((Vector)value.getFieldAs(2), (Vector)value.getFieldAs(3))) <= this.threshold) {
                out.collect((Object)Row.of((Object[])new Object[]{value.getFieldAs(0), value.getFieldAs(1), dist}));
            }
        }
    }

    private static class IndexHashValueKeySelector
    implements KeySelector<Row, Tuple2<Integer, DenseVector>> {
        private IndexHashValueKeySelector() {
        }

        public Tuple2<Integer, DenseVector> getKey(Row value) throws Exception {
            return Tuple2.of((Object)((Integer)value.getFieldAs(2)), (Object)((DenseVector)value.getFieldAs(3)));
        }
    }

    private static class ExplodeHashValuesFunction
    implements FlatMapFunction<Row, Row> {
        private final String idCol;
        private final String inputCol;
        private final String outputCol;

        public ExplodeHashValuesFunction(String idCol, String inputCol, String outputCol) {
            this.idCol = idCol;
            this.inputCol = inputCol;
            this.outputCol = outputCol;
        }

        public void flatMap(Row value, Collector<Row> out) throws Exception {
            Row kept = Row.of((Object[])new Object[]{value.getField(this.idCol), value.getField(this.inputCol)});
            DenseVector[] hashValues = (DenseVector[])value.getFieldAs(this.outputCol);
            for (int i = 0; i < hashValues.length; ++i) {
                out.collect((Object)Row.join((Row)kept, (Row[])new Row[]{Row.of((Object[])new Object[]{i, hashValues[i]})}));
            }
        }
    }

    private static class TopKFunction
    implements AggregateFunction<Row, PriorityQueue<Row>, List<Row>> {
        private final int numNearestNeighbors;
        private final String distCol;

        public TopKFunction(String distCol, int numNearestNeighbors) {
            this.distCol = distCol;
            this.numNearestNeighbors = numNearestNeighbors;
        }

        public PriorityQueue<Row> createAccumulator() {
            return new PriorityQueue<Row>(this.numNearestNeighbors, this.getComparator());
        }

        public PriorityQueue<Row> add(Row value, PriorityQueue<Row> accumulator) {
            if (accumulator.size() == this.numNearestNeighbors) {
                Row peek = accumulator.peek();
                if (accumulator.comparator().compare(value, peek) < 0) {
                    accumulator.poll();
                }
            }
            accumulator.add(value);
            return accumulator;
        }

        public List<Row> getResult(PriorityQueue<Row> accumulator) {
            return new ArrayList<Row>(accumulator);
        }

        public PriorityQueue<Row> merge(PriorityQueue<Row> a, PriorityQueue<Row> b) {
            PriorityQueue<Row> merged = new PriorityQueue<Row>(a);
            for (Row row : b) {
                this.add(row, merged);
            }
            return merged;
        }

        private Comparator<Row> getComparator() {
            return new DistColComparator(this.distCol);
        }

        private static class DistColComparator
        implements Comparator<Row>,
        Serializable {
            private final String distCol;

            private DistColComparator(String distCol) {
                this.distCol = distCol;
            }

            @Override
            public int compare(Row o1, Row o2) {
                return Double.compare((Double)o1.getFieldAs(this.distCol), (Double)o2.getFieldAs(this.distCol));
            }
        }
    }

    private static class FilterByBucketFunction
    extends RichFlatMapFunction<Row, Row> {
        private final String inputCol;
        private final String outputCol;
        private final Vector key;
        private LSHModelData modelData;
        private DenseVector[] keyHashes;

        public FilterByBucketFunction(String inputCol, String outputCol, Vector key) {
            this.inputCol = inputCol;
            this.outputCol = outputCol;
            this.key = key;
        }

        public void flatMap(Row value, Collector<Row> out) throws Exception {
            if (null == this.modelData) {
                this.modelData = (LSHModelData)this.getRuntimeContext().getBroadcastVariable(LSHModel.MODEL_DATA_BC_KEY).get(0);
                this.keyHashes = this.modelData.hashFunction(this.key);
            }
            DenseVector[] hashes = (DenseVector[])value.getFieldAs(this.outputCol);
            boolean sameBucket = false;
            for (int i = 0; i < this.keyHashes.length; ++i) {
                if (!this.keyHashes[i].equals(hashes[i])) continue;
                sameBucket = true;
                break;
            }
            if (!sameBucket) {
                return;
            }
            Vector vec = (Vector)value.getFieldAs(this.inputCol);
            double dist = this.modelData.keyDistance(this.key, vec);
            out.collect((Object)Row.join((Row)value, (Row[])new Row[]{Row.of((Object[])new Object[]{dist})}));
        }
    }

    private static class PredictFunction
    extends RichMapFunction<Row, Row> {
        private final String inputCol;
        private LSHModelData modelData;

        public PredictFunction(String inputCol) {
            this.inputCol = inputCol;
        }

        public Row map(Row value) throws Exception {
            if (null == this.modelData) {
                this.modelData = (LSHModelData)this.getRuntimeContext().getBroadcastVariable(LSHModel.MODEL_DATA_BC_KEY).get(0);
            }
            DenseVector[] hashValues = this.modelData.hashFunction((Vector)value.getFieldAs(this.inputCol));
            return Row.join((Row)value, (Row[])new Row[]{Row.of((Object[])new Object[]{hashValues})});
        }
    }
}

