/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.client;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.ShuffleClientHelper;
import org.apache.celeborn.client.compress.Compressor;
import org.apache.celeborn.client.read.RssInputStream;
import org.apache.celeborn.client.write.DataBatches;
import org.apache.celeborn.client.write.PushState;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.haclient.RssHARetryClient;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.network.TransportContext;
import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.network.protocol.PushData;
import org.apache.celeborn.common.network.protocol.PushDataHandShake;
import org.apache.celeborn.common.network.protocol.PushMergedData;
import org.apache.celeborn.common.network.protocol.RegionFinish;
import org.apache.celeborn.common.network.protocol.RegionStart;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbChangeLocationResponse;
import org.apache.celeborn.common.protocol.PbRegisterShuffleResponse;
import org.apache.celeborn.common.protocol.RpcNameConstants;
import org.apache.celeborn.common.protocol.message.ControlMessages;
import org.apache.celeborn.common.protocol.message.ControlMessages$PartitionSplit$;
import org.apache.celeborn.common.protocol.message.ControlMessages$RegisterMapPartitionTask$;
import org.apache.celeborn.common.protocol.message.ControlMessages$RegisterShuffle$;
import org.apache.celeborn.common.protocol.message.ControlMessages$Revive$;
import org.apache.celeborn.common.protocol.message.ControlMessages$UnregisterShuffle$;
import org.apache.celeborn.common.protocol.message.StatusCode;
import org.apache.celeborn.common.rpc.RpcAddress;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.rpc.RpcEnv;
import org.apache.celeborn.common.unsafe.Platform;
import org.apache.celeborn.common.util.PackedPartitionId;
import org.apache.celeborn.common.util.PbSerDeUtils;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.celeborn.shaded.com.google.common.util.concurrent.Uninterruptibles;
import org.apache.celeborn.shaded.io.netty.buffer.CompositeByteBuf;
import org.apache.celeborn.shaded.io.netty.buffer.Unpooled;
import org.apache.celeborn.shaded.io.netty.channel.ChannelFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.reflect.ClassTag$;

public class ShuffleClientImpl
extends ShuffleClient {
    private static final Logger logger = LoggerFactory.getLogger(ShuffleClientImpl.class);
    private static final byte MASTER_MODE = PartitionLocation.Mode.MASTER.mode();
    private static final Random RND = new Random();
    private final CelebornConf conf;
    private final UserIdentifier userIdentifier;
    private final int registerShuffleMaxRetries;
    private final long registerShuffleRetryWaitMs;
    private final int maxInFlight;
    private final int pushBufferMaxSize;
    private final RpcEnv rpcEnv;
    private RpcEndpointRef driverRssMetaService;
    protected TransportClientFactory dataClientFactory;
    private final Map<Integer, ConcurrentHashMap<Integer, PartitionLocation>> reducePartitionMap = new ConcurrentHashMap<Integer, ConcurrentHashMap<Integer, PartitionLocation>>();
    private final ConcurrentHashMap<Integer, Set<String>> mapperEndMap = new ConcurrentHashMap();
    private final Map<String, PushState> pushStates = new ConcurrentHashMap<String, PushState>();
    private final ExecutorService pushDataRetryPool;
    private final ExecutorService partitionSplitPool;
    private final Map<Integer, Set<Integer>> splitting = new ConcurrentHashMap<Integer, Set<Integer>>();
    ThreadLocal<Compressor> compressorThreadLocal = new ThreadLocal<Compressor>(){

        @Override
        protected Compressor initialValue() {
            return Compressor.getCompressor(ShuffleClientImpl.this.conf);
        }
    };
    private final Map<Integer, ReduceFileGroups> reduceFileGroupsMap = new ConcurrentHashMap<Integer, ReduceFileGroups>();

    public ShuffleClientImpl(CelebornConf conf, UserIdentifier userIdentifier) {
        this.conf = conf;
        this.userIdentifier = userIdentifier;
        this.registerShuffleMaxRetries = conf.registerShuffleMaxRetry();
        this.registerShuffleRetryWaitMs = conf.registerShuffleRetryWaitMs();
        this.maxInFlight = conf.pushMaxReqsInFlight();
        this.pushBufferMaxSize = conf.pushBufferMaxSize();
        this.rpcEnv = RpcEnv.create("ShuffleClient", Utils.localHostName(), 0, conf);
        String module = "data";
        TransportConf dataTransportConf = Utils.fromCelebornConf(conf, module, conf.getInt("celeborn" + module + ".io.threads", 8));
        TransportContext context = new TransportContext(dataTransportConf, new BaseMessageHandler(), true);
        this.dataClientFactory = context.createClientFactory();
        int pushDataRetryThreads = conf.pushRetryThreads();
        this.pushDataRetryPool = ThreadUtils.newDaemonCachedThreadPool("celeborn-retry-sender", pushDataRetryThreads, 60);
        int pushSplitPartitionThreads = conf.pushSplitPartitionThreads();
        this.partitionSplitPool = ThreadUtils.newDaemonCachedThreadPool("celeborn-shuffle-split", pushSplitPartitionThreads, 60);
    }

    private void submitRetryPushData(String applicationId, int shuffleId, int mapId, int attemptId, byte[] body, int batchId, PartitionLocation loc, RpcResponseCallback callback, PushState pushState, StatusCode cause) {
        int partitionId = loc.getId();
        if (!this.revive(applicationId, shuffleId, mapId, attemptId, partitionId, loc.getEpoch(), loc, cause)) {
            callback.onFailure(new IOException("Revive Failed"));
        } else if (this.mapperEnded(shuffleId, mapId, attemptId)) {
            logger.debug("Retrying push data, but the mapper(map {} attempt {}) has ended.", (Object)mapId, (Object)attemptId);
            pushState.removeBatch(batchId);
        } else {
            PartitionLocation newLoc = this.reducePartitionMap.get(shuffleId).get(partitionId);
            logger.info("Revive success, new location for reduce {} is {}.", (Object)partitionId, (Object)newLoc);
            try {
                TransportClient client = this.dataClientFactory.createClient(newLoc.getHost(), newLoc.getPushPort(), partitionId);
                NettyManagedBuffer newBuffer = new NettyManagedBuffer(Unpooled.wrappedBuffer(body));
                String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
                PushData newPushData = new PushData(MASTER_MODE, shuffleKey, newLoc.getUniqueId(), newBuffer);
                ChannelFuture future = client.pushData(newPushData, callback);
                pushState.pushStarted(batchId, future, callback);
            }
            catch (Exception ex) {
                logger.warn("Exception raised while pushing data for shuffle {} map {} attempt {} batch {}.", new Object[]{shuffleId, mapId, attemptId, batchId, ex});
                callback.onFailure(ex);
            }
        }
    }

    private void submitRetryPushMergedData(PushState pushState, String applicationId, int shuffleId, int mapId, int attemptId, ArrayList<DataBatches.DataBatch> batches, StatusCode cause, Integer oldGroupedBatchId) {
        HashMap<String, DataBatches> newDataBatchesMap = new HashMap<String, DataBatches>();
        for (DataBatches.DataBatch dataBatch : batches) {
            int partitionId = dataBatch.loc.getId();
            if (!this.revive(applicationId, shuffleId, mapId, attemptId, partitionId, dataBatch.loc.getEpoch(), dataBatch.loc, cause)) {
                pushState.exception.compareAndSet(null, new IOException("Revive Failed in retry push merged data for location: " + dataBatch.loc));
                return;
            }
            if (this.mapperEnded(shuffleId, mapId, attemptId)) {
                logger.debug("Retrying push data, but the mapper(map {} attempt {}) has ended.", (Object)mapId, (Object)attemptId);
                continue;
            }
            PartitionLocation newLoc = this.reducePartitionMap.get(shuffleId).get(partitionId);
            logger.info("Revive success, new location for reduce {} is {}.", (Object)partitionId, (Object)newLoc);
            DataBatches newDataBatches = newDataBatchesMap.computeIfAbsent(this.genAddressPair(newLoc), s -> new DataBatches());
            newDataBatches.addDataBatch(newLoc, dataBatch.batchId, dataBatch.body);
        }
        for (Map.Entry entry : newDataBatchesMap.entrySet()) {
            String addressPair = (String)entry.getKey();
            DataBatches newDataBatches = (DataBatches)entry.getValue();
            String[] tokens = addressPair.split("-");
            this.doPushMergedData(tokens[0], applicationId, shuffleId, mapId, attemptId, newDataBatches.requireBatches(), pushState, true);
        }
        pushState.removeBatch(oldGroupedBatchId);
    }

    private String genAddressPair(PartitionLocation loc) {
        String addressPair = loc.getPeer() != null ? loc.hostAndPushPort() + "-" + loc.getPeer().hostAndPushPort() : loc.hostAndPushPort();
        return addressPair;
    }

    private ConcurrentHashMap<Integer, PartitionLocation> registerShuffle(String appId, int shuffleId, int numMappers, int numPartitions) {
        return this.registerShuffleInternal(shuffleId, numMappers, numMappers, () -> (PbRegisterShuffleResponse)this.driverRssMetaService.askSync(ControlMessages$RegisterShuffle$.MODULE$.apply(appId, shuffleId, numMappers, numPartitions), this.conf.registerShuffleRpcAskTimeout(), ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
    }

    @VisibleForTesting
    public PartitionLocation registerMapPartitionTask(String appId, int shuffleId, int numMappers, int mapId, int attemptId) {
        int partitionId = PackedPartitionId.packedPartitionId(mapId, attemptId);
        logger.info("register mapPartitionTask, mapId: {}, attemptId: {}, partitionId: {}", new Object[]{mapId, attemptId, partitionId});
        ConcurrentHashMap<Integer, PartitionLocation> partitionLocationMap = this.registerShuffleInternal(shuffleId, numMappers, numMappers, () -> (PbRegisterShuffleResponse)this.driverRssMetaService.askSync(ControlMessages$RegisterMapPartitionTask$.MODULE$.apply(appId, shuffleId, numMappers, mapId, attemptId, partitionId), this.conf.registerShuffleRpcAskTimeout(), ClassTag$.MODULE$.apply(PbRegisterShuffleResponse.class)));
        return partitionLocationMap.get(partitionId);
    }

    private ConcurrentHashMap<Integer, PartitionLocation> registerShuffleInternal(int shuffleId, int numMappers, int numPartitions, Callable<PbRegisterShuffleResponse> callable) {
        for (int numRetries = this.registerShuffleMaxRetries; numRetries > 0; --numRetries) {
            try {
                PbRegisterShuffleResponse response = callable.call();
                StatusCode respStatus = Utils.toStatusCode(response.getStatus());
                if (StatusCode.SUCCESS.equals((Object)respStatus)) {
                    ConcurrentHashMap<Integer, PartitionLocation> result = new ConcurrentHashMap<Integer, PartitionLocation>();
                    for (int i = 0; i < response.getPartitionLocationsList().size(); ++i) {
                        PartitionLocation partitionLoc = PbSerDeUtils.fromPbPartitionLocation(response.getPartitionLocationsList().get(i));
                        result.put(partitionLoc.getId(), partitionLoc);
                    }
                    return result;
                }
                if (StatusCode.SLOT_NOT_AVAILABLE.equals((Object)respStatus)) {
                    logger.warn("LifecycleManager request slots return {}, retry again, remain retry times {}", (Object)StatusCode.SLOT_NOT_AVAILABLE, (Object)(numRetries - 1));
                } else {
                    logger.error("LifecycleManager request slots return {}, retry again, remain retry times {}", (Object)StatusCode.REQUEST_FAILED, (Object)(numRetries - 1));
                }
            }
            catch (Exception e) {
                logger.error("Exception raised while registering shuffle {} with {} mapper and {} partitions.", new Object[]{shuffleId, numMappers, numPartitions, e});
                break;
            }
            try {
                TimeUnit.MILLISECONDS.sleep(this.registerShuffleRetryWaitMs);
                continue;
            }
            catch (InterruptedException e) {
                break;
            }
        }
        return null;
    }

    private void limitMaxInFlight(String mapKey, PushState pushState, int limit) throws IOException {
        long times;
        if (pushState.exception.get() != null) {
            throw pushState.exception.get();
        }
        long timeoutMs = this.conf.pushLimitInFlightTimeoutMs();
        long delta = this.conf.pushLimitInFlightSleepDeltaMs();
        try {
            for (times = timeoutMs / delta; times > 0L && pushState.inflightBatchCount() > limit; --times) {
                if (pushState.exception.get() != null) {
                    throw pushState.exception.get();
                }
                pushState.failExpiredBatch();
                Thread.sleep(delta);
            }
        }
        catch (InterruptedException e) {
            pushState.exception.set(new IOException(e));
        }
        if (times <= 0L) {
            logger.error("After waiting for {} ms, there are still {} batches in flight for map {}, which exceeds the limit {}.", new Object[]{timeoutMs, pushState.inflightBatchCount(), mapKey, limit});
            throw new IOException("wait timeout for task " + mapKey, pushState.exception.get());
        }
        if (pushState.exception.get() != null) {
            throw pushState.exception.get();
        }
    }

    private boolean waitRevivedLocation(ConcurrentHashMap<Integer, PartitionLocation> map, int partitionId, int epoch) {
        PartitionLocation currentLocation = map.get(partitionId);
        if (currentLocation != null && currentLocation.getEpoch() > epoch) {
            return true;
        }
        long sleepTimeMs = RND.nextInt(50);
        if (sleepTimeMs > 30L) {
            try {
                TimeUnit.MILLISECONDS.sleep(sleepTimeMs);
            }
            catch (InterruptedException e) {
                logger.warn("Wait revived location interrupted", (Throwable)e);
                Thread.currentThread().interrupt();
            }
        }
        return (currentLocation = map.get(partitionId)) != null && currentLocation.getEpoch() > epoch;
    }

    private boolean revive(String applicationId, int shuffleId, int mapId, int attemptId, int partitionId, int epoch, PartitionLocation oldLocation, StatusCode cause) {
        ConcurrentHashMap<Integer, PartitionLocation> map = this.reducePartitionMap.get(shuffleId);
        if (this.waitRevivedLocation(map, partitionId, epoch)) {
            logger.debug("Has already revived for shuffle {} map {} reduce {} epoch {}, just return(Assume revive successfully).", new Object[]{shuffleId, mapId, partitionId, epoch});
            return true;
        }
        String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        if (this.mapperEnded(shuffleId, mapId, attemptId)) {
            logger.debug("The mapper(shuffle {} map {}) has already ended, just return(Assume revive successfully).", (Object)shuffleId, (Object)mapId);
            return true;
        }
        try {
            PbChangeLocationResponse response = (PbChangeLocationResponse)this.driverRssMetaService.askSync(ControlMessages$Revive$.MODULE$.apply(applicationId, shuffleId, mapId, attemptId, partitionId, epoch, oldLocation, cause), this.conf.requestPartitionLocationRpcAskTimeout(), ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
            StatusCode respStatus = Utils.toStatusCode(response.getStatus());
            if (StatusCode.SUCCESS.equals((Object)respStatus)) {
                map.put(partitionId, PbSerDeUtils.fromPbPartitionLocation(response.getLocation()));
                return true;
            }
            if (StatusCode.MAP_ENDED.equals((Object)respStatus)) {
                this.mapperEndMap.computeIfAbsent(shuffleId, id -> ConcurrentHashMap.newKeySet()).add(mapKey);
                return true;
            }
            return false;
        }
        catch (Exception e) {
            logger.error("Exception raised while reviving for shuffle {} reduce {} epoch {}.", new Object[]{shuffleId, partitionId, epoch, e});
            return false;
        }
    }

    public int pushOrMergeData(final String applicationId, final int shuffleId, final int mapId, final int attemptId, final int partitionId, byte[] data, int offset, int length, int numMappers, int numPartitions, boolean doPush) throws IOException {
        final String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
        if (this.mapperEnded(shuffleId, mapId, attemptId)) {
            logger.debug("The mapper(shuffle {} map {} attempt {}) has already ended while pushing data.", new Object[]{shuffleId, mapId, attemptId});
            PushState pushState = this.pushStates.get(mapKey);
            if (pushState != null) {
                pushState.cleanup();
            }
            return 0;
        }
        ConcurrentHashMap map = this.reducePartitionMap.computeIfAbsent(shuffleId, id -> this.registerShuffle(applicationId, shuffleId, numMappers, numPartitions));
        if (map == null) {
            throw new IOException("Register shuffle failed for shuffle " + shuffleKey);
        }
        if (!map.containsKey(partitionId)) {
            logger.warn("It should never reach here!");
            if (!this.revive(applicationId, shuffleId, mapId, attemptId, partitionId, -1, null, StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
                throw new IOException("Revive for shuffle " + shuffleKey + " partitionId " + partitionId + " failed.");
            }
        }
        if (this.mapperEnded(shuffleId, mapId, attemptId)) {
            logger.debug("The mapper(shuffle {} map {} attempt {}) has already ended while pushing data.", new Object[]{shuffleId, mapId, attemptId});
            PushState pushState = this.pushStates.get(mapKey);
            if (pushState != null) {
                pushState.cleanup();
            }
            return 0;
        }
        final PartitionLocation loc = (PartitionLocation)map.get(partitionId);
        if (loc == null) {
            throw new IOException("Partition location for shuffle " + shuffleKey + " partitionId " + partitionId + " is NULL!");
        }
        final PushState pushState = this.pushStates.computeIfAbsent(mapKey, s -> new PushState(this.conf));
        final int nextBatchId = pushState.batchId.addAndGet(1);
        Compressor compressor = this.compressorThreadLocal.get();
        compressor.compress(data, offset, length);
        int compressedTotalSize = compressor.getCompressedTotalSize();
        int BATCH_HEADER_SIZE = 16;
        final byte[] body = new byte[16 + compressedTotalSize];
        Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET, mapId);
        Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 4, attemptId);
        Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 8, nextBatchId);
        Platform.putInt(body, Platform.BYTE_ARRAY_OFFSET + 12, compressedTotalSize);
        System.arraycopy(compressor.getCompressedBuffer(), 0, body, 16, compressedTotalSize);
        if (doPush) {
            logger.debug("Do push data for app {} shuffle {} map {} attempt {} reduce {} batch {}.", new Object[]{applicationId, shuffleId, mapId, attemptId, partitionId, nextBatchId});
            this.limitMaxInFlight(mapKey, pushState, this.maxInFlight);
            pushState.addBatch(nextBatchId);
            NettyManagedBuffer buffer = new NettyManagedBuffer(Unpooled.wrappedBuffer(body));
            PushData pushData = new PushData(MASTER_MODE, shuffleKey, loc.getUniqueId(), buffer);
            final RpcResponseCallback callback = new RpcResponseCallback(){

                @Override
                public void onSuccess(ByteBuffer response) {
                    pushState.removeBatch(nextBatchId);
                    if (response.remaining() > 0 && response.get() == StatusCode.STAGE_ENDED.getValue()) {
                        ShuffleClientImpl.this.mapperEndMap.computeIfAbsent(shuffleId, id -> ConcurrentHashMap.newKeySet()).add(mapKey);
                    }
                    logger.debug("Push data to {}:{} success for map {} attempt {} batch {}.", new Object[]{loc.getHost(), loc.getPushPort(), mapId, attemptId, nextBatchId});
                }

                @Override
                public void onFailure(Throwable e) {
                    pushState.exception.compareAndSet(null, new IOException("Revived PushData failed!", e));
                    logger.error("Push data to {}:{} failed for map {} attempt {} batch {}.", new Object[]{loc.getHost(), loc.getPushPort(), mapId, attemptId, nextBatchId, e});
                }
            };
            RpcResponseCallback wrappedCallback = new RpcResponseCallback(){

                @Override
                public void onSuccess(ByteBuffer response) {
                    if (response.remaining() > 0) {
                        byte reason = response.get();
                        if (reason == StatusCode.SOFT_SPLIT.getValue()) {
                            logger.debug("Push data split required for map {} attempt {} batch {}", new Object[]{mapId, attemptId, nextBatchId});
                            ShuffleClientImpl.this.splitPartition(shuffleId, partitionId, applicationId, loc);
                            callback.onSuccess(response);
                        } else if (reason == StatusCode.HARD_SPLIT.getValue()) {
                            logger.debug("Push data split for map {} attempt {} batch {}.", new Object[]{mapId, attemptId, nextBatchId});
                            ShuffleClientImpl.this.pushDataRetryPool.submit(() -> ShuffleClientImpl.this.submitRetryPushData(applicationId, shuffleId, mapId, attemptId, body, nextBatchId, loc, this, pushState, StatusCode.HARD_SPLIT));
                        } else {
                            response.rewind();
                            callback.onSuccess(response);
                        }
                    } else {
                        callback.onSuccess(response);
                    }
                }

                @Override
                public void onFailure(Throwable e) {
                    if (pushState.exception.get() != null) {
                        return;
                    }
                    logger.error("Push data to {}:{} failed for map {} attempt {} batch {}.", new Object[]{loc.getHost(), loc.getPushPort(), mapId, attemptId, nextBatchId, e});
                    if (!ShuffleClientImpl.this.mapperEnded(shuffleId, mapId, attemptId)) {
                        ShuffleClientImpl.this.pushDataRetryPool.submit(() -> ShuffleClientImpl.this.submitRetryPushData(applicationId, shuffleId, mapId, attemptId, body, nextBatchId, loc, callback, pushState, ShuffleClientImpl.this.getPushDataFailCause(e.getMessage())));
                    } else {
                        pushState.removeBatch(nextBatchId);
                        logger.info("Mapper shuffleId:{} mapId:{} attempt:{} already ended, remove batchId:{}.", new Object[]{shuffleId, mapId, attemptId, nextBatchId});
                    }
                }
            };
            try {
                TransportClient client = this.dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), partitionId);
                ChannelFuture future = client.pushData(pushData, wrappedCallback);
                pushState.pushStarted(nextBatchId, future, wrappedCallback);
            }
            catch (Exception e) {
                logger.warn("PushData failed", (Throwable)e);
                wrappedCallback.onFailure(new Exception(this.getPushDataFailCause(e.getMessage()).toString(), e));
            }
        } else {
            logger.debug("Merge batch {}.", (Object)nextBatchId);
            String addressPair = this.genAddressPair(loc);
            boolean shouldPush = pushState.addBatchData(addressPair, loc, nextBatchId, body);
            if (shouldPush) {
                this.limitMaxInFlight(mapKey, pushState, this.maxInFlight);
                DataBatches dataBatches = pushState.takeDataBatches(addressPair);
                this.doPushMergedData(addressPair.split("-")[0], applicationId, shuffleId, mapId, attemptId, dataBatches.requireBatches(), pushState, false);
            }
        }
        return body.length;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void splitPartition(int shuffleId, int partitionId, String applicationId, PartitionLocation loc) {
        Set splittingSet;
        Set set = splittingSet = this.splitting.computeIfAbsent(shuffleId, integer -> ConcurrentHashMap.newKeySet());
        synchronized (set) {
            if (splittingSet.contains(partitionId)) {
                logger.debug("shuffle {} partitionId {} is splitting, skip split request ", (Object)shuffleId, (Object)partitionId);
                return;
            }
            splittingSet.add(partitionId);
        }
        ConcurrentHashMap<Integer, PartitionLocation> currentShuffleLocs = this.reducePartitionMap.get(shuffleId);
        ShuffleClientHelper.sendShuffleSplitAsync(this.driverRssMetaService, this.conf, ControlMessages$PartitionSplit$.MODULE$.apply(applicationId, shuffleId, partitionId, loc.getEpoch(), loc), this.partitionSplitPool, splittingSet, partitionId, shuffleId, currentShuffleLocs);
    }

    @Override
    public int pushData(String applicationId, int shuffleId, int mapId, int attemptId, int partitionId, byte[] data, int offset, int length, int numMappers, int numPartitions) throws IOException {
        return this.pushOrMergeData(applicationId, shuffleId, mapId, attemptId, partitionId, data, offset, length, numMappers, numPartitions, true);
    }

    @Override
    public void prepareForMergeData(int shuffleId, int mapId, int attemptId) throws IOException {
        String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        PushState pushState = this.pushStates.get(mapKey);
        if (pushState != null) {
            this.limitMaxInFlight(mapKey, pushState, 0);
        }
    }

    @Override
    public int mergeData(String applicationId, int shuffleId, int mapId, int attemptId, int partitionId, byte[] data, int offset, int length, int numMappers, int numPartitions) throws IOException {
        return this.pushOrMergeData(applicationId, shuffleId, mapId, attemptId, partitionId, data, offset, length, numMappers, numPartitions, false);
    }

    @Override
    public void pushMergedData(String applicationId, int shuffleId, int mapId, int attemptId) throws IOException {
        String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        PushState pushState = this.pushStates.get(mapKey);
        if (pushState == null) {
            return;
        }
        ArrayList<Map.Entry<String, DataBatches>> batchesArr = new ArrayList<Map.Entry<String, DataBatches>>(pushState.batchesMap.entrySet());
        while (!batchesArr.isEmpty()) {
            this.limitMaxInFlight(mapKey, pushState, this.maxInFlight);
            Map.Entry<String, DataBatches> entry = batchesArr.get(RND.nextInt(batchesArr.size()));
            ArrayList<DataBatches.DataBatch> batches = entry.getValue().requireBatches(this.pushBufferMaxSize);
            if (entry.getValue().getTotalSize() == 0) {
                batchesArr.remove(entry);
            }
            String[] tokens = entry.getKey().split("-");
            this.doPushMergedData(tokens[0], applicationId, shuffleId, mapId, attemptId, batches, pushState, false);
        }
    }

    private void doPushMergedData(String hostPort, final String applicationId, final int shuffleId, final int mapId, final int attemptId, final ArrayList<DataBatches.DataBatch> batches, final PushState pushState, final boolean revived) {
        String[] splits = hostPort.split(":");
        final String host = splits[0];
        final int port = Integer.parseInt(splits[1]);
        final int groupedBatchId = pushState.batchId.addAndGet(1);
        pushState.addBatch(groupedBatchId);
        int numBatches = batches.size();
        String[] partitionUniqueIds = new String[numBatches];
        int[] offsets = new int[numBatches];
        final int[] batchIds = new int[numBatches];
        int currentSize = 0;
        CompositeByteBuf byteBuf = Unpooled.compositeBuffer();
        for (int i = 0; i < numBatches; ++i) {
            DataBatches.DataBatch batch = batches.get(i);
            partitionUniqueIds[i] = batch.loc.getUniqueId();
            offsets[i] = currentSize;
            batchIds[i] = batch.batchId;
            currentSize += batch.body.length;
            byteBuf.addComponent(true, Unpooled.wrappedBuffer(batch.body));
        }
        NettyManagedBuffer buffer = new NettyManagedBuffer(byteBuf);
        String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
        PushMergedData mergedData = new PushMergedData(MASTER_MODE, shuffleKey, partitionUniqueIds, offsets, buffer);
        final RpcResponseCallback callback = new RpcResponseCallback(){

            @Override
            public void onSuccess(ByteBuffer response) {
                logger.debug("Push data success for map {} attempt {} grouped batch {}.", new Object[]{mapId, attemptId, groupedBatchId});
                pushState.removeBatch(groupedBatchId);
                if (response.remaining() > 0 && response.get() == StatusCode.STAGE_ENDED.getValue()) {
                    ShuffleClientImpl.this.mapperEndMap.computeIfAbsent(shuffleId, id -> ConcurrentHashMap.newKeySet()).add(Utils.makeMapKey(shuffleId, mapId, attemptId));
                }
            }

            @Override
            public void onFailure(Throwable e) {
                String errorMsg = (revived ? "Revived push" : "Push") + " merged data to " + host + ":" + port + " failed for map " + mapId + " attempt " + attemptId + " batches " + Arrays.toString(batchIds) + ".";
                pushState.exception.compareAndSet(null, new IOException(errorMsg, e));
                if (logger.isDebugEnabled()) {
                    for (int batchId : batchIds) {
                        logger.debug("Push data failed for map {} attempt {} batch {}.", new Object[]{mapId, attemptId, batchId});
                    }
                }
            }
        };
        RpcResponseCallback wrappedCallback = new RpcResponseCallback(){

            @Override
            public void onSuccess(ByteBuffer response) {
                if (response.remaining() > 0) {
                    byte reason = response.get();
                    if (reason == StatusCode.HARD_SPLIT.getValue()) {
                        logger.info("Push merged data return hard split for map " + mapId + " attempt " + attemptId + " batches " + Arrays.toString(batchIds) + ".");
                        ShuffleClientImpl.this.pushDataRetryPool.submit(() -> ShuffleClientImpl.this.submitRetryPushMergedData(pushState, applicationId, shuffleId, mapId, attemptId, batches, StatusCode.HARD_SPLIT, groupedBatchId));
                    } else {
                        response.rewind();
                        logger.error("Push merged data should not receive this response");
                        callback.onSuccess(response);
                    }
                } else {
                    callback.onSuccess(response);
                }
            }

            @Override
            public void onFailure(Throwable e) {
                if (pushState.exception.get() != null) {
                    return;
                }
                if (revived) {
                    callback.onFailure(e);
                    return;
                }
                logger.error("Push merged data to " + host + ":" + port + " failed for map " + mapId + " attempt " + attemptId + " batches " + Arrays.toString(batchIds) + ".", e);
                if (!ShuffleClientImpl.this.mapperEnded(shuffleId, mapId, attemptId)) {
                    ShuffleClientImpl.this.pushDataRetryPool.submit(() -> ShuffleClientImpl.this.submitRetryPushMergedData(pushState, applicationId, shuffleId, mapId, attemptId, batches, ShuffleClientImpl.this.getPushDataFailCause(e.getMessage()), groupedBatchId));
                }
            }
        };
        try {
            TransportClient client = this.dataClientFactory.createClient(host, port);
            ChannelFuture future = client.pushMergedData(mergedData, wrappedCallback);
            pushState.pushStarted(groupedBatchId, future, wrappedCallback);
        }
        catch (Exception e) {
            logger.warn("PushMergedData failed", (Throwable)e);
            wrappedCallback.onFailure(new Exception(this.getPushDataFailCause(e.getMessage()).toString(), e));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void mapperEnd(String applicationId, int shuffleId, int mapId, int attemptId, int numMappers) throws IOException {
        String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        PushState pushState = this.pushStates.computeIfAbsent(mapKey, s -> new PushState(this.conf));
        try {
            this.limitMaxInFlight(mapKey, pushState, 0);
            ControlMessages.MapperEndResponse response = (ControlMessages.MapperEndResponse)this.driverRssMetaService.askSync(new ControlMessages.MapperEnd(applicationId, shuffleId, mapId, attemptId, numMappers), ClassTag$.MODULE$.apply(ControlMessages.MapperEndResponse.class));
            if (response.status() != StatusCode.SUCCESS) {
                throw new IOException("MapperEnd failed! StatusCode: " + (Object)((Object)response.status()));
            }
        }
        finally {
            this.pushStates.remove(mapKey);
        }
    }

    @Override
    public void cleanup(String applicationId, int shuffleId, int mapId, int attemptId) {
        String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
        PushState pushState = this.pushStates.remove(mapKey);
        if (pushState != null) {
            pushState.exception.compareAndSet(null, new IOException("Cleaned Up"));
            pushState.cleanup();
        }
    }

    @Override
    public boolean unregisterShuffle(String applicationId, int shuffleId, boolean isDriver) {
        if (isDriver) {
            try {
                this.driverRssMetaService.send(ControlMessages$UnregisterShuffle$.MODULE$.apply(applicationId, shuffleId, RssHARetryClient.genRequestId()));
            }
            catch (Exception e) {
                logger.warn("Send UnregisterShuffle failed, ignore.", (Throwable)e);
            }
        }
        this.reducePartitionMap.remove(shuffleId);
        this.reduceFileGroupsMap.remove(shuffleId);
        this.mapperEndMap.remove(shuffleId);
        this.splitting.remove(shuffleId);
        logger.info("Unregistered shuffle {}.", (Object)shuffleId);
        return true;
    }

    @Override
    public RssInputStream readPartition(String applicationId, int shuffleId, int partitionId, int attemptNumber) throws IOException {
        return this.readPartition(applicationId, shuffleId, partitionId, attemptNumber, 0, Integer.MAX_VALUE);
    }

    @Override
    public RssInputStream readPartition(String applicationId, int shuffleId, int partitionId, int attemptNumber, int startMapIndex, int endMapIndex) throws IOException {
        String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
        ReduceFileGroups fileGroups = this.reduceFileGroupsMap.computeIfAbsent(shuffleId, id -> {
            long getReducerFileGroupStartTime = System.nanoTime();
            try {
                if (this.driverRssMetaService == null) {
                    logger.warn("Driver endpoint is null!");
                    return null;
                }
                ControlMessages.GetReducerFileGroup getReducerFileGroup = new ControlMessages.GetReducerFileGroup(applicationId, shuffleId);
                ControlMessages.GetReducerFileGroupResponse response = (ControlMessages.GetReducerFileGroupResponse)this.driverRssMetaService.askSync(getReducerFileGroup, this.conf.getReducerFileGroupRpcAskTimeout(), ClassTag$.MODULE$.apply(ControlMessages.GetReducerFileGroupResponse.class));
                if (response.status() == StatusCode.SUCCESS) {
                    logger.info("Shuffle {} request reducer file group success using time:{} ms", (Object)shuffleId, (Object)((System.nanoTime() - getReducerFileGroupStartTime) / 1000000L));
                    return new ReduceFileGroups(response.fileGroup(), response.attempts());
                }
                if (response.status() == StatusCode.STAGE_END_TIME_OUT) {
                    logger.warn("Request {} return {} for {}", new Object[]{getReducerFileGroup, StatusCode.STAGE_END_TIME_OUT.toString(), shuffleKey});
                } else if (response.status() == StatusCode.SHUFFLE_DATA_LOST) {
                    logger.warn("Request {} return {} for {}", new Object[]{getReducerFileGroup, StatusCode.SHUFFLE_DATA_LOST.toString(), shuffleKey});
                }
            }
            catch (Exception e) {
                logger.error("Exception raised while call GetReducerFileGroup for " + shuffleKey + ".", (Throwable)e);
            }
            return null;
        });
        if (fileGroups == null) {
            String msg = "Shuffle data lost for shuffle " + shuffleId + " reduce " + partitionId + "!";
            logger.error(msg);
            throw new IOException(msg);
        }
        if (fileGroups.partitionGroups.length == 0) {
            logger.warn("Shuffle data is empty for shuffle {} reduce {}.", (Object)shuffleId, (Object)partitionId);
            return RssInputStream.empty();
        }
        return RssInputStream.create(this.conf, this.dataClientFactory, shuffleKey, fileGroups.partitionGroups[partitionId], fileGroups.mapAttempts, attemptNumber, startMapIndex, endMapIndex);
    }

    @Override
    public void shutdown() {
        if (null != this.rpcEnv) {
            this.rpcEnv.shutdown();
        }
        if (null != this.dataClientFactory) {
            this.dataClientFactory.close();
        }
        if (null != this.pushDataRetryPool) {
            this.pushDataRetryPool.shutdown();
        }
        if (null != this.partitionSplitPool) {
            this.partitionSplitPool.shutdown();
        }
        if (null != this.driverRssMetaService) {
            this.driverRssMetaService = null;
        }
        logger.warn("Shuffle client has been shutdown!");
    }

    @Override
    public void setupMetaServiceRef(String host, int port) {
        this.driverRssMetaService = this.rpcEnv.setupEndpointRef(new RpcAddress(host, port), RpcNameConstants.RSS_METASERVICE_EP);
    }

    @Override
    public void setupMetaServiceRef(RpcEndpointRef endpointRef) {
        this.driverRssMetaService = endpointRef;
    }

    private boolean mapperEnded(int shuffleId, int mapId, int attemptId) {
        return this.mapperEndMap.containsKey(shuffleId) && this.mapperEndMap.get(shuffleId).contains(Utils.makeMapKey(shuffleId, mapId, attemptId));
    }

    private StatusCode getPushDataFailCause(String message) {
        logger.info("[getPushDataFailCause] message: " + message);
        StatusCode cause = StatusCode.PUSH_DATA_FAIL_SLAVE.getMessage().equals(message) ? StatusCode.PUSH_DATA_FAIL_SLAVE : (StatusCode.PUSH_DATA_FAIL_MASTER.getMessage().equals(message) || this.connectFail(message) ? StatusCode.PUSH_DATA_FAIL_MASTER : (StatusCode.PUSH_DATA_TIMEOUT.getMessage().equals(message) ? StatusCode.PUSH_DATA_TIMEOUT : StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE));
        return cause;
    }

    private boolean connectFail(String message) {
        return message.startsWith("Connection from ") && message.endsWith(" closed") || message.equals("Connection reset by peer") || message.startsWith("Failed to send RPC ");
    }

    @Override
    public void pushDataHandShake(String applicationId, int shuffleId, int mapId, int attemptId, int numPartitions, int bufferSize, PartitionLocation location) throws IOException {
        this.sendMessageInternal(shuffleId, mapId, attemptId, () -> {
            String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
            logger.info("pushDataHandShake shuffleKey:{}, attemptId:{}, locationId:{}", new Object[]{shuffleKey, attemptId, location.getUniqueId()});
            logger.debug("pushDataHandShake location:{}", (Object)location.toString());
            TransportClient client = this.dataClientFactory.createClient(location.getHost(), location.getPushPort());
            PushDataHandShake handShake = new PushDataHandShake(MASTER_MODE, shuffleKey, location.getUniqueId(), attemptId, numPartitions, bufferSize);
            client.sendRpcSync(handShake.toByteBuffer(), this.conf.pushDataTimeoutMs());
            return null;
        });
    }

    @Override
    public Optional<PartitionLocation> regionStart(String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location, int currentRegionIdx, boolean isBroadcast) throws IOException {
        return this.sendMessageInternal(shuffleId, mapId, attemptId, () -> {
            String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
            logger.info("regionStart shuffleKey:{}, attemptId:{}, locationId:{}", new Object[]{shuffleKey, attemptId, location.getUniqueId()});
            logger.debug("regionStart location:{}", (Object)location.toString());
            TransportClient client = this.dataClientFactory.createClient(location.getHost(), location.getPushPort());
            RegionStart regionStart = new RegionStart(MASTER_MODE, shuffleKey, location.getUniqueId(), attemptId, currentRegionIdx, isBroadcast);
            ByteBuffer regionStartResponse = client.sendRpcSync(regionStart.toByteBuffer(), this.conf.pushDataTimeoutMs());
            if (regionStartResponse.hasRemaining() && regionStartResponse.get() == StatusCode.HARD_SPLIT.getValue()) {
                PbChangeLocationResponse response = (PbChangeLocationResponse)this.driverRssMetaService.askSync(ControlMessages$Revive$.MODULE$.apply(applicationId, shuffleId, mapId, attemptId, location.getId(), location.getEpoch(), location, StatusCode.HARD_SPLIT), this.conf.requestPartitionLocationRpcAskTimeout(), ClassTag$.MODULE$.apply(PbChangeLocationResponse.class));
                StatusCode respStatus = Utils.toStatusCode(response.getStatus());
                if (StatusCode.SUCCESS.equals((Object)respStatus)) {
                    return Optional.of(PbSerDeUtils.fromPbPartitionLocation(response.getLocation()));
                }
                if (StatusCode.MAP_ENDED.equals((Object)respStatus)) {
                    String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
                    this.mapperEndMap.computeIfAbsent(shuffleId, id -> ConcurrentHashMap.newKeySet()).add(mapKey);
                    return Optional.empty();
                }
                logger.error("Exception raised while reviving for shuffle {} reduce {} epoch {}.", new Object[]{shuffleId, location.getId(), location.getEpoch()});
                throw new IOException("regiontstart revive failed");
            }
            return Optional.empty();
        });
    }

    @Override
    public void regionFinish(String applicationId, int shuffleId, int mapId, int attemptId, PartitionLocation location) throws IOException {
        this.sendMessageInternal(shuffleId, mapId, attemptId, () -> {
            String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId);
            logger.info("regionFinish shuffleKey:{}, attemptId:{}, locationId:{}", new Object[]{shuffleKey, attemptId, location.getUniqueId()});
            logger.debug("regionFinish location:{}", (Object)location.toString());
            TransportClient client = this.dataClientFactory.createClient(location.getHost(), location.getPushPort());
            RegionFinish regionFinish = new RegionFinish(MASTER_MODE, shuffleKey, location.getUniqueId(), attemptId);
            client.sendRpcSync(regionFinish.toByteBuffer(), this.conf.pushDataTimeoutMs());
            return null;
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private <R> R sendMessageInternal(int shuffleId, int mapId, int attemptId, ThrowingExceptionSupplier<R, Exception> supplier) throws IOException {
        PushState pushState = null;
        int batchId = 0;
        try {
            String mapKey = Utils.makeMapKey(shuffleId, mapId, attemptId);
            if (this.mapperEnded(shuffleId, mapId, attemptId)) {
                logger.debug("The mapper(shuffle {} map {} attempt {}) has already ended while pushing data.", new Object[]{shuffleId, mapId, attemptId});
                R r = null;
                return r;
            }
            pushState = this.pushStates.computeIfAbsent(mapKey, s -> new PushState(this.conf));
            this.limitMaxInFlight(mapKey, pushState, 0);
            batchId = pushState.batchId.incrementAndGet();
            pushState.addBatch(batchId);
            R r = this.retrySendMessage(supplier);
            return r;
        }
        finally {
            if (pushState != null) {
                pushState.removeBatch(batchId);
            }
        }
    }

    private <R> R retrySendMessage(ThrowingExceptionSupplier<R, Exception> supplier) throws IOException {
        int retryTimes = 0;
        boolean isSuccess = false;
        Throwable currentException = null;
        R result = null;
        while (!Thread.currentThread().isInterrupted() && !isSuccess && retryTimes < this.conf.networkIoMaxRetries("push")) {
            logger.info("retrySendMessage times: {}", (Object)retryTimes);
            try {
                result = supplier.get();
                isSuccess = true;
            }
            catch (Exception e) {
                currentException = e;
                if (e instanceof InterruptedException) {
                    Thread.currentThread().interrupt();
                }
                if (!this.shouldRetry(e)) break;
                ++retryTimes;
                Uninterruptibles.sleepUninterruptibly(this.conf.networkIoRetryWaitMs("push"), TimeUnit.MILLISECONDS);
            }
        }
        if (!isSuccess) {
            if (currentException instanceof IOException) {
                throw (IOException)currentException;
            }
            throw new IOException(currentException.getMessage(), currentException);
        }
        return result;
    }

    private boolean shouldRetry(Throwable e) {
        boolean isIOException = e instanceof IOException || e instanceof TimeoutException || e.getCause() != null && e.getCause() instanceof TimeoutException || e.getCause() != null && e.getCause() instanceof IOException || e instanceof RuntimeException && e.getMessage().startsWith(IOException.class.getName());
        return isIOException;
    }

    @FunctionalInterface
    static interface ThrowingExceptionSupplier<R, E extends Exception> {
        public R get() throws E;
    }

    private static class ReduceFileGroups {
        final PartitionLocation[][] partitionGroups;
        final int[] mapAttempts;

        ReduceFileGroups(PartitionLocation[][] partitionGroups, int[] mapAttempts) {
            this.partitionGroups = partitionGroups;
            this.mapAttempts = mapAttempts;
        }
    }
}

