/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.codegen.template;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import org.apache.commons.lang.ArrayUtils;
import org.apache.sysml.hops.AggBinaryOp;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
import org.apache.sysml.hops.codegen.cplan.CNodeData;
import org.apache.sysml.hops.codegen.cplan.CNodeTernary;
import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
import org.apache.sysml.hops.codegen.cplan.CNodeUnary;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.TemplateBase;
import org.apache.sysml.hops.codegen.template.TemplateCell;
import org.apache.sysml.hops.codegen.template.TemplateMultiAgg;
import org.apache.sysml.hops.codegen.template.TemplateOuterProduct;
import org.apache.sysml.hops.codegen.template.TemplateRow;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.codegen.SpoofCellwise;
import org.apache.sysml.runtime.codegen.SpoofOuterProduct;
import org.apache.sysml.runtime.codegen.SpoofRowwise;
import org.apache.sysml.runtime.util.UtilFunctions;

public class TemplateUtils {
    public static final TemplateBase[] TEMPLATES = new TemplateBase[]{new TemplateRow(), new TemplateCell(), new TemplateOuterProduct()};

    public static boolean isVector(Hop hop) {
        return hop.getDataType() == Expression.DataType.MATRIX && (hop.getDim1() != 1L && hop.getDim2() == 1L || hop.getDim1() == 1L && hop.getDim2() != 1L);
    }

    public static boolean isColVector(CNode hop) {
        return hop.getDataType() == Expression.DataType.MATRIX && hop.getNumRows() != 1L && hop.getNumCols() == 1L;
    }

    public static boolean isRowVector(CNode hop) {
        return hop.getDataType() == Expression.DataType.MATRIX && hop.getNumRows() == 1L && hop.getNumCols() != 1L;
    }

    public static CNode wrapLookupIfNecessary(CNode node, Hop hop) {
        CNode ret = node;
        if (TemplateUtils.isColVector(node)) {
            ret = new CNodeUnary(node, CNodeUnary.UnaryType.LOOKUP_R);
        } else if (TemplateUtils.isRowVector(node)) {
            ret = new CNodeUnary(node, CNodeUnary.UnaryType.LOOKUP_C);
        } else if (node instanceof CNodeData && hop.getDataType().isMatrix()) {
            ret = new CNodeUnary(node, CNodeUnary.UnaryType.LOOKUP_RC);
        }
        return ret;
    }

    public static boolean isMatrix(Hop hop) {
        return hop.getDataType() == Expression.DataType.MATRIX && hop.getDim1() != 1L && hop.getDim2() != 1L;
    }

    public static boolean isVectorOrScalar(Hop hop) {
        return hop.dimsKnown() && (hop.getDataType() == Expression.DataType.SCALAR || TemplateUtils.isVector(hop));
    }

    public static boolean isBinaryMatrixRowVector(Hop hop) {
        if (!(hop instanceof BinaryOp)) {
            return false;
        }
        Hop left = hop.getInput().get(0);
        Hop right = hop.getInput().get(1);
        return left.dimsKnown() && right.dimsKnown() && left.getDataType().isMatrix() && right.getDataType().isMatrix() && left.getDim1() > right.getDim1();
    }

    public static boolean isBinaryMatrixColVector(Hop hop) {
        if (!(hop instanceof BinaryOp)) {
            return false;
        }
        Hop left = hop.getInput().get(0);
        Hop right = hop.getInput().get(1);
        return left.dimsKnown() && right.dimsKnown() && left.getDataType().isMatrix() && right.getDataType().isMatrix() && left.getDim2() > right.getDim2();
    }

    public static boolean hasMatrixInput(Hop hop) {
        for (Hop c : hop.getInput()) {
            if (!TemplateUtils.isMatrix(c)) continue;
            return true;
        }
        return false;
    }

    public static boolean isOperationSupported(Hop h) {
        if (h instanceof UnaryOp) {
            return CNodeUnary.UnaryType.contains(((UnaryOp)h).getOp().name());
        }
        if (h instanceof BinaryOp) {
            return CNodeBinary.BinType.contains(((BinaryOp)h).getOp().name());
        }
        if (h instanceof TernaryOp) {
            return CNodeTernary.TernaryType.contains(((TernaryOp)h).getOp().name());
        }
        if (h instanceof ParameterizedBuiltinOp) {
            return CNodeTernary.TernaryType.contains(((ParameterizedBuiltinOp)h).getOp().name());
        }
        return false;
    }

    private static void rfindChildren(Hop hop, HashSet<Hop> children) {
        if (hop instanceof UnaryOp || hop instanceof BinaryOp && hop.getInput().get(0).getDataType() == Expression.DataType.MATRIX && TemplateUtils.isVectorOrScalar(hop.getInput().get(1)) || hop instanceof BinaryOp && TemplateUtils.isVectorOrScalar(hop.getInput().get(0)) && hop.getInput().get(1).getDataType() == Expression.DataType.MATRIX && hop.getDataType() == Expression.DataType.MATRIX) {
            if (!children.contains(hop)) {
                children.add(hop);
            }
            Hop matrix = TemplateUtils.isMatrix(hop.getInput().get(0)) ? hop.getInput().get(0) : hop.getInput().get(1);
            TemplateUtils.rfindChildren(matrix, children);
        } else {
            children.add(hop);
        }
    }

    private static Hop findCommonChild(Hop hop1, Hop hop2) {
        LinkedHashSet<Hop> children1 = new LinkedHashSet<Hop>();
        LinkedHashSet<Hop> children2 = new LinkedHashSet<Hop>();
        TemplateUtils.rfindChildren(hop1, children1);
        TemplateUtils.rfindChildren(hop2, children2);
        for (Hop candidate : children1) {
            if (!children2.contains(candidate)) continue;
            return candidate;
        }
        return null;
    }

    public static Hop commonChild(ArrayList<Hop> _adddedMatrices, Hop input) {
        Hop currentChild = null;
        for (Hop addedMatrix : _adddedMatrices) {
            Hop child = TemplateUtils.findCommonChild(addedMatrix, input);
            if (child == null) {
                return null;
            }
            if (currentChild == null) {
                currentChild = child;
                continue;
            }
            if (child.getHopID() == currentChild.getHopID()) continue;
            return null;
        }
        return currentChild;
    }

    public static HashSet<Long> rGetInputHopIDs(CNode node, HashSet<Long> ids) {
        if (node instanceof CNodeData && !node.isLiteral()) {
            ids.add(((CNodeData)node).getHopID());
        }
        for (CNode c : node.getInput()) {
            TemplateUtils.rGetInputHopIDs(c, ids);
        }
        return ids;
    }

    public static Hop[] mergeDistinct(HashSet<Long> ids, Hop[] input1, Hop[] input2) {
        Hop[] ret = new Hop[ids.size()];
        int pos = 0;
        Hop[][] hopArrayArray = new Hop[][]{input1, input2};
        int n = hopArrayArray.length;
        for (int i = 0; i < n; ++i) {
            Hop[] input;
            for (Hop c : input = hopArrayArray[i]) {
                if (!ids.contains(c.getHopID())) continue;
                ret[pos++] = c;
            }
        }
        return ret;
    }

    public static TemplateBase createTemplate(TemplateBase.TemplateType type) {
        return TemplateUtils.createTemplate(type, false);
    }

    public static TemplateBase createTemplate(TemplateBase.TemplateType type, boolean closed) {
        TemplateBase tpl = null;
        switch (type) {
            case CellTpl: {
                tpl = new TemplateCell(closed);
                break;
            }
            case RowTpl: {
                tpl = new TemplateRow(closed);
                break;
            }
            case MultiAggTpl: {
                tpl = new TemplateMultiAgg(closed);
                break;
            }
            case OuterProdTpl: {
                tpl = new TemplateOuterProduct(closed);
            }
        }
        return tpl;
    }

    public static TemplateBase[] createCompatibleTemplates(TemplateBase.TemplateType type, boolean closed) {
        TemplateBase[] tpl = null;
        switch (type) {
            case CellTpl: {
                tpl = new TemplateBase[]{new TemplateCell(closed), new TemplateRow(closed)};
                break;
            }
            case RowTpl: {
                tpl = new TemplateBase[]{new TemplateRow(closed)};
                break;
            }
            case MultiAggTpl: {
                tpl = new TemplateBase[]{new TemplateMultiAgg(closed)};
                break;
            }
            case OuterProdTpl: {
                tpl = new TemplateBase[]{new TemplateOuterProduct(closed)};
            }
        }
        return tpl;
    }

    public static SpoofCellwise.CellType getCellType(Hop hop) {
        return hop instanceof AggBinaryOp ? SpoofCellwise.CellType.FULL_AGG : (hop instanceof AggUnaryOp ? (((AggUnaryOp)hop).getDirection() == Hop.Direction.RowCol ? SpoofCellwise.CellType.FULL_AGG : SpoofCellwise.CellType.ROW_AGG) : SpoofCellwise.CellType.NO_AGG);
    }

    public static SpoofRowwise.RowType getRowType(Hop output, Hop input) {
        if (HopRewriteUtils.isEqualSize(output, input)) {
            return SpoofRowwise.RowType.NO_AGG;
        }
        if (output.getDim1() == input.getDim1() && output.getDim2() == 1L) {
            return SpoofRowwise.RowType.ROW_AGG;
        }
        if (output.getDim1() == input.getDim2() && output.getDim2() == 1L) {
            return SpoofRowwise.RowType.COL_AGG_T;
        }
        return SpoofRowwise.RowType.COL_AGG;
    }

    public static Hop.AggOp getAggOp(Hop hop) {
        return hop instanceof AggUnaryOp ? ((AggUnaryOp)hop).getOp() : (hop instanceof AggBinaryOp ? Hop.AggOp.SUM : null);
    }

    public static SpoofOuterProduct.OutProdType getOuterProductType(Hop X, Hop U, Hop V, Hop out) {
        if (out.getDataType() == Expression.DataType.SCALAR) {
            return SpoofOuterProduct.OutProdType.AGG_OUTER_PRODUCT;
        }
        if (out instanceof AggBinaryOp && (out.getInput().get(0) == U || HopRewriteUtils.isTransposeOperation(out.getInput().get(0)) && out.getInput().get(0).getInput().get(0) == U) || HopRewriteUtils.isTransposeOperation(out)) {
            return SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT;
        }
        if (out instanceof AggBinaryOp && (out.getInput().get(1) == V || HopRewriteUtils.isTransposeOperation(out.getInput().get(1)) && out.getInput().get(1).getInput().get(0) == V)) {
            return SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT;
        }
        if (out instanceof BinaryOp && HopRewriteUtils.isEqualSize(out.getInput().get(0), out.getInput().get(1))) {
            return SpoofOuterProduct.OutProdType.CELLWISE_OUTER_PRODUCT;
        }
        throw new RuntimeException("Undefined outer product type for hop " + out.getHopID());
    }

    public static boolean isLookup(CNode node) {
        return TemplateUtils.isUnary(node, CNodeUnary.UnaryType.LOOKUP_R, CNodeUnary.UnaryType.LOOKUP_C, CNodeUnary.UnaryType.LOOKUP_RC) || TemplateUtils.isTernary(node, CNodeTernary.TernaryType.LOOKUP_RC1);
    }

    public static boolean isUnary(CNode node, CNodeUnary.UnaryType ... types) {
        return node instanceof CNodeUnary && ArrayUtils.contains((Object[])types, (Object)((Object)((CNodeUnary)node).getType()));
    }

    public static boolean isTernary(CNode node, CNodeTernary.TernaryType ... types) {
        return node instanceof CNodeTernary && ArrayUtils.contains((Object[])types, (Object)((Object)((CNodeTernary)node).getType()));
    }

    public static CNodeData createCNodeData(Hop hop, boolean compileLiterals) {
        CNodeData cdata = new CNodeData(hop);
        cdata.setLiteral(hop instanceof LiteralOp && (compileLiterals || UtilFunctions.isIntegerNumber(((LiteralOp)hop).getStringValue())));
        return cdata;
    }

    public static CNode skipTranspose(CNode cdataOrig, Hop hop, HashMap<Long, CNode> tmp, boolean compileLiterals) {
        if (HopRewriteUtils.isTransposeOperation(hop)) {
            CNode cdata = tmp.get(hop.getInput().get(0).getHopID());
            if (cdata == null) {
                cdata = TemplateUtils.createCNodeData(hop.getInput().get(0), compileLiterals);
                tmp.put(hop.getInput().get(0).getHopID(), cdata);
            }
            tmp.put(hop.getHopID(), cdata);
            return cdata;
        }
        return cdataOrig;
    }

    public static boolean hasTransposeParentUnderOuterProduct(Hop hop) {
        for (Hop p : hop.getParent()) {
            if (!HopRewriteUtils.isTransposeOperation(p)) continue;
            for (Hop p2 : p.getParent()) {
                if (!HopRewriteUtils.isOuterProductLikeMM(p2)) continue;
                return true;
            }
        }
        return false;
    }

    public static boolean hasSingleOperation(CNodeTpl tpl) {
        CNode output = tpl.getOutput();
        return (output instanceof CNodeUnary || output instanceof CNodeBinary || output instanceof CNodeTernary) && TemplateUtils.hasOnlyDataNodeOrLookupInputs(output);
    }

    public static boolean hasNoOperation(CNodeTpl tpl) {
        return tpl.getOutput() instanceof CNodeData || TemplateUtils.isLookup(tpl.getOutput());
    }

    public static boolean hasOnlyDataNodeOrLookupInputs(CNode node) {
        boolean ret = true;
        for (CNode c : node.getInput()) {
            ret &= c instanceof CNodeData || c instanceof CNodeUnary && (((CNodeUnary)c).getType() == CNodeUnary.UnaryType.LOOKUP0 || ((CNodeUnary)c).getType() == CNodeUnary.UnaryType.LOOKUP_R || ((CNodeUnary)c).getType() == CNodeUnary.UnaryType.LOOKUP_RC);
        }
        return ret;
    }

    public static int countVectorIntermediates(CNode node, HashSet<Long> memo) {
        if (memo.contains(node.getID())) {
            return 0;
        }
        memo.add(node.getID());
        int ret = 0;
        for (CNode c : node.getInput()) {
            ret += TemplateUtils.countVectorIntermediates(c, memo);
        }
        int cntBin = node instanceof CNodeBinary && ((CNodeBinary)node).getType().isVectorScalarPrimitive() ? 1 : 0;
        int cntUn = node instanceof CNodeUnary && ((CNodeUnary)node).getType().isVectorScalarPrimitive() ? 1 : 0;
        return ret + cntBin + cntUn;
    }

    public static boolean isType(TemplateBase.TemplateType type, TemplateBase.TemplateType ... validTypes) {
        return ArrayUtils.contains((Object[])validTypes, (Object)((Object)type));
    }

    public static boolean hasCommonRowTemplateMatrixInput(Hop input1, Hop input2, CPlanMemoTable memo) {
        long tmp2;
        if (!memo.contains(input2.getHopID(), TemplateBase.TemplateType.RowTpl)) {
            return true;
        }
        long tmp1 = TemplateUtils.getRowTemplateMatrixInput(input1, memo);
        return tmp1 == (tmp2 = TemplateUtils.getRowTemplateMatrixInput(input2, memo));
    }

    public static long getRowTemplateMatrixInput(Hop current, CPlanMemoTable memo) {
        CPlanMemoTable.MemoTableEntry me = memo.getBest(current.getHopID(), TemplateBase.TemplateType.RowTpl);
        long ret = -1L;
        for (int i = 0; ret < 0L && i < current.getInput().size(); ++i) {
            Hop input = current.getInput().get(i);
            if (me.isPlanRef(i) && memo.contains(input.getHopID(), TemplateBase.TemplateType.RowTpl)) {
                ret = TemplateUtils.getRowTemplateMatrixInput(input, memo);
                continue;
            }
            if (me.isPlanRef(i) || !TemplateUtils.isMatrix(input)) continue;
            ret = input.getHopID();
        }
        return ret;
    }
}

