/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.operators.rank;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.shaded.guava30.com.google.common.cache.Cache;
import org.apache.flink.shaded.guava30.com.google.common.cache.CacheBuilder;
import org.apache.flink.shaded.guava30.com.google.common.cache.RemovalCause;
import org.apache.flink.shaded.guava30.com.google.common.cache.RemovalListener;
import org.apache.flink.shaded.guava30.com.google.common.cache.RemovalNotification;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.generated.GeneratedRecordComparator;
import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
import org.apache.flink.table.runtime.operators.rank.AbstractTopNFunction;
import org.apache.flink.table.runtime.operators.rank.RankRange;
import org.apache.flink.table.runtime.operators.rank.RankType;
import org.apache.flink.table.runtime.operators.rank.TopNBuffer;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class UpdatableTopNFunction
extends AbstractTopNFunction
implements CheckpointedFunction {
    private static final long serialVersionUID = 6786508184355952781L;
    private static final Logger LOG = LoggerFactory.getLogger(UpdatableTopNFunction.class);
    private final InternalTypeInfo<RowData> rowKeyType;
    private final long cacheSize;
    private transient MapState<RowData, Tuple2<RowData, Integer>> dataState;
    private transient TopNBuffer buffer;
    private transient Map<RowData, RankRow> rowKeyMap;
    private transient Cache<RowData, Tuple2<TopNBuffer, Map<RowData, RankRow>>> kvRowKeyMap;
    private final TypeSerializer<RowData> inputRowSer;
    private final KeySelector<RowData, RowData> rowKeySelector;

    public UpdatableTopNFunction(StateTtlConfig ttlConfig, InternalTypeInfo<RowData> inputRowType, RowDataKeySelector rowKeySelector, GeneratedRecordComparator generatedRecordComparator, RowDataKeySelector sortKeySelector, RankType rankType, RankRange rankRange, boolean generateUpdateBefore, boolean outputRankNumber, long cacheSize) {
        super(ttlConfig, inputRowType, generatedRecordComparator, sortKeySelector, rankType, rankRange, generateUpdateBefore, outputRankNumber);
        this.rowKeyType = rowKeySelector.getProducedType();
        this.cacheSize = cacheSize;
        this.inputRowSer = inputRowType.createSerializer(new ExecutionConfig());
        this.rowKeySelector = rowKeySelector;
    }

    @Override
    public void open(Configuration parameters) throws Exception {
        super.open(parameters);
        int lruCacheSize = Math.max(1, (int)(this.cacheSize / this.getDefaultTopNSize()));
        CacheBuilder cacheBuilder = CacheBuilder.newBuilder();
        if (this.ttlConfig.isEnabled()) {
            cacheBuilder.expireAfterWrite(this.ttlConfig.getTtl().toMilliseconds(), TimeUnit.MILLISECONDS);
        }
        this.kvRowKeyMap = cacheBuilder.maximumSize((long)lruCacheSize).removalListener((RemovalListener)new CacheRemovalListener()).build();
        LOG.info("Top{} operator is using LRU caches key-size: {}", (Object)this.getDefaultTopNSize(), (Object)lruCacheSize);
        TupleTypeInfo valueTypeInfo = new TupleTypeInfo(new TypeInformation[]{this.inputRowType, Types.INT});
        MapStateDescriptor mapStateDescriptor = new MapStateDescriptor("data-state-with-update", this.rowKeyType, (TypeInformation)valueTypeInfo);
        if (this.ttlConfig.isEnabled()) {
            mapStateDescriptor.enableTimeToLive(this.ttlConfig);
        }
        this.dataState = this.getRuntimeContext().getMapState(mapStateDescriptor);
        this.registerMetric(this.cacheSize);
    }

    public void initializeState(FunctionInitializationContext context) throws Exception {
    }

    public void processElement(RowData input, KeyedProcessFunction.Context context, Collector<RowData> out) throws Exception {
        this.initHeapStates();
        this.initRankEnd(input);
        if (this.outputRankNumber || this.hasOffset()) {
            this.processElementWithRowNumber(input, out);
        } else {
            this.processElementWithoutRowNumber(input, out);
        }
    }

    public void snapshotState(FunctionSnapshotContext context) throws Exception {
        for (Map.Entry entry : this.kvRowKeyMap.asMap().entrySet()) {
            RowData partitionKey = (RowData)entry.getKey();
            Map currentRowKeyMap = (Map)((Tuple2)entry.getValue()).f1;
            this.keyContext.setCurrentKey((Object)partitionKey);
            this.flushBufferToState(currentRowKeyMap);
        }
    }

    private void initHeapStates() throws Exception {
        ++this.requestCount;
        RowData partitionKey = (RowData)this.keyContext.getCurrentKey();
        Tuple2 tuple2 = (Tuple2)this.kvRowKeyMap.getIfPresent((Object)partitionKey);
        if (tuple2 == null) {
            this.buffer = new TopNBuffer(this.sortKeyComparator, LinkedHashSet::new);
            this.rowKeyMap = new HashMap<RowData, RankRow>();
            this.kvRowKeyMap.put((Object)partitionKey, (Object)new Tuple2((Object)this.buffer, this.rowKeyMap));
            Iterator iter = this.dataState.iterator();
            if (iter != null) {
                HashMap<RowData, TreeMap> tempSortedMap = new HashMap<RowData, TreeMap>();
                while (iter.hasNext()) {
                    Map.Entry entry = (Map.Entry)iter.next();
                    RowData rowKey = (RowData)entry.getKey();
                    Tuple2 recordAndInnerRank = (Tuple2)entry.getValue();
                    RowData record = (RowData)recordAndInnerRank.f0;
                    Integer innerRank = (Integer)recordAndInnerRank.f1;
                    this.rowKeyMap.put(rowKey, new RankRow(record, innerRank, false));
                    RowData sortKey = (RowData)this.sortKeySelector.getKey((Object)record);
                    TreeMap treeMap = tempSortedMap.computeIfAbsent(sortKey, k -> new TreeMap());
                    treeMap.put(innerRank, rowKey);
                }
                for (Map.Entry entry : tempSortedMap.entrySet()) {
                    RowData sortKey = (RowData)entry.getKey();
                    TreeMap treeMap = (TreeMap)entry.getValue();
                    for (Map.Entry treeMapEntry : treeMap.entrySet()) {
                        Integer innerRank = (Integer)treeMapEntry.getKey();
                        RowData recordRowKey = (RowData)treeMapEntry.getValue();
                        int size = this.buffer.put(sortKey, recordRowKey);
                        if (innerRank == size) continue;
                        LOG.warn("Failed to build sorted map from state, this may result in wrong result. The sort key is {}, partition key is {}, treeMap is {}. The expected inner rank is {}, but current size is {}.", new Object[]{sortKey, partitionKey, treeMap, innerRank, size});
                    }
                }
            }
        } else {
            ++this.hitCount;
            this.buffer = (TopNBuffer)tuple2.f0;
            this.rowKeyMap = (Map)tuple2.f1;
        }
    }

    private void processElementWithRowNumber(RowData inputRow, Collector<RowData> out) throws Exception {
        RowData sortKey = (RowData)this.sortKeySelector.getKey((Object)inputRow);
        RowData rowKey = (RowData)this.rowKeySelector.getKey((Object)inputRow);
        if (this.rowKeyMap.containsKey(rowKey)) {
            RankRow oldRow = this.rowKeyMap.get(rowKey);
            RowData oldSortKey = (RowData)this.sortKeySelector.getKey((Object)oldRow.row);
            if (oldSortKey.equals(sortKey)) {
                Tuple2<Integer, Integer> rankAndInnerRank = this.rowNumber(sortKey, rowKey, this.buffer);
                int rank = (Integer)rankAndInnerRank.f0;
                int innerRank = (Integer)rankAndInnerRank.f1;
                this.rowKeyMap.put(rowKey, new RankRow((RowData)this.inputRowSer.copy((Object)inputRow), innerRank, true));
                this.collectUpdateBefore(out, oldRow.row, rank);
                this.collectUpdateAfter(out, inputRow, rank);
                return;
            }
            Tuple2<Integer, Integer> oldRankAndInnerRank = this.rowNumber(oldSortKey, rowKey, this.buffer);
            int oldRank = (Integer)oldRankAndInnerRank.f0;
            this.buffer.remove(oldSortKey, rowKey);
            int size = this.buffer.put(sortKey, rowKey);
            this.rowKeyMap.put(rowKey, new RankRow((RowData)this.inputRowSer.copy((Object)inputRow), size, true));
            this.updateInnerRank(oldSortKey);
            this.emitRecordsWithRowNumber(sortKey, inputRow, out, oldSortKey, oldRow, oldRank);
        } else if (this.checkSortKeyInBufferRange(sortKey, this.buffer)) {
            int size = this.buffer.put(sortKey, rowKey);
            this.rowKeyMap.put(rowKey, new RankRow((RowData)this.inputRowSer.copy((Object)inputRow), size, true));
            this.emitRecordsWithRowNumber(sortKey, inputRow, out);
        }
    }

    private Tuple2<Integer, Integer> rowNumber(RowData sortKey, RowData rowKey, TopNBuffer buffer) {
        Iterator<Map.Entry<RowData, Collection<RowData>>> iterator = buffer.entrySet().iterator();
        int curRank = 1;
        while (iterator.hasNext()) {
            Map.Entry<RowData, Collection<RowData>> entry = iterator.next();
            RowData curKey = entry.getKey();
            Collection<RowData> rowKeys = entry.getValue();
            if (curKey.equals(sortKey)) {
                Iterator<RowData> rowKeysIter = rowKeys.iterator();
                int innerRank = 1;
                while (rowKeysIter.hasNext()) {
                    if (rowKey.equals(rowKeysIter.next())) {
                        return Tuple2.of((Object)curRank, (Object)innerRank);
                    }
                    ++innerRank;
                    ++curRank;
                }
                continue;
            }
            curRank += rowKeys.size();
        }
        LOG.error("Failed to find the sortKey: {}, rowkey: {} in the buffer. This should never happen", (Object)sortKey, (Object)rowKey);
        throw new RuntimeException("Failed to find the sortKey, rowkey in the buffer. This should never happen");
    }

    private void emitRecordsWithRowNumber(RowData sortKey, RowData inputRow, Collector<RowData> out) throws Exception {
        this.emitRecordsWithRowNumber(sortKey, inputRow, out, null, null, -1);
    }

    private void emitRecordsWithRowNumber(RowData sortKey, RowData inputRow, Collector<RowData> out, RowData oldSortKey, RankRow oldRow, int oldRank) throws Exception {
        Collection<RowData> rowKeys;
        Iterator<Map.Entry<RowData, Collection<RowData>>> iterator = this.buffer.entrySet().iterator();
        int currentRank = 0;
        RowData currentRow2 = null;
        boolean findsSortKey = false;
        boolean oldRowRetracted = false;
        while (iterator.hasNext() && this.isInRankEnd(currentRank)) {
            Map.Entry<RowData, Collection<RowData>> entry = iterator.next();
            RowData curSortKey = entry.getKey();
            rowKeys = entry.getValue();
            if (!findsSortKey && curSortKey.equals(sortKey)) {
                currentRank += rowKeys.size();
                currentRow2 = inputRow;
                findsSortKey = true;
                continue;
            }
            if (findsSortKey) {
                if (oldSortKey == null) {
                    Iterator<RowData> rowKeyIter = rowKeys.iterator();
                    while (rowKeyIter.hasNext() && this.isInRankEnd(currentRank)) {
                        RowData rowKey = rowKeyIter.next();
                        RankRow prevRow = this.rowKeyMap.get(rowKey);
                        this.collectUpdateBefore(out, prevRow.row, currentRank);
                        if (currentRow2 == inputRow && oldRow != null && !oldRowRetracted) {
                            this.collectUpdateBefore(out, oldRow.row, oldRank);
                            oldRowRetracted = true;
                        }
                        this.collectUpdateAfter(out, currentRow2, currentRank);
                        currentRow2 = prevRow.row;
                        ++currentRank;
                    }
                    continue;
                }
                int compare = this.sortKeyComparator.compare(curSortKey, oldSortKey);
                if (compare > 0) break;
                Iterator<RowData> rowKeyIter = rowKeys.iterator();
                while (rowKeyIter.hasNext() && currentRank < oldRank) {
                    RowData rowKey = rowKeyIter.next();
                    RankRow prevRow = this.rowKeyMap.get(rowKey);
                    this.collectUpdateBefore(out, prevRow.row, currentRank);
                    if (currentRow2 == inputRow && oldRow != null && !oldRowRetracted) {
                        this.collectUpdateBefore(out, oldRow.row, oldRank);
                        oldRowRetracted = true;
                    }
                    this.collectUpdateAfter(out, currentRow2, currentRank);
                    currentRow2 = prevRow.row;
                    ++currentRank;
                }
                continue;
            }
            currentRank += rowKeys.size();
        }
        if (this.isInRankEnd(currentRank)) {
            if (oldRow == null) {
                this.collectInsert(out, currentRow2, currentRank);
            } else {
                Preconditions.checkArgument((currentRank == oldRank ? 1 : 0) != 0);
                if (!oldRowRetracted) {
                    this.collectUpdateBefore(out, oldRow.row, oldRank);
                }
                this.collectUpdateAfter(out, currentRow2, currentRank);
            }
            return;
        }
        ArrayList<RowData> toDeleteSortKeys = new ArrayList<RowData>();
        while (iterator.hasNext()) {
            Map.Entry<RowData, Collection<RowData>> entry = iterator.next();
            rowKeys = entry.getValue();
            for (RowData rowKey : rowKeys) {
                this.rowKeyMap.remove(rowKey);
                this.dataState.remove((Object)rowKey);
            }
            toDeleteSortKeys.add(entry.getKey());
        }
        for (RowData toDeleteKey : toDeleteSortKeys) {
            this.buffer.removeAll(toDeleteKey);
        }
    }

    private void processElementWithoutRowNumber(RowData inputRow, Collector<RowData> out) throws Exception {
        RowData sortKey = (RowData)this.sortKeySelector.getKey((Object)inputRow);
        RowData rowKey = (RowData)this.rowKeySelector.getKey((Object)inputRow);
        if (this.rowKeyMap.containsKey(rowKey)) {
            RankRow oldRow = this.rowKeyMap.get(rowKey);
            RowData oldSortKey = (RowData)this.sortKeySelector.getKey((Object)oldRow.row);
            if (!oldSortKey.equals(sortKey)) {
                this.buffer.remove(oldSortKey, rowKey);
                int size = this.buffer.put(sortKey, rowKey);
                this.rowKeyMap.put(rowKey, new RankRow((RowData)this.inputRowSer.copy((Object)inputRow), size, true));
                this.updateInnerRank(oldSortKey);
            } else {
                this.rowKeyMap.put(rowKey, new RankRow((RowData)this.inputRowSer.copy((Object)inputRow), oldRow.innerRank, true));
            }
            this.collectUpdateBefore(out, oldRow.row);
            this.collectUpdateAfter(out, inputRow);
        } else if (this.checkSortKeyInBufferRange(sortKey, this.buffer)) {
            RowData lastRowKey;
            int size = this.buffer.put(sortKey, rowKey);
            this.rowKeyMap.put(rowKey, new RankRow((RowData)this.inputRowSer.copy((Object)inputRow), size, true));
            if ((long)this.buffer.getCurrentTopNum() > this.rankEnd && (lastRowKey = this.buffer.removeLast()) != null) {
                RankRow lastRow = this.rowKeyMap.remove(lastRowKey);
                this.dataState.remove((Object)lastRowKey);
                this.collectDelete(out, lastRow.row);
            }
            this.collectInsert(out, inputRow);
        }
    }

    private void flushBufferToState(Map<RowData, RankRow> curRowKeyMap) throws Exception {
        for (Map.Entry<RowData, RankRow> entry : curRowKeyMap.entrySet()) {
            RowData key = entry.getKey();
            RankRow rankRow = entry.getValue();
            if (!rankRow.dirty) continue;
            this.dataState.put((Object)key, (Object)Tuple2.of((Object)rankRow.row, (Object)rankRow.innerRank));
            rankRow.dirty = false;
        }
    }

    private void updateInnerRank(RowData oldSortKey) {
        Collection<RowData> list = this.buffer.get(oldSortKey);
        if (list != null) {
            Iterator<RowData> iter = list.iterator();
            int innerRank = 1;
            while (iter.hasNext()) {
                RowData rowKey = iter.next();
                RankRow row = this.rowKeyMap.get(rowKey);
                if (row.innerRank != innerRank) {
                    row.innerRank = innerRank;
                    row.dirty = true;
                }
                ++innerRank;
            }
        }
    }

    private static class RankRow {
        private final RowData row;
        private int innerRank;
        private boolean dirty;

        private RankRow(RowData row, int innerRank, boolean dirty) {
            this.row = row;
            this.innerRank = innerRank;
            this.dirty = dirty;
        }
    }

    private class CacheRemovalListener
    implements RemovalListener<RowData, Tuple2<TopNBuffer, Map<RowData, RankRow>>> {
        private CacheRemovalListener() {
        }

        public void onRemoval(RemovalNotification<RowData, Tuple2<TopNBuffer, Map<RowData, RankRow>>> notification) {
            if (notification.getCause() != RemovalCause.SIZE) {
                return;
            }
            RowData partitionKey = (RowData)notification.getKey();
            Tuple2 value2 = (Tuple2)notification.getValue();
            if (partitionKey == null || value2 == null) {
                return;
            }
            RowData previousKey = (RowData)UpdatableTopNFunction.this.keyContext.getCurrentKey();
            UpdatableTopNFunction.this.keyContext.setCurrentKey((Object)partitionKey);
            try {
                UpdatableTopNFunction.this.flushBufferToState((Map)value2.f1);
            }
            catch (Throwable e) {
                LOG.error("Fail to synchronize state!", e);
                throw new RuntimeException(e);
            }
            finally {
                UpdatableTopNFunction.this.keyContext.setCurrentKey((Object)previousKey);
            }
        }
    }
}

