/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.nemo.runtime.common.message.ncs;

import org.apache.nemo.runtime.common.ReplyFutureMap;
import org.apache.nemo.runtime.common.comm.ControlMessage;
import org.apache.nemo.runtime.common.message.*;
import org.apache.reef.exception.evaluator.NetworkException;
import org.apache.reef.io.network.Connection;
import org.apache.reef.io.network.ConnectionFactory;
import org.apache.reef.io.network.Message;
import org.apache.reef.io.network.NetworkConnectionService;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.wake.EventHandler;
import org.apache.reef.wake.IdentifierFactory;
import org.apache.reef.wake.remote.transport.LinkListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.inject.Inject;
import java.net.SocketAddress;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Future;

/**
 * Message environment for NCS.
 */
public final class NcsMessageEnvironment implements MessageEnvironment {
  private static final Logger LOG = LoggerFactory.getLogger(NcsMessageEnvironment.class.getName());

  private static final String NCS_CONN_FACTORY_ID = "NCS_CONN_FACTORY_ID";

  private final NetworkConnectionService networkConnectionService;
  private final IdentifierFactory idFactory;
  private final String senderId;

  private final ReplyFutureMap<ControlMessage.Message> replyFutureMap;
  private final ConcurrentMap<String, MessageListener> listenerConcurrentMap;
  private final Map<String, Connection> receiverToConnectionMap;
  private final ConnectionFactory<ControlMessage.Message> connectionFactory;

  @Inject
  private NcsMessageEnvironment(
    final NetworkConnectionService networkConnectionService,
    final IdentifierFactory idFactory,
    @Parameter(MessageParameters.SenderId.class) final String senderId) {
    this.networkConnectionService = networkConnectionService;
    this.idFactory = idFactory;
    this.senderId = senderId;
    this.replyFutureMap = new ReplyFutureMap<>();
    this.listenerConcurrentMap = new ConcurrentHashMap<>();
    this.receiverToConnectionMap = new ConcurrentHashMap<>();
    this.connectionFactory = networkConnectionService.registerConnectionFactory(
      idFactory.getNewInstance(NCS_CONN_FACTORY_ID),
      new ControlMessageCodec(),
      new NcsMessageHandler(),
      new NcsLinkListener(),
      idFactory.getNewInstance(senderId));
  }

  @Override
  public <T> void setupListener(final String listenerId, final MessageListener<T> listener) {
    if (listenerConcurrentMap.putIfAbsent(listenerId, listener) != null) {
      throw new RuntimeException("A listener for " + listenerId + " was already setup");
    }
  }

  @Override
  public void removeListener(final String listenerId) {
    listenerConcurrentMap.remove(listenerId);
  }

  @Override
  public <T> Future<MessageSender<T>> asyncConnect(final String receiverId, final String listenerId) {
    // If the connection toward the receiver exists already, reuses it.
    final Connection connection;

    if (receiverToConnectionMap.containsKey(receiverId)) {
      connection = receiverToConnectionMap.get(receiverId);
    } else {
      connection = connectionFactory.newConnection(idFactory.getNewInstance(receiverId));
      try {
        connection.open();
      } catch (final NetworkException e) {
        try {
          connection.close();
        } catch (final NetworkException exceptionToIgnore) {
          LOG.info("Can't close the broken connection.", exceptionToIgnore);
        }

        final CompletableFuture<MessageSender<T>> failedFuture = new CompletableFuture<>();
        failedFuture.completeExceptionally(e);
        return failedFuture;
      }
    }

    return CompletableFuture.completedFuture((MessageSender) new NcsMessageSender(connection, replyFutureMap));
  }

  @Override
  public String getId() {
    return senderId;
  }

  @Override
  public void close() throws Exception {
    networkConnectionService.close();
  }

  /**
   * Message handler for NCS.
   */
  private final class NcsMessageHandler implements EventHandler<Message<ControlMessage.Message>> {

    public void onNext(final Message<ControlMessage.Message> messages) {
      final ControlMessage.Message controlMessage = extractSingleMessage(messages);
      final MessageType messageType = getMsgType(controlMessage);
      switch (messageType) {
        case Send:
          processSendMessage(controlMessage);
          break;
        case Request:
          processRequestMessage(controlMessage);
          break;
        case Reply:
          processReplyMessage(controlMessage);
          break;
        default:
          throw new IllegalArgumentException(controlMessage.toString());
      }
    }

    private void processSendMessage(final ControlMessage.Message controlMessage) {
      final String listenerId = controlMessage.getListenerId();
      listenerConcurrentMap.get(listenerId).onMessage(controlMessage);
    }

    private void processRequestMessage(final ControlMessage.Message controlMessage) {
      final String listenerId = controlMessage.getListenerId();
      final String executorId = getExecutorId(controlMessage);
      final MessageContext messageContext = new NcsMessageContext(executorId, connectionFactory, idFactory);
      listenerConcurrentMap.get(listenerId).onMessageWithContext(controlMessage, messageContext);
    }

    private void processReplyMessage(final ControlMessage.Message controlMessage) {
      final long requestId = getRequestId(controlMessage);
      replyFutureMap.onSuccessMessage(requestId, controlMessage);
    }
  }

  /**
   * LinkListener for NCS.
   */
  private final class NcsLinkListener implements LinkListener<Message<ControlMessage.Message>> {

    public void onSuccess(final Message<ControlMessage.Message> messages) {
      // No-ops.
    }

    public void onException(final Throwable throwable,
                            final SocketAddress socketAddress,
                            final Message<ControlMessage.Message> messages) {
      // TODO #140: Properly classify and handle each RPC failure
      // Not logging the stacktrace here, as it's not very useful.
      LOG.error("NCS Exception");
    }
  }

  private ControlMessage.Message extractSingleMessage(final Message<ControlMessage.Message> messages) {
    return messages.getData().iterator().next();
  }

  /**
   * Send: Messages sent without expecting a reply.
   * Request: Messages sent to get a reply.
   * Reply: Messages that reply to a request.
   * <p>
   * Not sure these variable names are conventionally used in RPC frameworks...
   * Let's revisit them when we work on
   */
  enum MessageType {
    Send,
    Request,
    Reply
  }

  private MessageType getMsgType(final ControlMessage.Message controlMessage) {
    switch (controlMessage.getType()) {
      case TaskStateChanged:
      case ScheduleTask:
      case BlockStateChanged:
      case ExecutorFailed:
      case RunTimePassMessage:
      case ExecutorDataCollected:
      case MetricMessageReceived:
      case RequestMetricFlush:
      case MetricFlushed:
      case PipeInit:
        return MessageType.Send;
      case RequestBlockLocation:
      case RequestBroadcastVariable:
      case RequestPipeLoc:
        return MessageType.Request;
      case BlockLocationInfo:
      case InMasterBroadcastVariable:
      case PipeLocInfo:
        return MessageType.Reply;
      default:
        throw new IllegalArgumentException(controlMessage.toString());
    }
  }

  private String getExecutorId(final ControlMessage.Message controlMessage) {
    switch (controlMessage.getType()) {
      case RequestBlockLocation:
        return controlMessage.getRequestBlockLocationMsg().getExecutorId();
      case RequestBroadcastVariable:
        return controlMessage.getRequestbroadcastVariableMsg().getExecutorId();
      case RequestPipeLoc:
        return controlMessage.getRequestPipeLocMsg().getExecutorId();
      default:
        throw new IllegalArgumentException(controlMessage.toString());
    }
  }

  private long getRequestId(final ControlMessage.Message controlMessage) {
    switch (controlMessage.getType()) {
      case BlockLocationInfo:
        return controlMessage.getBlockLocationInfoMsg().getRequestId();
      case InMasterBroadcastVariable:
        return controlMessage.getBroadcastVariableMsg().getRequestId();
      case PipeLocInfo:
        return controlMessage.getPipeLocInfoMsg().getRequestId();
      default:
        throw new IllegalArgumentException(controlMessage.toString());
    }
  }
}
