/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark.utils;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.sysml.lops.PartialAggregate;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.instructions.spark.data.CorrMatrixBlock;
import org.apache.sysml.runtime.instructions.spark.data.RowMatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;

public class RDDAggregateUtils {
    private static final boolean TREE_AGGREGATION = false;

    public static MatrixBlock sumStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
        return RDDAggregateUtils.sumStable((JavaRDD<MatrixBlock>)in.values());
    }

    public static MatrixBlock sumStable(JavaRDD<MatrixBlock> in) {
        return (MatrixBlock)in.fold((Object)new MatrixBlock(), (Function2)new SumSingleBlockFunction(false));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
        return RDDAggregateUtils.sumByKeyStable(in, in.getNumPartitions(), true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, boolean deepCopyCombiner) {
        return RDDAggregateUtils.sumByKeyStable(in, in.getNumPartitions(), deepCopyCombiner);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, int numPartitions, boolean deepCopyCombiner) {
        JavaPairRDD tmp = in.combineByKey((Function)new CreateCorrBlockCombinerFunction(deepCopyCombiner), (Function2)new MergeSumBlockValueFunction(), (Function2)new MergeSumBlockCombinerFunction(), numPartitions);
        JavaPairRDD out = tmp.mapValues((Function)new ExtractMatrixBlock());
        return out;
    }

    public static JavaPairRDD<MatrixIndexes, Double> sumCellsByKeyStable(JavaPairRDD<MatrixIndexes, Double> in) {
        JavaPairRDD tmp = in.combineByKey((Function)new CreateCellCombinerFunction(), (Function2)new MergeSumCellValueFunction(), (Function2)new MergeSumCellCombinerFunction());
        JavaPairRDD out = tmp.mapValues((Function)new ExtractDoubleCell());
        return out;
    }

    public static MatrixBlock aggStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, AggregateOperator aop) {
        return RDDAggregateUtils.aggStable((JavaRDD<MatrixBlock>)in.values(), aop);
    }

    public static MatrixBlock aggStable(JavaRDD<MatrixBlock> in, AggregateOperator aop) {
        return (MatrixBlock)in.fold((Object)new MatrixBlock(), (Function2)new AggregateSingleBlockFunction(aop));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, AggregateOperator aop) {
        return RDDAggregateUtils.aggByKeyStable(in, aop, in.getNumPartitions(), true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, AggregateOperator aop, boolean deepCopyCombiner) {
        return RDDAggregateUtils.aggByKeyStable(in, aop, in.getNumPartitions(), deepCopyCombiner);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> aggByKeyStable(JavaPairRDD<MatrixIndexes, MatrixBlock> in, AggregateOperator aop, int numPartitions, boolean deepCopyCombiner) {
        JavaPairRDD tmp = in.combineByKey((Function)new CreateCorrBlockCombinerFunction(deepCopyCombiner), (Function2)new MergeAggBlockValueFunction(aop), (Function2)new MergeAggBlockCombinerFunction(aop), numPartitions);
        JavaPairRDD out = tmp.mapValues((Function)new ExtractMatrixBlock());
        return out;
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
        return RDDAggregateUtils.mergeByKey(in, in.getNumPartitions(), true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey(JavaPairRDD<MatrixIndexes, MatrixBlock> in, boolean deepCopyCombiner) {
        return RDDAggregateUtils.mergeByKey(in, in.getNumPartitions(), deepCopyCombiner);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeByKey(JavaPairRDD<MatrixIndexes, MatrixBlock> in, int numPartitions, boolean deepCopyCombiner) {
        return in.combineByKey((Function)new CreateBlockCombinerFunction(deepCopyCombiner), (Function2)new MergeBlocksFunction(false), (Function2)new MergeBlocksFunction(false), numPartitions);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> mergeRowsByKey(JavaPairRDD<MatrixIndexes, RowMatrixBlock> in) {
        return in.combineByKey((Function)new CreateRowBlockCombinerFunction(), (Function2)new MergeRowBlockValueFunction(), (Function2)new MergeBlocksFunction(false));
    }

    private static class MergeBlocksFunction
    implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -8881019027250258850L;
        private boolean _deep = false;

        public MergeBlocksFunction() {
            this(true);
        }

        public MergeBlocksFunction(boolean deep) {
            this._deep = deep;
        }

        public MatrixBlock call(MatrixBlock b1, MatrixBlock b2) throws Exception {
            long b1nnz = b1.getNonZeros();
            long b2nnz = b2.getNonZeros();
            if (b1.getNumRows() != b2.getNumRows() || b1.getNumColumns() != b2.getNumColumns()) {
                throw new DMLRuntimeException("Mismatched block sizes for: " + b1.getNumRows() + " " + b1.getNumColumns() + " " + b2.getNumRows() + " " + b2.getNumColumns());
            }
            MatrixBlock ret = this._deep ? new MatrixBlock(b1) : b1;
            ret.merge(b2, false);
            ret.examSparsity();
            if (ret.getNonZeros() != b1nnz + b2nnz) {
                throw new DMLRuntimeException("Number of non-zeros does not match: " + ret.getNonZeros() + " != " + b1nnz + " + " + b2nnz);
            }
            return ret;
        }
    }

    private static class AggregateSingleBlockFunction
    implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -3672377410407066396L;
        private AggregateOperator _op = null;
        private MatrixBlock _corr = null;

        public AggregateSingleBlockFunction(AggregateOperator op) {
            this._op = op;
        }

        public MatrixBlock call(MatrixBlock arg0, MatrixBlock arg1) throws Exception {
            if (arg0.getNumRows() <= 0 || arg0.getNumColumns() <= 0) {
                arg0.copy(arg1);
                return arg0;
            }
            if (arg1.getNumRows() <= 0 || arg1.getNumColumns() <= 0) {
                return arg0;
            }
            if (this._op.correctionExists && this._corr == null) {
                this._corr = new MatrixBlock(arg0.getNumRows(), arg0.getNumColumns(), false);
            }
            OperationsOnMatrixValues.incrementalAggregation(arg0, this._op.correctionExists ? this._corr : null, arg1, this._op, true);
            return arg0;
        }
    }

    private static class SumSingleBlockFunction
    implements Function2<MatrixBlock, MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1737038715965862222L;
        private AggregateOperator _op = new AggregateOperator(0.0, KahanPlus.getKahanPlusFnObject(), true, PartialAggregate.CorrectionLocationType.NONE);
        private MatrixBlock _corr = null;
        private boolean _deep = false;

        public SumSingleBlockFunction(boolean deep) {
            this._deep = deep;
        }

        public MatrixBlock call(MatrixBlock arg0, MatrixBlock arg1) throws Exception {
            if (arg0.getNumRows() <= 0 || arg0.getNumColumns() <= 0) {
                arg0.copy(arg1);
                return arg0;
            }
            if (arg1.getNumRows() <= 0 || arg1.getNumColumns() <= 0) {
                return arg0;
            }
            if (this._corr == null) {
                this._corr = new MatrixBlock(arg0.getNumRows(), arg0.getNumColumns(), false);
            }
            MatrixBlock out = this._deep ? new MatrixBlock(arg0) : arg0;
            OperationsOnMatrixValues.incrementalAggregation(out, this._corr, arg1, this._op, false);
            return out;
        }
    }

    private static class ExtractDoubleCell
    implements Function<KahanObject, Double> {
        private static final long serialVersionUID = -2873241816558275742L;

        private ExtractDoubleCell() {
        }

        public Double call(KahanObject arg0) throws Exception {
            return arg0._sum;
        }
    }

    private static class ExtractMatrixBlock
    implements Function<CorrMatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 5242158678070843495L;

        private ExtractMatrixBlock() {
        }

        public MatrixBlock call(CorrMatrixBlock arg0) throws Exception {
            return arg0.getValue();
        }
    }

    private static class MergeAggBlockCombinerFunction
    implements Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 4803711632648880797L;
        private AggregateOperator _op = null;

        public MergeAggBlockCombinerFunction(AggregateOperator aop) {
            this._op = aop;
        }

        public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock arg1) throws Exception {
            MatrixBlock value1 = arg0.getValue();
            MatrixBlock value2 = arg1.getValue();
            MatrixBlock corr = arg0.getCorrection();
            if (corr == null && this._op.correctionExists) {
                MatrixBlock matrixBlock = corr = arg1.getCorrection() != null ? arg1.getCorrection() : new MatrixBlock(value1.getNumRows(), value1.getNumColumns(), false);
            }
            if (this._op.correctionExists) {
                OperationsOnMatrixValues.incrementalAggregation(value1, corr, value2, this._op, true);
            } else {
                OperationsOnMatrixValues.incrementalAggregation(value1, null, value2, this._op, true);
            }
            return new CorrMatrixBlock(value1, corr);
        }
    }

    private static class MergeAggBlockValueFunction
    implements Function2<CorrMatrixBlock, MatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 389422125491172011L;
        private AggregateOperator _op = null;

        public MergeAggBlockValueFunction(AggregateOperator aop) {
            this._op = aop;
        }

        public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1) throws Exception {
            MatrixBlock value = arg0.getValue();
            MatrixBlock corr = arg0.getCorrection();
            if (corr == null && this._op.correctionExists) {
                corr = new MatrixBlock(value.getNumRows(), value.getNumColumns(), false);
            }
            if (this._op.correctionExists) {
                OperationsOnMatrixValues.incrementalAggregation(value, corr, arg1, this._op, true);
            } else {
                OperationsOnMatrixValues.incrementalAggregation(value, null, arg1, this._op, true);
            }
            return new CorrMatrixBlock(value, corr);
        }
    }

    private static class MergeSumCellCombinerFunction
    implements Function2<KahanObject, KahanObject, KahanObject> {
        private static final long serialVersionUID = 8726716909849119657L;

        private MergeSumCellCombinerFunction() {
        }

        public KahanObject call(KahanObject arg0, KahanObject arg1) throws Exception {
            KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
            kplus.execute2(arg0, arg1._sum);
            return arg0;
        }
    }

    private static class MergeSumCellValueFunction
    implements Function2<KahanObject, Double, KahanObject> {
        private static final long serialVersionUID = 468335171573184825L;

        private MergeSumCellValueFunction() {
        }

        public KahanObject call(KahanObject arg0, Double arg1) throws Exception {
            KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
            kplus.execute2(arg0, arg1);
            return arg0;
        }
    }

    private static class CreateCellCombinerFunction
    implements Function<Double, KahanObject> {
        private static final long serialVersionUID = 3697505233057172994L;

        private CreateCellCombinerFunction() {
        }

        public KahanObject call(Double arg0) throws Exception {
            return new KahanObject(arg0, 0.0);
        }
    }

    private static class MergeRowBlockValueFunction
    implements Function2<MatrixBlock, RowMatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = -803689998683298516L;

        private MergeRowBlockValueFunction() {
        }

        public MatrixBlock call(MatrixBlock arg0, RowMatrixBlock arg1) throws Exception {
            MatrixBlock row = arg1.getValue();
            MatrixBlock out = arg0;
            out.copy(arg1.getRow(), arg1.getRow(), 0, row.getNumColumns() - 1, row, true);
            out.examSparsity();
            return out;
        }
    }

    private static class CreateRowBlockCombinerFunction
    implements Function<RowMatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 2866598914232118425L;

        private CreateRowBlockCombinerFunction() {
        }

        public MatrixBlock call(RowMatrixBlock arg0) throws Exception {
            MatrixBlock row = arg0.getValue();
            MatrixBlock out = new MatrixBlock(arg0.getLen(), row.getNumColumns(), true);
            out.copy(arg0.getRow(), arg0.getRow(), 0, row.getNumColumns() - 1, row, false);
            out.setNonZeros(row.getNonZeros());
            out.examSparsity();
            return out;
        }
    }

    private static class CreateBlockCombinerFunction
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1987501624176848292L;
        private final boolean _deep;

        public CreateBlockCombinerFunction(boolean deep) {
            this._deep = deep;
        }

        public MatrixBlock call(MatrixBlock arg0) throws Exception {
            return this._deep ? new MatrixBlock(arg0) : arg0;
        }
    }

    private static class MergeSumBlockCombinerFunction
    implements Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 7664941774566119853L;
        private AggregateOperator _op = new AggregateOperator(0.0, KahanPlus.getKahanPlusFnObject(), true, PartialAggregate.CorrectionLocationType.NONE);

        private MergeSumBlockCombinerFunction() {
        }

        public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock arg1) throws Exception {
            MatrixBlock value1 = arg0.getValue();
            MatrixBlock value2 = arg1.getValue();
            MatrixBlock corr = arg0.getCorrection();
            if (corr == null) {
                corr = arg1.getCorrection() != null ? arg1.getCorrection() : new MatrixBlock(value1.getNumRows(), value1.getNumColumns(), false);
            }
            OperationsOnMatrixValues.incrementalAggregation(value1, corr, value2, this._op, false);
            return arg0.set(value1, corr);
        }
    }

    private static class MergeSumBlockValueFunction
    implements Function2<CorrMatrixBlock, MatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 3703543699467085539L;
        private AggregateOperator _op = new AggregateOperator(0.0, KahanPlus.getKahanPlusFnObject(), true, PartialAggregate.CorrectionLocationType.NONE);

        private MergeSumBlockValueFunction() {
        }

        public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1) throws Exception {
            MatrixBlock value = arg0.getValue();
            MatrixBlock corr = arg0.getCorrection();
            if (corr == null) {
                corr = new MatrixBlock(value.getNumRows(), value.getNumColumns(), false);
            }
            OperationsOnMatrixValues.incrementalAggregation(value, corr, arg1, this._op, false);
            return arg0.set(value, corr);
        }
    }

    private static class CreateCorrBlockCombinerFunction
    implements Function<MatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = -3666451526776017343L;
        private final boolean _deep;

        public CreateCorrBlockCombinerFunction(boolean deep) {
            this._deep = deep;
        }

        public CorrMatrixBlock call(MatrixBlock arg0) throws Exception {
            return new CorrMatrixBlock(this._deep ? new MatrixBlock(arg0) : arg0);
        }
    }
}

