/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.planner.distribution;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
import org.apache.iotdb.commons.partition.DataPartition;
import org.apache.iotdb.db.queryengine.plan.analyze.Analysis;
import org.apache.iotdb.db.queryengine.plan.planner.distribution.NodeDistribution;
import org.apache.iotdb.db.queryengine.plan.planner.distribution.NodeDistributionType;
import org.apache.iotdb.db.queryengine.plan.planner.distribution.NodeGroupContext;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanVisitor;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.WritePlanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.AbstractSchemaMergeNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.CountSchemaMergeNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.SchemaFetchMergeNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.SchemaFetchScanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.SchemaQueryMergeNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.SchemaQueryOrderByHeatNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.metedata.read.SchemaQueryScanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AggregationMergeSortNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.DeviceMergeNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.DeviceViewNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.ExchangeNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.FilterNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.GroupByLevelNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.GroupByTagNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.HorizontallyConcatNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.LimitNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.MergeSortNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.MultiChildProcessNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.SingleDeviceViewNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.SlidingWindowAggregationNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.SortNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.TopKNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.TransformNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.join.FullOuterTimeJoinNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.join.InnerTimeJoinNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.join.LeftOuterTimeJoinNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.last.LastQueryCollectNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.last.LastQueryMergeNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.last.LastQueryNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.process.last.LastQueryTransformNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.source.AlignedLastQueryScanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.source.AlignedSeriesAggregationScanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.source.AlignedSeriesScanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.source.LastQueryScanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.source.SeriesAggregationScanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.source.SeriesScanNode;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.source.SourceNode;

public class ExchangeNodeAdder
extends PlanVisitor<PlanNode, NodeGroupContext> {
    private final Analysis analysis;

    public ExchangeNodeAdder(Analysis analysis) {
        this.analysis = analysis;
    }

    @Override
    public PlanNode visitPlan(PlanNode node, NodeGroupContext context) {
        if (node instanceof WritePlanNode) {
            return node;
        }
        List children = (List)node.getChildren().stream().map(child -> child.accept(this, context)).collect(ImmutableList.toImmutableList());
        context.putNodeDistribution(node.getPlanNodeId(), new NodeDistribution(NodeDistributionType.SAME_WITH_ALL_CHILDREN, null));
        return node.cloneWithChildren(children);
    }

    @Override
    public PlanNode visitSchemaQueryMerge(SchemaQueryMergeNode node, NodeGroupContext context) {
        return this.internalVisitSchemaMerge(node, context);
    }

    private PlanNode internalVisitSchemaMerge(AbstractSchemaMergeNode node, NodeGroupContext context) {
        node.getChildren().forEach(child -> this.visit((PlanNode)child, context));
        NodeDistribution nodeDistribution = new NodeDistribution(NodeDistributionType.DIFFERENT_FROM_ALL_CHILDREN);
        PlanNode newNode = node.clone();
        nodeDistribution.setRegion(this.calculateSchemaRegionByChildren(node.getChildren(), context));
        context.putNodeDistribution(newNode.getPlanNodeId(), nodeDistribution);
        node.getChildren().forEach(child -> {
            if (!nodeDistribution.getRegion().equals(context.getNodeDistribution(child.getPlanNodeId()).getRegion())) {
                ExchangeNode exchangeNode = new ExchangeNode(context.queryContext.getQueryId().genPlanNodeId());
                exchangeNode.setChild((PlanNode)child);
                exchangeNode.setOutputColumnNames(child.getOutputColumnNames());
                context.hasExchangeNode = true;
                newNode.addChild(exchangeNode);
            } else {
                newNode.addChild((PlanNode)child);
            }
        });
        return newNode;
    }

    @Override
    public PlanNode visitCountMerge(CountSchemaMergeNode node, NodeGroupContext context) {
        return this.internalVisitSchemaMerge(node, context);
    }

    @Override
    public PlanNode visitSchemaFetchMerge(SchemaFetchMergeNode node, NodeGroupContext context) {
        return this.internalVisitSchemaMerge(node, context);
    }

    @Override
    public PlanNode visitSchemaQueryScan(SchemaQueryScanNode node, NodeGroupContext context) {
        return this.processNoChildSourceNode(node, context);
    }

    @Override
    public PlanNode visitSchemaFetchScan(SchemaFetchScanNode node, NodeGroupContext context) {
        return this.processNoChildSourceNode(node, context);
    }

    @Override
    public PlanNode visitSeriesScan(SeriesScanNode node, NodeGroupContext context) {
        return this.processNoChildSourceNode(node, context);
    }

    @Override
    public PlanNode visitAlignedSeriesScan(AlignedSeriesScanNode node, NodeGroupContext context) {
        return this.processNoChildSourceNode(node, context);
    }

    @Override
    public PlanNode visitLastQueryScan(LastQueryScanNode node, NodeGroupContext context) {
        return this.processNoChildSourceNode(node, context);
    }

    @Override
    public PlanNode visitAlignedLastQueryScan(AlignedLastQueryScanNode node, NodeGroupContext context) {
        return this.processNoChildSourceNode(node, context);
    }

    @Override
    public PlanNode visitSeriesAggregationScan(SeriesAggregationScanNode node, NodeGroupContext context) {
        return this.processNoChildSourceNode(node, context);
    }

    @Override
    public PlanNode visitAlignedSeriesAggregationScan(AlignedSeriesAggregationScanNode node, NodeGroupContext context) {
        return this.processNoChildSourceNode(node, context);
    }

    private PlanNode processNoChildSourceNode(SourceNode node, NodeGroupContext context) {
        context.putNodeDistribution(node.getPlanNodeId(), new NodeDistribution(NodeDistributionType.NO_CHILD, node.getRegionReplicaSet()));
        return node.clone();
    }

    @Override
    public PlanNode visitDeviceView(DeviceViewNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitAggregationMergeSort(AggregationMergeSortNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitDeviceMerge(DeviceMergeNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitSingleDeviceView(SingleDeviceViewNode node, NodeGroupContext context) {
        return this.processOneChildNode(node, context);
    }

    @Override
    public PlanNode visitMergeSort(MergeSortNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitTopK(TopKNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitLastQueryMerge(LastQueryMergeNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitLastQueryCollect(LastQueryCollectNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitLastQuery(LastQueryNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitLastQueryTransform(LastQueryTransformNode node, NodeGroupContext context) {
        return this.processOneChildNode(node, context);
    }

    @Override
    public PlanNode visitFullOuterTimeJoin(FullOuterTimeJoinNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitLeftOuterTimeJoin(LeftOuterTimeJoinNode node, NodeGroupContext context) {
        ExchangeNode exchangeNode;
        TRegionReplicaSet dataRegion;
        NodeDistributionType distributionType;
        LeftOuterTimeJoinNode newNode = (LeftOuterTimeJoinNode)node.clone();
        PlanNode leftChild = this.visit(node.getLeftChild(), context);
        PlanNode rightChild = this.visit(node.getRightChild(), context);
        boolean isChildrenDistributionSame = this.nodeDistributionIsSame(Arrays.asList(leftChild, rightChild), context);
        NodeDistributionType nodeDistributionType = distributionType = isChildrenDistributionSame ? NodeDistributionType.SAME_WITH_ALL_CHILDREN : NodeDistributionType.SAME_WITH_SOME_CHILD;
        if (context.isAlignByDevice()) {
            dataRegion = isChildrenDistributionSame ? context.getNodeDistribution(leftChild.getPlanNodeId()).getRegion() : context.getMostlyUsedDataRegion();
            context.putNodeDistribution(newNode.getPlanNodeId(), new NodeDistribution(distributionType, dataRegion));
        } else {
            dataRegion = this.calculateDataRegionByChildren(Arrays.asList(leftChild, rightChild), context);
            context.putNodeDistribution(newNode.getPlanNodeId(), new NodeDistribution(distributionType, dataRegion));
        }
        if (distributionType == NodeDistributionType.SAME_WITH_ALL_CHILDREN) {
            newNode.setLeftChild(leftChild);
            newNode.setRightChild(rightChild);
            return newNode;
        }
        if (!dataRegion.equals(context.getNodeDistribution(leftChild.getPlanNodeId()).getRegion())) {
            if (leftChild instanceof SingleDeviceViewNode) {
                ((SingleDeviceViewNode)leftChild).setCacheOutputColumnNames(true);
            }
            exchangeNode = new ExchangeNode(context.queryContext.getQueryId().genPlanNodeId());
            exchangeNode.setChild(leftChild);
            exchangeNode.setOutputColumnNames(leftChild.getOutputColumnNames());
            context.hasExchangeNode = true;
            newNode.setLeftChild(exchangeNode);
        } else {
            newNode.setLeftChild(leftChild);
        }
        if (!dataRegion.equals(context.getNodeDistribution(rightChild.getPlanNodeId()).getRegion())) {
            if (rightChild instanceof SingleDeviceViewNode) {
                ((SingleDeviceViewNode)rightChild).setCacheOutputColumnNames(true);
            }
            exchangeNode = new ExchangeNode(context.queryContext.getQueryId().genPlanNodeId());
            exchangeNode.setChild(rightChild);
            exchangeNode.setOutputColumnNames(rightChild.getOutputColumnNames());
            context.hasExchangeNode = true;
            newNode.setRightChild(exchangeNode);
        } else {
            newNode.setRightChild(rightChild);
        }
        return newNode;
    }

    @Override
    public PlanNode visitInnerTimeJoin(InnerTimeJoinNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitAggregation(AggregationNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitSchemaQueryOrderByHeat(SchemaQueryOrderByHeatNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitGroupByLevel(GroupByLevelNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitTransform(TransformNode node, NodeGroupContext context) {
        return this.processOneChildNode(node, context);
    }

    @Override
    public PlanNode visitFilter(FilterNode node, NodeGroupContext context) {
        return this.processOneChildNode(node, context);
    }

    @Override
    public PlanNode visitGroupByTag(GroupByTagNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitHorizontallyConcat(HorizontallyConcatNode node, NodeGroupContext context) {
        return this.processMultiChildNode(node, context);
    }

    @Override
    public PlanNode visitSort(SortNode node, NodeGroupContext context) {
        return this.processOneChildNode(node, context);
    }

    @Override
    public PlanNode visitLimit(LimitNode node, NodeGroupContext context) {
        return this.processOneChildNode(node, context);
    }

    private PlanNode processMultiChildNode(MultiChildProcessNode node, NodeGroupContext context) {
        TRegionReplicaSet dataRegion;
        NodeDistributionType distributionType;
        if (this.analysis.isVirtualSource()) {
            return this.processMultiChildNodeByLocation(node, context);
        }
        MultiChildProcessNode newNode = (MultiChildProcessNode)node.clone();
        List<PlanNode> visitedChildren = node.getChildren().stream().map(child -> this.visit((PlanNode)child, context)).collect(Collectors.toList());
        boolean isChildrenDistributionSame = this.nodeDistributionIsSame(visitedChildren, context);
        NodeDistributionType nodeDistributionType = distributionType = isChildrenDistributionSame ? NodeDistributionType.SAME_WITH_ALL_CHILDREN : NodeDistributionType.SAME_WITH_SOME_CHILD;
        if (context.isAlignByDevice()) {
            dataRegion = isChildrenDistributionSame ? context.getNodeDistribution(visitedChildren.get(0).getPlanNodeId()).getRegion() : context.getMostlyUsedDataRegion();
            context.putNodeDistribution(newNode.getPlanNodeId(), new NodeDistribution(distributionType, dataRegion));
        } else {
            dataRegion = this.calculateDataRegionByChildren(visitedChildren, context);
            context.putNodeDistribution(newNode.getPlanNodeId(), new NodeDistribution(distributionType, dataRegion));
        }
        if (distributionType == NodeDistributionType.SAME_WITH_ALL_CHILDREN) {
            newNode.setChildren(visitedChildren);
            return newNode;
        }
        if (node instanceof TopKNode) {
            return this.processTopKNode(node, visitedChildren, context, newNode, dataRegion);
        }
        for (PlanNode child2 : visitedChildren) {
            if (!dataRegion.equals(context.getNodeDistribution(child2.getPlanNodeId()).getRegion())) {
                if (child2 instanceof SingleDeviceViewNode) {
                    ((SingleDeviceViewNode)child2).setCacheOutputColumnNames(true);
                }
                ExchangeNode exchangeNode = this.genExchangeNode(context, child2);
                newNode.addChild(exchangeNode);
                continue;
            }
            newNode.addChild(child2);
        }
        return newNode;
    }

    private PlanNode processMultiChildNodeByLocation(MultiChildProcessNode node, NodeGroupContext context) {
        MultiChildProcessNode newNode = (MultiChildProcessNode)node.clone();
        List<PlanNode> children = node.getChildren();
        newNode.addChild(children.get(0));
        for (int i = 1; i < children.size(); ++i) {
            PlanNode child = children.get(i);
            ExchangeNode exchangeNode = new ExchangeNode(context.queryContext.getQueryId().genPlanNodeId());
            exchangeNode.setChild(child);
            exchangeNode.setOutputColumnNames(child.getOutputColumnNames());
            context.hasExchangeNode = true;
            newNode.addChild(exchangeNode);
        }
        return newNode;
    }

    private PlanNode processTopKNode(MultiChildProcessNode node, List<PlanNode> visitedChildren, NodeGroupContext context, MultiChildProcessNode newNode, TRegionReplicaSet dataRegion) {
        TopKNode rootNode = (TopKNode)node;
        HashMap<TRegionReplicaSet, TopKNode> regionTopKNodeMap = new HashMap<TRegionReplicaSet, TopKNode>();
        for (PlanNode planNode : visitedChildren) {
            TRegionReplicaSet region = context.getNodeDistribution(planNode.getPlanNodeId()).getRegion();
            regionTopKNodeMap.computeIfAbsent(region, k -> {
                TopKNode childTopKNode = new TopKNode(context.queryContext.getQueryId().genPlanNodeId(), rootNode.getTopValue(), rootNode.getMergeOrderParameter(), rootNode.getOutputColumnNames());
                context.putNodeDistribution(childTopKNode.getPlanNodeId(), new NodeDistribution(NodeDistributionType.SAME_WITH_ALL_CHILDREN, region));
                return childTopKNode;
            }).addChild(planNode);
        }
        for (Map.Entry entry : regionTopKNodeMap.entrySet()) {
            TRegionReplicaSet topKNodeLocatedRegion = (TRegionReplicaSet)entry.getKey();
            TopKNode topKNode = (TopKNode)entry.getValue();
            if (!dataRegion.equals(topKNodeLocatedRegion)) {
                ExchangeNode exchangeNode = this.genExchangeNode(context, topKNode);
                newNode.addChild(exchangeNode);
                continue;
            }
            newNode.addChild(topKNode);
        }
        return newNode;
    }

    private ExchangeNode genExchangeNode(NodeGroupContext context, PlanNode child) {
        ExchangeNode exchangeNode = new ExchangeNode(context.queryContext.getQueryId().genPlanNodeId());
        exchangeNode.setChild(child);
        exchangeNode.setOutputColumnNames(child.getOutputColumnNames());
        context.hasExchangeNode = true;
        return exchangeNode;
    }

    @Override
    public PlanNode visitSlidingWindowAggregation(SlidingWindowAggregationNode node, NodeGroupContext context) {
        return this.processOneChildNode(node, context);
    }

    private PlanNode processOneChildNode(PlanNode node, NodeGroupContext context) {
        PlanNode newNode = node.clone();
        PlanNode child = this.visit(node.getChildren().get(0), context);
        newNode.addChild(child);
        TRegionReplicaSet dataRegion = context.getNodeDistribution(child.getPlanNodeId()).getRegion();
        context.putNodeDistribution(newNode.getPlanNodeId(), new NodeDistribution(NodeDistributionType.SAME_WITH_ALL_CHILDREN, dataRegion));
        return newNode;
    }

    private TRegionReplicaSet calculateDataRegionByChildren(List<PlanNode> children, NodeGroupContext context) {
        Map<TRegionReplicaSet, Long> groupByRegion = children.stream().collect(Collectors.groupingBy(child -> {
            TRegionReplicaSet region = context.getNodeDistribution(child.getPlanNodeId()).getRegion();
            if (region == null && context.getNodeDistribution(child.getPlanNodeId()).getType() == NodeDistributionType.SAME_WITH_ALL_CHILDREN) {
                return this.calculateSchemaRegionByChildren(child.getChildren(), context);
            }
            return region;
        }, Collectors.counting()));
        if (groupByRegion.size() == 1) {
            return groupByRegion.keySet().iterator().next();
        }
        long maxCount = -1L;
        TRegionReplicaSet result = DataPartition.NOT_ASSIGNED;
        if (context.queryContext.getMainFragmentLocatedRegion() != null && groupByRegion.containsKey(context.queryContext.getMainFragmentLocatedRegion())) {
            return context.queryContext.getMainFragmentLocatedRegion();
        }
        if (context.getMostlyUsedDataRegion() != null && groupByRegion.containsKey(context.getMostlyUsedDataRegion())) {
            return context.getMostlyUsedDataRegion();
        }
        for (Map.Entry<TRegionReplicaSet, Long> entry : groupByRegion.entrySet()) {
            TRegionReplicaSet region = entry.getKey();
            if (DataPartition.NOT_ASSIGNED.equals(region)) continue;
            long planNodeCount = entry.getValue();
            if (planNodeCount > maxCount) {
                maxCount = planNodeCount;
                result = region;
                continue;
            }
            if (planNodeCount != maxCount || region.getRegionId().getId() >= result.getRegionId().getId()) continue;
            result = region;
        }
        return result;
    }

    private TRegionReplicaSet calculateSchemaRegionByChildren(List<PlanNode> children, NodeGroupContext context) {
        return context.getNodeDistribution(children.get(0).getPlanNodeId()).getRegion();
    }

    private boolean nodeDistributionIsSame(List<PlanNode> children, NodeGroupContext context) {
        NodeDistribution first = context.getNodeDistribution(children.get(0).getPlanNodeId());
        for (int i = 1; i < children.size(); ++i) {
            NodeDistribution next = context.getNodeDistribution(children.get(i).getPlanNodeId());
            if (first.getRegion() != null && first.getRegion().equals(next.getRegion())) continue;
            return false;
        }
        return true;
    }

    public PlanNode visit(PlanNode node, NodeGroupContext context) {
        return node.accept(this, context);
    }
}

