/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.template;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
import org.apache.sysml.hops.codegen.cplan.CNodeCell;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.TemplateBase;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.matrix.data.Pair;

public class TemplateCell
extends TemplateBase {
    private static final Hop.AggOp[] SUPPORTED_AGG = new Hop.AggOp[]{Hop.AggOp.SUM, Hop.AggOp.SUM_SQ, Hop.AggOp.MIN, Hop.AggOp.MAX};

    public TemplateCell() {
        super(TemplateBase.TemplateType.CellTpl);
    }

    public TemplateCell(boolean closed) {
        super(TemplateBase.TemplateType.CellTpl, closed);
    }

    public TemplateCell(TemplateBase.TemplateType type, boolean closed) {
        super(type, closed);
    }

    @Override
    public boolean open(Hop hop) {
        return TemplateCell.isValidOperation(hop) || hop instanceof IndexingOp && ((IndexingOp)hop).isColLowerEqualsUpper();
    }

    @Override
    public boolean fuse(Hop hop, Hop input) {
        return !this.isClosed() && (TemplateCell.isValidOperation(hop) || HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_AGG) && ((AggUnaryOp)hop).getDirection() != Hop.Direction.Col || HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1() == 1L && hop.getDim2() == 1L && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)));
    }

    @Override
    public boolean merge(Hop hop, Hop input) {
        return !this.isClosed() && TemplateCell.isValidOperation(hop);
    }

    @Override
    public TemplateBase.CloseType close(Hop hop) {
        if (HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_AGG) && ((AggUnaryOp)hop).getDirection() != Hop.Direction.Col || HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1() == 1L && hop.getDim2() == 1L) {
            return TemplateBase.CloseType.CLOSED_VALID;
        }
        if (hop instanceof AggUnaryOp || hop instanceof AggBinaryOp) {
            return TemplateBase.CloseType.CLOSED_INVALID;
        }
        return TemplateBase.CloseType.OPEN;
    }

    @Override
    public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable memo, boolean compileLiterals) {
        HashSet<Hop> inHops = new HashSet<Hop>();
        HashMap<Long, CNode> tmp = new HashMap<Long, CNode>();
        hop.resetVisitStatus();
        this.rConstructCplan(hop, memo, tmp, inHops, compileLiterals);
        hop.resetVisitStatus();
        List<Hop> sinHops = inHops.stream().filter(h -> !h.getDataType().isScalar() || !((CNode)tmp.get(h.getHopID())).isLiteral()).sorted(new HopInputComparator()).collect(Collectors.toList());
        ArrayList<CNode> inputs = new ArrayList<CNode>();
        for (Hop in : sinHops) {
            inputs.add(tmp.get(in.getHopID()));
        }
        CNode output = tmp.get(hop.getHopID());
        CNodeCell tpl = new CNodeCell(inputs, output);
        tpl.setCellType(TemplateUtils.getCellType(hop));
        tpl.setAggOp(TemplateUtils.getAggOp(hop));
        tpl.setSparseSafe(HopRewriteUtils.isBinary(hop, Hop.OpOp2.MULT) && hop.getInput().contains(sinHops.get(0)) || HopRewriteUtils.isBinary(hop, Hop.OpOp2.DIV) && hop.getInput().get(0) == sinHops.get(0));
        tpl.setRequiresCastDtm(hop instanceof AggBinaryOp);
        return new Pair<Hop[], CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
    }

    protected void rConstructCplan(Hop hop, CPlanMemoTable memo, HashMap<Long, CNode> tmp, HashSet<Hop> inHops, boolean compileLiterals) {
        CNode cdata2;
        CNode cdata1;
        CNode cdata12;
        if (tmp.containsKey(hop.getHopID())) {
            return;
        }
        CPlanMemoTable.MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateBase.TemplateType.CellTpl);
        if (me != null && (me.type == TemplateBase.TemplateType.RowTpl || me.type == TemplateBase.TemplateType.OuterProdTpl)) {
            CNodeData cdata = TemplateUtils.createCNodeData(hop, compileLiterals);
            tmp.put(hop.getHopID(), cdata);
            inHops.add(hop);
            return;
        }
        for (int i = 0; i < hop.getInput().size(); ++i) {
            Hop c = hop.getInput().get(i);
            if (me != null && me.isPlanRef(i) && !(c instanceof DataOp) && (me.type != TemplateBase.TemplateType.MultiAggTpl || memo.contains(c.getHopID(), TemplateBase.TemplateType.CellTpl))) {
                this.rConstructCplan(c, memo, tmp, inHops, compileLiterals);
                continue;
            }
            CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
            tmp.put(c.getHopID(), cdata);
            inHops.add(c);
        }
        CNode out = null;
        if (hop instanceof UnaryOp) {
            cdata12 = tmp.get(hop.getInput().get(0).getHopID());
            cdata12 = TemplateUtils.wrapLookupIfNecessary(cdata12, hop.getInput().get(0));
            String primitiveOpName = ((UnaryOp)hop).getOp().name();
            out = new CNodeUnary(cdata12, CNodeUnary.UnaryType.valueOf(primitiveOpName));
        } else if (hop instanceof BinaryOp) {
            BinaryOp bop = (BinaryOp)hop;
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            CNode cdata22 = tmp.get(hop.getInput().get(1).getHopID());
            String primitiveOpName = bop.getOp().name();
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            cdata22 = TemplateUtils.wrapLookupIfNecessary(cdata22, hop.getInput().get(1));
            out = bop.getOp() == Hop.OpOp2.POW && cdata22.isLiteral() && cdata22.getVarname().equals("2") ? new CNodeUnary(cdata1, CNodeUnary.UnaryType.POW2) : (bop.getOp() == Hop.OpOp2.MULT && cdata22.isLiteral() && cdata22.getVarname().equals("2") ? new CNodeUnary(cdata1, CNodeUnary.UnaryType.MULT2) : new CNodeBinary(cdata1, cdata22, CNodeBinary.BinType.valueOf(primitiveOpName)));
        } else if (hop instanceof TernaryOp) {
            TernaryOp top = (TernaryOp)hop;
            cdata1 = tmp.get(hop.getInput().get(0).getHopID());
            CNode cdata23 = tmp.get(hop.getInput().get(1).getHopID());
            CNode cdata3 = tmp.get(hop.getInput().get(2).getHopID());
            cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
            cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
            out = new CNodeTernary(cdata1, cdata23, cdata3, CNodeTernary.TernaryType.valueOf(top.getOp().name()));
        } else if (hop instanceof ParameterizedBuiltinOp) {
            cdata12 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
            cdata12 = TemplateUtils.wrapLookupIfNecessary(cdata12, hop.getInput().get(0));
            cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
            CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
            CNodeTernary.TernaryType ttype = cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN") ? CNodeTernary.TernaryType.REPLACE_NAN : CNodeTernary.TernaryType.REPLACE;
            out = new CNodeTernary(cdata12, cdata2, cdata3, ttype);
        } else if (hop instanceof IndexingOp) {
            cdata12 = tmp.get(hop.getInput().get(0).getHopID());
            out = new CNodeTernary(cdata12, TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), CNodeTernary.TernaryType.LOOKUP_RC1);
        } else if (HopRewriteUtils.isTransposeOperation(hop)) {
            out = tmp.get(hop.getInput().get(0).getHopID());
        } else if (hop instanceof AggUnaryOp) {
            out = tmp.get(hop.getInput().get(0).getHopID());
        } else if (hop instanceof AggBinaryOp) {
            if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
                cdata12 = tmp.get(hop.getInput().get(1).getHopID());
                out = new CNodeUnary(cdata12, CNodeUnary.UnaryType.POW2);
            } else {
                cdata12 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), hop.getInput().get(0), tmp, compileLiterals);
                if (TemplateUtils.isColVector(cdata12)) {
                    cdata12 = new CNodeUnary(cdata12, CNodeUnary.UnaryType.LOOKUP_R);
                }
                if (TemplateUtils.isColVector(cdata2 = tmp.get(hop.getInput().get(1).getHopID()))) {
                    cdata2 = new CNodeUnary(cdata2, CNodeUnary.UnaryType.LOOKUP_R);
                }
                out = new CNodeBinary(cdata12, cdata2, CNodeBinary.BinType.MULT);
            }
        }
        tmp.put(hop.getHopID(), out);
    }

    protected static boolean isValidOperation(Hop hop) {
        boolean isBinaryMatrixScalar = false;
        boolean isBinaryMatrixVector = false;
        boolean isBinaryMatrixMatrixDense = false;
        if (hop instanceof BinaryOp && hop.getDataType().isMatrix()) {
            Hop left = hop.getInput().get(0);
            Hop right = hop.getInput().get(1);
            Expression.DataType ldt = left.getDataType();
            Expression.DataType rdt = right.getDataType();
            isBinaryMatrixScalar = ldt.isScalar() || rdt.isScalar();
            isBinaryMatrixVector = hop.dimsKnown() && (ldt.isMatrix() && TemplateUtils.isVectorOrScalar(right) || rdt.isMatrix() && TemplateUtils.isVectorOrScalar(left));
            isBinaryMatrixMatrixDense = hop.dimsKnown() && HopRewriteUtils.isEqualSize(left, right) && ldt.isMatrix() && rdt.isMatrix() && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
        }
        boolean isTernaryVectorScalarVector = false;
        boolean isTernaryMatrixScalarMatrixDense = false;
        if (hop instanceof TernaryOp && hop.getInput().size() == 3 && hop.dimsKnown() && HopRewriteUtils.checkInputDataTypes(hop, Expression.DataType.MATRIX, Expression.DataType.SCALAR, Expression.DataType.MATRIX)) {
            Hop left = hop.getInput().get(0);
            Hop right = hop.getInput().get(2);
            isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right);
            isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right) && !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
        }
        return hop.getDataType() == Expression.DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp || isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrixDense || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp() == Hop.ParamBuiltinOp.REPLACE);
    }

    public static class HopInputComparator
    implements Comparator<Hop> {
        @Override
        public int compare(Hop h1, Hop h2) {
            long ncells2;
            long ncells1;
            long l = h1.getDataType() == Expression.DataType.SCALAR ? Long.MIN_VALUE : (ncells1 = h1.dimsKnown() ? h1.getDim1() * h1.getDim2() : Long.MAX_VALUE);
            long l2 = h2.getDataType() == Expression.DataType.SCALAR ? Long.MIN_VALUE : (ncells2 = h2.dimsKnown() ? h2.getDim1() * h2.getDim2() : Long.MAX_VALUE);
            if (ncells1 > ncells2) {
                return -1;
            }
            if (ncells1 < ncells2) {
                return 1;
            }
            return Long.compare(h1.dimsKnown(true) ? h1.getNnz() : ncells1, h2.dimsKnown(true) ? h2.getNnz() : ncells2);
        }
    }
}

