/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.hops.rewrite.ProgramRewriter;
import org.apache.sysml.hops.rewrite.RewriteCommonSubexpressionElimination;
import org.apache.sysml.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysml.parser.ForStatementBlock;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.IfStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;
import org.apache.sysml.parser.WhileStatementBlock;

public class RewriteMergeBlockSequence
extends StatementBlockRewriteRule {
    private ProgramRewriter rewriter = new ProgramRewriter(new RewriteCommonSubexpressionElimination(true));

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) throws HopsException {
        return Arrays.asList(sb);
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) throws HopsException {
        if (sbs == null || sbs.isEmpty()) {
            return sbs;
        }
        ArrayList<StatementBlock> tmpList = new ArrayList<StatementBlock>(sbs);
        boolean merged = true;
        block0: while (merged) {
            merged = false;
            for (int i = 0; i < tmpList.size() - 1; ++i) {
                StatementBlock sb1 = tmpList.get(i);
                StatementBlock sb2 = tmpList.get(i + 1);
                if (!RewriteMergeBlockSequence.isLastLevelStatementBlock(sb1) || !RewriteMergeBlockSequence.isLastLevelStatementBlock(sb2) || RewriteMergeBlockSequence.hasFunctionOpRoot(sb1) || sb1.isSplitDag() || sb2.isSplitDag() || sb2.getBeginLine() == 34) continue;
                ArrayList<Hop> sb1Hops = sb1.get_hops();
                ArrayList<Hop> sb2Hops = sb2.get_hops();
                Hop.resetVisitStatus(sb2Hops);
                HashMap<String, Hop> treads = new HashMap<String, Hop>();
                HashMap<String, Hop> twrites = new HashMap<String, Hop>();
                for (Hop root : sb2Hops) {
                    this.rCollectTransientReadWrites(root, treads, twrites);
                }
                Hop.resetVisitStatus(sb2Hops);
                Hop.resetVisitStatus(sb1Hops);
                for (Hop root : sb1Hops) {
                    if (HopRewriteUtils.isData(root, Hop.DataOpTypes.TRANSIENTWRITE) && treads.containsKey(root.getName())) {
                        Hop tread = treads.get(root.getName());
                        Hop in = root.getInput().get(0);
                        for (Hop parent : new ArrayList<Hop>(tread.getParent())) {
                            HopRewriteUtils.replaceChildReference(parent, tread, in);
                        }
                        HopRewriteUtils.removeAllChildReferences(root);
                        if (twrites.containsKey(root.getName()) || !sb2.liveOut().containsVariable(root.getName())) continue;
                        sb2Hops.add(HopRewriteUtils.createDataOp(root.getName(), in, Hop.DataOpTypes.TRANSIENTWRITE));
                        continue;
                    }
                    if (HopRewriteUtils.isData(root, Hop.DataOpTypes.TRANSIENTWRITE) && twrites.containsKey(root.getName())) continue;
                    sb2Hops.add(root);
                }
                Hop.resetVisitStatus(sb2Hops);
                this.rewriter.rewriteHopDAGs(sb2Hops, new ProgramRewriteStatus());
                sb2.setLiveIn(sb1.liveIn());
                sb2.setGen(VariableSet.minus(VariableSet.union(sb1.getGen(), sb2.getGen()), sb1.getKill()));
                sb2.setKill(VariableSet.union(sb1.getKill(), sb2.getKill()));
                sb2.setReadVariables(VariableSet.union(sb1.variablesRead(), sb2.variablesRead()));
                sb2.setUpdatedVariables(VariableSet.union(sb1.variablesUpdated(), sb2.variablesUpdated()));
                LOG.debug((Object)("Applied mergeStatementBlockSequences (blocks of lines " + sb1.getBeginLine() + "-" + sb1.getEndLine() + " and " + sb2.getBeginLine() + "-" + sb2.getEndLine() + ")."));
                sb2.setBeginLine(sb1.getBeginLine());
                sb2.setBeginColumn(sb1.getBeginColumn());
                tmpList.remove(i);
                merged = true;
                continue block0;
            }
        }
        return tmpList;
    }

    private void rCollectTransientReadWrites(Hop current, HashMap<String, Hop> treads, HashMap<String, Hop> twrites) {
        if (current.isVisited()) {
            return;
        }
        for (Hop c : current.getInput()) {
            this.rCollectTransientReadWrites(c, treads, twrites);
        }
        if (HopRewriteUtils.isData(current, Hop.DataOpTypes.TRANSIENTREAD)) {
            treads.put(current.getName(), current);
        } else if (HopRewriteUtils.isData(current, Hop.DataOpTypes.TRANSIENTWRITE)) {
            twrites.put(current.getName(), current);
        } else if (current instanceof FunctionOp) {
            for (String output : ((FunctionOp)current).getOutputVariableNames()) {
                twrites.put(output, null);
            }
        }
        current.setVisited();
    }

    private static boolean hasFunctionOpRoot(StatementBlock sb) throws HopsException {
        if (sb == null || sb.get_hops() == null) {
            return false;
        }
        boolean ret = false;
        for (Hop root : sb.get_hops()) {
            ret |= root instanceof FunctionOp;
        }
        return ret;
    }

    private static boolean isLastLevelStatementBlock(StatementBlock sb) {
        return !(sb instanceof FunctionStatementBlock) && !(sb instanceof WhileStatementBlock) && !(sb instanceof IfStatementBlock) && !(sb instanceof ForStatementBlock);
    }
}

