/*
 * Decompiled with CFR 0.152.
 */
package org.apache.samza.execution;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections4.ListUtils;
import org.apache.samza.SamzaException;
import org.apache.samza.application.LegacyTaskApplication;
import org.apache.samza.application.descriptors.ApplicationDescriptor;
import org.apache.samza.application.descriptors.ApplicationDescriptorImpl;
import org.apache.samza.config.ApplicationConfig;
import org.apache.samza.config.ClusterManagerConfig;
import org.apache.samza.config.Config;
import org.apache.samza.config.JobConfig;
import org.apache.samza.config.StreamConfig;
import org.apache.samza.execution.ExecutionPlan;
import org.apache.samza.execution.IntermediateStreamManager;
import org.apache.samza.execution.JobGraph;
import org.apache.samza.execution.JobNode;
import org.apache.samza.execution.OperatorSpecGraphAnalyzer;
import org.apache.samza.execution.StreamEdge;
import org.apache.samza.execution.StreamManager;
import org.apache.samza.operators.spec.InputOperatorSpec;
import org.apache.samza.operators.spec.OperatorSpec;
import org.apache.samza.operators.spec.StreamTableJoinOperatorSpec;
import org.apache.samza.system.StreamSpec;
import org.apache.samza.table.TableSpec;
import org.apache.samza.table.descriptors.BaseTableDescriptor;
import org.apache.samza.util.StreamUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ExecutionPlanner {
    private static final Logger log = LoggerFactory.getLogger(ExecutionPlanner.class);
    private final Config config;
    private final StreamManager streamManager;
    private final StreamConfig streamConfig;

    public ExecutionPlanner(Config config, StreamManager streamManager) {
        this.config = config;
        this.streamManager = streamManager;
        this.streamConfig = new StreamConfig(config);
    }

    public ExecutionPlan plan(ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc) {
        this.validateConfig();
        JobGraph jobGraph = this.createJobGraph(appDesc);
        this.setInputAndOutputStreamPartitionCount(jobGraph);
        List<StreamSet> joinedStreamSets = ExecutionPlanner.groupJoinedStreams(jobGraph);
        if (!jobGraph.getIntermediateStreamEdges().isEmpty()) {
            new IntermediateStreamManager(this.config).calculatePartitions(jobGraph, joinedStreamSets);
        }
        joinedStreamSets.forEach(ExecutionPlanner::validatePartitions);
        return jobGraph;
    }

    private void validateConfig() {
        ApplicationConfig appConfig = new ApplicationConfig(this.config);
        ClusterManagerConfig clusterConfig = new ClusterManagerConfig(this.config);
        if (appConfig.getAppMode() == ApplicationConfig.ApplicationMode.BATCH && clusterConfig.getHostAffinityEnabled()) {
            throw new SamzaException(String.format("Host affinity is not supported in batch mode. Please configure %s=false.", "job.host-affinity.enabled"));
        }
    }

    JobGraph createJobGraph(ApplicationDescriptorImpl<? extends ApplicationDescriptor> appDesc) {
        JobGraph jobGraph = new JobGraph(this.config, appDesc);
        Set<StreamSpec> sourceStreams = StreamUtil.getStreamSpecs(appDesc.getInputStreamIds(), this.streamConfig);
        Set<StreamSpec> sinkStreams = StreamUtil.getStreamSpecs(appDesc.getOutputStreamIds(), this.streamConfig);
        Sets.SetView intermediateStreams = Sets.intersection(sourceStreams, sinkStreams);
        Sets.SetView inputStreams = Sets.difference(sourceStreams, (Set)intermediateStreams);
        Sets.SetView outputStreams = Sets.difference(sinkStreams, (Set)intermediateStreams);
        Set tables = appDesc.getTableDescriptors().stream().map(tableDescriptor -> ((BaseTableDescriptor)tableDescriptor).getTableSpec()).collect(Collectors.toSet());
        String jobName = (String)this.config.get((Object)JobConfig.JOB_NAME());
        String jobId = this.config.get(JobConfig.JOB_ID(), "1");
        JobNode node = jobGraph.getOrCreateJobNode(jobName, jobId);
        inputStreams.forEach(spec -> jobGraph.addInputStream((StreamSpec)spec, node));
        outputStreams.forEach(spec -> jobGraph.addOutputStream((StreamSpec)spec, node));
        intermediateStreams.forEach(spec -> jobGraph.addIntermediateStream((StreamSpec)spec, node, node));
        for (TableSpec table : tables) {
            jobGraph.addTable(table, node);
            List sideInputs = ListUtils.emptyIfNull((List)table.getSideInputs());
            for (String sideInput : sideInputs) {
                jobGraph.addSideInputStream(StreamUtil.getStreamSpec(sideInput, this.streamConfig));
            }
        }
        if (!LegacyTaskApplication.class.isAssignableFrom(appDesc.getAppClass())) {
            jobGraph.validate();
        }
        return jobGraph;
    }

    void setInputAndOutputStreamPartitionCount(JobGraph jobGraph) {
        HashSet<StreamEdge> existingStreams = new HashSet<StreamEdge>();
        existingStreams.addAll(jobGraph.getInputStreams());
        existingStreams.addAll(jobGraph.getSideInputStreams());
        existingStreams.addAll(jobGraph.getOutputStreams());
        HashMultimap systemToStreamEdges = HashMultimap.create();
        for (StreamEdge streamEdge : existingStreams) {
            String system = streamEdge.getSystemStream().getSystem();
            systemToStreamEdges.put((Object)system, (Object)streamEdge);
        }
        for (String system : systemToStreamEdges.keySet()) {
            Collection streamEdges = systemToStreamEdges.get((Object)system);
            HashMap<String, StreamEdge> streamToStreamEdge = new HashMap<String, StreamEdge>();
            for (StreamEdge streamEdge : streamEdges) {
                streamToStreamEdge.put(streamEdge.getSystemStream().getStream(), streamEdge);
            }
            Set<String> streams = streamToStreamEdge.keySet();
            Map<String, Integer> streamToPartitionCount = this.streamManager.getStreamPartitionCounts(system, streams);
            for (Map.Entry<String, Integer> entry : streamToPartitionCount.entrySet()) {
                String stream = entry.getKey();
                Integer partitionCount = entry.getValue();
                ((StreamEdge)streamToStreamEdge.get(stream)).setPartitionCount(partitionCount);
                log.info("Fetched partition count value {} for stream {}", (Object)partitionCount, (Object)stream);
            }
        }
    }

    private static List<StreamSet> groupJoinedStreams(JobGraph jobGraph) {
        Multimap<OperatorSpec, InputOperatorSpec> joinOpSpecToInputOpSpecs = OperatorSpecGraphAnalyzer.getJoinToInputOperatorSpecs(jobGraph.getApplicationDescriptorImpl().getInputOperators().values());
        ArrayList<StreamSet> streamSets = new ArrayList<StreamSet>();
        for (OperatorSpec joinOpSpec : joinOpSpecToInputOpSpecs.keySet()) {
            Collection joinedInputOpSpecs = joinOpSpecToInputOpSpecs.get((Object)joinOpSpec);
            StreamSet streamSet = ExecutionPlanner.getStreamSet(joinOpSpec.getOpId(), joinedInputOpSpecs, jobGraph);
            if (joinOpSpec instanceof StreamTableJoinOperatorSpec) {
                StreamTableJoinOperatorSpec streamTableJoinOperatorSpec = (StreamTableJoinOperatorSpec)joinOpSpec;
                List sideInputs = ListUtils.emptyIfNull((List)streamTableJoinOperatorSpec.getTableSpec().getSideInputs());
                Iterable sideInputStreams = sideInputs.stream().map(jobGraph::getStreamEdge)::iterator;
                Set<StreamEdge> streams = streamSet.getStreamEdges();
                streamSet = new StreamSet(streamSet.getSetId(), Iterables.concat(streams, sideInputStreams));
            }
            streamSets.add(streamSet);
        }
        return Collections.unmodifiableList(streamSets);
    }

    private static StreamSet getStreamSet(String setId, Iterable<InputOperatorSpec> inputOpSpecs, JobGraph jobGraph) {
        HashSet<StreamEdge> streamEdges = new HashSet<StreamEdge>();
        for (InputOperatorSpec inputOpSpec : inputOpSpecs) {
            StreamEdge streamEdge = jobGraph.getStreamEdge(inputOpSpec.getStreamId());
            streamEdges.add(streamEdge);
        }
        return new StreamSet(setId, streamEdges);
    }

    private static void validatePartitions(StreamSet streamSet) {
        Set<StreamEdge> streamEdges = streamSet.getStreamEdges();
        StreamEdge referenceStreamEdge = (StreamEdge)streamEdges.stream().findFirst().get();
        int referencePartitions = referenceStreamEdge.getPartitionCount();
        for (StreamEdge streamEdge : streamEdges) {
            int partitions = streamEdge.getPartitionCount();
            if (partitions == referencePartitions) continue;
            throw new SamzaException(String.format("Unable to resolve input partitions of stream %s for the join %s. Expected: %d, Actual: %d", referenceStreamEdge.getName(), streamSet.getSetId(), referencePartitions, partitions));
        }
    }

    static class StreamSet {
        private final String setId;
        private final Set<StreamEdge> streamEdges;

        StreamSet(String setId, Iterable<StreamEdge> streamEdges) {
            this.setId = setId;
            this.streamEdges = ImmutableSet.copyOf(streamEdges);
        }

        Set<StreamEdge> getStreamEdges() {
            return Collections.unmodifiableSet(this.streamEdges);
        }

        String getSetId() {
            return this.setId;
        }
    }
}

