/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark.utils;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.storage.StorageLevel;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.Checkpoint;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.instructions.spark.functions.CopyBinaryCellFunction;
import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockFunction;
import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockPairFunction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.data.Pair;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.util.UtilFunctions;
import scala.Tuple2;

public class SparkUtils {
    public static final StorageLevel DEFAULT_TMP = Checkpoint.DEFAULT_STORAGE_LEVEL;

    public static IndexedMatrixValue toIndexedMatrixBlock(Tuple2<MatrixIndexes, MatrixBlock> in) {
        return new IndexedMatrixValue((MatrixIndexes)in._1(), (MatrixValue)in._2());
    }

    public static IndexedMatrixValue toIndexedMatrixBlock(MatrixIndexes ix, MatrixBlock mb) {
        return new IndexedMatrixValue(ix, mb);
    }

    public static Tuple2<MatrixIndexes, MatrixBlock> fromIndexedMatrixBlock(IndexedMatrixValue in) {
        return new Tuple2((Object)in.getIndexes(), (Object)((MatrixBlock)in.getValue()));
    }

    public static ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> fromIndexedMatrixBlock(ArrayList<IndexedMatrixValue> in) {
        ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes, MatrixBlock>>();
        for (IndexedMatrixValue imv : in) {
            ret.add(SparkUtils.fromIndexedMatrixBlock(imv));
        }
        return ret;
    }

    public static Pair<MatrixIndexes, MatrixBlock> fromIndexedMatrixBlockToPair(IndexedMatrixValue in) {
        return new Pair<MatrixIndexes, MatrixBlock>(in.getIndexes(), (MatrixBlock)in.getValue());
    }

    public static ArrayList<Pair<MatrixIndexes, MatrixBlock>> fromIndexedMatrixBlockToPair(ArrayList<IndexedMatrixValue> in) {
        ArrayList<Pair<MatrixIndexes, MatrixBlock>> ret = new ArrayList<Pair<MatrixIndexes, MatrixBlock>>();
        for (IndexedMatrixValue imv : in) {
            ret.add(SparkUtils.fromIndexedMatrixBlockToPair(imv));
        }
        return ret;
    }

    public static Tuple2<Long, FrameBlock> fromIndexedFrameBlock(Pair<Long, FrameBlock> in) {
        return new Tuple2((Object)in.getKey(), (Object)in.getValue());
    }

    public static ArrayList<Tuple2<Long, FrameBlock>> fromIndexedFrameBlock(ArrayList<Pair<Long, FrameBlock>> in) {
        ArrayList<Tuple2<Long, FrameBlock>> ret = new ArrayList<Tuple2<Long, FrameBlock>>();
        for (Pair<Long, FrameBlock> ifv : in) {
            ret.add(SparkUtils.fromIndexedFrameBlock(ifv));
        }
        return ret;
    }

    public static ArrayList<Pair<Long, Long>> toIndexedLong(List<Tuple2<Long, Long>> in) {
        ArrayList<Pair<Long, Long>> ret = new ArrayList<Pair<Long, Long>>();
        for (Tuple2<Long, Long> e : in) {
            ret.add(new Pair<Object, Object>(e._1(), e._2()));
        }
        return ret;
    }

    public static Pair<Long, FrameBlock> toIndexedFrameBlock(Tuple2<Long, FrameBlock> in) {
        return new Pair<Object, Object>(in._1(), in._2());
    }

    public static boolean isHashPartitioned(JavaPairRDD<?, ?> in) {
        return !in.rdd().partitioner().isEmpty() && in.rdd().partitioner().get() instanceof HashPartitioner;
    }

    public static int getNumPreferredPartitions(MatrixCharacteristics mc, JavaPairRDD<?, ?> in) {
        if (!mc.dimsKnown(true) && in != null) {
            return in.getNumPartitions();
        }
        return SparkUtils.getNumPreferredPartitions(mc);
    }

    public static int getNumPreferredPartitions(MatrixCharacteristics mc) {
        if (!mc.dimsKnown()) {
            return SparkExecutionContext.getDefaultParallelism(true);
        }
        double hdfsBlockSize = InfrastructureAnalyzer.getHDFSBlockSize();
        double matrixPSize = OptimizerUtils.estimatePartitionedSizeExactSparsity(mc);
        return (int)Math.max(Math.ceil(matrixPSize / hdfsBlockSize), 1.0);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> copyBinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> in) {
        return SparkUtils.copyBinaryBlockMatrix(in, true);
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> copyBinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> in, boolean deep) {
        if (!deep) {
            return in.mapValues((Function)new CopyBlockFunction(false));
        }
        return in.mapPartitionsToPair((PairFlatMapFunction)new CopyBlockPairFunction(deep), true);
    }

    public static String getStartLineFromSparkDebugInfo(String line) throws DMLRuntimeException {
        String withoutPrefix = line.substring(4, line.length());
        return withoutPrefix.split(":")[0];
    }

    public static String getPrefixFromSparkDebugInfo(String line) {
        String[] lines = line.split("\\||\\+-");
        String retVal = lines[0];
        for (int i = 1; i < lines.length - 1; ++i) {
            retVal = retVal + "|" + lines[i];
        }
        String twoSpaces = "  ";
        if (line.contains("+-")) {
            return retVal + "+- ";
        }
        return retVal + "|" + twoSpaces;
    }

    public static JavaPairRDD<MatrixIndexes, MatrixBlock> getEmptyBlockRDD(JavaSparkContext sc, MatrixCharacteristics mc) {
        long size = mc.getNumBlocks() * OptimizerUtils.estimateSizeEmptyBlock(Math.min(Math.max(mc.getRows(), 1L), (long)mc.getRowsPerBlock()), Math.min(Math.max(mc.getCols(), 1L), (long)mc.getColsPerBlock()));
        int par = (int)Math.min(Math.max((double)SparkExecutionContext.getDefaultParallelism(true), Math.ceil(size / InfrastructureAnalyzer.getHDFSBlockSize())), (double)mc.getNumBlocks());
        long pNumBlocks = (long)Math.ceil((double)mc.getNumBlocks() / (double)par);
        List offsets = LongStream.iterate(0L, n -> n + pNumBlocks).limit(par).boxed().collect(Collectors.toList());
        return sc.parallelize(offsets, par).flatMapToPair((PairFlatMapFunction)new GenerateEmptyBlocks(mc, pNumBlocks));
    }

    public static JavaPairRDD<MatrixIndexes, MatrixCell> cacheBinaryCellRDD(JavaPairRDD<MatrixIndexes, MatrixCell> input) {
        return !input.getStorageLevel().equals((Object)DEFAULT_TMP) ? input.mapToPair((PairFunction)new CopyBinaryCellFunction()).persist(DEFAULT_TMP) : input;
    }

    public static MatrixCharacteristics computeMatrixCharacteristics(JavaPairRDD<MatrixIndexes, MatrixCell> input) {
        MatrixCharacteristics ret = (MatrixCharacteristics)input.map((Function)new AnalyzeCellMatrixCharacteristics()).reduce((Function2)new AggregateMatrixCharacteristics());
        return ret;
    }

    public static long getNonZeros(JavaPairRDD<MatrixIndexes, MatrixBlock> input) {
        return (Long)input.values().map((Function & Serializable)b -> b.getNonZeros()).reduce((Function2 & Serializable)(a, b) -> a + b);
    }

    private static class GenerateEmptyBlocks
    implements PairFlatMapFunction<Long, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 630129586089106855L;
        private final MatrixCharacteristics _mc;
        private final long _pNumBlocks;

        public GenerateEmptyBlocks(MatrixCharacteristics mc, long pNumBlocks) {
            this._mc = mc;
            this._pNumBlocks = pNumBlocks;
        }

        public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Long arg0) throws Exception {
            ArrayList<Tuple2> list = new ArrayList<Tuple2>();
            long ncblks = this._mc.getNumColBlocks();
            long nblocksU = Math.min(arg0 + this._pNumBlocks, this._mc.getNumBlocks());
            for (long i = arg0.longValue(); i < nblocksU; ++i) {
                long rix = 1L + i / ncblks;
                long cix = 1L + i % ncblks;
                int lrlen = UtilFunctions.computeBlockSize(this._mc.getRows(), rix, this._mc.getRowsPerBlock());
                int lclen = UtilFunctions.computeBlockSize(this._mc.getCols(), cix, this._mc.getColsPerBlock());
                list.add(new Tuple2((Object)new MatrixIndexes(rix, cix), (Object)new MatrixBlock(lrlen, lclen, true)));
            }
            return list.iterator();
        }
    }

    private static class AggregateMatrixCharacteristics
    implements Function2<MatrixCharacteristics, MatrixCharacteristics, MatrixCharacteristics> {
        private static final long serialVersionUID = 4263886749699779994L;

        private AggregateMatrixCharacteristics() {
        }

        public MatrixCharacteristics call(MatrixCharacteristics arg0, MatrixCharacteristics arg1) throws Exception {
            return new MatrixCharacteristics(Math.max(arg0.getRows(), arg1.getRows()), Math.max(arg0.getCols(), arg1.getCols()), arg0.getRowsPerBlock(), arg0.getColsPerBlock(), arg0.getNonZeros() + arg1.getNonZeros());
        }
    }

    private static class AnalyzeCellMatrixCharacteristics
    implements Function<Tuple2<MatrixIndexes, MatrixCell>, MatrixCharacteristics> {
        private static final long serialVersionUID = 8899395272683723008L;

        private AnalyzeCellMatrixCharacteristics() {
        }

        public MatrixCharacteristics call(Tuple2<MatrixIndexes, MatrixCell> arg0) throws Exception {
            long rix = ((MatrixIndexes)arg0._1()).getRowIndex();
            long cix = ((MatrixIndexes)arg0._1()).getColumnIndex();
            long nnz = ((MatrixCell)arg0._2()).getValue() != 0.0 ? 1L : 0L;
            return new MatrixCharacteristics(rix, cix, 0, 0, nnz);
        }
    }
}

