/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalSnapshot;
import org.apache.calcite.rel.logical.LogicalTableScan;
import org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Pair;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableMap;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalJoin;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableJoinToMultiJoinRule;
import org.apache.flink.table.planner.plan.utils.IntervalJoinUtil;
import org.immutables.value.Value;

@Value.Enclosing
public class JoinToMultiJoinRule
extends RelRule<Config>
implements TransformationRule {
    public static final JoinToMultiJoinRule INSTANCE = Config.DEFAULT.toRule();

    public JoinToMultiJoinRule(Config config) {
        super(config);
    }

    @Deprecated
    public JoinToMultiJoinRule(Class<? extends Join> clazz) {
        this(Config.DEFAULT.withOperandFor(clazz));
    }

    @Deprecated
    public JoinToMultiJoinRule(Class<? extends Join> joinClass, RelBuilderFactory relBuilderFactory) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class).withOperandFor(joinClass));
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Join origJoin = (Join)call.rel(0);
        if (origJoin.getJoinType() != JoinRelType.INNER && origJoin.getJoinType() != JoinRelType.LEFT) {
            return false;
        }
        if (this.isIntervalJoin(origJoin) || this.isTemporalJoin(call)) {
            return false;
        }
        return origJoin.getJoinType().projectsRight();
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Join origJoin = (Join)call.rel(0);
        Object left = call.rel(1);
        Object right = call.rel(2);
        ArrayList<Boolean> inputNullGenFieldList = new ArrayList<Boolean>();
        this.buildInputNullGenFieldList((RelNode)left, (RelNode)right, origJoin.getJoinType(), (List<Boolean>)inputNullGenFieldList);
        ArrayList<ImmutableBitSet> projFieldsList = new ArrayList<ImmutableBitSet>();
        ArrayList<int[]> joinFieldRefCountsList = new ArrayList<int[]>();
        List<RelNode> newInputs = this.combineInputs(origJoin, (RelNode)left, (RelNode)right, (List<ImmutableBitSet>)projFieldsList, (List<int[]>)joinFieldRefCountsList, (List<Boolean>)inputNullGenFieldList);
        ArrayList<Pair<JoinRelType, RexNode>> joinSpecs = new ArrayList<Pair<JoinRelType, RexNode>>();
        this.combineJoinInfo(origJoin, newInputs, (RelNode)left, (RelNode)right, (List<Pair<JoinRelType, RexNode>>)joinSpecs, (List<Boolean>)inputNullGenFieldList);
        List<RexNode> newJoinFilters = this.combineJoinFilters(origJoin, (RelNode)left, (RelNode)right, (List<Boolean>)inputNullGenFieldList);
        Map<Integer, ImmutableIntList> newJoinFieldRefCountsMap = this.addOnJoinFieldRefCounts(newInputs, origJoin.getRowType().getFieldCount(), origJoin.getCondition(), joinFieldRefCountsList);
        List<RexNode> newPostJoinFilters = this.combinePostJoinFilters(origJoin, (RelNode)left, (RelNode)right);
        RexBuilder rexBuilder = origJoin.getCluster().getRexBuilder();
        MultiJoin multiJoin = new MultiJoin(origJoin.getCluster(), newInputs, RexUtil.composeConjunction(rexBuilder, newJoinFilters), origJoin.getRowType(), origJoin.getJoinType() == JoinRelType.FULL, Pair.right(joinSpecs), Pair.left(joinSpecs), projFieldsList, ImmutableMap.copyOf(newJoinFieldRefCountsMap), RexUtil.composeConjunction(rexBuilder, newPostJoinFilters, true));
        call.transformTo(multiJoin);
    }

    private void buildInputNullGenFieldList(RelNode left, RelNode right, JoinRelType joinType, List<Boolean> isNullGenFieldList) {
        if (joinType == JoinRelType.INNER) {
            this.buildNullGenFieldList(left, isNullGenFieldList);
            this.buildNullGenFieldList(right, isNullGenFieldList);
        } else if (joinType == JoinRelType.LEFT) {
            this.buildNullGenFieldList(left, isNullGenFieldList);
            for (int i = 0; i < right.getRowType().getFieldCount(); ++i) {
                isNullGenFieldList.add(true);
            }
        } else {
            throw new TableException("This is a bug. Now, join to multi join rule only support Full outer join, Inner join and Left/Right join.");
        }
    }

    private void buildNullGenFieldList(RelNode rel, List<Boolean> isNullGenFieldList) {
        MultiJoin multiJoin;
        MultiJoin multiJoin2 = multiJoin = rel instanceof MultiJoin ? (MultiJoin)rel : null;
        if (multiJoin == null) {
            for (int i = 0; i < rel.getRowType().getFieldCount(); ++i) {
                isNullGenFieldList.add(false);
            }
        } else {
            List<RelNode> inputs = multiJoin.getInputs();
            List<JoinRelType> joinTypes = multiJoin.getJoinTypes();
            for (int i = 0; i < inputs.size() - 1; ++i) {
                if (joinTypes.get(i) == JoinRelType.RIGHT) {
                    this.buildInputNullGenFieldList(inputs.get(i), inputs.get(i + 1), joinTypes.get(i), isNullGenFieldList);
                    continue;
                }
                this.buildInputNullGenFieldList(inputs.get(i), inputs.get(i + 1), joinTypes.get(i + 1), isNullGenFieldList);
            }
        }
    }

    private List<RelNode> combineInputs(Join join, RelNode left, RelNode right, List<ImmutableBitSet> projFieldsList, List<int[]> joinFieldRefCountsList, List<Boolean> inputNullGenFieldList) {
        int i;
        ArrayList<RelNode> newInputs = new ArrayList<RelNode>();
        JoinInfo joinInfo = join.analyzeCondition();
        ImmutableIntList leftKeys = joinInfo.leftKeys;
        ImmutableIntList rightKeys = joinInfo.rightKeys;
        if (this.canCombine(left, join)) {
            MultiJoin leftMultiJoin = (MultiJoin)left;
            for (i = 0; i < leftMultiJoin.getInputs().size(); ++i) {
                newInputs.add(leftMultiJoin.getInput(i));
                projFieldsList.add(leftMultiJoin.getProjFields().get(i));
                joinFieldRefCountsList.add(leftMultiJoin.getJoinFieldRefCountsMap().get(i).toIntArray());
            }
        } else {
            newInputs.add(left);
            projFieldsList.add(null);
            joinFieldRefCountsList.add(new int[left.getRowType().getFieldCount()]);
        }
        if (this.canCombine(right, join)) {
            MultiJoin rightMultiJoin = (MultiJoin)right;
            for (i = 0; i < rightMultiJoin.getInputs().size(); ++i) {
                newInputs.add(rightMultiJoin.getInput(i));
                projFieldsList.add(rightMultiJoin.getProjFields().get(i));
                joinFieldRefCountsList.add(rightMultiJoin.getJoinFieldRefCountsMap().get(i).toIntArray());
            }
        } else {
            newInputs.add(right);
            projFieldsList.add(null);
            joinFieldRefCountsList.add(new int[right.getRowType().getFieldCount()]);
        }
        return newInputs;
    }

    private void combineJoinInfo(Join joinRel, List<RelNode> combinedInputs, RelNode left, RelNode right, List<Pair<JoinRelType, RexNode>> joinSpecs, List<Boolean> inputNullGenFieldList) {
        JoinRelType joinType = joinRel.getJoinType();
        JoinInfo joinInfo = joinRel.analyzeCondition();
        ImmutableIntList leftKeys = joinInfo.leftKeys;
        RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder();
        boolean leftCombined = this.canCombine(left, joinRel);
        switch (joinType) {
            case LEFT: {
                if (leftCombined) {
                    this.copyJoinInfo((MultiJoin)left, joinSpecs);
                } else {
                    joinSpecs.add(Pair.of(JoinRelType.INNER, rexBuilder.makeLiteral(true)));
                }
                joinSpecs.add(Pair.of(joinType, joinRel.getCondition()));
                break;
            }
            case INNER: {
                if (leftCombined) {
                    this.copyJoinInfo((MultiJoin)left, joinSpecs);
                } else {
                    joinSpecs.add(Pair.of(JoinRelType.INNER, rexBuilder.makeLiteral(true)));
                }
                joinSpecs.add(Pair.of(joinType, joinRel.getCondition()));
                break;
            }
            default: {
                throw new TableException("This is a bug. This rule only supports left and inner joins");
            }
        }
    }

    private void copyJoinInfo(MultiJoin multiJoin, List<Pair<JoinRelType, RexNode>> destJoinSpecs) {
        List<Pair<JoinRelType, RexNode>> srcJoinSpecs = Pair.zip(multiJoin.getJoinTypes(), multiJoin.getOuterJoinConditions());
        destJoinSpecs.addAll(srcJoinSpecs);
    }

    private List<RexNode> combineJoinFilters(Join join, RelNode left, RelNode right, List<Boolean> inputNullGenFieldList) {
        JoinRelType joinType = join.getJoinType();
        JoinInfo joinInfo = join.analyzeCondition();
        ImmutableIntList leftKeys = joinInfo.leftKeys;
        if (joinType == JoinRelType.RIGHT) {
            throw new TableException("This is a bug. This rule only supports left and inner joins");
        }
        ArrayList<RexNode> filters = new ArrayList<RexNode>();
        if (joinType != JoinRelType.LEFT) {
            filters.add(join.getCondition());
        }
        if (this.canCombine(left, join)) {
            filters.add(((MultiJoin)left).getJoinFilter());
        }
        return filters;
    }

    private boolean canCombine(RelNode input, Join origJoin) {
        if (input instanceof MultiJoin) {
            MultiJoin join = (MultiJoin)input;
            if (join.isFullOuterJoin()) {
                return false;
            }
            return this.haveCommonJoinKey(origJoin, join);
        }
        return false;
    }

    private boolean haveCommonJoinKey(Join origJoin, MultiJoin otherJoin) {
        Set<String> origJoinKeys = this.getJoinKeys(origJoin);
        Set<String> otherJoinKeys = this.getJoinKeys(otherJoin);
        origJoinKeys.retainAll(otherJoinKeys);
        return !origJoinKeys.isEmpty();
    }

    public Set<String> getJoinKeys(RelNode join) {
        HashSet<String> joinKeys = new HashSet<String>();
        List<Object> conditions = Collections.emptyList();
        List<RelNode> inputs = join.getInputs();
        if (join instanceof Join) {
            conditions = this.collectConjunctions(((Join)join).getCondition());
        } else if (join instanceof MultiJoin) {
            conditions = ((MultiJoin)join).getOuterJoinConditions().stream().flatMap(cond -> this.collectConjunctions((RexNode)cond).stream()).collect(Collectors.toList());
        }
        RelMetadataQuery mq = join.getCluster().getMetadataQuery();
        for (RexCall condition : conditions) {
            for (RexNode operand : condition.getOperands()) {
                if (!(operand instanceof RexInputRef)) continue;
                this.addJoinKeysByOperand((RexInputRef)operand, inputs, mq, joinKeys);
            }
        }
        return joinKeys;
    }

    private List<RexCall> collectConjunctions(RexNode joinCondition) {
        return RelOptUtil.conjunctions(joinCondition).stream().map(rexNode -> (RexCall)rexNode).collect(Collectors.toList());
    }

    private void addJoinKeysByOperand(RexInputRef ref, List<RelNode> inputs, RelMetadataQuery mq, Set<String> joinKeys) {
        int inputRefIndex = ref.getIndex();
        Tuple2<RelNode, Integer> targetInputAndIdx = this.getTargetInputAndIdx(inputRefIndex, inputs);
        RelNode targetInput = (RelNode)targetInputAndIdx.f0;
        int idxInTargetInput = (Integer)targetInputAndIdx.f1;
        Set<RelColumnOrigin> origins = mq.getColumnOrigins(targetInput, idxInTargetInput);
        if (origins != null) {
            for (RelColumnOrigin origin : origins) {
                RelOptTable originTable = origin.getOriginTable();
                List<String> qualifiedName = originTable.getQualifiedName();
                String fieldName = originTable.getRowType().getFieldList().get(origin.getOriginColumnOrdinal()).getName();
                joinKeys.add(qualifiedName.get(qualifiedName.size() - 1) + "." + fieldName);
            }
        }
    }

    private Tuple2<RelNode, Integer> getTargetInputAndIdx(int inputRefIndex, List<RelNode> inputs) {
        RelNode targetInput = null;
        int idxInTargetInput = 0;
        int inputFieldEnd = 0;
        for (RelNode input : inputs) {
            if (inputRefIndex >= (inputFieldEnd += input.getRowType().getFieldCount())) continue;
            targetInput = input;
            int targetInputStartIdx = inputFieldEnd - input.getRowType().getFieldCount();
            idxInTargetInput = inputRefIndex - targetInputStartIdx;
            break;
        }
        RelNode relNode = targetInput = targetInput instanceof HepRelVertex ? ((HepRelVertex)targetInput).getCurrentRel() : targetInput;
        assert (targetInput != null);
        if (targetInput instanceof LogicalTableScan) {
            return new Tuple2((Object)targetInput, (Object)idxInTargetInput);
        }
        return this.getTargetInputAndIdx(idxInTargetInput, targetInput.getInputs());
    }

    private RexNode shiftRightFilter(Join joinRel, RelNode left, MultiJoin right, RexNode rightFilter) {
        if (rightFilter == null) {
            return null;
        }
        int nFieldsOnLeft = left.getRowType().getFieldList().size();
        int nFieldsOnRight = right.getRowType().getFieldList().size();
        int[] adjustments = new int[nFieldsOnRight];
        for (int i = 0; i < nFieldsOnRight; ++i) {
            adjustments[i] = nFieldsOnLeft;
        }
        rightFilter = rightFilter.accept(new RelOptUtil.RexInputConverter(joinRel.getCluster().getRexBuilder(), right.getRowType().getFieldList(), joinRel.getRowType().getFieldList(), adjustments));
        return rightFilter;
    }

    private Map<Integer, ImmutableIntList> addOnJoinFieldRefCounts(List<RelNode> multiJoinInputs, int nTotalFields, RexNode joinCondition, List<int[]> origJoinFieldRefCounts) {
        int[] joinCondRefCounts = new int[nTotalFields];
        joinCondition.accept(new InputReferenceCounter(joinCondRefCounts));
        HashMap<Integer, int[]> refCountsMap = new HashMap<Integer, int[]>();
        int nInputs = multiJoinInputs.size();
        int currInput = 0;
        for (int[] origRefCounts : origJoinFieldRefCounts) {
            refCountsMap.put(currInput, (int[])origRefCounts.clone());
            ++currInput;
        }
        currInput = -1;
        int startField = 0;
        int nFields = 0;
        for (int i = 0; i < nTotalFields; ++i) {
            if (joinCondRefCounts[i] == 0) continue;
            while (i >= startField + nFields) {
                startField += nFields;
                assert (++currInput < nInputs);
                nFields = multiJoinInputs.get(currInput).getRowType().getFieldCount();
            }
            int[] refCounts = (int[])refCountsMap.get(currInput);
            int n = i - startField;
            refCounts[n] = refCounts[n] + joinCondRefCounts[i];
        }
        HashMap<Integer, ImmutableIntList> aMap = new HashMap<Integer, ImmutableIntList>();
        for (Map.Entry entry : refCountsMap.entrySet()) {
            aMap.put((Integer)entry.getKey(), ImmutableIntList.of((int[])entry.getValue()));
        }
        return Collections.unmodifiableMap(aMap);
    }

    private List<RexNode> combinePostJoinFilters(Join joinRel, RelNode left, RelNode right) {
        ArrayList<RexNode> filters = new ArrayList<RexNode>();
        if (right instanceof MultiJoin) {
            MultiJoin multiRight = (MultiJoin)right;
            filters.add(this.shiftRightFilter(joinRel, left, multiRight, multiRight.getPostJoinFilter()));
        }
        if (left instanceof MultiJoin) {
            filters.add(((MultiJoin)left).getPostJoinFilter());
        }
        return filters;
    }

    private boolean isIntervalJoin(Join join) {
        if (!(join instanceof LogicalJoin)) {
            return true;
        }
        FlinkLogicalJoin flinkLogicalJoin = (FlinkLogicalJoin)FlinkLogicalJoin.CONVERTER().convert(join);
        return IntervalJoinUtil.satisfyIntervalJoin(flinkLogicalJoin);
    }

    private boolean isTemporalJoin(RelOptRuleCall call) {
        Object left = call.rel(1);
        Object right = call.rel(2);
        return this.containsSnapshot((RelNode)left) || this.containsSnapshot((RelNode)right);
    }

    private boolean containsSnapshot(RelNode relNode) {
        RelNode original = null;
        original = relNode instanceof RelSubset ? ((RelSubset)relNode).getOriginal() : (relNode instanceof HepRelVertex ? ((HepRelVertex)relNode).getCurrentRel() : relNode);
        if (original instanceof LogicalSnapshot) {
            return true;
        }
        if (original instanceof SingleRel) {
            return this.containsSnapshot(((SingleRel)original).getInput());
        }
        return false;
    }

    @Value.Immutable(singleton=false)
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableJoinToMultiJoinRule.Config.builder().build().as(Config.class).withOperandFor(LogicalJoin.class);

        @Override
        default public JoinToMultiJoinRule toRule() {
            return new JoinToMultiJoinRule(this);
        }

        default public Config withOperandFor(Class<? extends Join> joinClass) {
            return this.withOperandSupplier(b0 -> b0.operand(joinClass).inputs(b1 -> b1.operand(RelNode.class).anyInputs(), b2 -> b2.operand(RelNode.class).anyInputs())).as(Config.class);
        }
    }

    private static class InputReferenceCounter
    extends RexVisitorImpl<Void> {
        private final int[] refCounts;

        InputReferenceCounter(int[] refCounts) {
            super(true);
            this.refCounts = refCounts;
        }

        @Override
        public Void visitInputRef(RexInputRef inputRef) {
            int n = inputRef.getIndex();
            this.refCounts[n] = this.refCounts[n] + 1;
            return null;
        }
    }
}

