/*
 * Decompiled with CFR 0.152.
 */
package org.apache.samza.execution;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.HashMultimap;
import java.util.Collection;
import java.util.HashSet;
import java.util.Optional;
import org.apache.samza.SamzaException;
import org.apache.samza.config.Config;
import org.apache.samza.config.JobConfig;
import org.apache.samza.execution.ExecutionPlanner;
import org.apache.samza.execution.JobGraph;
import org.apache.samza.execution.StreamEdge;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class IntermediateStreamManager {
    private static final Logger log = LoggerFactory.getLogger(IntermediateStreamManager.class);
    private final Config config;
    @VisibleForTesting
    static final int MAX_INFERRED_PARTITIONS = 256;

    IntermediateStreamManager(Config config) {
        this.config = config;
    }

    void calculatePartitions(JobGraph jobGraph, Collection<ExecutionPlanner.StreamSet> joinedStreamSets) {
        IntermediateStreamManager.setJoinedIntermediateStreamPartitions(joinedStreamSets);
        this.setIntermediateStreamPartitions(jobGraph);
        IntermediateStreamManager.validateIntermediateStreamPartitions(jobGraph);
    }

    private void setIntermediateStreamPartitions(JobGraph jobGraph) {
        String defaultPartitionsConfigProperty = JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS();
        int partitions = this.config.getInt(defaultPartitionsConfigProperty, -1);
        if (partitions == -1) {
            int maxOutPartitions;
            int maxInPartitions = IntermediateStreamManager.maxPartitions(jobGraph.getInputStreams());
            partitions = Math.max(maxInPartitions, maxOutPartitions = IntermediateStreamManager.maxPartitions(jobGraph.getOutputStreams()));
            if (partitions > 256) {
                partitions = 256;
                log.warn(String.format("Inferred intermediate stream partition count %d is greater than the max %d. Using the max.", partitions, 256));
            }
        } else {
            if (partitions <= 0) {
                throw new SamzaException(String.format("Invalid value %d specified for config property %s", partitions, defaultPartitionsConfigProperty));
            }
            log.info("Using partition count value {} specified for config property {}", (Object)partitions, (Object)defaultPartitionsConfigProperty);
        }
        for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) {
            if (edge.getPartitionCount() > 0) continue;
            log.info("Set the partition count for intermediate stream {} to {}.", (Object)edge.getName(), (Object)partitions);
            edge.setPartitionCount(partitions);
        }
    }

    private static void setJoinedIntermediateStreamPartitions(Collection<ExecutionPlanner.StreamSet> joinedStreamSets) {
        HashMultimap intermediateStreamToStreamSets = HashMultimap.create();
        for (ExecutionPlanner.StreamSet streamSet : joinedStreamSets) {
            for (StreamEdge streamEdge2 : streamSet.getStreamEdges()) {
                if (streamEdge2.getPartitionCount() != -1) continue;
                intermediateStreamToStreamSets.put((Object)streamEdge2, (Object)streamSet);
            }
        }
        HashSet<ExecutionPlanner.StreamSet> streamSets = new HashSet<ExecutionPlanner.StreamSet>(joinedStreamSets);
        HashSet<ExecutionPlanner.StreamSet> processedStreamSets = new HashSet<ExecutionPlanner.StreamSet>();
        while (!streamSets.isEmpty()) {
            ExecutionPlanner.StreamSet streamSet = (ExecutionPlanner.StreamSet)streamSets.iterator().next();
            streamSets.remove(streamSet);
            Optional<StreamEdge> streamWithSetPartitions = streamSet.getStreamEdges().stream().filter(streamEdge -> streamEdge.getPartitionCount() != -1).findAny();
            if (!streamWithSetPartitions.isPresent()) continue;
            processedStreamSets.add(streamSet);
            int partitions = streamWithSetPartitions.get().getPartitionCount();
            for (StreamEdge streamEdge3 : streamSet.getStreamEdges()) {
                if (streamEdge3.getPartitionCount() != -1) continue;
                streamEdge3.setPartitionCount(partitions);
                Collection streamSetsIncludingIntStream = intermediateStreamToStreamSets.get((Object)streamEdge3);
                streamSetsIncludingIntStream.stream().filter(s -> !processedStreamSets.contains(s)).forEach(streamSets::add);
            }
        }
    }

    private static void validateIntermediateStreamPartitions(JobGraph jobGraph) {
        for (StreamEdge edge : jobGraph.getIntermediateStreamEdges()) {
            if (edge.getPartitionCount() > 0) continue;
            throw new SamzaException(String.format("Failed to assign valid partition count to Stream %s", edge.getName()));
        }
    }

    static int maxPartitions(Collection<StreamEdge> edges) {
        return edges.stream().mapToInt(StreamEdge::getPartitionCount).max().orElse(-1);
    }
}

