/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.controlprogram.parfor.opt;

import java.util.ArrayList;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.ipa.InterProceduralAnalysis;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.hops.rewrite.HopRewriteRule;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.hops.rewrite.ProgramRewriter;
import org.apache.sysml.hops.rewrite.RewriteConstantFolding;
import org.apache.sysml.hops.rewrite.RewriteRemoveUnnecessaryBranches;
import org.apache.sysml.hops.rewrite.StatementBlockRewriteRule;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.ForStatement;
import org.apache.sysml.parser.ParForStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.opt.CostEstimator;
import org.apache.sysml.runtime.controlprogram.parfor.opt.CostEstimatorHops;
import org.apache.sysml.runtime.controlprogram.parfor.opt.CostEstimatorRuntime;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptTree;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptTreeConverter;
import org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptimizerConstrained;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptimizerHeuristic;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptimizerRuleBased;
import org.apache.sysml.runtime.controlprogram.parfor.opt.ProgramRecompiler;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Stat;
import org.apache.sysml.runtime.controlprogram.parfor.stat.StatisticMonitor;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.sysml.utils.Statistics;

public class OptimizationWrapper {
    private static final boolean LDEBUG = false;
    private static final Log LOG = LogFactory.getLog((String)OptimizationWrapper.class.getName());
    public static final double PAR_FACTOR_INFRASTRUCTURE = 1.0;
    private static final boolean CHECK_PLAN_CORRECTNESS = false;

    public static void optimize(ParForProgramBlock.POptMode type, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor) throws DMLRuntimeException {
        Timing time = new Timing(true);
        LOG.debug((Object)("ParFOR Opt: Running optimization for ParFOR(" + pb.getID() + ")"));
        int ck = UtilFunctions.toInt((double)Math.max(InfrastructureAnalyzer.getCkMaxCP(), InfrastructureAnalyzer.getCkMaxMR()) * 1.0);
        double cm = (double)InfrastructureAnalyzer.getCmMax() * OptimizerUtils.MEM_UTIL_FACTOR;
        OptimizationWrapper.optimize(type, ck, cm, sb, pb, ec, monitor);
        double timeVal = time.stop();
        LOG.debug((Object)("ParFOR Opt: Finished optimization for PARFOR(" + pb.getID() + ") in " + timeVal + "ms."));
        if (monitor) {
            StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_T, timeVal);
        }
    }

    public static void setLogLevel(Level optLogLevel) {
        Logger.getLogger((String)"org.apache.sysml.runtime.controlprogram.parfor.opt").setLevel(optLogLevel);
    }

    private static void optimize(ParForProgramBlock.POptMode otype, int ck, double cm, ParForStatementBlock sb, ParForProgramBlock pb, ExecutionContext ec, boolean monitor) throws DMLRuntimeException {
        Timing time = new Timing(true);
        if (DMLScript.STATISTICS) {
            Statistics.incrementParForOptimCount();
        }
        Optimizer opt = OptimizationWrapper.createOptimizer(otype);
        Optimizer.CostModelType cmtype = opt.getCostModelType();
        LOG.trace((Object)("ParFOR Opt: Created optimizer (" + (Object)((Object)otype) + "," + (Object)((Object)opt.getPlanInputType()) + "," + (Object)((Object)opt.getCostModelType())));
        OptTree tree = null;
        if (ConfigurationManager.isDynamicRecompilation()) {
            ForStatement fs = (ForStatement)sb.getStatement(0);
            if (LOG.isDebugEnabled()) {
                try {
                    tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
                    LOG.debug((Object)("ParFOR Opt: Input plan (before recompilation):\n" + tree.explain(false)));
                    OptTreeConverter.clear();
                }
                catch (Exception ex) {
                    throw new DMLRuntimeException("Unable to create opt tree.", ex);
                }
            }
            try {
                LocalVariableMap constVars = ProgramRecompiler.getReusableScalarVariables(sb.getDMLProg(), sb, ec.getVariables());
                ProgramRecompiler.replaceConstantScalarVariables(sb, constVars);
            }
            catch (Exception ex) {
                throw new DMLRuntimeException(ex);
            }
            try {
                ProgramRewriter rewriter = OptimizationWrapper.createProgramRewriterWithRuleSets();
                ProgramRewriteStatus state = new ProgramRewriteStatus();
                rewriter.rRewriteStatementBlockHopDAGs(sb, state);
                fs.setBody(rewriter.rRewriteStatementBlocks(fs.getBody(), state, true));
                if (state.getRemovedBranches()) {
                    LOG.debug((Object)"ParFOR Opt: Removed branches during program rewrites, rebuilding runtime program");
                    pb.setChildBlocks(ProgramRecompiler.generatePartitialRuntimeProgram(pb.getProgram(), fs.getBody()));
                }
            }
            catch (Exception ex) {
                throw new DMLRuntimeException(ex);
            }
            try {
                InterProceduralAnalysis ipa;
                Set<String> fcand;
                LocalVariableMap tmp = (LocalVariableMap)ec.getVariables().clone();
                Recompiler.ResetType reset = ConfigurationManager.isCodegenEnabled() ? Recompiler.ResetType.RESET_KNOWN_DIMS : Recompiler.ResetType.RESET;
                Recompiler.recompileProgramBlockHierarchy(pb.getChildBlocks(), tmp, 0L, reset);
                if (pb.hasFunctions() && !(fcand = (ipa = new InterProceduralAnalysis(sb)).analyzeSubProgram()).isEmpty()) {
                    for (String func : fcand) {
                        String[] funcparts = DMLProgram.splitFunctionKey(func);
                        FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(funcparts[0], funcparts[1]);
                        Recompiler.ResetType reset2 = fpb.isRecompileOnce() ? reset : Recompiler.ResetType.NO_RESET;
                        Recompiler.recompileProgramBlockHierarchy(fpb.getChildBlocks(), new LocalVariableMap(), 0L, reset2);
                    }
                }
            }
            catch (Exception ex) {
                throw new DMLRuntimeException(ex);
            }
        }
        try {
            tree = OptTreeConverter.createOptTree(ck, cm, opt.getPlanInputType(), sb, pb, ec);
            LOG.debug((Object)("ParFOR Opt: Input plan (before optimization):\n" + tree.explain(false)));
        }
        catch (Exception ex) {
            throw new DMLRuntimeException("Unable to create opt tree.", ex);
        }
        CostEstimator est = OptimizationWrapper.createCostEstimator(cmtype, ec.getVariables());
        LOG.trace((Object)("ParFOR Opt: Created cost estimator (" + (Object)((Object)cmtype) + ")"));
        opt.optimize(sb, pb, tree, est, ec);
        LOG.debug((Object)("ParFOR Opt: Optimized plan (after optimization): \n" + tree.explain(false)));
        long ltime = (long)time.stop();
        LOG.trace((Object)("ParFOR Opt: Optimized plan in " + ltime + "ms."));
        if (DMLScript.STATISTICS) {
            Statistics.incrementParForOptimTime(ltime);
        }
        OptTreeConverter.clear();
        if (monitor) {
            StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_OPTIMIZER, otype.ordinal());
            StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMTPLANS, opt.getNumTotalPlans());
            StatisticMonitor.putPFStat(pb.getID(), Stat.OPT_NUMEPLANS, opt.getNumEvaluatedPlans());
        }
    }

    private static Optimizer createOptimizer(ParForProgramBlock.POptMode otype) throws DMLRuntimeException {
        OptimizerRuleBased opt = null;
        switch (otype) {
            case HEURISTIC: {
                opt = new OptimizerHeuristic();
                break;
            }
            case RULEBASED: {
                opt = new OptimizerRuleBased();
                break;
            }
            case CONSTRAINED: {
                opt = new OptimizerConstrained();
                break;
            }
            default: {
                throw new DMLRuntimeException("Undefined optimizer: '" + (Object)((Object)otype) + "'.");
            }
        }
        return opt;
    }

    private static CostEstimator createCostEstimator(Optimizer.CostModelType cmtype, LocalVariableMap vars) throws DMLRuntimeException {
        CostEstimator est = null;
        switch (cmtype) {
            case STATIC_MEM_METRIC: {
                est = new CostEstimatorHops(OptTreeConverter.getAbstractPlanMapping());
                break;
            }
            case RUNTIME_METRICS: {
                est = new CostEstimatorRuntime(OptTreeConverter.getAbstractPlanMapping(), (LocalVariableMap)vars.clone());
                break;
            }
            default: {
                throw new DMLRuntimeException("Undefined cost model type: '" + (Object)((Object)cmtype) + "'.");
            }
        }
        return est;
    }

    private static ProgramRewriter createProgramRewriterWithRuleSets() {
        ArrayList<HopRewriteRule> hRewrites = new ArrayList<HopRewriteRule>();
        hRewrites.add(new RewriteConstantFolding());
        ArrayList<StatementBlockRewriteRule> sbRewrites = new ArrayList<StatementBlockRewriteRule>();
        sbRewrites.add(new RewriteRemoveUnnecessaryBranches());
        ProgramRewriter rewriter = new ProgramRewriter(hRewrites, sbRewrites);
        return rewriter;
    }
}

