/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.gpu.context;

import java.util.HashMap;
import jcuda.driver.CUdevice;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.gpu.context.JCudaKernels;

public class ExecutionConfig {
    public int gridDimX;
    public int gridDimY = 1;
    public int gridDimZ = 1;
    public int blockDimX;
    public int blockDimY = 1;
    public int blockDimZ = 1;
    public int sharedMemBytes = 0;
    public CUstream stream = null;
    private static HashMap<Integer, Integer> maxBlockDimForDevice = new HashMap();

    public ExecutionConfig(int gridDimX, int blockDimX, int sharedMemBytes) {
        this.gridDimX = gridDimX;
        this.blockDimX = blockDimX;
        this.sharedMemBytes = sharedMemBytes;
    }

    public static ExecutionConfig getConfigForSimpleVectorOperations(int numCells) throws DMLRuntimeException {
        int deviceNumber = 0;
        int blockDimX = ExecutionConfig.getMaxBlockDim(deviceNumber);
        int gridDimX = (int)Math.ceil((double)numCells / (double)blockDimX);
        return new ExecutionConfig(gridDimX, blockDimX);
    }

    public static ExecutionConfig getConfigForSimpleMatrixOperations(int rlen, int clen) throws DMLRuntimeException {
        int deviceNumber = 0;
        int maxBlockDim = ExecutionConfig.getMaxBlockDim(deviceNumber);
        int blockDimX = Math.min(maxBlockDim, rlen);
        int gridDimX = (int)Math.ceil((double)rlen / (double)blockDimX);
        int blockDimY = (int)Math.min(Math.floor((double)maxBlockDim / (double)blockDimX), (double)clen);
        int gridDimY = (int)Math.ceil((double)clen / (double)blockDimY);
        return new ExecutionConfig(gridDimX, gridDimY, blockDimX, blockDimY);
    }

    public ExecutionConfig(int gridDimX, int blockDimX) {
        this.gridDimX = gridDimX;
        this.blockDimX = blockDimX;
    }

    public ExecutionConfig(int gridDimX, int gridDimY, int blockDimX, int blockDimY) {
        this.gridDimX = gridDimX;
        this.gridDimY = gridDimY;
        this.blockDimX = blockDimX;
        this.blockDimY = blockDimY;
    }

    private static int getMaxBlockDim(int deviceNumber) throws DMLRuntimeException {
        Integer ret = maxBlockDimForDevice.get(deviceNumber);
        if (ret == null) {
            CUdevice device = new CUdevice();
            JCudaKernels.checkResult(JCudaDriver.cuDeviceGet((CUdevice)device, (int)deviceNumber));
            int[] maxBlockDimX = new int[]{0};
            JCudaDriver.cuDeviceGetAttribute((int[])maxBlockDimX, (int)2, (CUdevice)device);
            maxBlockDimForDevice.put(deviceNumber, maxBlockDimX[0]);
            return maxBlockDimX[0];
        }
        return ret;
    }
}

