/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops;

import java.util.ArrayList;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.DataGenOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.MemoTable;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.Aggregate;
import org.apache.sysml.lops.Group;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.LopsException;
import org.apache.sysml.lops.SortKeys;
import org.apache.sysml.lops.Transform;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;

public class ReorgOp
extends Hop
implements Hop.MultiThreadedHop {
    public static boolean FORCE_DIST_SORT_INDEXES = false;
    public boolean bSortSPRewriteApplicable = false;
    private Hop.ReOrgOp op;
    private int _maxNumThreads = -1;

    private ReorgOp() {
    }

    public ReorgOp(String l, Expression.DataType dt, Expression.ValueType vt, Hop.ReOrgOp o, Hop inp) {
        super(l, dt, vt);
        this.op = o;
        this.getInput().add(0, inp);
        inp.getParent().add(this);
        this.refreshSizeInformation();
    }

    public ReorgOp(String l, Expression.DataType dt, Expression.ValueType vt, Hop.ReOrgOp o, ArrayList<Hop> inp) {
        super(l, dt, vt);
        this.op = o;
        for (int i = 0; i < inp.size(); ++i) {
            Hop in = inp.get(i);
            this.getInput().add(i, in);
            in.getParent().add(this);
        }
        this.refreshSizeInformation();
    }

    @Override
    public void checkArity() throws HopsException {
        int sz = this._input.size();
        switch (this.op) {
            case TRANSPOSE: 
            case DIAG: 
            case REV: {
                HopsException.check(sz == 1, this, "should have arity 1 for op %s but has arity %d", new Object[]{this.op, sz});
                break;
            }
            case RESHAPE: 
            case SORT: {
                HopsException.check(sz == 4, this, "should have arity 4 for op %s but has arity %d", new Object[]{this.op, sz});
                break;
            }
            default: {
                throw new HopsException("Unsupported lops construction for operation type '" + (Object)((Object)this.op) + "'.");
            }
        }
    }

    @Override
    public void setMaxNumThreads(int k) {
        this._maxNumThreads = k;
    }

    @Override
    public int getMaxNumThreads() {
        return this._maxNumThreads;
    }

    public Hop.ReOrgOp getOp() {
        return this.op;
    }

    @Override
    public String getOpString() {
        String s = new String("");
        s = s + "r(" + (String)HopsTransf2String.get((Object)this.op) + ")";
        return s;
    }

    @Override
    public boolean isGPUEnabled() {
        if (!DMLScript.USE_ACCELERATOR) {
            return false;
        }
        switch (this.op) {
            case TRANSPOSE: {
                Lop lin;
                try {
                    lin = this.getInput().get(0).constructLops();
                }
                catch (HopsException | LopsException e) {
                    throw new RuntimeException("Unable to create child lop", e);
                }
                if (lin instanceof Transform && ((Transform)lin).getOperationType() == Transform.OperationTypes.Transpose) {
                    return false;
                }
                return this.getDim1() != 1L || this.getDim2() != 1L;
            }
            case DIAG: 
            case REV: 
            case RESHAPE: 
            case SORT: {
                return false;
            }
        }
        throw new RuntimeException("Unsupported operator:" + this.op.name());
    }

    @Override
    public Lop constructLops() throws HopsException, LopsException {
        if (this.getLops() != null) {
            return this.getLops();
        }
        LopProperties.ExecType et = this.optFindExecType();
        switch (this.op) {
            case TRANSPOSE: {
                Lop lin = this.getInput().get(0).constructLops();
                if (lin instanceof Transform && ((Transform)lin).getOperationType() == Transform.OperationTypes.Transpose) {
                    this.setLops(lin.getInputs().get(0));
                    break;
                }
                if (this.getDim1() == 1L && this.getDim2() == 1L) {
                    this.setLops(lin);
                    break;
                }
                int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
                Transform transform1 = new Transform(lin, (Transform.OperationTypes)((Object)HopsTransf2Lops.get((Object)this.op)), this.getDataType(), this.getValueType(), et, k);
                this.setOutputDimensions(transform1);
                this.setLineNumbers(transform1);
                this.setLops(transform1);
                break;
            }
            case DIAG: {
                Transform transform1 = new Transform(this.getInput().get(0).constructLops(), (Transform.OperationTypes)((Object)HopsTransf2Lops.get((Object)this.op)), this.getDataType(), this.getValueType(), et);
                this.setOutputDimensions(transform1);
                this.setLineNumbers(transform1);
                this.setLops(transform1);
                break;
            }
            case REV: {
                Lop rev = null;
                if (et == LopProperties.ExecType.MR) {
                    Transform tmp = new Transform(this.getInput().get(0).constructLops(), (Transform.OperationTypes)((Object)HopsTransf2Lops.get((Object)this.op)), this.getDataType(), this.getValueType(), et);
                    this.setOutputDimensions(tmp);
                    this.setLineNumbers(tmp);
                    Group group1 = new Group(tmp, Group.OperationTypes.Sort, Expression.DataType.MATRIX, this.getValueType());
                    this.setOutputDimensions(group1);
                    this.setLineNumbers(group1);
                    rev = new Aggregate(group1, Aggregate.OperationTypes.Sum, Expression.DataType.MATRIX, this.getValueType(), et);
                } else {
                    rev = new Transform(this.getInput().get(0).constructLops(), (Transform.OperationTypes)((Object)HopsTransf2Lops.get((Object)this.op)), this.getDataType(), this.getValueType(), et);
                }
                this.setOutputDimensions(rev);
                this.setLineNumbers(rev);
                this.setLops(rev);
                break;
            }
            case RESHAPE: {
                if (et == LopProperties.ExecType.MR) {
                    Transform transform1 = new Transform(this.getInput().get(0).constructLops(), (Transform.OperationTypes)((Object)HopsTransf2Lops.get((Object)this.op)), this.getDataType(), this.getValueType(), et);
                    this.setOutputDimensions(transform1);
                    this.setLineNumbers(transform1);
                    for (int i = 1; i <= 3; ++i) {
                        Lop ltmp = this.getInput().get(i).constructLops();
                        transform1.addInput(ltmp);
                        ltmp.addOutput(transform1);
                    }
                    transform1.setLevel();
                    Group group1 = new Group(transform1, Group.OperationTypes.Sort, Expression.DataType.MATRIX, this.getValueType());
                    this.setOutputDimensions(group1);
                    this.setLineNumbers(group1);
                    Aggregate agg1 = new Aggregate(group1, Aggregate.OperationTypes.Sum, Expression.DataType.MATRIX, this.getValueType(), et);
                    this.setOutputDimensions(agg1);
                    this.setLineNumbers(agg1);
                    this.setLops(agg1);
                    break;
                }
                Transform transform1 = new Transform(this.getInput().get(0).constructLops(), (Transform.OperationTypes)((Object)HopsTransf2Lops.get((Object)this.op)), this.getDataType(), this.getValueType(), et);
                this.setOutputDimensions(transform1);
                this.setLineNumbers(transform1);
                for (int i = 1; i <= 3; ++i) {
                    Lop ltmp = this.getInput().get(i).constructLops();
                    transform1.addInput(ltmp);
                    ltmp.addOutput(transform1);
                }
                transform1.setLevel();
                this.setLops(transform1);
                break;
            }
            case SORT: {
                Hop input = this.getInput().get(0);
                Hop by = this.getInput().get(1);
                Hop desc = this.getInput().get(2);
                Hop ixret = this.getInput().get(3);
                if (et == LopProperties.ExecType.MR) {
                    if (!(desc instanceof LiteralOp) || !(ixret instanceof LiteralOp)) {
                        LOG.warn((Object)"Unsupported non-constant ordering parameters, using defaults and mark for recompilation.");
                        this.setRequiresRecompile();
                        desc = new LiteralOp(false);
                        ixret = new LiteralOp(false);
                    }
                    Hop vinput = input;
                    if (input.getDim2() != 1L) {
                        vinput = new IndexingOp("tmp1", this.getDataType(), this.getValueType(), input, new LiteralOp(1L), HopRewriteUtils.createValueHop(input, true), by, by, false, true);
                        vinput.refreshSizeInformation();
                        vinput.setOutputBlocksizes(this.getRowsInBlock(), this.getColsInBlock());
                        HopRewriteUtils.copyLineNumbers(this, vinput);
                    }
                    ReorgOp voutput = null;
                    if ((double)(2L * OptimizerUtils.estimateSize(vinput.getDim1(), vinput.getDim2())) > OptimizerUtils.getLocalMemBudget() || FORCE_DIST_SORT_INDEXES) {
                        SortKeys sort = new SortKeys(vinput.constructLops(), HopRewriteUtils.getBooleanValueSafe((LiteralOp)desc), SortKeys.OperationTypes.Indexes, vinput.getDataType(), vinput.getValueType(), LopProperties.ExecType.MR);
                        sort.getOutputParameters().setDimensions(vinput.getDim1(), 1L, vinput.getRowsInBlock(), vinput.getColsInBlock(), vinput.getNnz());
                        this.setLineNumbers(sort);
                        this.setLops(sort);
                        voutput = this;
                    } else {
                        ArrayList<Hop> sinputs = new ArrayList<Hop>();
                        sinputs.add(vinput);
                        sinputs.add(new LiteralOp(1L));
                        sinputs.add(desc);
                        sinputs.add(new LiteralOp(true));
                        voutput = new ReorgOp("tmp3", this.getDataType(), this.getValueType(), Hop.ReOrgOp.SORT, sinputs);
                        HopRewriteUtils.copyLineNumbers(this, voutput);
                        voutput.setLops(ReorgOp.constructCPOrSparkSortLop(vinput, sinputs.get(1), sinputs.get(2), sinputs.get(3), LopProperties.ExecType.CP, false));
                        voutput.getLops().getOutputParameters().setDimensions(vinput.getDim1(), vinput.getDim2(), vinput.getRowsInBlock(), vinput.getColsInBlock(), vinput.getNnz());
                        this.setLops(((Hop)voutput).constructLops());
                    }
                    if (HopRewriteUtils.getBooleanValueSafe((LiteralOp)ixret)) break;
                    DataGenOp seq = HopRewriteUtils.createSeqDataGenOp(voutput);
                    seq.setName("tmp4");
                    seq.refreshSizeInformation();
                    seq.computeMemEstimate(new MemoTable());
                    HopRewriteUtils.copyLineNumbers(this, seq);
                    TernaryOp table = new TernaryOp("tmp5", Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, Hop.OpOp3.CTABLE, seq, voutput, new LiteralOp(1L));
                    table.setOutputBlocksizes(this.getRowsInBlock(), this.getColsInBlock());
                    table.refreshSizeInformation();
                    table.setForcedExecType(LopProperties.ExecType.MR);
                    HopRewriteUtils.copyLineNumbers(this, table);
                    table.setDisjointInputs(true);
                    table.setOutputEmptyBlocks(false);
                    AggBinaryOp mmult = HopRewriteUtils.createMatrixMultiply(table, input);
                    mmult.setForcedExecType(LopProperties.ExecType.MR);
                    this.setLops(mmult.constructLops());
                    HopRewriteUtils.removeChildReference(table, input);
                    break;
                }
                if (et == LopProperties.ExecType.SPARK && !FORCE_DIST_SORT_INDEXES) {
                    this.bSortSPRewriteApplicable = this.isSortSPRewriteApplicable();
                }
                Lop transform1 = ReorgOp.constructCPOrSparkSortLop(input, by, desc, ixret, et, this.bSortSPRewriteApplicable);
                this.setOutputDimensions(transform1);
                this.setLineNumbers(transform1);
                this.setLops(transform1);
                break;
            }
            default: {
                throw new HopsException("Unsupported lops construction for operation type '" + (Object)((Object)this.op) + "'.");
            }
        }
        this.constructAndSetLopsDataFlowProperties();
        return this.getLops();
    }

    private static Lop constructCPOrSparkSortLop(Hop input, Hop by, Hop desc, Hop ixret, LopProperties.ExecType et, boolean bSortIndInMem) throws HopsException, LopsException {
        Transform transform1 = new Transform(input.constructLops(), (Transform.OperationTypes)((Object)HopsTransf2Lops.get((Object)Hop.ReOrgOp.SORT)), input.getDataType(), input.getValueType(), et, bSortIndInMem);
        for (Hop c : new Hop[]{by, desc, ixret}) {
            Lop ltmp = c.constructLops();
            transform1.addInput(ltmp);
            ltmp.addOutput(transform1);
        }
        transform1.setLevel();
        return transform1;
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        double sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz);
        return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
    }

    @Override
    protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
        Hop ixreturn;
        if (this.op == Hop.ReOrgOp.SORT && (!((ixreturn = this.getInput().get(3)) instanceof LiteralOp) || HopRewriteUtils.getBooleanValueSafe((LiteralOp)ixreturn) || dim2 != 1L && nnz != 0L)) {
            return dim1 * 4L;
        }
        return 0.0;
    }

    @Override
    protected long[] inferOutputCharacteristics(MemoTable memo) {
        long[] ret = null;
        Hop input = this.getInput().get(0);
        MatrixCharacteristics mc = memo.getAllInputStats(input);
        switch (this.op) {
            case TRANSPOSE: {
                if (!mc.dimsKnown()) break;
                ret = new long[]{mc.getCols(), mc.getRows(), mc.getNonZeros()};
                break;
            }
            case REV: {
                if (!mc.dimsKnown()) break;
                ret = new long[]{mc.getRows(), mc.getCols(), mc.getNonZeros()};
                break;
            }
            case DIAG: {
                long k = mc.getRows();
                if (k == 1L) {
                    ret = new long[]{k, k, mc.getNonZeros() >= 0L ? mc.getNonZeros() : k};
                }
                if (k <= 1L) break;
                ret = new long[]{k, 1L, mc.getNonZeros() >= 0L ? Math.min(k, mc.getNonZeros()) : k};
                break;
            }
            case RESHAPE: {
                if (!mc.dimsKnown()) break;
                if (this._dim1 > 0L) {
                    ret = new long[]{this._dim1, mc.getRows() * mc.getCols() / this._dim1, mc.getNonZeros()};
                    break;
                }
                if (this._dim2 <= 0L) break;
                ret = new long[]{mc.getRows() * mc.getCols() / this._dim2, this._dim2, mc.getNonZeros()};
                break;
            }
            case SORT: {
                boolean unknownIxRet;
                Hop input4 = this.getInput().get(3);
                boolean bl = unknownIxRet = !(input4 instanceof LiteralOp);
                if (!unknownIxRet) {
                    boolean ixret = HopRewriteUtils.getBooleanValueSafe((LiteralOp)input4);
                    long dim2 = ixret ? 1L : mc.getCols();
                    long nnz = ixret ? mc.getRows() : mc.getNonZeros();
                    ret = new long[]{mc.getRows(), dim2, nnz};
                    break;
                }
                ret = new long[]{mc.getRows(), -1L, -1L};
            }
        }
        return ret;
    }

    @Override
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override
    protected LopProperties.ExecType optFindExecType() throws HopsException {
        LopProperties.ExecType REMOTE;
        this.checkAndSetForcedPlatform();
        LopProperties.ExecType execType = REMOTE = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = OptimizerUtils.isMemoryBasedOptLevel() ? this.findExecTypeByMemEstimate() : (this.getInput().get(0).areDimsBelowThreshold() || this.getInput().get(0).isVector() ? LopProperties.ExecType.CP : REMOTE);
            this.checkAndSetInvalidCPDimsAndSize();
        }
        this.setRequiresRecompileIfNecessary();
        return this._etype;
    }

    @Override
    public void refreshSizeInformation() {
        Hop input1 = this.getInput().get(0);
        switch (this.op) {
            case TRANSPOSE: {
                this.setDim1(input1.getDim2());
                this.setDim2(input1.getDim1());
                this.setNnz(input1.getNnz());
                break;
            }
            case REV: {
                this.setDim1(input1.getDim1());
                this.setDim2(input1.getDim2());
                this.setNnz(input1.getNnz());
                break;
            }
            case DIAG: {
                long k = input1.getDim1();
                this.setDim1(k);
                if (input1.getDim2() == 1L) {
                    this.setDim2(k);
                    this.setNnz(input1.getNnz() >= 0L ? input1.getNnz() : k);
                }
                if (input1.getDim2() <= 1L) break;
                this.setDim2(1L);
                this.setNnz(input1.getNnz() >= 0L ? Math.min(k, input1.getNnz()) : k);
                break;
            }
            case RESHAPE: {
                Hop input2 = this.getInput().get(1);
                Hop input3 = this.getInput().get(2);
                this.refreshRowsParameterInformation(input2);
                this.refreshColsParameterInformation(input3);
                this.setNnz(input1.getNnz());
                if (this.dimsKnown() || !input1.dimsKnown()) break;
                if (this._dim1 > 0L) {
                    this._dim2 = input1._dim1 * input1._dim2 / this._dim1;
                    break;
                }
                if (this._dim2 <= 0L) break;
                this._dim1 = input1._dim1 * input1._dim2 / this._dim2;
                break;
            }
            case SORT: {
                Hop input4 = this.getInput().get(3);
                boolean unknownIxRet = !(input4 instanceof LiteralOp);
                this._dim1 = input1.getDim1();
                if (!unknownIxRet) {
                    boolean ixret = HopRewriteUtils.getBooleanValueSafe((LiteralOp)input4);
                    this._dim2 = ixret ? 1L : input1.getDim2();
                    this._nnz = ixret ? input1.getDim1() : input1.getNnz();
                    break;
                }
                this._dim2 = -1L;
                this._nnz = -1L;
                break;
            }
        }
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        ReorgOp ret = new ReorgOp();
        ret.clone(this, false);
        ret.op = this.op;
        ret._maxNumThreads = this._maxNumThreads;
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        boolean ret;
        if (!(that instanceof ReorgOp)) {
            return false;
        }
        ReorgOp that2 = (ReorgOp)that;
        boolean bl = ret = this.op == that2.op && this._maxNumThreads == that2._maxNumThreads && this.getInput().size() == that.getInput().size();
        if (ret) {
            for (int i = 0; i < this._input.size(); ++i) {
                ret &= this.getInput().get(i) == that2.getInput().get(i);
            }
        }
        return ret;
    }

    private boolean isSortSPRewriteApplicable() {
        double size;
        boolean ret = false;
        Hop input = this.getInput().get(0);
        double d = size = input.dimsKnown() ? (double)OptimizerUtils.estimateSize(input.getDim1(), 1L) : input.getOutputMemEstimate();
        if (OptimizerUtils.checkSparkBroadcastMemoryBudget(size)) {
            ret = true;
        }
        return ret;
    }
}

