/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.util.ArrayList;
import java.util.concurrent.Callable;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNIm2ColHelper;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.utils.NativeHelper;

public class LibMatrixDNNConv2dHelper {

    public static class SparseNativeConv2d
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final ConvolutionParameters _params;

        public SparseNativeConv2d(int rl, int ru, ConvolutionParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int KPQ = this._params.K * this._params.P * this._params.Q;
            double[] temp = new double[KPQ];
            for (int n = this._rl; n < this._ru; ++n) {
                if (this._params.input1.getSparseBlock().isEmpty(n)) continue;
                int apos = this._params.input1.getSparseBlock().pos(n);
                int alen = this._params.input1.getSparseBlock().size(n);
                int[] aix = this._params.input1.getSparseBlock().indexes(n);
                double[] avals = this._params.input1.getSparseBlock().values(n);
                NativeHelper.conv2dSparse(apos, alen, aix, avals, this._params.input2.getDenseBlock(), temp, 1, this._params.C, this._params.H, this._params.W, this._params.K, this._params.R, this._params.S, this._params.stride_h, this._params.stride_w, this._params.pad_h, this._params.pad_w, this._params.P, this._params.Q, 1);
                System.arraycopy(temp, 0, this._params.output.denseBlock, n * KPQ, KPQ);
            }
            return 0L;
        }
    }

    public static class LoopedIm2ColConv2dAllChannels
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final ConvolutionParameters _params;

        public LoopedIm2ColConv2dAllChannels(int rl, int ru, ConvolutionParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K = this._params.K;
            int CRS = this._params.C * this._params.R * this._params.S;
            MatrixBlock im2ColOutBlock = new MatrixBlock(CRS, PQ, false);
            LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker(this._params.input1, im2ColOutBlock, this._params, true);
            long time1 = 0L;
            long time2 = 0L;
            for (int n = this._rl; n < this._ru; ++n) {
                long t3;
                long t1 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0L;
                im2ColWorker.execute(n);
                long t2 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0L;
                MatrixBlock matMultOutBlock = new MatrixBlock(K, PQ, false);
                LibMatrixDNNHelper.singleThreadedMatMult(this._params.input2, im2ColOutBlock, matMultOutBlock, false, true, this._params);
                long l = t3 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0L;
                if (DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
                    time1 += t2 - t1;
                    time2 += t3 - t2;
                }
                this.partialCopy1(matMultOutBlock, this._params.output.getDenseBlock(), n * K * PQ, K, PQ);
            }
            if (this._params.bias != null) {
                LibMatrixDNNHelper.addBias(this._rl, this._ru, this._params.output.getDenseBlock(), this._params.bias.getDenseBlock(), K, PQ);
            }
            if (DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
                LibMatrixDNN.loopedConvIm2ColTime.addAndGet(time1);
                LibMatrixDNN.loopedConvMatMultTime.addAndGet(time2);
            }
            return 0L;
        }

        private void partialCopy1(MatrixBlock src, double[] dest, int destPos, int K, int PQ) {
            if (!src.isEmptyBlock()) {
                if (src.isInSparseFormat()) {
                    for (int k = 0; k < src.getNumRows(); ++k) {
                        if (src.sparseBlock.isEmpty(k)) continue;
                        int apos = src.sparseBlock.pos(k);
                        int alen = src.sparseBlock.size(k);
                        int[] aix = src.sparseBlock.indexes(k);
                        double[] avals = src.sparseBlock.values(k);
                        for (int j = apos; j < apos + alen; ++j) {
                            int pqIndex = aix[j];
                            dest[destPos + k * PQ + pqIndex] = avals[j];
                        }
                    }
                } else {
                    System.arraycopy(src.denseBlock, 0, dest, destPos, K * PQ);
                }
            }
        }
    }

    public static class LoopedIm2ColConv2dOneChannel
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final ConvolutionParameters _params;
        ArrayList<MatrixBlock> _filters;

        public LoopedIm2ColConv2dOneChannel(int rl, int ru, ConvolutionParameters params, ArrayList<MatrixBlock> filters) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
            this._filters = filters;
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K = this._params.K;
            int RS = this._params.R * this._params.S;
            MatrixBlock im2ColOutBlock = new MatrixBlock(RS, PQ, false);
            LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker(this._params.input1, im2ColOutBlock, this._params, false);
            long time1 = 0L;
            long time2 = 0L;
            for (int n = this._rl; n < this._ru; ++n) {
                for (int c = 0; c < this._params.C; ++c) {
                    long t3;
                    long t1 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0L;
                    im2ColWorker.execute(n, c);
                    long t2 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0L;
                    MatrixBlock matMultOutBlock = new MatrixBlock(K, PQ, false);
                    LibMatrixDNNHelper.singleThreadedMatMult(this._filters.get(c), im2ColOutBlock, matMultOutBlock, false, true, this._params);
                    long l = t3 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0L;
                    if (DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
                        time1 += t2 - t1;
                        time2 += t3 - t2;
                    }
                    this.add(matMultOutBlock, this._params.output.getDenseBlock(), n * K * PQ, K, PQ);
                }
            }
            if (this._params.bias != null) {
                LibMatrixDNNHelper.addBias(this._rl, this._ru, this._params.output.getDenseBlock(), this._params.bias.getDenseBlock(), K, PQ);
            }
            if (DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
                LibMatrixDNN.loopedConvIm2ColTime.addAndGet(time1);
                LibMatrixDNN.loopedConvMatMultTime.addAndGet(time2);
            }
            return 0L;
        }

        private void add(MatrixBlock src, double[] dest, int destPos, int K, int PQ) {
            block5: {
                if (src.isEmptyBlock()) break block5;
                if (src.isInSparseFormat()) {
                    for (int k = 0; k < src.getNumRows(); ++k) {
                        if (src.sparseBlock.isEmpty(k)) continue;
                        int apos = src.sparseBlock.pos(k);
                        int alen = src.sparseBlock.size(k);
                        int[] aix = src.sparseBlock.indexes(k);
                        double[] avals = src.sparseBlock.values(k);
                        for (int j = apos; j < apos + alen; ++j) {
                            int pqIndex = aix[j];
                            int n = destPos + k * PQ + pqIndex;
                            dest[n] = dest[n] + avals[j];
                        }
                    }
                } else {
                    for (int i = 0; i < K * PQ; ++i) {
                        int n = destPos + i;
                        dest[n] = dest[n] + src.denseBlock[i];
                    }
                }
            }
        }
    }
}

