/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.AggregateJoinTransposeRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;

public class HiveAggregateJoinTransposeRule
extends AggregateJoinTransposeRule {
    public static final HiveAggregateJoinTransposeRule INSTANCE = new HiveAggregateJoinTransposeRule(HiveAggregate.class, HiveRelFactories.HIVE_AGGREGATE_FACTORY, HiveJoin.class, HiveRelFactories.HIVE_JOIN_FACTORY, HiveRelFactories.HIVE_PROJECT_FACTORY, true);
    private final RelFactories.AggregateFactory aggregateFactory;
    private final RelFactories.JoinFactory joinFactory;
    private final RelFactories.ProjectFactory projectFactory;
    private final boolean allowFunctions;

    private HiveAggregateJoinTransposeRule(Class<? extends Aggregate> aggregateClass, RelFactories.AggregateFactory aggregateFactory, Class<? extends Join> joinClass, RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory, boolean allowFunctions) {
        super(aggregateClass, aggregateFactory, joinClass, joinFactory, projectFactory, true);
        this.aggregateFactory = aggregateFactory;
        this.joinFactory = joinFactory;
        this.projectFactory = projectFactory;
        this.allowFunctions = allowFunctions;
    }

    /*
     * Unable to fully structure code
     */
    public void onMatch(RelOptRuleCall call) {
        block20: {
            aggregate = (Aggregate)call.rel(0);
            join = (Join)call.rel(1);
            rexBuilder = aggregate.getCluster().getRexBuilder();
            for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
                if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
                    return;
                }
                if (aggregateCall.filterArg < 0) continue;
                return;
            }
            if (join.getJoinType() != JoinRelType.INNER) {
                return;
            }
            if (!this.allowFunctions && !aggregate.getAggCallList().isEmpty()) {
                return;
            }
            mq = RelMetadataQuery.instance();
            aggregateColumns = aggregate.getGroupSet();
            keyColumns = HiveAggregateJoinTransposeRule.keyColumns(aggregateColumns, (ImmutableList<RexNode>)mq.getPulledUpPredicates((RelNode)join).pulledUpPredicates);
            joinColumns = RelOptUtil.InputFinder.bits((RexNode)join.getCondition());
            allColumnsInAggregate = keyColumns.contains(joinColumns);
            belowAggregateColumns = aggregateColumns.union(joinColumns);
            leftKeys = Lists.newArrayList();
            rightKeys = Lists.newArrayList();
            filterNulls = Lists.newArrayList();
            nonEquiConj = RelOptUtil.splitJoinCondition((RelNode)join.getLeft(), (RelNode)join.getRight(), (RexNode)join.getCondition(), (List)leftKeys, (List)rightKeys, (List)filterNulls);
            if (!nonEquiConj.isAlwaysTrue()) {
                return;
            }
            map = new HashMap<Object, Integer>();
            sides = new ArrayList<Side>();
            uniqueCount = 0;
            offset = 0;
            belowOffset = 0;
            for (s = 0; s < 2; ++s) {
                side = new Side();
                joinInput = join.getInput(s);
                fieldCount = joinInput.getRowType().getFieldCount();
                fieldSet = ImmutableBitSet.range((int)offset, (int)(offset + fieldCount));
                belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
                for (Ord c : Ord.zip((Iterable)belowAggregateKeyNotShifted)) {
                    map.put(c.e, belowOffset + c.i);
                }
                belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
                if (!this.allowFunctions) {
                    if (!HiveAggregateJoinTransposeRule.$assertionsDisabled && !aggregate.getAggCallList().isEmpty()) {
                        throw new AssertionError();
                    }
                    unique = true;
                } else {
                    unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
                    v0 = unique = unique0 != null && unique0 != false;
                }
                if (unique) {
                    ++uniqueCount;
                    side.newInput = joinInput;
                } else {
                    belowAggCalls = new ArrayList<E>();
                    belowAggCallRegistry = HiveAggregateJoinTransposeRule.registry(belowAggCalls);
                    mapping = s == 0 ? Mappings.createIdentity((int)fieldCount) : Mappings.createShiftMapping((int)(fieldCount + offset), (int[])new int[]{0, offset, fieldCount});
                    for (Ord aggCall : Ord.zip((List)aggregate.getAggCallList())) {
                        aggregation = ((AggregateCall)aggCall.e).getAggregation();
                        splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation.unwrap(SqlSplittableAggFunction.class));
                        call1 = fieldSet.contains(ImmutableBitSet.of((Iterable)((AggregateCall)aggCall.e).getArgList())) != false ? splitter.split((AggregateCall)aggCall.e, (Mappings.TargetMapping)mapping) : splitter.other(rexBuilder.getTypeFactory(), (AggregateCall)aggCall.e);
                        if (call1 == null) continue;
                        side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register((Object)call1));
                    }
                    side.newInput = this.aggregateFactory.createAggregate(joinInput, false, belowAggregateKey, null, belowAggCalls);
                }
                offset += fieldCount;
                belowOffset += side.newInput.getRowType().getFieldCount();
                sides.add(side);
            }
            if (uniqueCount == 2) {
                return;
            }
            mapping = (Mapping)Mappings.target((Function)new Function<Integer, Integer>(){

                public Integer apply(Integer a0) {
                    return (Integer)map.get(a0);
                }
            }, (int)join.getRowType().getFieldCount(), (int)belowOffset);
            newCondition = RexUtil.apply((Mappings.TargetMapping)mapping, (RexNode)join.getCondition());
            newJoin = this.joinFactory.createJoin(((Side)sides.get((int)0)).newInput, ((Side)sides.get((int)1)).newInput, newCondition, join.getJoinType(), join.getVariablesStopped(), join.isSemiJoinDone());
            newAggCalls = new ArrayList<AggregateCall>();
            groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
            newLeftWidth = ((Side)sides.get((int)0)).newInput.getRowType().getFieldCount();
            projects = new ArrayList<E>(rexBuilder.identityProjects(newJoin.getRowType()));
            for (Ord aggCall : Ord.zip((List)aggregate.getAggCallList())) {
                aggregation = ((AggregateCall)aggCall.e).getAggregation();
                splitter = (SqlSplittableAggFunction)Preconditions.checkNotNull((Object)aggregation.unwrap(SqlSplittableAggFunction.class));
                leftSubTotal = ((Side)sides.get((int)0)).split.get(aggCall.i);
                rightSubTotal = ((Side)sides.get((int)1)).split.get(aggCall.i);
                newAggCalls.add(splitter.topSplit(rexBuilder, HiveAggregateJoinTransposeRule.registry(projects), groupIndicatorCount, newJoin.getRowType(), (AggregateCall)aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
            }
            r = newJoin;
            if (allColumnsInAggregate && newAggCalls.isEmpty() && RelOptUtil.areRowTypesEqual((RelDataType)r.getRowType(), (RelDataType)aggregate.getRowType(), (boolean)false)) break block20;
            r = RelOptUtil.createProject((RelNode)r, projects, null, (boolean)true, (RelBuilder)this.relBuilderFactory.create(aggregate.getCluster(), null));
            if (!allColumnsInAggregate) ** GOTO lbl-1000
            projects2 = new ArrayList<Object>();
            aggregation = Mappings.apply((Mapping)mapping, (ImmutableBitSet)aggregate.getGroupSet()).iterator();
            while (aggregation.hasNext()) {
                key = (Integer)aggregation.next();
                projects2.add(rexBuilder.makeInputRef(r, key));
            }
            for (AggregateCall newAggCall : newAggCalls) {
                splitter = (SqlSplittableAggFunction)newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
                if (splitter == null) continue;
                projects2.add(splitter.singleton(rexBuilder, r.getRowType(), newAggCall));
            }
            if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
                r = RelOptUtil.createProject((RelNode)r, projects2, null, (boolean)true, (RelBuilder)this.relBuilderFactory.create(aggregate.getCluster(), null));
            } else lbl-1000:
            // 2 sources

            {
                r = this.aggregateFactory.createAggregate(r, aggregate.indicator, Mappings.apply((Mapping)mapping, (ImmutableBitSet)aggregate.getGroupSet()), Mappings.apply2((Mapping)mapping, (Iterable)aggregate.getGroupSets()), newAggCalls);
            }
        }
        afterCost = mq.getCumulativeCost(r);
        beforeCost = mq.getCumulativeCost((RelNode)aggregate);
        if (afterCost.isLt(beforeCost)) {
            call.transformTo(r);
        }
    }

    private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns, ImmutableList<RexNode> predicates) {
        TreeMap<Integer, BitSet> equivalence = new TreeMap<Integer, BitSet>();
        for (RexNode pred : predicates) {
            HiveAggregateJoinTransposeRule.populateEquivalences(equivalence, pred);
        }
        ImmutableBitSet keyColumns = aggregateColumns;
        for (Integer aggregateColumn : aggregateColumns) {
            BitSet bitSet = (BitSet)equivalence.get(aggregateColumn);
            if (bitSet == null) continue;
            keyColumns = keyColumns.union(bitSet);
        }
        return keyColumns;
    }

    private static void populateEquivalences(Map<Integer, BitSet> equivalence, RexNode predicate) {
        switch (predicate.getKind()) {
            case EQUALS: {
                RexCall call = (RexCall)predicate;
                List operands = call.getOperands();
                if (!(operands.get(0) instanceof RexInputRef)) break;
                RexInputRef ref0 = (RexInputRef)operands.get(0);
                if (!(operands.get(1) instanceof RexInputRef)) break;
                RexInputRef ref1 = (RexInputRef)operands.get(1);
                HiveAggregateJoinTransposeRule.populateEquivalence(equivalence, ref0.getIndex(), ref1.getIndex());
                HiveAggregateJoinTransposeRule.populateEquivalence(equivalence, ref1.getIndex(), ref0.getIndex());
            }
        }
    }

    private static void populateEquivalence(Map<Integer, BitSet> equivalence, int i0, int i1) {
        BitSet bitSet = equivalence.get(i0);
        if (bitSet == null) {
            bitSet = new BitSet();
            equivalence.put(i0, bitSet);
        }
        bitSet.set(i1);
    }

    private static <E> SqlSplittableAggFunction.Registry<E> registry(final List<E> list) {
        return new SqlSplittableAggFunction.Registry<E>(){

            public int register(E e) {
                int i = list.indexOf(e);
                if (i < 0) {
                    i = list.size();
                    list.add(e);
                }
                return i;
            }
        };
    }

    private static class Side {
        final Map<Integer, Integer> split = new HashMap<Integer, Integer>();
        RelNode newInput;

        private Side() {
        }
    }
}

