/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.hops.rewrite.ProgramRewriter;
import org.apache.sysml.hops.rewrite.RewriteCommonSubexpressionElimination;
import org.apache.sysml.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysml.parser.ExternalFunctionStatement;
import org.apache.sysml.parser.FunctionStatementBlock;
import org.apache.sysml.parser.StatementBlock;
import org.apache.sysml.parser.VariableSet;

public class RewriteMergeBlockSequence
extends StatementBlockRewriteRule {
    private ProgramRewriter rewriter = new ProgramRewriter(new RewriteCommonSubexpressionElimination(true));

    @Override
    public boolean createsSplitDag() {
        return false;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) throws HopsException {
        return Arrays.asList(sb);
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) throws HopsException {
        if (sbs == null || sbs.isEmpty()) {
            return sbs;
        }
        ArrayList<StatementBlock> tmpList = new ArrayList<StatementBlock>(sbs);
        boolean merged = true;
        block0: while (merged) {
            merged = false;
            for (int i = 0; i < tmpList.size() - 1; ++i) {
                StatementBlock sb1 = tmpList.get(i);
                StatementBlock sb2 = tmpList.get(i + 1);
                if (!HopRewriteUtils.isLastLevelStatementBlock(sb1) || !HopRewriteUtils.isLastLevelStatementBlock(sb2) || sb1.isSplitDag() || sb2.isSplitDag() || RewriteMergeBlockSequence.hasExternalFunctionOpRootWithSideEffect(sb1) && RewriteMergeBlockSequence.hasExternalFunctionOpRootWithSideEffect(sb2) || RewriteMergeBlockSequence.hasFunctionOpRoot(sb1) && RewriteMergeBlockSequence.hasFunctionIOConflict(sb1, sb2) || RewriteMergeBlockSequence.hasFunctionOpRoot(sb2) && RewriteMergeBlockSequence.hasFunctionIOConflict(sb2, sb1)) continue;
                ArrayList<Hop> sb1Hops = sb1.getHops();
                ArrayList<Hop> sb2Hops = sb2.getHops();
                ArrayList<Hop> newHops = new ArrayList<Hop>();
                Hop.resetVisitStatus(sb2Hops);
                HashMap<String, Hop> treads = new HashMap<String, Hop>();
                HashMap<String, Hop> twrites = new HashMap<String, Hop>();
                for (Hop root : sb2Hops) {
                    this.rCollectTransientReadWrites(root, treads, twrites);
                }
                Hop.resetVisitStatus(sb2Hops);
                Hop.resetVisitStatus(sb1Hops);
                for (Hop root : sb1Hops) {
                    if (HopRewriteUtils.isData(root, Hop.DataOpTypes.TRANSIENTWRITE) && treads.containsKey(root.getName())) {
                        Hop tread = treads.get(root.getName());
                        Hop in = root.getInput().get(0);
                        for (Hop parent : new ArrayList<Hop>(tread.getParent())) {
                            HopRewriteUtils.replaceChildReference(parent, tread, in);
                        }
                        HopRewriteUtils.removeAllChildReferences(root);
                        if (twrites.containsKey(root.getName()) || !sb2.liveOut().containsVariable(root.getName())) continue;
                        newHops.add(HopRewriteUtils.createDataOp(root.getName(), in, Hop.DataOpTypes.TRANSIENTWRITE));
                        continue;
                    }
                    if (HopRewriteUtils.isData(root, Hop.DataOpTypes.TRANSIENTWRITE) && (twrites.containsKey(root.getName()) || !sb2.liveOut().containsVariable(root.getName()))) continue;
                    newHops.add(root);
                }
                sb1Hops.clear();
                newHops.addAll(sb2Hops);
                sb2.setHops(newHops);
                Hop.resetVisitStatus(sb2.getHops());
                this.rewriter.rewriteHopDAG(sb2.getHops(), new ProgramRewriteStatus());
                sb2.setLiveIn(sb1.liveIn());
                sb2.setGen(VariableSet.minus(VariableSet.union(sb1.getGen(), sb2.getGen()), sb1.getKill()));
                sb2.setKill(VariableSet.union(sb1.getKill(), sb2.getKill()));
                sb2.setReadVariables(VariableSet.union(sb1.variablesRead(), sb2.variablesRead()));
                sb2.setUpdatedVariables(VariableSet.union(sb1.variablesUpdated(), sb2.variablesUpdated()));
                LOG.debug("Applied mergeStatementBlockSequences (blocks of lines " + sb1.getBeginLine() + "-" + sb1.getEndLine() + " and " + sb2.getBeginLine() + "-" + sb2.getEndLine() + ").");
                sb2.setBeginLine(sb1.getBeginLine());
                sb2.setBeginColumn(sb1.getBeginColumn());
                tmpList.remove(i);
                merged = true;
                continue block0;
            }
        }
        return tmpList;
    }

    private void rCollectTransientReadWrites(Hop current, HashMap<String, Hop> treads, HashMap<String, Hop> twrites) {
        if (current.isVisited()) {
            return;
        }
        for (Hop c : current.getInput()) {
            this.rCollectTransientReadWrites(c, treads, twrites);
        }
        if (HopRewriteUtils.isData(current, Hop.DataOpTypes.TRANSIENTREAD)) {
            treads.put(current.getName(), current);
        } else if (HopRewriteUtils.isData(current, Hop.DataOpTypes.TRANSIENTWRITE)) {
            twrites.put(current.getName(), current);
        } else if (current instanceof FunctionOp) {
            for (String output : ((FunctionOp)current).getOutputVariableNames()) {
                twrites.put(output, null);
            }
        }
        current.setVisited();
    }

    private static boolean hasFunctionOpRoot(StatementBlock sb) throws HopsException {
        if (sb == null || sb.getHops() == null) {
            return false;
        }
        boolean ret = false;
        for (Hop root : sb.getHops()) {
            ret |= root instanceof FunctionOp;
        }
        return ret;
    }

    private static boolean hasExternalFunctionOpRootWithSideEffect(StatementBlock sb) throws HopsException {
        if (sb == null || sb.getHops() == null) {
            return false;
        }
        for (Hop root : sb.getHops()) {
            FunctionStatementBlock fsb;
            if (!(root instanceof FunctionOp) || (fsb = sb.getDMLProg().getFunctionStatementBlock(((FunctionOp)root).getFunctionKey())) == null || !(fsb.getStatement(0) instanceof ExternalFunctionStatement) || !((ExternalFunctionStatement)fsb.getStatement(0)).hasSideEffects()) continue;
            return true;
        }
        return false;
    }

    private static boolean hasFunctionIOConflict(StatementBlock sb1, StatementBlock sb2) throws HopsException {
        HashSet<String> outSb1 = new HashSet<String>();
        for (Hop root : sb1.getHops()) {
            if (!(root instanceof FunctionOp)) continue;
            outSb1.addAll(Arrays.asList(((FunctionOp)root).getOutputVariableNames()));
        }
        return sb2.variablesRead().containsAnyName(outSb1) || sb2.variablesUpdated().containsAnyName(outSb1);
    }
}

