/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.template;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.TemplateBase;
import org.apache.sysml.hops.codegen.template.TemplateCell;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.runtime.matrix.data.Pair;

public class TemplateMultiAgg
extends TemplateCell {
    public TemplateMultiAgg() {
        super(TemplateBase.TemplateType.MultiAggTpl, false);
    }

    public TemplateMultiAgg(boolean closed) {
        super(TemplateBase.TemplateType.MultiAggTpl, closed);
    }

    @Override
    public boolean open(Hop hop) {
        return false;
    }

    @Override
    public boolean fuse(Hop hop, Hop input) {
        return false;
    }

    @Override
    public boolean merge(Hop hop, Hop input) {
        return false;
    }

    @Override
    public TemplateBase.CloseType close(Hop hop) {
        return TemplateBase.CloseType.CLOSED_INVALID;
    }

    @Override
    public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
        CPlanMemoTable.MemoTableEntry multiAgg = memo.getBest(hop.getHopID(), TemplateBase.TemplateType.MultiAggTpl);
        ArrayList<Hop> roots = new ArrayList<Hop>();
        for (int i = 0; i < 3; ++i) {
            if (!multiAgg.isPlanRef(i)) continue;
            roots.add(memo._hopRefs.get(multiAgg.input(i)));
        }
        Hop.resetVisitStatus(roots);
        HashSet<Hop> inHops = new HashSet<Hop>();
        HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
        for (Hop root : roots) {
            super.rConstructCplan(root, memo, tmp, inHops, compileLiterals);
        }
        Hop.resetVisitStatus(roots);
        List<Hop> sinHops = inHops.stream().filter(h -> !h.getDataType().isScalar() || !((CNode)tmp.get(h.getHopID())).isLiteral()).sorted(new TemplateCell.HopInputComparator()).collect(Collectors.toList());
        ArrayList<CNode> inputs = new ArrayList<CNode>();
        for (Hop in : sinHops) {
            inputs.add(tmp.get(in.getHopID()));
        }
        ArrayList<CNode> outputs = new ArrayList<CNode>();
        ArrayList<Hop.AggOp> aggOps = new ArrayList<Hop.AggOp>();
        for (Hop root : roots) {
            CNode node = tmp.get(root.getHopID());
            if (node instanceof CNodeData && ((CNodeData)inputs.get(0)).getHopID() != ((CNodeData)node).getHopID()) {
                node = new CNodeUnary(node, roots.get(0).getDim2() == 1L ? CNodeUnary.UnaryType.LOOKUP_R : CNodeUnary.UnaryType.LOOKUP_RC);
            }
            outputs.add(node);
            aggOps.add(TemplateUtils.getAggOp(root));
        }
        CNodeMultiAgg tpl = new CNodeMultiAgg(inputs, outputs);
        tpl.setAggOps(aggOps);
        tpl.setRootNodes(roots);
        return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
    }
}

