/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.gpu.context;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import jcuda.Pointer;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcusolver.JCusolverDn;
import jcuda.jcusolver.JCusolverSp;
import jcuda.jcusolver.cusolverDnHandle;
import jcuda.jcusolver.cusolverSpHandle;
import jcuda.jcusparse.JCusparse;
import jcuda.jcusparse.cusparseHandle;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.instructions.gpu.context.CSRPointer;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysml.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysml.runtime.instructions.gpu.context.JCudaKernels;
import org.apache.sysml.utils.GPUStatistics;
import org.apache.sysml.utils.LRUCacheMap;

public class GPUContext {
    protected static final Log LOG = LogFactory.getLog((String)GPUContext.class.getName());
    public final EvictionPolicy evictionPolicy = EvictionPolicy.LRU;
    final int MAJOR_REQUIRED = 3;
    final int MINOR_REQUIRED = 0;
    private final int deviceNum;
    public double GPU_MEMORY_UTILIZATION_FACTOR = ConfigurationManager.getDMLConfig().getDoubleValue("gpu.memory.util.factor");
    private LRUCacheMap<Long, LinkedList<Pointer>> freeCUDASpaceMap = new LRUCacheMap();
    private HashMap<Pointer, Long> cudaBlockSizeMap = new HashMap();
    private ArrayList<GPUObject> allocatedGPUObjects = new ArrayList();
    private final ThreadLocal<cudnnHandle> cudnnHandle = new ThreadLocal();
    private final ThreadLocal<cublasHandle> cublasHandle = new ThreadLocal();
    private final ThreadLocal<cusparseHandle> cusparseHandle = new ThreadLocal();
    private final ThreadLocal<cusolverDnHandle> cusolverDnHandle = new ThreadLocal();
    private final ThreadLocal<cusolverSpHandle> cusolverSpHandle = new ThreadLocal();
    private final ThreadLocal<JCudaKernels> kernels = new ThreadLocal();

    protected GPUContext(int deviceNum) throws DMLRuntimeException {
        this.deviceNum = deviceNum;
        JCuda.cudaSetDevice((int)deviceNum);
        JCuda.cudaSetDeviceFlags((int)4);
        long[] free = new long[]{0L};
        long[] total = new long[]{0L};
        JCuda.cudaMemGetInfo((long[])free, (long[])total);
        long start = -1L;
        if (DMLScript.STATISTICS) {
            start = System.nanoTime();
        }
        this.initializeCudaLibraryHandles();
        if (DMLScript.STATISTICS) {
            GPUStatistics.cudaLibrariesInitTime = System.nanoTime() - start;
        }
        LOG.info((Object)(" GPU memory - Total: " + (double)total[0] * 1.0E-6 + " MB, Available: " + (double)free[0] * 1.0E-6 + " MB on " + this));
    }

    private void initializeCudaLibraryHandles() throws DMLRuntimeException {
        if (this.cudnnHandle.get() == null) {
            this.cudnnHandle.set(new cudnnHandle());
            JCudnn.cudnnCreate((cudnnHandle)this.cudnnHandle.get());
        }
        if (this.cublasHandle.get() == null) {
            this.cublasHandle.set(new cublasHandle());
            JCublas2.cublasCreate((cublasHandle)this.cublasHandle.get());
        }
        if (this.cusparseHandle.get() == null) {
            this.cusparseHandle.set(new cusparseHandle());
            JCusparse.cusparseCreate((cusparseHandle)this.cusparseHandle.get());
        }
        if (this.cusolverDnHandle.get() == null) {
            this.cusolverDnHandle.set(new cusolverDnHandle());
            JCusolverDn.cusolverDnCreate((cusolverDnHandle)this.cusolverDnHandle.get());
        }
        if (this.cusolverSpHandle.get() == null) {
            this.cusolverSpHandle.set(new cusolverSpHandle());
            JCusolverSp.cusolverSpCreate((cusolverSpHandle)this.cusolverSpHandle.get());
        }
        if (this.kernels.get() == null) {
            this.kernels.set(new JCudaKernels());
        }
    }

    public static int cudaGetDevice() {
        int[] device = new int[1];
        JCuda.cudaGetDevice((int[])device);
        return device[0];
    }

    public int getDeviceNum() {
        return this.deviceNum;
    }

    public void initializeThread() throws DMLRuntimeException {
        JCuda.cudaSetDevice((int)this.deviceNum);
        this.initializeCudaLibraryHandles();
    }

    public Pointer allocate(long size) throws DMLRuntimeException {
        return this.allocate(null, size, 1);
    }

    public Pointer allocate(String instructionName, long size) throws DMLRuntimeException {
        return this.allocate(instructionName, size, 1);
    }

    public Pointer allocate(String instructionName, long size, int statsCount) throws DMLRuntimeException {
        Pointer A;
        long t0 = 0L;
        long t1 = 0L;
        long end = 0L;
        if (this.freeCUDASpaceMap.containsKey(size)) {
            LOG.trace((Object)("GPU : in allocate from instruction " + instructionName + ", found free block of size " + (double)size / 1024.0 + " Kbytes from previously allocated block on " + this));
            if (instructionName != null && GPUStatistics.DISPLAY_STATISTICS) {
                t0 = System.nanoTime();
            }
            LinkedList freeList = (LinkedList)this.freeCUDASpaceMap.get(size);
            A = (Pointer)freeList.pop();
            if (freeList.isEmpty()) {
                this.freeCUDASpaceMap.remove(size);
            }
            if (instructionName != null && GPUStatistics.DISPLAY_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instructionName, "r", System.nanoTime() - t0);
            }
        } else {
            LOG.trace((Object)("GPU : in allocate from instruction " + instructionName + ", allocating new block of size " + (double)size / 1024.0 + " Kbytes on " + this));
            if (DMLScript.STATISTICS) {
                t0 = System.nanoTime();
            }
            this.ensureFreeSpace(instructionName, size);
            A = new Pointer();
            JCuda.cudaMalloc((Pointer)A, (long)size);
            if (DMLScript.STATISTICS) {
                GPUStatistics.cudaAllocTime.add(System.nanoTime() - t0);
            }
            if (DMLScript.STATISTICS) {
                GPUStatistics.cudaAllocCount.add(statsCount);
            }
            if (instructionName != null && GPUStatistics.DISPLAY_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instructionName, "a", System.nanoTime() - t0);
            }
        }
        if (DMLScript.STATISTICS) {
            t1 = System.nanoTime();
        }
        LOG.trace((Object)("GPU : in allocate from instruction " + instructionName + ", setting block of size " + (double)size / 1024.0 + " Kbytes to zero on " + this));
        JCuda.cudaMemset((Pointer)A, (int)0, (long)size);
        if (DMLScript.STATISTICS) {
            end = System.nanoTime();
        }
        if (instructionName != null && GPUStatistics.DISPLAY_STATISTICS) {
            GPUStatistics.maintainCPMiscTimes(instructionName, "az", end - t1);
        }
        if (DMLScript.STATISTICS) {
            GPUStatistics.cudaMemSet0Time.add(end - t1);
        }
        if (DMLScript.STATISTICS) {
            GPUStatistics.cudaMemSet0Count.add(1L);
        }
        this.cudaBlockSizeMap.put(A, size);
        return A;
    }

    public void cudaFreeHelper(Pointer toFree) {
        this.cudaFreeHelper(null, toFree, false);
    }

    public void cudaFreeHelper(Pointer toFree, boolean eager) {
        this.cudaFreeHelper(null, toFree, eager);
    }

    public void cudaFreeHelper(String instructionName, Pointer toFree) {
        this.cudaFreeHelper(instructionName, toFree, false);
    }

    public void cudaFreeHelper(String instructionName, Pointer toFree, boolean eager) {
        Pointer dummy = new Pointer();
        if (toFree == dummy) {
            return;
        }
        long t0 = 0L;
        assert (this.cudaBlockSizeMap.containsKey(toFree)) : "ERROR : Internal state corrupted, cache block size map is not aware of a block it trying to free up";
        long size = this.cudaBlockSizeMap.get(toFree);
        if (eager) {
            LOG.trace((Object)("GPU : eagerly freeing cuda memory [ " + toFree + " ] for instruction " + instructionName + " on " + this));
            if (DMLScript.STATISTICS) {
                t0 = System.nanoTime();
            }
            JCuda.cudaFree((Pointer)toFree);
            this.cudaBlockSizeMap.remove(toFree);
            if (DMLScript.STATISTICS) {
                GPUStatistics.cudaDeAllocTime.add(System.nanoTime() - t0);
            }
            if (DMLScript.STATISTICS) {
                GPUStatistics.cudaDeAllocCount.add(1L);
            }
            if (instructionName != null && GPUStatistics.DISPLAY_STATISTICS) {
                GPUStatistics.maintainCPMiscTimes(instructionName, "f", System.nanoTime() - t0);
            }
        } else {
            LOG.trace((Object)("GPU : lazily freeing cuda memory for instruction " + instructionName + " on " + this));
            LinkedList<Pointer> freeList = (LinkedList<Pointer>)this.freeCUDASpaceMap.get(size);
            if (freeList == null) {
                freeList = new LinkedList<Pointer>();
                this.freeCUDASpaceMap.put(size, freeList);
            }
            if (freeList.contains(toFree)) {
                throw new RuntimeException("GPU : Internal state corrupted, double free");
            }
            freeList.add(toFree);
        }
    }

    void ensureFreeSpace(long size) throws DMLRuntimeException {
        this.ensureFreeSpace(null, size);
    }

    void ensureFreeSpace(String instructionName, long size) throws DMLRuntimeException {
        if (size >= this.getAvailableMemory()) {
            this.evict(instructionName, size);
        }
    }

    protected void evict(long GPUSize) throws DMLRuntimeException {
        this.evict(null, GPUSize);
    }

    protected void evict(String instructionName, final long neededSize) throws DMLRuntimeException {
        LOG.trace((Object)("GPU : evict called from " + instructionName + " for size " + neededSize + " on " + this));
        GPUStatistics.cudaEvictionCount.add(1L);
        LRUCacheMap<Long, LinkedList<Pointer>> lruCacheMap = this.freeCUDASpaceMap;
        while (lruCacheMap.size() > 0 && neededSize > this.getAvailableMemory()) {
            Map.Entry<Long, LinkedList<Pointer>> toFreeListPair = lruCacheMap.removeAndGetLRUEntry();
            LinkedList<Pointer> toFreeList = toFreeListPair.getValue();
            Long size = toFreeListPair.getKey();
            Pointer toFree = toFreeList.pop();
            if (toFreeList.isEmpty()) {
                lruCacheMap.remove(size);
            }
            this.cudaFreeHelper(instructionName, toFree, true);
        }
        if (neededSize <= this.getAvailableMemory()) {
            return;
        }
        if (this.allocatedGPUObjects.size() == 0) {
            throw new DMLRuntimeException("There is not enough memory on device for this matrix, request (" + neededSize + ")");
        }
        Collections.sort(this.allocatedGPUObjects, new Comparator<GPUObject>(){

            @Override
            public int compare(GPUObject p1, GPUObject p2) {
                long p1Val = p1.locks.get();
                long p2Val = p2.locks.get();
                if (p1Val > 0L && p2Val > 0L) {
                    return 0;
                }
                if (p1Val > 0L || p2Val > 0L) {
                    return Long.compare(p2Val, p1Val);
                }
                if (GPUContext.this.evictionPolicy == EvictionPolicy.MIN_EVICT) {
                    long p1Size = 0L;
                    long p2Size = 0L;
                    try {
                        p1Size = p1.getSizeOnDevice() - neededSize;
                        p2Size = p2.getSizeOnDevice() - neededSize;
                    }
                    catch (DMLRuntimeException e) {
                        throw new RuntimeException(e);
                    }
                    if (p1Size >= 0L && p2Size >= 0L) {
                        return Long.compare(p2Size, p1Size);
                    }
                    return Long.compare(p1Size, p2Size);
                }
                if (GPUContext.this.evictionPolicy == EvictionPolicy.LRU || GPUContext.this.evictionPolicy == EvictionPolicy.LFU) {
                    return Long.compare(p2.timestamp.get(), p1.timestamp.get());
                }
                throw new RuntimeException("Unsupported eviction policy:" + GPUContext.this.evictionPolicy.name());
            }
        });
        while (neededSize > this.getAvailableMemory() && this.allocatedGPUObjects.size() > 0) {
            GPUObject toBeRemoved = this.allocatedGPUObjects.get(this.allocatedGPUObjects.size() - 1);
            if (toBeRemoved.locks.get() > 0L) {
                throw new DMLRuntimeException("There is not enough memory on device for this matrix, request (" + neededSize + ")");
            }
            if (toBeRemoved.dirty) {
                toBeRemoved.copyFromDeviceToHost();
            }
            toBeRemoved.clearData(true);
        }
    }

    public boolean isBlockRecorded(GPUObject o) {
        return this.allocatedGPUObjects.contains(o);
    }

    public void recordBlockUsage(GPUObject o) {
        this.allocatedGPUObjects.add(o);
    }

    public void removeRecordedUsage(GPUObject o) {
        this.allocatedGPUObjects.removeIf(a -> a.equals(o));
    }

    public long getAvailableMemory() {
        long[] free = new long[]{0L};
        long[] total = new long[]{0L};
        JCuda.cudaMemGetInfo((long[])free, (long[])total);
        return (long)((double)free[0] * this.GPU_MEMORY_UTILIZATION_FACTOR);
    }

    public void ensureComputeCapability() throws DMLRuntimeException {
        int[] devices = new int[]{-1};
        JCuda.cudaGetDeviceCount((int[])devices);
        if (devices[0] == -1) {
            throw new DMLRuntimeException("Call to cudaGetDeviceCount returned 0 devices");
        }
        boolean isComputeCapable = true;
        for (int i = 0; i < devices[0]; ++i) {
            cudaDeviceProp properties = GPUContextPool.getGPUProperties(i);
            int major = properties.major;
            int minor = properties.minor;
            if (major < 3) {
                isComputeCapable = false;
                continue;
            }
            if (major != 3 || minor >= 0) continue;
            isComputeCapable = false;
        }
        if (!isComputeCapable) {
            throw new DMLRuntimeException("One of the CUDA cards on the system has compute capability lower than 3.0");
        }
    }

    public GPUObject createGPUObject(MatrixObject mo) {
        return new GPUObject(this, mo);
    }

    public cudaDeviceProp getGPUProperties() throws DMLRuntimeException {
        return GPUContextPool.getGPUProperties(this.deviceNum);
    }

    public int getMaxThreadsPerBlock() throws DMLRuntimeException {
        cudaDeviceProp deviceProps = this.getGPUProperties();
        return deviceProps.maxThreadsPerBlock;
    }

    public int getMaxBlocks() throws DMLRuntimeException {
        cudaDeviceProp deviceProp = this.getGPUProperties();
        return deviceProp.maxGridSize[0];
    }

    public long getMaxSharedMemory() throws DMLRuntimeException {
        cudaDeviceProp deviceProp = this.getGPUProperties();
        return deviceProp.sharedMemPerBlock;
    }

    public int getWarpSize() throws DMLRuntimeException {
        cudaDeviceProp deviceProp = this.getGPUProperties();
        return deviceProp.warpSize;
    }

    public cudnnHandle getCudnnHandle() {
        return this.cudnnHandle.get();
    }

    public cublasHandle getCublasHandle() {
        return this.cublasHandle.get();
    }

    public cusparseHandle getCusparseHandle() {
        return this.cusparseHandle.get();
    }

    public cusolverDnHandle getCusolverDnHandle() {
        return this.cusolverDnHandle.get();
    }

    public cusolverSpHandle getCusolverSpHandle() {
        return this.cusolverSpHandle.get();
    }

    public JCudaKernels getKernels() {
        return this.kernels.get();
    }

    public void destroy() throws DMLRuntimeException {
        LOG.trace((Object)("GPU : this context was destroyed, this = " + this.toString()));
        this.clearMemory();
        JCudnn.cudnnDestroy((cudnnHandle)this.cudnnHandle.get());
        JCublas2.cublasDestroy((cublasHandle)this.cublasHandle.get());
        JCusparse.cusparseDestroy((cusparseHandle)this.cusparseHandle.get());
        JCusolverDn.cusolverDnDestroy((cusolverDnHandle)this.cusolverDnHandle.get());
        JCusolverSp.cusolverSpDestroy((cusolverSpHandle)this.cusolverSpHandle.get());
    }

    public void clearMemory() throws DMLRuntimeException {
        this.clearTemporaryMemory();
        while (!this.allocatedGPUObjects.isEmpty()) {
            GPUObject o = this.allocatedGPUObjects.get(0);
            if (o.isDirty()) {
                LOG.warn((Object)("Attempted to free GPU Memory when a block[" + o + "] is still on GPU memory, copying it back to host."));
                o.acquireHostRead();
            }
            o.clearData(true);
        }
        this.allocatedGPUObjects.clear();
    }

    public void clearTemporaryMemory() {
        HashMap<Object, Long> tmpCudaBlockSizeMap = new HashMap<Object, Long>();
        for (GPUObject o : this.allocatedGPUObjects) {
            CSRPointer p;
            if (!o.isDirty()) continue;
            if (o.isSparse()) {
                p = o.getSparseMatrixCudaPointer();
                if (p == null) {
                    throw new RuntimeException("CSRPointer is null in clearTemporaryMemory");
                }
                if (p.rowPtr != null && this.cudaBlockSizeMap.containsKey(p.rowPtr)) {
                    tmpCudaBlockSizeMap.put(p.rowPtr, this.cudaBlockSizeMap.get(p.rowPtr));
                }
                if (p.colInd != null && this.cudaBlockSizeMap.containsKey(p.colInd)) {
                    tmpCudaBlockSizeMap.put(p.colInd, this.cudaBlockSizeMap.get(p.colInd));
                }
                if (p.val == null || !this.cudaBlockSizeMap.containsKey(p.val)) continue;
                tmpCudaBlockSizeMap.put(p.val, this.cudaBlockSizeMap.get(p.val));
                continue;
            }
            p = o.getJcudaDenseMatrixPtr();
            if (p == null) {
                throw new RuntimeException("Pointer is null in clearTemporaryMemory");
            }
            tmpCudaBlockSizeMap.put(p, this.cudaBlockSizeMap.get(p));
        }
        for (LinkedList l : this.freeCUDASpaceMap.values()) {
            for (Pointer p : l) {
                this.cudaFreeHelper(p, true);
            }
        }
        this.cudaBlockSizeMap.clear();
        this.freeCUDASpaceMap.clear();
        this.cudaBlockSizeMap.putAll(tmpCudaBlockSizeMap);
    }

    public String toString() {
        return "GPUContext{deviceNum=" + this.deviceNum + '}';
    }

    public static enum EvictionPolicy {
        LRU,
        LFU,
        MIN_EVICT;

    }
}

