/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ratis.netty.server;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.ratis.client.DataStreamOutputRpc;
import org.apache.ratis.client.impl.ClientProtoUtils;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
import org.apache.ratis.io.StandardWriteOption;
import org.apache.ratis.io.WriteOption;
import org.apache.ratis.metrics.Timekeeper;
import org.apache.ratis.netty.metrics.NettyServerStreamRpcMetrics;
import org.apache.ratis.netty.server.DataStreamRequestByteBuf;
import org.apache.ratis.proto.RaftProtos;
import org.apache.ratis.protocol.ClientId;
import org.apache.ratis.protocol.ClientInvocationId;
import org.apache.ratis.protocol.DataStreamReply;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftClientRequest;
import org.apache.ratis.protocol.RaftGroupId;
import org.apache.ratis.protocol.RaftPeer;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.protocol.RoutingTable;
import org.apache.ratis.protocol.exceptions.AlreadyExistsException;
import org.apache.ratis.protocol.exceptions.DataStreamException;
import org.apache.ratis.server.RaftConfiguration;
import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.RaftServerConfigKeys;
import org.apache.ratis.statemachine.StateMachine;
import org.apache.ratis.thirdparty.io.netty.buffer.ByteBuf;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelHandlerContext;
import org.apache.ratis.thirdparty.io.netty.channel.ChannelId;
import org.apache.ratis.util.ConcurrentUtils;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.MemoizedSupplier;
import org.apache.ratis.util.Preconditions;
import org.apache.ratis.util.ReferenceCountedObject;
import org.apache.ratis.util.TimeDuration;
import org.apache.ratis.util.TimeoutExecutor;
import org.apache.ratis.util.function.CheckedBiFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DataStreamManagement {
    public static final Logger LOG = LoggerFactory.getLogger(DataStreamManagement.class);
    private final RaftServer server;
    private final String name;
    private final StreamMap streams = new StreamMap();
    private final ChannelMap channels;
    private final Executor requestExecutor;
    private final Executor writeExecutor;
    private final NettyServerStreamRpcMetrics nettyServerStreamRpcMetrics;

    DataStreamManagement(RaftServer server, NettyServerStreamRpcMetrics metrics) {
        this.server = server;
        this.name = server.getId() + "-" + JavaUtils.getClassSimpleName(this.getClass());
        this.channels = new ChannelMap();
        RaftProperties properties = server.getProperties();
        boolean useCachedThreadPool = RaftServerConfigKeys.DataStream.asyncRequestThreadPoolCached(properties);
        this.requestExecutor = ConcurrentUtils.newThreadPoolWithMax(useCachedThreadPool, RaftServerConfigKeys.DataStream.asyncRequestThreadPoolSize(properties), this.name + "-request-");
        this.writeExecutor = ConcurrentUtils.newThreadPoolWithMax(useCachedThreadPool, RaftServerConfigKeys.DataStream.asyncWriteThreadPoolSize(properties), this.name + "-write-");
        this.nettyServerStreamRpcMetrics = metrics;
    }

    private CompletableFuture<StateMachine.DataStream> stream(RaftClientRequest request, StateMachine stateMachine) {
        NettyServerStreamRpcMetrics.RequestMetrics metrics = this.getMetrics().newRequestMetrics(NettyServerStreamRpcMetrics.RequestType.STATE_MACHINE_STREAM);
        Timekeeper.Context context = metrics.start();
        return stateMachine.data().stream(request).whenComplete((r, e) -> metrics.stop(context, e == null));
    }

    private CompletableFuture<StateMachine.DataStream> computeDataStreamIfAbsent(RaftClientRequest request) throws IOException {
        RaftServer.Division division = this.server.getDivision(request.getRaftGroupId());
        ClientInvocationId invocationId = ClientInvocationId.valueOf(request);
        CompletableFuture<StateMachine.DataStream> created = new CompletableFuture<StateMachine.DataStream>();
        CompletableFuture<StateMachine.DataStream> returned = division.getDataStreamMap().computeIfAbsent(invocationId, key -> created);
        if (returned != created) {
            throw new AlreadyExistsException("A DataStream already exists for " + invocationId);
        }
        this.stream(request, division.getStateMachine()).whenComplete(JavaUtils.asBiConsumer(created));
        return created;
    }

    private StreamInfo newStreamInfo(ByteBuf buf, CheckedBiFunction<RaftClientRequest, Set<RaftPeer>, Set<DataStreamOutputRpc>, IOException> getStreams) {
        try {
            RaftClientRequest request = ClientProtoUtils.toRaftClientRequest(RaftProtos.RaftClientRequestProto.parseFrom(buf.nioBuffer()));
            boolean isPrimary = this.server.getId().equals(request.getServerId());
            return new StreamInfo(request, isPrimary, this.computeDataStreamIfAbsent(request), this.server, getStreams, this.getMetrics()::newRequestMetrics);
        }
        catch (Throwable e) {
            throw new CompletionException(e);
        }
    }

    static <T> CompletableFuture<T> composeAsync(AtomicReference<CompletableFuture<T>> future, Executor executor, Function<T, CompletableFuture<T>> function) {
        return future.updateAndGet(previous -> previous.thenComposeAsync(function, executor));
    }

    static CompletableFuture<Long> writeToAsync(ByteBuf buf, Iterable<WriteOption> options, StateMachine.DataStream stream, Executor defaultExecutor) {
        Executor e = Optional.ofNullable(stream.getExecutor()).orElse(defaultExecutor);
        return CompletableFuture.supplyAsync(() -> DataStreamManagement.writeTo(buf, options, stream), e);
    }

    static long writeTo(ByteBuf buf, Iterable<WriteOption> options, StateMachine.DataStream stream) {
        StateMachine.DataChannel channel = stream.getDataChannel();
        long byteWritten = 0L;
        for (ByteBuffer buffer : buf.nioBuffers()) {
            ReferenceCountedObject<ByteBuffer> wrapped = ReferenceCountedObject.wrap(buffer, buf::retain, buf::release);
            wrapped.retain();
            try {
                byteWritten += (long)channel.write(wrapped);
            }
            catch (Throwable t2) {
                throw new CompletionException(t2);
            }
            finally {
                wrapped.release();
            }
        }
        if (WriteOption.containsOption(options, (WriteOption)StandardWriteOption.SYNC)) {
            try {
                channel.force(false);
            }
            catch (IOException e) {
                throw new CompletionException(e);
            }
        }
        if (WriteOption.containsOption(options, (WriteOption)StandardWriteOption.CLOSE)) {
            DataStreamManagement.close(stream);
        }
        return byteWritten;
    }

    static void close(StateMachine.DataStream stream) {
        try {
            stream.getDataChannel().close();
        }
        catch (IOException e) {
            throw new CompletionException("Failed to close " + stream, e);
        }
    }

    static DataStreamReplyByteBuffer newDataStreamReplyByteBuffer(DataStreamRequestByteBuf request, RaftClientReply reply) {
        ByteBuffer buffer = ClientProtoUtils.toRaftClientReplyProto(reply).toByteString().asReadOnlyByteBuffer();
        return DataStreamReplyByteBuffer.newBuilder().setDataStreamPacket(request).setBuffer(buffer).setSuccess(reply.isSuccess()).build();
    }

    static void sendReply(List<CompletableFuture<DataStreamReply>> remoteWrites, DataStreamRequestByteBuf request, long bytesWritten, Collection<RaftProtos.CommitInfoProto> commitInfos, ChannelHandlerContext ctx) {
        boolean success = DataStreamManagement.checkSuccessRemoteWrite(remoteWrites, bytesWritten, request);
        DataStreamReplyByteBuffer.Builder builder = DataStreamReplyByteBuffer.newBuilder().setDataStreamPacket(request).setSuccess(success).setCommitInfos(commitInfos);
        if (success) {
            builder.setBytesWritten(bytesWritten);
        }
        ctx.writeAndFlush(builder.build());
    }

    static void replyDataStreamException(RaftServer server, Throwable cause, RaftClientRequest raftClientRequest, DataStreamRequestByteBuf request, ChannelHandlerContext ctx) {
        RaftClientReply reply = RaftClientReply.newBuilder().setRequest(raftClientRequest).setException(new DataStreamException(server.getId(), cause)).build();
        DataStreamManagement.sendDataStreamException(cause, request, reply, ctx);
    }

    void replyDataStreamException(Throwable cause, DataStreamRequestByteBuf request, ChannelHandlerContext ctx) {
        RaftClientReply reply = RaftClientReply.newBuilder().setClientId(ClientId.emptyClientId()).setServerId(this.server.getId()).setGroupId(RaftGroupId.emptyGroupId()).setException(new DataStreamException(this.server.getId(), cause)).build();
        DataStreamManagement.sendDataStreamException(cause, request, reply, ctx);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static void sendDataStreamException(Throwable throwable, DataStreamRequestByteBuf request, RaftClientReply reply, ChannelHandlerContext ctx) {
        LOG.warn("Failed to process {}", (Object)request, (Object)throwable);
        try {
            ctx.writeAndFlush(DataStreamManagement.newDataStreamReplyByteBuffer(request, reply));
        }
        catch (Throwable t2) {
            LOG.warn("Failed to sendDataStreamException {} for {}", throwable, request, t2);
        }
        finally {
            request.release();
        }
    }

    void cleanUp(Set<ClientInvocationId> ids) {
        for (ClientInvocationId clientInvocationId : ids) {
            Optional.ofNullable(this.streams.remove(clientInvocationId)).map(StreamInfo::getLocal).ifPresent(LocalStream::cleanUp);
        }
    }

    void cleanUpOnChannelInactive(ChannelId channelId, TimeDuration channelInactiveGracePeriod) {
        Optional.ofNullable(this.channels.remove(channelId)).ifPresent(ids -> {
            LOG.info("Channel {} is inactive, cleanup clientInvocationIds={}", (Object)channelId, ids);
            TimeoutExecutor.getInstance().onTimeout(channelInactiveGracePeriod, () -> this.cleanUp((Set<ClientInvocationId>)ids), LOG, () -> "Timeout check failed, clientInvocationIds=" + ids);
        });
    }

    void read(DataStreamRequestByteBuf request, ChannelHandlerContext ctx, CheckedBiFunction<RaftClientRequest, Set<RaftPeer>, Set<DataStreamOutputRpc>, IOException> getStreams) {
        LOG.debug("{}: read {}", (Object)this, (Object)request);
        try {
            this.readImpl(request, ctx, getStreams);
        }
        catch (Throwable t2) {
            this.replyDataStreamException(t2, request, ctx);
        }
    }

    private void readImpl(DataStreamRequestByteBuf request, ChannelHandlerContext ctx, CheckedBiFunction<RaftClientRequest, Set<RaftPeer>, Set<DataStreamOutputRpc>, IOException> getStreams) {
        List<Object> remoteWrites;
        CompletableFuture<Long> localWrite;
        StreamInfo info;
        boolean close = request.getWriteOptionList().contains(StandardWriteOption.CLOSE);
        ClientInvocationId key = ClientInvocationId.valueOf(request.getClientId(), request.getStreamId());
        ChannelId channelId = ctx.channel().id();
        this.channels.add(channelId, key);
        if (request.getType() == RaftProtos.DataStreamPacketHeaderProto.Type.STREAM_HEADER) {
            MemoizedSupplier<StreamInfo> supplier = JavaUtils.memoize(() -> this.newStreamInfo(request.slice(), getStreams));
            info = this.streams.computeIfAbsent(key, id -> (StreamInfo)supplier.get());
            if (!supplier.isInitialized()) {
                StreamInfo removed = this.streams.remove(key);
                removed.getLocal().cleanUp();
                throw new IllegalStateException("Failed to create a new stream for " + request + " since a stream already exists Key: " + key + " StreamInfo:" + info);
            }
            this.getMetrics().onRequestCreate(NettyServerStreamRpcMetrics.RequestType.HEADER);
        } else {
            info = close ? Optional.ofNullable(this.streams.remove(key)).orElseThrow(() -> new IllegalStateException("Failed to remove StreamInfo for " + request)) : Optional.ofNullable(this.streams.get(key)).orElseThrow(() -> new IllegalStateException("Failed to get StreamInfo for " + request));
        }
        if (request.getType() == RaftProtos.DataStreamPacketHeaderProto.Type.STREAM_HEADER) {
            localWrite = CompletableFuture.completedFuture(0L);
            remoteWrites = Collections.emptyList();
        } else if (request.getType() == RaftProtos.DataStreamPacketHeaderProto.Type.STREAM_DATA) {
            localWrite = info.getLocal().write(request.slice(), request.getWriteOptionList(), this.writeExecutor);
            remoteWrites = info.applyToRemotes(out -> out.write(request, this.requestExecutor));
        } else {
            throw new IllegalStateException(this + ": Unexpected type " + request.getType() + ", request=" + request);
        }
        DataStreamManagement.composeAsync(info.getPrevious(), this.requestExecutor, n -> JavaUtils.allOf(remoteWrites).thenCombineAsync(localWrite, (v, bytesWritten) -> {
            if (request.getType() != RaftProtos.DataStreamPacketHeaderProto.Type.STREAM_HEADER && request.getType() != RaftProtos.DataStreamPacketHeaderProto.Type.STREAM_DATA && !close) {
                throw new IllegalStateException(this + ": Unexpected type " + request.getType() + ", request=" + request);
            }
            DataStreamManagement.sendReply(remoteWrites, request, bytesWritten, info.getCommitInfos(), ctx);
            return null;
        }, this.requestExecutor)).whenComplete((v, exception) -> {
            try {
                if (exception != null) {
                    StreamInfo removed = this.streams.remove(key);
                    DataStreamManagement.replyDataStreamException(this.server, exception, info.getRequest(), request, ctx);
                    removed.getLocal().cleanUp();
                }
            }
            finally {
                request.release();
                this.channels.remove(channelId, key);
            }
        });
    }

    static void assertReplyCorrespondingToRequest(DataStreamRequestByteBuf request, DataStreamReply reply) {
        Preconditions.assertTrue(request.getClientId().equals(reply.getClientId()));
        Preconditions.assertTrue(request.getType() == reply.getType());
        Preconditions.assertTrue(request.getStreamId() == reply.getStreamId());
        Preconditions.assertTrue(request.getStreamOffset() == reply.getStreamOffset());
    }

    static boolean checkSuccessRemoteWrite(List<CompletableFuture<DataStreamReply>> replyFutures, long bytesWritten, DataStreamRequestByteBuf request) {
        for (CompletableFuture<DataStreamReply> replyFuture : replyFutures) {
            DataStreamReply reply = replyFuture.join();
            DataStreamManagement.assertReplyCorrespondingToRequest(request, reply);
            if (!reply.isSuccess()) {
                LOG.warn("reply is not success, request: {}", (Object)request);
                return false;
            }
            if (reply.getBytesWritten() == bytesWritten) continue;
            LOG.warn("reply written bytes not match, local size: {} remote size: {} request: {}", bytesWritten, reply.getBytesWritten(), request);
            return false;
        }
        return true;
    }

    NettyServerStreamRpcMetrics getMetrics() {
        return this.nettyServerStreamRpcMetrics;
    }

    public String toString() {
        return this.name;
    }

    public static class ChannelMap {
        private final Map<ChannelId, Map<ClientInvocationId, ClientInvocationId>> map = new ConcurrentHashMap<ChannelId, Map<ClientInvocationId, ClientInvocationId>>();

        public void add(ChannelId channelId, ClientInvocationId clientInvocationId) {
            this.map.computeIfAbsent(channelId, e -> new ConcurrentHashMap()).put(clientInvocationId, clientInvocationId);
        }

        public void remove(ChannelId channelId, ClientInvocationId clientInvocationId) {
            Optional.ofNullable(this.map.get(channelId)).ifPresent(ids -> {
                ClientInvocationId cfr_ignored_0 = (ClientInvocationId)ids.remove(clientInvocationId);
            });
        }

        public Set<ClientInvocationId> remove(ChannelId channelId) {
            return Optional.ofNullable(this.map.remove(channelId)).map(Map::keySet).orElse(Collections.emptySet());
        }
    }

    static class StreamMap {
        private final ConcurrentMap<ClientInvocationId, StreamInfo> map = new ConcurrentHashMap<ClientInvocationId, StreamInfo>();

        StreamMap() {
        }

        StreamInfo computeIfAbsent(ClientInvocationId key, Function<ClientInvocationId, StreamInfo> function) {
            StreamInfo info = this.map.computeIfAbsent(key, function);
            LOG.debug("computeIfAbsent({}) returns {}", (Object)key, (Object)info);
            return info;
        }

        StreamInfo get(ClientInvocationId key) {
            StreamInfo info = (StreamInfo)this.map.get(key);
            LOG.debug("get({}) returns {}", (Object)key, (Object)info);
            return info;
        }

        StreamInfo remove(ClientInvocationId key) {
            StreamInfo info = (StreamInfo)this.map.remove(key);
            LOG.debug("remove({}) returns {}", (Object)key, (Object)info);
            return info;
        }
    }

    static class StreamInfo {
        private final RaftClientRequest request;
        private final boolean primary;
        private final LocalStream local;
        private final Set<RemoteStream> remotes;
        private final RaftServer server;
        private final AtomicReference<CompletableFuture<Void>> previous = new AtomicReference<CompletableFuture<Object>>(CompletableFuture.completedFuture(null));

        StreamInfo(RaftClientRequest request, boolean primary, CompletableFuture<StateMachine.DataStream> stream, RaftServer server, CheckedBiFunction<RaftClientRequest, Set<RaftPeer>, Set<DataStreamOutputRpc>, IOException> getStreams, Function<NettyServerStreamRpcMetrics.RequestType, NettyServerStreamRpcMetrics.RequestMetrics> metricsConstructor) throws IOException {
            this.request = request;
            this.primary = primary;
            this.local = new LocalStream(stream, metricsConstructor.apply(NettyServerStreamRpcMetrics.RequestType.LOCAL_WRITE));
            this.server = server;
            Set<RaftPeer> successors = this.getSuccessors(server.getId());
            Set<DataStreamOutputRpc> outs = getStreams.apply(request, successors);
            this.remotes = outs.stream().map(o -> new RemoteStream((DataStreamOutputRpc)o, (NettyServerStreamRpcMetrics.RequestMetrics)metricsConstructor.apply(NettyServerStreamRpcMetrics.RequestType.REMOTE_WRITE))).collect(Collectors.toSet());
        }

        AtomicReference<CompletableFuture<Void>> getPrevious() {
            return this.previous;
        }

        RaftClientRequest getRequest() {
            return this.request;
        }

        RaftServer.Division getDivision() throws IOException {
            return this.server.getDivision(this.request.getRaftGroupId());
        }

        Collection<RaftProtos.CommitInfoProto> getCommitInfos() {
            try {
                return this.getDivision().getCommitInfos();
            }
            catch (IOException e) {
                throw new IllegalStateException(e);
            }
        }

        boolean isPrimary() {
            return this.primary;
        }

        LocalStream getLocal() {
            return this.local;
        }

        <T> List<T> applyToRemotes(Function<RemoteStream, T> function) {
            return this.remotes.isEmpty() ? Collections.emptyList() : this.remotes.stream().map(function).collect(Collectors.toList());
        }

        public String toString() {
            return JavaUtils.getClassSimpleName(this.getClass()) + ":" + this.request;
        }

        private Set<RaftPeer> getSuccessors(RaftPeerId peerId) throws IOException {
            RaftConfiguration conf = this.getDivision().getRaftConf();
            RoutingTable routingTable = this.request.getRoutingTable();
            if (routingTable != null) {
                return routingTable.getSuccessors(peerId).stream().map(x$0 -> conf.getPeer((RaftPeerId)x$0, new RaftProtos.RaftPeerRole[0])).collect(Collectors.toSet());
            }
            if (this.isPrimary()) {
                return conf.getCurrentPeers().stream().filter(p -> !p.getId().equals(this.server.getId())).collect(Collectors.toSet());
            }
            return Collections.emptySet();
        }
    }

    static class RemoteStream {
        private final DataStreamOutputRpc out;
        private final AtomicReference<CompletableFuture<DataStreamReply>> sendFuture = new AtomicReference<CompletableFuture<Object>>(CompletableFuture.completedFuture(null));
        private final NettyServerStreamRpcMetrics.RequestMetrics metrics;

        RemoteStream(DataStreamOutputRpc out, NettyServerStreamRpcMetrics.RequestMetrics metrics) {
            this.metrics = metrics;
            this.out = out;
        }

        CompletableFuture<DataStreamReply> write(DataStreamRequestByteBuf request, Executor executor) {
            Timekeeper.Context context = this.metrics.start();
            return DataStreamManagement.composeAsync(this.sendFuture, executor, n -> this.out.writeAsync(request.slice().nioBuffer(), request.getWriteOptionList()).whenComplete((l, e) -> this.metrics.stop(context, e == null)));
        }
    }

    static class LocalStream {
        private final CompletableFuture<StateMachine.DataStream> streamFuture;
        private final AtomicReference<CompletableFuture<Long>> writeFuture;
        private final NettyServerStreamRpcMetrics.RequestMetrics metrics;

        LocalStream(CompletableFuture<StateMachine.DataStream> streamFuture, NettyServerStreamRpcMetrics.RequestMetrics metrics) {
            this.streamFuture = streamFuture;
            this.writeFuture = new AtomicReference<CompletionStage>(streamFuture.thenApply(s2 -> 0L));
            this.metrics = metrics;
        }

        CompletableFuture<Long> write(ByteBuf buf, Iterable<WriteOption> options, Executor executor) {
            Timekeeper.Context context = this.metrics.start();
            return DataStreamManagement.composeAsync(this.writeFuture, executor, n -> this.streamFuture.thenCompose(stream -> DataStreamManagement.writeToAsync(buf, options, stream, executor).whenComplete((l, e) -> this.metrics.stop(context, e == null))));
        }

        void cleanUp() {
            this.streamFuture.thenAccept(StateMachine.DataStream::cleanUp);
        }
    }
}

