/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;

public class LibMatrixDNN {
    protected static final Log LOG = LogFactory.getLog((String)LibMatrixDNN.class.getName());
    public static final boolean SUPPORTS_SPARSE_OUTPUTS = false;
    private static final boolean ALLOW_MULTI_THREADED_OPS = true;
    private static final int NUM_TASK_FACTOR = 2;
    public static boolean DISPLAY_STATISTICS = false;
    private static AtomicLong conv2dSparseCount = new AtomicLong(0L);
    private static AtomicLong conv2dDenseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdFilterSparseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdFilterDenseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdDataSparseCount = new AtomicLong(0L);
    private static AtomicLong conv2dBwdDataDenseCount = new AtomicLong(0L);
    private static AtomicLong im2colSparseCount = new AtomicLong(0L);
    private static AtomicLong im2colDenseCount = new AtomicLong(0L);
    private static AtomicLong maxPoolBwdSparseCount = new AtomicLong(0L);
    private static AtomicLong maxPoolBwdDenseCount = new AtomicLong(0L);
    private static AtomicLong loopedConvMatMultTime = new AtomicLong(0L);
    private static AtomicLong loopedConvIm2ColTime = new AtomicLong(0L);
    private static AtomicLong loopedConvBwdFilterMatMultTime = new AtomicLong(0L);
    private static AtomicLong loopedConvBwdFilterIm2ColTime = new AtomicLong(0L);
    private static AtomicLong loopedConvBwdDataMatMultTime = new AtomicLong(0L);
    private static AtomicLong loopedConvBwdDataCol2ImTime = new AtomicLong(0L);
    public static boolean TEST_SPARSE_INPUT = false;
    public static boolean TEST_SPARSE_FILTER = false;

    public static void appendStatistics(StringBuilder sb) {
        if (DMLScript.STATISTICS && DISPLAY_STATISTICS && (conv2dDenseCount.get() != 0L || conv2dSparseCount.get() != 0L)) {
            sb.append("LibMatrixDNN dense count (conv/bwdF/bwdD/im2col/maxBwd):\t" + conv2dDenseCount.get() + "/" + conv2dBwdFilterDenseCount.get() + "/" + conv2dBwdDataDenseCount.get() + "/" + im2colDenseCount.get() + "/" + maxPoolBwdDenseCount.get() + ".\n");
            sb.append("LibMatrixDNN sparse count (conv/bwdF/bwdD/im2col/maxBwd):\t" + conv2dSparseCount.get() + "/" + conv2dBwdFilterSparseCount.get() + "/" + conv2dBwdDataSparseCount.get() + "/" + im2colSparseCount.get() + "/" + maxPoolBwdSparseCount.get() + ".\n");
            if (loopedConvMatMultTime.get() != 0L || loopedConvIm2ColTime.get() != 0L) {
                sb.append("LibMatrixDNN conv(im2col/matmult), bwdF (im2col/matmult), bwdD (col2im/matmult) time:\t" + String.format("%.3f", (double)loopedConvIm2ColTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvMatMultTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdFilterIm2ColTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdFilterMatMultTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdDataCol2ImTime.get() * 1.0E-9) + "/" + String.format("%.3f", (double)loopedConvBwdDataMatMultTime.get() * 1.0E-9) + " sec.\n");
            }
        }
    }

    public static void resetStatistics() {
        conv2dDenseCount.set(0L);
        conv2dBwdFilterDenseCount.set(0L);
        conv2dBwdDataDenseCount.set(0L);
        im2colDenseCount.set(0L);
        maxPoolBwdDenseCount.set(0L);
        conv2dSparseCount.set(0L);
        conv2dBwdFilterSparseCount.set(0L);
        conv2dBwdDataSparseCount.set(0L);
        im2colSparseCount.set(0L);
        maxPoolBwdSparseCount.set(0L);
        loopedConvIm2ColTime.set(0L);
        loopedConvMatMultTime.set(0L);
        loopedConvBwdFilterMatMultTime.set(0L);
        loopedConvBwdFilterIm2ColTime.set(0L);
        loopedConvBwdDataMatMultTime.set(0L);
        loopedConvBwdDataCol2ImTime.set(0L);
    }

    public static void conv2dBackwardData(MatrixBlock filter, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        params.input1 = filter;
        params.input2 = dout;
        params.output = outputBlock;
        if (filter.getNumRows() != params.K || filter.getNumColumns() != params.C * params.R * params.S || dout.getNumRows() != params.N || dout.getNumColumns() != params.K * params.P * params.Q) {
            throw new DMLRuntimeException("Incorrect input to conv2d_backward_filter");
        }
        if (params.stride_h <= 0 || params.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported");
        }
        if (DMLScript.STATISTICS && DISPLAY_STATISTICS) {
            if (filter.isInSparseFormat() || dout.isInSparseFormat()) {
                conv2dBwdDataSparseCount.addAndGet(1L);
            } else {
                conv2dBwdDataDenseCount.addAndGet(1L);
            }
        }
        LibMatrixDNN.runConvTask(TaskType.LoopedIm2ColConv2dBwdData, params);
        outputBlock.recomputeNonZeros();
    }

    public static void conv2dBackwardFilter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        params.input1 = input;
        params.input2 = dout;
        params.output = outputBlock;
        if (input.getNumRows() != params.N || input.getNumColumns() != params.C * params.H * params.W || dout.getNumRows() != params.N || dout.getNumColumns() != params.K * params.P * params.Q) {
            throw new DMLRuntimeException("Incorrect input to conv2d_backward_filter");
        }
        if (params.stride_h <= 0 || params.stride_w <= 0) {
            throw new DMLRuntimeException("Only positive strides supported");
        }
        if (DMLScript.STATISTICS && DISPLAY_STATISTICS) {
            if (input.isInSparseFormat() || dout.isInSparseFormat()) {
                conv2dBwdFilterSparseCount.addAndGet(1L);
            } else {
                conv2dBwdFilterDenseCount.addAndGet(1L);
            }
        }
        LibMatrixDNN.runConvTask(TaskType.LoopedIm2ColConv2dBwdFilter, params);
        outputBlock.recomputeNonZeros();
    }

    private static void elementWiseInPlaceAddition(MatrixBlock ret, MatrixBlock elem) throws DMLRuntimeException {
        if (ret.getNumRows() != elem.getNumRows() || ret.getNumColumns() != elem.getNumColumns()) {
            throw new DMLRuntimeException("Incorrect dimensions");
        }
        if (!ret.isInSparseFormat() && !elem.isInSparseFormat()) {
            for (int i = 0; i < ret.getNumRows() * ret.getNumColumns(); ++i) {
                int n = i;
                ret.denseBlock[n] = ret.denseBlock[n] + elem.denseBlock[i];
            }
        } else if (!ret.isInSparseFormat() && elem.isInSparseFormat()) {
            if (!elem.isEmptyBlock()) {
                Iterator<IJV> iter = elem.sparseBlock.getIterator();
                int numCol = ret.getNumColumns();
                while (iter.hasNext()) {
                    int index;
                    IJV ijv = iter.next();
                    int n = index = ijv.getI() * numCol + ijv.getJ();
                    ret.denseBlock[n] = ret.denseBlock[n] + ijv.getV();
                }
            }
        } else {
            throw new DMLRuntimeException("Sparse return format not supported");
        }
    }

    private static void elementWiseInPlaceTransposedAddition(MatrixBlock ret, MatrixBlock[] elem) throws DMLRuntimeException {
        if (elem == null || elem.length == 0) {
            throw new DMLRuntimeException("Empty input not supported.");
        }
        for (MatrixBlock e : elem) {
            if (!e.isInSparseFormat()) continue;
            throw new DMLRuntimeException("Sparse input format not supported.");
        }
        if (ret.isInSparseFormat()) {
            throw new DMLRuntimeException("Sparse output format not supported.");
        }
        MatrixBlock tmpAgg = elem[0];
        double[] tmp = tmpAgg.denseBlock;
        for (int k = 1; k < elem.length; ++k) {
            double[] tmp2 = elem[k].denseBlock;
            for (int i = 0; i < tmp.length; ++i) {
                int n = i;
                tmp[n] = tmp[n] + tmp2[i];
            }
        }
        tmpAgg.setNonZeros(-1L);
        LibMatrixReorg.transpose(tmpAgg, ret);
    }

    private static void doLoopedIm2ColConv2dBwdData(int n, MatrixBlock dout_reshaped, ConvolutionParameters params) throws DMLRuntimeException {
        long t3;
        MatrixBlock filter = params.input1;
        MatrixBlock dout = params.input2;
        LibMatrixDNN.doRotate180(n, 0, dout, dout_reshaped.denseBlock, params, true);
        dout_reshaped.recomputeNonZeros();
        MatrixBlock temp = new MatrixBlock(params.P * params.Q, params.C * params.R * params.S, false);
        long t1 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0L;
        LibMatrixMult.matrixMult(dout_reshaped, filter, temp, false);
        long t2 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0L;
        LibMatrixDNN.doCol2imOverSingleImage(n, temp, params);
        long l = t3 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0L;
        if (DMLScript.STATISTICS && DISPLAY_STATISTICS) {
            loopedConvBwdDataMatMultTime.addAndGet(t2 - t1);
            loopedConvBwdDataCol2ImTime.addAndGet(t3 - t2);
        }
    }

    private static MatrixBlock doLoopedIm2ColConv2dBwdFilter(int n, MatrixBlock im2ColOutBlock, MatrixBlock dout_reshaped, MatrixBlock partialRetBlock, ConvolutionParameters params) throws DMLRuntimeException {
        long t4;
        long t1 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0L;
        LibMatrixDNN.doIm2col(n, im2ColOutBlock, params);
        im2ColOutBlock.recomputeNonZeros();
        long t2 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0L;
        LibMatrixDNN.doRotate180(n, 0, params.input2, dout_reshaped.denseBlock, params, true);
        dout_reshaped.recomputeNonZeros();
        MatrixBlock temp = new MatrixBlock(params.C * params.R * params.S, params.K, false);
        long t3 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0L;
        LibMatrixMult.matrixMult(im2ColOutBlock, dout_reshaped, temp, false);
        long l = t4 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0L;
        if (DMLScript.STATISTICS && DISPLAY_STATISTICS) {
            loopedConvBwdFilterMatMultTime.addAndGet(t4 - t3);
            loopedConvBwdFilterIm2ColTime.addAndGet(t2 - t1);
        }
        if (!temp.isEmptyBlock()) {
            LibMatrixDNN.elementWiseInPlaceAddition(partialRetBlock, temp);
        }
        return partialRetBlock;
    }

    private static void computeTensorIndexes(int j, int[] ret, int H, int W) throws DMLRuntimeException {
        ret[0] = j / (H * W);
        ret[1] = (j - ret[0] * (H * W)) / W;
        ret[2] = j % W;
    }

    public static void conv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        params.input1 = input;
        params.input2 = filter;
        params.output = outputBlock;
        if (input.getNumRows() != params.N || input.getNumColumns() != params.C * params.H * params.W || filter.getNumRows() != params.K || filter.getNumColumns() != params.C * params.R * params.S) {
            throw new DMLRuntimeException("Incorrect input to conv2d: " + input.getNumRows());
        }
        if (DMLScript.STATISTICS && DISPLAY_STATISTICS) {
            if (input.isInSparseFormat() || filter.isInSparseFormat()) {
                conv2dSparseCount.addAndGet(1L);
            } else {
                conv2dDenseCount.addAndGet(1L);
            }
        }
        if (!input.isInSparseFormat() && TEST_SPARSE_INPUT) {
            input.denseToSparse();
        }
        if (!filter.isInSparseFormat() && TEST_SPARSE_FILTER) {
            filter.denseToSparse();
        }
        LibMatrixDNN.runConvTask(TaskType.LoopedIm2ColConv2d, params);
        outputBlock.recomputeNonZeros();
    }

    private static void doLoopedIm2ColConv2d(int n, MatrixBlock im2ColOutBlock, ConvolutionParameters params) throws DMLRuntimeException {
        long t3;
        long t1 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0L;
        LibMatrixDNN.doIm2col(n, im2ColOutBlock, params);
        im2ColOutBlock.recomputeNonZeros();
        long t2 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0L;
        MatrixBlock matMultOutBlock = new MatrixBlock(params.K, params.P * params.Q, false);
        LibMatrixMult.matrixMult(params.input2, im2ColOutBlock, matMultOutBlock, false);
        long l = t3 = DMLScript.STATISTICS && DISPLAY_STATISTICS ? System.nanoTime() : 0L;
        if (DMLScript.STATISTICS && DISPLAY_STATISTICS) {
            loopedConvIm2ColTime.addAndGet(t2 - t1);
            loopedConvMatMultTime.addAndGet(t3 - t2);
        }
        int destPos = n * params.K * params.P * params.Q;
        int length = params.K * params.P * params.Q;
        if (!matMultOutBlock.isEmptyBlock()) {
            if (matMultOutBlock.isInSparseFormat()) {
                Iterator<IJV> iter = matMultOutBlock.sparseBlock.getIterator();
                int outOffset = n * params.K * params.P * params.Q;
                while (iter.hasNext()) {
                    IJV ijv = iter.next();
                    int k = ijv.getI();
                    int p = ijv.getJ() / params.Q;
                    int q = ijv.getJ() % params.Q;
                    params.output.denseBlock[outOffset + k * params.P * params.Q + p * params.Q + q] = ijv.getV();
                }
            } else {
                System.arraycopy(matMultOutBlock.denseBlock, 0, params.output.denseBlock, destPos, length);
            }
        }
    }

    public static void maxpoolingBackward(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        params.input1 = input;
        params.input2 = dout;
        params.output = outputBlock;
        if (input.getNumColumns() != params.C * params.H * params.W || input.getNumRows() != params.N) {
            throw new DMLRuntimeException("Incorrect input dimensions in maxpooling_backward:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.K * params.P * params.Q);
        }
        if (dout.getNumColumns() != params.C * params.P * params.Q || dout.getNumRows() != params.N) {
            throw new DMLRuntimeException("Incorrect dout dimensions in maxpooling_backward:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.K * params.P * params.Q);
        }
        if (DMLScript.STATISTICS && DISPLAY_STATISTICS) {
            if (input.isInSparseFormat() || dout.isInSparseFormat()) {
                maxPoolBwdSparseCount.addAndGet(1L);
            } else {
                maxPoolBwdDenseCount.addAndGet(1L);
            }
        }
        if (params.output.isInSparseFormat()) {
            throw new DMLRuntimeException("Sparse maxpooling_backward is not supported");
        }
        LibMatrixDNN.fillIndexesArray(params);
        LibMatrixDNN.runConvTask(TaskType.MaxPooling_Backward, params);
        outputBlock.recomputeNonZeros();
    }

    private static void fillIndexesArray(ConvolutionParameters params) {
        params.start_indexes_h = new int[params.P];
        params.end_indexes_h = new int[params.P];
        params.start_indexes_w = new int[params.Q];
        params.end_indexes_w = new int[params.Q];
        for (int p = 0; p < params.P; ++p) {
            int start_index_h = p * params.stride_h - params.pad_h;
            int end_index_h = start_index_h + params.R;
            params.start_indexes_h[p] = Math.max(start_index_h, 0);
            params.end_indexes_h[p] = Math.min(end_index_h, params.H);
        }
        for (int q = 0; q < params.Q; ++q) {
            int start_index_w = q * params.stride_w - params.pad_w;
            int end_index_w = start_index_w + params.S;
            params.start_indexes_w[q] = Math.max(start_index_w, 0);
            params.end_indexes_w[q] = Math.min(end_index_w, params.W);
        }
    }

    private static void doPoolingBackward(int n, ConvolutionParameters params) throws DMLRuntimeException {
        double[] inputArray = null;
        if (!params.input1.isInSparseFormat()) {
            inputArray = params.input1.getDenseBlock();
        }
        double[] doutArray = null;
        if (!params.input2.isInSparseFormat()) {
            doutArray = params.input2.getDenseBlock();
        }
        double[] outputArray = null;
        if (params.output.isInSparseFormat()) {
            throw new DMLRuntimeException("Only dense output supported for pooling_backward");
        }
        outputArray = params.output.getDenseBlock();
        if (inputArray != null) {
            if (doutArray != null) {
                LibMatrixDNN.doPoolingBackwardDenseDense(n, inputArray, doutArray, outputArray, params);
            } else {
                LibMatrixDNN.doPoolingBackwardDenseSparse(n, inputArray, params.input2, outputArray, params);
            }
        } else if (doutArray != null) {
            LibMatrixDNN.doPoolingBackwardSparseDense(n, doutArray, outputArray, params);
        } else {
            LibMatrixDNN.doPoolingBackwardSparseSparse(n, outputArray, params);
        }
    }

    private static void doPoolingBackwardSparseDense(int n, double[] doutArray, double[] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
        if (!params.input1.isInSparseFormat()) {
            throw new DMLRuntimeException("Incorrect usage: Call optimized versions");
        }
        for (int c = 0; c < params.C; ++c) {
            for (int p = 0; p < params.P; ++p) {
                for (int q = 0; q < params.Q; ++q) {
                    int inputOffset;
                    int maxIndex;
                    double inVal = doutArray[n * params.C * params.P * params.Q + c * params.P * params.Q + p * params.Q + q];
                    if (inVal == 0.0 || (maxIndex = LibMatrixDNN.getMaxIndexSparse(p, q, inputOffset = n * params.C * params.H * params.W + c * params.H * params.W, n, c, params.input1, params)) == -1) continue;
                    int n2 = maxIndex;
                    outputArray[n2] = outputArray[n2] + inVal;
                }
            }
        }
    }

    private static void doPoolingBackwardSparseSparse(int n, double[] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
        if (!params.input1.isInSparseFormat()) {
            throw new DMLRuntimeException("Incorrect usage: Call optimized versions");
        }
        Iterator<IJV> iter = params.input2.sparseBlock.getIterator(n, n + 1);
        int[] tensorIndexes = new int[3];
        while (iter.hasNext()) {
            IJV ijv = iter.next();
            LibMatrixDNN.computeTensorIndexes(ijv.getJ(), tensorIndexes, params.P, params.Q);
            int p = tensorIndexes[1];
            int q = tensorIndexes[2];
            int c = tensorIndexes[0];
            int inputOffset = n * params.C * params.H * params.W + c * params.H * params.W;
            int maxIndex = LibMatrixDNN.getMaxIndexSparse(p, q, inputOffset, n, c, params.input1, params);
            if (maxIndex == -1) continue;
            int n2 = maxIndex;
            outputArray[n2] = outputArray[n2] + ijv.getV();
        }
    }

    private static void doPoolingBackwardDenseSparse(int n, double[] inputArray, MatrixBlock dout, double[] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
        Iterator<IJV> iter = dout.sparseBlock.getIterator(n, n + 1);
        int[] tensorIndexes = new int[3];
        while (iter.hasNext()) {
            IJV ijv = iter.next();
            LibMatrixDNN.computeTensorIndexes(ijv.getJ(), tensorIndexes, params.P, params.Q);
            int p = tensorIndexes[1];
            int q = tensorIndexes[2];
            int c = tensorIndexes[0];
            int inputOffset = n * params.C * params.H * params.W + c * params.H * params.W;
            int maxIndex = LibMatrixDNN.getMaxIndex(p, q, inputOffset, inputArray, params);
            if (maxIndex == -1) continue;
            int n2 = maxIndex;
            outputArray[n2] = outputArray[n2] + ijv.getV();
        }
    }

    private static void doPoolingBackwardDenseDense(int n, double[] inputArray, double[] doutArray, double[] outputArray, ConvolutionParameters params) {
        for (int c = 0; c < params.C; ++c) {
            int inputOffset = n * params.C * params.H * params.W + c * params.H * params.W;
            int outputOffset = n * params.C * params.P * params.Q + c * params.P * params.Q;
            for (int p = 0; p < params.P; ++p) {
                for (int q = 0; q < params.Q; ++q) {
                    int maxIndex = LibMatrixDNN.getMaxIndex(p, q, inputOffset, inputArray, params);
                    if (maxIndex == -1) continue;
                    int n2 = maxIndex;
                    outputArray[n2] = outputArray[n2] + doutArray[outputOffset + p * params.Q + q];
                }
            }
        }
    }

    private static int getMaxIndexSparse(int p, int q, int inputOffset, int n, int c, MatrixBlock input, ConvolutionParameters params) throws DMLRuntimeException {
        if (!input.isInSparseFormat()) {
            throw new DMLRuntimeException("Incorrect usage: Only sparse format supported");
        }
        Iterator<IJV> iter = input.sparseBlock.getIterator(n, n + 1);
        int[] tensorIndexes = new int[3];
        int start_index_h = params.start_indexes_h[p];
        int end_index_h = params.end_indexes_h[p];
        int start_index_w = params.start_indexes_w[q];
        int end_index_w = params.end_indexes_w[q];
        int maxIndex = -1;
        double maxVal = -1.7976931348623157E308;
        double currDoutVal = -1.0;
        while (iter.hasNext()) {
            IJV ijv = iter.next();
            LibMatrixDNN.computeTensorIndexes(ijv.getJ(), tensorIndexes, params.H, params.W);
            if (c != tensorIndexes[0]) continue;
            int h = tensorIndexes[1];
            int w = tensorIndexes[2];
            if (h < start_index_h || h >= end_index_h || w < start_index_w || w >= end_index_w || !(maxVal < (currDoutVal = ijv.getV()))) continue;
            maxIndex = inputOffset + h * params.W + w;
            maxVal = currDoutVal;
        }
        return maxIndex;
    }

    private static int getMaxIndex(int p, int q, int inputOffset, double[] inputArray, ConvolutionParameters params) {
        int start_index_h = params.start_indexes_h[p];
        int end_index_h = params.end_indexes_h[p];
        int start_index_w = params.start_indexes_w[q];
        int end_index_w = params.end_indexes_w[q];
        int maxIndex = -1;
        double maxVal = -1.7976931348623157E308;
        double currDoutVal = -1.0;
        for (int h = start_index_h; h < end_index_h; ++h) {
            for (int w = start_index_w; w < end_index_w; ++w) {
                currDoutVal = inputArray[inputOffset + h * params.W + w];
                if (!(maxVal < currDoutVal)) continue;
                maxIndex = inputOffset + h * params.W + w;
                maxVal = currDoutVal;
            }
        }
        return maxIndex;
    }

    public static void reluBackward(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, int numThreads) throws DMLRuntimeException {
        int N = input.getNumRows();
        ConvolutionParameters params = new ConvolutionParameters(N, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, numThreads);
        params.input1 = input;
        params.input2 = dout;
        params.output = outputBlock;
        if (input.getNumRows() != dout.getNumRows() || input.getNumColumns() != dout.getNumColumns()) {
            throw new DMLRuntimeException("Incorrect dimensions for relu_backward:" + input.getNumRows() + " != " + dout.getNumRows() + " || " + input.getNumColumns() + " != " + dout.getNumColumns());
        }
        LibMatrixDNN.runConvTask(TaskType.ReluBackward, params);
    }

    private static long doReluBackward(ConvolutionParameters params, int rl, int ru) throws DMLRuntimeException {
        double[] outputArray = params.output.getDenseBlock();
        int numOutCols = params.input1.getNumColumns();
        if (!params.input1.isInSparseFormat() && !params.input2.isInSparseFormat()) {
            double[] inputArr = params.input1.getDenseBlock();
            double[] doutArr = params.input2.getDenseBlock();
            for (int i = rl * numOutCols; i < ru * numOutCols; ++i) {
                outputArray[i] = inputArr[i] > 0.0 ? doutArr[i] : 0.0;
            }
        } else {
            int j;
            int i;
            Iterator<IJV> iter;
            if (params.input1.isInSparseFormat()) {
                iter = params.input1.sparseBlock.getIterator(rl, ru);
                while (iter.hasNext()) {
                    IJV ijv = iter.next();
                    i = ijv.getI();
                    j = ijv.getJ();
                    outputArray[i * numOutCols + j] = ijv.getV() > 0.0 ? 1.0 : 0.0;
                }
            } else {
                double[] inputArr = params.input1.getDenseBlock();
                for (int i2 = rl * numOutCols; i2 < ru * numOutCols; ++i2) {
                    outputArray[i2] = inputArr[i2] > 0.0 ? 1.0 : 0.0;
                }
            }
            if (params.input2.isInSparseFormat()) {
                iter = params.input2.sparseBlock.getIterator(rl, ru);
                while (iter.hasNext()) {
                    IJV ijv = iter.next();
                    i = ijv.getI();
                    j = ijv.getJ();
                    int n = i * numOutCols + j;
                    outputArray[n] = outputArray[n] * ijv.getV();
                }
            } else {
                double[] doutArr = params.input2.getDenseBlock();
                for (int i3 = rl * numOutCols; i3 < ru * numOutCols; ++i3) {
                    int n = i3;
                    outputArray[n] = outputArray[n] * doutArr[i3];
                }
            }
        }
        return params.output.recomputeNonZeros(rl, ru - 1, 0, numOutCols - 1);
    }

    public static void biasAdd(MatrixBlock input, MatrixBlock bias, MatrixBlock outputBlock, int numThreads) throws DMLRuntimeException {
        int N = input.getNumRows();
        int K = bias.getNumRows();
        int PQ = input.getNumColumns() / K;
        ConvolutionParameters params = new ConvolutionParameters(N, PQ, -1, -1, K, -1, -1, -1, -1, -1, -1, numThreads);
        params.input1 = input;
        params.input2 = bias;
        params.output = outputBlock;
        if (!input.isInSparseFormat() && TEST_SPARSE_INPUT) {
            input.denseToSparse();
        }
        if (!bias.isInSparseFormat() && TEST_SPARSE_FILTER) {
            bias.denseToSparse();
        }
        if (bias.getNumColumns() != 1 || input.getNumColumns() % K != 0) {
            throw new DMLRuntimeException("Incorrect inputs for bias_add: input[" + N + " X " + input.getNumColumns() + "] and bias[" + K + " X " + bias.getNumColumns() + "]");
        }
        if (input.isEmptyBlock()) {
            double[] outputArray = outputBlock.getDenseBlock();
            for (int n = 0; n < N; ++n) {
                LibMatrixDNN.fillBias(bias, outputArray, n, n + 1, N, K, PQ);
            }
        } else {
            LibMatrixDNN.runConvTask(TaskType.BiasAdd, params);
        }
        params.output.recomputeNonZeros();
    }

    public static void biasMultiply(MatrixBlock input, MatrixBlock bias, MatrixBlock outputBlock, int numThreads) throws DMLRuntimeException {
        int N = input.getNumRows();
        int K = bias.getNumRows();
        int PQ = input.getNumColumns() / K;
        ConvolutionParameters params = new ConvolutionParameters(N, PQ, -1, -1, K, -1, -1, -1, -1, -1, -1, numThreads);
        params.input1 = input;
        params.input2 = bias;
        params.output = outputBlock;
        if (!input.isInSparseFormat() && TEST_SPARSE_INPUT) {
            input.denseToSparse();
        }
        if (!bias.isInSparseFormat() && TEST_SPARSE_FILTER) {
            bias.denseToSparse();
        }
        if (bias.getNumColumns() != 1 || input.getNumColumns() % K != 0) {
            throw new DMLRuntimeException("Incorrect inputs for bias_multiply: input[" + N + " X " + input.getNumColumns() + "] and bias[" + K + " X " + bias.getNumColumns() + "]");
        }
        if (!input.isEmptyBlock() && !bias.isEmptyBlock()) {
            LibMatrixDNN.runConvTask(TaskType.BiasMultiply, params);
            params.output.recomputeNonZeros();
        } else {
            params.output.setNonZeros(0L);
        }
    }

    private static void doBiasMultiply(ConvolutionParameters params, int rl, int ru) throws DMLRuntimeException {
        double[] outputArray = params.output.getDenseBlock();
        int PQ = params.C;
        int numOutCols = params.input1.getNumColumns();
        if (!params.input1.isInSparseFormat() && !params.input2.isInSparseFormat()) {
            double[] inputArr = params.input1.getDenseBlock();
            double[] biasArr = params.input2.getDenseBlock();
            int K = params.K;
            int index = rl * K * PQ;
            for (int n = rl; n < ru; ++n) {
                for (int k = 0; k < K; ++k) {
                    int pq = 0;
                    while (pq < PQ) {
                        outputArray[index] = inputArr[index] * biasArr[k];
                        ++pq;
                        ++index;
                    }
                }
            }
        } else {
            if (params.input1.isInSparseFormat()) {
                Iterator<IJV> iter = params.input1.sparseBlock.getIterator(rl, ru);
                while (iter.hasNext()) {
                    IJV ijv = iter.next();
                    int i = ijv.getI();
                    int j = ijv.getJ();
                    outputArray[i * numOutCols + j] = ijv.getV();
                }
            } else {
                System.arraycopy(params.input1.getDenseBlock(), 0, outputArray, 0, outputArray.length);
            }
            int K = params.K;
            int index = rl * K * PQ;
            for (int k = 0; k < K; ++k) {
                double val = params.input2.getValue(k, 1);
                for (int n = rl; n < ru; ++n) {
                    for (int pq = 0; pq < PQ; ++pq) {
                        int n2 = index++;
                        outputArray[n2] = outputArray[n2] * val;
                    }
                }
            }
        }
    }

    private static void doBiasAdd(ConvolutionParameters params, int rl, int ru) throws DMLRuntimeException {
        double[] outputArray = params.output.getDenseBlock();
        int PQ = params.C;
        int numOutCols = params.input1.getNumColumns();
        if (!params.input1.isInSparseFormat() && !params.input2.isInSparseFormat()) {
            double[] inputArr = params.input1.getDenseBlock();
            double[] biasArr = params.input2.getDenseBlock();
            int K = params.K;
            int index = rl * K * PQ;
            for (int n = rl; n < ru; ++n) {
                for (int k = 0; k < K; ++k) {
                    int pq = 0;
                    while (pq < PQ) {
                        outputArray[index] = inputArr[index] + biasArr[k];
                        ++pq;
                        ++index;
                    }
                }
            }
        } else {
            LibMatrixDNN.fillBias(params.input2, outputArray, rl, ru, params.N, params.K, PQ);
            if (params.input1.isInSparseFormat()) {
                Iterator<IJV> iter = params.input1.sparseBlock.getIterator(rl, ru);
                while (iter.hasNext()) {
                    IJV ijv = iter.next();
                    int i = ijv.getI();
                    int j = ijv.getJ();
                    int n = i * numOutCols + j;
                    outputArray[n] = outputArray[n] + ijv.getV();
                }
            } else {
                double[] inputArr = params.input1.getDenseBlock();
                for (int i = rl * numOutCols; i < ru * numOutCols; ++i) {
                    int n = i;
                    outputArray[n] = outputArray[n] + inputArr[i];
                }
            }
        }
    }

    private static void fillBias(MatrixBlock bias, double[] outputArray, int n1, int n2, int N, int K, int PQ) {
        if (bias.isInSparseFormat()) {
            Iterator<IJV> iter = bias.sparseBlock.getIterator();
            while (iter.hasNext()) {
                IJV ijv = iter.next();
                int k = ijv.getI();
                double val = ijv.getV();
                for (int n = n1; n < n2; ++n) {
                    int fromIndex = n * K * PQ + k * PQ;
                    Arrays.fill(outputArray, fromIndex, fromIndex + PQ, val);
                }
            }
        } else {
            double[] biasArr = bias.getDenseBlock();
            for (int n = n1; n < n2; ++n) {
                for (int k = 0; k < K; ++k) {
                    int fromIndex = n * K * PQ + k * PQ;
                    double val = biasArr[k];
                    Arrays.fill(outputArray, fromIndex, fromIndex + PQ, val);
                }
            }
        }
    }

    public static void maxpooling(MatrixBlock input, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        params.input1 = input;
        params.output = outputBlock;
        if (input.getNumColumns() != params.C * params.H * params.W || input.getNumRows() != params.N) {
            throw new DMLRuntimeException("Incorrect input dimensions in maxpooling:" + input.getNumRows() + " " + input.getNumColumns() + " " + params.N + " " + params.C * params.H * params.W);
        }
        LibMatrixDNN.fillIndexesArray(params);
        LibMatrixDNN.runConvTask(TaskType.MaxPooling_Forward, params);
        outputBlock.recomputeNonZeros();
    }

    private static void doPooling(int n, ConvolutionParameters params) throws DMLRuntimeException {
        double[] inputArray = null;
        if (!params.input1.isInSparseFormat()) {
            inputArray = params.input1.getDenseBlock();
        }
        double[] outputArray = null;
        if (params.output.isInSparseFormat()) {
            throw new DMLRuntimeException("Expected the output to be allocated in dense format");
        }
        outputArray = params.output.getDenseBlock();
        int inOffset = n * params.C * params.H * params.W;
        int out_index = n * params.C * params.P * params.Q;
        int HW = params.H * params.W;
        if (inputArray != null) {
            for (int c = 0; c < params.C; ++c) {
                int inOffset1 = inOffset + c * HW;
                for (int p = 0; p < params.P; ++p) {
                    int q = 0;
                    while (q < params.Q) {
                        for (int h = params.start_indexes_h[p]; h < params.end_indexes_h[p]; ++h) {
                            for (int w = params.start_indexes_w[q]; w < params.end_indexes_w[q]; ++w) {
                                outputArray[out_index] = Math.max(outputArray[out_index], inputArray[inOffset1 + h * params.W + w]);
                            }
                        }
                        ++q;
                        ++out_index;
                    }
                }
            }
        } else {
            for (int c = 0; c < params.C; ++c) {
                for (int p = 0; p < params.P; ++p) {
                    int q = 0;
                    while (q < params.Q) {
                        for (int h = params.start_indexes_h[p]; h < params.end_indexes_h[p]; ++h) {
                            for (int w = params.start_indexes_w[q]; w < params.end_indexes_w[q]; ++w) {
                                outputArray[out_index] = Math.max(outputArray[out_index], params.input1.quickGetValue(n, c * HW + h * params.W + w));
                            }
                        }
                        ++q;
                        ++out_index;
                    }
                }
            }
        }
    }

    private static void doRotate180(int inputN, int outputN, MatrixBlock input, double[] outputArray, ConvolutionParameters params, boolean zeroOutSparseOutput) throws DMLRuntimeException {
        block8: {
            block7: {
                double[] inputArray = null;
                if (!input.isInSparseFormat()) {
                    inputArray = input.getDenseBlock();
                }
                if (outputArray == null) {
                    throw new DMLRuntimeException("Sparse output is not supported for rotate180");
                }
                int outputOffset = outputN * params.K * params.P * params.Q;
                if (inputArray == null) break block7;
                for (int k = 0; k < params.K; ++k) {
                    for (int p = 0; p < params.P; ++p) {
                        for (int q = 0; q < params.Q; ++q) {
                            outputArray[outputOffset + p * params.Q * params.K + q * params.K + k] = inputArray[inputN * params.K * params.P * params.Q + k * params.P * params.Q + p * params.Q + q];
                        }
                    }
                }
                break block8;
            }
            if (zeroOutSparseOutput) {
                Arrays.fill(outputArray, 0.0);
            }
            if (input.isEmptyBlock()) break block8;
            Iterator<IJV> iter = input.sparseBlock.getIterator(inputN, inputN + 1);
            int[] tensorIndexes = new int[3];
            while (iter.hasNext()) {
                IJV ijv = iter.next();
                LibMatrixDNN.computeTensorIndexes(ijv.getJ(), tensorIndexes, params.P, params.Q);
                int k = tensorIndexes[0];
                int p = tensorIndexes[1];
                int q = tensorIndexes[2];
                outputArray[outputOffset + p * params.Q * params.K + q * params.K + k] = ijv.getV();
            }
        }
    }

    private static void addMatrixBlocks(int poolSize, TaskType type, ConvolutionParameters params, ConcurrentLinkedQueue<MatrixBlock> im2ColOutBlocks, ConcurrentLinkedQueue<MatrixBlock> doutReshapedBlocks, ConcurrentLinkedQueue<MatrixBlock> partialRetBlocks) {
        for (int i = 0; i < poolSize; ++i) {
            if (type == TaskType.LoopedIm2ColConv2d || type == TaskType.LoopedIm2ColConv2dBwdFilter) {
                MatrixBlock im2ColOutBlock = new MatrixBlock(params.C * params.R * params.S, params.P * params.Q, false);
                im2ColOutBlock.allocateDenseBlock();
                im2ColOutBlocks.add(im2ColOutBlock);
            }
            if (type == TaskType.LoopedIm2ColConv2dBwdFilter) {
                MatrixBlock partialRetBlock = new MatrixBlock(params.C * params.R * params.S, params.K, false);
                partialRetBlock.allocateDenseBlock();
                partialRetBlocks.add(partialRetBlock);
            }
            if (type != TaskType.LoopedIm2ColConv2dBwdData && type != TaskType.LoopedIm2ColConv2dBwdFilter) continue;
            MatrixBlock doutReshapedBlock = new MatrixBlock(params.P * params.Q, params.K, false);
            doutReshapedBlock.allocateDenseBlock();
            doutReshapedBlocks.add(doutReshapedBlock);
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private static void runConvTask(TaskType type, ConvolutionParameters params) throws DMLRuntimeException {
        int k = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        ConcurrentLinkedQueue<MatrixBlock> im2ColOutBlocks = new ConcurrentLinkedQueue<MatrixBlock>();
        ConcurrentLinkedQueue<MatrixBlock> doutReshapedBlocks = new ConcurrentLinkedQueue<MatrixBlock>();
        ConcurrentLinkedQueue<MatrixBlock> partialRetBlocks = new ConcurrentLinkedQueue<MatrixBlock>();
        if (params.isOutputThreadSafe() && k > 1) {
            int poolSize = Math.min(k, params.N);
            LibMatrixDNN.addMatrixBlocks(poolSize, type, params, im2ColOutBlocks, doutReshapedBlocks, partialRetBlocks);
            ArrayList<ConvTask> tasks = new ArrayList<ConvTask>();
            int blklen = (int)Math.ceil((double)params.N / (double)poolSize / 2.0);
            for (int i = 0; i < poolSize * 2 && i * blklen < params.N; ++i) {
                tasks.add(new ConvTask(i * blklen, Math.min((i + 1) * blklen, params.N), type, params, im2ColOutBlocks, doutReshapedBlocks, partialRetBlocks));
            }
            try {
                ExecutorService pool = Executors.newFixedThreadPool(poolSize);
                List taskret = pool.invokeAll(tasks);
                pool.shutdown();
                for (Future task : taskret) {
                    params.output.nonZeros += ((Long)task.get()).longValue();
                }
                if (type != TaskType.LoopedIm2ColConv2dBwdFilter) return;
                LibMatrixDNN.elementWiseInPlaceTransposedAddition(params.output, partialRetBlocks.toArray(new MatrixBlock[0]));
                return;
            }
            catch (Exception e) {
                throw new DMLRuntimeException("Error while executing multi-threaded " + type.name(), e);
            }
        }
        LibMatrixDNN.addMatrixBlocks(1, type, params, im2ColOutBlocks, doutReshapedBlocks, partialRetBlocks);
        try {
            params.output.setNonZeros(new ConvTask(0, params.N, type, params, im2ColOutBlocks, doutReshapedBlocks, partialRetBlocks).call());
            if (type != TaskType.LoopedIm2ColConv2dBwdFilter) return;
            LibMatrixDNN.elementWiseInPlaceTransposedAddition(params.output, partialRetBlocks.toArray(new MatrixBlock[0]));
            return;
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Error while executing single-threaded " + type.name(), e);
        }
    }

    private static void addBias(ConvolutionParameters params, int rl, int ru) {
        int PQ = params.P * params.Q;
        int K = params.K;
        double[] outputArr = params.output.getDenseBlock();
        if (!params.bias.isInSparseFormat()) {
            double[] biasArr = params.bias.getDenseBlock();
            int index = rl * K * PQ;
            for (int n = rl; n < ru; ++n) {
                for (int k = 0; k < K; ++k) {
                    for (int pq = 0; pq < PQ; ++pq) {
                        int n2 = index++;
                        outputArr[n2] = outputArr[n2] + biasArr[k];
                    }
                }
            }
        } else {
            Iterator<IJV> iter = params.bias.getSparseBlockIterator();
            while (iter.hasNext()) {
                IJV ijv = iter.next();
                int k = ijv.getI();
                double val = ijv.getV();
                for (int n = rl; n < ru; ++n) {
                    int index = n * K * PQ + k * PQ;
                    for (int pq = 0; pq < PQ; ++pq) {
                        int n3 = index++;
                        outputArr[n3] = outputArr[n3] + val;
                    }
                }
            }
        }
    }

    private static void doCol2imOverSingleImage(int outputN, MatrixBlock input, ConvolutionParameters params) throws DMLRuntimeException {
        if (input.rlen != params.P * params.Q || input.clen != params.C * params.R * params.S) {
            throw new DMLRuntimeException("Incorrect input dimensions");
        }
        double[] outputArray = null;
        if (params.output.isInSparseFormat()) {
            throw new DMLRuntimeException("Only dense output is implemented");
        }
        outputArray = params.output.getDenseBlock();
        if (!input.isInSparseFormat()) {
            double[] inputArray = input.getDenseBlock();
            LibMatrixDNN.doCol2IMDenseInput(0, outputN, inputArray, outputArray, params);
        } else if (!input.isEmptyBlock()) {
            LibMatrixDNN.doCol2IMSparseInput(0, outputN, input.getSparseBlockIterator(), outputArray, params);
        }
    }

    private static void doCol2IMSparseInput(int inputN, int outputN, Iterator<IJV> inputIter, double[] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
        int[] tensorIndexes = new int[3];
        while (inputIter.hasNext()) {
            int outIndex;
            IJV ijv = inputIter.next();
            LibMatrixDNN.computeTensorIndexes(ijv.getJ(), tensorIndexes, params.R, params.S);
            int c = tensorIndexes[0];
            int r = tensorIndexes[1];
            int s = tensorIndexes[2];
            LibMatrixDNN.computeTensorIndexes(ijv.getI(), tensorIndexes, params.P, params.Q);
            int p = tensorIndexes[1];
            int q = tensorIndexes[2];
            if (inputN != tensorIndexes[0]) {
                throw new DMLRuntimeException("Incorrect tensor indexes: " + inputN + " != " + tensorIndexes[0] + " <" + p + " " + q + " " + ijv.getI() + params.P + " " + params.Q + ">");
            }
            int h = p * params.stride_h + r - params.pad_h;
            int w = q * params.stride_w + s - params.pad_w;
            if (h < 0 || h >= params.H || w < 0 || w >= params.W) continue;
            int n = outIndex = outputN * params.C * params.H * params.W + c * params.H * params.W + h * params.W + w;
            outputArray[n] = outputArray[n] + ijv.getV();
        }
    }

    private static void doCol2IMDenseInput(int inputN, int outputN, double[] inputArray, double[] outputArray, ConvolutionParameters params) throws DMLRuntimeException {
        int outputNOffset = outputN * params.C * params.H * params.W;
        for (int p = 0; p < params.P; ++p) {
            int hOffset = p * params.stride_h - params.pad_h;
            int rStart = Math.max(0, -hOffset);
            int rEnd = Math.min(params.R, params.H - hOffset);
            for (int q = 0; q < params.Q; ++q) {
                int wOffset = q * params.stride_w - params.pad_w;
                int sStart = Math.max(0, -wOffset);
                int sEnd = Math.min(params.S, params.W - wOffset);
                int tempOffset = (inputN * params.P * params.Q + p * params.Q + q) * params.C * params.R * params.S;
                for (int c = 0; c < params.C; ++c) {
                    int outOffset = outputNOffset + c * params.H * params.W;
                    int inputOffset = tempOffset + c * params.R * params.S;
                    for (int r = rStart; r < rEnd; ++r) {
                        for (int s = sStart; s < sEnd; ++s) {
                            int outIndex;
                            int inputIndex = inputOffset + r * params.S + s;
                            int n = outIndex = outOffset + (hOffset + r) * params.W + wOffset + s;
                            outputArray[n] = outputArray[n] + inputArray[inputIndex];
                        }
                    }
                }
            }
        }
    }

    private static void doIm2colDense(int n, double[] inputArray, double[] outputArray, ConvolutionParameters params) {
        int CRS = params.C * params.R * params.S;
        int nOffset = n * params.C * params.H * params.W;
        if (params.stride_h == 1 && params.stride_w == 1 && params.pad_h == 0 && params.pad_w == 0) {
            for (int c = 0; c < CRS; ++c) {
                int wOffset = c % params.S;
                int hOffset = c / params.S % params.R;
                int cInput = c / params.R / params.S;
                for (int h = 0; h < params.P; ++h) {
                    int hPadded = h + hOffset;
                    int outOffset = (c * params.P + h) * params.Q;
                    int inputOffset = nOffset + (cInput * params.H + hPadded) * params.W;
                    System.arraycopy(inputArray, inputOffset + wOffset, outputArray, outOffset, params.Q);
                    int w = params.Q - 1;
                    int wPadded = w + wOffset;
                    outputArray[outOffset + w] = hPadded < params.H && wPadded < params.W ? inputArray[inputOffset + wPadded] : 0.0;
                }
            }
        } else {
            for (int c = 0; c < CRS; ++c) {
                int wOffset = c % params.S;
                int hOffset = c / params.S % params.R;
                int cInput = c / params.R / params.S;
                for (int h = 0; h < params.P; ++h) {
                    int outOffset = (c * params.P + h) * params.Q;
                    int hPadded = h * params.stride_h - params.pad_h + hOffset;
                    int inputOffset = nOffset + (cInput * params.H + hPadded) * params.W;
                    if (hPadded < 0 || hPadded >= params.H) {
                        Arrays.fill(outputArray, outOffset, outOffset + params.Q, 0.0);
                        continue;
                    }
                    for (int w = 0; w < params.Q; ++w) {
                        int wPadded = w * params.stride_w - params.pad_w + wOffset;
                        outputArray[outOffset + w] = wPadded >= 0 && wPadded < params.W ? inputArray[inputOffset + wPadded] : 0.0;
                    }
                }
            }
        }
    }

    private static void doIm2colSparse(int n, MatrixBlock input, double[] outputArray, ConvolutionParameters params) {
        int CRS = params.C * params.R * params.S;
        for (int c = 0; c < CRS; ++c) {
            int wOffset = c % params.S;
            int hOffset = c / params.S % params.R;
            int cInput = c / params.R / params.S;
            for (int h = 0; h < params.P; ++h) {
                int outOffset = (c * params.P + h) * params.Q;
                int hPadded = h * params.stride_h - params.pad_h + hOffset;
                int tempOffset = (cInput * params.H + hPadded) * params.W;
                if (hPadded < 0 || hPadded >= params.H) {
                    Arrays.fill(outputArray, outOffset, outOffset + params.Q, 0.0);
                    continue;
                }
                for (int w = 0; w < params.Q; ++w) {
                    int wPadded = w * params.stride_w - params.pad_w + wOffset;
                    outputArray[outOffset + w] = wPadded >= 0 && wPadded < params.W ? input.getValue(n, tempOffset + wPadded) : 0.0;
                }
            }
        }
    }

    private static void doIm2col(int n, MatrixBlock output, ConvolutionParameters params) throws DMLRuntimeException {
        double[] inputArray = null;
        if (!params.input1.isInSparseFormat()) {
            inputArray = params.input1.getDenseBlock();
        }
        double[] outputArray = null;
        if (output.isInSparseFormat()) {
            throw new DMLRuntimeException("Sparse output is not supported for im2col");
        }
        outputArray = output.getDenseBlock();
        if (inputArray != null) {
            LibMatrixDNN.doIm2colDense(n, inputArray, outputArray, params);
        } else {
            LibMatrixDNN.doIm2colSparse(n, params.input1, outputArray, params);
        }
    }

    private static class ConvTask
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final ConvolutionParameters _params;
        private final TaskType _type;
        private final ConcurrentLinkedQueue<MatrixBlock> _im2ColOutBlocks;
        private final ConcurrentLinkedQueue<MatrixBlock> _partialRetBlocks;
        private final ConcurrentLinkedQueue<MatrixBlock> _doutReshapedBlocks;

        public ConvTask(int rl, int ru, TaskType type, ConvolutionParameters params, ConcurrentLinkedQueue<MatrixBlock> im2ColOutBlocks, ConcurrentLinkedQueue<MatrixBlock> doutReshapedBlocks, ConcurrentLinkedQueue<MatrixBlock> partialRetBlocks) {
            this._rl = rl;
            this._ru = ru;
            this._type = type;
            this._params = params;
            this._im2ColOutBlocks = im2ColOutBlocks;
            this._partialRetBlocks = partialRetBlocks;
            this._doutReshapedBlocks = doutReshapedBlocks;
        }

        @Override
        public Long call() throws DMLRuntimeException {
            long lnnz = 0L;
            switch (this._type) {
                case MaxPooling_Forward: {
                    for (int n = this._rl; n < this._ru; ++n) {
                        LibMatrixDNN.doPooling(n, this._params);
                    }
                    break;
                }
                case MaxPooling_Backward: {
                    for (int n = this._rl; n < this._ru; ++n) {
                        LibMatrixDNN.doPoolingBackward(n, this._params);
                    }
                    break;
                }
                case BiasAdd: {
                    LibMatrixDNN.doBiasAdd(this._params, this._rl, this._ru);
                    break;
                }
                case BiasMultiply: {
                    LibMatrixDNN.doBiasMultiply(this._params, this._rl, this._ru);
                    break;
                }
                case ReluBackward: {
                    lnnz = LibMatrixDNN.doReluBackward(this._params, this._rl, this._ru);
                    break;
                }
                case LoopedIm2ColConv2d: {
                    MatrixBlock im2ColOutBlock = (MatrixBlock)this._im2ColOutBlocks.remove();
                    for (int n = this._rl; n < this._ru; ++n) {
                        LibMatrixDNN.doLoopedIm2ColConv2d(n, im2ColOutBlock, this._params);
                    }
                    this._im2ColOutBlocks.add(im2ColOutBlock);
                    if (this._params.bias == null) break;
                    LibMatrixDNN.addBias(this._params, this._rl, this._ru);
                    break;
                }
                case LoopedIm2ColConv2dBwdFilter: {
                    MatrixBlock im2ColOutBlock = (MatrixBlock)this._im2ColOutBlocks.remove();
                    MatrixBlock partialRetBlock = (MatrixBlock)this._partialRetBlocks.remove();
                    MatrixBlock doutReshapedBlock = (MatrixBlock)this._doutReshapedBlocks.remove();
                    for (int n = this._rl; n < this._ru; ++n) {
                        partialRetBlock = LibMatrixDNN.doLoopedIm2ColConv2dBwdFilter(n, im2ColOutBlock, doutReshapedBlock, partialRetBlock, this._params);
                    }
                    this._im2ColOutBlocks.add(im2ColOutBlock);
                    this._partialRetBlocks.add(partialRetBlock);
                    this._doutReshapedBlocks.add(doutReshapedBlock);
                    break;
                }
                case LoopedIm2ColConv2dBwdData: {
                    MatrixBlock doutReshapedBlock = (MatrixBlock)this._doutReshapedBlocks.remove();
                    for (int n = this._rl; n < this._ru; ++n) {
                        LibMatrixDNN.doLoopedIm2ColConv2dBwdData(n, doutReshapedBlock, this._params);
                    }
                    this._doutReshapedBlocks.add(doutReshapedBlock);
                    break;
                }
                default: {
                    throw new DMLRuntimeException("Unsupported ConvTask:" + this._type.name());
                }
            }
            return lnnz;
        }
    }

    private static enum TaskType {
        MaxPooling_Forward,
        MaxPooling_Backward,
        LoopedIm2ColConv2d,
        LoopedIm2ColConv2dBwdFilter,
        LoopedIm2ColConv2dBwdData,
        BiasAdd,
        ReluBackward,
        BiasMultiply;

    }
}

