/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.opt;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.opt.InterestingPoint;
import org.apache.sysml.hops.codegen.opt.PlanAnalyzer;
import org.apache.sysml.hops.codegen.opt.PlanPartition;
import org.apache.sysml.hops.codegen.opt.PlanSelection;
import org.apache.sysml.hops.codegen.opt.ReachabilityGraph;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.TemplateBase;
import org.apache.sysml.hops.codegen.template.TemplateOuterProduct;
import org.apache.sysml.hops.codegen.template.TemplateRow;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;
import org.apache.sysml.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.utils.Statistics;

public class PlanSelectionFuseCostBasedV2
extends PlanSelection {
    private static final Log LOG = LogFactory.getLog((String)PlanSelectionFuseCostBasedV2.class.getName());
    private static final double WRITE_BANDWIDTH_IO = 5.36870912E8;
    private static final double WRITE_BANDWIDTH_MEM = 2.147483648E9;
    private static final double READ_BANDWIDTH_MEM = 3.4359738368E10;
    private static final double READ_BANDWIDTH_BROADCAST = 1.34217728E8;
    private static final double COMPUTE_BANDWIDTH = 2.147483648E9 * (double)InfrastructureAnalyzer.getLocalParallelism();
    private static final double SPARSE_SAFE_SPARSITY_EST = 0.1;
    public static final double COST_MIN_EPS = 0.01;
    public static final int COST_MIN_EPS_NUM_POINTS = 20;
    private static final int PLAN_CACHE_NUM_POINTS = 10;
    private static final int PLAN_CACHE_SIZE = 1024;
    private static final LinkedHashMap<PartitionSignature, boolean[]> _planCache = new LinkedHashMap();
    public static boolean COST_PRUNING = true;
    public static boolean STRUCTURAL_PRUNING = true;
    public static boolean PLAN_CACHING = true;
    private static final TemplateRow ROW_TPL = new TemplateRow();
    private final IDSequence COST_ID = new IDSequence();

    @Override
    public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
        Collection<PlanPartition> parts = PlanAnalyzer.analyzePlanPartitions(memo, roots, true);
        int sumMatPoints = 0;
        for (PlanPartition planPartition : parts) {
            PlanSelectionFuseCostBasedV2.createAndAddMultiAggPlans(memo, planPartition.getPartition(), planPartition.getRoots());
            this.selectPlans(memo, planPartition);
            sumMatPoints += planPartition.getMatPointsExt().length;
        }
        this.createAndAddMultiAggPlans(memo, roots);
        for (Map.Entry entry : this.getBestPlans().entrySet()) {
            memo.setDistinct((Long)entry.getKey(), (List)entry.getValue());
        }
        if (DMLScript.STATISTICS) {
            if (sumMatPoints >= 63) {
                LOG.warn((Object)("Long overflow on maintaining codegen statistics for a DAG with " + sumMatPoints + " interesting points."));
            }
            Statistics.incrementCodegenEnumAll(UtilFunctions.pow(2, sumMatPoints));
        }
    }

    private void selectPlans(CPlanMemoTable memo, PlanPartition part) {
        PlanSelectionFuseCostBasedV2.pruneInvalidAndSpecialCasePlans(memo, part);
        if (part.getMatPointsExt() == null || part.getMatPointsExt().length == 0) {
            for (Long hopID : part.getRoots()) {
                this.rSelectPlansFuseAll(memo, memo.getHopRefs().get(hopID), null, part.getPartition());
            }
        } else {
            ReachabilityGraph rgraph;
            HashMap<Long, Double> computeCosts = new HashMap<Long, Double>();
            for (Long hopID : part.getPartition()) {
                PlanSelectionFuseCostBasedV2.getComputeCosts(memo.getHopRefs().get(hopID), computeCosts);
            }
            StaticCosts costs = new StaticCosts(computeCosts, PlanSelectionFuseCostBasedV2.sumComputeCost(computeCosts), PlanSelectionFuseCostBasedV2.getReadCost(part, memo), PlanSelectionFuseCostBasedV2.getWriteCost(part.getRoots(), memo), PlanSelectionFuseCostBasedV2.minOuterSparsity(part, memo));
            ReachabilityGraph reachabilityGraph = rgraph = STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null;
            if (STRUCTURAL_PRUNING) {
                part.setMatPointsExt(rgraph.getSortedSearchSpace());
                for (Long hopID : part.getPartition()) {
                    memo.pruneRedundant(hopID, true, part.getMatPointsExt());
                }
            }
            boolean[] bestPlan = this.enumPlans(memo, part, costs, rgraph, part.getMatPointsExt(), 0);
            HashSet<Long> visited = new HashSet<Long>();
            for (Long hopID : part.getRoots()) {
                PlanSelectionFuseCostBasedV2.rPruneSuboptimalPlans(memo, memo.getHopRefs().get(hopID), visited, part, part.getMatPointsExt(), bestPlan);
            }
            HashSet<Long> visited2 = new HashSet<Long>();
            for (Long hopID : part.getRoots()) {
                PlanSelectionFuseCostBasedV2.rPruneInvalidPlans(memo, memo.getHopRefs().get(hopID), visited2, part, bestPlan);
            }
            for (Long hopID : part.getRoots()) {
                this.rSelectPlansFuseAll(memo, memo.getHopRefs().get(hopID), null, part.getPartition());
            }
        }
    }

    /*
     * Unable to fully structure code
     */
    private boolean[] enumPlans(CPlanMemoTable memo, PlanPartition part, StaticCosts costs, ReachabilityGraph rgraph, InterestingPoint[] matPoints, int off) {
        Mlen = matPoints.length - off;
        len = UtilFunctions.pow(2, Mlen);
        numEvalPlans = 2L;
        numEvalPartPlans = 0L;
        plan0 = PlanSelectionFuseCostBasedV2.createAssignment(Mlen, off, 0L);
        planN = PlanSelectionFuseCostBasedV2.createAssignment(Mlen, off, len - 1L);
        C0 = this.getPlanCost(memo, part, matPoints, plan0, costs._computeCosts, 1.7976931348623157E308);
        bestPlan = C0 <= (CN = this.getPlanCost(memo, part, matPoints, planN, costs._computeCosts, 1.7976931348623157E308)) ? plan0 : planN;
        bestC = Math.min(C0, CN);
        v0 = evalRemain = Mlen < 20 || PlanSelectionFuseCostBasedV2.COST_PRUNING == false || bestC > 1.01 * costs.getMinCosts();
        if (PlanSelectionFuseCostBasedV2.LOG.isTraceEnabled()) {
            PlanSelectionFuseCostBasedV2.LOG.trace((Object)("Enum opening: " + Arrays.toString(bestPlan) + " -> " + bestC));
        }
        if (!evalRemain) {
            PlanSelectionFuseCostBasedV2.LOG.warn((Object)("Skip enum for |M|=" + Mlen + ", C=" + bestC + ", Cmin=" + costs.getMinCosts()));
        }
        pKey = null;
        if (PlanSelectionFuseCostBasedV2.probePlanCache(matPoints) && (plan = PlanSelectionFuseCostBasedV2.getPlan(pKey = new PartitionSignature(part, matPoints.length, costs, C0, CN))) != null) {
            Statistics.incrementCodegenEnumAllP(rgraph != null || PlanSelectionFuseCostBasedV2.STRUCTURAL_PRUNING == false ? len : 0L);
            return plan;
        }
        i = 1L;
        while (i < len - 1L & evalRemain) {
            block17: {
                plan = PlanSelectionFuseCostBasedV2.createAssignment(Mlen, off, i);
                pskip = 0L;
                if (!PlanSelectionFuseCostBasedV2.STRUCTURAL_PRUNING || rgraph == null || !rgraph.isCutSet(plan)) break block17;
                pskip = rgraph.getNumSkipPlans(plan);
                if (PlanSelectionFuseCostBasedV2.LOG.isTraceEnabled()) {
                    PlanSelectionFuseCostBasedV2.LOG.trace((Object)("Enum: Structural pruning for cut set: " + rgraph.getCutSet(plan)));
                }
                prob = rgraph.getSubproblems(plan);
                for (j = 0; j < prob.length; ++j) {
                    if (PlanSelectionFuseCostBasedV2.LOG.isTraceEnabled()) {
                        PlanSelectionFuseCostBasedV2.LOG.trace((Object)("Enum: Subproblem " + (j + 1) + "/" + prob.length + ": " + prob[j]));
                    }
                    bestTmp = this.enumPlans(memo, part, costs, null, prob[j].freeMat, prob[j].offset);
                    LibSpoofPrimitives.vectWrite(bestTmp, plan, prob[j].freePos);
                }
                ** GOTO lbl-1000
            }
            if (PlanSelectionFuseCostBasedV2.COST_PRUNING && (lbC = PlanSelectionFuseCostBasedV2.getLowerBoundCosts(part, matPoints, memo, costs, plan)) >= bestC) {
                skip = PlanSelectionFuseCostBasedV2.getNumSkipPlans(plan);
                if (PlanSelectionFuseCostBasedV2.LOG.isTraceEnabled()) {
                    PlanSelectionFuseCostBasedV2.LOG.trace((Object)("Enum: Skip " + skip + " plans (by cost)."));
                }
                i += skip - 1L;
            } else lbl-1000:
            // 2 sources

            {
                pCBound = PlanSelectionFuseCostBasedV2.COST_PRUNING != false ? bestC : 1.7976931348623157E308;
                C = this.getPlanCost(memo, part, matPoints, plan, costs._computeCosts, pCBound);
                if (PlanSelectionFuseCostBasedV2.LOG.isTraceEnabled()) {
                    PlanSelectionFuseCostBasedV2.LOG.trace((Object)("Enum: " + Arrays.toString(plan) + " -> " + C));
                }
                numEvalPartPlans += C == Infinity ? 1L : 0L;
                ++numEvalPlans;
                if (bestPlan == null || C < bestC) {
                    bestC = C;
                    bestPlan = plan;
                    if (PlanSelectionFuseCostBasedV2.LOG.isTraceEnabled()) {
                        PlanSelectionFuseCostBasedV2.LOG.trace((Object)"Enum: Found new best plan.");
                    }
                }
                i += pskip;
                if (pskip != 0L && PlanSelectionFuseCostBasedV2.LOG.isTraceEnabled()) {
                    PlanSelectionFuseCostBasedV2.LOG.trace((Object)("Enum: Skip " + pskip + " plans (by structure)."));
                }
            }
            ++i;
        }
        if (DMLScript.STATISTICS) {
            Statistics.incrementCodegenEnumAllP(rgraph != null || PlanSelectionFuseCostBasedV2.STRUCTURAL_PRUNING == false ? len : 0L);
            Statistics.incrementCodegenEnumEval(numEvalPlans);
            Statistics.incrementCodegenEnumEvalP(numEvalPartPlans);
        }
        if (PlanSelectionFuseCostBasedV2.LOG.isTraceEnabled()) {
            PlanSelectionFuseCostBasedV2.LOG.trace((Object)("Enum: Optimal plan: " + Arrays.toString(bestPlan)));
        }
        if (PlanSelectionFuseCostBasedV2.probePlanCache(matPoints)) {
            PlanSelectionFuseCostBasedV2.putPlan(pKey, bestPlan);
        }
        return bestPlan == null ? new boolean[Mlen] : Arrays.copyOfRange(bestPlan, off, bestPlan.length);
    }

    private static boolean[] createAssignment(int len, int off, long pos) {
        boolean[] ret = new boolean[off + len];
        Arrays.fill(ret, 0, off, true);
        long tmp = pos;
        for (int i = 0; i < len; ++i) {
            long mask = UtilFunctions.pow(2, len - i - 1);
            ret[off + i] = tmp >= mask;
            tmp %= mask;
        }
        return ret;
    }

    private static long getNumSkipPlans(boolean[] plan) {
        int pos = ArrayUtils.lastIndexOf((boolean[])plan, (boolean)true);
        return UtilFunctions.pow(2, plan.length - pos - 1);
    }

    private static double getLowerBoundCosts(PlanPartition part, InterestingPoint[] M, CPlanMemoTable memo, StaticCosts costs, boolean[] plan) {
        double lb = Math.max(costs._read, costs._compute) + costs._write + PlanSelectionFuseCostBasedV2.getMaterializationCost(part, M, memo, plan);
        if (part.hasOuter()) {
            lb *= costs._minSparsity;
        }
        return lb;
    }

    private static double getMaterializationCost(PlanPartition part, InterestingPoint[] M, CPlanMemoTable memo, boolean[] plan) {
        double costs = 0.0;
        HashSet<Long> matTargets = new HashSet<Long>();
        for (int i = 0; i < plan.length; ++i) {
            long hopID = M[i].getToHopID();
            if (!plan[i] || matTargets.contains(hopID)) continue;
            matTargets.add(hopID);
            Hop hop = memo.getHopRefs().get(hopID);
            long size = PlanSelectionFuseCostBasedV2.getSize(hop);
            costs += (double)(size * 8L) / 2.147483648E9 + (double)(size * 8L) / 3.4359738368E10;
        }
        for (Long hopID : part.getExtConsumed()) {
            if (matTargets.contains(hopID)) continue;
            matTargets.add(hopID);
            Hop hop = memo.getHopRefs().get(hopID);
            costs += (double)(PlanSelectionFuseCostBasedV2.getSize(hop) * 8L) / 2.147483648E9;
        }
        return costs;
    }

    private static double getReadCost(PlanPartition part, CPlanMemoTable memo) {
        double costs = 0.0;
        for (Long hopID : part.getInputs()) {
            Hop hop = memo.getHopRefs().get(hopID);
            costs += PlanSelectionFuseCostBasedV2.getSafeMemEst(hop) / 3.4359738368E10;
        }
        return costs;
    }

    private static double getWriteCost(Collection<Long> R, CPlanMemoTable memo) {
        double costs = 0.0;
        for (Long hopID : R) {
            Hop hop = memo.getHopRefs().get(hopID);
            costs += (double)(PlanSelectionFuseCostBasedV2.getSize(hop) * 8L) / 2.147483648E9;
        }
        return costs;
    }

    private static double sumComputeCost(HashMap<Long, Double> computeCosts) {
        return computeCosts.values().stream().mapToDouble(d -> d / COMPUTE_BANDWIDTH).sum();
    }

    private static double minOuterSparsity(PlanPartition part, CPlanMemoTable memo) {
        return !part.hasOuter() ? 1.0 : part.getPartition().stream().map(k -> HopRewriteUtils.getLargestInput(memo.getHopRefs().get(k))).mapToDouble(h -> h.dimsKnown(true) ? h.getSparsity() : 0.1).min().orElse(0.1);
    }

    private static double sumTmpInputOutputSize(CPlanMemoTable memo, CostVector vect) {
        return vect.outSize + vect.inSizes.entrySet().stream().filter(e -> !HopRewriteUtils.isData(memo.getHopRefs().get(e.getKey()), Hop.DataOpTypes.TRANSIENTREAD)).mapToDouble(e -> (Double)e.getValue()).sum();
    }

    private static double sumInputMemoryEstimates(CPlanMemoTable memo, CostVector vect) {
        return vect.inSizes.keySet().stream().mapToDouble(e -> PlanSelectionFuseCostBasedV2.getSafeMemEst(memo.getHopRefs().get(e))).sum();
    }

    private static double getSafeMemEst(Hop hop) {
        return !hop.dimsKnown() ? (double)(PlanSelectionFuseCostBasedV2.getSize(hop) * 8L) : hop.getOutputMemEstimate();
    }

    private static long getSize(Hop hop) {
        return Math.max(hop.getDim1(), 1L) * Math.max(hop.getDim2(), 1L);
    }

    private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R) {
        HashSet<Long> refHops = new HashSet<Long>();
        for (Map.Entry<Long, List<CPlanMemoTable.MemoTableEntry>> e : memo.getPlans().entrySet()) {
            if (e.getValue().isEmpty()) continue;
            Hop hop = memo.getHopRefs().get(e.getKey());
            for (Hop c : hop.getInput()) {
                refHops.add(c.getHopID());
            }
        }
        ArrayList<Long> fullAggs = new ArrayList<Long>();
        for (Long hopID : R) {
            Hop root = memo.getHopRefs().get(hopID);
            if (refHops.contains(hopID) || !PlanSelectionFuseCostBasedV2.isMultiAggregateRoot(root)) continue;
            fullAggs.add(hopID);
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("Found within-partition ua(RC) aggregations: " + Arrays.toString((Object[])fullAggs.toArray(new Long[0]))));
        }
        for (int i = 0; i < fullAggs.size(); i += 3) {
            int ito = Math.min(i + 3, fullAggs.size());
            if (ito - i < 2) continue;
            CPlanMemoTable.MemoTableEntry me = new CPlanMemoTable.MemoTableEntry(TemplateBase.TemplateType.MAGG, (Long)fullAggs.get(i), (Long)fullAggs.get(i + 1), ito - i == 3 ? (Long)fullAggs.get(i + 2) : -1L, ito - i);
            if (PlanSelectionFuseCostBasedV2.isValidMultiAggregate(memo, me)) {
                for (int j = i; j < ito; ++j) {
                    memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
                    if (!LOG.isTraceEnabled()) continue;
                    LOG.trace((Object)("Added multiagg plan: " + fullAggs.get(j) + " " + me));
                }
                continue;
            }
            if (!LOG.isTraceEnabled()) continue;
            LOG.trace((Object)("Removed invalid multiagg plan: " + me));
        }
    }

    /*
     * WARNING - void declaration
     */
    private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
        boolean bl;
        HashSet<Long> fullAggs = new HashSet<Long>();
        Hop.resetVisitStatus(roots);
        for (Hop hop : roots) {
            PlanSelectionFuseCostBasedV2.rCollectFullAggregates(hop, fullAggs);
        }
        Hop.resetVisitStatus(roots);
        fullAggs.removeIf(p -> memo.contains((long)p, TemplateBase.TemplateType.MAGG));
        if (fullAggs.size() <= 1) {
            return;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("Found across-partition ua(RC) aggregations: " + Arrays.toString((Object[])fullAggs.toArray(new Long[0]))));
        }
        List<Object> aggInfos = new ArrayList();
        for (Long l : fullAggs) {
            Hop hop = memo.getHopRefs().get(l);
            AggregateInfo tmp = new AggregateInfo(hop);
            for (int i = 0; i < hop.getInput().size(); ++i) {
                Hop c = HopRewriteUtils.isMatrixMultiply(hop) && i == 0 ? hop.getInput().get(0).getInput().get(0) : hop.getInput().get(i);
                PlanSelectionFuseCostBasedV2.rExtractAggregateInfo(memo, c, tmp, TemplateBase.TemplateType.CELL);
            }
            if (tmp._fusedInputs.isEmpty()) {
                if (HopRewriteUtils.isMatrixMultiply(hop)) {
                    tmp.addFusedInput(hop.getInput().get(0).getInput().get(0).getHopID());
                    tmp.addFusedInput(hop.getInput().get(1).getHopID());
                } else {
                    tmp.addFusedInput(hop.getInput().get(0).getHopID());
                }
            }
            aggInfos.add(tmp);
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)"Extracted across-partition ua(RC) aggregation info: ");
            for (AggregateInfo aggregateInfo : aggInfos) {
                LOG.trace((Object)aggregateInfo);
            }
        }
        aggInfos = aggInfos.stream().sorted(Comparator.comparing(a -> a._inputAggs.size())).collect(Collectors.toList());
        boolean bl2 = false;
        while (!bl) {
            void var6_16;
            Object var6_15 = null;
            for (int i = 0; i < aggInfos.size(); ++i) {
                AggregateInfo current = (AggregateInfo)aggInfos.get(i);
                for (int j = i + 1; j < aggInfos.size(); ++j) {
                    AggregateInfo that = (AggregateInfo)aggInfos.get(j);
                    if (!current.isMergable(that)) continue;
                    AggregateInfo aggregateInfo = current.merge(that);
                    aggInfos.remove(j);
                    --j;
                }
            }
            bl = var6_16 == null;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)"Merged across-partition ua(RC) aggregation info: ");
            for (AggregateInfo aggregateInfo : aggInfos) {
                LOG.trace((Object)aggregateInfo);
            }
        }
        for (AggregateInfo aggregateInfo : aggInfos) {
            if (aggregateInfo._aggregates.size() <= 1) continue;
            Long[] aggs = aggregateInfo._aggregates.keySet().toArray(new Long[0]);
            CPlanMemoTable.MemoTableEntry me = new CPlanMemoTable.MemoTableEntry(TemplateBase.TemplateType.MAGG, aggs[0], aggs[1], aggs.length > 2 ? aggs[2] : -1L, aggs.length);
            for (int i = 0; i < aggs.length; ++i) {
                memo.add(memo.getHopRefs().get(aggs[i]), me);
                this.addBestPlan(aggs[i], me);
                if (!LOG.isTraceEnabled()) continue;
                LOG.trace((Object)("Added multiagg* plan: " + aggs[i] + " " + me));
            }
        }
    }

    private static boolean isMultiAggregateRoot(Hop root) {
        return HopRewriteUtils.isAggUnaryOp(root, Hop.AggOp.SUM, Hop.AggOp.SUM_SQ, Hop.AggOp.MIN, Hop.AggOp.MAX) && ((AggUnaryOp)root).getDirection() == Hop.Direction.RowCol || root instanceof AggBinaryOp && root.getDim1() == 1L && root.getDim2() == 1L && HopRewriteUtils.isTransposeOperation(root.getInput().get(0));
    }

    private static boolean isValidMultiAggregate(CPlanMemoTable memo, CPlanMemoTable.MemoTableEntry me) {
        int i;
        boolean ret = true;
        Hop refSize = memo.getHopRefs().get(me.input1).getInput().get(0);
        for (i = 1; ret && i < 3; ++i) {
            if (!me.isPlanRef(i)) continue;
            ret &= HopRewriteUtils.isEqualSize(refSize, memo.getHopRefs().get(me.input(i)).getInput().get(0));
        }
        for (i = 0; ret && i < 3; ++i) {
            if (!me.isPlanRef(i)) continue;
            HashSet<Long> probe = new HashSet<Long>();
            for (int j = 0; j < 3; ++j) {
                if (i == j) continue;
                probe.add(me.input(j));
            }
            ret &= PlanSelectionFuseCostBasedV2.rCheckMultiAggregate(memo.getHopRefs().get(me.input(i)), probe);
        }
        return ret;
    }

    private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> probe) {
        boolean ret = true;
        for (Hop c : current.getInput()) {
            ret &= PlanSelectionFuseCostBasedV2.rCheckMultiAggregate(c, probe);
        }
        return ret &= !probe.contains(current.getHopID());
    }

    private static void rCollectFullAggregates(Hop current, HashSet<Long> aggs) {
        if (current.isVisited()) {
            return;
        }
        if (PlanSelectionFuseCostBasedV2.isMultiAggregateRoot(current)) {
            aggs.add(current.getHopID());
        }
        for (Hop c : current.getInput()) {
            PlanSelectionFuseCostBasedV2.rCollectFullAggregates(c, aggs);
        }
        current.setVisited();
    }

    private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop current, AggregateInfo aggInfo, TemplateBase.TemplateType type) {
        if (PlanSelectionFuseCostBasedV2.isMultiAggregateRoot(current)) {
            aggInfo.addInputAggregate(current.getHopID());
        }
        CPlanMemoTable.MemoTableEntry me = type != null ? memo.getBest(current.getHopID()) : null;
        for (int i = 0; i < current.getInput().size(); ++i) {
            Hop c = current.getInput().get(i);
            if (me != null && me.isPlanRef(i)) {
                PlanSelectionFuseCostBasedV2.rExtractAggregateInfo(memo, c, aggInfo, type);
                continue;
            }
            if (type != null && c.getDataType().isMatrix()) {
                aggInfo.addFusedInput(c.getHopID());
            }
            PlanSelectionFuseCostBasedV2.rExtractAggregateInfo(memo, c, aggInfo, null);
        }
    }

    private static HashSet<Long> collectIrreplaceableRowOps(CPlanMemoTable memo, PlanPartition part) {
        HashSet<Long> blacklist = new HashSet<Long>();
        HashSet<Pair<Long, Integer>> visited = new HashSet<Pair<Long, Integer>>();
        for (Long hopID : part.getRoots()) {
            PlanSelectionFuseCostBasedV2.rCollectDependentRowOps(memo.getHopRefs().get(hopID), memo, part, blacklist, visited, null, false);
        }
        return blacklist;
    }

    private static void rCollectDependentRowOps(Hop hop, CPlanMemoTable memo, PlanPartition part, HashSet<Long> blacklist, HashSet<Pair<Long, Integer>> visited, TemplateBase.TemplateType type, boolean foundRowOp) {
        int i;
        boolean diffPlans;
        Pair key = Pair.of((Object)hop.getHopID(), (Object)((foundRowOp ? Short.MAX_VALUE : 0) + (type != null ? type.ordinal() + 1 : 0)));
        if (visited.contains(key) || !part.getPartition().contains(hop.getHopID())) {
            return;
        }
        CPlanMemoTable.MemoTableEntry me = type == null ? memo.getBest(hop.getHopID()) : memo.getBest(hop.getHopID(), type);
        boolean inRow = me != null && me.type == TemplateBase.TemplateType.ROW && type == TemplateBase.TemplateType.ROW;
        boolean bl = diffPlans = part.getMatPointsExt().length > 0 && memo.contains(hop.getHopID(), TemplateBase.TemplateType.ROW) && !memo.hasOnlyExactMatches(hop.getHopID(), TemplateBase.TemplateType.ROW, TemplateBase.TemplateType.CELL);
        if (inRow && foundRowOp) {
            blacklist.add(hop.getHopID());
        }
        if (PlanSelectionFuseCostBasedV2.isRowAggOp(hop, inRow) || diffPlans) {
            blacklist.add(hop.getHopID());
            foundRowOp = true;
        }
        for (i = 0; i < hop.getInput().size(); ++i) {
            boolean lfoundRowOp = foundRowOp && me != null && (me.isPlanRef(i) || PlanSelectionFuseCostBasedV2.isImplicitlyFused(hop, i, me.type));
            PlanSelectionFuseCostBasedV2.rCollectDependentRowOps(hop.getInput().get(i), memo, part, blacklist, visited, me != null ? me.type : null, lfoundRowOp);
        }
        if (!blacklist.contains(hop.getHopID())) {
            for (i = 0; i < hop.getInput().size(); ++i) {
                if (me == null || me.type != TemplateBase.TemplateType.ROW || !me.isPlanRef(i) && !PlanSelectionFuseCostBasedV2.isImplicitlyFused(hop, i, me.type) || !blacklist.contains(hop.getInput().get(i).getHopID())) continue;
                blacklist.add(hop.getHopID());
            }
        }
        visited.add((Pair<Long, Integer>)key);
    }

    private static boolean isRowAggOp(Hop hop, boolean inRow) {
        return HopRewriteUtils.isBinary(hop, Hop.OpOp2.CBIND) || HopRewriteUtils.isNary(hop, Hop.OpOpN.CBIND) || hop instanceof AggBinaryOp && (inRow || !hop.dimsKnown() || hop.getDim1() != 1L && hop.getDim2() != 1L) || HopRewriteUtils.isReorg(hop, Hop.ReOrgOp.TRANSPOSE) && hop.getDim1() != 1L && hop.getDim2() != 1L || hop instanceof AggUnaryOp && inRow;
    }

    private static boolean isValidRow2CellOp(Hop hop) {
        return !HopRewriteUtils.isBinary(hop, Hop.OpOp2.CBIND) && (!(hop instanceof AggBinaryOp) || hop.getDim1() == 1L || hop.getDim2() == 1L);
    }

    private static void pruneInvalidAndSpecialCasePlans(CPlanMemoTable memo, PlanPartition part) {
        if (OptimizerUtils.isSparkExecutionMode()) {
            for (Long hopID : part.getPartition()) {
                boolean isSpark;
                if (!memo.contains(hopID, TemplateBase.TemplateType.ROW)) continue;
                Hop hop = memo.getHopRefs().get(hopID);
                boolean bl = isSpark = DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.SPARK || OptimizerUtils.getTotalMemEstimate(hop.getInput().toArray(new Hop[0]), hop, true) > OptimizerUtils.getLocalMemBudget();
                boolean validNcol = hop.getDataType().isScalar() || (HopRewriteUtils.isTransposeOperation(hop) ? hop.getDim1() <= hop.getRowsInBlock() : hop.getDim2() <= hop.getColsInBlock());
                for (Hop in : hop.getInput()) {
                    validNcol &= in.getDataType().isScalar() || in.getDim2() <= in.getColsInBlock() || hop instanceof AggBinaryOp && in.getDim1() <= in.getRowsInBlock() && HopRewriteUtils.isTransposeOperation(in);
                }
                if (!isSpark || validNcol) continue;
                List<CPlanMemoTable.MemoTableEntry> blacklist = memo.get(hopID, TemplateBase.TemplateType.ROW);
                memo.remove(memo.getHopRefs().get(hopID), TemplateBase.TemplateType.ROW);
                memo.removeAllRefTo(hopID, TemplateBase.TemplateType.ROW);
                if (!LOG.isTraceEnabled()) continue;
                LOG.trace((Object)("Removed row memo table entries w/ violated blocksize constraint (" + hopID + "): " + Arrays.toString(blacklist.toArray(new CPlanMemoTable.MemoTableEntry[0]))));
            }
        }
        HashSet<Long> blacklist = PlanSelectionFuseCostBasedV2.collectIrreplaceableRowOps(memo, part);
        for (Long hopID : part.getPartition()) {
            CPlanMemoTable.MemoTableEntry me;
            if (blacklist.contains(hopID) || (me = memo.getBest(hopID, TemplateBase.TemplateType.ROW)) == null || me.type != TemplateBase.TemplateType.ROW || !memo.hasOnlyExactMatches(hopID, TemplateBase.TemplateType.ROW, TemplateBase.TemplateType.CELL)) continue;
            List<CPlanMemoTable.MemoTableEntry> rmList = memo.get(hopID, TemplateBase.TemplateType.ROW);
            memo.remove(memo.getHopRefs().get(hopID), new HashSet<CPlanMemoTable.MemoTableEntry>(rmList));
            if (!LOG.isTraceEnabled()) continue;
            LOG.trace((Object)("Removed row memo table entries w/o aggregation: " + Arrays.toString(rmList.toArray(new CPlanMemoTable.MemoTableEntry[0]))));
        }
        for (Long hopID : part.getPartition()) {
            CPlanMemoTable.MemoTableEntry me2;
            List<CPlanMemoTable.MemoTableEntry> entries;
            CPlanMemoTable.MemoTableEntry me1;
            CPlanMemoTable.MemoTableEntry rmEntry;
            if (memo.countEntries(hopID, TemplateBase.TemplateType.OUTER) != 2 || (rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1 = (entries = memo.get(hopID, TemplateBase.TemplateType.OUTER)).get(0), me2 = entries.get(1))) == null) continue;
            memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
            memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
            if (!LOG.isTraceEnabled()) continue;
            LOG.trace((Object)("Removed dominated outer product memo table entry: " + rmEntry));
        }
    }

    private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, PlanPartition part, InterestingPoint[] matPoints, boolean[] plan) {
        if (visited.contains(current.getHopID())) {
            return;
        }
        long hopID = current.getHopID();
        if (part.getPartition().contains(hopID) && memo.contains(hopID)) {
            Iterator<CPlanMemoTable.MemoTableEntry> iter = memo.get(hopID).iterator();
            while (iter.hasNext()) {
                CPlanMemoTable.MemoTableEntry me = iter.next();
                if (PlanSelectionFuseCostBasedV2.hasNoRefToMatPoint(hopID, me, matPoints, plan) || me.type == TemplateBase.TemplateType.OUTER) continue;
                iter.remove();
                if (!LOG.isTraceEnabled()) continue;
                LOG.trace((Object)("Removed memo table entry: " + me));
            }
        }
        for (Hop c : current.getInput()) {
            PlanSelectionFuseCostBasedV2.rPruneSuboptimalPlans(memo, c, visited, part, matPoints, plan);
        }
        visited.add(current.getHopID());
    }

    private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, PlanPartition part, boolean[] plan) {
        if (visited.contains(current.getHopID())) {
            return;
        }
        for (Hop c : current.getInput()) {
            PlanSelectionFuseCostBasedV2.rPruneInvalidPlans(memo, c, visited, part, plan);
        }
        long hopID = current.getHopID();
        if (part.getPartition().contains(hopID) && memo.contains(hopID, TemplateBase.TemplateType.ROW)) {
            Iterator<CPlanMemoTable.MemoTableEntry> iter = memo.get(hopID, TemplateBase.TemplateType.ROW).iterator();
            while (iter.hasNext()) {
                String type;
                CPlanMemoTable.MemoTableEntry me = iter.next();
                boolean applyLeaf = !me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current);
                boolean applyInner = !applyLeaf && !ROW_TPL.open(current);
                int i = 0;
                while (i < 3 & applyInner) {
                    if (me.isPlanRef(i)) {
                        applyInner &= !memo.contains(me.input(i), TemplateBase.TemplateType.ROW);
                    }
                    ++i;
                }
                if (!applyLeaf && !applyInner) continue;
                String string = type = applyLeaf ? "leaf" : "inner";
                if (PlanSelectionFuseCostBasedV2.isValidRow2CellOp(current)) {
                    me.type = TemplateBase.TemplateType.CELL;
                    if (!LOG.isTraceEnabled()) continue;
                    LOG.trace((Object)("Converted " + type + " memo table entry from row to cell: " + me));
                    continue;
                }
                if (LOG.isTraceEnabled()) {
                    LOG.trace((Object)("Removed " + type + " memo table entry row (unsupported cell): " + me));
                }
                iter.remove();
            }
        }
        visited.add(current.getHopID());
    }

    private double getPlanCost(CPlanMemoTable memo, PlanPartition part, InterestingPoint[] matPoints, boolean[] plan, HashMap<Long, Double> computeCosts, double costBound) {
        HashSet<PlanSelection.VisitMarkCost> visited = new HashSet<PlanSelection.VisitMarkCost>();
        double costs = 0.0;
        int rem = part.getRoots().size();
        for (Long hopID : part.getRoots()) {
            if (!((costs += this.rGetPlanCosts(memo, memo.getHopRefs().get(hopID), visited, part, matPoints, plan, computeCosts, null, null, costBound - costs)) >= costBound) || --rem <= 0) continue;
            return Double.POSITIVE_INFINITY;
        }
        return costs;
    }

    private double rGetPlanCosts(CPlanMemoTable memo, Hop current, HashSet<PlanSelection.VisitMarkCost> visited, PlanPartition part, InterestingPoint[] matPoints, boolean[] plan, HashMap<Long, Double> computeCosts, CostVector costsCurrent, TemplateBase.TemplateType currentType, double costBound) {
        int i;
        boolean opened;
        long currentHopId = current.getHopID();
        if (!visited.add(new PlanSelection.VisitMarkCost(currentHopId, costsCurrent == null || currentType == TemplateBase.TemplateType.MAGG ? -1L : costsCurrent.ID))) {
            return 0.0;
        }
        CPlanMemoTable.MemoTableEntry best = null;
        boolean bl = opened = currentType == null;
        if (memo.contains(currentHopId)) {
            if (currentType == null) {
                for (CPlanMemoTable.MemoTableEntry me : memo.get(currentHopId)) {
                    best = me.isValid() && PlanSelectionFuseCostBasedV2.hasNoRefToMatPoint(currentHopId, me, matPoints, plan) && PlanSelection.BasicPlanComparator.icompare(me, best) < 0 ? me : best;
                }
                opened = true;
            } else {
                for (CPlanMemoTable.MemoTableEntry me : memo.get(currentHopId)) {
                    best = (me.type == currentType || me.type == TemplateBase.TemplateType.CELL) && PlanSelectionFuseCostBasedV2.hasNoRefToMatPoint(currentHopId, me, matPoints, plan) && PlanSelection.TypedPlanComparator.icompare(me, best, currentType) < 0 ? me : best;
                }
            }
        }
        CostVector costVect = !opened ? costsCurrent : new CostVector(PlanSelectionFuseCostBasedV2.getSize(current));
        double costs = 0.0;
        if (opened && best != null && best.type == TemplateBase.TemplateType.MAGG) {
            if (best.input1 == currentHopId) {
                for (i = 1; i < 3; ++i) {
                    if (!best.isPlanRef(i) || !((costs += this.rGetPlanCosts(memo, memo.getHopRefs().get(best.input(i)), visited, part, matPoints, plan, computeCosts, costVect, TemplateBase.TemplateType.MAGG, costBound - costs)) >= costBound)) continue;
                    return Double.POSITIVE_INFINITY;
                }
            } else {
                return 0.0;
            }
        }
        costVect.computeCosts += computeCosts.get(currentHopId).doubleValue();
        for (i = 0; i < current.getInput().size(); ++i) {
            Hop c = current.getInput().get(i);
            if (best != null && best.isPlanRef(i)) {
                costs += this.rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, costVect, best.type, costBound - costs);
            } else if (best != null && PlanSelectionFuseCostBasedV2.isImplicitlyFused(current, i, best.type)) {
                costVect.addInputSize(c.getInput().get(0).getHopID(), PlanSelectionFuseCostBasedV2.getSize(c));
            } else {
                if (part.getPartition().contains(c.getHopID())) {
                    costs += this.rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, null, null, costBound - costs);
                }
                if (costVect != null && c.getDataType().isMatrix()) {
                    costVect.addInputSize(c.getHopID(), PlanSelectionFuseCostBasedV2.getSize(c));
                }
            }
            if (!(costs >= costBound)) continue;
            return Double.POSITIVE_INFINITY;
        }
        if (opened) {
            double memInputs = PlanSelectionFuseCostBasedV2.sumInputMemoryEstimates(memo, costVect);
            double tmpCosts = costVect.outSize * 8.0 / 2.147483648E9 + Math.max(memInputs / 3.4359738368E10, costVect.computeCosts / COMPUTE_BANDWIDTH);
            if (memInputs > OptimizerUtils.getLocalMemBudget()) {
                tmpCosts += costVect.getSideInputSize() * 8.0 / 1.34217728E8;
            }
            Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
            if (best != null && best.type == TemplateBase.TemplateType.OUTER) {
                tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : 0.1;
            } else if (memInputs <= OptimizerUtils.getLocalMemBudget() && PlanSelectionFuseCostBasedV2.sumTmpInputOutputSize(memo, costVect) * 8.0 > (double)LazyWriteBuffer.getWriteBufferLimit()) {
                tmpCosts += costVect.outSize * 8.0 / 5.36870912E8;
            }
            costs += tmpCosts;
            if (LOG.isTraceEnabled()) {
                String type = best != null ? best.type.name() : "HOP";
                LOG.trace((Object)("Cost vector (" + type + " " + currentHopId + "): " + costVect + " -> " + tmpCosts));
            }
        } else if (part.getExtConsumed().contains(current.getHopID())) {
            costs += this.rGetPlanCosts(memo, current, visited, part, matPoints, plan, computeCosts, null, null, costBound - costs);
        }
        if (costs < 0.0 || Double.isNaN(costs) || Double.isInfinite(costs)) {
            throw new RuntimeException("Wrong cost estimate: " + costs);
        }
        return costs;
    }

    private static void getComputeCosts(Hop current, HashMap<Long, Double> computeCosts) {
        double costs = 1.0;
        if (current instanceof UnaryOp) {
            switch (((UnaryOp)current).getOp()) {
                case ABS: 
                case ROUND: 
                case CEIL: 
                case FLOOR: 
                case SIGN: {
                    costs = 1.0;
                    break;
                }
                case SPROP: 
                case SQRT: {
                    costs = 2.0;
                    break;
                }
                case EXP: {
                    costs = 18.0;
                    break;
                }
                case SIGMOID: {
                    costs = 21.0;
                    break;
                }
                case LOG: 
                case LOG_NZ: {
                    costs = 32.0;
                    break;
                }
                case NCOL: 
                case NROW: 
                case PRINT: 
                case ASSERT: 
                case CAST_AS_BOOLEAN: 
                case CAST_AS_DOUBLE: 
                case CAST_AS_INT: 
                case CAST_AS_MATRIX: 
                case CAST_AS_SCALAR: {
                    costs = 1.0;
                    break;
                }
                case SIN: {
                    costs = 18.0;
                    break;
                }
                case COS: {
                    costs = 22.0;
                    break;
                }
                case TAN: {
                    costs = 42.0;
                    break;
                }
                case ASIN: {
                    costs = 93.0;
                    break;
                }
                case ACOS: {
                    costs = 103.0;
                    break;
                }
                case ATAN: {
                    costs = 40.0;
                    break;
                }
                case SINH: {
                    costs = 93.0;
                    break;
                }
                case COSH: {
                    costs = 103.0;
                    break;
                }
                case TANH: {
                    costs = 40.0;
                    break;
                }
                case CUMSUM: 
                case CUMMIN: 
                case CUMMAX: 
                case CUMPROD: {
                    costs = 1.0;
                    break;
                }
                default: {
                    LOG.warn((Object)("Cost model not implemented yet for: " + (Object)((Object)((UnaryOp)current).getOp())));
                    break;
                }
            }
        } else if (current instanceof BinaryOp) {
            block18 : switch (((BinaryOp)current).getOp()) {
                case MULT: 
                case PLUS: 
                case MINUS: 
                case MIN: 
                case MAX: 
                case AND: 
                case OR: 
                case EQUAL: 
                case NOTEQUAL: 
                case LESS: 
                case LESSEQUAL: 
                case GREATER: 
                case GREATEREQUAL: 
                case CBIND: 
                case RBIND: {
                    costs = 1.0;
                    break;
                }
                case INTDIV: {
                    costs = 6.0;
                    break;
                }
                case MODULUS: {
                    costs = 8.0;
                    break;
                }
                case DIV: {
                    costs = 22.0;
                    break;
                }
                case LOG: 
                case LOG_NZ: {
                    costs = 32.0;
                    break;
                }
                case POW: {
                    costs = HopRewriteUtils.isLiteralOfValue(current.getInput().get(1), 2.0) ? 1 : 16;
                    break;
                }
                case MINUS_NZ: 
                case MINUS1_MULT: {
                    costs = 2.0;
                    break;
                }
                case CENTRALMOMENT: {
                    int type = (int)(current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2L);
                    switch (type) {
                        case 0: {
                            costs = 1.0;
                            break block18;
                        }
                        case 1: {
                            costs = 8.0;
                            break block18;
                        }
                        case 2: {
                            costs = 16.0;
                            break block18;
                        }
                        case 3: {
                            costs = 31.0;
                            break block18;
                        }
                        case 4: {
                            costs = 51.0;
                            break block18;
                        }
                        case 5: {
                            costs = 16.0;
                        }
                    }
                    break;
                }
                case COVARIANCE: {
                    costs = 23.0;
                    break;
                }
                default: {
                    LOG.warn((Object)("Cost model not implemented yet for: " + (Object)((Object)((BinaryOp)current).getOp())));
                    break;
                }
            }
        } else if (current instanceof TernaryOp) {
            block37 : switch (((TernaryOp)current).getOp()) {
                case IFELSE: 
                case PLUS_MULT: 
                case MINUS_MULT: {
                    costs = 2.0;
                    break;
                }
                case CTABLE: {
                    costs = 3.0;
                    break;
                }
                case CENTRALMOMENT: {
                    int type = (int)(current.getInput().get(1) instanceof LiteralOp ? HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2L);
                    switch (type) {
                        case 0: {
                            costs = 2.0;
                            break block37;
                        }
                        case 1: {
                            costs = 9.0;
                            break block37;
                        }
                        case 2: {
                            costs = 17.0;
                            break block37;
                        }
                        case 3: {
                            costs = 32.0;
                            break block37;
                        }
                        case 4: {
                            costs = 52.0;
                            break block37;
                        }
                        case 5: {
                            costs = 17.0;
                        }
                    }
                    break;
                }
                case COVARIANCE: {
                    costs = 23.0;
                    break;
                }
                default: {
                    LOG.warn((Object)("Cost model not implemented yet for: " + (Object)((Object)((TernaryOp)current).getOp())));
                    break;
                }
            }
        } else if (current instanceof ParameterizedBuiltinOp) {
            costs = 1.0;
        } else if (current instanceof IndexingOp) {
            costs = 1.0;
        } else if (current instanceof ReorgOp) {
            costs = 1.0;
        } else if (current instanceof AggBinaryOp) {
            costs = 2L * current.getInput().get(0).getDim2();
            if (current.getInput().get(0).dimsKnown(true)) {
                costs *= current.getInput().get(0).getSparsity();
            }
        } else if (current instanceof AggUnaryOp) {
            switch (((AggUnaryOp)current).getOp()) {
                case SUM: {
                    costs = 4.0;
                    break;
                }
                case SUM_SQ: {
                    costs = 5.0;
                    break;
                }
                case MIN: 
                case MAX: {
                    costs = 1.0;
                    break;
                }
                default: {
                    LOG.warn((Object)("Cost model not implemented yet for: " + (Object)((Object)((AggUnaryOp)current).getOp())));
                }
            }
            switch (((AggUnaryOp)current).getDirection()) {
                case Col: {
                    costs *= (double)Math.max(current.getInput().get(0).getDim1(), 1L);
                    break;
                }
                case Row: {
                    costs *= (double)Math.max(current.getInput().get(0).getDim2(), 1L);
                    break;
                }
                case RowCol: {
                    costs *= (double)PlanSelectionFuseCostBasedV2.getSize(current.getInput().get(0));
                }
            }
        }
        computeCosts.put(current.getHopID(), costs *= (double)PlanSelectionFuseCostBasedV2.getSize(current));
    }

    private static boolean hasNoRefToMatPoint(long hopID, CPlanMemoTable.MemoTableEntry me, InterestingPoint[] M, boolean[] plan) {
        return !InterestingPoint.isMatPoint(M, hopID, me, plan);
    }

    private static boolean isImplicitlyFused(Hop hop, int index, TemplateBase.TemplateType type) {
        return type == TemplateBase.TemplateType.ROW && HopRewriteUtils.isMatrixMultiply(hop) && index == 0 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(index));
    }

    private static boolean probePlanCache(InterestingPoint[] matPoints) {
        return matPoints.length >= 10;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static boolean[] getPlan(PartitionSignature pKey) {
        boolean[] plan = null;
        LinkedHashMap<PartitionSignature, boolean[]> linkedHashMap = _planCache;
        synchronized (linkedHashMap) {
            plan = _planCache.get(pKey);
        }
        if (DMLScript.STATISTICS) {
            if (plan != null) {
                Statistics.incrementCodegenPlanCacheHits();
            }
            Statistics.incrementCodegenPlanCacheTotal();
        }
        return plan;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void putPlan(PartitionSignature pKey, boolean[] plan) {
        LinkedHashMap<PartitionSignature, boolean[]> linkedHashMap = _planCache;
        synchronized (linkedHashMap) {
            if (_planCache.size() >= 1024) {
                Iterator<Map.Entry<PartitionSignature, boolean[]>> iter = _planCache.entrySet().iterator();
                iter.next();
                iter.remove();
            }
            _planCache.put(pKey, plan);
        }
    }

    private class PartitionSignature {
        private final int partNodes;
        private final int inputNodes;
        private final int rootNodes;
        private final int matPoints;
        private final double cCompute;
        private final double cRead;
        private final double cWrite;
        private final double cPlan0;
        private final double cPlanN;

        public PartitionSignature(PlanPartition part, int M, StaticCosts costs, double cP0, double cPN) {
            this.partNodes = part.getPartition().size();
            this.inputNodes = part.getInputs().size();
            this.rootNodes = part.getRoots().size();
            this.matPoints = M;
            this.cCompute = costs._compute;
            this.cRead = costs._read;
            this.cWrite = costs._write;
            this.cPlan0 = cP0;
            this.cPlanN = cPN;
        }

        public int hashCode() {
            return UtilFunctions.intHashCode(Arrays.hashCode(new int[]{this.partNodes, this.inputNodes, this.rootNodes, this.matPoints}), Arrays.hashCode(new double[]{this.cCompute, this.cRead, this.cWrite, this.cPlan0, this.cPlanN}));
        }

        public boolean equals(Object o) {
            if (!(o instanceof PartitionSignature)) {
                return false;
            }
            PartitionSignature that = (PartitionSignature)o;
            return this.partNodes == that.partNodes && this.inputNodes == that.inputNodes && this.rootNodes == that.rootNodes && this.matPoints == that.matPoints && this.cCompute == that.cCompute && this.cRead == that.cRead && this.cWrite == that.cWrite && this.cPlan0 == that.cPlan0 && this.cPlanN == that.cPlanN;
        }
    }

    private static class AggregateInfo {
        public final HashMap<Long, Hop> _aggregates;
        public final HashSet<Long> _inputAggs = new HashSet();
        public final HashSet<Long> _fusedInputs = new HashSet();

        public AggregateInfo(Hop aggregate) {
            this._aggregates = new HashMap();
            this._aggregates.put(aggregate.getHopID(), aggregate);
        }

        public void addInputAggregate(long hopID) {
            this._inputAggs.add(hopID);
        }

        public void addFusedInput(long hopID) {
            this._fusedInputs.add(hopID);
        }

        public boolean isMergable(AggregateInfo that) {
            boolean ret = this._aggregates.size() < 3 && this._aggregates.size() + that._aggregates.size() <= 3;
            for (Long hopID : that._aggregates.keySet()) {
                ret &= !this._inputAggs.contains(hopID);
            }
            for (Long hopID : this._aggregates.keySet()) {
                ret &= !that._inputAggs.contains(hopID);
            }
            boolean bl = !CollectionUtils.intersection(this._fusedInputs, that._fusedInputs).isEmpty();
            Hop in1 = this._aggregates.values().iterator().next();
            Hop in2 = that._aggregates.values().iterator().next();
            return (ret &= bl) && HopRewriteUtils.isEqualSize(in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1) ? 1 : 0), in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2) ? 1 : 0));
        }

        public AggregateInfo merge(AggregateInfo that) {
            this._aggregates.putAll(that._aggregates);
            this._inputAggs.addAll(that._inputAggs);
            this._fusedInputs.addAll(that._fusedInputs);
            return this;
        }

        public String toString() {
            return "[" + Arrays.toString((Object[])this._aggregates.keySet().toArray(new Long[0])) + ": {" + Arrays.toString((Object[])this._inputAggs.toArray(new Long[0])) + "},{" + Arrays.toString((Object[])this._fusedInputs.toArray(new Long[0])) + "}]";
        }
    }

    private static class StaticCosts {
        public final HashMap<Long, Double> _computeCosts;
        public final double _compute;
        public final double _read;
        public final double _write;
        public final double _minSparsity;

        public StaticCosts(HashMap<Long, Double> allComputeCosts, double computeCost, double readCost, double writeCost, double minSparsity) {
            this._computeCosts = allComputeCosts;
            this._compute = computeCost;
            this._read = readCost;
            this._write = writeCost;
            this._minSparsity = minSparsity;
        }

        public double getMinCosts() {
            return Math.max(this._read, this._compute) + this._write;
        }
    }

    private class CostVector {
        public final long ID;
        public final double outSize;
        public double computeCosts = 0.0;
        public final HashMap<Long, Double> inSizes = new HashMap();

        public CostVector(double outputSize) {
            this.ID = PlanSelectionFuseCostBasedV2.this.COST_ID.getNextID();
            this.outSize = outputSize;
        }

        public void addInputSize(long hopID, double inputSize) {
            this.inSizes.put(hopID, inputSize);
        }

        public double getInputSize() {
            return this.inSizes.values().stream().mapToDouble(d -> d).sum();
        }

        public double getSideInputSize() {
            double max = this.getMaxInputSize();
            return this.inSizes.values().stream().filter(d -> d < max).mapToDouble(d -> d).sum();
        }

        public double getMaxInputSize() {
            return this.inSizes.values().stream().mapToDouble(d -> d).max().orElse(0.0);
        }

        public long getMaxInputSizeHopID() {
            long id = -1L;
            double max = 0.0;
            for (Map.Entry<Long, Double> e : this.inSizes.entrySet()) {
                if (!(max < e.getValue())) continue;
                id = e.getKey();
                max = e.getValue();
            }
            return id;
        }

        public String toString() {
            return "[" + this.outSize + ", " + this.computeCosts + ", {" + Arrays.toString((Object[])this.inSizes.keySet().toArray(new Long[0])) + ", " + Arrays.toString((Object[])this.inSizes.values().toArray(new Double[0])) + "}]";
        }
    }
}

