/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.DiagIndex;
import org.apache.sysml.runtime.functionobjects.RevIndex;
import org.apache.sysml.runtime.functionobjects.SortIndex;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysml.runtime.instructions.spark.functions.FilterDiagBlocksFunction;
import org.apache.sysml.runtime.instructions.spark.functions.IsBlockInRange;
import org.apache.sysml.runtime.instructions.spark.functions.ReorgMapFunction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDSortUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;

public class ReorgSPInstruction
extends UnarySPInstruction {
    private CPOperand _col = null;
    private CPOperand _desc = null;
    private CPOperand _ixret = null;
    private boolean _bSortIndInMem = false;

    public ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) {
        super(op, in, out, opcode, istr);
        this._sptype = SPInstruction.SPINSTRUCTION_TYPE.Reorg;
    }

    public ReorgSPInstruction(Operator op, CPOperand in, CPOperand col, CPOperand desc, CPOperand ixret, CPOperand out, String opcode, boolean bSortIndInMem, String istr) {
        this(op, in, out, opcode, istr);
        this._col = col;
        this._desc = desc;
        this._ixret = ixret;
        this._sptype = SPInstruction.SPINSTRUCTION_TYPE.Reorg;
        this._bSortIndInMem = bSortIndInMem;
    }

    public static ReorgSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        CPOperand in = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
        CPOperand out = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
        String opcode = InstructionUtils.getOpCode(str);
        if (opcode.equalsIgnoreCase("r'")) {
            ReorgSPInstruction.parseUnaryInstruction(str, in, out);
            return new ReorgSPInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), in, out, opcode, str);
        }
        if (opcode.equalsIgnoreCase("rev")) {
            ReorgSPInstruction.parseUnaryInstruction(str, in, out);
            return new ReorgSPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
        }
        if (opcode.equalsIgnoreCase("rdiag")) {
            ReorgSPInstruction.parseUnaryInstruction(str, in, out);
            return new ReorgSPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
        }
        if (opcode.equalsIgnoreCase("rsort")) {
            String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
            InstructionUtils.checkNumFields(parts, 5, 6);
            in.split(parts[1]);
            out.split(parts[5]);
            CPOperand col = new CPOperand(parts[2]);
            CPOperand desc = new CPOperand(parts[3]);
            CPOperand ixret = new CPOperand(parts[4]);
            boolean bSortIndInMem = false;
            if (parts.length > 5) {
                bSortIndInMem = Boolean.parseBoolean(parts[6]);
            }
            return new ReorgSPInstruction(new ReorgOperator(SortIndex.getSortIndexFnObject(1, false, false)), in, col, desc, ixret, out, opcode, bSortIndInMem, str);
        }
        throw new DMLRuntimeException("Unknown opcode while parsing a ReorgInstruction: " + str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        String opcode = this.getOpcode();
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        Object out = null;
        MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(this.input1.getName());
        if (opcode.equalsIgnoreCase("r'")) {
            out = in1.mapToPair((PairFunction)new ReorgMapFunction(opcode));
        } else if (opcode.equalsIgnoreCase("rev")) {
            out = in1.flatMapToPair((PairFlatMapFunction)new RDDRevFunction(mcIn));
            if (mcIn.getRows() % (long)mcIn.getRowsPerBlock() != 0L) {
                out = RDDAggregateUtils.mergeByKey(out, false);
            }
        } else if (opcode.equalsIgnoreCase("rdiag")) {
            out = mcIn.getCols() == 1L ? in1.flatMapToPair((PairFlatMapFunction)new RDDDiagV2MFunction(mcIn)) : in1.filter((Function)new FilterDiagBlocksFunction()).mapToPair((PairFunction)new ReorgMapFunction(opcode));
        } else if (opcode.equalsIgnoreCase("rsort")) {
            long col = ec.getScalarInput(this._col.getName(), this._col.getValueType(), this._col.isLiteral()).getLongValue();
            boolean desc = ec.getScalarInput(this._desc.getName(), this._desc.getValueType(), this._desc.isLiteral()).getBooleanValue();
            boolean ixret = ec.getScalarInput(this._ixret.getName(), this._ixret.getValueType(), this._ixret.isLiteral()).getBooleanValue();
            boolean singleCol = mcIn.getCols() == 1L;
            out = in1;
            if (!singleCol) {
                out = out.filter((Function)new IsBlockInRange(1L, mcIn.getRows(), col, col, mcIn)).mapValues((Function)new ExtractColumn(UtilFunctions.computeCellInBlock(col, mcIn.getColsPerBlock())));
            }
            out = ixret ? RDDSortUtils.sortIndexesByVal(out, !desc, mcIn.getRows(), mcIn.getRowsPerBlock()) : (singleCol && !desc ? RDDSortUtils.sortByVal(out, mcIn.getRows(), mcIn.getRowsPerBlock()) : (!this._bSortIndInMem ? RDDSortUtils.sortDataByVal(out, in1, !desc, mcIn.getRows(), mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock()) : RDDSortUtils.sortDataByValMemSort(out, in1, !desc, mcIn.getRows(), mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock(), sec, (ReorgOperator)this._optr)));
        } else {
            throw new DMLRuntimeException("Error: Incorrect opcode in ReorgSPInstruction:" + opcode);
        }
        this.updateReorgMatrixCharacteristics(sec);
        sec.setRDDHandleForVariable(this.output.getName(), (JavaPairRDD<?, ?>)out);
        sec.addLineageRDD(this.output.getName(), this.input1.getName());
    }

    private void updateReorgMatrixCharacteristics(SparkExecutionContext sec) throws DMLRuntimeException {
        MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this.output.getName());
        if (!mcOut.dimsKnown()) {
            if (!mc1.dimsKnown()) {
                throw new DMLRuntimeException("Unable to compute output matrix characteristics from input.");
            }
            if (this.getOpcode().equalsIgnoreCase("r'")) {
                mcOut.set(mc1.getCols(), mc1.getRows(), mc1.getColsPerBlock(), mc1.getRowsPerBlock());
            } else if (this.getOpcode().equalsIgnoreCase("rdiag")) {
                mcOut.set(mc1.getRows(), mc1.getCols() > 1L ? 1L : mc1.getRows(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
            } else if (this.getOpcode().equalsIgnoreCase("rsort")) {
                boolean ixret = sec.getScalarInput(this._ixret.getName(), this._ixret.getValueType(), this._ixret.isLiteral()).getBooleanValue();
                mcOut.set(mc1.getRows(), ixret ? 1L : mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
            }
        }
        if (!mcOut.nnzKnown() && mc1.nnzKnown()) {
            boolean sortIx;
            boolean bl = sortIx = this.getOpcode().equalsIgnoreCase("rsort") && sec.getScalarInput(this._ixret.getName(), this._ixret.getValueType(), this._ixret.isLiteral()).getBooleanValue();
            if (sortIx) {
                mcOut.setNonZeros(mc1.getRows());
            } else {
                mcOut.setNonZeros(mc1.getNonZeros());
            }
        }
    }

    private static class ExtractColumn
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -1472164797288449559L;
        private int _col;

        public ExtractColumn(int col) {
            this._col = col;
        }

        public MatrixBlock call(MatrixBlock arg0) throws Exception {
            return arg0.sliceOperations(0, arg0.getNumRows() - 1, this._col, this._col, new MatrixBlock());
        }
    }

    private static class RDDRevFunction
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1183373828539843938L;
        private MatrixCharacteristics _mcIn = null;

        public RDDRevFunction(MatrixCharacteristics mcIn) throws DMLRuntimeException {
            this._mcIn = mcIn;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            IndexedMatrixValue in = SparkUtils.toIndexedMatrixBlock(arg0);
            ArrayList<IndexedMatrixValue> out = new ArrayList<IndexedMatrixValue>();
            LibMatrixReorg.rev(in, this._mcIn.getRows(), this._mcIn.getRowsPerBlock(), out);
            return SparkUtils.fromIndexedMatrixBlock(out).iterator();
        }
    }

    private static class RDDDiagV2MFunction
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 31065772250744103L;
        private ReorgOperator _reorgOp = new ReorgOperator(DiagIndex.getDiagIndexFnObject());
        private MatrixCharacteristics _mcIn = null;

        public RDDDiagV2MFunction(MatrixCharacteristics mcIn) throws DMLRuntimeException {
            this._mcIn = mcIn;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            long rix = ixIn.getRowIndex();
            MatrixIndexes ixOut = new MatrixIndexes(rix, rix);
            MatrixBlock blkOut = (MatrixBlock)blkIn.reorgOperations(this._reorgOp, new MatrixBlock(), -1, -1, -1);
            ret.add(new Tuple2((Object)ixOut, (Object)blkOut));
            int numBlocks = (int)Math.ceil((double)this._mcIn.getRows() / (double)this._mcIn.getRowsPerBlock());
            for (int i = 1; i <= numBlocks; ++i) {
                if ((long)i == ixOut.getColumnIndex()) continue;
                int lrlen = UtilFunctions.computeBlockSize(this._mcIn.getRows(), rix, this._mcIn.getRowsPerBlock());
                int lclen = UtilFunctions.computeBlockSize(this._mcIn.getRows(), i, this._mcIn.getRowsPerBlock());
                MatrixBlock emptyBlk = new MatrixBlock(lrlen, lclen, true);
                ret.add(new Tuple2((Object)new MatrixIndexes(rix, i), (Object)emptyBlk));
            }
            return ret.iterator();
        }
    }
}

