/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops;

import java.util.ArrayList;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.MemoTable;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.ConvolutionTransform;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.ConvolutionParameters;
import org.apache.sysml.runtime.util.ConvolutionUtils;

public class ConvolutionOp
extends Hop
implements Hop.MultiThreadedHop {
    private static final boolean INFER_TENSOR_SHAPE_FROM_PARENT_CONV_OP = true;
    private static final boolean THROW_ERROR_IF_INFERRED_SHAPE_MISMATCH = true;
    private Hop.ConvOp op;
    private int _maxNumThreads = -1;
    private ConvolutionParameters _cachedParams = new ConvolutionParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, this._maxNumThreads);

    private ConvolutionOp() {
    }

    public ConvolutionOp(String l, Expression.DataType dt, Expression.ValueType vt, Hop.ConvOp o, ArrayList<Hop> inp) {
        super(l, dt, vt);
        this.op = o;
        for (int i = 0; i < inp.size(); ++i) {
            Hop in = inp.get(i);
            this.getInput().add(i, in);
            in.getParent().add(this);
        }
        this.refreshSizeInformation();
    }

    @Override
    public void checkArity() throws HopsException {
        HopsException.check(this._input.size() >= 1, this, "should have at least one input but has %d inputs", this._input.size());
    }

    public Hop.ConvOp getOp() {
        return this.op;
    }

    @Override
    public String getOpString() {
        return "" + HopsConv2Lops.get((Object)this.op);
    }

    private static boolean isEligibleForSpark() {
        return false;
    }

    @Override
    public boolean isGPUEnabled() {
        return DMLScript.USE_ACCELERATOR;
    }

    @Override
    public Lop constructLops() throws HopsException, LopsException {
        if (this.getLops() != null) {
            return this.getLops();
        }
        LopProperties.ExecType et = this.optFindExecType();
        ArrayList<Hop> inputs = this.getInput();
        switch (this.op) {
            case MAX_POOLING: 
            case MAX_POOLING_BACKWARD: 
            case AVG_POOLING: 
            case AVG_POOLING_BACKWARD: 
            case DIRECT_CONV2D: 
            case DIRECT_CONV2D_BACKWARD_DATA: 
            case DIRECT_CONV2D_BACKWARD_FILTER: 
            case BIAS_ADD: 
            case BIAS_MULTIPLY: {
                if (et == LopProperties.ExecType.CP || et == LopProperties.ExecType.GPU) {
                    this.setLops(this.constructConvolutionLops(et, inputs));
                    break;
                }
                throw new HopsException("Unimplemented ConvolutionOp for execution type: " + et.name());
            }
            default: {
                throw new HopsException("Unsupported lops construction for operation type '" + (Object)((Object)this.op) + "'.");
            }
        }
        this.constructAndSetLopsDataFlowProperties();
        return this.getLops();
    }

    public void setOp(Hop.ConvOp op) {
        this.op = op;
    }

    private int getNumExpectedInputs() {
        switch (this.op) {
            case MAX_POOLING_BACKWARD: 
            case AVG_POOLING_BACKWARD: 
            case DIRECT_CONV2D: 
            case DIRECT_CONV2D_BACKWARD_DATA: 
            case DIRECT_CONV2D_BACKWARD_FILTER: {
                return 14;
            }
            case BIAS_ADD: 
            case BIAS_MULTIPLY: {
                return 2;
            }
        }
        return 13;
    }

    private static Hop isInputReLU(Hop input) {
        if (HopRewriteUtils.isBinary(input, Hop.OpOp2.MAX)) {
            if (HopRewriteUtils.isLiteralOfValue(input.getInput().get(0), 0.0)) {
                return input.getInput().get(1);
            }
            if (HopRewriteUtils.isLiteralOfValue(input.getInput().get(1), 0.0)) {
                return input.getInput().get(0);
            }
            return null;
        }
        return null;
    }

    private static boolean isInputConv2d(Hop input) {
        return input instanceof ConvolutionOp && ((ConvolutionOp)input).getOp() == Hop.ConvOp.DIRECT_CONV2D;
    }

    private static boolean isPoolingParametersEqualAndKnown(ConvolutionParameters param1, ConvolutionParameters param2) {
        return ConvolutionOp.isEqualAndKnown(param1.stride_h, param2.stride_h) && ConvolutionOp.isEqualAndKnown(param1.stride_w, param2.stride_w) && ConvolutionOp.isEqualAndKnown(param1.pad_h, param2.pad_h) && ConvolutionOp.isEqualAndKnown(param1.pad_w, param2.pad_w) && ConvolutionOp.isEqualAndKnown(param1.R, param2.R) && ConvolutionOp.isEqualAndKnown(param1.S, param2.S) && ConvolutionOp.isEqualAndKnown(param1.N, param2.N) && ConvolutionOp.isEqualAndKnown(param1.C, param2.C) && ConvolutionOp.isEqualAndKnown(param1.H, param2.H) && ConvolutionOp.isEqualAndKnown(param1.W, param2.W);
    }

    private static boolean isEqualAndKnown(int val1, int val2) {
        return val1 >= 0 && val2 >= 0 && val1 == val2;
    }

    private Lop getMaxPoolOutputLop() throws HopsException, LopsException {
        if (this.op == Hop.ConvOp.MAX_POOLING_BACKWARD || this.op == Hop.ConvOp.AVG_POOLING_BACKWARD) {
            Hop.ConvOp opType = this.op == Hop.ConvOp.MAX_POOLING_BACKWARD ? Hop.ConvOp.MAX_POOLING : Hop.ConvOp.AVG_POOLING;
            Hop inputImage = this.getInput().get(0);
            for (Hop tmpParent : inputImage.getParent()) {
                ConvolutionOp parent;
                if (!(tmpParent instanceof ConvolutionOp) || (parent = (ConvolutionOp)tmpParent).getOp() != opType || !ConvolutionOp.isPoolingParametersEqualAndKnown(parent._cachedParams, this._cachedParams)) continue;
                return parent.constructLops();
            }
        }
        return null;
    }

    public Lop constructConvolutionLops(LopProperties.ExecType et, ArrayList<Hop> inputs) throws HopsException, LopsException {
        if (inputs.size() != this.getNumExpectedInputs()) {
            throw new HopsException("Incorrect number of inputs for " + this.op.name());
        }
        Lop lhsInputLop = null;
        Lop optionalRhsInputLop = null;
        ArrayList<Hop> inputsOfPotentiallyFusedOp = inputs;
        ConvolutionTransform.OperationTypes lopOp = (ConvolutionTransform.OperationTypes)((Object)HopsConv2Lops.get((Object)this.op));
        Hop parentReLU = ConvolutionOp.isInputReLU(inputs.get(0));
        if (OptimizerUtils.ALLOW_OPERATOR_FUSION && et == LopProperties.ExecType.CP && this.op == Hop.ConvOp.MAX_POOLING && parentReLU != null) {
            lhsInputLop = parentReLU.constructLops();
            lopOp = ConvolutionTransform.OperationTypes.RELU_MAX_POOLING;
        } else if (OptimizerUtils.ALLOW_OPERATOR_FUSION && et == LopProperties.ExecType.CP && this.op == Hop.ConvOp.MAX_POOLING_BACKWARD && parentReLU != null) {
            lhsInputLop = parentReLU.constructLops();
            lopOp = ConvolutionTransform.OperationTypes.RELU_MAX_POOLING_BACKWARD;
        } else if (OptimizerUtils.ALLOW_OPERATOR_FUSION && this.op == Hop.ConvOp.BIAS_ADD && ConvolutionOp.isInputConv2d(inputs.get(0))) {
            lopOp = ConvolutionTransform.OperationTypes.DIRECT_CONV2D_BIAS_ADD;
            lhsInputLop = inputs.get(0).getInput().get(0).constructLops();
            optionalRhsInputLop = inputs.get(1).constructLops();
            inputsOfPotentiallyFusedOp = inputs.get(0).getInput();
        } else {
            lhsInputLop = inputs.get(0).constructLops();
        }
        double intermediateMemEstimate = this.computeIntermediateMemEstimate(-1L, -1L, -1L);
        if (et == LopProperties.ExecType.GPU && this._dim1 >= 0L && this._dim2 >= 0L) {
            double optimisticIntermediateMemEstimate = (double)GPUContextPool.initialGPUMemBudget() - this.getOutputMemEstimate() - inputs.get(0).getOutputMemEstimate();
            if (optionalRhsInputLop != null) {
                optimisticIntermediateMemEstimate -= inputs.get(1).getOutputMemEstimate();
            }
            intermediateMemEstimate = Math.max(intermediateMemEstimate, optimisticIntermediateMemEstimate);
        }
        Lop optionalMaxPoolOutput = et == LopProperties.ExecType.GPU ? this.getMaxPoolOutputLop() : null;
        Lop[] l2inputs = new Lop[inputsOfPotentiallyFusedOp.size() - 1];
        for (int i = 1; i < inputsOfPotentiallyFusedOp.size(); ++i) {
            l2inputs[i - 1] = inputsOfPotentiallyFusedOp.get(i).constructLops();
        }
        ConvolutionTransform convolutionLop = new ConvolutionTransform(lhsInputLop, lopOp, this.getDataType(), this.getValueType(), et, OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads), intermediateMemEstimate);
        this.setOutputDimensions(convolutionLop);
        this.setLineNumbers(convolutionLop);
        lhsInputLop.addOutput(convolutionLop);
        if (optionalRhsInputLop != null) {
            convolutionLop.addInput(optionalRhsInputLop);
            optionalRhsInputLop.addOutput(convolutionLop);
        }
        for (int i = 0; i < l2inputs.length; ++i) {
            convolutionLop.addInput(l2inputs[i]);
            l2inputs[i].addOutput(convolutionLop);
        }
        if (optionalMaxPoolOutput != null) {
            convolutionLop.addInput(optionalMaxPoolOutput);
            optionalMaxPoolOutput.addOutput(convolutionLop);
        }
        convolutionLop.updateLopProperties();
        return convolutionLop;
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        if (this.getOp() == Hop.ConvOp.BIAS_MULTIPLY) {
            if (DMLScript.USE_ACCELERATOR) {
                return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
            }
            return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, this.getInput().get(0).getSparsity());
        }
        double sparsity = 1.0;
        return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
    }

    private double computeIntermediateMemEstimateHelper(ArrayList<IntermediateDimensions> gpuIntermediates, ArrayList<IntermediateDimensions> cpIntermediates) {
        int numWorkers = (int)Math.min((long)OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads), Math.max(this.getDim("N"), 1L));
        if (DMLScript.USE_ACCELERATOR) {
            double oneThreadCPMemBudget;
            double gpuMemBudget = IntermediateDimensions.addEstimateSizes(gpuIntermediates, 1);
            double cpMemoryBudget = IntermediateDimensions.addEstimateSizes(cpIntermediates, numWorkers);
            if (cpMemoryBudget > gpuMemBudget && (oneThreadCPMemBudget = IntermediateDimensions.addEstimateSizes(cpIntermediates, 1)) <= gpuMemBudget) {
                cpMemoryBudget = oneThreadCPMemBudget;
            }
            return IntermediateDimensions.guardedMax(cpMemoryBudget, gpuMemBudget);
        }
        return IntermediateDimensions.addEstimateSizes(cpIntermediates, numWorkers);
    }

    @Override
    protected double computeIntermediateMemEstimate(long ignoreDim1, long ignoreDim2, long ignoreNnz) {
        ArrayList<IntermediateDimensions> gpuIntermediates = new ArrayList<IntermediateDimensions>();
        ArrayList<IntermediateDimensions> cpIntermediates = new ArrayList<IntermediateDimensions>();
        if (this.getOp() == Hop.ConvOp.DIRECT_CONV2D) {
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
            gpuIntermediates.add(new IntermediateDimensions(this, "K", "CRS"));
            cpIntermediates.add(new IntermediateDimensions(this, "CRS", "PQ", this.getInput().get(0).getSparsity()));
        } else if (this.getOp() == Hop.ConvOp.DIRECT_CONV2D_BACKWARD_DATA) {
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "KPQ"));
            gpuIntermediates.add(new IntermediateDimensions(this, "K", "CRS"));
            cpIntermediates.add(new IntermediateDimensions(this, "PQ", "K", this.getInput().get(1).getSparsity()));
            cpIntermediates.add(new IntermediateDimensions(this, "PQ", "CRS"));
        } else if (this.getOp() == Hop.ConvOp.DIRECT_CONV2D_BACKWARD_FILTER) {
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "KPQ"));
            cpIntermediates.add(new IntermediateDimensions(this, "PQ", "K", this.getInput().get(1).getSparsity()));
            cpIntermediates.add(new IntermediateDimensions(this, "CRS", "PQ", this.getInput().get(0).getSparsity()));
        } else if (this.getOp() == Hop.ConvOp.MAX_POOLING || this.getOp() == Hop.ConvOp.AVG_POOLING) {
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
        } else if (this.getOp() == Hop.ConvOp.MAX_POOLING_BACKWARD || this.getOp() == Hop.ConvOp.AVG_POOLING_BACKWARD) {
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "CHW"));
            gpuIntermediates.add(new IntermediateDimensions(this, 1, "CPQ"));
        }
        if (gpuIntermediates.size() > 0 || cpIntermediates.size() > 0) {
            return this.computeIntermediateMemEstimateHelper(gpuIntermediates, cpIntermediates);
        }
        return 0.0;
    }

    @Override
    protected long[] inferOutputCharacteristics(MemoTable memo) {
        long[] ret = new long[3];
        if (this.op == Hop.ConvOp.BIAS_ADD || this.op == Hop.ConvOp.BIAS_MULTIPLY) {
            MatrixCharacteristics[] mc = memo.getAllInputStats(this.getInput());
            ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1L;
            ret[1] = mc[0].colsKnown() ? mc[0].getCols() : -1L;
            ret[2] = -1L;
            return ret[0] >= 0L && ret[1] >= 0L ? ret : null;
        }
        this.refreshSizeInformation();
        ret[0] = this._dim1;
        ret[1] = this._dim2;
        ret[2] = this._nnz;
        return ret[0] > 0L && ret[1] > 0L ? ret : null;
    }

    @Override
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override
    protected LopProperties.ExecType optFindExecType() throws HopsException {
        LopProperties.ExecType REMOTE;
        this.checkAndSetForcedPlatform();
        LopProperties.ExecType execType = REMOTE = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = OptimizerUtils.isMemoryBasedOptLevel() ? this.findExecTypeByMemEstimate() : REMOTE;
            this.checkAndSetInvalidCPDimsAndSize();
        }
        this._etype = !ConvolutionOp.isEligibleForSpark() && this._etype == REMOTE ? LopProperties.ExecType.CP : this._etype;
        this.setRequiresRecompileIfNecessary();
        return this._etype;
    }

    ConvolutionParameters parseInput() throws DMLRuntimeException {
        boolean unknownCHWPQ;
        Hop imageHeightHop = null;
        Hop filterHeightHop = null;
        if (this.op == Hop.ConvOp.MAX_POOLING_BACKWARD || this.op == Hop.ConvOp.AVG_POOLING_BACKWARD || this.op == Hop.ConvOp.DIRECT_CONV2D || this.op == Hop.ConvOp.DIRECT_CONV2D_BACKWARD_FILTER || this.op == Hop.ConvOp.DIRECT_CONV2D_BACKWARD_DATA) {
            imageHeightHop = this.getInput().get(8);
            filterHeightHop = this.getInput().get(12);
            this._cachedParams.setIfUnknown(this.getInput().get(6), this.getInput().get(7), imageHeightHop, this.getInput().get(9), this.getInput().get(10), filterHeightHop, this.getInput().get(13), this.getInput().get(2), this.getInput().get(3), this.getInput().get(4), this.getInput().get(5), this._maxNumThreads);
        } else {
            imageHeightHop = this.getInput().get(7);
            filterHeightHop = this.getInput().get(11);
            this._cachedParams.setIfUnknown(this.getInput().get(5), this.getInput().get(6), imageHeightHop, this.getInput().get(8), this.getInput().get(9), filterHeightHop, this.getInput().get(12), this.getInput().get(1), this.getInput().get(2), this.getInput().get(3), this.getInput().get(4), this._maxNumThreads);
        }
        boolean isPool = this.getOp() == Hop.ConvOp.MAX_POOLING || this.getOp() == Hop.ConvOp.AVG_POOLING;
        boolean isConv = this.getOp() == Hop.ConvOp.DIRECT_CONV2D;
        boolean bl = unknownCHWPQ = this._cachedParams.C < 0 || this._cachedParams.H < 0 || this._cachedParams.W < 0 || this._cachedParams.P < 0 || this._cachedParams.Q < 0;
        if ((isPool || isConv) && unknownCHWPQ) {
            this.inferCHWPQFromParentOp();
        }
        if (imageHeightHop == filterHeightHop && this._cachedParams.R < 0 && this._cachedParams.H > 0) {
            this._cachedParams.R = this._cachedParams.H;
        }
        if (this._cachedParams.P < 0 && this._cachedParams.H >= 0 && this._cachedParams.R >= 0 && this._cachedParams.stride_h >= 0 && this._cachedParams.pad_h >= 0) {
            this._cachedParams.P = (int)ConvolutionUtils.getP(this._cachedParams.H, this._cachedParams.R, this._cachedParams.stride_h, this._cachedParams.pad_h);
        }
        if (this._cachedParams.Q < 0 && this._cachedParams.W >= 0 && this._cachedParams.S >= 0 && this._cachedParams.stride_w >= 0 && this._cachedParams.pad_w >= 0) {
            this._cachedParams.Q = (int)ConvolutionUtils.getQ(this._cachedParams.W, this._cachedParams.S, this._cachedParams.stride_w, this._cachedParams.pad_w);
        }
        return this._cachedParams;
    }

    private static boolean isInputBiasAdd(Hop hop) {
        return hop instanceof ConvolutionOp && ((ConvolutionOp)hop).getOp() == Hop.ConvOp.BIAS_ADD;
    }

    private void throwExceptionIfNotEqual(int dim1, int dim2, String paramType) throws DMLRuntimeException {
        if (dim1 >= 0 && dim2 >= 0 && dim1 != dim2) {
            throw new DMLRuntimeException("Inferred " + paramType + " from parent doesn't match with given " + paramType + ":" + dim1 + " != " + dim2);
        }
    }

    private void inferCHWPQFromParentOp() throws DMLRuntimeException {
        ConvolutionOp parentOp;
        Hop tmp = this.getInput().get(0);
        tmp = ConvolutionOp.isInputBiasAdd(tmp) ? tmp.getInput().get(0) : tmp;
        Hop parentReLU = ConvolutionOp.isInputReLU(tmp);
        tmp = parentReLU != null ? parentReLU : tmp;
        ConvolutionOp convolutionOp = parentOp = tmp instanceof ConvolutionOp ? (ConvolutionOp)tmp : null;
        if (parentOp == null) {
            return;
        }
        if (parentOp.getOp() == Hop.ConvOp.MAX_POOLING || parentOp.getOp() == Hop.ConvOp.AVG_POOLING) {
            ConvolutionParameters parentParam = parentOp.parseInput();
            int prevC = this._cachedParams.C;
            int prevH = this._cachedParams.H;
            int prevW = this._cachedParams.W;
            this._cachedParams.C = this._cachedParams.C < 0 ? parentParam.C : this._cachedParams.C;
            this._cachedParams.H = this._cachedParams.H < 0 ? parentParam.P : this._cachedParams.H;
            int n = this._cachedParams.W = this._cachedParams.W < 0 ? parentParam.Q : this._cachedParams.W;
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + this._cachedParams.C + "," + this._cachedParams.H + "," + this._cachedParams.W + "]"));
            }
            this.throwExceptionIfNotEqual(prevC, this._cachedParams.C, "C");
            this.throwExceptionIfNotEqual(prevH, this._cachedParams.H, "H");
            this.throwExceptionIfNotEqual(prevW, this._cachedParams.W, "W");
        } else if (parentOp.getOp() == Hop.ConvOp.DIRECT_CONV2D) {
            ConvolutionParameters parentParam = parentOp.parseInput();
            int prevC = this._cachedParams.C;
            int prevH = this._cachedParams.H;
            int prevW = this._cachedParams.W;
            this._cachedParams.C = this._cachedParams.C < 0 ? parentParam.K : this._cachedParams.C;
            this._cachedParams.H = this._cachedParams.H < 0 ? parentParam.P : this._cachedParams.H;
            int n = this._cachedParams.W = this._cachedParams.W < 0 ? parentParam.Q : this._cachedParams.W;
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Inferring [C,H,W] from maxpool parent: [" + prevC + "," + prevH + "," + prevW + "]-> [" + this._cachedParams.C + "," + this._cachedParams.H + "," + this._cachedParams.W + "]"));
            }
            this.throwExceptionIfNotEqual(prevC, this._cachedParams.C, "C");
            this.throwExceptionIfNotEqual(prevH, this._cachedParams.H, "H");
            this.throwExceptionIfNotEqual(prevW, this._cachedParams.W, "W");
        }
    }

    @Override
    public void refreshSizeInformation() {
        if (this.op == Hop.ConvOp.BIAS_ADD || this.op == Hop.ConvOp.BIAS_MULTIPLY) {
            Hop input1 = this.getInput().get(0);
            this.setDim1(input1.getDim1());
            this.setDim2(input1.getDim2());
            this._nnz = -1L;
            return;
        }
        this._cachedParams = new ConvolutionParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, this._maxNumThreads);
        switch (this.op) {
            case MAX_POOLING: 
            case AVG_POOLING: {
                this._dim1 = this.getDim("N");
                this._dim2 = this.getDim("CPQ");
                this._nnz = -1L;
                break;
            }
            case MAX_POOLING_BACKWARD: 
            case AVG_POOLING_BACKWARD: {
                this._dim1 = this.getDim("N");
                this._dim2 = this.getDim("CHW");
                this._nnz = -1L;
                break;
            }
            case DIRECT_CONV2D: {
                this._dim1 = this.getDim("N");
                this._dim2 = this.getDim("KPQ");
                this._nnz = -1L;
                break;
            }
            case DIRECT_CONV2D_BACKWARD_DATA: {
                this._dim1 = this.getDim("N");
                this._dim2 = this.getDim("CHW");
                this._nnz = -1L;
                break;
            }
            case DIRECT_CONV2D_BACKWARD_FILTER: {
                this._dim1 = this.getDim("K");
                this._dim2 = this.getDim("CRS");
                this._nnz = -1L;
                break;
            }
            default: {
                throw new RuntimeException("The sizes are not refreshed for " + this.op.name());
            }
        }
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        ConvolutionOp ret = new ConvolutionOp();
        ret.clone(this, false);
        ret.op = this.op;
        ret._maxNumThreads = this._maxNumThreads;
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        boolean ret;
        if (!(that instanceof ConvolutionOp)) {
            return false;
        }
        ConvolutionOp that2 = (ConvolutionOp)that;
        boolean bl = ret = this.op == that2.op && this.getInput().size() == that.getInput().size() && this._maxNumThreads == that2._maxNumThreads;
        if (ret) {
            for (int i = 0; i < this._input.size(); ++i) {
                ret &= this.getInput().get(i) == that2.getInput().get(i);
            }
        }
        return ret;
    }

    @Override
    public void setMaxNumThreads(int k) {
        this._maxNumThreads = k;
    }

    @Override
    public int getMaxNumThreads() {
        return this._maxNumThreads;
    }

    private long getDim(String dimString) {
        if (this.op == Hop.ConvOp.BIAS_ADD || this.op == Hop.ConvOp.BIAS_MULTIPLY) {
            throw new RuntimeException("getDim method should not be invoked for bias_add and bias_multiply");
        }
        try {
            this.parseInput();
        }
        catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
        Hop filter = null;
        Hop input = null;
        Hop dout = null;
        Hop dout1 = null;
        if (this.getOp() == Hop.ConvOp.DIRECT_CONV2D) {
            input = this.getInput().get(0);
            filter = this.getInput().get(1);
        } else if (this.getOp() == Hop.ConvOp.DIRECT_CONV2D_BACKWARD_DATA) {
            filter = this.getInput().get(0);
            dout = this.getInput().get(1);
        } else if (this.getOp() == Hop.ConvOp.DIRECT_CONV2D_BACKWARD_FILTER) {
            input = this.getInput().get(0);
            dout = this.getInput().get(1);
        } else if (this.getOp() == Hop.ConvOp.MAX_POOLING || this.getOp() == Hop.ConvOp.AVG_POOLING) {
            input = this.getInput().get(0);
        } else if (this.getOp() == Hop.ConvOp.MAX_POOLING_BACKWARD || this.getOp() == Hop.ConvOp.AVG_POOLING_BACKWARD) {
            input = this.getInput().get(0);
            dout1 = this.getInput().get(1);
        }
        long ret = -1L;
        if (dimString.equals("K") && filter != null) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.getNonNegative(this._cachedParams.K, filter._dim1));
        } else if (dimString.equals("CRS") && filter != null) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.getNonNegative(ConvolutionOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.R, this._cachedParams.S), filter._dim2));
        } else if (dimString.equals("N") && input != null) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.getNonNegative(this._cachedParams.N, input._dim1));
        } else if (dimString.equals("CHW") && input != null) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.getNonNegative(ConvolutionOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.H, this._cachedParams.W), input._dim2));
        } else if (dimString.equals("N") && dout != null) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.getNonNegative(this._cachedParams.N, dout._dim1));
        } else if (dimString.equals("KPQ") && dout != null) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.getNonNegative(ConvolutionOp.nonNegativeMultiply(this._cachedParams.K, this._cachedParams.P, this._cachedParams.Q), dout._dim2));
        } else if (dimString.equals("N") && dout1 != null) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.getNonNegative(this._cachedParams.N, dout1._dim1));
        } else if (dimString.equals("CPQ") && dout1 != null) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.getNonNegative(ConvolutionOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.P, this._cachedParams.Q), dout1._dim2));
        } else if (dimString.equals("K")) {
            ret = ConvolutionOp.getNonNegative(ret, this._cachedParams.K >= 0 ? (long)this._cachedParams.K : -1L);
        } else if (dimString.equals("CRS")) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.R, this._cachedParams.S));
        } else if (dimString.equals("N")) {
            ret = ConvolutionOp.getNonNegative(ret, this._cachedParams.N >= 0 ? (long)this._cachedParams.N : -1L);
        } else if (dimString.equals("CHW")) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.H, this._cachedParams.W));
        } else if (dimString.equals("KPQ")) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.nonNegativeMultiply(this._cachedParams.K, this._cachedParams.P, this._cachedParams.Q));
        } else if (dimString.equals("PQ")) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.nonNegativeMultiply(this._cachedParams.P, this._cachedParams.Q));
        } else if (dimString.equals("CPQ")) {
            ret = ConvolutionOp.getNonNegative(ret, ConvolutionOp.nonNegativeMultiply(this._cachedParams.C, this._cachedParams.P, this._cachedParams.Q));
        } else {
            throw new RuntimeException("Unsupported dimension:" + dimString + " for operator " + this.getOp().name());
        }
        if (LOG.isDebugEnabled() && ret < 0L) {
            LOG.debug((Object)("Unknown dimension " + dimString + " for ConvolutionOp:" + this.op.name() + " img_dim=[" + this._cachedParams.N + " " + this._cachedParams.C + " " + this._cachedParams.H + " " + this._cachedParams.W + "] filter_dim=[" + this._cachedParams.K + " " + this._cachedParams.C + " " + this._cachedParams.R + " " + this._cachedParams.S + "] output_feature_map=[" + this._cachedParams.P + " " + this._cachedParams.Q + "] stride=[" + this._cachedParams.stride_h + " " + this._cachedParams.stride_w + "] pad=[" + this._cachedParams.pad_h + " " + this._cachedParams.pad_w + "]"));
        }
        return ret;
    }

    private static long nonNegativeMultiply(long val1, long val2, long val3) {
        if (val1 >= 0L && val2 >= 0L && val3 >= 0L) {
            return val1 * val2 * val3;
        }
        return -1L;
    }

    private static long nonNegativeMultiply(long val1, long val2) {
        if (val1 >= 0L && val2 >= 0L) {
            return val1 * val2;
        }
        return -1L;
    }

    private static long getNonNegative(long val1, long val2) {
        if (val1 >= 0L && val2 >= 0L) {
            if (val1 == val2) {
                return val1;
            }
            throw new RuntimeException("Incorrect dimensions in Convolution Hop: " + val1 + " != " + val2);
        }
        if (val1 >= 0L) {
            return val1;
        }
        if (val2 >= 0L) {
            return val2;
        }
        return -1L;
    }

    private static class IntermediateDimensions {
        int dim1;
        int dim2;
        double sp;

        public IntermediateDimensions(ConvolutionOp h, String dim1Str, String dim2Str, double sp) {
            this.dim1 = (int)h.getDim(dim1Str);
            this.dim2 = (int)h.getDim(dim2Str);
            this.sp = sp;
        }

        public IntermediateDimensions(ConvolutionOp h, String dim1Str, String dim2Str) {
            this.dim1 = (int)h.getDim(dim1Str);
            this.dim2 = (int)h.getDim(dim2Str);
            this.sp = 1.0;
        }

        public IntermediateDimensions(ConvolutionOp h, int dim1, String dim2Str) {
            this.dim1 = dim1;
            this.dim2 = (int)h.getDim(dim2Str);
            this.sp = 1.0;
        }

        static double guardedAdd(double val1, double val2) {
            if (val1 < 0.0 || val2 < 0.0) {
                return OptimizerUtils.DEFAULT_SIZE;
            }
            double ret = val1 + val2;
            if (ret >= OptimizerUtils.DEFAULT_SIZE) {
                return OptimizerUtils.DEFAULT_SIZE;
            }
            return ret;
        }

        public static double addEstimateSizes(ArrayList<IntermediateDimensions> intermediates, int numWorkers) {
            double memBudget = 0.0;
            for (int i = 0; i < intermediates.size(); ++i) {
                memBudget = IntermediateDimensions.guardedAdd(memBudget, OptimizerUtils.estimateSizeExactSparsity((long)intermediates.get((int)i).dim1, (long)intermediates.get((int)i).dim2, intermediates.get((int)i).sp) * (long)numWorkers);
            }
            return memBudget;
        }

        public static double guardedMax(double val1, double val2) {
            if (val1 < 0.0 || val2 < 0.0) {
                return OptimizerUtils.DEFAULT_SIZE;
            }
            double ret = Math.max(val1, val2);
            if (ret >= OptimizerUtils.DEFAULT_SIZE) {
                return OptimizerUtils.DEFAULT_SIZE;
            }
            return ret;
        }
    }
}

