/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import java.util.List;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.lops.PickByCount;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.DoubleObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;

public class QuantilePickSPInstruction
extends BinarySPInstruction {
    private PickByCount.OperationTypes _type = null;

    public QuantilePickSPInstruction(Operator op, CPOperand in, CPOperand out, PickByCount.OperationTypes type, boolean inmem, String opcode, String istr) {
        this(op, in, null, out, type, inmem, opcode, istr);
    }

    public QuantilePickSPInstruction(Operator op, CPOperand in, CPOperand in2, CPOperand out, PickByCount.OperationTypes type, boolean inmem, String opcode, String istr) {
        super(op, in, in2, out, opcode, istr);
        this._sptype = SPInstruction.SPINSTRUCTION_TYPE.QPick;
        this._type = type;
    }

    public static QuantilePickSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("qpick")) {
            throw new DMLRuntimeException("Unknown opcode while parsing a QuantilePickCPInstruction: " + str);
        }
        if (parts.length == 4) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            PickByCount.OperationTypes ptype = PickByCount.OperationTypes.IQM;
            boolean inmem = false;
            return new QuantilePickSPInstruction(null, in1, in2, out, ptype, inmem, opcode, str);
        }
        if (parts.length == 5) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand out = new CPOperand(parts[2]);
            PickByCount.OperationTypes ptype = PickByCount.OperationTypes.valueOf(parts[3]);
            boolean inmem = Boolean.parseBoolean(parts[4]);
            return new QuantilePickSPInstruction(null, in1, out, ptype, inmem, opcode, str);
        }
        if (parts.length == 6) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            PickByCount.OperationTypes ptype = PickByCount.OperationTypes.valueOf(parts[4]);
            boolean inmem = Boolean.parseBoolean(parts[5]);
            return new QuantilePickSPInstruction(null, in1, in2, out, ptype, inmem, opcode, str);
        }
        return null;
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        MatrixCharacteristics mc = sec.getMatrixCharacteristics(this.input1.getName());
        boolean weighted = mc.getCols() == 2L;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        switch (this._type) {
            case VALUEPICK: {
                double sum_wt = weighted ? this.sumWeights(in) : (double)mc.getRows();
                ScalarObject quantile = ec.getScalarInput(this.input2.getName(), this.input2.getValueType(), this.input2.isLiteral());
                long key = (long)Math.ceil(quantile.getDoubleValue() * sum_wt);
                double val = this.lookupKey(in, key, mc.getRowsPerBlock());
                ec.setScalarOutput(this.output.getName(), new DoubleObject(val));
                break;
            }
            case MEDIAN: {
                double sum_wt = weighted ? this.sumWeights(in) : (double)mc.getRows();
                long key = (long)Math.ceil(0.5 * sum_wt);
                double val = this.lookupKey(in, key, mc.getRowsPerBlock());
                ec.setScalarOutput(this.output.getName(), new DoubleObject(val));
                break;
            }
            case IQM: {
                double sum_wt = weighted ? this.sumWeights(in) : (double)mc.getRows();
                long key25 = (long)Math.ceil(0.25 * sum_wt);
                long key75 = (long)Math.ceil(0.75 * sum_wt);
                double val25 = this.lookupKey(in, key25, mc.getRowsPerBlock());
                double val75 = this.lookupKey(in, key75, mc.getRowsPerBlock());
                JavaPairRDD out = in.filter((Function)new FilterFunction(key25 + 1L, key75, mc.getRowsPerBlock())).mapToPair((PairFunction)new ExtractAndSumFunction(key25 + 1L, key75, mc.getRowsPerBlock()));
                MatrixBlock mb = RDDAggregateUtils.sumStable((JavaPairRDD<MatrixIndexes, MatrixBlock>)out);
                double val = (mb.getValue(0, 0) + ((double)key25 - 0.25 * sum_wt) * val25 - ((double)key75 - 0.75 * sum_wt) * val75) / (0.5 * sum_wt);
                ec.setScalarOutput(this.output.getName(), new DoubleObject(val));
                break;
            }
            default: {
                throw new DMLRuntimeException("Unsupported qpick operation type: " + (Object)((Object)this._type));
            }
        }
    }

    private double lookupKey(JavaPairRDD<MatrixIndexes, MatrixBlock> in, long key, int brlen) {
        long rix = UtilFunctions.computeBlockIndex(key, brlen);
        long pos = UtilFunctions.computeCellInBlock(key, brlen);
        List val = in.lookup((Object)new MatrixIndexes(rix, 1L));
        return ((MatrixBlock)val.get(0)).quickGetValue((int)pos, 0);
    }

    private double sumWeights(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
        JavaPairRDD tmp = in.mapValues((Function)new ExtractAndSumWeightsFunction());
        MatrixBlock val = RDDAggregateUtils.sumStable((JavaPairRDD<MatrixIndexes, MatrixBlock>)tmp);
        return val.quickGetValue(0, 0);
    }

    private static class ExtractAndSumWeightsFunction
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 7169831202450745373L;

        private ExtractAndSumWeightsFunction() {
        }

        public MatrixBlock call(MatrixBlock arg0) throws Exception {
            MatrixBlock mb = arg0.sliceOperations(0, arg0.getNumRows() - 1, 1, 1, new MatrixBlock());
            MatrixBlock ret = new MatrixBlock(1, 2, false);
            ret.setValue(0, 0, mb.sum());
            return ret;
        }
    }

    private static class ExtractAndSumFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -584044441055250489L;
        private long _minRowIndex;
        private long _maxRowIndex;
        private int _minPos;
        private int _maxPos;

        public ExtractAndSumFunction(long key25, long key75, int brlen) {
            this._minRowIndex = UtilFunctions.computeBlockIndex(key25, brlen);
            this._maxRowIndex = UtilFunctions.computeBlockIndex(key75, brlen);
            this._minPos = UtilFunctions.computeCellInBlock(key25, brlen);
            this._maxPos = UtilFunctions.computeCellInBlock(key75, brlen);
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ix = (MatrixIndexes)arg0._1();
            MatrixBlock mb = (MatrixBlock)arg0._2();
            if (this._minRowIndex == this._maxRowIndex) {
                mb = mb.sliceOperations(this._minPos - 1, this._maxPos - 1, 0, 0, new MatrixBlock());
            } else if (ix.getRowIndex() == this._minRowIndex) {
                mb = mb.sliceOperations(this._minPos, mb.getNumRows() - 1, 0, 0, new MatrixBlock());
            } else if (ix.getRowIndex() == this._maxRowIndex) {
                mb = mb.sliceOperations(0, this._maxPos, 0, 0, new MatrixBlock());
            }
            MatrixBlock ret = new MatrixBlock(1, 2, false);
            ret.setValue(0, 0, mb.sum());
            return new Tuple2((Object)new MatrixIndexes(1L, 1L), (Object)ret);
        }
    }

    private static class FilterFunction
    implements Function<Tuple2<MatrixIndexes, MatrixBlock>, Boolean> {
        private static final long serialVersionUID = -8249102381116157388L;
        private long _minRowIndex;
        private long _maxRowIndex;

        public FilterFunction(long key25, long key75, int brlen) {
            this._minRowIndex = UtilFunctions.computeBlockIndex(key25, brlen);
            this._maxRowIndex = UtilFunctions.computeBlockIndex(key75, brlen);
        }

        public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            long rowIndex = ((MatrixIndexes)arg0._1()).getRowIndex();
            return rowIndex >= this._minRowIndex && rowIndex <= this._maxRowIndex;
        }
    }
}

