/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.client.write;

import java.io.IOException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.celeborn.client.write.DataBatches;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.message.StatusCode;
import org.apache.celeborn.shaded.io.netty.channel.ChannelFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PushState {
    private static final Logger logger = LoggerFactory.getLogger(PushState.class);
    private int pushBufferMaxSize;
    private long pushTimeout;
    public final AtomicInteger batchId = new AtomicInteger();
    private final ConcurrentHashMap<Integer, BatchInfo> inflightBatchInfos = new ConcurrentHashMap();
    public AtomicReference<IOException> exception = new AtomicReference();
    public final ConcurrentHashMap<String, DataBatches> batchesMap = new ConcurrentHashMap();

    public PushState(CelebornConf conf) {
        this.pushBufferMaxSize = conf.pushBufferMaxSize();
        this.pushTimeout = conf.pushDataTimeoutMs();
    }

    public void addBatch(int batchId) {
        this.inflightBatchInfos.computeIfAbsent(batchId, id -> new BatchInfo());
    }

    public void removeBatch(int batchId) {
        BatchInfo info = this.inflightBatchInfos.remove(batchId);
        if (info != null && info.channelFuture != null) {
            info.channelFuture.cancel(true);
        }
    }

    public int inflightBatchCount() {
        return this.inflightBatchInfos.size();
    }

    public synchronized void failExpiredBatch() {
        long currentTime = System.currentTimeMillis();
        this.inflightBatchInfos.values().forEach(info -> {
            if (info.pushTime != -1L && currentTime - info.pushTime > this.pushTimeout && info.callback != null) {
                info.channelFuture.cancel(true);
                info.callback.onFailure(new IOException(StatusCode.PUSH_DATA_TIMEOUT.getMessage()));
                info.channelFuture = null;
                info.callback = null;
            }
        });
    }

    public void pushStarted(int batchId, ChannelFuture future, RpcResponseCallback callback) {
        BatchInfo info = this.inflightBatchInfos.get(batchId);
        if (info != null) {
            info.pushTime = System.currentTimeMillis();
            info.channelFuture = future;
            info.callback = callback;
        }
    }

    public void cleanup() {
        if (!this.inflightBatchInfos.isEmpty()) {
            logger.debug("Cancel all {} futures.", (Object)this.inflightBatchInfos.size());
            this.inflightBatchInfos.values().forEach(entry -> {
                if (entry.channelFuture != null) {
                    entry.channelFuture.cancel(true);
                }
            });
            this.inflightBatchInfos.clear();
        }
    }

    public boolean addBatchData(String addressPair, PartitionLocation loc, int batchId, byte[] body) {
        DataBatches batches = this.batchesMap.computeIfAbsent(addressPair, s -> new DataBatches());
        batches.addDataBatch(loc, batchId, body);
        return batches.getTotalSize() > this.pushBufferMaxSize;
    }

    public DataBatches takeDataBatches(String addressPair) {
        return this.batchesMap.remove(addressPair);
    }

    class BatchInfo {
        ChannelFuture channelFuture;
        long pushTime = -1L;
        RpcResponseCallback callback;

        BatchInfo() {
        }
    }
}

