/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.rewrite;

import java.io.IOException;
import java.util.ArrayList;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.hops.rewrite.HopRewriteRule;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.compile.Dag;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;

public class RewriteConstantFolding
extends HopRewriteRule {
    private static final String TMP_VARNAME = "__cf_tmp";
    private ProgramBlock _tmpPB = null;
    private ExecutionContext _tmpEC = null;

    @Override
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
        if (roots == null) {
            return null;
        }
        for (int i = 0; i < roots.size(); ++i) {
            Hop h = roots.get(i);
            roots.set(i, this.rule_ConstantFolding(h));
        }
        return roots;
    }

    @Override
    public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
        if (root == null) {
            return null;
        }
        return this.rule_ConstantFolding(root);
    }

    private Hop rule_ConstantFolding(Hop hop) throws HopsException {
        return this.rConstantFoldingExpression(hop);
    }

    private Hop rConstantFoldingExpression(Hop root) throws HopsException {
        if (root.isVisited()) {
            return root;
        }
        for (int i = 0; i < root.getInput().size(); ++i) {
            Hop h = root.getInput().get(i);
            this.rConstantFoldingExpression(h);
        }
        LiteralOp literal = null;
        if (root.getDataType() == Expression.DataType.SCALAR && (this.isApplicableBinaryOp(root) || this.isApplicableUnaryOp(root))) {
            try {
                literal = this.evalScalarOperation(root);
            }
            catch (Exception ex) {
                LOG.error((Object)"Failed to execute constant folding instructions. No abort.", (Throwable)ex);
            }
        } else if (this.isApplicableFalseConjunctivePredicate(root)) {
            literal = new LiteralOp(false);
        } else if (this.isApplicableTrueDisjunctivePredicate(root)) {
            literal = new LiteralOp(true);
        }
        if (literal != null) {
            int plen = root.getParent().size();
            if (plen > 0) {
                for (int i = 0; i < root.getParent().size(); ++i) {
                    Hop parent = root.getParent().get(i);
                    for (int j = 0; j < parent.getInput().size(); ++j) {
                        Hop child = parent.getInput().get(j);
                        if (root != child) continue;
                        parent.getInput().remove(j);
                        HopRewriteUtils.addChildReference(parent, literal, j);
                    }
                }
                root.getParent().clear();
            } else {
                root = literal;
            }
        }
        root.setVisited();
        return root;
    }

    private LiteralOp evalScalarOperation(Hop bop) throws LopsException, DMLRuntimeException, IOException, HopsException {
        DataOp tmpWrite = new DataOp(TMP_VARNAME, bop.getDataType(), bop.getValueType(), bop, Hop.DataOpTypes.TRANSIENTWRITE, TMP_VARNAME);
        Dag<Lop> dag = new Dag<Lop>();
        Recompiler.rClearLops(tmpWrite);
        Lop lops = tmpWrite.constructLops();
        lops.addToDag(dag);
        ArrayList<Instruction> inst = dag.getJobs(null, ConfigurationManager.getDMLConfig());
        ExecutionContext ec = this.getExecutionContext();
        ProgramBlock pb = this.getProgramBlock();
        pb.setInstructions(inst);
        pb.execute(ec);
        ScalarObject so = (ScalarObject)ec.getVariable(TMP_VARNAME);
        LiteralOp literal = null;
        switch (so.getValueType()) {
            case DOUBLE: {
                literal = new LiteralOp(so.getDoubleValue());
                break;
            }
            case INT: {
                literal = new LiteralOp(so.getLongValue());
                break;
            }
            case BOOLEAN: {
                literal = new LiteralOp(so.getBooleanValue());
                break;
            }
            case STRING: {
                literal = new LiteralOp(so.getStringValue());
                break;
            }
            default: {
                throw new HopsException("Unsupported literal value type: " + (Object)((Object)bop.getValueType()));
            }
        }
        tmpWrite.getInput().clear();
        bop.getParent().remove(tmpWrite);
        pb.setInstructions(null);
        ec.getVariables().removeAll();
        HopRewriteUtils.setOutputParametersForScalar(literal);
        return literal;
    }

    private ProgramBlock getProgramBlock() throws DMLRuntimeException {
        if (this._tmpPB == null) {
            this._tmpPB = new ProgramBlock(new Program());
        }
        return this._tmpPB;
    }

    private ExecutionContext getExecutionContext() {
        if (this._tmpEC == null) {
            this._tmpEC = ExecutionContextFactory.createContext();
        }
        return this._tmpEC;
    }

    private boolean isApplicableBinaryOp(Hop hop) {
        ArrayList<Hop> in = hop.getInput();
        return hop instanceof BinaryOp && in.get(0) instanceof LiteralOp && in.get(1) instanceof LiteralOp && ((BinaryOp)hop).getOp() != Hop.OpOp2.CBIND && ((BinaryOp)hop).getOp() != Hop.OpOp2.RBIND;
    }

    private boolean isApplicableUnaryOp(Hop hop) {
        ArrayList<Hop> in = hop.getInput();
        return hop instanceof UnaryOp && in.get(0) instanceof LiteralOp && ((UnaryOp)hop).getOp() != Hop.OpOp1.PRINT && ((UnaryOp)hop).getOp() != Hop.OpOp1.STOP && hop.getDataType() == Expression.DataType.SCALAR;
    }

    private boolean isApplicableFalseConjunctivePredicate(Hop hop) throws HopsException {
        ArrayList<Hop> in = hop.getInput();
        return HopRewriteUtils.isBinary(hop, Hop.OpOp2.AND) && (in.get(0) instanceof LiteralOp && !((LiteralOp)in.get(0)).getBooleanValue() || in.get(1) instanceof LiteralOp && !((LiteralOp)in.get(1)).getBooleanValue());
    }

    private boolean isApplicableTrueDisjunctivePredicate(Hop hop) throws HopsException {
        ArrayList<Hop> in = hop.getInput();
        return HopRewriteUtils.isBinary(hop, Hop.OpOp2.OR) && (in.get(0) instanceof LiteralOp && ((LiteralOp)in.get(0)).getBooleanValue() || in.get(1) instanceof LiteralOp && ((LiteralOp)in.get(1)).getBooleanValue());
    }
}

