/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNConv2dBackwardDataHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNConv2dBackwardFilterHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNConv2dHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNPoolingBackwardHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNPoolingHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.util.ConvolutionUtils;
import org.apache.sysml.utils.NativeHelper;

public class LibMatrixDNNHelper {
    public static ArrayList<Callable<Long>> getMaxPoolingWorkers(ConvolutionParameters params) throws DMLRuntimeException {
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k);
        int i = 0;
        while (i * taskSize < params.N) {
            if (params.input1.isInSparseFormat()) {
                ret.add(new LibMatrixDNNPoolingHelper.SparseMaxPooling(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else {
                ret.add(new LibMatrixDNNPoolingHelper.DenseMaxPooling(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            }
            ++i;
        }
        return ret;
    }

    public static ArrayList<Callable<Long>> getMaxPoolingBackwardWorkers(ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException {
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k);
        int i = 0;
        while (i * taskSize < params.N) {
            if (!params.input1.isInSparseFormat()) {
                if (!params.input2.isInSparseFormat()) {
                    ret.add(new LibMatrixDNNPoolingBackwardHelper.PoolingBackwardDenseDense(i * taskSize, Math.min((i + 1) * taskSize, params.N), params, performReluBackward));
                } else {
                    ret.add(new LibMatrixDNNPoolingBackwardHelper.PoolingBackwardDenseSparse(i * taskSize, Math.min((i + 1) * taskSize, params.N), params, performReluBackward));
                }
            } else if (!params.input2.isInSparseFormat()) {
                ret.add(new LibMatrixDNNPoolingBackwardHelper.PoolingBackwardSparseDense(i * taskSize, Math.min((i + 1) * taskSize, params.N), params, performReluBackward));
            } else {
                ret.add(new LibMatrixDNNPoolingBackwardHelper.PoolingBackwardSparseSparse(i * taskSize, Math.min((i + 1) * taskSize, params.N), params, performReluBackward));
            }
            ++i;
        }
        return ret;
    }

    public static ArrayList<Callable<Long>> getReluBackwardWorkers(ConvolutionParameters params) throws DMLRuntimeException {
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k);
        int i = 0;
        while (i * taskSize < params.N) {
            ret.add(new ReluBackward(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            ++i;
        }
        return ret;
    }

    public static ArrayList<Callable<Long>> getConv2dWorkers(ConvolutionParameters params) throws DMLRuntimeException {
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k);
        boolean allChannels = true;
        ArrayList<MatrixBlock> filters = null;
        if (!allChannels) {
            filters = LibMatrixDNNHelper.splitFilter(params);
        }
        boolean isEmptyDenseInput = !params.input1.isInSparseFormat() && params.input1.denseBlock == null;
        int i = 0;
        while (i * taskSize < params.N) {
            if (LibMatrixDNN.isEligibleForConv2dSparse(params)) {
                ret.add(new LibMatrixDNNConv2dHelper.SparseNativeConv2d(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else if (!isEmptyDenseInput && allChannels) {
                ret.add(new LibMatrixDNNConv2dHelper.LoopedIm2ColConv2dAllChannels(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else if (!isEmptyDenseInput && !allChannels) {
                ret.add(new LibMatrixDNNConv2dHelper.LoopedIm2ColConv2dOneChannel(i * taskSize, Math.min((i + 1) * taskSize, params.N), params, filters));
            } else {
                throw new DMLRuntimeException("Unsupported operator");
            }
            ++i;
        }
        return ret;
    }

    public static ArrayList<Callable<Long>> getConv2dBackwardFilterWorkers(ConvolutionParameters params) throws DMLRuntimeException {
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k);
        boolean isEmptyDenseInput = !params.input1.isInSparseFormat() && params.input1.denseBlock == null || !params.input2.isInSparseFormat() && params.input2.denseBlock == null;
        int i = 0;
        while (i * taskSize < params.N) {
            if (LibMatrixDNN.isEligibleForConv2dBackwardFilterSparseDense(params)) {
                ret.add(new LibMatrixDNNConv2dBackwardFilterHelper.SparseNativeConv2dBackwardFilterDense(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else if (!isEmptyDenseInput) {
                ret.add(new LibMatrixDNNConv2dBackwardFilterHelper.Conv2dBackwardFilter(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else {
                throw new DMLRuntimeException("Unsupported operator");
            }
            ++i;
        }
        return ret;
    }

    public static ArrayList<Callable<Long>> getConv2dBackwardDataWorkers(ConvolutionParameters params) throws DMLRuntimeException {
        ArrayList<Callable<Long>> ret = new ArrayList<Callable<Long>>();
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        int taskSize = (int)Math.ceil((double)params.N / (double)k);
        boolean isEmptyDenseInput = !params.input1.isInSparseFormat() && params.input1.denseBlock == null || !params.input2.isInSparseFormat() && params.input2.denseBlock == null;
        int i = 0;
        while (i * taskSize < params.N) {
            if (LibMatrixDNN.isEligibleForConv2dBackwardDataDense(params)) {
                ret.add(new LibMatrixDNNConv2dBackwardDataHelper.SparseNativeConv2dBackwardDataDense(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else if (!isEmptyDenseInput) {
                ret.add(new LibMatrixDNNConv2dBackwardDataHelper.Conv2dBackwardData(i * taskSize, Math.min((i + 1) * taskSize, params.N), params));
            } else {
                throw new DMLRuntimeException("Unsupported operator");
            }
            ++i;
        }
        return ret;
    }

    static void computeTensorIndexes(int j, int[] ret, int H, int W) {
        ret[0] = j / (H * W);
        ret[1] = (j - ret[0] * (H * W)) / W;
        ret[2] = j % W;
    }

    private static ArrayList<MatrixBlock> splitFilter(ConvolutionParameters _params) {
        ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>();
        int RS = _params.R * _params.S;
        int CRS = _params.C * _params.R * _params.S;
        double[] filter = _params.input2.getDenseBlock();
        int S = _params.S;
        for (int c = 0; c < _params.C; ++c) {
            int k;
            MatrixBlock mb = new MatrixBlock(_params.K, RS, false);
            mb.allocateDenseBlock();
            long nnz = 0L;
            double[] outputArr = mb.getDenseBlock();
            if (filter != null) {
                for (k = 0; k < _params.K; ++k) {
                    for (int rs = 0; rs < RS; ++rs) {
                        outputArr[k * RS + rs] = filter[k * CRS + c * RS + rs];
                        nnz += outputArr[k * RS + rs] != 0.0 ? 1L : 0L;
                    }
                }
            } else {
                for (k = 0; k < _params.K; ++k) {
                    if (_params.input2.sparseBlock.isEmpty(k)) continue;
                    int[] tensorIndexes = new int[3];
                    int apos = _params.input2.sparseBlock.pos(k);
                    int alen = _params.input2.sparseBlock.size(k);
                    int[] aix = _params.input2.sparseBlock.indexes(k);
                    double[] avals = _params.input2.sparseBlock.values(k);
                    for (int j = apos; j < apos + alen; ++j) {
                        LibMatrixDNNHelper.computeTensorIndexes(aix[j], tensorIndexes, _params.R, _params.S);
                        if (c != tensorIndexes[0]) continue;
                        int r = tensorIndexes[1];
                        int s = tensorIndexes[2];
                        outputArr[k * RS + r * S + s] = avals[j];
                        nnz += outputArr[k * RS + r * S + s] != 0.0 ? 1L : 0L;
                    }
                }
            }
            mb.setNonZeros(nnz);
            ret.add(mb);
        }
        return ret;
    }

    static void singleThreadedMatMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean recomputeNNZM1, boolean recomputeNNZM2, ConvolutionParameters params) throws DMLRuntimeException {
        if (!params.enableNative || m1.isInSparseFormat() || m2.isInSparseFormat()) {
            LibMatrixDNNHelper.prepNonZerosForMatrixMult(m1, recomputeNNZM1);
            LibMatrixDNNHelper.prepNonZerosForMatrixMult(m2, recomputeNNZM2);
            LibMatrixMult.matrixMult(m1, m2, ret, false);
        } else {
            ret.sparse = false;
            if (ret.getDenseBlock() == null) {
                ret.allocateDenseBlock();
            }
            NativeHelper.matrixMultDenseDense(m1.denseBlock, m2.denseBlock, ret.denseBlock, m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns(), 1);
            ret.recomputeNonZeros();
        }
    }

    static void addBias(int _rl, int _ru, double[] outputArr, double[] biasArr, int K, int PQ) {
        int index = _rl * K * PQ;
        for (int n = _rl; n < _ru; ++n) {
            for (int k = 0; k < K; ++k) {
                for (int pq = 0; pq < PQ; ++pq) {
                    int n2 = index++;
                    outputArr[n2] = outputArr[n2] + biasArr[k];
                }
            }
        }
    }

    static int getMaxIndex(int p, int q, int inputOffset, double[] inputArray, ConvolutionParameters params, boolean performReluBackward) {
        int start_index_h = params.start_indexes_h[p];
        int end_index_h = params.end_indexes_h[p];
        int start_index_w = params.start_indexes_w[q];
        int end_index_w = params.end_indexes_w[q];
        int maxIndex = -1;
        double maxVal = -1.7976931348623157E308;
        double currDoutVal = -1.0;
        for (int h = start_index_h; h < end_index_h; ++h) {
            for (int w = start_index_w; w < end_index_w; ++w) {
                currDoutVal = inputArray[inputOffset + h * params.W + w];
                double d = currDoutVal = performReluBackward && currDoutVal < 0.0 ? 0.0 : currDoutVal;
                if (!(maxVal < currDoutVal)) continue;
                maxIndex = inputOffset + h * params.W + w;
                maxVal = currDoutVal;
            }
        }
        return maxIndex;
    }

    static int getMaxIndexSparse(int p, int q, int inputOffset, int n, int c, MatrixBlock input, ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException {
        if (!input.isInSparseFormat()) {
            throw new DMLRuntimeException("Incorrect usage: Only sparse format supported");
        }
        int[] tensorIndexes = new int[3];
        int start_index_h = params.start_indexes_h[p];
        int end_index_h = params.end_indexes_h[p];
        int start_index_w = params.start_indexes_w[q];
        int end_index_w = params.end_indexes_w[q];
        int maxIndex = -1;
        double maxVal = -1.7976931348623157E308;
        if (!input.sparseBlock.isEmpty(n)) {
            int apos = input.sparseBlock.pos(n);
            int alen = input.sparseBlock.size(n);
            int[] aix = input.sparseBlock.indexes(n);
            double[] avals = input.sparseBlock.values(n);
            for (int j = apos; j < apos + alen; ++j) {
                double val;
                LibMatrixDNNHelper.computeTensorIndexes(aix[j], tensorIndexes, params.H, params.W);
                if (c != tensorIndexes[0]) continue;
                int h = tensorIndexes[1];
                int w = tensorIndexes[2];
                if (h < start_index_h || h >= end_index_h || w < start_index_w || w >= end_index_w) continue;
                double d = val = performReluBackward && avals[j] < 0.0 ? 0.0 : avals[j];
                if (!(maxVal < val)) continue;
                maxIndex = inputOffset + h * params.W + w;
                maxVal = val;
            }
        } else {
            maxIndex = inputOffset;
        }
        return maxIndex;
    }

    static void getRowInDenseFormat(MatrixBlock input, int n, double[] ret) throws DMLRuntimeException {
        if (input.getNumColumns() != ret.length) {
            throw new DMLRuntimeException("Invalid parameters");
        }
        if (input.isInSparseFormat()) {
            Arrays.fill(ret, 0.0);
            if (!input.sparseBlock.isEmpty(n)) {
                int apos = input.sparseBlock.pos(n);
                int alen = input.sparseBlock.size(n);
                int[] aix = input.sparseBlock.indexes(n);
                double[] avals = input.sparseBlock.values(n);
                for (int j = apos; j < apos + alen; ++j) {
                    ret[aix[j]] = avals[j];
                }
            }
        } else {
            System.arraycopy(input.getDenseBlock(), n * input.getNumColumns(), ret, 0, input.getNumColumns());
        }
    }

    static void doCol2imOverSingleImage(int outputN, MatrixBlock input, ConvolutionParameters params) throws DMLRuntimeException {
        if (input.rlen != params.P * params.Q || input.clen != params.C * params.R * params.S) {
            throw new DMLRuntimeException("Incorrect input dimensions");
        }
        double[] outputArray = null;
        if (params.output.isInSparseFormat()) {
            throw new DMLRuntimeException("Only dense output is implemented");
        }
        outputArray = params.output.getDenseBlock();
        if (!input.isInSparseFormat()) {
            double[] inputArray = input.getDenseBlock();
            LibMatrixDNNHelper.doCol2IMDenseInput(0, outputN, inputArray, outputArray, params);
        } else if (!input.isEmptyBlock()) {
            int[] tensorIndexes = new int[3];
            for (int i = 0; i < input.getNumRows(); ++i) {
                if (input.sparseBlock.isEmpty(i)) continue;
                LibMatrixDNNHelper.computeTensorIndexes(i, tensorIndexes, params.P, params.Q);
                int p = tensorIndexes[1];
                int q = tensorIndexes[2];
                if (tensorIndexes[0] != 0) {
                    throw new DMLRuntimeException("Incorrect tensor indexes: " + tensorIndexes[0] + " != 0 <" + p + " " + q + " " + tensorIndexes[0] + params.P + " " + params.Q + ">");
                }
                int apos = input.sparseBlock.pos(i);
                int alen = input.sparseBlock.size(i);
                int[] aix = input.sparseBlock.indexes(i);
                double[] avals = input.sparseBlock.values(i);
                for (int j = apos; j < apos + alen; ++j) {
                    int outIndex;
                    LibMatrixDNNHelper.computeTensorIndexes(aix[j], tensorIndexes, params.R, params.S);
                    int c = tensorIndexes[0];
                    int r = tensorIndexes[1];
                    int s = tensorIndexes[2];
                    int h = p * params.stride_h + r - params.pad_h;
                    int w = q * params.stride_w + s - params.pad_w;
                    if (h < 0 || h >= params.H || w < 0 || w >= params.W) continue;
                    int n = outIndex = outputN * params.C * params.H * params.W + c * params.H * params.W + h * params.W + w;
                    outputArray[n] = outputArray[n] + avals[j];
                }
            }
        }
    }

    private static void doCol2IMDenseInput(int inputN, int outputN, double[] inputArray, double[] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
        int outputNOffset = outputN * params.C * params.H * params.W;
        for (int p = 0; p < params.P; ++p) {
            int hOffset = p * params.stride_h - params.pad_h;
            int rStart = Math.max(0, -hOffset);
            int rEnd = Math.min(params.R, params.H - hOffset);
            for (int q = 0; q < params.Q; ++q) {
                int wOffset = q * params.stride_w - params.pad_w;
                int sStart = Math.max(0, -wOffset);
                int sEnd = Math.min(params.S, params.W - wOffset);
                int tempOffset = (inputN * params.P * params.Q + p * params.Q + q) * params.C * params.R * params.S;
                for (int c = 0; c < params.C; ++c) {
                    int outOffset = outputNOffset + c * params.H * params.W;
                    int inputOffset = tempOffset + c * params.R * params.S;
                    for (int r = rStart; r < rEnd; ++r) {
                        for (int s = sStart; s < sEnd; ++s) {
                            int outIndex;
                            int inputIndex = inputOffset + r * params.S + s;
                            int n = outIndex = outOffset + (hOffset + r) * params.W + wOffset + s;
                            outputArray[n] = outputArray[n] + inputArray[inputIndex];
                        }
                    }
                }
            }
        }
    }

    private static void prepNonZerosForMatrixMult(MatrixBlock mb, boolean update) {
        if (!update) {
            return;
        }
        if (!mb.isInSparseFormat()) {
            mb.setNonZeros(mb.getNumRows() * mb.getNumColumns());
        } else {
            mb.recomputeNonZeros();
        }
    }

    public static class ReluBackward
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final ConvolutionParameters _params;
        double[] outputArray;
        int numOutCols;

        public ReluBackward(int rl, int ru, ConvolutionParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this.outputArray = params.output.getDenseBlock();
            this.numOutCols = params.input1.getNumColumns();
        }

        @Override
        public Long call() throws Exception {
            if (!this._params.input1.isInSparseFormat() && !this._params.input2.isInSparseFormat()) {
                double[] inputArr = this._params.input1.getDenseBlock();
                double[] doutArr = this._params.input2.getDenseBlock();
                for (int i = this._rl * this.numOutCols; i < this._ru * this.numOutCols; ++i) {
                    this.outputArray[i] = inputArr[i] > 0.0 ? doutArr[i] : 0.0;
                }
            } else {
                ConvolutionUtils.scalarOperations(this._params.input1, this.outputArray, this._rl * this.numOutCols, this.numOutCols, this._rl, this._ru, InstructionUtils.parseScalarBinaryOperator(">", false, 0.0));
                ConvolutionUtils.binaryOperationInPlace(this._params.input2, this.outputArray, this._rl * this.numOutCols, this.numOutCols, this._rl, this._ru, LibMatrixDNN._binaryElementWiseMultiplication);
            }
            return 0L;
        }
    }
}

