/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.planner.plan.utils.PythonUtil;

public class PythonMapMergeRule
extends RelOptRule {
    public static final PythonMapMergeRule INSTANCE = new PythonMapMergeRule();

    private PythonMapMergeRule() {
        super(PythonMapMergeRule.operand(FlinkLogicalCalc.class, PythonMapMergeRule.operand(FlinkLogicalCalc.class, PythonMapMergeRule.operand(FlinkLogicalCalc.class, PythonMapMergeRule.none()), new RelOptRuleOperand[0]), new RelOptRuleOperand[0]), "PythonMapMergeRule");
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        int inputRowFieldCount;
        FlinkLogicalCalc topCalc = (FlinkLogicalCalc)call.rel(0);
        FlinkLogicalCalc middleCalc = (FlinkLogicalCalc)call.rel(1);
        FlinkLogicalCalc bottomCalc = (FlinkLogicalCalc)call.rel(2);
        RexProgram topProgram = topCalc.getProgram();
        List topProjects = topProgram.getProjectList().stream().map(topProgram::expandLocalRef).collect(Collectors.toList());
        if (topProjects.size() != 1 || !PythonUtil.isPythonCall((RexNode)topProjects.get(0), null) || !PythonUtil.takesRowAsInput((RexCall)topProjects.get(0))) {
            return false;
        }
        RexProgram bottomProgram = bottomCalc.getProgram();
        List bottomProjects = bottomProgram.getProjectList().stream().map(bottomProgram::expandLocalRef).collect(Collectors.toList());
        if (bottomProjects.size() != 1 || !PythonUtil.isPythonCall((RexNode)bottomProjects.get(0), null)) {
            return false;
        }
        RexProgram middleProgram = middleCalc.getProgram();
        if (topProgram.getCondition() != null || middleProgram.getCondition() != null || bottomProgram.getCondition() != null) {
            return false;
        }
        List<RexNode> middleProjects = middleProgram.getProjectList().stream().map(middleProgram::expandLocalRef).collect(Collectors.toList());
        return this.isFlattenCalc(middleProjects, inputRowFieldCount = ((RelDataType)middleProgram.getInputRowType().getFieldList().get(0).getValue()).getFieldList().size()) && this.isTopCalcTakesWholeMiddleCalcAsInputs((RexCall)topProjects.get(0), middleProjects.size());
    }

    private boolean isTopCalcTakesWholeMiddleCalcAsInputs(RexCall pythonCall, int inputColumnCount) {
        List<RexNode> pythonCallInputs = pythonCall.getOperands();
        if (pythonCallInputs.size() != inputColumnCount) {
            return false;
        }
        for (int i = 0; i < pythonCallInputs.size(); ++i) {
            RexNode input = pythonCallInputs.get(i);
            if (input instanceof RexInputRef) {
                if (((RexInputRef)input).getIndex() == i) continue;
                return false;
            }
            return false;
        }
        return true;
    }

    private boolean isFlattenCalc(List<RexNode> middleProjects, int inputRowFieldCount) {
        if (inputRowFieldCount != middleProjects.size()) {
            return false;
        }
        for (int i = 0; i < inputRowFieldCount; ++i) {
            RexNode middleProject = middleProjects.get(i);
            if (middleProject instanceof RexFieldAccess) {
                RexFieldAccess rexField = (RexFieldAccess)middleProject;
                if (rexField.getField().getIndex() != i) {
                    return false;
                }
                RexNode expr = rexField.getReferenceExpr();
                if (expr instanceof RexInputRef) {
                    if (((RexInputRef)expr).getIndex() == 0) continue;
                    return false;
                }
                return false;
            }
            return false;
        }
        return true;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        FlinkLogicalCalc topCalc = (FlinkLogicalCalc)call.rel(0);
        FlinkLogicalCalc middleCalc = (FlinkLogicalCalc)call.rel(1);
        FlinkLogicalCalc bottomCalc = (FlinkLogicalCalc)call.rel(2);
        RexProgram topProgram = topCalc.getProgram();
        List topProjects = topProgram.getProjectList().stream().map(topProgram::expandLocalRef).map(x -> (RexCall)x).collect(Collectors.toList());
        RexCall topPythonCall = (RexCall)topProjects.get(0);
        RexCall newPythonCall = topPythonCall.clone(topPythonCall.getType(), Collections.singletonList(RexInputRef.of(0, bottomCalc.getRowType())));
        List<RexCall> topMiddleMergedProjects = Collections.singletonList(newPythonCall);
        FlinkLogicalCalc topMiddleMergedCalc = new FlinkLogicalCalc(middleCalc.getCluster(), middleCalc.getTraitSet(), bottomCalc, RexProgram.create(bottomCalc.getRowType(), topMiddleMergedProjects, null, Collections.singletonList("f0"), call.builder().getRexBuilder()));
        RexProgram bottomProgram = bottomCalc.getProgram();
        List bottomProjects = bottomProgram.getProjectList().stream().map(bottomProgram::expandLocalRef).map(x -> (RexCall)x).collect(Collectors.toList());
        RexCall bottomPythonCall = (RexCall)bottomProjects.get(0);
        if (PythonUtil.isPythonCall(topPythonCall, PythonFunctionKind.GENERAL) ^ PythonUtil.isPythonCall(bottomPythonCall, PythonFunctionKind.GENERAL)) {
            call.transformTo(topMiddleMergedCalc);
        } else {
            RexBuilder rexBuilder = call.builder().getRexBuilder();
            RexProgram mergedProgram = RexProgramBuilder.mergePrograms(topMiddleMergedCalc.getProgram(), bottomCalc.getProgram(), rexBuilder);
            Calc newCalc = topMiddleMergedCalc.copy(topMiddleMergedCalc.getTraitSet(), bottomCalc.getInput(), mergedProgram);
            call.transformTo(newCalc);
        }
    }
}

