/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.MinusMultiply;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.PlusMultiply;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibBinaryCellOp {
    private static final Log LOG = LogFactory.getLog((String)CLALibBinaryCellOp.class.getName());

    public static MatrixBlock binaryOperations(BinaryOperator op, CompressedMatrixBlock m1, MatrixValue thatValue, MatrixValue result) {
        MatrixBlock that = CompressedMatrixBlock.getUncompressed(thatValue);
        LibMatrixBincell.isValidDimensionsBinary(m1, that);
        LibMatrixBincell.BinaryAccessType atype = LibMatrixBincell.getBinaryAccessType(m1, that);
        return CLALibBinaryCellOp.selectProcessingBasedOnAccessType(op, m1, that, thatValue, result, atype, false);
    }

    public static MatrixBlock binaryOperationsLeft(BinaryOperator op, CompressedMatrixBlock m1, MatrixValue thatValue, MatrixValue result) {
        MatrixBlock that = CompressedMatrixBlock.getUncompressed(thatValue);
        LibMatrixBincell.isValidDimensionsBinary(that, m1);
        LibMatrixBincell.BinaryAccessType atype = LibMatrixBincell.getBinaryAccessType(that, m1);
        return CLALibBinaryCellOp.selectProcessingBasedOnAccessType(op, m1, that, thatValue, result, atype, true);
    }

    private static MatrixBlock selectProcessingBasedOnAccessType(BinaryOperator op, CompressedMatrixBlock m1, MatrixBlock that, MatrixValue thatValue, MatrixValue result, LibMatrixBincell.BinaryAccessType atype, boolean left) {
        if (atype == LibMatrixBincell.BinaryAccessType.MATRIX_COL_VECTOR) {
            return CLALibBinaryCellOp.binaryMVCol(m1, that, op, left);
        }
        if (atype == LibMatrixBincell.BinaryAccessType.MATRIX_MATRIX) {
            if (that.isEmpty()) {
                ScalarOperator sop = left ? new LeftScalarOperator(op.fn, 0.0, -1) : new RightScalarOperator(op.fn, 0.0, -1);
                return CLALibScalar.scalarOperations(sop, m1, result);
            }
            if (that.isInSparseFormat()) {
                return CLALibBinaryCellOp.binaryMMSparse(m1, that, op, left);
            }
            return CLALibBinaryCellOp.binaryMMDense(m1, that, op, left);
        }
        if (CLALibBinaryCellOp.isSupportedBinaryCellOp(op.fn)) {
            return CLALibBinaryCellOp.bincellOp(m1, that, CLALibBinaryCellOp.setupCompressedReturnMatrixBlock(m1, result), op, left);
        }
        LOG.warn((Object)("Decompressing since Binary Ops" + op.fn + " is not supported compressed"));
        return CompressedMatrixBlock.getUncompressed(m1).binaryOperations(op, thatValue, result);
    }

    private static boolean isSupportedBinaryCellOp(ValueFunction fn) {
        return fn instanceof Multiply || fn instanceof Divide || fn instanceof Plus || fn instanceof Minus || fn instanceof MinusMultiply || fn instanceof PlusMultiply;
    }

    private static CompressedMatrixBlock setupCompressedReturnMatrixBlock(CompressedMatrixBlock m1, MatrixValue result) {
        CompressedMatrixBlock ret = null;
        if (result == null || !(result instanceof CompressedMatrixBlock)) {
            ret = new CompressedMatrixBlock(m1.getNumRows(), m1.getNumColumns());
        } else {
            ret = (CompressedMatrixBlock)result;
            ret.reset(m1.getNumRows(), m1.getNumColumns());
        }
        return ret;
    }

    private static MatrixBlock bincellOp(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        if (CLALibBinaryCellOp.isValidForOverlappingBinaryCellOperations(m1, op)) {
            CLALibBinaryCellOp.overlappingBinaryCellOp(m1, m2, ret, op, left);
        } else {
            CLALibBinaryCellOp.nonOverlappingBinaryCellOp(m1, m2, ret, op, left);
        }
        return ret;
    }

    private static void nonOverlappingBinaryCellOp(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        LibMatrixBincell.BinaryAccessType atype = LibMatrixBincell.getBinaryAccessType(m1, m2);
        switch (atype) {
            case MATRIX_ROW_VECTOR: {
                CLALibBinaryCellOp.binaryMVRow(m1, m2, ret, op, left);
                return;
            }
            case OUTER_VECTOR_VECTOR: {
                if (m2.getNumRows() == 1 && m2.getNumColumns() == 1) {
                    CLALibScalar.scalarOperations(new RightScalarOperator(op.fn, m2.quickGetValue(0, 0)), m1, ret);
                }
                return;
            }
        }
        LOG.warn((Object)("Inefficient Decompression for " + op + "  " + (Object)((Object)atype)));
        m1.decompress().binaryOperations(op, m2, ret);
    }

    private static boolean isValidForOverlappingBinaryCellOperations(CompressedMatrixBlock m1, BinaryOperator op) {
        return m1.isOverlapping() && !(op.fn instanceof Multiply) && !(op.fn instanceof Divide);
    }

    private static void overlappingBinaryCellOp(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        if (!(op.fn instanceof Plus) && !(op.fn instanceof Minus)) {
            throw new NotImplementedException(op + " not implemented for Overlapping CLA");
        }
        CLALibBinaryCellOp.binaryMVPlusStack(m1, m2, ret, op, left);
    }

    public static CompressedMatrixBlock binaryMVRow(CompressedMatrixBlock m1, double[] v, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        List<AColGroup> oldColGroups = m1.getColGroups();
        if (ret == null) {
            ret = new CompressedMatrixBlock(m1.getNumRows(), m1.getNumColumns());
        }
        boolean sparseSafe = true;
        for (double x : v) {
            if (op.fn.execute(0.0, x) == 0.0) continue;
            sparseSafe = false;
            break;
        }
        ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>(oldColGroups.size());
        int k = op.getNumThreads();
        ExecutorService pool = CommonThreadPool.get(k);
        ArrayList<BinaryMVRowTask> tasks = new ArrayList<BinaryMVRowTask>();
        try {
            for (AColGroup aColGroup : oldColGroups) {
                tasks.add(new BinaryMVRowTask(aColGroup, v, sparseSafe, op, left));
            }
            for (Future future : pool.invokeAll(tasks)) {
                newColGroups.add((AColGroup)future.get());
            }
            pool.shutdown();
        }
        catch (InterruptedException | ExecutionException e) {
            e.printStackTrace();
            throw new DMLRuntimeException(e);
        }
        ret.allocateColGroupList(newColGroups);
        ret.setNonZeros(m1.getNumColumns() * m1.getNumRows());
        return ret;
    }

    protected static CompressedMatrixBlock binaryMVRow(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        return CLALibBinaryCellOp.binaryMVRow(m1, CLALibBinaryCellOp.forceMatrixBlockToDense(m2), ret, op, left);
    }

    private static double[] forceMatrixBlockToDense(MatrixBlock m2) {
        double[] v;
        if (m2.isInSparseFormat()) {
            SparseBlock sb = m2.getSparseBlock();
            if (sb == null) {
                throw new DMLRuntimeException("Unknown matrix block type");
            }
            double[] spV = sb.values(0);
            int[] spI = sb.indexes(0);
            v = new double[m2.getNumColumns()];
            for (int i = sb.pos(0); i < sb.size(0); ++i) {
                v[spI[i]] = spV[i];
            }
        } else {
            v = m2.getDenseBlockValues();
        }
        return v;
    }

    protected static CompressedMatrixBlock binaryMVPlusStack(CompressedMatrixBlock m1, MatrixBlock m2, CompressedMatrixBlock ret, BinaryOperator op, boolean left) {
        if (m2.isEmpty()) {
            return m1;
        }
        List<AColGroup> oldColGroups = m1.getColGroups();
        int size = oldColGroups.size();
        ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>(size);
        int smallestIndex = 0;
        int smallestSize = Integer.MAX_VALUE;
        int nCol = m1.getNumColumns();
        for (int i = 0; i < size; ++i) {
            AColGroup g = oldColGroups.get(i);
            int newSize = g.getNumValues();
            newColGroups.add(g);
            if (newSize >= smallestSize || g.getNumCols() != nCol) continue;
            smallestIndex = i;
            smallestSize = newSize;
        }
        if (smallestSize == Integer.MAX_VALUE) {
            int[] colIndexes = new int[nCol];
            for (int i = 0; i < nCol; ++i) {
                colIndexes[i] = i;
            }
            MatrixBlockDictionary newDict = new MatrixBlockDictionary(m2);
            newColGroups.add(new ColGroupConst(colIndexes, m1.getNumRows(), newDict));
        } else {
            AColGroup g = ((AColGroup)newColGroups.get(smallestIndex)).binaryRowOp(op, m2.getDenseBlockValues(), false, left);
            newColGroups.set(smallestIndex, g);
        }
        ret.allocateColGroupList(newColGroups);
        ret.setOverlapping(true);
        ret.setNonZeros(-1L);
        return ret;
    }

    private static MatrixBlock binaryMVCol(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left) {
        MatrixBlock ret = new MatrixBlock(m1.getNumRows(), m1.getNumColumns(), false, -1L).allocateBlock();
        int blkz = 65535;
        int k = OptimizerUtils.getConstrainedNumThreads(-1);
        ExecutorService pool = CommonThreadPool.get(k);
        ArrayList<Callable<Integer>> tasks = new ArrayList<Callable<Integer>>();
        try {
            int i = 0;
            while (i * 65535 < m1.getNumRows()) {
                if (left) {
                    tasks.add(new BinaryMVColLeftTask(m1, m2, ret, i * 65535, Math.min(m1.getNumRows(), (i + 1) * 65535), op));
                } else {
                    tasks.add(new BinaryMVColTask(m1, m2, ret, i * 65535, Math.min(m1.getNumRows(), (i + 1) * 65535), op));
                }
                ++i;
            }
            long nnz = 0L;
            for (Future f : pool.invokeAll(tasks)) {
                nnz += (long)((Integer)f.get()).intValue();
            }
            ret.setNonZeros(nnz);
            pool.shutdown();
        }
        catch (InterruptedException | ExecutionException e) {
            e.printStackTrace();
            throw new DMLRuntimeException(e);
        }
        return ret;
    }

    private static MatrixBlock binaryMMDense(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left) {
        MatrixBlock ret = new MatrixBlock(m1.getNumRows(), m1.getNumColumns(), false, -1L).allocateBlock();
        int blkz = 393210 / m1.getNumColumns();
        int k = OptimizerUtils.getConstrainedNumThreads(-1);
        ExecutorService pool = CommonThreadPool.get(k);
        ArrayList<Callable<Integer>> tasks = new ArrayList<Callable<Integer>>();
        try {
            int i = 0;
            while (i * blkz < m1.getNumRows()) {
                int start = i * blkz;
                int end = Math.min(m1.getNumRows(), (i + 1) * blkz);
                if (left) {
                    tasks.add(new BinaryMMLeftTask(m1, m2, ret, start, end, op));
                } else {
                    tasks.add(new BinaryMMTask(m1, m2, ret, start, end, op));
                }
                ++i;
            }
            long nnz = 0L;
            for (Future f : pool.invokeAll(tasks)) {
                nnz += (long)((Integer)f.get()).intValue();
            }
            ret.setNonZeros(nnz);
            pool.shutdown();
        }
        catch (InterruptedException | ExecutionException e) {
            e.printStackTrace();
            throw new DMLRuntimeException(e);
        }
        return ret;
    }

    private static MatrixBlock binaryMMSparse(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left) {
        throw new NotImplementedException("not implemented sparse Binary MM " + op.fn);
    }

    private static class BinaryMVRowTask
    implements Callable<AColGroup> {
        private final AColGroup _group;
        private final double[] _v;
        private final boolean _sparseSafe;
        private final BinaryOperator _op;
        private final boolean _left;

        protected BinaryMVRowTask(AColGroup group, double[] v, boolean sparseSafe, BinaryOperator op, boolean left) {
            this._group = group;
            this._v = v;
            this._op = op;
            this._sparseSafe = sparseSafe;
            this._left = left;
        }

        @Override
        public AColGroup call() {
            return this._group.binaryRowOp(this._op, this._v, this._sparseSafe, this._left);
        }
    }

    private static class BinaryMMLeftTask
    implements Callable<Integer> {
        private final int _rl;
        private final int _ru;
        private final CompressedMatrixBlock _m1;
        private final MatrixBlock _m2;
        private final MatrixBlock _ret;
        private final BinaryOperator _op;

        protected BinaryMMLeftTask(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, BinaryOperator op) {
            this._m1 = m1;
            this._m2 = m2;
            this._ret = ret;
            this._op = op;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public Integer call() {
            AColGroup.decompressColumnToBlockUnSafe(this._ret, this._rl, this._ru, this._m1.getColGroups());
            if (this._m2.isInSparseFormat()) {
                throw new NotImplementedException("Not Implemented sparse Format execution for MM.");
            }
            double[] _retDense = this._ret.getDenseBlockValues();
            double[] _m2Dense = this._m2.getDenseBlockValues();
            int nnz = 0;
            int numCols = this._m1.getNumColumns();
            int offset = this._rl * numCols;
            for (int row = this._rl; row < this._ru; ++row) {
                for (int col = 0; col < numCols; ++col) {
                    double v = this._op.fn.execute(_m2Dense[offset], _retDense[offset]);
                    nnz += v != 0.0 ? 1 : 0;
                    _retDense[offset] = v;
                    ++offset;
                }
            }
            return nnz;
        }
    }

    private static class BinaryMMTask
    implements Callable<Integer> {
        private final int _rl;
        private final int _ru;
        private final CompressedMatrixBlock _m1;
        private final MatrixBlock _m2;
        private final MatrixBlock _ret;
        private final BinaryOperator _op;

        protected BinaryMMTask(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, BinaryOperator op) {
            this._m1 = m1;
            this._m2 = m2;
            this._ret = ret;
            this._op = op;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public Integer call() {
            AColGroup.decompressColumnToBlockUnSafe(this._ret, this._rl, this._ru, this._m1.getColGroups());
            if (this._m2.isInSparseFormat()) {
                throw new NotImplementedException("Not Implemented sparse Format execution for MM.");
            }
            double[] _retDense = this._ret.getDenseBlockValues();
            double[] _m2Dense = this._m2.getDenseBlockValues();
            int nnz = 0;
            int offset = this._rl * this._m1.getNumColumns();
            for (int row = this._rl; row < this._ru; ++row) {
                for (int col = 0; col < this._m1.getNumColumns(); ++col) {
                    double v = this._op.fn.execute(_retDense[offset], _m2Dense[offset]);
                    nnz += v != 0.0 ? 1 : 0;
                    _retDense[offset] = v;
                    ++offset;
                }
            }
            return nnz;
        }
    }

    private static class BinaryMVColLeftTask
    implements Callable<Integer> {
        private final int _rl;
        private final int _ru;
        private final CompressedMatrixBlock _m1;
        private final MatrixBlock _m2;
        private final MatrixBlock _ret;
        private final BinaryOperator _op;

        protected BinaryMVColLeftTask(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, BinaryOperator op) {
            this._m1 = m1;
            this._m2 = m2;
            this._ret = ret;
            this._op = op;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public Integer call() {
            for (AColGroup g : this._m1.getColGroups()) {
                g.decompressToBlockUnSafe(this._ret, this._rl, this._ru);
            }
            if (this._m2.isInSparseFormat()) {
                throw new NotImplementedException("Not Implemented sparse Format execution for MM.");
            }
            int offset = this._rl * this._m1.getNumColumns();
            double[] _retDense = this._ret.getDenseBlockValues();
            double[] _m2Dense = this._m2.getDenseBlockValues();
            for (int row = this._rl; row < this._ru; ++row) {
                double vr = _m2Dense[row];
                for (int col = 0; col < this._m1.getNumColumns(); ++col) {
                    double v;
                    _retDense[offset] = v = this._op.fn.execute(vr, _retDense[offset]);
                    ++offset;
                }
            }
            return this._ret.getNumColumns() * this._ret.getNumRows();
        }
    }

    private static class BinaryMVColTask
    implements Callable<Integer> {
        private final int _rl;
        private final int _ru;
        private final CompressedMatrixBlock _m1;
        private final MatrixBlock _m2;
        private final MatrixBlock _ret;
        private final BinaryOperator _op;

        protected BinaryMVColTask(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, BinaryOperator op) {
            this._m1 = m1;
            this._m2 = m2;
            this._ret = ret;
            this._op = op;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public Integer call() {
            for (AColGroup g : this._m1.getColGroups()) {
                g.decompressToBlockUnSafe(this._ret, this._rl, this._ru);
            }
            if (this._m2.isInSparseFormat()) {
                throw new NotImplementedException("Not Implemented sparse Format execution for MM.");
            }
            int offset = this._rl * this._m1.getNumColumns();
            double[] _retDense = this._ret.getDenseBlockValues();
            double[] _m2Dense = this._m2.getDenseBlockValues();
            for (int row = this._rl; row < this._ru; ++row) {
                double vr = _m2Dense[row];
                for (int col = 0; col < this._m1.getNumColumns(); ++col) {
                    double v;
                    _retDense[offset] = v = this._op.fn.execute(_retDense[offset], vr);
                    ++offset;
                }
            }
            return this._ret.getNumColumns() * this._ret.getNumRows();
        }
    }
}

