/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Minus1Multiply;
import org.apache.sysds.runtime.functionobjects.MinusMultiply;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.PlusMultiply;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibBinaryCellOp {
    private static final Log LOG = LogFactory.getLog((String)CLALibBinaryCellOp.class.getName());

    public static MatrixBlock binaryOperations(BinaryOperator op, CompressedMatrixBlock m1, MatrixBlock that, MatrixBlock result) {
        try {
            if (that.getNumRows() == 1 && that.getNumColumns() == 1) {
                RightScalarOperator sop = new RightScalarOperator(op.fn, that.getValue(0, 0), op.getNumThreads());
                return CLALibScalar.scalarOperations(sop, m1, result);
            }
            if (that.isEmpty()) {
                return CLALibBinaryCellOp.binaryOperationsEmpty(op, m1, that, result);
            }
            that = CompressedMatrixBlock.getUncompressed(that, "Decompressing right side in BinaryOps");
            LibMatrixBincell.isValidDimensionsBinary(m1, that);
            LibMatrixBincell.BinaryAccessType atype = LibMatrixBincell.getBinaryAccessType(m1, that);
            return CLALibBinaryCellOp.selectProcessingBasedOnAccessType(op, m1, that, result, atype, false);
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed to perform compressed binary operation: " + op, e);
        }
    }

    public static MatrixBlock binaryOperationsLeft(BinaryOperator op, CompressedMatrixBlock m1, MatrixBlock that, MatrixBlock result) {
        if (that.getNumRows() == 1 && that.getNumColumns() == 1) {
            LeftScalarOperator sop = new LeftScalarOperator(op.fn, that.getValue(0, 0), op.getNumThreads());
            return CLALibScalar.scalarOperations(sop, m1, result);
        }
        if (that.isEmpty()) {
            throw new NotImplementedException("Not handling left empty yet");
        }
        that = CompressedMatrixBlock.getUncompressed(that, "Decompressing left side in BinaryOps");
        LibMatrixBincell.isValidDimensionsBinary(that, m1);
        LibMatrixBincell.BinaryAccessType atype = LibMatrixBincell.getBinaryAccessType(that, m1);
        return CLALibBinaryCellOp.selectProcessingBasedOnAccessType(op, m1, that, result, atype, true);
    }

    private static MatrixBlock binaryOperationsEmpty(BinaryOperator op, CompressedMatrixBlock m1, MatrixBlock that, MatrixBlock result) {
        ValueFunction fn = op.fn;
        if (m1.getNumRows() == that.getNumRows() && m1.getNumColumns() == that.getNumColumns() || m1.getNumColumns() == that.getNumColumns()) {
            if (fn instanceof Multiply) {
                result = CompressedMatrixBlockFactory.createConstant(m1.getNumRows(), m1.getNumColumns(), 0.0);
            } else if (fn instanceof Minus1Multiply) {
                result = CompressedMatrixBlockFactory.createConstant(m1.getNumRows(), m1.getNumColumns(), 1.0);
            } else {
                if (fn instanceof Minus || fn instanceof Plus || fn instanceof MinusMultiply || fn instanceof PlusMultiply) {
                    CompressedMatrixBlock ret = new CompressedMatrixBlock();
                    ret.copy(m1);
                    return ret;
                }
                throw new NotImplementedException("Function Type: " + fn);
            }
            return result;
        }
        long lr = m1.getNumRows();
        long rr = that.getNumRows();
        long lc = m1.getNumColumns();
        long rc = that.getNumColumns();
        throw new NotImplementedException("Not Implemented sizes: left(" + lr + ", " + lc + ") right(" + rr + ", " + rc + ")");
    }

    private static MatrixBlock selectProcessingBasedOnAccessType(BinaryOperator op, CompressedMatrixBlock m1, MatrixBlock that, MatrixBlock result, LibMatrixBincell.BinaryAccessType atype, boolean left) {
        if (atype == LibMatrixBincell.BinaryAccessType.MATRIX_COL_VECTOR) {
            MatrixBlock d_compressed = m1.getCachedDecompressed();
            if (d_compressed != null) {
                if (left) {
                    return that.binaryOperations(op, d_compressed);
                }
                return d_compressed.binaryOperations(op, that);
            }
            return CLALibBinaryCellOp.binaryMVCol(m1, that, op, left);
        }
        if (atype == LibMatrixBincell.BinaryAccessType.MATRIX_MATRIX) {
            MatrixBlock d_compressed = m1.getUncompressed("MatrixMatrix " + op);
            if (left) {
                return that.binaryOperations(op, d_compressed);
            }
            return d_compressed.binaryOperations(op, that);
        }
        if (CLALibBinaryCellOp.isSupportedBinaryCellOp(op.fn)) {
            return CLALibBinaryCellOp.bincellOp(m1, that, CLALibBinaryCellOp.setupCompressedReturnMatrixBlock(m1, result), op, left);
        }
        return CompressedMatrixBlock.getUncompressed(m1, "BinaryOp: " + op.fn).binaryOperations(op, that, result);
    }

    private static boolean isSupportedBinaryCellOp(ValueFunction fn) {
        return fn instanceof Multiply || fn instanceof Divide || fn instanceof Plus || fn instanceof Minus || fn instanceof MinusMultiply || fn instanceof PlusMultiply;
    }

    private static CompressedMatrixBlock setupCompressedReturnMatrixBlock(CompressedMatrixBlock m1, MatrixValue result) {
        CompressedMatrixBlock ret = null;
        if (result == null || !(result instanceof CompressedMatrixBlock)) {
            ret = new CompressedMatrixBlock(m1.getNumRows(), m1.getNumColumns());
        } else {
            ret = (CompressedMatrixBlock)result;
            ret.reset(m1.getNumRows(), m1.getNumColumns());
        }
        return ret;
    }

    private static MatrixBlock bincellOp(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        if (CLALibBinaryCellOp.isValidForOverlappingBinaryCellOperations(m1, op)) {
            CLALibBinaryCellOp.overlappingBinaryCellOp(m1, m2, ret, op, left);
        } else {
            CLALibBinaryCellOp.nonOverlappingBinaryCellOp(m1, m2, ret, op, left);
        }
        return ret;
    }

    private static void nonOverlappingBinaryCellOp(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        LibMatrixBincell.BinaryAccessType atype = LibMatrixBincell.getBinaryAccessType(m1, m2);
        switch (atype) {
            case MATRIX_ROW_VECTOR: {
                CLALibBinaryCellOp.binaryMVRow(m1, m2, ret, op, left);
                return;
            }
            case OUTER_VECTOR_VECTOR: {
                if (m2.getNumRows() == 1 && m2.getNumColumns() == 1) {
                    CLALibScalar.scalarOperations(new RightScalarOperator(op.fn, m2.quickGetValue(0, 0)), m1, ret);
                }
                return;
            }
        }
        LOG.warn((Object)("Inefficient Decompression for " + op + "  " + (Object)((Object)atype)));
        m1.decompress().binaryOperations(op, m2, ret);
    }

    private static boolean isValidForOverlappingBinaryCellOperations(CompressedMatrixBlock m1, BinaryOperator op) {
        return m1.isOverlapping() && !(op.fn instanceof Multiply) && !(op.fn instanceof Divide);
    }

    private static void overlappingBinaryCellOp(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        if (!(op.fn instanceof Plus) && !(op.fn instanceof Minus)) {
            throw new NotImplementedException(op + " not implemented for Overlapping CLA");
        }
        CLALibBinaryCellOp.binaryMVPlusStack(m1, m2, ret, op, left);
    }

    public static CompressedMatrixBlock binaryMVRow(CompressedMatrixBlock m1, double[] v, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        List<AColGroup> oldColGroups = m1.getColGroups();
        if (ret == null) {
            ret = new CompressedMatrixBlock(m1.getNumRows(), m1.getNumColumns());
        }
        boolean sparseSafe = true;
        for (double x : v) {
            if (op.fn.execute(0.0, x) == 0.0) continue;
            sparseSafe = false;
            break;
        }
        ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>(oldColGroups.size());
        int k = op.getNumThreads();
        ExecutorService pool = CommonThreadPool.get(k);
        ArrayList<BinaryMVRowTask> tasks = new ArrayList<BinaryMVRowTask>();
        try {
            for (AColGroup aColGroup : oldColGroups) {
                tasks.add(new BinaryMVRowTask(aColGroup, v, sparseSafe, op, left));
            }
            for (Future future : pool.invokeAll(tasks)) {
                newColGroups.add((AColGroup)future.get());
            }
            pool.shutdown();
        }
        catch (InterruptedException | ExecutionException e) {
            e.printStackTrace();
            throw new DMLRuntimeException(e);
        }
        ret.allocateColGroupList(newColGroups);
        ret.setNonZeros(m1.getNumColumns() * m1.getNumRows());
        return ret;
    }

    protected static CompressedMatrixBlock binaryMVRow(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        return CLALibBinaryCellOp.binaryMVRow(m1, CLALibBinaryCellOp.forceMatrixBlockToDense(m2), ret, op, left);
    }

    private static double[] forceMatrixBlockToDense(MatrixBlock m2) {
        double[] v;
        if (m2.isInSparseFormat()) {
            SparseBlock sb = m2.getSparseBlock();
            if (sb == null) {
                throw new DMLRuntimeException("Unknown matrix block type");
            }
            double[] spV = sb.values(0);
            int[] spI = sb.indexes(0);
            v = new double[m2.getNumColumns()];
            for (int i = sb.pos(0); i < sb.size(0); ++i) {
                v[spI[i]] = spV[i];
            }
        } else {
            v = m2.getDenseBlockValues();
        }
        return v;
    }

    protected static CompressedMatrixBlock binaryMVPlusStack(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        if (m2.isEmpty()) {
            return m1;
        }
        List<AColGroup> oldColGroups = m1.getColGroups();
        int size = oldColGroups.size();
        ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>(size);
        int smallestIndex = 0;
        int smallestSize = Integer.MAX_VALUE;
        int nCol = m1.getNumColumns();
        for (int i = 0; i < size; ++i) {
            AColGroup g = oldColGroups.get(i);
            int newSize = g.getNumValues();
            newColGroups.add(g);
            if (newSize >= smallestSize || g.getNumCols() != nCol) continue;
            smallestIndex = i;
            smallestSize = newSize;
        }
        if (smallestSize == Integer.MAX_VALUE) {
            int[] colIndexes = new int[nCol];
            for (int i = 0; i < nCol; ++i) {
                colIndexes[i] = i;
            }
            MatrixBlockDictionary newDict = new MatrixBlockDictionary(m2);
            newColGroups.add(new ColGroupConst(colIndexes, newDict));
        } else {
            AColGroup g = ((AColGroup)newColGroups.get(smallestIndex)).binaryRowOp(op, m2.getDenseBlockValues(), false, left);
            newColGroups.set(smallestIndex, g);
        }
        ret.allocateColGroupList(newColGroups);
        ret.setOverlapping(true);
        ret.setNonZeros(-1L);
        return ret;
    }

    private static MatrixBlock binaryMVCol(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left) {
        MatrixBlock ret = new MatrixBlock(m1.getNumRows(), m1.getNumColumns(), false, -1L).allocateBlock();
        int blkz = 65535 / m1.getNumColumns() * 5;
        int k = op.getNumThreads();
        long nnz = 0L;
        if (k <= 1) {
            int i = 0;
            while (i * blkz < m1.getNumRows()) {
                nnz = left ? (nnz += (long)new BinaryMVColLeftTask(m1, m2, ret, i * blkz, Math.min(m1.getNumRows(), (i + 1) * blkz), op).call().intValue()) : (nnz += (long)new BinaryMVColTask(m1, m2, ret, i * blkz, Math.min(m1.getNumRows(), (i + 1) * blkz), op).call().intValue());
                ++i;
            }
        } else {
            ExecutorService pool = CommonThreadPool.get(op.getNumThreads());
            ArrayList<Callable<Integer>> tasks = new ArrayList<Callable<Integer>>();
            try {
                int i = 0;
                while (i * blkz < m1.getNumRows()) {
                    if (left) {
                        tasks.add(new BinaryMVColLeftTask(m1, m2, ret, i * blkz, Math.min(m1.getNumRows(), (i + 1) * blkz), op));
                    } else {
                        tasks.add(new BinaryMVColTask(m1, m2, ret, i * blkz, Math.min(m1.getNumRows(), (i + 1) * blkz), op));
                    }
                    ++i;
                }
                for (Future f : pool.invokeAll(tasks)) {
                    nnz += (long)((Integer)f.get()).intValue();
                }
                pool.shutdown();
            }
            catch (InterruptedException | ExecutionException e) {
                e.printStackTrace();
                throw new DMLRuntimeException(e);
            }
        }
        ret.setNonZeros(nnz);
        return ret;
    }

    private static class BinaryMVRowTask
    implements Callable<AColGroup> {
        private final AColGroup _group;
        private final double[] _v;
        private final boolean _sparseSafe;
        private final BinaryOperator _op;
        private final boolean _left;

        protected BinaryMVRowTask(AColGroup group, double[] v, boolean sparseSafe, BinaryOperator op, boolean left) {
            this._group = group;
            this._v = v;
            this._op = op;
            this._sparseSafe = sparseSafe;
            this._left = left;
        }

        @Override
        public AColGroup call() {
            return this._group.binaryRowOp(this._op, this._v, this._sparseSafe, this._left);
        }
    }

    private static class BinaryMVColLeftTask
    implements Callable<Integer> {
        private final int _rl;
        private final int _ru;
        private final CompressedMatrixBlock _m1;
        private final MatrixBlock _m2;
        private final MatrixBlock _ret;
        private final BinaryOperator _op;

        protected BinaryMVColLeftTask(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, BinaryOperator op) {
            this._m1 = m1;
            this._m2 = m2;
            this._ret = ret;
            this._op = op;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public Integer call() {
            for (AColGroup g : this._m1.getColGroups()) {
                g.decompressToBlock(this._ret, this._rl, this._ru);
            }
            if (this._m2.isInSparseFormat()) {
                throw new NotImplementedException("Not Implemented sparse Format execution for MM.");
            }
            int offset = this._rl * this._m1.getNumColumns();
            double[] _retDense = this._ret.getDenseBlockValues();
            double[] _m2Dense = this._m2.getDenseBlockValues();
            for (int row = this._rl; row < this._ru; ++row) {
                double vr = _m2Dense[row];
                for (int col = 0; col < this._m1.getNumColumns(); ++col) {
                    double v;
                    _retDense[offset] = v = this._op.fn.execute(vr, _retDense[offset]);
                    ++offset;
                }
            }
            return this._ret.getNumColumns() * this._ret.getNumRows();
        }
    }

    private static class BinaryMVColTask
    implements Callable<Integer> {
        private final int _rl;
        private final int _ru;
        private final CompressedMatrixBlock _m1;
        private final MatrixBlock _m2;
        private final MatrixBlock _ret;
        private final BinaryOperator _op;

        protected BinaryMVColTask(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, BinaryOperator op) {
            this._m1 = m1;
            this._m2 = m2;
            this._ret = ret;
            this._op = op;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public Integer call() {
            for (AColGroup g : this._m1.getColGroups()) {
                g.decompressToBlock(this._ret, this._rl, this._ru);
            }
            if (this._m2.isInSparseFormat()) {
                throw new NotImplementedException("Not Implemented sparse Format execution for MM.");
            }
            int offset = this._rl * this._m1.getNumColumns();
            double[] _retDense = this._ret.getDenseBlockValues();
            double[] _m2Dense = this._m2.getDenseBlockValues();
            for (int row = this._rl; row < this._ru; ++row) {
                double vr = _m2Dense[row];
                for (int col = 0; col < this._m1.getNumColumns(); ++col) {
                    double v;
                    _retDense[offset] = v = this._op.fn.execute(_retDense[offset], vr);
                    ++offset;
                }
            }
            return this._ret.getNumColumns() * this._ret.getNumRows();
        }
    }
}

