/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.util.concurrent.Callable;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNIm2ColHelper;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNNRotate180Helper;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.util.ConvolutionUtils;
import org.apache.sysml.utils.NativeHelper;

public class LibMatrixDNNConv2dBackwardFilterHelper {
    private static synchronized void inplaceTransposedAddition(double[] partialRetBlock, ConvolutionParameters params) {
        int iter = 0;
        int CRS = params.C * params.R * params.S;
        int K = params.K;
        double[] outputArr = params.output.denseBlock;
        for (int i = 0; i < CRS; ++i) {
            int j = 0;
            while (j < K) {
                int index;
                int n = index = j * CRS + i;
                outputArr[n] = outputArr[n] + partialRetBlock[iter];
                ++j;
                ++iter;
            }
        }
    }

    public static class Conv2dBackwardFilter
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final ConvolutionParameters _params;

        public Conv2dBackwardFilter(int rl, int ru, ConvolutionParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int PQ = this._params.P * this._params.Q;
            int K = this._params.K;
            int CRS = this._params.C * this._params.R * this._params.S;
            MatrixBlock dout = this._params.input2;
            MatrixBlock im2ColOutBlock = new MatrixBlock(CRS, PQ, false);
            MatrixBlock dout_reshaped = new MatrixBlock(PQ, K, false);
            dout_reshaped.allocateDenseBlock();
            LibMatrixDNNIm2ColHelper.Im2colWorker im2ColWorker = LibMatrixDNNIm2ColHelper.Im2colWorker.getWorker(this._params.input1, im2ColOutBlock, this._params, true);
            LibMatrixDNNRotate180Helper.Rotate180Worker rotate180Worker = LibMatrixDNNRotate180Helper.Rotate180Worker.getWorker(dout, dout_reshaped.getDenseBlock(), this._params, true);
            double[] partialRetBlock = new double[CRS * this._params.K];
            long time1 = 0L;
            long time2 = 0L;
            for (int n = this._rl; n < this._ru; ++n) {
                long t3;
                rotate180Worker.execute(n, 0);
                long t1 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0L;
                im2ColWorker.execute(n);
                long t2 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0L;
                MatrixBlock temp = new MatrixBlock(CRS, K, false);
                LibMatrixDNNHelper.singleThreadedMatMult(im2ColOutBlock, dout_reshaped, temp, true, true, this._params);
                long l = t3 = DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS ? System.nanoTime() : 0L;
                if (!temp.isEmptyBlock()) {
                    ConvolutionUtils.binaryOperationInPlace(temp, partialRetBlock, 0, K, 0, CRS, LibMatrixDNN._binaryElementWiseAddition);
                }
                if (!DMLScript.STATISTICS || !LibMatrixDNN.DISPLAY_STATISTICS) continue;
                time1 += t2 - t1;
                time2 += t3 - t2;
            }
            LibMatrixDNNConv2dBackwardFilterHelper.inplaceTransposedAddition(partialRetBlock, this._params);
            if (DMLScript.STATISTICS && LibMatrixDNN.DISPLAY_STATISTICS) {
                LibMatrixDNN.loopedConvBwdFilterIm2ColTime.addAndGet(time1);
                LibMatrixDNN.loopedConvBwdFilterMatMultTime.addAndGet(time2);
            }
            return 0L;
        }
    }

    public static class SparseNativeConv2dBackwardFilterDense
    implements Callable<Long> {
        public int _rl;
        public int _ru;
        private final ConvolutionParameters _params;

        public SparseNativeConv2dBackwardFilterDense(int rl, int ru, ConvolutionParameters params) {
            this._rl = rl;
            this._ru = ru;
            this._params = params;
        }

        @Override
        public Long call() throws Exception {
            int CRS = this._params.C * this._params.R * this._params.S;
            double[] dout_n = new double[this._params.P * this._params.Q * this._params.K];
            LibMatrixDNNRotate180Helper.Rotate180Worker rotate180Worker = LibMatrixDNNRotate180Helper.Rotate180Worker.getWorker(this._params.input2, dout_n, this._params, true);
            double[] partialRetBlock = new double[CRS * this._params.K];
            for (int n = this._rl; n < this._ru; ++n) {
                if (this._params.input1.getSparseBlock().isEmpty(n)) continue;
                rotate180Worker.execute(n, 0);
                int apos = this._params.input1.getSparseBlock().pos(n);
                int alen = this._params.input1.getSparseBlock().size(n);
                int[] aix = this._params.input1.getSparseBlock().indexes(n);
                double[] avals = this._params.input1.getSparseBlock().values(n);
                NativeHelper.conv2dBackwardFilterSparseDense(apos, alen, aix, avals, dout_n, partialRetBlock, 1, this._params.C, this._params.H, this._params.W, this._params.K, this._params.R, this._params.S, this._params.stride_h, this._params.stride_w, this._params.pad_h, this._params.pad_w, this._params.P, this._params.Q, 1);
            }
            LibMatrixDNNConv2dBackwardFilterHelper.inplaceTransposedAddition(partialRetBlock, this._params);
            return 0L;
        }
    }
}

