/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysml.lops.BinaryM;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.spark.BinaryMatrixBVectorSPInstruction;
import org.apache.sysml.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
import org.apache.sysml.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
import org.apache.sysml.runtime.instructions.spark.ComputationSPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixMatrixBinaryOpFunction;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixScalarUnaryFunction;
import org.apache.sysml.runtime.instructions.spark.functions.MatrixVectorBinaryOpPartitionFunction;
import org.apache.sysml.runtime.instructions.spark.functions.OuterVectorBinaryOpFunction;
import org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.ScalarOperator;

public abstract class BinarySPInstruction
extends ComputationSPInstruction {
    protected BinarySPInstruction(SPInstruction.SPType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(type, op, in1, in2, out, opcode, istr);
    }

    public static BinarySPInstruction parseInstruction(String str) throws DMLRuntimeException {
        CPOperand in1 = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
        CPOperand in2 = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
        CPOperand out = new CPOperand("", Expression.ValueType.UNKNOWN, Expression.DataType.UNKNOWN);
        String opcode = null;
        boolean isBroadcast = false;
        BinaryM.VectorType vtype = null;
        if (str.startsWith("SPARK\u00b0map")) {
            String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
            InstructionUtils.checkNumFields(parts, 5);
            opcode = parts[0];
            in1.split(parts[1]);
            in2.split(parts[2]);
            out.split(parts[3]);
            vtype = BinaryM.VectorType.valueOf(parts[5]);
            isBroadcast = true;
        } else {
            opcode = BinarySPInstruction.parseBinaryInstruction(str, in1, in2, out);
        }
        Expression.DataType dt1 = in1.getDataType();
        Expression.DataType dt2 = in2.getDataType();
        Operator operator = InstructionUtils.parseExtendedBinaryOrBuiltinOperator(opcode, in1, in2);
        if (dt1 == Expression.DataType.MATRIX || dt2 == Expression.DataType.MATRIX) {
            if (dt1 == Expression.DataType.MATRIX && dt2 == Expression.DataType.MATRIX) {
                if (isBroadcast) {
                    return new BinaryMatrixBVectorSPInstruction(operator, in1, in2, out, vtype, opcode, str);
                }
                return new BinaryMatrixMatrixSPInstruction(operator, in1, in2, out, opcode, str);
            }
            return new BinaryMatrixScalarSPInstruction(operator, in1, in2, out, opcode, str);
        }
        return null;
    }

    protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
        InstructionUtils.checkNumFields(parts, 3);
        String opcode = parts[0];
        in1.split(parts[1]);
        in2.split(parts[2]);
        out.split(parts[3]);
        return opcode;
    }

    protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out) throws DMLRuntimeException {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
        InstructionUtils.checkNumFields(parts, 4);
        String opcode = parts[0];
        in1.split(parts[1]);
        in2.split(parts[2]);
        in3.split(parts[3]);
        out.split(parts[4]);
        return opcode;
    }

    protected void processMatrixMatrixBinaryInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        this.checkMatrixMatrixBinaryCharacteristics(sec);
        this.updateBinaryOutputMatrixCharacteristics(sec);
        JavaPairRDD in1 = sec.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD in2 = sec.getBinaryBlockRDDHandleForVariable(this.input2.getName());
        MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(this.input2.getName());
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this.output.getName());
        BinaryOperator bop = (BinaryOperator)this._optr;
        boolean rowvector = mc2.getRows() == 1L && mc1.getRows() > 1L;
        long numRepLeft = this.getNumReplicas(mc1, mc2, true);
        long numRepRight = this.getNumReplicas(mc1, mc2, false);
        if (numRepLeft > 1L) {
            in1 = in1.flatMapToPair((PairFlatMapFunction)new ReplicateVectorFunction(false, numRepLeft));
        }
        if (numRepRight > 1L) {
            in2 = in2.flatMapToPair((PairFlatMapFunction)new ReplicateVectorFunction(rowvector, numRepRight));
        }
        int numPrefPart = SparkUtils.isHashPartitioned(in1) ? in1.getNumPartitions() : (SparkUtils.isHashPartitioned(in2) ? in2.getNumPartitions() : Math.min(in1.getNumPartitions() + in2.getNumPartitions(), 2 * SparkUtils.getNumPreferredPartitions(mcOut)));
        JavaPairRDD out = in1.join(in2, numPrefPart).mapValues((Function)new MatrixMatrixBinaryOpFunction(bop));
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), this.input1.getName());
        sec.addLineageRDD(this.output.getName(), this.input2.getName());
    }

    protected void processMatrixBVectorBinaryInstruction(ExecutionContext ec, BinaryM.VectorType vtype) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        this.checkMatrixMatrixBinaryCharacteristics(sec);
        String rddVar = this.input1.getName();
        String bcastVar = this.input2.getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
        PartitionedBroadcast<MatrixBlock> in2 = sec.getBroadcastForVariable(bcastVar);
        MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(rddVar);
        MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(bcastVar);
        BinaryOperator bop = (BinaryOperator)this._optr;
        boolean isOuter = mc1.getRows() > 1L && mc1.getCols() == 1L && mc2.getRows() == 1L && mc2.getCols() > 1L;
        JavaPairRDD out = null;
        out = isOuter ? in1.flatMapToPair((PairFlatMapFunction)new OuterVectorBinaryOpFunction(bop, in2)) : in1.mapPartitionsToPair((PairFlatMapFunction)new MatrixVectorBinaryOpPartitionFunction(bop, in2, vtype), true);
        this.updateBinaryOutputMatrixCharacteristics(sec);
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), rddVar);
        sec.addLineageBroadcast(this.output.getName(), bcastVar);
    }

    protected void processMatrixScalarBinaryInstruction(ExecutionContext ec) throws DMLRuntimeException {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        String rddVar = this.input1.getDataType() == Expression.DataType.MATRIX ? this.input1.getName() : this.input2.getName();
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(rddVar);
        CPOperand scalar = this.input1.getDataType() == Expression.DataType.MATRIX ? this.input2 : this.input1;
        ScalarObject constant = ec.getScalarInput(scalar.getName(), scalar.getValueType(), scalar.isLiteral());
        ScalarOperator sc_op = (ScalarOperator)this._optr;
        sc_op = sc_op.setConstant(constant.getDoubleValue());
        JavaPairRDD out = in1.mapValues((Function)new MatrixScalarUnaryFunction(sc_op));
        this.updateUnaryOutputMatrixCharacteristics(sec, rddVar, this.output.getName());
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), rddVar);
    }

    protected MatrixCharacteristics updateBinaryMMOutputMatrixCharacteristics(SparkExecutionContext sec, boolean checkCommonDim) throws DMLRuntimeException {
        MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(this.input2.getName());
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this.output.getName());
        if (!mcOut.dimsKnown()) {
            if (!mc1.dimsKnown() || !mc2.dimsKnown()) {
                throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
            }
            if (mc1.getRowsPerBlock() != mc2.getRowsPerBlock() || mc1.getColsPerBlock() != mc2.getColsPerBlock()) {
                throw new DMLRuntimeException("Incompatible block sizes for BinarySPInstruction.");
            }
            if (checkCommonDim && mc1.getCols() != mc2.getRows()) {
                throw new DMLRuntimeException("Incompatible dimensions for BinarySPInstruction");
            }
            mcOut.set(mc1.getRows(), mc2.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
        }
        return mcOut;
    }

    protected void updateBinaryAppendOutputMatrixCharacteristics(SparkExecutionContext sec, boolean cbind) throws DMLRuntimeException {
        MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(this.input2.getName());
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this.output.getName());
        if (!mcOut.dimsKnown()) {
            if (!mc1.dimsKnown() || !mc2.dimsKnown()) {
                throw new DMLRuntimeException("The output dimensions are not specified and cannot be inferred from inputs.");
            }
            if (cbind) {
                mcOut.set(mc1.getRows(), mc1.getCols() + mc2.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
            } else {
                mcOut.set(mc1.getRows() + mc2.getRows(), mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock());
            }
        }
        if (!mcOut.nnzKnown() && mc1.nnzKnown() && mc2.nnzKnown()) {
            mcOut.setNonZeros(mc1.getNonZeros() + mc2.getNonZeros());
        }
    }

    protected long getNumReplicas(MatrixCharacteristics mc1, MatrixCharacteristics mc2, boolean left) {
        if (left) {
            if (mc1.getCols() == 1L) {
                return (long)Math.ceil((double)mc2.getCols() / (double)mc2.getColsPerBlock());
            }
        } else {
            if (mc2.getRows() == 1L && mc1.getRows() > 1L) {
                return (long)Math.ceil((double)mc1.getRows() / (double)mc1.getRowsPerBlock());
            }
            if (mc2.getCols() == 1L && mc1.getCols() > 1L) {
                return (long)Math.ceil((double)mc1.getCols() / (double)mc1.getColsPerBlock());
            }
        }
        return 1L;
    }

    protected void checkMatrixMatrixBinaryCharacteristics(SparkExecutionContext sec) throws DMLRuntimeException {
        MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(this.input2.getName());
        if (!mc1.dimsKnown() || !mc2.dimsKnown()) {
            throw new DMLRuntimeException("Unknown dimensions matrix-matrix binary operations: [" + mc1.getRows() + "x" + mc1.getCols() + " vs " + mc2.getRows() + "x" + mc2.getCols() + "]");
        }
        if (!(mc1.getRows() == mc2.getRows() && mc1.getCols() == mc2.getCols() || mc1.getRows() == mc2.getRows() && mc2.getCols() == 1L || mc1.getCols() == mc2.getCols() && mc2.getRows() == 1L || mc1.getCols() == 1L && mc2.getRows() == 1L)) {
            throw new DMLRuntimeException("Dimensions mismatch matrix-matrix binary operations: [" + mc1.getRows() + "x" + mc1.getCols() + " vs " + mc2.getRows() + "x" + mc2.getCols() + "]");
        }
        if (mc1.getRowsPerBlock() != mc2.getRowsPerBlock() || mc1.getColsPerBlock() != mc2.getColsPerBlock()) {
            throw new DMLRuntimeException("Blocksize mismatch matrix-matrix binary operations: [" + mc1.getRowsPerBlock() + "x" + mc1.getColsPerBlock() + " vs " + mc2.getRowsPerBlock() + "x" + mc2.getColsPerBlock() + "]");
        }
    }

    protected void checkBinaryAppendInputCharacteristics(SparkExecutionContext sec, boolean cbind, boolean checkSingleBlk, boolean checkAligned) throws DMLRuntimeException {
        MatrixCharacteristics mc1 = sec.getMatrixCharacteristics(this.input1.getName());
        MatrixCharacteristics mc2 = sec.getMatrixCharacteristics(this.input2.getName());
        if (!mc1.dimsKnown() || !mc2.dimsKnown()) {
            throw new DMLRuntimeException("The dimensions unknown for inputs");
        }
        if (cbind && mc1.getRows() != mc2.getRows()) {
            throw new DMLRuntimeException("The number of rows of inputs should match for append-cbind instruction");
        }
        if (!cbind && mc1.getCols() != mc2.getCols()) {
            throw new DMLRuntimeException("The number of columns of inputs should match for append-rbind instruction");
        }
        if (mc1.getRowsPerBlock() != mc2.getRowsPerBlock() || mc1.getColsPerBlock() != mc2.getColsPerBlock()) {
            throw new DMLRuntimeException("The block sizes donot match for input matrices");
        }
        if (checkSingleBlk && mc1.getCols() + mc2.getCols() > (long)mc1.getColsPerBlock()) {
            throw new DMLRuntimeException("Output must have at most one column block");
        }
        if (checkAligned && mc1.getCols() % (long)mc1.getColsPerBlock() != 0L) {
            throw new DMLRuntimeException("Input matrices are not aligned to blocksize boundaries. Wrong append selected");
        }
    }
}

