/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LeftIndexingOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.IfStatement;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;
import org.apache.sysml.parser.WhileStatement;
import org.apache.sysml.parser.WhileStatementBlock;

public class RewriteMarkLoopVariablesUpdateInPlace
extends StatementBlockRewriteRule {
    @Override
    public boolean createsSplitDag() {
        return false;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) throws HopsException {
        if (DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.HADOOP || DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.SPARK) {
            return Arrays.asList(sb);
        }
        if (sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
            ArrayList<String> candidates = new ArrayList<String>();
            VariableSet updated = sb.variablesUpdated();
            VariableSet liveout = sb.liveOut();
            for (String varname : updated.getVariableNames()) {
                Statement wstmt;
                if (updated.getVariable(varname).getDataType() != Expression.DataType.MATRIX || !liveout.containsVariable(varname)) continue;
                if (sb instanceof WhileStatementBlock) {
                    wstmt = (WhileStatement)sb.getStatement(0);
                    if (!this.rIsApplicableForUpdateInPlace(((WhileStatement)wstmt).getBody(), varname)) continue;
                    candidates.add(varname);
                    continue;
                }
                if (!(sb instanceof ForStatementBlock) || !this.rIsApplicableForUpdateInPlace(((ForStatement)(wstmt = (ForStatement)sb.getStatement(0))).getBody(), varname)) continue;
                candidates.add(varname);
            }
            sb.setUpdateInPlaceVars(candidates);
        }
        return Arrays.asList(sb);
    }

    private boolean rIsApplicableForUpdateInPlace(ArrayList<StatementBlock> sbs, String varname) throws HopsException {
        boolean ret = true;
        for (StatementBlock sb : sbs) {
            if (!sb.variablesRead().containsVariable(varname) && !sb.variablesUpdated().containsVariable(varname)) continue;
            if (sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
                ret &= sb.getUpdateInPlaceVars().contains(varname);
            } else if (sb instanceof IfStatementBlock) {
                IfStatementBlock isb = (IfStatementBlock)sb;
                IfStatement istmt = (IfStatement)isb.getStatement(0);
                if ((ret &= this.rIsApplicableForUpdateInPlace(istmt.getIfBody(), varname)) && istmt.getElseBody() != null) {
                    ret &= this.rIsApplicableForUpdateInPlace(istmt.getElseBody(), varname);
                }
            } else if (sb.getHops() != null) {
                for (Hop hop : sb.getHops()) {
                    ret &= RewriteMarkLoopVariablesUpdateInPlace.isApplicableForUpdateInPlace(hop, varname);
                }
            }
            if (ret) continue;
            break;
        }
        return ret;
    }

    private static boolean isApplicableForUpdateInPlace(Hop hop, String varname) {
        boolean validLix;
        if (!hop.getName().equals(varname)) {
            return true;
        }
        boolean bl = validLix = hop instanceof DataOp && hop.isMatrix() && hop.getInput().get(0).isMatrix() && hop.getInput().get(0) instanceof LeftIndexingOp && hop.getInput().get(0).getInput().get(0) instanceof DataOp && hop.getInput().get(0).getInput().get(0).getName().equals(varname);
        if (validLix) {
            for (Hop p : hop.getInput().get(0).getInput().get(0).getParent()) {
                validLix &= p == hop.getInput().get(0) || p instanceof UnaryOp && ((UnaryOp)p).getOp() == Hop.OpOp1.NROW || p instanceof UnaryOp && ((UnaryOp)p).getOp() == Hop.OpOp1.NCOL;
            }
        }
        return validLix;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) throws HopsException {
        return sbs;
    }
}

