/*
 * Decompiled with CFR 0.152.
 */
package org.apache.doris.load.loadv2.dpp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import org.apache.commons.collections.map.MultiValueMap;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalog.Column;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GlobalDictBuilder {
    protected static final Logger LOG = LoggerFactory.getLogger(GlobalDictBuilder.class);
    private MultiValueMap dictColumn;
    private List<String> dorisOlapTableColumnList;
    private List<String> mapSideJoinColumns;
    private String sourceHiveDBTableName;
    private String sourceHiveFilter;
    private String distinctKeyTableName;
    private String globalDictTableName;
    private String dorisIntermediateHiveTable;
    private SparkSession spark;
    private Map<String, String> dorisColumnNameTypeMap = new HashMap<String, String>();
    private List<String> veryHighCardinalityColumn;
    private int veryHighCardinalityColumnSplitNum;
    private ExecutorService pool;
    private StructType distinctValueSchema;

    public GlobalDictBuilder(MultiValueMap dictColumn, List<String> dorisOlapTableColumnList, List<String> mapSideJoinColumns, String sourceHiveDBTableName, String sourceHiveFilter, String dorisHiveDB, String distinctKeyTableName, String globalDictTableName, String dorisIntermediateHiveTable, int buildConcurrency, List<String> veryHighCardinalityColumn, int veryHighCardinalityColumnSplitNum, SparkSession spark) {
        this.dictColumn = dictColumn;
        this.dorisOlapTableColumnList = dorisOlapTableColumnList;
        this.mapSideJoinColumns = mapSideJoinColumns;
        this.sourceHiveDBTableName = sourceHiveDBTableName;
        this.sourceHiveFilter = sourceHiveFilter;
        this.distinctKeyTableName = distinctKeyTableName;
        this.globalDictTableName = globalDictTableName;
        this.dorisIntermediateHiveTable = dorisIntermediateHiveTable;
        this.spark = spark;
        this.pool = Executors.newFixedThreadPool(buildConcurrency < 0 ? 1 : buildConcurrency);
        this.veryHighCardinalityColumn = veryHighCardinalityColumn;
        this.veryHighCardinalityColumnSplitNum = veryHighCardinalityColumnSplitNum;
        spark.sql("use " + dorisHiveDB);
    }

    public void createHiveIntermediateTable() throws AnalysisException {
        Map<String, String> sourceHiveTableColumn = this.spark.catalog().listColumns(this.sourceHiveDBTableName).collectAsList().stream().collect(Collectors.toMap(Column::name, Column::dataType));
        HashMap<String, String> sourceHiveTableColumnInLowercase = new HashMap<String, String>();
        for (Map.Entry<String, String> entry : sourceHiveTableColumn.entrySet()) {
            sourceHiveTableColumnInLowercase.put(entry.getKey().toLowerCase(), entry.getValue().toLowerCase());
        }
        this.dorisOlapTableColumnList.stream().map(String::toLowerCase).forEach(columnName -> {
            String columnType = (String)sourceHiveTableColumnInLowercase.get(columnName);
            if (StringUtils.isEmpty((CharSequence)columnType)) {
                throw new RuntimeException(String.format("doris column %s not in source hive table", columnName));
            }
            this.dorisColumnNameTypeMap.put((String)columnName, columnType);
        });
        this.spark.sql(String.format("drop table if exists %s ", this.dorisIntermediateHiveTable));
        this.spark.sql(this.getCreateIntermediateHiveTableSql());
        this.spark.sql(this.getInsertIntermediateHiveTableSql());
    }

    public void extractDistinctColumn() {
        this.spark.sql(this.getCreateDistinctKeyTableSql());
        ArrayList<GlobalDictBuildWorker> workerList = new ArrayList<GlobalDictBuildWorker>();
        for (Object column : this.dictColumn.keySet()) {
            workerList.add(() -> this.spark.sql(this.getInsertDistinctKeyTableSql(column.toString(), this.dorisIntermediateHiveTable)));
        }
        this.submitWorker(workerList);
    }

    public void buildGlobalDict() throws ExecutionException, InterruptedException {
        this.spark.sql(this.getCreateGlobalDictHiveTableSql());
        ArrayList<GlobalDictBuildWorker> globalDictBuildWorkers = new ArrayList<GlobalDictBuildWorker>();
        for (Object distinctColumnNameOrigin : this.dictColumn.keySet()) {
            String distinctColumnNameTmp = distinctColumnNameOrigin.toString();
            globalDictBuildWorkers.add(() -> {
                List maxGlobalDictValueRow = this.spark.sql(this.getMaxGlobalDictValueSql(distinctColumnNameTmp)).collectAsList();
                if (maxGlobalDictValueRow.size() == 0) {
                    throw new RuntimeException(String.format("get max dict value failed: %s", distinctColumnNameTmp));
                }
                long maxDictValue = 0L;
                long minDictValue = 0L;
                Row row = (Row)maxGlobalDictValueRow.get(0);
                if (row != null && row.get(0) != null) {
                    maxDictValue = (Long)row.get(0);
                    minDictValue = (Long)row.get(1);
                }
                LOG.info(" column " + distinctColumnNameTmp + " 's max value in dict is " + maxDictValue + ", min value is " + minDictValue);
                if (minDictValue < 0L) {
                    throw new RuntimeException(String.format(" column %s 's cardinality has exceed bigint's max value", distinctColumnNameTmp));
                }
                if (this.veryHighCardinalityColumn.contains(distinctColumnNameTmp) && this.veryHighCardinalityColumnSplitNum > 1) {
                    this.buildGlobalDictBySplit(maxDictValue, distinctColumnNameTmp);
                } else {
                    this.spark.sql(this.getBuildGlobalDictSql(maxDictValue, distinctColumnNameTmp));
                }
            });
        }
        this.submitWorker(globalDictBuildWorkers);
    }

    public void encodeDorisIntermediateHiveTable() {
        for (Object distinctColumnObj : this.dictColumn.keySet()) {
            this.spark.sql(this.getEncodeDorisIntermediateHiveTableSql(distinctColumnObj.toString(), (ArrayList)this.dictColumn.get((Object)distinctColumnObj.toString())));
        }
    }

    private String getCreateIntermediateHiveTableSql() {
        StringBuilder sql = new StringBuilder();
        sql.append("create table if not exists " + this.dorisIntermediateHiveTable + " ( ");
        HashSet allDictColumn = new HashSet();
        allDictColumn.addAll(this.dictColumn.keySet());
        allDictColumn.addAll(this.dictColumn.values());
        this.dorisOlapTableColumnList.stream().forEach(columnName -> {
            sql.append((String)columnName).append(" ");
            if (allDictColumn.contains(columnName)) {
                sql.append(" string ,");
            } else {
                sql.append(this.dorisColumnNameTypeMap.get(columnName)).append(" ,");
            }
        });
        return sql.deleteCharAt(sql.length() - 1).append(" )").append(" stored as sequencefile ").toString();
    }

    private String getInsertIntermediateHiveTableSql() {
        StringBuilder sql = new StringBuilder();
        sql.append("insert overwrite table ").append(this.dorisIntermediateHiveTable).append(" select ");
        this.dorisOlapTableColumnList.stream().forEach(columnName -> sql.append((String)columnName).append(" ,"));
        sql.deleteCharAt(sql.length() - 1).append(" from ").append(this.sourceHiveDBTableName);
        if (!StringUtils.isEmpty((CharSequence)this.sourceHiveFilter)) {
            sql.append(" where ").append(this.sourceHiveFilter);
        }
        return sql.toString();
    }

    private String getCreateDistinctKeyTableSql() {
        return "create table if not exists " + this.distinctKeyTableName + "(dict_key string) partitioned by (dict_column string) stored as sequencefile ";
    }

    private String getInsertDistinctKeyTableSql(String distinctColumnName, String sourceHiveTable) {
        StringBuilder sql = new StringBuilder();
        sql.append("insert overwrite table ").append(this.distinctKeyTableName).append(" partition(dict_column='").append(distinctColumnName).append("')").append(" select ").append(distinctColumnName).append(" from ").append(sourceHiveTable).append(" group by ").append(distinctColumnName);
        return sql.toString();
    }

    private String getCreateGlobalDictHiveTableSql() {
        return "create table if not exists " + this.globalDictTableName + "(dict_key string, dict_value bigint) partitioned by(dict_column string) stored as sequencefile ";
    }

    private String getMaxGlobalDictValueSql(String distinctColumnName) {
        return "select max(dict_value) as max_value,min(dict_value) as min_value from " + this.globalDictTableName + " where dict_column='" + distinctColumnName + "'";
    }

    private void buildGlobalDictBySplit(long maxGlobalDictValue, String distinctColumnName) {
        Dataset newDistinctValue = this.spark.sql(this.getNewDistinctValue(distinctColumnName));
        Dataset[] splitedDistinctValue = newDistinctValue.randomSplit(this.getRandomSplitWeights());
        long currentMaxDictValue = maxGlobalDictValue;
        HashMap<String, Long> distinctKeyMap = new HashMap<String, Long>();
        for (int i = 0; i < splitedDistinctValue.length; ++i) {
            long currentDatasetStartDictValue = currentMaxDictValue;
            long splitDistinctValueCount = splitedDistinctValue[i].count();
            currentMaxDictValue += splitDistinctValueCount;
            String tmpDictTableName = String.format("%s_%s_tmp_dict_%s", i, currentDatasetStartDictValue, distinctColumnName);
            distinctKeyMap.put(tmpDictTableName, currentDatasetStartDictValue);
            Dataset distinctValueFrame = this.spark.createDataFrame(splitedDistinctValue[i].toJavaRDD(), this.getDistinctValueSchema());
            distinctValueFrame.createOrReplaceTempView(tmpDictTableName);
        }
        this.spark.sql(this.getSplitBuildGlobalDictSql(distinctKeyMap, distinctColumnName));
    }

    private String getSplitBuildGlobalDictSql(Map<String, Long> distinctKeyMap, String distinctColumnName) {
        StringBuilder sql = new StringBuilder();
        sql.append("insert overwrite table ").append(this.globalDictTableName).append(" partition(dict_column='").append(distinctColumnName).append("') ").append(" select dict_key,dict_value from ").append(this.globalDictTableName).append(" where dict_column='").append(distinctColumnName).append("' ");
        for (Map.Entry<String, Long> entry : distinctKeyMap.entrySet()) {
            sql.append(" union all select dict_key, (row_number() over(order by dict_key)) ").append(String.format(" +(%s) as dict_value from %s", entry.getValue(), entry.getKey()));
        }
        return sql.toString();
    }

    private StructType getDistinctValueSchema() {
        if (this.distinctValueSchema == null) {
            ArrayList<StructField> fieldList = new ArrayList<StructField>();
            fieldList.add(DataTypes.createStructField((String)"dict_key", (DataType)DataTypes.StringType, (boolean)false));
            this.distinctValueSchema = DataTypes.createStructType(fieldList);
        }
        return this.distinctValueSchema;
    }

    private double[] getRandomSplitWeights() {
        double[] weights = new double[this.veryHighCardinalityColumnSplitNum];
        double weight = 1.0 / Double.parseDouble(String.valueOf(this.veryHighCardinalityColumnSplitNum));
        Arrays.fill(weights, weight);
        return weights;
    }

    private String getBuildGlobalDictSql(long maxGlobalDictValue, String distinctColumnName) {
        return "insert overwrite table " + this.globalDictTableName + " partition(dict_column='" + distinctColumnName + "')  select dict_key,dict_value from " + this.globalDictTableName + " where dict_column='" + distinctColumnName + "'  union all select t1.dict_key as dict_key,(row_number() over(order by t1.dict_key)) + (" + maxGlobalDictValue + ") as dict_value from (select dict_key from " + this.distinctKeyTableName + " where dict_column='" + distinctColumnName + "' and dict_key is not null)t1 left join  (select dict_key,dict_value from " + this.globalDictTableName + " where dict_column='" + distinctColumnName + "' )t2 on t1.dict_key = t2.dict_key where t2.dict_value is null";
    }

    private String getNewDistinctValue(String distinctColumnName) {
        return "select t1.dict_key from  (select dict_key from " + this.distinctKeyTableName + " where dict_column='" + distinctColumnName + "' and dict_key is not null)t1 left join  (select dict_key,dict_value from " + this.globalDictTableName + " where dict_column='" + distinctColumnName + "' )t2 on t1.dict_key = t2.dict_key where t2.dict_value is null";
    }

    private String getEncodeDorisIntermediateHiveTableSql(String dictColumn, List<String> childColumn) {
        StringBuilder sql = new StringBuilder();
        sql.append("insert overwrite table ").append(this.dorisIntermediateHiveTable).append(" select ");
        if (this.mapSideJoinColumns.size() != 0 && this.mapSideJoinColumns.contains(dictColumn)) {
            sql.append(" /*+ BROADCAST (t) */ ");
        }
        this.dorisOlapTableColumnList.forEach(columnName -> {
            if (dictColumn.equals(columnName)) {
                sql.append("t.dict_value").append(" ,");
            } else if (childColumn != null && childColumn.contains(columnName)) {
                sql.append(String.format(" if(%s is null, null, t.dict_value) ", columnName)).append(" ,");
            } else {
                sql.append(this.dorisIntermediateHiveTable).append(".").append((String)columnName).append(" ,");
            }
        });
        sql.deleteCharAt(sql.length() - 1).append(" from ").append(this.dorisIntermediateHiveTable).append(" LEFT OUTER JOIN ( select dict_key,dict_value from ").append(this.globalDictTableName).append(" where dict_column='").append(dictColumn).append("' ) t on ").append(this.dorisIntermediateHiveTable).append(".").append(dictColumn).append(" = t.dict_key ");
        return sql.toString();
    }

    private void submitWorker(List<GlobalDictBuildWorker> workerList) {
        try {
            ArrayList<Future<Boolean>> futureList = new ArrayList<Future<Boolean>>();
            for (final GlobalDictBuildWorker globalDictBuildWorker : workerList) {
                futureList.add(this.pool.submit(new Callable<Boolean>(){

                    @Override
                    public Boolean call() throws Exception {
                        try {
                            globalDictBuildWorker.work();
                            return true;
                        }
                        catch (Exception e) {
                            LOG.error("BuildGlobalDict failed", (Throwable)e);
                            return false;
                        }
                    }
                }));
            }
            LOG.info("begin to fetch worker result");
            for (Future future : futureList) {
                if (((Boolean)future.get()).booleanValue()) continue;
                throw new RuntimeException("detect one worker failed");
            }
            LOG.info("fetch worker result complete");
        }
        catch (Exception e) {
            LOG.error("submit worker failed", (Throwable)e);
            throw new RuntimeException("submit worker failed", e);
        }
    }

    private static interface GlobalDictBuildWorker {
        public void work();
    }
}

