/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.nemo.runtime.common.message.local;

import org.apache.nemo.runtime.common.message.*;
import org.apache.reef.tang.Injector;
import org.apache.reef.tang.Tang;
import org.junit.Assert;
import org.junit.Test;

import java.io.Serializable;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Tests local messaging components.
 */
public class LocalMessageTest {
  private static final Tang TANG = Tang.Factory.getTang();

  @Test
  public void testLocalMessages() throws Exception {
    final String driverNodeId = "DRIVER_NODE";
    final String executorOneNodeId = "EXECUTOR_ONE_NODE";
    final String executorTwoNodeId = "EXECUTOR_TWO_NODE";

    final String listenerIdToDriver = "ToDriver";
    final String secondListenerIdToDriver = "SecondToDriver";
    final String listenerIdBetweenExecutors = "BetweenExecutors";

    final Injector injector = TANG.newInjector(TANG.newConfigurationBuilder()
      .bindImplementation(MessageEnvironment.class, LocalMessageEnvironment.class).build());
    injector.getInstance(LocalMessageDispatcher.class);

    final MessageEnvironment driverEnv = injector.forkInjector(TANG.newConfigurationBuilder()
      .bindNamedParameter(MessageParameters.SenderId.class, driverNodeId).build())
      .getInstance(MessageEnvironment.class);
    final MessageEnvironment executorOneEnv = injector.forkInjector(TANG.newConfigurationBuilder()
      .bindNamedParameter(MessageParameters.SenderId.class, executorOneNodeId).build())
      .getInstance(MessageEnvironment.class);
    final MessageEnvironment executorTwoEnv = injector.forkInjector(TANG.newConfigurationBuilder()
      .bindNamedParameter(MessageParameters.SenderId.class, executorTwoNodeId).build())
      .getInstance(MessageEnvironment.class);

    final AtomicInteger toDriverMessageUsingSend = new AtomicInteger();

    driverEnv.setupListener(listenerIdToDriver, new MessageListener<ToDriver>() {
      @Override
      public void onMessage(final ToDriver message) {
        toDriverMessageUsingSend.incrementAndGet();
      }

      @Override
      public void onMessageWithContext(final ToDriver message, final MessageContext messageContext) {
        messageContext.reply(true);
      }
    });

    // Setup multiple listeners.
    driverEnv.setupListener(secondListenerIdToDriver, new MessageListener<SecondToDriver>() {
      @Override
      public void onMessage(SecondToDriver message) {
      }

      @Override
      public void onMessageWithContext(SecondToDriver message, MessageContext messageContext) {
      }
    });

    // Test sending message from executors to the driver.

    final Future<MessageSender<ToDriver>> messageSenderFuture1 = executorOneEnv.asyncConnect(
      driverNodeId, listenerIdToDriver);
    Assert.assertTrue(messageSenderFuture1.isDone());
    final MessageSender<ToDriver> messageSender1 = messageSenderFuture1.get();

    final Future<MessageSender<ToDriver>> messageSenderFuture2 = executorTwoEnv.asyncConnect(
      driverNodeId, listenerIdToDriver);
    Assert.assertTrue(messageSenderFuture2.isDone());
    final MessageSender<ToDriver> messageSender2 = messageSenderFuture2.get();

    messageSender1.send(new ExecutorStarted());
    messageSender2.send(new ExecutorStarted());

    Assert.assertEquals(2, toDriverMessageUsingSend.get());
    Assert.assertTrue(messageSender1.<Boolean>request(new ExecutorStarted()).get());
    Assert.assertTrue(messageSender2.<Boolean>request(new ExecutorStarted()).get());

    // Test exchanging messages between executors.

    final AtomicInteger executorOneMessageCount = new AtomicInteger();
    final AtomicInteger executorTwoMessageCount = new AtomicInteger();

    executorOneEnv.setupListener(listenerIdBetweenExecutors, new SimpleMessageListener(executorOneMessageCount));
    executorTwoEnv.setupListener(listenerIdBetweenExecutors, new SimpleMessageListener(executorTwoMessageCount));

    final MessageSender<BetweenExecutors> oneToTwo = executorOneEnv.<BetweenExecutors>asyncConnect(
      executorTwoNodeId, listenerIdBetweenExecutors).get();
    final MessageSender<BetweenExecutors> twoToOne = executorTwoEnv.<BetweenExecutors>asyncConnect(
      executorOneNodeId, listenerIdBetweenExecutors).get();

    Assert.assertEquals("oneToTwo", oneToTwo.<String>request(new SimpleMessage("oneToTwo")).get());
    Assert.assertEquals("twoToOne", twoToOne.<String>request(new SimpleMessage("twoToOne")).get());
    Assert.assertEquals(1, executorOneMessageCount.get());
    Assert.assertEquals(1, executorTwoMessageCount.get());

    // Test deletion and re-setting of listener.

    final AtomicInteger newExecutorOneMessageCount = new AtomicInteger();
    final AtomicInteger newExecutorTwoMessageCount = new AtomicInteger();

    executorOneEnv.removeListener(listenerIdBetweenExecutors);
    executorTwoEnv.removeListener(listenerIdBetweenExecutors);

    executorOneEnv.setupListener(listenerIdBetweenExecutors, new SimpleMessageListener(newExecutorOneMessageCount));
    executorTwoEnv.setupListener(listenerIdBetweenExecutors, new SimpleMessageListener(newExecutorTwoMessageCount));

    final MessageSender<BetweenExecutors> newOneToTwo = executorOneEnv.<BetweenExecutors>asyncConnect(
      executorTwoNodeId, listenerIdBetweenExecutors).get();
    final MessageSender<BetweenExecutors> newTwoToOne = executorTwoEnv.<BetweenExecutors>asyncConnect(
      executorOneNodeId, listenerIdBetweenExecutors).get();

    Assert.assertEquals("newOneToTwo", newOneToTwo.<String>request(new SimpleMessage("newOneToTwo")).get());
    Assert.assertEquals("newTwoToOne", newTwoToOne.<String>request(new SimpleMessage("newTwoToOne")).get());
    Assert.assertEquals(1, executorOneMessageCount.get());
    Assert.assertEquals(1, executorTwoMessageCount.get());
    Assert.assertEquals(1, newExecutorOneMessageCount.get());
    Assert.assertEquals(1, newExecutorTwoMessageCount.get());
  }

  final class SimpleMessageListener implements MessageListener<SimpleMessage> {
    private final AtomicInteger messageCount;

    private SimpleMessageListener(final AtomicInteger messageCount) {
      this.messageCount = messageCount;
    }

    @Override
    public void onMessage(final SimpleMessage message) {
      // Expected not reached here.
      throw new RuntimeException();
    }

    @Override
    public void onMessageWithContext(final SimpleMessage message, final MessageContext messageContext) {
      messageCount.getAndIncrement();
      messageContext.reply(message.getData());
    }
  }

  interface ToDriver extends Serializable {
  }

  final class ExecutorStarted implements ToDriver {
  }

  interface SecondToDriver extends Serializable {
  }

  interface BetweenExecutors extends Serializable {
  }

  final class SimpleMessage implements BetweenExecutors {
    private final String data;

    SimpleMessage(final String data) {
      this.data = data;
    }

    public String getData() {
      return data;
    }
  }
}
