/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.DoubleObject;
import org.apache.sysml.runtime.instructions.spark.ComputationSPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.functions.AggregateDropCorrectionFunction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.AggregateTernaryOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import scala.Tuple2;

public class AggregateTernarySPInstruction
extends ComputationSPInstruction {
    private AggregateTernarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
        super(SPInstruction.SPType.AggregateTernary, op, in1, in2, in3, out, opcode, istr);
    }

    public static AggregateTernarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("tak+*") || opcode.equalsIgnoreCase("tack+*")) {
            InstructionUtils.checkNumFields(parts, 4);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand out = new CPOperand(parts[4]);
            AggregateTernaryOperator op = InstructionUtils.parseAggregateTernaryOperator(opcode);
            return new AggregateTernarySPInstruction(op, in1, in2, in3, out, opcode, str);
        }
        throw new DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown opcode " + opcode);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = sec.getBinaryBlockRDDHandleForVariable(this.input2.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in3 = this.input3.isLiteral() ? null : sec.getBinaryBlockRDDHandleForVariable(this.input3.getName());
        AggregateTernaryOperator aggop = (AggregateTernaryOperator)this._optr;
        JavaPairRDD out = null;
        out = in3 != null ? in1.join(in2).join(in3).mapToPair((PairFunction)new RDDAggregateTernaryFunction(aggop)) : in1.join(in2).mapToPair((PairFunction)new RDDAggregateTernaryFunction2(aggop));
        if (aggop.indexFn instanceof ReduceAll) {
            MatrixBlock tmp = RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>)out.values());
            DoubleObject ret = new DoubleObject(tmp.getValue(0, 0));
            sec.setVariable(this.output.getName(), ret);
        } else if (mcIn.dimsKnown() && mcIn.getCols() <= (long)mcIn.getColsPerBlock()) {
            MatrixBlock ret = RDDAggregateUtils.aggStable(out, aggop.aggOp);
            ret.dropLastRowsOrColumns(aggop.aggOp.correctionLocation);
            sec.setMatrixOutput(this.output.getName(), ret, this.getExtendedOpcode());
        } else {
            out = RDDAggregateUtils.aggByKeyStable((JavaPairRDD<MatrixIndexes, MatrixBlock>)out, aggop.aggOp, false);
            out = out.mapValues((Function)new AggregateDropCorrectionFunction(aggop.aggOp));
            this.updateUnaryAggOutputMatrixCharacteristics(sec, aggop.indexFn);
            sec.setRDDHandleForVariable(this.output.getName(), out);
            sec.addLineageRDD(this.output.getName(), this.input1.getName());
            sec.addLineageRDD(this.output.getName(), this.input2.getName());
            if (in3 != null) {
                sec.addLineageRDD(this.output.getName(), this.input3.getName());
            }
        }
    }

    private static class RDDAggregateTernaryFunction2
    implements PairFunction<Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -6615412819746331700L;
        private final AggregateTernaryOperator _aggop;

        public RDDAggregateTernaryFunction2(AggregateTernaryOperator aggop) {
            this._aggop = aggop;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg0) throws Exception {
            MatrixIndexes ix = (MatrixIndexes)arg0._1();
            MatrixBlock in1 = (MatrixBlock)((Tuple2)arg0._2())._1();
            MatrixBlock in2 = (MatrixBlock)((Tuple2)arg0._2())._2();
            return new Tuple2((Object)new MatrixIndexes(1L, ix.getColumnIndex()), (Object)in1.aggregateTernaryOperations(in1, in2, null, new MatrixBlock(), this._aggop, false));
        }
    }

    private static class RDDAggregateTernaryFunction
    implements PairFunction<Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 6410232464410434210L;
        private final AggregateTernaryOperator _aggop;

        public RDDAggregateTernaryFunction(AggregateTernaryOperator aggop) {
            this._aggop = aggop;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>> arg0) throws Exception {
            MatrixIndexes ix = (MatrixIndexes)arg0._1();
            MatrixBlock in1 = (MatrixBlock)((Tuple2)((Tuple2)arg0._2())._1())._1();
            MatrixBlock in2 = (MatrixBlock)((Tuple2)((Tuple2)arg0._2())._1())._2();
            MatrixBlock in3 = (MatrixBlock)((Tuple2)arg0._2())._2();
            return new Tuple2((Object)new MatrixIndexes(1L, ix.getColumnIndex()), (Object)in1.aggregateTernaryOperations(in1, in2, in3, new MatrixBlock(), this._aggop, false));
        }
    }
}

