/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.mr;

import java.util.ArrayList;
import org.apache.sysml.lops.MapMultChain;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.mr.IDistributedCacheConsumer;
import org.apache.sysml.runtime.instructions.mr.MRInstruction;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.mapred.MRBaseForCommonInstructions;

public class MapMultChainInstruction
extends MRInstruction
implements IDistributedCacheConsumer {
    private MapMultChain.ChainType _chainType = null;
    private byte _input1 = (byte)-1;
    private byte _input2 = (byte)-1;
    private byte _input3 = (byte)-1;

    private MapMultChainInstruction(MapMultChain.ChainType type, byte in1, byte in2, byte out, String istr) {
        super(MRInstruction.MRType.MapMultChain, null, out);
        this._chainType = type;
        this._input1 = in1;
        this._input2 = in2;
        this._input3 = (byte)-1;
        this.instString = istr;
    }

    public MapMultChainInstruction(MapMultChain.ChainType type, byte in1, byte in2, byte in3, byte out, String istr) {
        super(MRInstruction.MRType.MapMultChain, null, out);
        this._chainType = type;
        this._input1 = in1;
        this._input2 = in2;
        this._input3 = in3;
        this.instString = istr;
    }

    public MapMultChain.ChainType getChainType() {
        return this._chainType;
    }

    public byte getInput1() {
        return this._input1;
    }

    public byte getInput2() {
        return this._input2;
    }

    public byte getInput3() {
        return this._input3;
    }

    public static MapMultChainInstruction parseInstruction(String str) throws DMLRuntimeException {
        InstructionUtils.checkNumFields(str, 4, 5);
        String[] parts = InstructionUtils.getInstructionParts(str);
        byte in1 = Byte.parseByte(parts[1]);
        byte in2 = Byte.parseByte(parts[2]);
        if (parts.length == 5) {
            byte out = Byte.parseByte(parts[3]);
            MapMultChain.ChainType type = MapMultChain.ChainType.valueOf(parts[4]);
            return new MapMultChainInstruction(type, in1, in2, out, str);
        }
        byte in3 = Byte.parseByte(parts[3]);
        byte out = Byte.parseByte(parts[4]);
        MapMultChain.ChainType type = MapMultChain.ChainType.valueOf(parts[5]);
        return new MapMultChainInstruction(type, in1, in2, in3, out, str);
    }

    @Override
    public boolean isDistCacheOnlyIndex(String inst, byte index) {
        return this._chainType == MapMultChain.ChainType.XtXv ? index == this._input2 && index != this._input1 : index == this._input2 && index != this._input1 || index == this._input3 && index != this._input1;
    }

    @Override
    public void addDistCacheIndex(String inst, ArrayList<Byte> indexes) {
        if (this._chainType == MapMultChain.ChainType.XtXv) {
            indexes.add(this._input2);
        } else {
            indexes.add(this._input2);
            indexes.add(this._input3);
        }
    }

    @Override
    public byte[] getInputIndexes() {
        if (this._chainType == MapMultChain.ChainType.XtXv) {
            return new byte[]{this._input1, this._input2};
        }
        return new byte[]{this._input1, this._input2, this._input3};
    }

    @Override
    public byte[] getAllIndexes() {
        if (this._chainType == MapMultChain.ChainType.XtXv) {
            return new byte[]{this._input1, this._input2, this.output};
        }
        return new byte[]{this._input1, this._input2, this._input3, this.output};
    }

    @Override
    public void processInstruction(Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) throws DMLRuntimeException {
        ArrayList<IndexedMatrixValue> blkList = cachedValues.get(this._input1);
        if (blkList != null) {
            for (IndexedMatrixValue imv : blkList) {
                if (imv == null) continue;
                MatrixIndexes inIx = imv.getIndexes();
                MatrixValue inVal = imv.getValue();
                IndexedMatrixValue iout = null;
                iout = this.output == this._input1 ? tempValue : cachedValues.holdPlace(this.output, valueClass);
                MatrixIndexes outIx = iout.getIndexes();
                MatrixValue outVal = iout.getValue();
                if (this._chainType == MapMultChain.ChainType.XtXv) {
                    this.processXtXvOperations(inIx, inVal, outIx, outVal);
                } else {
                    this.processXtwXvOperations(inIx, inVal, outIx, outVal, this._chainType);
                }
                if (iout != tempValue) continue;
                cachedValues.add(this.output, iout);
            }
        }
    }

    private void processXtXvOperations(MatrixIndexes inIx, MatrixValue inVal, MatrixIndexes outIx, MatrixValue outVal) throws DMLRuntimeException {
        DistributedCacheInput dcInput2 = MRBaseForCommonInstructions.dcValues.get(this._input2);
        MatrixBlock Xi = (MatrixBlock)inVal;
        MatrixBlock v = (MatrixBlock)dcInput2.getDataBlock(1, 1).getValue();
        Xi.chainMatrixMultOperations(v, null, (MatrixBlock)outVal, MapMultChain.ChainType.XtXv);
        outIx.setIndexes(1L, 1L);
    }

    private void processXtwXvOperations(MatrixIndexes inIx, MatrixValue inVal, MatrixIndexes outIx, MatrixValue outVal, MapMultChain.ChainType chain) throws DMLRuntimeException {
        DistributedCacheInput dcInput2 = MRBaseForCommonInstructions.dcValues.get(this._input2);
        DistributedCacheInput dcInput3 = MRBaseForCommonInstructions.dcValues.get(this._input3);
        MatrixBlock Xi = (MatrixBlock)inVal;
        MatrixBlock v = (MatrixBlock)dcInput2.getDataBlock(1, 1).getValue();
        MatrixBlock w = (MatrixBlock)dcInput3.getDataBlock((int)inIx.getRowIndex(), 1).getValue();
        Xi.chainMatrixMultOperations(v, w, (MatrixBlock)outVal, chain);
        outIx.setIndexes(1L, 1L);
    }
}

