/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import scala.Tuple2;

public class CpmmSPInstruction
extends BinarySPInstruction {
    private AggBinaryOp.SparkAggType _aggtype;

    public CpmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, AggBinaryOp.SparkAggType aggtype, String opcode, String istr) {
        super(op, in1, in2, out, opcode, istr);
        this._sptype = SPInstruction.SPINSTRUCTION_TYPE.CPMM;
        this._aggtype = aggtype;
    }

    public static CpmmSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("cpmm")) {
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
            AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
            AggBinaryOp.SparkAggType aggtype = AggBinaryOp.SparkAggType.valueOf(parts[4]);
            return new CpmmSPInstruction((Operator)aggbin, in1, in2, out, aggtype, opcode, str);
        }
        throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = sec.getBinaryBlockRDDHandleForVariable(this.input2.getName());
        JavaPairRDD tmp1 = in1.mapToPair((PairFunction)new CpmmIndexFunction(true));
        JavaPairRDD tmp2 = in2.mapToPair((PairFunction)new CpmmIndexFunction(false));
        JavaPairRDD<MatrixIndexes, MatrixBlock> out = tmp1.join(tmp2).mapToPair((PairFunction)new CpmmMultiplyFunction());
        if (this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
            sec.setMatrixOutput(this.output.getName(), out2);
        } else {
            out = RDDAggregateUtils.sumByKeyStable(out, false);
            sec.setRDDHandleForVariable(this.output.getName(), out);
            sec.addLineageRDD(this.output.getName(), this.input1.getName());
            sec.addLineageRDD(this.output.getName(), this.input2.getName());
            this.updateBinaryMMOutputMatrixCharacteristics(sec, true);
        }
    }

    private static class CpmmMultiplyFunction
    implements PairFunction<Tuple2<Long, Tuple2<IndexedMatrixValue, IndexedMatrixValue>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -2009255629093036642L;
        private AggregateBinaryOperator _op = null;

        public CpmmMultiplyFunction() {
            AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
            this._op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<Long, Tuple2<IndexedMatrixValue, IndexedMatrixValue>> arg0) throws Exception {
            MatrixBlock blkIn1 = (MatrixBlock)((IndexedMatrixValue)((Tuple2)arg0._2())._1()).getValue();
            MatrixBlock blkIn2 = (MatrixBlock)((IndexedMatrixValue)((Tuple2)arg0._2())._2()).getValue();
            MatrixIndexes ixOut = new MatrixIndexes();
            MatrixBlock blkOut = new MatrixBlock();
            blkIn1.aggregateBinaryOperations(blkIn1, blkIn2, blkOut, this._op);
            ixOut.setIndexes(((IndexedMatrixValue)((Tuple2)arg0._2())._1()).getIndexes().getRowIndex(), ((IndexedMatrixValue)((Tuple2)arg0._2())._2()).getIndexes().getColumnIndex());
            return new Tuple2((Object)ixOut, (Object)blkOut);
        }
    }

    private static class CpmmIndexFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, Long, IndexedMatrixValue> {
        private static final long serialVersionUID = -1187183128301671162L;
        private boolean _left = false;

        public CpmmIndexFunction(boolean left) {
            this._left = left;
        }

        public Tuple2<Long, IndexedMatrixValue> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            IndexedMatrixValue value = new IndexedMatrixValue();
            value.set((MatrixIndexes)arg0._1(), new MatrixBlock((MatrixBlock)arg0._2()));
            Long key = this._left ? ((MatrixIndexes)arg0._1).getColumnIndex() : ((MatrixIndexes)arg0._1).getRowIndex();
            return new Tuple2((Object)key, (Object)value);
        }
    }
}

