/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.matrix.data;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.matrix.data.LibMatrixDNN;
import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.utils.NativeHelper;
import org.apache.sysml.utils.Statistics;

public class LibMatrixNative {
    private static final Log LOG = LogFactory.getLog((String)LibMatrixNative.class.getName());
    private static ThreadLocal<FloatBuffer> inBuff = new ThreadLocal();
    private static ThreadLocal<FloatBuffer> biasBuff = new ThreadLocal();
    private static ThreadLocal<FloatBuffer> filterBuff = new ThreadLocal();
    private static ThreadLocal<FloatBuffer> outBuff = new ThreadLocal();

    private static boolean isMatMultMemoryBound(int m1Rlen, int m1Clen, int m2Clen) {
        return m1Rlen == 1 || m1Clen == 1 || m2Clen == 1;
    }

    public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) throws DMLRuntimeException {
        LibMatrixNative.matrixMult(m1, m2, ret, k, true);
    }

    public static void matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k, boolean examSparsity) throws DMLRuntimeException {
        int n = k = k <= 0 ? NativeHelper.getMaxNumThreads() : k;
        if (m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) {
            ret.setNonZeros(0L);
            if (examSparsity) {
                ret.examSparsity();
            }
            return;
        }
        if (NativeHelper.isNativeLibraryLoaded() && !LibMatrixNative.isMatMultMemoryBound(m1.rlen, m1.clen, m2.clen) && !m1.isInSparseFormat() && !m2.isInSparseFormat()) {
            ret.sparse = false;
            ret.allocateDenseBlock();
            long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            boolean rccode = false;
            if (LibMatrixNative.isSinglePrecision()) {
                FloatBuffer fin1 = LibMatrixNative.toFloatBuffer(m1.getDenseBlockValues(), inBuff, true);
                FloatBuffer fin2 = LibMatrixNative.toFloatBuffer(m2.getDenseBlockValues(), filterBuff, true);
                FloatBuffer fout = LibMatrixNative.toFloatBuffer(ret.getDenseBlockValues(), outBuff, false);
                rccode = NativeHelper.smmdd(fin1, fin2, fout, m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns(), k);
                LibMatrixNative.fromFloatBuffer(outBuff.get(), ret.getDenseBlockValues());
            } else {
                rccode = NativeHelper.dmmdd(m1.getDenseBlockValues(), m2.getDenseBlockValues(), ret.getDenseBlockValues(), m1.getNumRows(), m1.getNumColumns(), m2.getNumColumns(), k);
            }
            if (rccode) {
                if (DMLScript.STATISTICS) {
                    Statistics.nativeLibMatrixMultTime += System.nanoTime() - start;
                    Statistics.numNativeLibMatrixMultCalls.increment();
                }
                ret.recomputeNonZeros();
                if (examSparsity) {
                    ret.examSparsity();
                }
                return;
            }
            Statistics.incrementNativeFailuresCounter();
        }
        if (k == 1) {
            LibMatrixMult.matrixMult(m1, m2, ret, !examSparsity);
        } else {
            LibMatrixMult.matrixMult(m1, m2, ret, k);
        }
    }

    public static void conv2d(MatrixBlock input, MatrixBlock filter, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        LibMatrixDNN.checkInputsConv2d(input, filter, outputBlock, params);
        int n = params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !input.isInSparseFormat() && !filter.isInSparseFormat()) {
            LibMatrixNative.setNumThreads(params);
            long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            int nnz = 0;
            if (params.bias == null) {
                nnz = NativeHelper.conv2dDense(input.getDenseBlockValues(), filter.getDenseBlockValues(), outputBlock.getDenseBlockValues(), params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
            } else {
                if (params.bias.isInSparseFormat()) {
                    params.bias.sparseToDense();
                }
                if (LibMatrixNative.isSinglePrecision() && !NativeHelper.getCurrentBLAS().equalsIgnoreCase("mkl")) {
                    FloatBuffer foutput;
                    FloatBuffer ffilter;
                    FloatBuffer fbias;
                    FloatBuffer finput = LibMatrixNative.toFloatBuffer(input.getDenseBlockValues(), inBuff, true);
                    nnz = NativeHelper.sconv2dBiasAddDense(finput, fbias = LibMatrixNative.toFloatBuffer(params.bias.getDenseBlockValues(), biasBuff, true), ffilter = LibMatrixNative.toFloatBuffer(filter.getDenseBlockValues(), filterBuff, true), foutput = LibMatrixNative.toFloatBuffer(outputBlock.getDenseBlockValues(), outBuff, false), params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
                    if (nnz != -1) {
                        LibMatrixNative.fromFloatBuffer(outBuff.get(), outputBlock.getDenseBlockValues());
                    }
                } else {
                    nnz = NativeHelper.dconv2dBiasAddDense(input.getDenseBlockValues(), params.bias.getDenseBlockValues(), filter.getDenseBlockValues(), outputBlock.getDenseBlockValues(), params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
                }
            }
            if (nnz != -1) {
                if (DMLScript.STATISTICS) {
                    Statistics.nativeConv2dTime += System.nanoTime() - start;
                    Statistics.numNativeConv2dCalls.increment();
                }
                outputBlock.setNonZeros(nnz);
                return;
            }
            LOG.warn((Object)"Native conv2d call returned with error - falling back to java operator.");
            if (!LibMatrixNative.isSinglePrecision() || params.bias == null) {
                outputBlock.reset();
            }
            Statistics.incrementNativeFailuresCounter();
        }
        LibMatrixDNN.conv2d(input, filter, outputBlock, params);
    }

    private static void setNumThreads(ConvolutionParameters params) {
        params.numThreads = OptimizerUtils.getConstrainedNumThreads(params.numThreads);
        if (!params.isOutputThreadSafe() || params.numThreads <= 1) {
            params.numThreads = 1;
        }
    }

    public static void conv2dBackwardFilter(MatrixBlock input, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        LibMatrixDNN.checkInputsConv2dBackwardFilter(input, dout, outputBlock, params);
        int n = params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !dout.isInSparseFormat() && !input.isInSparseFormat()) {
            LibMatrixNative.setNumThreads(params);
            long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            int nnz = NativeHelper.conv2dBackwardFilterDense(input.getDenseBlockValues(), dout.getDenseBlockValues(), outputBlock.getDenseBlockValues(), params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
            if (nnz != -1) {
                if (DMLScript.STATISTICS) {
                    Statistics.nativeConv2dBwdFilterTime += System.nanoTime() - start;
                    Statistics.numNativeConv2dBwdFilterCalls.increment();
                }
                outputBlock.setNonZeros(nnz);
                return;
            }
            Statistics.incrementNativeFailuresCounter();
        }
        LibMatrixDNN.conv2dBackwardFilter(input, dout, outputBlock, params);
    }

    public static void conv2dBackwardData(MatrixBlock filter, MatrixBlock dout, MatrixBlock outputBlock, ConvolutionParameters params) throws DMLRuntimeException {
        LibMatrixDNN.checkInputsConv2dBackwardData(filter, dout, outputBlock, params);
        int n = params.numThreads = params.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : params.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !dout.isInSparseFormat() && !filter.isInSparseFormat()) {
            LibMatrixNative.setNumThreads(params);
            long start = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            int nnz = NativeHelper.conv2dBackwardDataDense(filter.getDenseBlockValues(), dout.getDenseBlockValues(), outputBlock.getDenseBlockValues(), params.N, params.C, params.H, params.W, params.K, params.R, params.S, params.stride_h, params.stride_w, params.pad_h, params.pad_w, params.P, params.Q, params.numThreads);
            if (nnz != -1) {
                if (DMLScript.STATISTICS) {
                    Statistics.nativeConv2dBwdDataTime += System.nanoTime() - start;
                    Statistics.numNativeConv2dBwdDataCalls.increment();
                }
                outputBlock.setNonZeros(nnz);
                return;
            }
            Statistics.incrementNativeFailuresCounter();
        }
        LibMatrixDNN.conv2dBackwardData(filter, dout, outputBlock, params);
    }

    private static boolean isSinglePrecision() {
        return ConfigurationManager.getDMLConfig().getTextValue("sysml.floating.point.precision").equals("single");
    }

    private static FloatBuffer toFloatBuffer(double[] input, ThreadLocal<FloatBuffer> buff, boolean copy) {
        FloatBuffer ret = buff.get();
        if (ret == null || ret.capacity() < input.length) {
            ret = ByteBuffer.allocateDirect(4 * input.length).order(ByteOrder.nativeOrder()).asFloatBuffer();
            buff.set(ret);
        }
        FloatBuffer ret2 = ret;
        if (copy) {
            IntStream.range(0, input.length).parallel().forEach(i -> ret2.put(i, (float)input[i]));
        }
        return ret2;
    }

    private static void fromFloatBuffer(FloatBuffer buff, double[] output) {
        Arrays.parallelSetAll(output, i -> buff.get(i));
    }
}

