/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.command;

import java.io.IOException;
import java.util.Objects;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.service.api.records.Component;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.RunJobParameters;
import org.apache.hadoop.yarn.submarine.client.cli.param.runjob.TensorFlowRunJobParameters;
import org.apache.hadoop.yarn.submarine.common.api.Role;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.HadoopEnvironmentSetup;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.AbstractLaunchCommand;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.command.LaunchScriptBuilder;
import org.apache.hadoop.yarn.submarine.runtimes.yarnservice.tensorflow.TensorFlowCommons;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class TensorFlowLaunchCommand
extends AbstractLaunchCommand {
    private static final Logger LOG = LoggerFactory.getLogger(TensorFlowLaunchCommand.class);
    private final Configuration yarnConfig;
    private final boolean distributed;
    private final int numberOfWorkers;
    private final int numberOfPS;
    private final String name;
    private final Role role;

    TensorFlowLaunchCommand(HadoopEnvironmentSetup hadoopEnvSetup, Role role, Component component, TensorFlowRunJobParameters parameters, Configuration yarnConfig) throws IOException {
        super(hadoopEnvSetup, component, (RunJobParameters)parameters, role != null ? role.getName() : "");
        Objects.requireNonNull(role, "TensorFlowRole must not be null!");
        this.role = role;
        this.name = parameters.getName();
        this.distributed = parameters.isDistributed();
        this.numberOfWorkers = parameters.getNumWorkers();
        this.numberOfPS = parameters.getNumPS();
        this.yarnConfig = yarnConfig;
        this.logReceivedParameters();
    }

    private void logReceivedParameters() {
        if (this.numberOfWorkers <= 0) {
            LOG.warn("Received number of workers: {}", (Object)this.numberOfWorkers);
        }
        if (this.numberOfPS <= 0) {
            LOG.warn("Received number of PS: {}", (Object)this.numberOfPS);
        }
    }

    @Override
    public String generateLaunchScript() throws IOException {
        LaunchScriptBuilder builder = this.getBuilder();
        if (this.distributed) {
            String tfConfigEnvValue = TensorFlowCommons.getTFConfigEnv(this.role.getComponentName(), this.numberOfWorkers, this.numberOfPS, this.name, TensorFlowCommons.getUserName(), TensorFlowCommons.getDNSDomain(this.yarnConfig));
            String tfConfig = "export TF_CONFIG=\"" + tfConfigEnvValue + "\"\n";
            builder.append(tfConfig);
        }
        return builder.withLaunchCommand(this.createLaunchCommand()).build();
    }
}

