/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.api.mlcontext;

import java.io.IOException;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.ScriptExecutorUtils;
import org.apache.sysml.api.jmlc.JMLCUtils;
import org.apache.sysml.api.mlcontext.MLContext;
import org.apache.sysml.api.mlcontext.MLContextException;
import org.apache.sysml.api.mlcontext.MLContextUtil;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Metadata;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.conf.CompilerConfig;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.globalopt.GlobalOptimizerWrapper;
import org.apache.sysml.hops.rewrite.ProgramRewriter;
import org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.parser.LanguageException;
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.parser.ParserFactory;
import org.apache.sysml.parser.ParserWrapper;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysml.utils.Explain;
import org.apache.sysml.utils.Statistics;

public class ScriptExecutor {
    protected DMLConfig config;
    protected DMLProgram dmlProgram;
    protected DMLTranslator dmlTranslator;
    protected Program runtimeProgram;
    protected ExecutionContext executionContext;
    protected Script script;
    protected boolean init = false;
    protected boolean explain = false;
    protected boolean gpu = false;
    protected boolean oldGPU = false;
    protected boolean forceGPU = false;
    protected boolean oldForceGPU = false;
    protected boolean statistics = false;
    protected boolean oldStatistics = false;
    protected MLContext.ExplainLevel explainLevel;
    protected MLContext.ExecutionType executionType;
    protected int statisticsMaxHeavyHitters = 10;
    protected boolean maintainSymbolTable = false;

    public ScriptExecutor() {
        this.config = ConfigurationManager.getDMLConfig();
    }

    public ScriptExecutor(DMLConfig config) {
        this.config = config;
        ConfigurationManager.setGlobalConfig(config);
    }

    protected void constructHops() {
        try {
            this.dmlTranslator.constructHops(this.dmlProgram);
        }
        catch (LanguageException | ParseException e) {
            throw new MLContextException("Exception occurred while constructing HOPS (high-level operators)", e);
        }
    }

    protected void rewriteHops() {
        try {
            this.dmlTranslator.rewriteHopsDAG(this.dmlProgram);
        }
        catch (HopsException | LanguageException | ParseException | DMLRuntimeException e) {
            throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e);
        }
    }

    protected void showExplanation() {
        if (!this.explain) {
            return;
        }
        try {
            Explain.ExplainType explainType = this.explainLevel != null ? this.explainLevel.getExplainType() : Explain.ExplainType.RUNTIME;
            System.out.println(Explain.display(this.dmlProgram, this.runtimeProgram, explainType, null));
        }
        catch (Exception e) {
            throw new MLContextException("Exception occurred while explaining dml program", e);
        }
    }

    protected void constructLops() {
        try {
            this.dmlTranslator.constructLops(this.dmlProgram);
        }
        catch (HopsException | LopsException | LanguageException | ParseException e) {
            throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e);
        }
    }

    protected void generateRuntimeProgram() {
        try {
            this.runtimeProgram = this.dmlTranslator.getRuntimeProgram(this.dmlProgram, this.config);
        }
        catch (IOException | HopsException | LopsException | LanguageException | DMLRuntimeException e) {
            throw new MLContextException("Exception occurred while generating runtime program", e);
        }
    }

    protected void countCompiledMRJobsAndSparkInstructions() {
        Explain.ExplainCounts counts = Explain.countDistributedOperations(this.runtimeProgram);
        Statistics.resetNoOfCompiledJobs(counts.numJobs);
    }

    protected void createAndInitializeExecutionContext() {
        this.executionContext = ExecutionContextFactory.createContext(this.runtimeProgram);
        LocalVariableMap symbolTable = this.script.getSymbolTable();
        if (symbolTable != null) {
            this.executionContext.setVariables(symbolTable);
        }
        this.executionContext.getVariables().setRegisteredOutputs(new HashSet<String>(this.script.getOutputVariables()));
    }

    protected void setGlobalFlags() {
        this.oldStatistics = DMLScript.STATISTICS;
        DMLScript.STATISTICS = this.statistics;
        this.oldForceGPU = DMLScript.FORCE_ACCELERATOR;
        DMLScript.FORCE_ACCELERATOR = this.forceGPU;
        this.oldGPU = DMLScript.USE_ACCELERATOR;
        DMLScript.USE_ACCELERATOR = this.gpu;
        DMLScript.STATISTICS_COUNT = this.statisticsMaxHeavyHitters;
        try {
            OptimizerUtils.resetStaticCompilerFlags();
            CompilerConfig cconf = OptimizerUtils.constructCompilerConfig(ConfigurationManager.getCompilerConfig(), this.config);
            ConfigurationManager.setGlobalConfig(cconf);
        }
        catch (DMLRuntimeException ex) {
            throw new RuntimeException(ex);
        }
        GPUContextPool.AVAILABLE_GPUS = this.config.getTextValue("sysml.gpu.availableGPUs");
        String evictionPolicy = this.config.getTextValue("sysml.gpu.eviction.policy").toUpperCase();
        try {
            DMLScript.GPU_EVICTION_POLICY = DMLScript.EvictionPolicy.valueOf(evictionPolicy);
        }
        catch (IllegalArgumentException e) {
            throw new RuntimeException("Unsupported eviction policy:" + evictionPolicy);
        }
    }

    protected void resetGlobalFlags() {
        DMLScript.STATISTICS = this.oldStatistics;
        DMLScript.FORCE_ACCELERATOR = this.oldForceGPU;
        DMLScript.USE_ACCELERATOR = this.oldGPU;
        DMLScript.STATISTICS_COUNT = DMLScript.DMLOptions.defaultOptions.statsCount;
    }

    public void compile(Script script) {
        this.compile(script, true);
    }

    public void compile(Script script, boolean performHOPRewrites) {
        this.setup(script);
        if (this.statistics) {
            Statistics.startCompileTimer();
        }
        this.parseScript();
        this.liveVariableAnalysis();
        this.validateScript();
        this.constructHops();
        if (performHOPRewrites) {
            this.rewriteHops();
        }
        this.rewritePersistentReadsAndWrites();
        this.constructLops();
        this.generateRuntimeProgram();
        this.showExplanation();
        this.globalDataFlowOptimization();
        this.countCompiledMRJobsAndSparkInstructions();
        this.initializeCachingAndScratchSpace();
        this.cleanupRuntimeProgram();
        if (this.statistics) {
            Statistics.stopCompileTimer();
        }
    }

    public MLResults execute(Script script) {
        this.compile(script);
        try {
            this.createAndInitializeExecutionContext();
            this.executeRuntimeProgram();
        }
        finally {
            this.cleanupAfterExecution();
        }
        MLResults mlResults = new MLResults(script);
        script.setResults(mlResults);
        return mlResults;
    }

    protected void setup(Script script) {
        this.script = script;
        this.checkScriptHasTypeAndString();
        script.setScriptExecutor(this);
        DMLScript.SCRIPT_TYPE = script.getScriptType();
        this.setGlobalFlags();
        Statistics.resetNoOfExecutedJobs();
        if (this.statistics) {
            Statistics.reset();
        }
    }

    protected void cleanupAfterExecution() {
        this.restoreInputsInSymbolTable();
        this.resetGlobalFlags();
    }

    protected void restoreInputsInSymbolTable() {
        Map<String, Object> inputs = this.script.getInputs();
        Map<String, Metadata> inputMetadata = this.script.getInputMetadata();
        LocalVariableMap symbolTable = this.script.getSymbolTable();
        Set<String> inputVariables = this.script.getInputVariables();
        for (String inputVariable : inputVariables) {
            if (symbolTable.get(inputVariable) != null) continue;
            Metadata m = inputMetadata.get(inputVariable);
            this.script.in(inputVariable, inputs.get(inputVariable), m);
        }
    }

    protected void cleanupRuntimeProgram() {
        if (this.maintainSymbolTable) {
            MLContextUtil.deleteRemoveVariableInstructions(this.runtimeProgram);
        } else {
            JMLCUtils.cleanupRuntimeProgram(this.runtimeProgram, this.script.getOutputVariables() == null ? new String[]{} : this.script.getOutputVariables().toArray(new String[0]));
        }
    }

    protected void executeRuntimeProgram() {
        try {
            ScriptExecutorUtils.executeRuntimeProgram(this, this.statistics ? this.statisticsMaxHeavyHitters : 0);
        }
        catch (DMLRuntimeException e) {
            throw new MLContextException("Exception occurred while executing runtime program", e);
        }
    }

    protected void initializeCachingAndScratchSpace() {
        if (!this.init) {
            return;
        }
        try {
            DMLScript.initHadoopExecution(this.config);
        }
        catch (ParseException e) {
            throw new MLContextException("Exception occurred initializing caching and scratch space", e);
        }
        catch (DMLRuntimeException e) {
            throw new MLContextException("Exception occurred initializing caching and scratch space", e);
        }
        catch (IOException e) {
            throw new MLContextException("Exception occurred initializing caching and scratch space", e);
        }
    }

    protected void globalDataFlowOptimization() {
        if (OptimizerUtils.isOptLevel(OptimizerUtils.OptimizationLevel.O4_GLOBAL_TIME_MEMORY)) {
            try {
                this.runtimeProgram = GlobalOptimizerWrapper.optimizeProgram(this.dmlProgram, this.runtimeProgram);
            }
            catch (DMLRuntimeException e) {
                throw new MLContextException("Exception occurred during global data flow optimization", e);
            }
            catch (HopsException e) {
                throw new MLContextException("Exception occurred during global data flow optimization", e);
            }
            catch (LopsException e) {
                throw new MLContextException("Exception occurred during global data flow optimization", e);
            }
        }
    }

    protected void parseScript() {
        try {
            ParserWrapper parser = ParserFactory.createParser(this.script.getScriptType());
            Map<String, Object> inputParameters = this.script.getInputParameters();
            Map<String, String> inputParametersStringMaps = MLContextUtil.convertInputParametersForParser(inputParameters, this.script.getScriptType());
            String scriptExecutionString = this.script.getScriptExecutionString();
            this.dmlProgram = parser.parse(null, scriptExecutionString, inputParametersStringMaps);
        }
        catch (ParseException e) {
            throw new MLContextException("Exception occurred while parsing script", e);
        }
    }

    protected void rewritePersistentReadsAndWrites() {
        LocalVariableMap symbolTable = this.script.getSymbolTable();
        if (symbolTable != null) {
            String[] inputs = this.script.getInputVariables() == null ? new String[]{} : this.script.getInputVariables().toArray(new String[0]);
            String[] outputs = this.script.getOutputVariables() == null ? new String[]{} : this.script.getOutputVariables().toArray(new String[0]);
            RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs, this.script.getSymbolTable());
            ProgramRewriter programRewriter = new ProgramRewriter(rewrite);
            try {
                programRewriter.rewriteProgramHopDAGs(this.dmlProgram);
            }
            catch (HopsException | LanguageException e) {
                throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e);
            }
        }
    }

    public void setConfig(DMLConfig config) {
        this.config = config;
        ConfigurationManager.setGlobalConfig(config);
    }

    protected void liveVariableAnalysis() {
        try {
            this.dmlTranslator = new DMLTranslator(this.dmlProgram);
            this.dmlTranslator.liveVariableAnalysis(this.dmlProgram);
        }
        catch (DMLRuntimeException e) {
            throw new MLContextException("Exception occurred during live variable analysis", e);
        }
        catch (LanguageException e) {
            throw new MLContextException("Exception occurred during live variable analysis", e);
        }
    }

    protected void validateScript() {
        try {
            this.dmlTranslator.validateParseTree(this.dmlProgram);
        }
        catch (LanguageException e) {
            throw new MLContextException("Exception occurred while validating script", e);
        }
        catch (ParseException e) {
            throw new MLContextException("Exception occurred while validating script", e);
        }
        catch (IOException e) {
            throw new MLContextException("Exception occurred while validating script", e);
        }
    }

    protected void checkScriptHasTypeAndString() {
        if (this.script == null) {
            throw new MLContextException("Script is null");
        }
        if (this.script.getScriptType() == null) {
            throw new MLContextException("ScriptType (DML or PYDML) needs to be specified");
        }
        if (this.script.getScriptString() == null) {
            throw new MLContextException("Script string is null");
        }
        if (StringUtils.isBlank((CharSequence)this.script.getScriptString())) {
            throw new MLContextException("Script string is blank");
        }
    }

    public DMLProgram getDmlProgram() {
        return this.dmlProgram;
    }

    public DMLTranslator getDmlTranslator() {
        return this.dmlTranslator;
    }

    public Program getRuntimeProgram() {
        return this.runtimeProgram;
    }

    public ExecutionContext getExecutionContext() {
        return this.executionContext;
    }

    public Script getScript() {
        return this.script;
    }

    public void setExplain(boolean explain) {
        this.explain = explain;
    }

    public void setStatistics(boolean statistics) {
        this.statistics = statistics;
    }

    public void setStatisticsMaxHeavyHitters(int maxHeavyHitters) {
        this.statisticsMaxHeavyHitters = maxHeavyHitters;
    }

    public boolean isMaintainSymbolTable() {
        return this.maintainSymbolTable;
    }

    public void setMaintainSymbolTable(boolean maintainSymbolTable) {
        this.maintainSymbolTable = maintainSymbolTable;
    }

    public void setInit(boolean init) {
        this.init = init;
    }

    public void setExplainLevel(MLContext.ExplainLevel explainLevel) {
        Explain.ExplainType explainType;
        this.explainLevel = explainLevel;
        DMLScript.EXPLAIN = explainLevel == null ? Explain.ExplainType.NONE : (explainType = explainLevel.getExplainType());
    }

    public void setGPU(boolean enabled) {
        this.gpu = enabled;
    }

    public void setForceGPU(boolean enabled) {
        this.forceGPU = enabled;
    }

    public DMLConfig getConfig() {
        return this.config;
    }

    public MLContext.ExecutionType getExecutionType() {
        return this.executionType;
    }

    public void setExecutionType(MLContext.ExecutionType executionType) {
        DMLScript.rtplatform = executionType.getRuntimePlatform();
        this.executionType = executionType;
    }
}

