/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.utils.NativeHelper;
import org.apache.sysml.utils.Statistics;

public class LibMatrixNative {
    private static boolean isMatMultMemoryBound(int m1Rlen, int m1Clen, int m2Clen) {
        return m1Rlen == 1 || m1Clen == 1 || m2Clen == 1;
    }

    public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) throws DMLRuntimeException {
        LibMatrixNative.matrixMult(m1, m2, ret, k, true);
    }

    public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k, boolean examSparsity) throws DMLRuntimeException {
        int n = k = k <= 0 ? NativeHelper.getMaxNumThreads() : k;
        if (m1.isEmptyBlock() || m2.isEmptyBlock()) {
            ret.setNonZeros(0L);
            if (examSparsity) {
                ret.examSparsity();
            }
            return;
        }
        if (NativeHelper.isNativeLibraryLoaded() && !LibMatrixNative.isMatMultMemoryBound(m1.rlen, m1.clen, m2.clen) && !m1.isInSparseFormat() && !m2.isInSparseFormat()) {
            long start;
            ret.sparse = false;
            ret.allocateDenseBlock();
            long l = start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            if (NativeHelper.matrixMultDenseDense(m1.denseBlock, m2.denseBlock, ret.denseBlock, m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns(), k)) {
                if (DMLScript.STATISTICS) {
                    Statistics.nativeLibMatrixMultTime += System.nanoTime() - start;
                    Statistics.numNativeLibMatrixMultCalls.increment();
                }
                ret.recomputeNonZeros();
                if (examSparsity) {
                    ret.examSparsity();
                }
                return;
            }
            Statistics.incrementNativeFailuresCounter();
        }
        if (k == 1) {
            LibMatrixMult.matrixMult(m1, m2, ret, examSparsity);
        } else {
            LibMatrixMult.matrixMult(m1, m2, ret, k);
        }
    }

    public static void conv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        LibMatrixDNN.checkInputsConv2d(input, filter, outputBlock, params);
        int n = params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !input.isInSparseFormat() && !filter.isInSparseFormat()) {
            LibMatrixNative.setNumThreads(params);
            if (params.bias == null) {
                long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
                int nnz = NativeHelper.conv2dDense(input.denseBlock, filter.denseBlock, outputBlock.denseBlock, params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
                if (nnz != -1) {
                    if (DMLScript.STATISTICS) {
                        Statistics.nativeConv2dTime += System.nanoTime() - start;
                        Statistics.numNativeConv2dCalls.increment();
                    }
                    outputBlock.setNonZeros(nnz);
                    return;
                }
                Statistics.incrementNativeFailuresCounter();
            } else {
                if (params.bias.isInSparseFormat()) {
                    params.bias.sparseToDense();
                }
                long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
                int nnz = NativeHelper.conv2dBiasAddDense(input.denseBlock, params.bias.denseBlock, filter.denseBlock, outputBlock.denseBlock, params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
                if (nnz != -1) {
                    if (DMLScript.STATISTICS) {
                        Statistics.nativeConv2dTime += System.nanoTime() - start;
                        Statistics.numNativeConv2dCalls.increment();
                    }
                    outputBlock.setNonZeros(nnz);
                    return;
                }
                Statistics.incrementNativeFailuresCounter();
            }
        }
        LibMatrixDNN.conv2d(input, filter, outputBlock, params);
    }

    private static void setNumThreads(ConvolutionParameters params) {
        params.numThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        if (!params.isOutputThreadSafe() || params.numThreads <= 1) {
            params.numThreads = 1;
        }
    }

    public static void conv2dBackwardFilter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        LibMatrixDNN.checkInputsConv2dBackwardFilter(input, dout, outputBlock, params);
        int n = params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !dout.isInSparseFormat() && !input.isInSparseFormat()) {
            LibMatrixNative.setNumThreads(params);
            long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            int nnz = NativeHelper.conv2dBackwardFilterDense(input.denseBlock, dout.denseBlock, outputBlock.denseBlock, params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
            if (nnz != -1) {
                if (DMLScript.STATISTICS) {
                    Statistics.nativeConv2dBwdFilterTime += System.nanoTime() - start;
                    Statistics.numNativeConv2dBwdFilterCalls.increment();
                }
                outputBlock.setNonZeros(nnz);
                return;
            }
            Statistics.incrementNativeFailuresCounter();
        }
        LibMatrixDNN.conv2dBackwardFilter(input, dout, outputBlock, params);
    }

    public static void conv2dBackwardData(MatrixBlock filter, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        LibMatrixDNN.checkInputsConv2dBackwardData(filter, dout, outputBlock, params);
        int n = params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !dout.isInSparseFormat() && !filter.isInSparseFormat()) {
            LibMatrixNative.setNumThreads(params);
            long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            int nnz = NativeHelper.conv2dBackwardDataDense(filter.denseBlock, dout.denseBlock, outputBlock.denseBlock, params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
            if (nnz != -1) {
                if (DMLScript.STATISTICS) {
                    Statistics.nativeConv2dBwdDataTime += System.nanoTime() - start;
                    Statistics.numNativeConv2dBwdDataCalls.increment();
                }
                outputBlock.setNonZeros(nnz);
                return;
            }
            Statistics.incrementNativeFailuresCounter();
        }
        LibMatrixDNN.conv2dBackwardData(filter, dout, outputBlock, params);
    }
}

