/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.ratis.client.impl;

import org.apache.ratis.client.RaftClient;
import org.apache.ratis.client.RaftClientRpc;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto.TypeCase;
import org.apache.ratis.proto.RaftProtos.ReplicationLevel;
import org.apache.ratis.proto.RaftProtos.SlidingWindowEntry;
import org.apache.ratis.protocol.*;
import org.apache.ratis.retry.RetryPolicy;
import org.apache.ratis.util.CollectionUtils;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.Preconditions;
import org.apache.ratis.util.TimeoutScheduler;

import java.io.IOException;
import java.io.InterruptedIOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Stream;

/** A client who sends requests to a raft service. */
final class RaftClientImpl implements RaftClient {
  private static final AtomicLong CALL_ID_COUNTER = new AtomicLong();

  static long nextCallId() {
    return CALL_ID_COUNTER.getAndIncrement() & Long.MAX_VALUE;
  }

  abstract static class PendingClientRequest {
    private final CompletableFuture<RaftClientReply> replyFuture = new CompletableFuture<>();
    private final AtomicInteger attemptCount = new AtomicInteger();

    abstract RaftClientRequest newRequestImpl();

    final RaftClientRequest newRequest() {
      attemptCount.incrementAndGet();
      return newRequestImpl();
    }

    CompletableFuture<RaftClientReply> getReplyFuture() {
      return replyFuture;
    }

    int getAttemptCount() {
      return attemptCount.get();
    }
  }

  private final ClientId clientId;
  private final RaftClientRpc clientRpc;
  private final Collection<RaftPeer> peers;
  private final RaftGroupId groupId;
  private final RetryPolicy retryPolicy;

  private volatile RaftPeerId leaderId;

  private final TimeoutScheduler scheduler;

  private final Supplier<OrderedAsync> orderedAsync;

  RaftClientImpl(ClientId clientId, RaftGroup group, RaftPeerId leaderId,
      RaftClientRpc clientRpc, RaftProperties properties, RetryPolicy retryPolicy) {
    this.clientId = clientId;
    this.clientRpc = clientRpc;
    this.peers = new ConcurrentLinkedQueue<>(group.getPeers());
    this.groupId = group.getGroupId();
    this.leaderId = leaderId != null? leaderId
        : !peers.isEmpty()? peers.iterator().next().getId(): null;
    Preconditions.assertTrue(retryPolicy != null, "retry policy can't be null");
    this.retryPolicy = retryPolicy;

    scheduler = TimeoutScheduler.newInstance(0);
    clientRpc.addServers(peers);

    this.orderedAsync = JavaUtils.memoize(() -> new OrderedAsync(this, properties));
  }

  @Override
  public ClientId getId() {
    return clientId;
  }

  RetryPolicy getRetryPolicy() {
    return retryPolicy;
  }

  TimeoutScheduler getScheduler() {
    return scheduler;
  }

  OrderedAsync getOrderedAsync() {
    return orderedAsync.get();
  }

  @Override
  public CompletableFuture<RaftClientReply> sendAsync(Message message) {
    return sendAsync(RaftClientRequest.writeRequestType(), message, null);
  }

  @Override
  public CompletableFuture<RaftClientReply> sendReadOnlyAsync(Message message) {
    return sendAsync(RaftClientRequest.readRequestType(), message, null);
  }

  @Override
  public CompletableFuture<RaftClientReply> sendStaleReadAsync(Message message, long minIndex, RaftPeerId server) {
    return sendAsync(RaftClientRequest.staleReadRequestType(minIndex), message, server);
  }

  @Override
  public CompletableFuture<RaftClientReply> sendWatchAsync(long index, ReplicationLevel replication) {
    return UnorderedAsync.send(RaftClientRequest.watchRequestType(index, replication), this);
  }

  private CompletableFuture<RaftClientReply> sendAsync(
      RaftClientRequest.Type type, Message message, RaftPeerId server) {
    return getOrderedAsync().send(type, message, server);
  }

  RaftClientRequest newRaftClientRequest(
      RaftPeerId server, long callId, Message message, RaftClientRequest.Type type,
      SlidingWindowEntry slidingWindowEntry) {
    return new RaftClientRequest(clientId, server != null? server: leaderId, groupId,
        callId, message, type, slidingWindowEntry);
  }

  @Override
  public RaftClientReply send(Message message) throws IOException {
    return send(RaftClientRequest.writeRequestType(), message, null);
  }

  @Override
  public RaftClientReply sendReadOnly(Message message) throws IOException {
    return send(RaftClientRequest.readRequestType(), message, null);
  }

  @Override
  public RaftClientReply sendStaleRead(Message message, long minIndex, RaftPeerId server)
      throws IOException {
    return send(RaftClientRequest.staleReadRequestType(minIndex), message, server);
  }

  @Override
  public RaftClientReply sendWatch(long index, ReplicationLevel replication) throws IOException {
    return send(RaftClientRequest.watchRequestType(index, replication), null, null);
  }

  private RaftClientReply send(RaftClientRequest.Type type, Message message, RaftPeerId server)
      throws IOException {
    if (!type.is(TypeCase.WATCH)) {
      Objects.requireNonNull(message, "message == null");
    }

    final long callId = nextCallId();
    return sendRequestWithRetry(() -> newRaftClientRequest(server, callId, message, type, null));
  }

  // TODO: change peersInNewConf to List<RaftPeer>
  @Override
  public RaftClientReply setConfiguration(RaftPeer[] peersInNewConf)
      throws IOException {
    Objects.requireNonNull(peersInNewConf, "peersInNewConf == null");

    final long callId = nextCallId();
    // also refresh the rpc proxies for these peers
    addServers(Arrays.stream(peersInNewConf));
    return sendRequestWithRetry(() -> new SetConfigurationRequest(
        clientId, leaderId, groupId, callId, Arrays.asList(peersInNewConf)));
  }

  @Override
  public RaftClientReply groupAdd(RaftGroup newGroup, RaftPeerId server) throws IOException {
    Objects.requireNonNull(newGroup, "newGroup == null");
    Objects.requireNonNull(server, "server == null");

    final long callId = nextCallId();
    addServers(newGroup.getPeers().stream());
    return sendRequest(GroupManagementRequest.newAdd(clientId, server, callId, newGroup));
  }

  @Override
  public RaftClientReply groupRemove(RaftGroupId grpId, boolean deleteDirectory, RaftPeerId server)
      throws IOException {
    Objects.requireNonNull(groupId, "groupId == null");
    Objects.requireNonNull(server, "server == null");

    final long callId = nextCallId();
    return sendRequest(GroupManagementRequest.newRemove(clientId, server, callId, grpId, deleteDirectory));
  }

  @Override
  public GroupListReply getGroupList(RaftPeerId server) throws IOException {
    Objects.requireNonNull(server, "server == null");

    final RaftClientReply reply = sendRequest(new GroupListRequest(clientId, server, groupId, nextCallId()));
    Preconditions.assertTrue(reply instanceof GroupListReply, () -> "Unexpected reply: " + reply);
    return (GroupListReply)reply;
  }

  @Override
  public GroupInfoReply getGroupInfo(RaftGroupId raftGroupId, RaftPeerId server) throws IOException {
    Objects.requireNonNull(server, "server == null");
    RaftGroupId rgi = raftGroupId == null ? groupId : raftGroupId;
    final RaftClientReply reply = sendRequest(new GroupInfoRequest(clientId, server, rgi, nextCallId()));
    Preconditions.assertTrue(reply instanceof GroupInfoReply, () -> "Unexpected reply: " + reply);
    return (GroupInfoReply)reply;
  }

  private void addServers(Stream<RaftPeer> peersInNewConf) {
    clientRpc.addServers(
        peersInNewConf.filter(p -> !peers.contains(p))::iterator);
  }

  private RaftClientReply sendRequestWithRetry(Supplier<RaftClientRequest> supplier) throws IOException {
    for(int attemptCount = 1;; attemptCount++) {
      final RaftClientRequest request = supplier.get();
      IOException ioe = null;
      try {
        final RaftClientReply reply = sendRequest(request);
        if (reply != null) {
          return reply;
        }
      } catch (GroupMismatchException | StateMachineException e) {
        throw e;
      } catch (IOException e) {
        ioe = e;
      }
      if (!retryPolicy.shouldRetry(attemptCount, request)) {
        throw (IOException)noMoreRetries(request, attemptCount, ioe);
      }

      try {
        retryPolicy.getSleepTime(attemptCount, request).sleep();
      } catch (InterruptedException e) {
        throw new InterruptedIOException("retry policy=" + retryPolicy);
      }
    }
  }

  Throwable noMoreRetries(RaftClientRequest request, int attemptCount, Throwable throwable) {
    if (attemptCount == 1 && throwable != null) {
      return throwable;
    }
    return new RaftRetryFailureException(request, attemptCount, retryPolicy, throwable);
  }

  private RaftClientReply sendRequest(RaftClientRequest request) throws IOException {
    LOG.debug("{}: send {}", clientId, request);
    RaftClientReply reply;
    try {
      reply = clientRpc.sendRequest(request);
    } catch (GroupMismatchException gme) {
      throw gme;
    } catch (IOException ioe) {
      handleIOException(request, ioe);
      throw ioe;
    }
    LOG.debug("{}: receive {}", clientId, reply);
    reply = handleLeaderException(request, reply, null);
    reply = handleRaftException(reply, Function.identity());
    return reply;
  }

  static <E extends Throwable> RaftClientReply handleRaftException(
      RaftClientReply reply, Function<RaftException, E> converter) throws E {
    if (reply != null) {
      final RaftException e = reply.getException();
      if (e != null) {
        throw converter.apply(e);
      }
    }
    return reply;
  }

  /**
   * @return null if the reply is null or it has
   * {@link NotLeaderException} or {@link LeaderNotReadyException}
   * otherwise return the same reply.
   */
  RaftClientReply handleLeaderException(RaftClientRequest request, RaftClientReply reply,
                                        Consumer<RaftClientRequest> handler) {
    if (reply == null || reply.getException() instanceof LeaderNotReadyException) {
      return null;
    }
    final NotLeaderException nle = reply.getNotLeaderException();
    if (nle == null) {
      return reply;
    }
    return handleNotLeaderException(request, nle, handler);
  }

  RaftClientReply handleNotLeaderException(RaftClientRequest request, NotLeaderException nle,
      Consumer<RaftClientRequest> handler) {
    refreshPeers(nle.getPeers());
    final RaftPeerId newLeader = nle.getSuggestedLeader() == null ? null
        : nle.getSuggestedLeader().getId();
    handleIOException(request, nle, newLeader, handler);
    return null;
  }

  private void refreshPeers(Collection<RaftPeer> newPeers) {
    if (newPeers != null && newPeers.size() > 0) {
      peers.clear();
      peers.addAll(newPeers);
      // also refresh the rpc proxies for these peers
      clientRpc.addServers(newPeers);
    }
  }

  void handleIOException(RaftClientRequest request, IOException ioe) {
    handleIOException(request, ioe, null, null);
  }

  void handleIOException(RaftClientRequest request, IOException ioe,
      RaftPeerId newLeader, Consumer<RaftClientRequest> handler) {
    LOG.debug("{}: suggested new leader: {}. Failed {} with {}",
        clientId, newLeader, request, ioe);
    if (LOG.isTraceEnabled()) {
      LOG.trace("Stack trace", new Throwable("TRACE"));
    }

    Optional.ofNullable(handler).ifPresent(h -> h.accept(request));

    if (ioe instanceof LeaderNotReadyException) {
      return;
    }

    final RaftPeerId oldLeader = request.getServerId();
    final RaftPeerId curLeader = leaderId;
    final boolean stillLeader = oldLeader.equals(curLeader);
    if (newLeader == null && stillLeader) {
      newLeader = CollectionUtils.random(oldLeader,
          CollectionUtils.as(peers, RaftPeer::getId));
    }
    LOG.debug("{}: oldLeader={},  curLeader={}, newLeader={}", clientId, oldLeader, curLeader, newLeader);

    final boolean changeLeader =
        !(ioe instanceof AlreadyClosedException || ioe instanceof TimeoutIOException)
            && newLeader != null && stillLeader;
    final boolean reconnect = changeLeader || clientRpc.shouldReconnect(ioe);
    if (reconnect) {
      if (changeLeader && oldLeader.equals(leaderId)) {
        LOG.debug("{} {}: client change Leader from {} to {} ex={}", groupId,
            clientId, oldLeader, newLeader, ioe.getClass().getName());
        this.leaderId = newLeader;
      }
      clientRpc.handleException(oldLeader, ioe, reconnect);
    }
  }

  long getCallId() {
    return CALL_ID_COUNTER.get();
  }

  @Override
  public RaftClientRpc getClientRpc() {
    return clientRpc;
  }

  @Override
  public void close() throws IOException {
    scheduler.close();
    clientRpc.close();
  }
}
