/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;

public class FlinkSemiAntiJoinJoinTransposeRule
extends RelOptRule {
    public static final FlinkSemiAntiJoinJoinTransposeRule INSTANCE = new FlinkSemiAntiJoinJoinTransposeRule();

    private FlinkSemiAntiJoinJoinTransposeRule() {
        super(FlinkSemiAntiJoinJoinTransposeRule.operand(LogicalJoin.class, FlinkSemiAntiJoinJoinTransposeRule.some(FlinkSemiAntiJoinJoinTransposeRule.operand(LogicalJoin.class, FlinkSemiAntiJoinJoinTransposeRule.any()), new RelOptRuleOperand[0])));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        RelNode rightJoinRel;
        RelNode leftJoinRel;
        RexNode newSemiAntiJoinFilter;
        int i;
        LogicalJoin semiAntiJoin = (LogicalJoin)call.rel(0);
        if (semiAntiJoin.getJoinType() != JoinRelType.SEMI && semiAntiJoin.getJoinType() != JoinRelType.ANTI) {
            return;
        }
        Join join = (Join)call.rel(1);
        if (join.getJoinType() == JoinRelType.SEMI || join.getJoinType() == JoinRelType.ANTI) {
            return;
        }
        if (join.getJoinType() != JoinRelType.INNER) {
            return;
        }
        Pair<ImmutableBitSet, ImmutableBitSet> inputRefs = this.getSemiAntiJoinConditionInputRefs(semiAntiJoin);
        ImmutableBitSet leftInputRefs = (ImmutableBitSet)inputRefs.left;
        ImmutableBitSet rightInputRefs = (ImmutableBitSet)inputRefs.right;
        if (leftInputRefs.isEmpty() || rightInputRefs.isEmpty()) {
            return;
        }
        int nFieldsX = join.getLeft().getRowType().getFieldList().size();
        int nFieldsY = join.getRight().getRowType().getFieldList().size();
        int nFieldsZ = semiAntiJoin.getRight().getRowType().getFieldList().size();
        int nTotalFields = nFieldsX + nFieldsY + nFieldsZ;
        ArrayList<RelDataTypeField> fields = new ArrayList<RelDataTypeField>();
        List<RelDataTypeField> joinFields = semiAntiJoin.getRowType().getFieldList();
        for (i = 0; i < nFieldsX + nFieldsY; ++i) {
            fields.add(joinFields.get(i));
        }
        joinFields = semiAntiJoin.getRight().getRowType().getFieldList();
        for (i = 0; i < nFieldsZ; ++i) {
            fields.add(joinFields.get(i));
        }
        int nKeysFromX = 0;
        int nKeysFromY = 0;
        for (int leftKey : leftInputRefs) {
            if (leftKey < nFieldsX) {
                ++nKeysFromX;
                continue;
            }
            ++nKeysFromY;
        }
        if (nKeysFromX > 0 && nKeysFromY > 0) {
            return;
        }
        assert (nKeysFromX == 0 || nKeysFromX == leftInputRefs.cardinality());
        assert (nKeysFromY == 0 || nKeysFromY == leftInputRefs.cardinality());
        int[] adjustments = new int[nTotalFields];
        if (nKeysFromX > 0) {
            this.setJoinAdjustments(adjustments, nFieldsX, nFieldsY, nFieldsZ, 0, -nFieldsY);
            newSemiAntiJoinFilter = semiAntiJoin.getCondition().accept(new RelOptUtil.RexInputConverter(semiAntiJoin.getCluster().getRexBuilder(), fields, adjustments));
        } else {
            this.setJoinAdjustments(adjustments, nFieldsX, nFieldsY, nFieldsZ, -nFieldsX, -nFieldsX);
            newSemiAntiJoinFilter = semiAntiJoin.getCondition().accept(new RelOptUtil.RexInputConverter(semiAntiJoin.getCluster().getRexBuilder(), fields, adjustments));
        }
        RelNode newSemiAntiJoinLeft = nKeysFromX > 0 ? join.getLeft() : join.getRight();
        LogicalJoin newSemiAntiJoin = LogicalJoin.create(newSemiAntiJoinLeft, semiAntiJoin.getRight(), newSemiAntiJoinFilter, semiAntiJoin.getVariablesSet(), semiAntiJoin.getJoinType());
        if (nKeysFromX > 0) {
            leftJoinRel = newSemiAntiJoin;
            rightJoinRel = join.getRight();
        } else {
            leftJoinRel = join.getLeft();
            rightJoinRel = newSemiAntiJoin;
        }
        Join newJoinRel = join.copy(join.getTraitSet(), join.getCondition(), leftJoinRel, rightJoinRel, join.getJoinType(), join.isSemiJoinDone());
        call.transformTo(newJoinRel);
    }

    private void setJoinAdjustments(int[] adjustments, int nFieldsX, int nFieldsY, int nFieldsZ, int adjustY, int adjustZ) {
        int i;
        for (i = 0; i < nFieldsX; ++i) {
            adjustments[i] = 0;
        }
        for (i = nFieldsX; i < nFieldsX + nFieldsY; ++i) {
            adjustments[i] = adjustY;
        }
        for (i = nFieldsX + nFieldsY; i < nFieldsX + nFieldsY + nFieldsZ; ++i) {
            adjustments[i] = adjustZ;
        }
    }

    private Pair<ImmutableBitSet, ImmutableBitSet> getSemiAntiJoinConditionInputRefs(Join semiAntiJoin) {
        final int leftInputFieldCount = semiAntiJoin.getLeft().getRowType().getFieldCount();
        final ImmutableBitSet.Builder leftInputBitSet = ImmutableBitSet.builder();
        final ImmutableBitSet.Builder rightInputBitSet = ImmutableBitSet.builder();
        semiAntiJoin.getCondition().accept(new RexVisitorImpl<Void>(true){

            @Override
            public Void visitInputRef(RexInputRef inputRef) {
                int index = inputRef.getIndex();
                if (index < leftInputFieldCount) {
                    leftInputBitSet.set(index);
                } else {
                    rightInputBitSet.set(index);
                }
                return null;
            }
        });
        return new Pair<ImmutableBitSet, ImmutableBitSet>(leftInputBitSet.build(), rightInputBitSet.build());
    }
}

