/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.controlprogram.parfor.opt;

import java.util.HashMap;
import java.util.HashSet;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.parser.ParForStatementBlock;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.opt.CostEstimator;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptNode;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptTree;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptTreeConverter;
import org.apache.sysml.runtime.controlprogram.parfor.opt.OptimizerRuleBased;

public class OptimizerConstrained
extends OptimizerRuleBased {
    @Override
    public ParForProgramBlock.POptMode getOptMode() {
        return ParForProgramBlock.POptMode.CONSTRAINED;
    }

    @Override
    public boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, OptTree plan, CostEstimator est, ExecutionContext ec) throws DMLRuntimeException {
        LOG.debug((Object)("--- " + (Object)((Object)this.getOptMode()) + " OPTIMIZER -------"));
        OptNode pn = plan.getRoot();
        if (pn.isLeaf()) {
            return true;
        }
        super.analyzeProblemAndInfrastructure(pn);
        this._cost = est;
        LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: Optimize with local_max_mem=" + OptimizerConstrained.toMB(this._lm) + " and remote_max_mem=" + OptimizerConstrained.toMB(this._rm) + ")."));
        if (this._rnk <= 0 || this._rk <= 0) {
            LOG.warn((Object)((Object)((Object)this.getOptMode()) + " OPT: Optimize for inactive cluster (num_nodes=" + this._rnk + ", num_map_slots=" + this._rk + ")."));
        }
        OptNode.ExecType oldET = pn.getExecType();
        int oldK = pn.getK();
        pn.setSerialParFor();
        double M0a = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn);
        pn.setExecType(oldET);
        pn.setK(oldK);
        LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: estimated mem (serial exec) M=" + OptimizerConstrained.toMB(M0a)));
        HashMap<String, ParForProgramBlock.PartitionFormat> partitionedMatrices = new HashMap<String, ParForProgramBlock.PartitionFormat>();
        this.rewriteSetDataPartitioner(pn, ec.getVariables(), partitionedMatrices, OptimizerUtils.getLocalMemBudget());
        double M0b = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn);
        this.rewriteRemoveUnnecessaryCompareMatrix(pn, ec);
        boolean flagLIX = super.rewriteSetResultPartitioning(pn, M0b, ec.getVariables());
        double M1 = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn);
        LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: estimated new mem (serial exec) M=" + OptimizerConstrained.toMB(M1)));
        double M2 = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn, LopProperties.ExecType.CP);
        LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: estimated new mem (serial exec, all CP) M=" + OptimizerConstrained.toMB(M2)));
        double M3 = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn, true);
        LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: estimated new mem (cond partitioning) M=" + OptimizerConstrained.toMB(M3)));
        ParForProgramBlock.PExecMode tmpmode = this.getPExecMode(pn);
        boolean flagRecompMR = this.rewriteSetExecutionStategy(pn, M0a, M1, M2, M3, flagLIX);
        if (pn.getExecType() == this.getRemoteExecType()) {
            if (M1 > this._rm && M3 <= this._rm) {
                this.rewriteSetDataPartitioner(pn, ec.getVariables(), partitionedMatrices, M3);
                M1 = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn);
            }
            if (flagRecompMR) {
                this.rewriteSetOperationsExecType(pn, flagRecompMR);
                M1 = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn);
            }
            super.rewriteDataColocation(pn, ec.getVariables());
            super.rewriteSetPartitionReplicationFactor(pn, partitionedMatrices, ec.getVariables());
            super.rewriteSetExportReplicationFactor(pn, ec.getVariables());
            boolean flagNested = super.rewriteNestedParallelism(pn, M1, flagLIX);
            this.rewriteSetDegreeOfParallelism(pn, M1, flagNested);
            this.rewriteSetTaskPartitioner(pn, flagNested, flagLIX);
            this.rewriteSetFusedDataPartitioningExecution(pn, M1, flagLIX, partitionedMatrices, ec.getVariables(), tmpmode);
            super.rewriteSetTranposeSparseVectorOperations(pn, partitionedMatrices, ec.getVariables());
            HashSet<String> inplaceResultVars = new HashSet<String>();
            super.rewriteSetInPlaceResultIndexing(pn, M1, ec.getVariables(), inplaceResultVars, ec);
            super.rewriteDisableCPCaching(pn, inplaceResultVars, ec.getVariables());
        } else {
            this.rewriteSetDegreeOfParallelism(pn, M1, false);
            this.rewriteSetTaskPartitioner(pn, false, false);
            HashSet<String> inplaceResultVars = new HashSet<String>();
            super.rewriteSetInPlaceResultIndexing(pn, M1, ec.getVariables(), inplaceResultVars, ec);
            if (!OptimizerUtils.isSparkExecutionMode()) {
                super.rewriteEnableRuntimePiggybacking(pn, ec.getVariables(), partitionedMatrices);
            } else {
                super.rewriteInjectSparkLoopCheckpointing(pn);
                super.rewriteInjectSparkRepartition(pn, ec.getVariables());
                super.rewriteSetSparkEagerRDDCaching(pn, ec.getVariables());
            }
        }
        this.rewriteSetResultMerge(pn, ec.getVariables(), true);
        super.rewriteSetRecompileMemoryBudget(pn);
        super.rewriteRemoveRecursiveParFor(pn, ec.getVariables());
        super.rewriteRemoveUnnecessaryParFor(pn);
        this._numEvaluatedPlans = 1L;
        return true;
    }

    @Override
    protected boolean rewriteSetDataPartitioner(OptNode n, LocalVariableMap vars, HashMap<String, ParForProgramBlock.PartitionFormat> partitionedMatrices, double thetaM) throws DMLRuntimeException {
        boolean blockwise = false;
        if (!n.getParam(OptNode.ParamType.DATA_PARTITIONER).equals(ParForProgramBlock.PDataPartitioner.UNSPECIFIED.toString())) {
            Object[] o = OptTreeConverter.getAbstractPlanMapping().getMappedProg(n.getID());
            ParForProgramBlock pfpb = (ParForProgramBlock)o[1];
            pfpb.setDataPartitioner(ParForProgramBlock.PDataPartitioner.valueOf(n.getParam(OptNode.ParamType.DATA_PARTITIONER)));
            LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: forced 'set data partitioner' - result=" + n.getParam(OptNode.ParamType.DATA_PARTITIONER)));
        } else {
            super.rewriteSetDataPartitioner(n, vars, partitionedMatrices, thetaM);
        }
        return blockwise;
    }

    @Override
    protected boolean rewriteSetExecutionStategy(OptNode n, double M0, double M, double M2, double M3, boolean flagLIX) throws DMLRuntimeException {
        boolean ret = false;
        if (n.getExecType() != null && ConfigurationManager.isParallelParFor()) {
            ParForProgramBlock pfpb = (ParForProgramBlock)OptTreeConverter.getAbstractPlanMapping().getMappedProg(n.getID())[1];
            ParForProgramBlock.PExecMode mode = ParForProgramBlock.PExecMode.LOCAL;
            if (n.getExecType() == OptNode.ExecType.MR) {
                mode = ParForProgramBlock.PExecMode.REMOTE_MR;
            } else if (n.getExecType() == OptNode.ExecType.SPARK) {
                mode = ParForProgramBlock.PExecMode.REMOTE_SPARK;
            }
            pfpb.setExecMode(mode);
            LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: forced 'set execution strategy' - result=" + (Object)((Object)mode)));
        } else {
            ret = super.rewriteSetExecutionStategy(n, M0, M, M2, M3, flagLIX);
        }
        return ret;
    }

    @Override
    protected void rewriteSetDegreeOfParallelism(OptNode n, double M, boolean flagNested) throws DMLRuntimeException {
        if (n.getK() > 0 && ConfigurationManager.isParallelParFor()) {
            ParForProgramBlock pfpb = (ParForProgramBlock)OptTreeConverter.getAbstractPlanMapping().getMappedProg(n.getID())[1];
            pfpb.setDegreeOfParallelism(n.getK());
            LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: forced 'set degree of parallelism' - result=(see EXPLAIN)"));
        } else {
            super.rewriteSetDegreeOfParallelism(n, M, flagNested);
        }
    }

    @Override
    protected void rewriteSetTaskPartitioner(OptNode pn, boolean flagNested, boolean flagLIX) {
        if (!pn.getParam(OptNode.ParamType.TASK_PARTITIONER).equals(ParForProgramBlock.PTaskPartitioner.UNSPECIFIED.toString())) {
            ParForProgramBlock pfpb = (ParForProgramBlock)OptTreeConverter.getAbstractPlanMapping().getMappedProg(pn.getID())[1];
            pfpb.setTaskPartitioner(ParForProgramBlock.PTaskPartitioner.valueOf(pn.getParam(OptNode.ParamType.TASK_PARTITIONER)));
            String tsExt = "";
            if (pn.getParam(OptNode.ParamType.TASK_SIZE) != null) {
                pfpb.setTaskSize(Integer.parseInt(pn.getParam(OptNode.ParamType.TASK_SIZE)));
                tsExt = tsExt + "," + pn.getParam(OptNode.ParamType.TASK_SIZE);
            }
            LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: forced 'set task partitioner' - result=" + pn.getParam(OptNode.ParamType.TASK_PARTITIONER) + tsExt));
        } else {
            if (pn.getParam(OptNode.ParamType.TASK_SIZE) != null) {
                LOG.warn((Object)"Cannot force task size without forcing task partitioner.");
            }
            super.rewriteSetTaskPartitioner(pn, flagNested, flagLIX);
        }
    }

    @Override
    protected void rewriteSetResultMerge(OptNode n, LocalVariableMap vars, boolean inLocal) throws DMLRuntimeException {
        if (!n.getParam(OptNode.ParamType.RESULT_MERGE).equals(ParForProgramBlock.PResultMerge.UNSPECIFIED.toString())) {
            ParForProgramBlock pfpb = (ParForProgramBlock)OptTreeConverter.getAbstractPlanMapping().getMappedProg(n.getID())[1];
            pfpb.setResultMerge(ParForProgramBlock.PResultMerge.valueOf(n.getParam(OptNode.ParamType.RESULT_MERGE)));
            LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: force 'set result merge' - result=" + n.getParam(OptNode.ParamType.RESULT_MERGE)));
        } else {
            super.rewriteSetResultMerge(n, vars, inLocal);
        }
    }

    protected void rewriteSetFusedDataPartitioningExecution(OptNode pn, double M, boolean flagLIX, HashMap<String, ParForProgramBlock.PartitionFormat> partitionedMatrices, LocalVariableMap vars, ParForProgramBlock.PExecMode emode) throws DMLRuntimeException {
        if (emode == ParForProgramBlock.PExecMode.REMOTE_MR_DP || emode == ParForProgramBlock.PExecMode.REMOTE_SPARK_DP) {
            ParForProgramBlock pfpb = (ParForProgramBlock)OptTreeConverter.getAbstractPlanMapping().getMappedProg(pn.getID())[1];
            if (partitionedMatrices.size() <= 0) {
                LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: unable to force 'set fused data partitioning and execution' - result=" + false));
                return;
            }
            String moVarname = partitionedMatrices.keySet().iterator().next();
            ParForProgramBlock.PartitionFormat moDpf = partitionedMatrices.get(moVarname);
            MatrixObject mo = (MatrixObject)vars.get(moVarname);
            String iterVarname = pfpb.getIterablePredicateVars()[0];
            if (this.rIsAccessByIterationVariable(pn, moVarname, iterVarname) && (moDpf == ParForProgramBlock.PartitionFormat.ROW_WISE && mo.getNumRows() == this._N || moDpf == ParForProgramBlock.PartitionFormat.COLUMN_WISE && mo.getNumColumns() == this._N || moDpf._dpf == ParForProgramBlock.PDataPartitionFormat.ROW_BLOCK_WISE_N && mo.getNumRows() <= this._N * (long)moDpf._N || moDpf._dpf == ParForProgramBlock.PDataPartitionFormat.COLUMN_BLOCK_WISE_N && mo.getNumColumns() <= this._N * (long)moDpf._N)) {
                int k = (int)Math.min(this._N, (long)this._rk2);
                if (emode == ParForProgramBlock.PExecMode.REMOTE_MR_DP) {
                    pn.addParam(OptNode.ParamType.DATA_PARTITIONER, "REMOTE_MR(fused)");
                    pfpb.setExecMode(ParForProgramBlock.PExecMode.REMOTE_MR_DP);
                } else {
                    pn.addParam(OptNode.ParamType.DATA_PARTITIONER, "REMOTE_SPARK(fused)");
                    pfpb.setExecMode(ParForProgramBlock.PExecMode.REMOTE_SPARK_DP);
                }
                pn.setK(k);
                pfpb.setDataPartitioner(ParForProgramBlock.PDataPartitioner.NONE);
                pfpb.enableColocatedPartitionedMatrix(moVarname);
                pfpb.setDegreeOfParallelism(k);
            }
            LOG.debug((Object)((Object)((Object)this.getOptMode()) + " OPT: force 'set fused data partitioning and execution' - result=" + true));
        } else {
            super.rewriteSetFusedDataPartitioningExecution(pn, M, flagLIX, partitionedMatrices, vars);
        }
    }

    private ParForProgramBlock.PExecMode getPExecMode(OptNode pn) {
        ParForProgramBlock pfpb = (ParForProgramBlock)OptTreeConverter.getAbstractPlanMapping().getMappedProg(pn.getID())[1];
        return pfpb.getExecMode();
    }
}

