/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.MapMult;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import scala.Tuple2;

public class MapmmSPInstruction
extends BinarySPInstruction {
    private MapMult.CacheType _type = null;
    private boolean _outputEmpty = true;
    private AggBinaryOp.SparkAggType _aggtype;

    public MapmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, MapMult.CacheType type, boolean outputEmpty, AggBinaryOp.SparkAggType aggtype, String opcode, String istr) {
        super(op, in1, in2, out, opcode, istr);
        this._sptype = SPInstruction.SPINSTRUCTION_TYPE.MAPMM;
        this._type = type;
        this._outputEmpty = outputEmpty;
        this._aggtype = aggtype;
    }

    public static MapmmSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("mapmm")) {
            throw new DMLRuntimeException("MapmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand out = new CPOperand(parts[3]);
        MapMult.CacheType type = MapMult.CacheType.valueOf(parts[4]);
        boolean outputEmpty = Boolean.parseBoolean(parts[5]);
        AggBinaryOp.SparkAggType aggtype = AggBinaryOp.SparkAggType.valueOf(parts[6]);
        AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
        AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
        return new MapmmSPInstruction(aggbin, in1, in2, out, type, outputEmpty, aggtype, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        MapMult.CacheType type = this._type;
        String rddVar = type.isRight() ? this.input1.getName() : this.input2.getName();
        String bcastVar = type.isRight() ? this.input2.getName() : this.input1.getName();
        MatrixCharacteristics mcRdd = sec.getMatrixCharacteristics(rddVar);
        MatrixCharacteristics mcBc = sec.getMatrixCharacteristics(bcastVar);
        JavaPairRDD in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
        if (MapmmSPInstruction.requiresFlatMapFunction(type, mcBc) && MapmmSPInstruction.requiresRepartitioning(type, mcRdd, mcBc, in1.getNumPartitions())) {
            int numParts = MapmmSPInstruction.getNumRepartitioning(type, mcRdd, mcBc);
            int numParts2 = MapmmSPInstruction.getNumRepartitioning(type.getFlipped(), mcBc, mcRdd);
            if (numParts2 > numParts) {
                rddVar = (type = type.getFlipped()).isRight() ? this.input1.getName() : this.input2.getName();
                bcastVar = type.isRight() ? this.input2.getName() : this.input1.getName();
                mcRdd = sec.getMatrixCharacteristics(rddVar);
                mcBc = sec.getMatrixCharacteristics(bcastVar);
                in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
                LOG.warn((Object)("Mapmm: Switching rdd ('" + bcastVar + "') and broadcast ('" + rddVar + "') inputs for repartitioning because this allows better control of output partition sizes (" + numParts + " < " + numParts2 + ")."));
            }
        }
        PartitionedBroadcast<MatrixBlock> in2 = sec.getBroadcastForVariable(bcastVar);
        if (!this._outputEmpty) {
            in1 = in1.filter((Function)new FilterNonEmptyBlocksFunction());
        }
        if (this._aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            JavaRDD out = in1.map((Function)new RDDMapMMFunction2(type, in2));
            MatrixBlock out2 = RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>)out);
            sec.setMatrixOutput(this.output.getName(), out2);
        } else {
            Object out = null;
            if (MapmmSPInstruction.requiresFlatMapFunction(type, mcBc)) {
                if (MapmmSPInstruction.requiresRepartitioning(type, mcRdd, mcBc, in1.getNumPartitions())) {
                    int numParts = MapmmSPInstruction.getNumRepartitioning(type, mcRdd, mcBc);
                    LOG.warn((Object)("Mapmm: Repartition input rdd '" + rddVar + "' from " + in1.getNumPartitions() + " to " + numParts + " partitions to satisfy size restrictions of output partitions."));
                    in1 = in1.repartition(numParts);
                }
                out = in1.flatMapToPair((PairFlatMapFunction)new RDDFlatMapMMFunction(type, in2));
            } else {
                out = MapmmSPInstruction.preservesPartitioning(mcRdd, type) ? in1.mapPartitionsToPair((PairFlatMapFunction)new RDDMapMMPartitionFunction(type, in2), true) : in1.mapToPair((PairFunction)new RDDMapMMFunction(type, in2));
            }
            if (!this._outputEmpty) {
                out = out.filter((Function)new FilterNonEmptyBlocksFunction());
            }
            if (this._aggtype == AggBinaryOp.SparkAggType.MULTI_BLOCK) {
                out = RDDAggregateUtils.sumByKeyStable(out, false);
            }
            sec.setRDDHandleForVariable(this.output.getName(), (JavaPairRDD<?, ?>)out);
            sec.addLineageRDD(this.output.getName(), rddVar);
            sec.addLineageBroadcast(this.output.getName(), bcastVar);
            this.updateBinaryMMOutputMatrixCharacteristics(sec, true);
        }
    }

    private static boolean preservesPartitioning(MatrixCharacteristics mcIn, MapMult.CacheType type) {
        if (type == MapMult.CacheType.LEFT) {
            return mcIn.dimsKnown() && mcIn.getRows() <= (long)mcIn.getRowsPerBlock();
        }
        return mcIn.dimsKnown() && mcIn.getCols() <= (long)mcIn.getColsPerBlock();
    }

    private static boolean requiresFlatMapFunction(MapMult.CacheType type, MatrixCharacteristics mcBc) {
        return type == MapMult.CacheType.LEFT && mcBc.getRows() > (long)mcBc.getRowsPerBlock() || type == MapMult.CacheType.RIGHT && mcBc.getCols() > (long)mcBc.getColsPerBlock();
    }

    private static boolean requiresRepartitioning(MapMult.CacheType type, MatrixCharacteristics mcRdd, MatrixCharacteristics mcBc, int numPartitions) {
        boolean isLeft;
        boolean bl = isLeft = type == MapMult.CacheType.LEFT;
        boolean isOuter = isLeft ? mcRdd.getRows() <= (long)mcRdd.getRowsPerBlock() : mcRdd.getCols() <= (long)mcRdd.getColsPerBlock();
        boolean isLargeOutput = OptimizerUtils.estimatePartitionedSizeExactSparsity(isLeft ? mcBc.getRows() : mcRdd.getRows(), isLeft ? mcRdd.getCols() : mcBc.getCols(), isLeft ? (long)mcBc.getRowsPerBlock() : (long)mcRdd.getRowsPerBlock(), isLeft ? (long)mcRdd.getColsPerBlock() : (long)mcBc.getColsPerBlock(), 1.0) / (long)numPartitions > 0x40000000L;
        return isOuter && isLargeOutput && mcRdd.dimsKnown() && mcBc.dimsKnown();
    }

    private static int getNumRepartitioning(MapMult.CacheType type, MatrixCharacteristics mcRdd, MatrixCharacteristics mcBc) {
        boolean isLeft = type == MapMult.CacheType.LEFT;
        long sizeOutput = OptimizerUtils.estimatePartitionedSizeExactSparsity(isLeft ? mcBc.getRows() : mcRdd.getRows(), isLeft ? mcRdd.getCols() : mcBc.getCols(), isLeft ? (long)mcBc.getRowsPerBlock() : (long)mcRdd.getRowsPerBlock(), isLeft ? (long)mcRdd.getColsPerBlock() : (long)mcBc.getColsPerBlock(), 1.0);
        long numParts = sizeOutput / InfrastructureAnalyzer.getHDFSBlockSize();
        return (int)Math.min(numParts, isLeft ? mcRdd.getNumColBlocks() : mcRdd.getNumRowBlocks());
    }

    private static class RDDFlatMapMMFunction
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -6076256569118957281L;
        private final MapMult.CacheType _type;
        private final AggregateBinaryOperator _op;
        private final PartitionedBroadcast<MatrixBlock> _pbc;

        public RDDFlatMapMMFunction(MapMult.CacheType type, PartitionedBroadcast<MatrixBlock> binput) {
            this._type = type;
            this._pbc = binput;
            AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
            this._op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            if (this._type == MapMult.CacheType.LEFT) {
                int len = this._pbc.getNumRowBlocks();
                for (int i = 1; i <= len; ++i) {
                    MatrixBlock left = this._pbc.getBlock(i, (int)ixIn.getRowIndex());
                    MatrixIndexes ixOut = new MatrixIndexes();
                    MatrixBlock blkOut = new MatrixBlock();
                    OperationsOnMatrixValues.performAggregateBinary(new MatrixIndexes(i, ixIn.getRowIndex()), left, ixIn, blkIn, ixOut, blkOut, this._op);
                    ret.add(new Tuple2((Object)ixOut, (Object)blkOut));
                }
            } else {
                int len = this._pbc.getNumColumnBlocks();
                for (int j = 1; j <= len; ++j) {
                    MatrixBlock right = this._pbc.getBlock((int)ixIn.getColumnIndex(), j);
                    MatrixIndexes ixOut = new MatrixIndexes();
                    MatrixBlock blkOut = new MatrixBlock();
                    OperationsOnMatrixValues.performAggregateBinary(ixIn, blkIn, new MatrixIndexes(ixIn.getColumnIndex(), j), right, ixOut, blkOut, this._op);
                    ret.add(new Tuple2((Object)ixOut, (Object)blkOut));
                }
            }
            return ret.iterator();
        }
    }

    private static class RDDMapMMPartitionFunction
    implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 1886318890063064287L;
        private final MapMult.CacheType _type;
        private final AggregateBinaryOperator _op;
        private final PartitionedBroadcast<MatrixBlock> _pbc;

        public RDDMapMMPartitionFunction(MapMult.CacheType type, PartitionedBroadcast<MatrixBlock> binput) {
            this._type = type;
            this._pbc = binput;
            AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
            this._op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
        }

        public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg0) throws Exception {
            return new MapMMPartitionIterator(arg0);
        }

        private class MapMMPartitionIterator
        extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public MapMMPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) {
                super(in);
            }

            @Override
            protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception {
                MatrixIndexes ixIn = (MatrixIndexes)arg._1();
                MatrixBlock blkIn = (MatrixBlock)arg._2();
                MatrixBlock blkOut = new MatrixBlock();
                if (RDDMapMMPartitionFunction.this._type == MapMult.CacheType.LEFT) {
                    MatrixBlock left = (MatrixBlock)RDDMapMMPartitionFunction.this._pbc.getBlock(1, (int)ixIn.getRowIndex());
                    left.aggregateBinaryOperations(left, blkIn, blkOut, RDDMapMMPartitionFunction.this._op);
                } else {
                    MatrixBlock right = (MatrixBlock)RDDMapMMPartitionFunction.this._pbc.getBlock((int)ixIn.getColumnIndex(), 1);
                    blkIn.aggregateBinaryOperations(blkIn, right, blkOut, RDDMapMMPartitionFunction.this._op);
                }
                return new Tuple2((Object)ixIn, (Object)blkOut);
            }
        }
    }

    private static class RDDMapMMFunction2
    implements Function<Tuple2<MatrixIndexes, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = -2753453898072910182L;
        private final MapMult.CacheType _type;
        private final AggregateBinaryOperator _op;
        private final PartitionedBroadcast<MatrixBlock> _pbc;

        public RDDMapMMFunction2(MapMult.CacheType type, PartitionedBroadcast<MatrixBlock> binput) {
            this._type = type;
            this._pbc = binput;
            AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
            this._op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
        }

        public MatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            if (this._type == MapMult.CacheType.LEFT) {
                MatrixBlock left = this._pbc.getBlock(1, (int)ixIn.getRowIndex());
                return (MatrixBlock)OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(left, blkIn, new MatrixBlock(), this._op);
            }
            MatrixBlock right = this._pbc.getBlock((int)ixIn.getColumnIndex(), 1);
            return (MatrixBlock)OperationsOnMatrixValues.performAggregateBinaryIgnoreIndexes(blkIn, right, new MatrixBlock(), this._op);
        }
    }

    private static class RDDMapMMFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 8197406787010296291L;
        private final MapMult.CacheType _type;
        private final AggregateBinaryOperator _op;
        private final PartitionedBroadcast<MatrixBlock> _pbc;

        public RDDMapMMFunction(MapMult.CacheType type, PartitionedBroadcast<MatrixBlock> binput) {
            this._type = type;
            this._pbc = binput;
            AggregateOperator agg = new AggregateOperator(0.0, Plus.getPlusFnObject());
            this._op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            MatrixIndexes ixOut = new MatrixIndexes();
            MatrixBlock blkOut = new MatrixBlock();
            if (this._type == MapMult.CacheType.LEFT) {
                MatrixBlock left = this._pbc.getBlock(1, (int)ixIn.getRowIndex());
                OperationsOnMatrixValues.performAggregateBinary(new MatrixIndexes(1L, ixIn.getRowIndex()), left, ixIn, blkIn, ixOut, blkOut, this._op);
            } else {
                MatrixBlock right = this._pbc.getBlock((int)ixIn.getColumnIndex(), 1);
                OperationsOnMatrixValues.performAggregateBinary(ixIn, blkIn, new MatrixIndexes(ixIn.getColumnIndex(), 1L), right, ixOut, blkOut, this._op);
            }
            return new Tuple2((Object)ixOut, (Object)blkOut);
        }
    }
}

