/*
 * Decompiled with CFR 0.152.
 */
package org.apache.doris.planner;

import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.SelectStmt;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.TableRef;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.analysis.TupleId;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.MaterializedIndexMeta;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.UserException;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.ScanNode;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.rewrite.mvrewrite.MVExprEquivalent;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class MaterializedViewSelector {
    private static final Logger LOG = LogManager.getLogger(MaterializedViewSelector.class);
    private final SelectStmt selectStmt;
    private final Analyzer analyzer;
    private Map<Long, Set<String>> columnNamesInPredicates = Maps.newHashMap();
    private boolean isSPJQuery;
    private Map<Long, Set<String>> columnNamesInGrouping = Maps.newHashMap();
    private Map<Long, Set<FunctionCallExpr>> aggColumnsInQuery = Maps.newHashMap();
    private Map<Long, Set<String>> columnNamesInQueryOutput = Maps.newHashMap();
    private boolean disableSPJGView;
    private boolean isPreAggregation = true;
    private String reasonOfDisable;

    public MaterializedViewSelector(SelectStmt selectStmt, Analyzer analyzer) {
        this.selectStmt = selectStmt;
        this.analyzer = analyzer;
        this.init();
    }

    public BestIndexInfo selectBestMV(ScanNode scanNode) throws UserException {
        this.resetPreAggregationVariables();
        long start = System.currentTimeMillis();
        Preconditions.checkState((boolean)(scanNode instanceof OlapScanNode));
        OlapScanNode olapScanNode = (OlapScanNode)scanNode;
        Map<Long, List<Column>> candidateIndexIdToSchema = this.predicates(olapScanNode);
        if (candidateIndexIdToSchema.keySet().size() == 0) {
            return null;
        }
        long bestIndexId = this.priorities(olapScanNode, candidateIndexIdToSchema);
        LOG.debug("The best materialized view is {} for scan node {} in query {}, isPreAggregation: {}, reasonOfDisable: {}, cost {}", (Object)bestIndexId, (Object)scanNode.getId(), (Object)this.selectStmt.toSql(), (Object)this.isPreAggregation, (Object)this.reasonOfDisable, (Object)(System.currentTimeMillis() - start));
        return new BestIndexInfo(bestIndexId, this.isPreAggregation, this.reasonOfDisable);
    }

    private void resetPreAggregationVariables() {
        this.isPreAggregation = true;
        this.reasonOfDisable = null;
    }

    private Map<Long, List<Column>> predicates(OlapScanNode scanNode) throws AnalysisException {
        Map<Long, MaterializedIndexMeta> candidateIndexIdToMeta = scanNode.getOlapTable().getVisibleIndexIdToMeta();
        OlapTable table = scanNode.getOlapTable();
        Preconditions.checkState((table != null ? 1 : 0) != 0);
        long tableId = table.getId();
        this.checkCompensatingPredicates(this.columnNamesInPredicates.get(tableId), candidateIndexIdToMeta);
        this.checkGrouping(this.columnNamesInGrouping.get(tableId), candidateIndexIdToMeta);
        this.checkAggregationFunction(this.aggColumnsInQuery.get(tableId), candidateIndexIdToMeta);
        this.checkOutputColumns(this.columnNamesInQueryOutput.get(tableId), candidateIndexIdToMeta);
        if ((table.getKeysType() == KeysType.AGG_KEYS || table.getKeysType() == KeysType.UNIQUE_KEYS) && candidateIndexIdToMeta.size() == 0) {
            this.compensateCandidateIndex(candidateIndexIdToMeta, scanNode.getOlapTable().getVisibleIndexIdToMeta(), table);
            this.checkOutputColumns(this.columnNamesInQueryOutput.get(tableId), candidateIndexIdToMeta);
        }
        HashMap result = Maps.newHashMap();
        for (Map.Entry<Long, MaterializedIndexMeta> entry : candidateIndexIdToMeta.entrySet()) {
            result.put(entry.getKey(), entry.getValue().getSchema());
        }
        return result;
    }

    private long priorities(OlapScanNode scanNode, Map<Long, List<Column>> candidateIndexIdToSchema) {
        OlapTable tbl = scanNode.getOlapTable();
        Long v2RollupIndexId = tbl.getSegmentV2FormatIndexId();
        if (v2RollupIndexId != null) {
            ConnectContext connectContext = ConnectContext.get();
            if (connectContext != null && connectContext.getSessionVariable().isUseV2Rollup()) {
                if (candidateIndexIdToSchema.containsKey(v2RollupIndexId)) {
                    return v2RollupIndexId;
                }
            } else {
                candidateIndexIdToSchema.remove(v2RollupIndexId);
            }
        }
        HashSet equivalenceColumns = Sets.newHashSet();
        HashSet unequivalenceColumns = Sets.newHashSet();
        scanNode.collectColumns(this.analyzer, equivalenceColumns, unequivalenceColumns);
        Set<Long> indexesMatchingBestPrefixIndex = this.matchBestPrefixIndex(candidateIndexIdToSchema, equivalenceColumns, unequivalenceColumns);
        return this.selectBestRowCountIndex(indexesMatchingBestPrefixIndex, scanNode.getOlapTable(), scanNode.getSelectedPartitionIds());
    }

    private Set<Long> matchBestPrefixIndex(Map<Long, List<Column>> candidateIndexIdToSchema, Set<String> equivalenceColumns, Set<String> unequivalenceColumns) {
        if (equivalenceColumns.size() == 0 && unequivalenceColumns.size() == 0) {
            return candidateIndexIdToSchema.keySet();
        }
        HashSet indexesMatchingBestPrefixIndex = Sets.newHashSet();
        int maxPrefixMatchCount = 0;
        for (Map.Entry<Long, List<Column>> entry : candidateIndexIdToSchema.entrySet()) {
            int prefixMatchCount = 0;
            long indexId = entry.getKey();
            List<Column> indexSchema = entry.getValue();
            for (Column col : indexSchema) {
                if (equivalenceColumns.contains(col.getName())) {
                    ++prefixMatchCount;
                    continue;
                }
                if (!unequivalenceColumns.contains(col.getName())) break;
                ++prefixMatchCount;
                break;
            }
            if (prefixMatchCount == maxPrefixMatchCount) {
                LOG.debug("find a equal prefix match index {}. match count: {}", (Object)indexId, (Object)prefixMatchCount);
                indexesMatchingBestPrefixIndex.add(indexId);
                continue;
            }
            if (prefixMatchCount <= maxPrefixMatchCount) continue;
            LOG.debug("find a better prefix match index {}. match count: {}", (Object)indexId, (Object)prefixMatchCount);
            maxPrefixMatchCount = prefixMatchCount;
            indexesMatchingBestPrefixIndex.clear();
            indexesMatchingBestPrefixIndex.add(indexId);
        }
        LOG.debug("Those mv match the best prefix index:" + Joiner.on((String)",").join((Iterable)indexesMatchingBestPrefixIndex));
        return indexesMatchingBestPrefixIndex;
    }

    private long selectBestRowCountIndex(Set<Long> indexesMatchingBestPrefixIndex, OlapTable olapTable, Collection<Long> partitionIds) {
        long minRowCount = Long.MAX_VALUE;
        long selectedIndexId = 0L;
        for (Long indexId : indexesMatchingBestPrefixIndex) {
            long rowCount = 0L;
            for (Long partitionId : partitionIds) {
                rowCount += olapTable.getPartition(partitionId).getIndex(indexId).getRowCount();
            }
            LOG.debug("rowCount={} for table={}", (Object)rowCount, (Object)indexId);
            if (rowCount < minRowCount) {
                minRowCount = rowCount;
                selectedIndexId = indexId;
                continue;
            }
            if (rowCount != minRowCount) continue;
            int selectedColumnSize = olapTable.getSchemaByIndexId(selectedIndexId).size();
            int currColumnSize = olapTable.getSchemaByIndexId(indexId).size();
            if (currColumnSize >= selectedColumnSize) continue;
            selectedIndexId = indexId;
        }
        return selectedIndexId;
    }

    private void checkCompensatingPredicates(Set<String> columnsInPredicates, Map<Long, MaterializedIndexMeta> candidateIndexIdToMeta) {
        if (columnsInPredicates == null) {
            return;
        }
        Iterator<Map.Entry<Long, MaterializedIndexMeta>> iterator = candidateIndexIdToMeta.entrySet().iterator();
        while (iterator.hasNext()) {
            Map.Entry<Long, MaterializedIndexMeta> entry = iterator.next();
            TreeSet indexNonAggregatedColumnNames = new TreeSet(String.CASE_INSENSITIVE_ORDER);
            entry.getValue().getSchema().stream().filter(column -> !column.isAggregated()).forEach(column -> indexNonAggregatedColumnNames.add(column.getName()));
            if (indexNonAggregatedColumnNames.containsAll(columnsInPredicates)) continue;
            iterator.remove();
        }
        LOG.debug("Those mv pass the test of compensating predicates:" + Joiner.on((String)",").join(candidateIndexIdToMeta.keySet()));
    }

    private void checkGrouping(Set<String> columnsInGrouping, Map<Long, MaterializedIndexMeta> candidateIndexIdToMeta) {
        Iterator<Map.Entry<Long, MaterializedIndexMeta>> iterator = candidateIndexIdToMeta.entrySet().iterator();
        while (iterator.hasNext()) {
            Map.Entry<Long, MaterializedIndexMeta> entry = iterator.next();
            TreeSet indexNonAggregatedColumnNames = new TreeSet(String.CASE_INSENSITIVE_ORDER);
            MaterializedIndexMeta candidateIndexMeta = entry.getValue();
            List<Column> candidateIndexSchema = candidateIndexMeta.getSchema();
            candidateIndexSchema.stream().filter(column -> !column.isAggregated()).forEach(column -> indexNonAggregatedColumnNames.add(column.getName()));
            if (indexNonAggregatedColumnNames.size() == candidateIndexSchema.size() && candidateIndexMeta.getKeysType() == KeysType.DUP_KEYS) continue;
            if (this.isSPJQuery || this.disableSPJGView) {
                iterator.remove();
                continue;
            }
            if (columnsInGrouping == null || indexNonAggregatedColumnNames.containsAll(columnsInGrouping)) continue;
            iterator.remove();
        }
        LOG.debug("Those mv pass the test of grouping:" + Joiner.on((String)",").join(candidateIndexIdToMeta.keySet()));
    }

    private void checkAggregationFunction(Set<FunctionCallExpr> aggregatedColumnsInQueryOutput, Map<Long, MaterializedIndexMeta> candidateIndexIdToMeta) throws AnalysisException {
        Iterator<Map.Entry<Long, MaterializedIndexMeta>> iterator = candidateIndexIdToMeta.entrySet().iterator();
        while (iterator.hasNext()) {
            Map.Entry<Long, MaterializedIndexMeta> entry = iterator.next();
            MaterializedIndexMeta candidateIndexMeta = entry.getValue();
            List<FunctionCallExpr> indexAggColumnExpsList = this.mvAggColumnsToExprList(candidateIndexMeta);
            if (indexAggColumnExpsList.size() == 0 && candidateIndexMeta.getKeysType() == KeysType.DUP_KEYS) continue;
            if (this.isSPJQuery || this.disableSPJGView) {
                iterator.remove();
                continue;
            }
            if (aggregatedColumnsInQueryOutput == null || this.aggFunctionsMatchAggColumns(aggregatedColumnsInQueryOutput, indexAggColumnExpsList)) continue;
            iterator.remove();
        }
        LOG.debug("Those mv pass the test of aggregation function:" + Joiner.on((String)",").join(candidateIndexIdToMeta.keySet()));
    }

    private void checkOutputColumns(Set<String> columnNamesInQueryOutput, Map<Long, MaterializedIndexMeta> candidateIndexIdToMeta) {
        if (columnNamesInQueryOutput == null) {
            return;
        }
        Iterator<Map.Entry<Long, MaterializedIndexMeta>> iterator = candidateIndexIdToMeta.entrySet().iterator();
        while (iterator.hasNext()) {
            Map.Entry<Long, MaterializedIndexMeta> entry = iterator.next();
            TreeSet indexColumnNames = new TreeSet(String.CASE_INSENSITIVE_ORDER);
            List<Column> candidateIndexSchema = entry.getValue().getSchema();
            candidateIndexSchema.stream().forEach(column -> indexColumnNames.add(column.getName()));
            if (indexColumnNames.containsAll(columnNamesInQueryOutput)) continue;
            iterator.remove();
        }
        LOG.debug("Those mv pass the test of output columns:" + Joiner.on((String)",").join(candidateIndexIdToMeta.keySet()));
    }

    private void compensateCandidateIndex(Map<Long, MaterializedIndexMeta> candidateIndexIdToMeta, Map<Long, MaterializedIndexMeta> allVisibleIndexes, OlapTable table) {
        this.isPreAggregation = false;
        this.reasonOfDisable = "The aggregate operator does not match";
        int keySizeOfBaseIndex = table.getKeyColumnsByIndexId(table.getBaseIndexId()).size();
        for (Map.Entry<Long, MaterializedIndexMeta> index : allVisibleIndexes.entrySet()) {
            long mvIndexId = index.getKey();
            if (table.getKeyColumnsByIndexId(mvIndexId).size() != keySizeOfBaseIndex) continue;
            candidateIndexIdToMeta.put(mvIndexId, index.getValue());
        }
        LOG.debug("Those mv pass the test of output columns:" + Joiner.on((String)",").join(candidateIndexIdToMeta.keySet()));
    }

    private void init() {
        Expr whereClause = this.selectStmt.getWhereClause();
        if (whereClause != null) {
            whereClause.getTableIdToColumnNames(this.columnNamesInPredicates);
        }
        for (TableRef tableRef : this.selectStmt.getTableRefs()) {
            if (tableRef.getOnClause() == null) continue;
            tableRef.getOnClause().getTableIdToColumnNames(this.columnNamesInPredicates);
        }
        if (this.selectStmt.getAggInfo() == null) {
            this.isSPJQuery = true;
        } else {
            if (this.selectStmt.getAggInfo().getGroupingExprs() != null) {
                ArrayList<Expr> groupingExprs = this.selectStmt.getAggInfo().getGroupingExprs();
                Iterator iterator = groupingExprs.iterator();
                while (iterator.hasNext()) {
                    Expr expr = (Expr)iterator.next();
                    expr.getTableIdToColumnNames(this.columnNamesInGrouping);
                }
            }
            for (FunctionCallExpr functionCallExpr : this.selectStmt.getAggInfo().getAggregateExprs()) {
                HashMap tableIdToAggColumnNames = Maps.newHashMap();
                functionCallExpr.getTableIdToColumnNames(tableIdToAggColumnNames);
                if (tableIdToAggColumnNames.size() != 1) {
                    this.reasonOfDisable = "aggExpr[" + functionCallExpr.debugString() + "] should involved only one column";
                    this.disableSPJGView = true;
                    break;
                }
                this.addAggColumnInQuery((Long)tableIdToAggColumnNames.keySet().stream().findFirst().get(), functionCallExpr);
            }
        }
        List<TupleId> tupleIds = this.selectStmt.getTableRefIdsWithoutInlineView();
        for (TupleId tupleId : tupleIds) {
            TupleDescriptor tupleDescriptor = this.analyzer.getTupleDesc(tupleId);
            tupleDescriptor.getTableIdToColumnNames(this.columnNamesInQueryOutput);
        }
    }

    private void addAggColumnInQuery(Long tableId, FunctionCallExpr fnExpr) {
        HashSet aggColumns = this.aggColumnsInQuery.get(tableId);
        if (aggColumns == null) {
            aggColumns = Sets.newHashSet();
            this.aggColumnsInQuery.put(tableId, aggColumns);
        }
        aggColumns.add(fnExpr);
    }

    private boolean aggFunctionsMatchAggColumns(Set<FunctionCallExpr> queryExprList, List<FunctionCallExpr> mvColumnExprList) throws AnalysisException {
        for (Expr expr : queryExprList) {
            boolean match = false;
            for (Expr expr2 : mvColumnExprList) {
                if (!MVExprEquivalent.mvExprEqual(expr, expr2)) continue;
                match = true;
                break;
            }
            if (match) continue;
            return false;
        }
        return true;
    }

    private List<FunctionCallExpr> mvAggColumnsToExprList(MaterializedIndexMeta mvMeta) {
        ArrayList result = Lists.newArrayList();
        List<Column> schema = mvMeta.getSchema();
        for (Column column : schema) {
            if (!column.isAggregated()) continue;
            SlotRef slotRef = new SlotRef(null, column.getName());
            SlotDescriptor slotDescriptor = new SlotDescriptor(null, null);
            slotDescriptor.setColumn(column);
            slotRef.setDesc(slotDescriptor);
            FunctionCallExpr fnExpr = new FunctionCallExpr(column.getAggregationType().name(), (List<Expr>)Lists.newArrayList((Object[])new Expr[]{slotRef}));
            result.add(fnExpr);
        }
        return result;
    }

    public class BestIndexInfo {
        private long bestIndexId;
        private boolean isPreAggregation;
        private String reasonOfDisable;

        public BestIndexInfo(long bestIndexId, boolean isPreAggregation, String reasonOfDisable) {
            this.bestIndexId = bestIndexId;
            this.isPreAggregation = isPreAggregation;
            this.reasonOfDisable = reasonOfDisable;
        }

        public long getBestIndexId() {
            return this.bestIndexId;
        }

        public boolean isPreAggregation() {
            return this.isPreAggregation;
        }

        public String getReasonOfDisable() {
            return this.reasonOfDisable;
        }
    }
}

