/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.shuffle.celeborn;

import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.atomic.LongAdder;
import javax.annotation.Nullable;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.client.write.DataPusher;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.spark.Aggregator;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.celeborn.OpenByteArrayOutputStream;
import org.apache.spark.shuffle.celeborn.RssShuffleHandle;
import org.apache.spark.shuffle.celeborn.SendBufferPool;
import org.apache.spark.shuffle.celeborn.SparkUtils;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.execution.PartitionIdPassthrough;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.unsafe.Platform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.Product2;
import scala.collection.Iterator;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

@Private
public class HashBasedShuffleWriter<K, V, C>
extends ShuffleWriter<K, V> {
    private static final Logger logger = LoggerFactory.getLogger(HashBasedShuffleWriter.class);
    private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
    private static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 0x100000;
    private final int PUSH_BUFFER_INIT_SIZE;
    private final int PUSH_BUFFER_MAX_SIZE;
    private final ShuffleDependency<K, V, C> dep;
    private final Partitioner partitioner;
    private final ShuffleWriteMetrics writeMetrics;
    private final String appId;
    private final int shuffleId;
    private final int mapId;
    private final TaskContext taskContext;
    private final ShuffleClient rssShuffleClient;
    private final int numMappers;
    private final int numPartitions;
    @Nullable
    private MapStatus mapStatus;
    private long peakMemoryUsedBytes = 0L;
    private final OpenByteArrayOutputStream serBuffer;
    private final SerializationStream serOutputStream;
    private byte[][] sendBuffers;
    private int[] sendOffsets;
    private final LongAdder[] mapStatusLengths;
    private final long[] mapStatusRecords;
    private final long[] tmpRecords;
    private SendBufferPool sendBufferPool;
    private volatile boolean stopping = false;
    private final DataPusher dataPusher;

    public HashBasedShuffleWriter(RssShuffleHandle<K, V, C> handle, int mapId, TaskContext taskContext, CelebornConf conf, ShuffleClient client, SendBufferPool sendBufferPool) throws IOException {
        this.mapId = mapId;
        this.dep = handle.dependency();
        this.appId = handle.newAppId();
        this.shuffleId = this.dep.shuffleId();
        SerializerInstance serializer = this.dep.serializer().newInstance();
        this.partitioner = this.dep.partitioner();
        this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
        this.taskContext = taskContext;
        this.numMappers = handle.numMaps();
        this.numPartitions = this.dep.partitioner().numPartitions();
        this.rssShuffleClient = client;
        this.serBuffer = new OpenByteArrayOutputStream(0x100000);
        this.serOutputStream = serializer.serializeStream((OutputStream)this.serBuffer);
        this.mapStatusLengths = new LongAdder[this.numPartitions];
        for (int i = 0; i < this.numPartitions; ++i) {
            this.mapStatusLengths[i] = new LongAdder();
        }
        this.mapStatusRecords = new long[this.numPartitions];
        this.tmpRecords = new long[this.numPartitions];
        this.PUSH_BUFFER_INIT_SIZE = conf.pushBufferInitialSize();
        this.PUSH_BUFFER_MAX_SIZE = conf.pushBufferMaxSize();
        this.sendBufferPool = sendBufferPool;
        this.sendBuffers = sendBufferPool.acquireBuffer(this.numPartitions);
        if (this.sendBuffers == null) {
            this.sendBuffers = new byte[this.numPartitions][];
        }
        this.sendOffsets = new int[this.numPartitions];
        this.dataPusher = new DataPusher(this.appId, this.shuffleId, mapId, taskContext.attemptNumber(), taskContext.taskAttemptId(), this.numMappers, this.numPartitions, conf, this.rssShuffleClient, arg_0 -> ((ShuffleWriteMetrics)this.writeMetrics).incBytesWritten(arg_0), this.mapStatusLengths);
    }

    public void write(Iterator<Product2<K, V>> records) throws IOException {
        if (this.canUseFastWrite()) {
            this.fastWrite0(records);
        } else if (this.dep.mapSideCombine()) {
            if (this.dep.aggregator().isEmpty()) {
                throw new UnsupportedOperationException("When using map side combine, an aggregator must be specified.");
            }
            this.write0(((Aggregator)this.dep.aggregator().get()).combineValuesByKey(records, this.taskContext));
        } else {
            this.write0(records);
        }
        this.close();
    }

    @VisibleForTesting
    boolean canUseFastWrite() {
        return this.dep.serializer() instanceof UnsafeRowSerializer && this.partitioner instanceof PartitionIdPassthrough;
    }

    private void fastWrite0(Iterator iterator) throws IOException {
        Iterator records = iterator;
        SQLMetric dataSize = SparkUtils.getUnsafeRowSerializerDataSizeMetric((UnsafeRowSerializer)this.dep.serializer());
        while (records.hasNext()) {
            Product2 record = (Product2)records.next();
            int partitionId = (Integer)record._1();
            UnsafeRow row = (UnsafeRow)record._2();
            int rowSize = row.getSizeInBytes();
            int serializedRecordSize = 4 + rowSize;
            if (dataSize != null) {
                dataSize.add((long)serializedRecordSize);
            }
            if (serializedRecordSize > this.PUSH_BUFFER_MAX_SIZE) {
                byte[] giantBuffer = new byte[serializedRecordSize];
                Platform.putInt((Object)giantBuffer, (long)Platform.BYTE_ARRAY_OFFSET, (int)Integer.reverseBytes(rowSize));
                Platform.copyMemory((Object)row.getBaseObject(), (long)row.getBaseOffset(), (Object)giantBuffer, (long)(Platform.BYTE_ARRAY_OFFSET + 4), (long)rowSize);
                this.pushGiantRecord(partitionId, giantBuffer, serializedRecordSize);
            } else {
                int offset = this.getOrUpdateOffset(partitionId, serializedRecordSize);
                byte[] buffer = this.getOrCreateBuffer(partitionId);
                Platform.putInt((Object)buffer, (long)(Platform.BYTE_ARRAY_OFFSET + offset), (int)Integer.reverseBytes(rowSize));
                Platform.copyMemory((Object)row.getBaseObject(), (long)row.getBaseOffset(), (Object)buffer, (long)(Platform.BYTE_ARRAY_OFFSET + offset + 4), (long)rowSize);
                this.sendOffsets[partitionId] = offset + serializedRecordSize;
            }
            int n = partitionId;
            this.tmpRecords[n] = this.tmpRecords[n] + 1L;
        }
    }

    private void write0(Iterator iterator) throws IOException {
        Iterator records = iterator;
        while (records.hasNext()) {
            Product2 record = (Product2)records.next();
            Object key = record._1();
            int partitionId = this.partitioner.getPartition(key);
            this.serBuffer.reset();
            this.serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
            this.serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
            this.serOutputStream.flush();
            int serializedRecordSize = this.serBuffer.size();
            assert (serializedRecordSize > 0);
            if (serializedRecordSize > this.PUSH_BUFFER_MAX_SIZE) {
                this.pushGiantRecord(partitionId, this.serBuffer.getBuf(), serializedRecordSize);
            } else {
                int offset = this.getOrUpdateOffset(partitionId, serializedRecordSize);
                byte[] buffer = this.getOrCreateBuffer(partitionId);
                System.arraycopy(this.serBuffer.getBuf(), 0, buffer, offset, serializedRecordSize);
                this.sendOffsets[partitionId] = offset + serializedRecordSize;
            }
            int n = partitionId;
            this.tmpRecords[n] = this.tmpRecords[n] + 1L;
        }
    }

    private byte[] getOrCreateBuffer(int partitionId) {
        byte[] buffer = this.sendBuffers[partitionId];
        if (buffer == null) {
            buffer = new byte[this.PUSH_BUFFER_INIT_SIZE];
            this.sendBuffers[partitionId] = buffer;
            this.peakMemoryUsedBytes += (long)this.PUSH_BUFFER_INIT_SIZE;
        }
        return buffer;
    }

    private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throws IOException {
        logger.debug("Push giant record for partition {}, size {}.", (Object)partitionId, (Object)numBytes);
        long pushStartTime = System.nanoTime();
        int bytesWritten = this.rssShuffleClient.pushData(this.appId, this.shuffleId, this.mapId, this.taskContext.attemptNumber(), partitionId, buffer, 0, numBytes, this.numMappers, this.numPartitions);
        this.mapStatusLengths[partitionId].add(bytesWritten);
        this.writeMetrics.incBytesWritten((long)bytesWritten);
        this.writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
    }

    private int getOrUpdateOffset(int partitionId, int serializedRecordSize) throws IOException {
        int offset = this.sendOffsets[partitionId];
        byte[] buffer = this.getOrCreateBuffer(partitionId);
        while (buffer.length - offset < serializedRecordSize && buffer.length < this.PUSH_BUFFER_MAX_SIZE) {
            byte[] newBuffer = new byte[Math.min(buffer.length * 2, this.PUSH_BUFFER_MAX_SIZE)];
            this.peakMemoryUsedBytes += (long)(newBuffer.length - buffer.length);
            System.arraycopy(buffer, 0, newBuffer, 0, offset);
            this.sendBuffers[partitionId] = newBuffer;
            buffer = newBuffer;
        }
        if (buffer.length - offset < serializedRecordSize) {
            this.flushSendBuffer(partitionId, buffer, offset);
            this.updateMapStatus();
            offset = 0;
        }
        return offset;
    }

    private void flushSendBuffer(int partitionId, byte[] buffer, int size) throws IOException {
        long pushStartTime = System.nanoTime();
        logger.debug("Flush buffer for partition {}, size {}.", (Object)partitionId, (Object)size);
        this.dataPusher.addTask(partitionId, buffer, size);
        this.writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
    }

    private void close() throws IOException {
        this.dataPusher.waitOnTermination();
        this.rssShuffleClient.prepareForMergeData(this.shuffleId, this.mapId, this.taskContext.attemptNumber());
        for (int i = 0; i < this.sendBuffers.length; ++i) {
            int size = this.sendOffsets[i];
            if (size <= 0) continue;
            int bytesWritten = this.rssShuffleClient.mergeData(this.appId, this.shuffleId, this.mapId, this.taskContext.attemptNumber(), i, this.sendBuffers[i], 0, size, this.numMappers, this.numPartitions);
            this.sendBuffers[i] = null;
            this.mapStatusLengths[i].add(bytesWritten);
            this.writeMetrics.incBytesWritten((long)bytesWritten);
        }
        this.rssShuffleClient.pushMergedData(this.appId, this.shuffleId, this.mapId, this.taskContext.attemptNumber());
        this.updateMapStatus();
        this.sendBufferPool.returnBuffer(this.sendBuffers);
        this.sendBuffers = null;
        this.sendOffsets = null;
        long waitStartTime = System.nanoTime();
        this.rssShuffleClient.mapperEnd(this.appId, this.shuffleId, this.mapId, this.taskContext.attemptNumber(), this.numMappers);
        this.writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
        BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
        this.mapStatus = SparkUtils.createMapStatus(bmId, SparkUtils.unwrap(this.mapStatusLengths), this.mapStatusRecords);
    }

    private void updateMapStatus() {
        long recordsWritten = 0L;
        for (int i = 0; i < this.partitioner.numPartitions(); ++i) {
            int n = i;
            this.mapStatusRecords[n] = this.mapStatusRecords[n] + this.tmpRecords[i];
            recordsWritten += this.tmpRecords[i];
            this.tmpRecords[i] = 0L;
        }
        this.writeMetrics.incRecordsWritten(recordsWritten);
    }

    public Option<MapStatus> stop(boolean success) {
        try {
            this.taskContext.taskMetrics().incPeakExecutionMemory(this.peakMemoryUsedBytes);
            if (this.stopping) {
                Option option = Option.apply(null);
                return option;
            }
            this.stopping = true;
            if (success) {
                if (this.mapStatus == null) {
                    throw new IllegalStateException("Cannot call stop(true) without having called write()");
                }
                Option option = Option.apply((Object)this.mapStatus);
                return option;
            }
            Option option = Option.apply(null);
            return option;
        }
        finally {
            this.rssShuffleClient.cleanup(this.appId, this.shuffleId, this.mapId, this.taskContext.attemptNumber());
        }
    }
}

