/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.stream.IntStream;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.PartialAggregate;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.codegen.CodegenUtils;
import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysml.runtime.codegen.SpoofCellwise;
import org.apache.sysml.runtime.codegen.SpoofMultiAggregate;
import org.apache.sysml.runtime.codegen.SpoofOperator;
import org.apache.sysml.runtime.codegen.SpoofOuterProduct;
import org.apache.sysml.runtime.codegen.SpoofRowwise;
import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.DoubleObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import scala.Tuple2;

public class SpoofSPInstruction
extends SPInstruction {
    private final Class<?> _class;
    private final byte[] _classBytes;
    private final CPOperand[] _in;
    private final CPOperand _out;

    private SpoofSPInstruction(Class<?> cls, byte[] classBytes, CPOperand[] in, CPOperand out, String opcode, String str) {
        super(SPInstruction.SPType.SpoofFused, opcode, str);
        this._class = cls;
        this._classBytes = classBytes;
        this._in = in;
        this._out = out;
    }

    public static SpoofSPInstruction parseInstruction(String str) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        ArrayList<CPOperand> inlist = new ArrayList<CPOperand>();
        Class<?> cls = CodegenUtils.getClass(parts[1]);
        byte[] classBytes = CodegenUtils.getClassData(parts[1]);
        String opcode = parts[0] + CodegenUtils.createInstance(cls).getSpoofType();
        for (int i = 2; i < parts.length - 2; ++i) {
            inlist.add(new CPOperand(parts[i]));
        }
        CPOperand out = new CPOperand(parts[parts.length - 2]);
        return new SpoofSPInstruction(cls, classBytes, inlist.toArray(new CPOperand[0]), out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        boolean[] bcVect = SpoofSPInstruction.determineBroadcastInputs(sec, this._in);
        boolean[] bcVect2 = SpoofSPInstruction.getMatrixBroadcastVector(sec, this._in, bcVect);
        int main = SpoofSPInstruction.getMainInputIndex(this._in, bcVect);
        MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(this._in[main].getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock[]> in = SpoofSPInstruction.createJoinedInputRDD(sec, this._in, bcVect, this._class.getSuperclass() == SpoofOuterProduct.class);
        JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
        ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new ArrayList<PartitionedBroadcast<MatrixBlock>>();
        ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>();
        for (int i = 0; i < this._in.length; ++i) {
            if (this._in[i].getDataType() == Expression.DataType.MATRIX && bcVect[i]) {
                bcMatrices.add(sec.getBroadcastForVariable(this._in[i].getName()));
                continue;
            }
            if (this._in[i].getDataType() != Expression.DataType.SCALAR) continue;
            scalars.add(sec.getScalarInput(this._in[i].getName(), this._in[i].getValueType(), this._in[i].isLiteral()));
        }
        if (this._class.getSuperclass() == SpoofCellwise.class) {
            SpoofCellwise op = (SpoofCellwise)CodegenUtils.createInstance(this._class);
            AggregateOperator aggop = SpoofSPInstruction.getAggregateOperator(op.getAggOp());
            if (this._out.getDataType() == Expression.DataType.MATRIX) {
                out = in.mapPartitionsToPair((PairFlatMapFunction)new CellwiseFunction(this._class.getName(), this._classBytes, bcVect2, bcMatrices, scalars), true);
                if (op.getCellType() == SpoofCellwise.CellType.ROW_AGG && mcIn.getCols() > (long)mcIn.getColsPerBlock() || op.getCellType() == SpoofCellwise.CellType.COL_AGG && mcIn.getRows() > (long)mcIn.getRowsPerBlock()) {
                    long numBlocks = op.getCellType() == SpoofCellwise.CellType.ROW_AGG ? mcIn.getNumRowBlocks() : mcIn.getNumColBlocks();
                    out = RDDAggregateUtils.aggByKeyStable(out, aggop, (int)Math.min((long)out.getNumPartitions(), numBlocks), false);
                }
                sec.setRDDHandleForVariable(this._out.getName(), out);
                SpoofSPInstruction.maintainLineageInfo(sec, this._in, bcVect, this._out);
                this.updateOutputMatrixCharacteristics(sec, op);
            } else {
                out = in.mapPartitionsToPair((PairFlatMapFunction)new CellwiseFunction(this._class.getName(), this._classBytes, bcVect2, bcMatrices, scalars), true);
                MatrixBlock tmpMB = RDDAggregateUtils.aggStable(out, aggop);
                sec.setVariable(this._out.getName(), new DoubleObject(tmpMB.getValue(0, 0)));
            }
        } else if (this._class.getSuperclass() == SpoofMultiAggregate.class) {
            SpoofMultiAggregate op = (SpoofMultiAggregate)CodegenUtils.createInstance(this._class);
            SpoofCellwise.AggOp[] aggOps = op.getAggOps();
            MatrixBlock tmpMB = (MatrixBlock)in.mapToPair((PairFunction)new MultiAggregateFunction(this._class.getName(), this._classBytes, bcVect2, bcMatrices, scalars)).values().fold((Object)new MatrixBlock(), (Function2)new MultiAggAggregateFunction(aggOps));
            sec.setMatrixOutput(this._out.getName(), tmpMB, this.getExtendedOpcode());
        } else if (this._class.getSuperclass() == SpoofOuterProduct.class) {
            if (this._out.getDataType() == Expression.DataType.MATRIX) {
                SpoofOperator op = CodegenUtils.createInstance(this._class);
                SpoofOuterProduct.OutProdType type = ((SpoofOuterProduct)op).getOuterProdType();
                this.updateOutputMatrixCharacteristics(sec, op);
                MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this._out.getName());
                out = in.mapPartitionsToPair((PairFlatMapFunction)new OuterProductFunction(this._class.getName(), this._classBytes, bcVect2, bcMatrices, scalars), true);
                if (type == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT || type == SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT) {
                    long numBlocks = mcOut.getNumRowBlocks() * mcOut.getNumColBlocks();
                    out = RDDAggregateUtils.sumByKeyStable(out, (int)Math.min((long)out.getNumPartitions(), numBlocks), false);
                }
                sec.setRDDHandleForVariable(this._out.getName(), out);
                SpoofSPInstruction.maintainLineageInfo(sec, this._in, bcVect, this._out);
            } else {
                out = in.mapPartitionsToPair((PairFlatMapFunction)new OuterProductFunction(this._class.getName(), this._classBytes, bcVect2, bcMatrices, scalars), true);
                MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
                sec.setVariable(this._out.getName(), new DoubleObject(tmp.getValue(0, 0)));
            }
        } else if (this._class.getSuperclass() == SpoofRowwise.class) {
            if (mcIn.getCols() > (long)mcIn.getColsPerBlock()) {
                throw new DMLRuntimeException("Invalid spark rowwise operator w/ ncol=" + mcIn.getCols() + ", ncolpb=" + mcIn.getColsPerBlock() + ".");
            }
            SpoofRowwise op = (SpoofRowwise)CodegenUtils.createInstance(this._class);
            long clen2 = op.getRowType().isConstDim2(op.getConstDim2()) ? op.getConstDim2() : (op.getRowType().isRowTypeB1() ? sec.getMatrixCharacteristics(this._in[1].getName()).getCols() : -1L);
            RowwiseFunction fmmc = new RowwiseFunction(this._class.getName(), this._classBytes, bcVect2, bcMatrices, scalars, (int)mcIn.getCols(), (int)clen2);
            out = in.mapPartitionsToPair((PairFlatMapFunction)fmmc, op.getRowType() == SpoofRowwise.RowType.ROW_AGG || op.getRowType() == SpoofRowwise.RowType.NO_AGG);
            if (op.getRowType().isColumnAgg() || op.getRowType() == SpoofRowwise.RowType.FULL_AGG) {
                MatrixBlock tmpMB = RDDAggregateUtils.sumStable(out);
                if (op.getRowType().isColumnAgg()) {
                    sec.setMatrixOutput(this._out.getName(), tmpMB, this.getExtendedOpcode());
                } else {
                    sec.setScalarOutput(this._out.getName(), new DoubleObject(tmpMB.quickGetValue(0, 0)));
                }
            } else {
                if (op.getRowType() == SpoofRowwise.RowType.ROW_AGG && mcIn.getCols() > (long)mcIn.getColsPerBlock()) {
                    out = RDDAggregateUtils.sumByKeyStable(out, (int)Math.min((long)out.getNumPartitions(), mcIn.getNumRowBlocks()), false);
                }
                sec.setRDDHandleForVariable(this._out.getName(), out);
                SpoofSPInstruction.maintainLineageInfo(sec, this._in, bcVect, this._out);
                this.updateOutputMatrixCharacteristics(sec, op);
            }
        } else {
            throw new DMLRuntimeException("Operator " + this._class.getSuperclass() + " is not supported on Spark");
        }
    }

    private static boolean[] determineBroadcastInputs(SparkExecutionContext sec, CPOperand[] inputs) throws DMLRuntimeException {
        boolean[] ret = new boolean[inputs.length];
        double localBudget = OptimizerUtils.getLocalMemBudget() - (double)CacheableData.getBroadcastSize();
        double bcBudget = SparkExecutionContext.getBroadcastMemoryBudget();
        for (int i2 = 0; i2 < inputs.length; ++i2) {
            double sizeP;
            if (!inputs[i2].getDataType().isMatrix()) continue;
            MatrixCharacteristics mc = sec.getMatrixCharacteristics(inputs[i2].getName());
            double sizeL = OptimizerUtils.estimateSizeExactSparsity(mc);
            ret[i2] = localBudget > sizeL + (sizeP = (double)OptimizerUtils.estimatePartitionedSizeExactSparsity(mc)) && bcBudget > sizeP;
            localBudget -= ret[i2] ? sizeP : 0.0;
            bcBudget -= ret[i2] ? sizeP : 0.0;
        }
        if (!IntStream.range(0, ret.length).anyMatch(i -> inputs[i].isMatrix() && !ret[i])) {
            ret[0] = false;
        }
        return ret;
    }

    private static boolean[] getMatrixBroadcastVector(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect) throws DMLRuntimeException {
        int numMtx = (int)Arrays.stream(inputs).filter(in -> in.getDataType().isMatrix()).count();
        boolean[] ret = new boolean[numMtx];
        int pos = 0;
        for (int i = 0; i < inputs.length; ++i) {
            if (!inputs[i].getDataType().isMatrix()) continue;
            ret[pos++] = bcVect[i];
        }
        return ret;
    }

    private static JavaPairRDD<MatrixIndexes, MatrixBlock[]> createJoinedInputRDD(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect, boolean outer) throws DMLRuntimeException {
        int main = SpoofSPInstruction.getMainInputIndex(inputs, bcVect);
        MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(inputs[main].getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(inputs[main].getName());
        JavaPairRDD ret = in.mapValues((Function)new MapInputSignature());
        for (int i = 0; i < inputs.length; ++i) {
            if (i == main || !inputs[i].getDataType().isMatrix() || bcVect[i]) continue;
            String varname = inputs[i].getName();
            JavaPairRDD tmp = sec.getBinaryBlockRDDHandleForVariable(varname);
            MatrixCharacteristics mcTmp = sec.getMatrixCharacteristics(varname);
            if (outer && i == 2) {
                tmp = tmp.flatMapToPair((PairFlatMapFunction)new ReplicateRightFactorFunction(mcIn.getRows(), mcIn.getRowsPerBlock()));
            } else if (mcIn.getNumRowBlocks() > mcTmp.getNumRowBlocks()) {
                tmp = tmp.flatMapToPair((PairFlatMapFunction)new ReplicateBlockFunction(mcIn.getRows(), mcIn.getRowsPerBlock(), false));
            } else if (mcIn.getNumColBlocks() > mcTmp.getNumColBlocks()) {
                tmp = tmp.flatMapToPair((PairFlatMapFunction)new ReplicateBlockFunction(mcIn.getCols(), mcIn.getColsPerBlock(), true));
            }
            ret = ret.join(tmp).mapValues((Function)new MapJoinSignature());
        }
        return ret;
    }

    private static void maintainLineageInfo(SparkExecutionContext sec, CPOperand[] inputs, boolean[] bcVect, CPOperand output) throws DMLRuntimeException {
        for (int i = 0; i < inputs.length; ++i) {
            if (!inputs[i].getDataType().isMatrix()) continue;
            sec.addLineage(output.getName(), inputs[i].getName(), bcVect[i]);
        }
    }

    private static int getMainInputIndex(CPOperand[] inputs, boolean[] bcVect) {
        return IntStream.range(0, bcVect.length).filter(i -> inputs[i].isMatrix() && !bcVect[i]).min().orElse(0);
    }

    private void updateOutputMatrixCharacteristics(SparkExecutionContext sec, SpoofOperator op) throws DMLRuntimeException {
        if (op instanceof SpoofCellwise) {
            MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(this._in[0].getName());
            MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this._out.getName());
            if (((SpoofCellwise)op).getCellType() == SpoofCellwise.CellType.ROW_AGG) {
                mcOut.set(mcIn.getRows(), 1L, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
            } else if (((SpoofCellwise)op).getCellType() == SpoofCellwise.CellType.NO_AGG) {
                mcOut.set(mcIn);
            }
        } else if (op instanceof SpoofOuterProduct) {
            MatrixCharacteristics mcIn1 = sec.getMatrixCharacteristics(this._in[0].getName());
            MatrixCharacteristics mcIn2 = sec.getMatrixCharacteristics(this._in[1].getName());
            MatrixCharacteristics mcIn3 = sec.getMatrixCharacteristics(this._in[2].getName());
            MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this._out.getName());
            SpoofOuterProduct.OutProdType type = ((SpoofOuterProduct)op).getOuterProdType();
            if (type == SpoofOuterProduct.OutProdType.CELLWISE_OUTER_PRODUCT) {
                mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getRowsPerBlock(), mcIn1.getColsPerBlock());
            } else if (type == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT) {
                mcOut.set(mcIn3.getRows(), mcIn3.getCols(), mcIn3.getRowsPerBlock(), mcIn3.getColsPerBlock());
            } else if (type == SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT) {
                mcOut.set(mcIn2.getRows(), mcIn2.getCols(), mcIn2.getRowsPerBlock(), mcIn2.getColsPerBlock());
            }
        } else if (op instanceof SpoofRowwise) {
            MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(this._in[0].getName());
            MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this._out.getName());
            SpoofRowwise.RowType type = ((SpoofRowwise)op).getRowType();
            if (type == SpoofRowwise.RowType.NO_AGG) {
                mcOut.set(mcIn);
            } else if (type == SpoofRowwise.RowType.ROW_AGG) {
                mcOut.set(mcIn.getRows(), 1L, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
            } else if (type == SpoofRowwise.RowType.COL_AGG) {
                mcOut.set(1L, mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
            } else if (type == SpoofRowwise.RowType.COL_AGG_T) {
                mcOut.set(mcIn.getCols(), 1L, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
            }
        }
    }

    public static AggregateOperator getAggregateOperator(SpoofCellwise.AggOp aggop) {
        if (aggop == SpoofCellwise.AggOp.SUM || aggop == SpoofCellwise.AggOp.SUM_SQ) {
            return new AggregateOperator(0.0, KahanPlus.getKahanPlusFnObject(), true, PartialAggregate.CorrectionLocationType.NONE);
        }
        if (aggop == SpoofCellwise.AggOp.MIN) {
            return new AggregateOperator(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MIN), false, PartialAggregate.CorrectionLocationType.NONE);
        }
        if (aggop == SpoofCellwise.AggOp.MAX) {
            return new AggregateOperator(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MAX), false, PartialAggregate.CorrectionLocationType.NONE);
        }
        return null;
    }

    public static class ReplicateRightFactorFunction
    implements PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -7295989688796126442L;
        private final long _len;
        private final long _blen;

        public ReplicateRightFactorFunction(long len, long blen) {
            this._len = len;
            this._blen = blen;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            LinkedList<Tuple2> ret = new LinkedList<Tuple2>();
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            long numBlocks = (long)Math.ceil((double)this._len / (double)this._blen);
            long j = ixIn.getRowIndex();
            for (long i = 1L; i <= numBlocks; ++i) {
                MatrixIndexes tmpix = new MatrixIndexes(i, j);
                MatrixBlock tmpblk = blkIn;
                ret.add(new Tuple2((Object)tmpix, (Object)tmpblk));
            }
            return ret.iterator();
        }
    }

    private static class OuterProductFunction
    extends SpoofFunction
    implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -8209188316939435099L;
        private SpoofOperator _op = null;

        public OuterProductFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) throws DMLRuntimeException {
            super(className, classBytes, bcInd, bcMatrices, scalars);
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg) throws Exception {
            if (this._op == null) {
                Class<?> loadedClass = CodegenUtils.getClassSync(this._className, this._classBytes);
                this._op = CodegenUtils.createInstance(loadedClass);
            }
            ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
            while (arg.hasNext()) {
                Tuple2<MatrixIndexes, MatrixBlock[]> tmp = arg.next();
                MatrixIndexes ixIn = (MatrixIndexes)tmp._1();
                MatrixBlock[] blkIn = (MatrixBlock[])tmp._2();
                MatrixBlock blkOut = new MatrixBlock();
                ArrayList<MatrixBlock> inputs = this.getAllMatrixInputs(ixIn, blkIn, true);
                if (((SpoofOuterProduct)this._op).getOuterProdType() == SpoofOuterProduct.OutProdType.AGG_OUTER_PRODUCT) {
                    ScalarObject obj = this._op.execute(inputs, (ArrayList<ScalarObject>)this._scalars, 1);
                    blkOut.reset(1, 1);
                    blkOut.quickSetValue(0, 0, obj.getDoubleValue());
                } else {
                    blkOut = this._op.execute(inputs, (ArrayList<ScalarObject>)this._scalars, blkOut);
                }
                ret.add(new Tuple2((Object)OuterProductFunction.createOutputIndexes(ixIn, this._op), (Object)blkOut));
            }
            return ret.iterator();
        }

        private static MatrixIndexes createOutputIndexes(MatrixIndexes in, SpoofOperator spoofOp) {
            if (((SpoofOuterProduct)spoofOp).getOuterProdType() == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT) {
                return new MatrixIndexes(in.getColumnIndex(), 1L);
            }
            if (((SpoofOuterProduct)spoofOp).getOuterProdType() == SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT) {
                return new MatrixIndexes(in.getRowIndex(), 1L);
            }
            return in;
        }
    }

    private static class MultiAggAggregateFunction
    implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 5978731867787952513L;
        private SpoofCellwise.AggOp[] _ops = null;

        public MultiAggAggregateFunction(SpoofCellwise.AggOp[] ops) {
            this._ops = ops;
        }

        public MatrixBlock call(MatrixBlock arg0, MatrixBlock arg1) throws Exception {
            if (arg0.getNumRows() <= 0 || arg0.getNumColumns() <= 0) {
                arg0.copy(arg1);
                return arg0;
            }
            if (arg1.getNumRows() <= 0 || arg1.getNumColumns() <= 0) {
                return arg0;
            }
            SpoofMultiAggregate.aggregatePartialResults(this._ops, arg0, arg1);
            return arg0;
        }
    }

    private static class MultiAggregateFunction
    extends SpoofFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -5224519291577332734L;
        private SpoofOperator _op = null;

        public MultiAggregateFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) throws DMLRuntimeException {
            super(className, classBytes, bcInd, bcMatrices, scalars);
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock[]> arg) throws Exception {
            if (this._op == null) {
                Class<?> loadedClass = CodegenUtils.getClassSync(this._className, this._classBytes);
                this._op = CodegenUtils.createInstance(loadedClass);
            }
            ArrayList<MatrixBlock> inputs = this.getAllMatrixInputs((MatrixIndexes)arg._1(), (MatrixBlock[])arg._2());
            MatrixBlock blkOut = new MatrixBlock();
            blkOut = this._op.execute(inputs, (ArrayList<ScalarObject>)this._scalars, blkOut);
            return new Tuple2(arg._1(), (Object)blkOut);
        }
    }

    private static class CellwiseFunction
    extends SpoofFunction
    implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -8209188316939435099L;
        private SpoofOperator _op = null;

        public CellwiseFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) throws DMLRuntimeException {
            super(className, classBytes, bcInd, bcMatrices, scalars);
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg) throws Exception {
            if (this._op == null) {
                Class<?> loadedClass = CodegenUtils.getClassSync(this._className, this._classBytes);
                this._op = CodegenUtils.createInstance(loadedClass);
            }
            ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
            while (arg.hasNext()) {
                Tuple2<MatrixIndexes, MatrixBlock[]> tmp = arg.next();
                MatrixIndexes ixIn = (MatrixIndexes)tmp._1();
                MatrixBlock[] blkIn = (MatrixBlock[])tmp._2();
                MatrixIndexes ixOut = ixIn;
                MatrixBlock blkOut = new MatrixBlock();
                ArrayList<MatrixBlock> inputs = this.getAllMatrixInputs(ixIn, blkIn);
                if (((SpoofCellwise)this._op).getCellType() == SpoofCellwise.CellType.FULL_AGG) {
                    ScalarObject obj = this._op.execute(inputs, (ArrayList<ScalarObject>)this._scalars, 1);
                    blkOut.reset(1, 1);
                    blkOut.quickSetValue(0, 0, obj.getDoubleValue());
                } else {
                    if (((SpoofCellwise)this._op).getCellType() == SpoofCellwise.CellType.ROW_AGG) {
                        ixOut = new MatrixIndexes(ixOut.getRowIndex(), 1L);
                    } else if (((SpoofCellwise)this._op).getCellType() == SpoofCellwise.CellType.COL_AGG) {
                        ixOut = new MatrixIndexes(1L, ixOut.getColumnIndex());
                    }
                    blkOut = this._op.execute(inputs, (ArrayList<ScalarObject>)this._scalars, blkOut);
                }
                ret.add(new Tuple2((Object)ixOut, (Object)blkOut));
            }
            return ret.iterator();
        }
    }

    private static class RowwiseFunction
    extends SpoofFunction
    implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -7926980450209760212L;
        private final int _clen;
        private final int _clen2;
        private SpoofRowwise _op = null;

        public RowwiseFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars, int clen, int clen2) throws DMLRuntimeException {
            super(className, classBytes, bcInd, bcMatrices, scalars);
            this._clen = clen;
            this._clen2 = clen;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock[]>> arg) throws Exception {
            MatrixBlock blkOut;
            if (this._op == null) {
                Class<?> loadedClass = CodegenUtils.getClassSync(this._className, this._classBytes);
                this._op = (SpoofRowwise)CodegenUtils.createInstance(loadedClass);
            }
            LibSpoofPrimitives.setupThreadLocalMemory(this._op.getNumIntermediates(), this._clen, this._clen2);
            ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
            boolean aggIncr = this._op.getRowType().isColumnAgg() || this._op.getRowType() == SpoofRowwise.RowType.FULL_AGG;
            MatrixBlock matrixBlock = blkOut = aggIncr ? new MatrixBlock() : null;
            while (arg.hasNext()) {
                Tuple2<MatrixIndexes, MatrixBlock[]> e = arg.next();
                MatrixIndexes ixIn = (MatrixIndexes)e._1();
                MatrixBlock[] blkIn = (MatrixBlock[])e._2();
                ArrayList<MatrixBlock> inputs = this.getAllMatrixInputs(ixIn, blkIn);
                blkOut = aggIncr ? blkOut : new MatrixBlock();
                blkOut = this._op.execute(inputs, this._scalars, blkOut, false, aggIncr);
                if (aggIncr) continue;
                MatrixIndexes ixOut = new MatrixIndexes(ixIn.getRowIndex(), this._op.getRowType() != SpoofRowwise.RowType.NO_AGG ? 1L : ixIn.getColumnIndex());
                ret.add(new Tuple2((Object)ixOut, (Object)blkOut));
            }
            LibSpoofPrimitives.cleanupThreadLocalMemory();
            if (aggIncr) {
                blkOut.recomputeNonZeros();
                blkOut.examSparsity();
                ret.add(new Tuple2((Object)new MatrixIndexes(1L, 1L), (Object)blkOut));
            }
            return ret.iterator();
        }
    }

    private static class SpoofFunction
    implements Serializable {
        private static final long serialVersionUID = 2953479427746463003L;
        protected final boolean[] _bcInd;
        protected final ArrayList<PartitionedBroadcast<MatrixBlock>> _inputs;
        protected final ArrayList<ScalarObject> _scalars;
        protected final byte[] _classBytes;
        protected final String _className;

        protected SpoofFunction(String className, byte[] classBytes, boolean[] bcInd, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars) {
            this._bcInd = bcInd;
            this._inputs = bcMatrices;
            this._scalars = scalars;
            this._classBytes = classBytes;
            this._className = className;
        }

        protected ArrayList<MatrixBlock> getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn) throws DMLRuntimeException {
            return this.getAllMatrixInputs(ixIn, blkIn, false);
        }

        protected ArrayList<MatrixBlock> getAllMatrixInputs(MatrixIndexes ixIn, MatrixBlock[] blkIn, boolean outer) throws DMLRuntimeException {
            ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>();
            int posRdd = 0;
            int posBc = 0;
            for (int i = 0; i < this._bcInd.length; ++i) {
                if (this._bcInd[i]) {
                    PartitionedBroadcast<MatrixBlock> pb = this._inputs.get(posBc++);
                    int rowIndex = (int)(outer && i == 2 ? ixIn.getColumnIndex() : ((long)pb.getNumRowBlocks() >= ixIn.getRowIndex() ? ixIn.getRowIndex() : 1L));
                    int colIndex = (int)(outer && i == 2 ? 1L : ((long)pb.getNumColumnBlocks() >= ixIn.getColumnIndex() ? ixIn.getColumnIndex() : 1L));
                    ret.add(pb.getBlock(rowIndex, colIndex));
                    continue;
                }
                ret.add(blkIn[posRdd++]);
            }
            return ret;
        }
    }

    private static class MapJoinSignature
    implements Function<Tuple2<MatrixBlock[], MatrixBlock>, MatrixBlock[]> {
        private static final long serialVersionUID = -704403012606821854L;

        private MapJoinSignature() {
        }

        public MatrixBlock[] call(Tuple2<MatrixBlock[], MatrixBlock> v1) throws Exception {
            return (MatrixBlock[])ArrayUtils.add((Object[])((Object[])v1._1()), (Object)v1._2());
        }
    }

    private static class MapInputSignature
    implements Function<MatrixBlock, MatrixBlock[]> {
        private static final long serialVersionUID = -816443970067626102L;

        private MapInputSignature() {
        }

        public MatrixBlock[] call(MatrixBlock v1) throws Exception {
            return new MatrixBlock[]{v1};
        }
    }
}

