/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.template;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.PlanSelection;
import org.apache.sysml.hops.codegen.template.TemplateBase;
import org.apache.sysml.hops.codegen.template.TemplateRow;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;

public class PlanSelectionFuseCostBased
extends PlanSelection {
    private static final Log LOG = LogFactory.getLog((String)PlanSelectionFuseCostBased.class.getName());
    private static final double WRITE_BANDWIDTH = 2.147483648E9;
    private static final double READ_BANDWIDTH = 3.4359738368E10;
    private static final double COMPUTE_BANDWIDTH = 2.147483648E9 * (double)InfrastructureAnalyzer.getLocalParallelism();
    private static final IDSequence COST_ID = new IDSequence();
    private static final TemplateRow ROW_TPL = new TemplateRow();

    @Override
    public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
        Collection<HashSet<Long>> parts = PlanSelectionFuseCostBased.getConnectedSubGraphs(memo, roots);
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("Connected sub graphs: " + parts.size()));
        }
        for (HashSet<Long> hashSet : parts) {
            HashSet<Long> R = PlanSelectionFuseCostBased.getPartitionRootNodes(memo, hashSet);
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("Partition root points: " + Arrays.toString((Object[])R.toArray(new Long[0]))));
            }
            ArrayList<Long> M = PlanSelectionFuseCostBased.getMaterializationPoints(R, hashSet, memo);
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("Partition materialization points: " + Arrays.toString((Object[])M.toArray(new Long[0]))));
            }
            PlanSelectionFuseCostBased.createAndAddMultiAggPlans(memo, hashSet, R);
            this.selectPlans(memo, hashSet, R, M);
        }
        this.createAndAddMultiAggPlans(memo, roots);
        for (Map.Entry entry : this.getBestPlans().entrySet()) {
            memo.setDistinct((Long)entry.getKey(), (List)entry.getValue());
        }
    }

    private static Collection<HashSet<Long>> getConnectedSubGraphs(CPlanMemoTable memo, ArrayList<Hop> roots) {
        HashMap<Long, HashSet<Long>> refBy = new HashMap<Long, HashSet<Long>>();
        for (Map.Entry<Long, List<CPlanMemoTable.MemoTableEntry>> e : memo._plans.entrySet()) {
            for (CPlanMemoTable.MemoTableEntry memoTableEntry : e.getValue()) {
                for (int i = 0; i < 3; ++i) {
                    if (!memoTableEntry.isPlanRef(i)) continue;
                    if (!refBy.containsKey(memoTableEntry.input(i))) {
                        refBy.put(memoTableEntry.input(i), new HashSet());
                    }
                    refBy.get(memoTableEntry.input(i)).add(e.getKey());
                }
            }
        }
        ArrayList<HashSet<Long>> parts = new ArrayList<HashSet<Long>>();
        HashSet<Long> visited = new HashSet<Long>();
        for (Map.Entry entry : memo._plans.entrySet()) {
            HashSet<Long> part;
            if (refBy.containsKey(entry.getKey()) || (part = PlanSelectionFuseCostBased.rGetConnectedSubGraphs((Long)entry.getKey(), memo, refBy, visited, new HashSet<Long>())).isEmpty()) continue;
            parts.add(part);
        }
        return parts;
    }

    private static HashSet<Long> rGetConnectedSubGraphs(long hopID, CPlanMemoTable memo, HashMap<Long, HashSet<Long>> refBy, HashSet<Long> visited, HashSet<Long> partition) {
        if (visited.contains(hopID)) {
            return partition;
        }
        if (memo.contains(hopID)) {
            partition.add(hopID);
            visited.add(hopID);
        }
        if (refBy.containsKey(hopID)) {
            for (Long ref : refBy.get(hopID)) {
                PlanSelectionFuseCostBased.rGetConnectedSubGraphs(ref, memo, refBy, visited, partition);
            }
        }
        if (memo.contains(hopID)) {
            long[] refs = memo.getAllRefs(hopID);
            for (int i = 0; i < 3; ++i) {
                if (refs[i] == -1L) continue;
                PlanSelectionFuseCostBased.rGetConnectedSubGraphs(refs[i], memo, refBy, visited, partition);
            }
        }
        return partition;
    }

    private static HashSet<Long> getPartitionRootNodes(CPlanMemoTable memo, HashSet<Long> partition) {
        HashSet<Long> ix = new HashSet<Long>();
        for (Long hopID : partition) {
            if (!memo.contains(hopID)) continue;
            for (CPlanMemoTable.MemoTableEntry me : memo.get(hopID)) {
                ix.add(me.input1);
                ix.add(me.input2);
                ix.add(me.input3);
            }
        }
        HashSet<Long> roots = new HashSet<Long>();
        for (Long hopID : partition) {
            if (ix.contains(hopID)) continue;
            roots.add(hopID);
        }
        return roots;
    }

    private static ArrayList<Long> getMaterializationPoints(HashSet<Long> roots, HashSet<Long> partition, CPlanMemoTable memo) {
        ArrayList<Long> ret = new ArrayList<Long>();
        HashSet<Long> visited = new HashSet<Long>();
        for (Long hopID : roots) {
            PlanSelectionFuseCostBased.rCollectMaterializationPoints(memo._hopRefs.get(hopID), visited, partition, ret);
        }
        Iterator iter = ret.iterator();
        while (iter.hasNext()) {
            Long hopID;
            hopID = (Long)iter.next();
            if (roots.contains(hopID)) {
                iter.remove();
                continue;
            }
            if (!HopRewriteUtils.isTsmmInput(memo._hopRefs.get(hopID))) continue;
            iter.remove();
        }
        return ret;
    }

    private static void rCollectMaterializationPoints(Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M) {
        if (visited.contains(current.getHopID())) {
            return;
        }
        for (Hop c : current.getInput()) {
            PlanSelectionFuseCostBased.rCollectMaterializationPoints(c, visited, partition, M);
        }
        if (PlanSelectionFuseCostBased.isMaterializationPointCandidate(current, partition)) {
            M.add(current.getHopID());
        }
        visited.add(current.getHopID());
    }

    private static boolean isMaterializationPointCandidate(Hop hop, HashSet<Long> partition) {
        return hop.getParent().size() >= 2 && partition.contains(hop.getHopID());
    }

    private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
        HashSet<Long> refHops = new HashSet<Long>();
        for (Map.Entry<Long, List<CPlanMemoTable.MemoTableEntry>> e : memo._plans.entrySet()) {
            if (e.getValue().isEmpty()) continue;
            Hop hop = memo._hopRefs.get(e.getKey());
            for (Hop c : hop.getInput()) {
                refHops.add(c.getHopID());
            }
        }
        ArrayList<Long> fullAggs = new ArrayList<Long>();
        for (Long hopID : R) {
            Hop root = memo._hopRefs.get(hopID);
            if (refHops.contains(hopID) || !(root instanceof AggUnaryOp) || ((AggUnaryOp)root).getDirection() != Hop.Direction.RowCol) continue;
            fullAggs.add(hopID);
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("Found within-partition ua(RC) aggregations: " + Arrays.toString((Object[])fullAggs.toArray(new Long[0]))));
        }
        for (int i = 0; i < fullAggs.size(); i += 3) {
            int ito = Math.min(i + 3, fullAggs.size());
            if (ito - i < 2) continue;
            CPlanMemoTable.MemoTableEntry me = new CPlanMemoTable.MemoTableEntry(TemplateBase.TemplateType.MultiAggTpl, (Long)fullAggs.get(i), (Long)fullAggs.get(i + 1), ito - i == 3 ? (Long)fullAggs.get(i + 2) : -1L);
            if (PlanSelectionFuseCostBased.isValidMultiAggregate(memo, me)) {
                for (int j = i; j < ito; ++j) {
                    memo.add(memo._hopRefs.get(fullAggs.get(j)), me);
                    if (!LOG.isTraceEnabled()) continue;
                    LOG.trace((Object)("Added multiagg plan: " + fullAggs.get(j) + " " + me));
                }
                continue;
            }
            if (!LOG.isTraceEnabled()) continue;
            LOG.trace((Object)("Removed invalid multiagg plan: " + me));
        }
    }

    /*
     * WARNING - void declaration
     */
    private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
        HashSet<Long> fullAggs = new HashSet<Long>();
        Hop.resetVisitStatus(roots);
        for (Hop hop : roots) {
            PlanSelectionFuseCostBased.rCollectFullAggregates(hop, fullAggs);
        }
        Hop.resetVisitStatus(roots);
        Iterator iter = fullAggs.iterator();
        while (iter.hasNext()) {
            if (!memo.contains((Long)iter.next(), TemplateBase.TemplateType.MultiAggTpl)) continue;
            iter.remove();
        }
        if (fullAggs.size() <= 1) {
            return;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("Found across-partition ua(RC) aggregations: " + Arrays.toString((Object[])fullAggs.toArray(new Long[0]))));
        }
        List<Object> aggInfos = new ArrayList();
        for (Long l : fullAggs) {
            Hop hop = memo._hopRefs.get(l);
            AggregateInfo tmp = new AggregateInfo(hop);
            for (Hop c : hop.getInput()) {
                PlanSelectionFuseCostBased.rExtractAggregateInfo(memo, c, tmp, TemplateBase.TemplateType.CellTpl);
            }
            if (tmp._fusedInputs.isEmpty()) {
                tmp.addFusedInput(hop.getInput().get(0).getHopID());
            }
            aggInfos.add(tmp);
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)"Extracted across-partition ua(RC) aggregation info: ");
            for (AggregateInfo aggregateInfo : aggInfos) {
                LOG.trace((Object)aggregateInfo);
            }
        }
        aggInfos = aggInfos.stream().filter(a -> !a.containsMatMult).sorted(Comparator.comparing(a -> a._inputAggs.size())).collect(Collectors.toList());
        boolean converged = false;
        while (!converged) {
            void var7_13;
            Object var7_12 = null;
            for (int i = 0; i < aggInfos.size(); ++i) {
                AggregateInfo current = (AggregateInfo)aggInfos.get(i);
                for (int j = i + 1; j < aggInfos.size(); ++j) {
                    AggregateInfo that = (AggregateInfo)aggInfos.get(j);
                    if (!current.isMergable(that)) continue;
                    AggregateInfo aggregateInfo = current.merge(that);
                    aggInfos.remove(j);
                    --j;
                }
            }
            converged = var7_13 == null;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)"Merged across-partition ua(RC) aggregation info: ");
            for (AggregateInfo aggregateInfo : aggInfos) {
                LOG.trace((Object)aggregateInfo);
            }
        }
        for (AggregateInfo aggregateInfo : aggInfos) {
            if (aggregateInfo._aggregates.size() <= 1) continue;
            Long[] aggs = aggregateInfo._aggregates.keySet().toArray(new Long[0]);
            CPlanMemoTable.MemoTableEntry me = new CPlanMemoTable.MemoTableEntry(TemplateBase.TemplateType.MultiAggTpl, aggs[0], aggs[1], aggs.length > 2 ? aggs[2] : -1L);
            for (int i = 0; i < aggs.length; ++i) {
                memo.add(memo._hopRefs.get(aggs[i]), me);
                this.addBestPlan(aggs[i], me);
                if (!LOG.isTraceEnabled()) continue;
                LOG.trace((Object)("Added multiagg* plan: " + aggs[i] + " " + me));
            }
        }
    }

    private static boolean isValidMultiAggregate(CPlanMemoTable memo, CPlanMemoTable.MemoTableEntry me) {
        int i;
        boolean ret = true;
        Hop refSize = memo._hopRefs.get(me.input1).getInput().get(0);
        for (i = 1; ret && i < 3; ++i) {
            if (!me.isPlanRef(i)) continue;
            ret &= HopRewriteUtils.isEqualSize(refSize, memo._hopRefs.get(me.input(i)).getInput().get(0));
        }
        for (i = 0; ret && i < 3; ++i) {
            if (!me.isPlanRef(i)) continue;
            HashSet<Long> probe = new HashSet<Long>();
            for (int j = 0; j < 3; ++j) {
                if (i == j) continue;
                probe.add(me.input(j));
            }
            ret &= PlanSelectionFuseCostBased.rCheckMultiAggregate(memo._hopRefs.get(me.input(i)), probe);
        }
        return ret;
    }

    private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> probe) {
        boolean ret = true;
        for (Hop c : current.getInput()) {
            ret &= PlanSelectionFuseCostBased.rCheckMultiAggregate(c, probe);
        }
        return ret &= !probe.contains(current.getHopID());
    }

    private static void rCollectFullAggregates(Hop current, HashSet<Long> aggs) {
        if (current.isVisited()) {
            return;
        }
        if (HopRewriteUtils.isAggUnaryOp(current, Hop.AggOp.SUM, Hop.AggOp.SUM_SQ, Hop.AggOp.MIN, Hop.AggOp.MAX) && ((AggUnaryOp)current).getDirection() == Hop.Direction.RowCol) {
            aggs.add(current.getHopID());
        }
        for (Hop c : current.getInput()) {
            PlanSelectionFuseCostBased.rCollectFullAggregates(c, aggs);
        }
        current.setVisited();
    }

    private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop current, AggregateInfo aggInfo, TemplateBase.TemplateType type) {
        if (HopRewriteUtils.isAggUnaryOp(current, Hop.AggOp.SUM, Hop.AggOp.SUM_SQ, Hop.AggOp.MIN, Hop.AggOp.MAX) && ((AggUnaryOp)current).getDirection() == Hop.Direction.RowCol) {
            aggInfo.addInputAggregate(current.getHopID());
        }
        if (type != null && HopRewriteUtils.isMatrixMultiply(current)) {
            aggInfo.setContainsMatMult();
        }
        CPlanMemoTable.MemoTableEntry me = type != null ? memo.getBest(current.getHopID()) : null;
        for (int i = 0; i < current.getInput().size(); ++i) {
            Hop c = current.getInput().get(i);
            if (me != null && me.isPlanRef(i)) {
                PlanSelectionFuseCostBased.rExtractAggregateInfo(memo, c, aggInfo, type);
                continue;
            }
            if (type != null && c.getDataType().isMatrix()) {
                aggInfo.addFusedInput(c.getHopID());
            }
            PlanSelectionFuseCostBased.rExtractAggregateInfo(memo, c, aggInfo, null);
        }
    }

    private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M) {
        if (M == null || M.isEmpty()) {
            for (Long hopID : R) {
                this.rSelectPlansFuseAll(memo, memo._hopRefs.get(hopID), null, partition);
            }
        } else {
            Object plan;
            HashMap<Long, Double> computeCosts = new HashMap<Long, Double>();
            for (Long hopID : R) {
                PlanSelectionFuseCostBased.rGetComputeCosts(memo._hopRefs.get(hopID), partition, computeCosts);
            }
            int len = (int)Math.pow(2.0, M.size());
            Object bestPlan = null;
            double bestC = Double.MAX_VALUE;
            for (int i = 0; i < len; ++i) {
                plan = PlanSelectionFuseCostBased.createAssignment(M.size(), i);
                double C = PlanSelectionFuseCostBased.getPlanCost(memo, partition, R, M, (boolean[])plan, computeCosts);
                if (LOG.isTraceEnabled()) {
                    LOG.trace((Object)("Enum: " + Arrays.toString((boolean[])plan) + " -> " + C));
                }
                if (bestPlan != null && !(C < bestC)) continue;
                bestC = C;
                bestPlan = plan;
                if (!LOG.isTraceEnabled()) continue;
                LOG.trace((Object)"Enum: Found new best plan.");
            }
            HashSet<Long> visited = new HashSet<Long>();
            plan = R.iterator();
            while (plan.hasNext()) {
                Long hopID = (Long)plan.next();
                PlanSelectionFuseCostBased.rPruneSuboptimalPlans(memo, memo._hopRefs.get(hopID), visited, partition, M, (boolean[])bestPlan);
            }
            HashSet<Long> visited2 = new HashSet<Long>();
            for (Long hopID : R) {
                PlanSelectionFuseCostBased.rPruneInvalidPlans(memo, memo._hopRefs.get(hopID), visited2, partition, M, (boolean[])bestPlan);
            }
            for (Long hopID : R) {
                this.rSelectPlansFuseAll(memo, memo._hopRefs.get(hopID), null, partition);
            }
        }
    }

    private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) {
        if (visited.contains(current.getHopID())) {
            return;
        }
        long hopID = current.getHopID();
        if (partition.contains(hopID) && memo.contains(hopID)) {
            Iterator<CPlanMemoTable.MemoTableEntry> iter = memo.get(hopID).iterator();
            while (iter.hasNext()) {
                CPlanMemoTable.MemoTableEntry me = iter.next();
                if (PlanSelectionFuseCostBased.hasNoRefToMaterialization(me, M, plan) || me.type == TemplateBase.TemplateType.OuterProdTpl) continue;
                iter.remove();
                if (!LOG.isTraceEnabled()) continue;
                LOG.trace((Object)("Removed memo table entry: " + me));
            }
        }
        for (Hop c : current.getInput()) {
            PlanSelectionFuseCostBased.rPruneSuboptimalPlans(memo, c, visited, partition, M, plan);
        }
        visited.add(current.getHopID());
    }

    private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan) {
        if (visited.contains(current.getHopID())) {
            return;
        }
        for (Hop c : current.getInput()) {
            PlanSelectionFuseCostBased.rPruneInvalidPlans(memo, c, visited, partition, M, plan);
        }
        long hopID = current.getHopID();
        if (partition.contains(hopID) && memo.contains(hopID, TemplateBase.TemplateType.RowTpl)) {
            for (CPlanMemoTable.MemoTableEntry me : memo.get(hopID)) {
                if (me.type != TemplateBase.TemplateType.RowTpl) continue;
                if (!me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current)) {
                    me.type = TemplateBase.TemplateType.CellTpl;
                    if (LOG.isTraceEnabled()) {
                        LOG.trace((Object)("Converted leaf memo table entry from row to cell: " + me));
                    }
                }
                if (!me.hasPlanRef() || ROW_TPL.open(current)) continue;
                boolean hasRowInput = false;
                for (int i = 0; i < 3; ++i) {
                    if (!me.isPlanRef(i)) continue;
                    hasRowInput |= memo.contains(me.input(i), TemplateBase.TemplateType.RowTpl);
                }
                if (hasRowInput) continue;
                me.type = TemplateBase.TemplateType.CellTpl;
                if (!LOG.isTraceEnabled()) continue;
                LOG.trace((Object)("Converted inner memo table entry from row to cell: " + me));
            }
        }
        visited.add(current.getHopID());
    }

    private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateBase.TemplateType currentType, HashSet<Long> partition) {
        if (this.isVisited(current.getHopID(), currentType) || !partition.contains(current.getHopID())) {
            return;
        }
        if (memo.contains(current.getHopID())) {
            HashSet<CPlanMemoTable.MemoTableEntry> rmSet = new HashSet<CPlanMemoTable.MemoTableEntry>();
            List<CPlanMemoTable.MemoTableEntry> hopP = memo.get(current.getHopID());
            for (CPlanMemoTable.MemoTableEntry e1 : hopP) {
                for (CPlanMemoTable.MemoTableEntry e2 : hopP) {
                    if (e1 == e2 || !e1.subsumes(e2)) continue;
                    rmSet.add(e2);
                }
            }
            memo.remove(current, rmSet);
        }
        CPlanMemoTable.MemoTableEntry best = null;
        if (memo.contains(current.getHopID())) {
            best = currentType == null ? (CPlanMemoTable.MemoTableEntry)memo.get(current.getHopID()).stream().filter(p -> PlanSelectionFuseCostBased.isValid(p, current)).min(new PlanSelection.BasicPlanComparator()).orElse(null) : (CPlanMemoTable.MemoTableEntry)memo.get(current.getHopID()).stream().filter(p -> p.type == currentType || p.type == TemplateBase.TemplateType.CellTpl).min(Comparator.comparing(p -> 7 - (p.type == currentType ? 4 : 0) - p.countPlanRefs())).orElse(null);
            this.addBestPlan(current.getHopID(), best);
        }
        for (int i = 0; i < current.getInput().size(); ++i) {
            TemplateBase.TemplateType pref = best != null && best.isPlanRef(i) ? best.type : null;
            this.rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition);
        }
        this.setVisited(current.getHopID(), currentType);
    }

    private static boolean[] createAssignment(int len, int pos) {
        boolean[] ret = new boolean[len];
        int tmp = pos;
        for (int i = 0; i < len; ++i) {
            ret[i] = tmp < (int)Math.pow(2.0, len - i - 1);
            tmp = (int)((double)tmp % Math.pow(2.0, len - i - 1));
        }
        return ret;
    }

    private static double getPlanCost(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M, boolean[] plan, HashMap<Long, Double> computeCosts) {
        HashSet<Pair<Long, Long>> visited = new HashSet<Pair<Long, Long>>();
        double costs = 0.0;
        for (Long hopID : R) {
            costs += PlanSelectionFuseCostBased.rGetPlanCosts(memo, memo._hopRefs.get(hopID), visited, partition, M, plan, computeCosts, null, null);
        }
        return costs;
    }

    private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, HashSet<Pair<Long, Long>> visited, HashSet<Long> partition, ArrayList<Long> M, boolean[] plan, HashMap<Long, Double> computeCosts, CostVector costsCurrent, TemplateBase.TemplateType currentType) {
        CostVector costVect;
        Pair tag = Pair.of((Object)current.getHopID(), (Object)(costsCurrent == null ? 0L : costsCurrent.ID));
        if (visited.contains(tag)) {
            return 0.0;
        }
        visited.add((Pair<Long, Long>)tag);
        CPlanMemoTable.MemoTableEntry best = null;
        boolean opened = false;
        if (memo.contains(current.getHopID())) {
            if (currentType == null) {
                best = memo.get(current.getHopID()).stream().filter(p -> PlanSelectionFuseCostBased.isValid(p, current)).filter(p -> PlanSelectionFuseCostBased.hasNoRefToMaterialization(p, M, plan)).min(new PlanSelection.BasicPlanComparator()).orElse(null);
                opened = true;
            } else {
                best = memo.get(current.getHopID()).stream().filter(p -> p.type == currentType || p.type == TemplateBase.TemplateType.CellTpl).filter(p -> PlanSelectionFuseCostBased.hasNoRefToMaterialization(p, M, plan)).min(Comparator.comparing(p -> 7 - (p.type == currentType ? 4 : 0) - p.countPlanRefs())).orElse(null);
            }
        }
        CostVector costVector = costVect = !opened ? costsCurrent : new CostVector(Math.max(current.getDim1(), 1L) * Math.max(current.getDim2(), 1L));
        if (partition.contains(current.getHopID())) {
            costVect.computeCosts += computeCosts.get(current.getHopID()).doubleValue();
        }
        double costs = 0.0;
        for (int i = 0; i < current.getInput().size(); ++i) {
            Hop c = current.getInput().get(i);
            if (best != null && best.isPlanRef(i)) {
                costs += PlanSelectionFuseCostBased.rGetPlanCosts(memo, c, visited, partition, M, plan, computeCosts, costVect, best.type);
                continue;
            }
            costs += PlanSelectionFuseCostBased.rGetPlanCosts(memo, c, visited, partition, M, plan, computeCosts, null, null);
            if (costVect == null || !c.getDataType().isMatrix()) continue;
            costVect.addInputSize(c.getHopID(), Math.max(c.getDim1(), 1L) * Math.max(c.getDim2(), 1L));
        }
        if (partition.contains(current.getHopID())) {
            if (opened) {
                if (LOG.isTraceEnabled()) {
                    LOG.trace((Object)("Cost vector for fused operator: " + costVect));
                }
                costs += costVect.outSize * 8.0 / 2.147483648E9;
                costs += Math.max(costVect.computeCosts * costVect.getMaxInputSize() / COMPUTE_BANDWIDTH, costVect.getSumInputSizes() * 8.0 / 3.4359738368E10);
            } else if (PlanSelectionFuseCostBased.hasNonPartitionConsumer(current, partition)) {
                costs += PlanSelectionFuseCostBased.rGetPlanCosts(memo, current, visited, partition, M, plan, computeCosts, null, null);
            }
        }
        if (costs < 0.0 || Double.isNaN(costs) || Double.isInfinite(costs)) {
            throw new RuntimeException("Wrong cost estimate: " + costs);
        }
        return costs;
    }

    private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) {
        double costs;
        block44: {
            block46: {
                block45: {
                    block43: {
                        if (computeCosts.containsKey(current.getHopID())) {
                            return;
                        }
                        for (Hop c : current.getInput()) {
                            PlanSelectionFuseCostBased.rGetComputeCosts(c, partition, computeCosts);
                        }
                        costs = 0.0;
                        if (!(current instanceof UnaryOp)) break block43;
                        switch (((UnaryOp)current).getOp()) {
                            case ABS: 
                            case ROUND: 
                            case CEIL: 
                            case FLOOR: 
                            case SIGN: 
                            case SELP: {
                                costs = 1.0;
                                break block44;
                            }
                            case SPROP: 
                            case SQRT: {
                                costs = 2.0;
                                break block44;
                            }
                            case EXP: {
                                costs = 18.0;
                                break block44;
                            }
                            case SIGMOID: {
                                costs = 21.0;
                                break block44;
                            }
                            case LOG: 
                            case LOG_NZ: {
                                costs = 32.0;
                                break block44;
                            }
                            case NCOL: 
                            case NROW: 
                            case PRINT: 
                            case CAST_AS_BOOLEAN: 
                            case CAST_AS_DOUBLE: 
                            case CAST_AS_INT: 
                            case CAST_AS_MATRIX: 
                            case CAST_AS_SCALAR: {
                                costs = 1.0;
                                break block44;
                            }
                            case SIN: {
                                costs = 18.0;
                                break block44;
                            }
                            case COS: {
                                costs = 22.0;
                                break block44;
                            }
                            case TAN: {
                                costs = 42.0;
                                break block44;
                            }
                            case ASIN: {
                                costs = 93.0;
                                break block44;
                            }
                            case ACOS: {
                                costs = 103.0;
                                break block44;
                            }
                            case ATAN: {
                                costs = 40.0;
                                break block44;
                            }
                            case CUMSUM: 
                            case CUMMIN: 
                            case CUMMAX: 
                            case CUMPROD: {
                                costs = 1.0;
                                break block44;
                            }
                            default: {
                                throw new RuntimeException("Cost model not implemented yet for: " + (Object)((Object)((UnaryOp)current).getOp()));
                            }
                        }
                    }
                    if (!(current instanceof BinaryOp)) break block45;
                    switch (((BinaryOp)current).getOp()) {
                        case MULT: 
                        case PLUS: 
                        case MINUS: 
                        case MIN: 
                        case MAX: 
                        case AND: 
                        case OR: 
                        case EQUAL: 
                        case NOTEQUAL: 
                        case LESS: 
                        case LESSEQUAL: 
                        case GREATER: 
                        case GREATEREQUAL: 
                        case CBIND: 
                        case RBIND: {
                            costs = 1.0;
                            break block44;
                        }
                        case INTDIV: {
                            costs = 6.0;
                            break block44;
                        }
                        case MODULUS: {
                            costs = 8.0;
                            break block44;
                        }
                        case DIV: {
                            costs = 22.0;
                            break block44;
                        }
                        case LOG: 
                        case LOG_NZ: {
                            costs = 32.0;
                            break block44;
                        }
                        case POW: {
                            costs = HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2.0) ? 1 : 16;
                            break block44;
                        }
                        case MINUS_NZ: 
                        case MINUS1_MULT: {
                            costs = 2.0;
                            break block44;
                        }
                        default: {
                            throw new RuntimeException("Cost model not implemented yet for: " + (Object)((Object)((BinaryOp)current).getOp()));
                        }
                    }
                }
                if (!(current instanceof TernaryOp)) break block46;
                switch (((TernaryOp)current).getOp()) {
                    case PLUS_MULT: 
                    case MINUS_MULT: {
                        costs = 2.0;
                        break block44;
                    }
                    default: {
                        throw new RuntimeException("Cost model not implemented yet for: " + (Object)((Object)((TernaryOp)current).getOp()));
                    }
                }
            }
            if (current instanceof ParameterizedBuiltinOp) {
                costs = 1.0;
            } else if (current instanceof IndexingOp) {
                costs = 1.0;
            } else if (current instanceof ReorgOp) {
                costs = 1.0;
            } else if (current instanceof AggBinaryOp) {
                costs = 2.0;
            } else if (current instanceof AggUnaryOp) {
                switch (((AggUnaryOp)current).getOp()) {
                    case SUM: {
                        costs = 4.0;
                        break;
                    }
                    case SUM_SQ: {
                        costs = 5.0;
                        break;
                    }
                    case MIN: 
                    case MAX: {
                        costs = 1.0;
                        break;
                    }
                    default: {
                        throw new RuntimeException("Cost model not implemented yet for: " + (Object)((Object)((AggUnaryOp)current).getOp()));
                    }
                }
            }
        }
        computeCosts.put(current.getHopID(), costs);
    }

    private static boolean hasNoRefToMaterialization(CPlanMemoTable.MemoTableEntry me, ArrayList<Long> M, boolean[] plan) {
        boolean ret = true;
        for (int i = 0; ret && i < 3; ret &= !M.contains(me.input(i)) || !plan[M.indexOf(me.input(i))], ++i) {
        }
        return ret;
    }

    private static boolean hasNonPartitionConsumer(Hop hop, HashSet<Long> partition) {
        boolean ret = false;
        for (Hop p : hop.getParent()) {
            ret |= !partition.contains(p.getHopID());
        }
        return ret;
    }

    private static class AggregateInfo {
        public final HashMap<Long, Hop> _aggregates;
        public final HashSet<Long> _inputAggs = new HashSet();
        public final HashSet<Long> _fusedInputs = new HashSet();
        public boolean containsMatMult = false;

        public AggregateInfo(Hop aggregate) {
            this._aggregates = new HashMap();
            this._aggregates.put(aggregate.getHopID(), aggregate);
        }

        public void addInputAggregate(long hopID) {
            this._inputAggs.add(hopID);
        }

        public void addFusedInput(long hopID) {
            this._fusedInputs.add(hopID);
        }

        public void setContainsMatMult() {
            this.containsMatMult = true;
        }

        public boolean isMergable(AggregateInfo that) {
            boolean ret = this._aggregates.size() < 3 && this._aggregates.size() + that._aggregates.size() <= 3;
            for (Long hopID : that._aggregates.keySet()) {
                ret &= !this._inputAggs.contains(hopID);
            }
            for (Long hopID : this._aggregates.keySet()) {
                ret &= !that._inputAggs.contains(hopID);
            }
            return (ret &= !CollectionUtils.intersection(this._fusedInputs, that._fusedInputs).isEmpty()) && HopRewriteUtils.isEqualSize(this._aggregates.values().iterator().next().getInput().get(0), that._aggregates.values().iterator().next().getInput().get(0));
        }

        public AggregateInfo merge(AggregateInfo that) {
            this._aggregates.putAll(that._aggregates);
            this._inputAggs.addAll(that._inputAggs);
            this._fusedInputs.addAll(that._fusedInputs);
            return this;
        }

        public String toString() {
            return "[" + Arrays.toString((Object[])this._aggregates.keySet().toArray(new Long[0])) + ": {" + Arrays.toString((Object[])this._inputAggs.toArray(new Long[0])) + "},{" + Arrays.toString((Object[])this._fusedInputs.toArray(new Long[0])) + "}]";
        }
    }

    private static class CostVector {
        public final long ID;
        public final double outSize;
        public double computeCosts = 0.0;
        public final HashMap<Long, Double> inSizes = new HashMap();

        public CostVector(double outputSize) {
            this.ID = COST_ID.getNextID();
            this.outSize = outputSize;
        }

        public void addInputSize(long hopID, double inputSize) {
            this.inSizes.put(hopID, inputSize);
        }

        public double getSumInputSizes() {
            return this.inSizes.values().stream().mapToDouble(d -> d).sum();
        }

        public double getMaxInputSize() {
            return this.inSizes.values().stream().mapToDouble(d -> d).max().orElse(0.0);
        }

        public String toString() {
            return "[" + this.outSize + ", " + this.computeCosts + ", {" + Arrays.toString((Object[])this.inSizes.keySet().toArray(new Long[0])) + ", " + Arrays.toString((Object[])this.inSizes.values().toArray(new Double[0])) + "}]";
        }
    }
}

