/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.client.api.AppAdminClient;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.service.api.records.Service;
import org.apache.hadoop.yarn.service.utils.ServiceApiUtil;
import org.apache.hadoop.yarn.submarine.client.cli.param.ParametersHolder;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.PyTorchRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.runjob.Framework;
import org.apache.hadoop.yarn.submarine.common.ClientContext;
import org.apache.hadoop.yarn.submarine.runtimes.common.JobSubmitter;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.FileSystemOperations;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceSpecFileGenerator;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.ServiceWrapper;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.YarnServiceUtils;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.PyTorchLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.TensorFlowLaunchCommandFactory;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.pytorch.PyTorchServiceSpec;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowServiceSpec;
import org.apache.hadoop.yarn.submarine.utils.Localizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class YarnServiceJobSubmitter
implements JobSubmitter {
    private static final Logger LOG = LoggerFactory.getLogger(YarnServiceJobSubmitter.class);
    private ClientContext clientContext;
    private ServiceWrapper serviceWrapper;

    YarnServiceJobSubmitter(ClientContext clientContext) {
        this.clientContext = clientContext;
    }

    public ApplicationId submitJob(ParametersHolder paramsHolder) throws IOException, YarnException {
        Framework framework = paramsHolder.getFramework();
        RunJobParameters parameters = (RunJobParameters)paramsHolder.getParameters();
        if (framework == Framework.TENSORFLOW) {
            return this.submitTensorFlowJob((TensorFlowRunJobParameters)parameters);
        }
        if (framework == Framework.PYTORCH) {
            return this.submitPyTorchJob((PyTorchRunJobParameters)parameters);
        }
        throw new UnsupportedOperationException("TensorFlow and PyTorch are the only supported frameworks for now!");
    }

    private ApplicationId submitTensorFlowJob(TensorFlowRunJobParameters parameters) throws IOException, YarnException {
        FileSystemOperations fsOperations = new FileSystemOperations(this.clientContext);
        HadoopEnvironmentSetup hadoopEnvSetup = new HadoopEnvironmentSetup(this.clientContext, fsOperations);
        Service serviceSpec = this.createTensorFlowServiceSpec(parameters, fsOperations, hadoopEnvSetup);
        return this.submitJobInternal(serviceSpec);
    }

    private ApplicationId submitPyTorchJob(PyTorchRunJobParameters parameters) throws IOException, YarnException {
        FileSystemOperations fsOperations = new FileSystemOperations(this.clientContext);
        HadoopEnvironmentSetup hadoopEnvSetup = new HadoopEnvironmentSetup(this.clientContext, fsOperations);
        Service serviceSpec = this.createPyTorchServiceSpec(parameters, fsOperations, hadoopEnvSetup);
        return this.submitJobInternal(serviceSpec);
    }

    private ApplicationId submitJobInternal(Service serviceSpec) throws IOException, YarnException {
        String serviceSpecFile = ServiceSpecFileGenerator.generateJson(serviceSpec);
        AppAdminClient appAdminClient = YarnServiceUtils.createServiceClient(this.clientContext.getYarnConfig());
        int code = appAdminClient.actionLaunch(serviceSpecFile, serviceSpec.getName(), null, null);
        if (code != 0) {
            throw new YarnException("Fail to launch application with exit code:" + code);
        }
        String appStatus = appAdminClient.getStatusString(serviceSpec.getName());
        Service app = (Service)ServiceApiUtil.jsonSerDeser.fromJson(appStatus);
        int maxRetryTimes = 30;
        for (int count = 0; app.getId() == null && count < maxRetryTimes; ++count) {
            LOG.info("Waiting for application Id. AppStatusString=\n {}", (Object)appStatus);
            try {
                Thread.sleep(1000L);
            }
            catch (InterruptedException e) {
                throw new IOException(e);
            }
            appStatus = appAdminClient.getStatusString(serviceSpec.getName());
            app = (Service)ServiceApiUtil.jsonSerDeser.fromJson(appStatus);
        }
        if (app.getId() == null) {
            throw new YarnException("Can't get application id for Service " + serviceSpec.getName());
        }
        ApplicationId appid = ApplicationId.fromString((String)app.getId());
        appAdminClient.stop();
        return appid;
    }

    private Service createTensorFlowServiceSpec(TensorFlowRunJobParameters parameters, FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup) throws IOException {
        TensorFlowLaunchCommandFactory launchCommandFactory = new TensorFlowLaunchCommandFactory(hadoopEnvSetup, parameters, this.clientContext.getYarnConfig());
        Localizer localizer = new Localizer(fsOperations, this.clientContext.getRemoteDirectoryManager(), (RunJobParameters)parameters);
        TensorFlowServiceSpec tensorFlowServiceSpec = new TensorFlowServiceSpec(parameters, this.clientContext, fsOperations, launchCommandFactory, localizer);
        this.serviceWrapper = tensorFlowServiceSpec.create();
        return this.serviceWrapper.getService();
    }

    private Service createPyTorchServiceSpec(PyTorchRunJobParameters parameters, FileSystemOperations fsOperations, HadoopEnvironmentSetup hadoopEnvSetup) throws IOException {
        PyTorchLaunchCommandFactory launchCommandFactory = new PyTorchLaunchCommandFactory(hadoopEnvSetup, parameters, this.clientContext.getYarnConfig());
        Localizer localizer = new Localizer(fsOperations, this.clientContext.getRemoteDirectoryManager(), (RunJobParameters)parameters);
        PyTorchServiceSpec pyTorchServiceSpec = new PyTorchServiceSpec(parameters, this.clientContext, fsOperations, launchCommandFactory, localizer);
        this.serviceWrapper = pyTorchServiceSpec.create();
        return this.serviceWrapper.getService();
    }

    @VisibleForTesting
    public ServiceWrapper getServiceWrapper() {
        return this.serviceWrapper;
    }
}

