/*
 * Decompiled with CFR 0.152.
 */
package org.apache.samza.container.grouper.task;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.samza.SamzaException;
import org.apache.samza.config.Config;
import org.apache.samza.config.JobConfig;
import org.apache.samza.container.LocalityManager;
import org.apache.samza.container.TaskName;
import org.apache.samza.container.grouper.task.BalancingTaskNameGrouper;
import org.apache.samza.container.grouper.task.TaskAssignmentManager;
import org.apache.samza.job.model.ContainerModel;
import org.apache.samza.job.model.TaskModel;
import org.apache.samza.metrics.MetricsRegistry;
import org.apache.samza.metrics.MetricsRegistryMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GroupByContainerCount
implements BalancingTaskNameGrouper {
    private static final Logger log = LoggerFactory.getLogger(GroupByContainerCount.class);
    private final int containerCount;
    private final Config config;

    public GroupByContainerCount(Config config) {
        this.containerCount = new JobConfig(config).getContainerCount();
        this.config = config;
        if (this.containerCount <= 0) {
            throw new IllegalArgumentException("Must have at least one container");
        }
    }

    @Override
    public Set<ContainerModel> group(Set<TaskModel> tasks) {
        int i;
        this.validateTasks(tasks);
        ArrayList<TaskModel> sortedTasks = new ArrayList<TaskModel>(tasks);
        Collections.sort(sortedTasks);
        Map[] taskGroups = new Map[this.containerCount];
        for (i = 0; i < this.containerCount; ++i) {
            taskGroups[i] = new HashMap();
        }
        for (i = 0; i < sortedTasks.size(); ++i) {
            TaskModel tm = (TaskModel)sortedTasks.get(i);
            taskGroups[i % this.containerCount].put(tm.getTaskName(), tm);
        }
        HashSet<ContainerModel> containerModels = new HashSet<ContainerModel>();
        for (int i2 = 0; i2 < this.containerCount; ++i2) {
            containerModels.add(new ContainerModel(String.valueOf(i2), taskGroups[i2]));
        }
        return Collections.unmodifiableSet(containerModels);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Set<ContainerModel> balance(Set<TaskModel> tasks, LocalityManager localityManager) {
        this.validateTasks(tasks);
        if (localityManager == null) {
            log.info("Locality manager is null. Cannot read or write task assignments. Invoking grouper.");
            return this.group(tasks);
        }
        taskAssignmentManager.init();
        try (TaskAssignmentManager taskAssignmentManager = new TaskAssignmentManager(this.config, (MetricsRegistry)new MetricsRegistryMap());){
            List<TaskGroup> containers = this.getPreviousContainers(taskAssignmentManager, tasks.size());
            if (containers == null || containers.size() == 1 || this.containerCount == 1) {
                log.info("Balancing does not apply. Invoking grouper.");
                Set<ContainerModel> models = this.group(tasks);
                this.saveTaskAssignments(models, taskAssignmentManager);
                Set<ContainerModel> set = models;
                return set;
            }
            int prevContainerCount = containers.size();
            int containerDelta = this.containerCount - prevContainerCount;
            if (containerDelta == 0) {
                log.info("Container count has not changed. Reusing previous container models.");
                Set<ContainerModel> set = this.buildContainerModels(tasks, containers);
                return set;
            }
            log.info("Container count changed from {} to {}. Balancing tasks.", (Object)prevContainerCount, (Object)this.containerCount);
            int[] expectedTaskCountPerContainer = this.calculateTaskCountPerContainer(tasks.size(), prevContainerCount, this.containerCount);
            LinkedList<String> taskNamesToReassign = new LinkedList<String>();
            for (int i = 0; i < prevContainerCount; ++i) {
                TaskGroup taskGroup = containers.get(i);
                while (taskGroup.size() > expectedTaskCountPerContainer[i]) {
                    taskNamesToReassign.add(taskGroup.removeTask());
                }
            }
            if (containerDelta > 0) {
                List<TaskGroup> newContainers = this.createContainers(prevContainerCount, this.containerCount);
                containers.addAll(newContainers);
            } else {
                containers = containers.subList(0, this.containerCount);
            }
            this.assignTasksToContainers(expectedTaskCountPerContainer, taskNamesToReassign, containers);
            Set<ContainerModel> models = this.buildContainerModels(tasks, containers);
            this.saveTaskAssignments(models, taskAssignmentManager);
            Set<ContainerModel> set = models;
            return set;
        }
    }

    private List<TaskGroup> getPreviousContainers(TaskAssignmentManager taskAssignmentManager, int taskCount) {
        List<TaskGroup> containers;
        Map<String, String> taskToContainerId = taskAssignmentManager.readTaskAssignment();
        taskToContainerId.values().forEach(id -> {
            try {
                int n = Integer.parseInt(id);
            }
            catch (NumberFormatException nfe) {
                throw new SamzaException("GroupByContainerCount cannot handle non-integer processorIds!", (Throwable)nfe);
            }
        });
        if (taskToContainerId.isEmpty()) {
            log.info("No task assignment map was saved.");
            return null;
        }
        if (taskCount != taskToContainerId.size()) {
            log.warn("Current task count {} does not match saved task count {}. Stateful jobs may observe misalignment of keys!", (Object)taskCount, (Object)taskToContainerId.size());
            taskAssignmentManager.deleteTaskContainerMappings(taskToContainerId.keySet());
            return null;
        }
        try {
            containers = this.getOrderedContainers(taskToContainerId);
        }
        catch (Exception e) {
            log.error("Exception while parsing task mapping", (Throwable)e);
            return null;
        }
        return containers;
    }

    private void saveTaskAssignments(Set<ContainerModel> containers, TaskAssignmentManager taskAssignmentManager) {
        for (ContainerModel container : containers) {
            for (TaskName taskName : container.getTasks().keySet()) {
                taskAssignmentManager.writeTaskContainerMapping(taskName.getTaskName(), container.getId());
            }
        }
    }

    private void validateTasks(Set<TaskModel> tasks) {
        if (tasks.size() <= 0) {
            throw new IllegalArgumentException("No tasks found. Likely due to no input partitions. Can't run a job with no tasks.");
        }
        if (tasks.size() < this.containerCount) {
            throw new IllegalArgumentException(String.format("Your container count (%s) is larger than your task count (%s). Can't have containers with nothing to do, so aborting.", this.containerCount, tasks.size()));
        }
    }

    private List<TaskGroup> createContainers(int startContainerId, int endContainerId) {
        ArrayList<TaskGroup> containers = new ArrayList<TaskGroup>(endContainerId - startContainerId);
        for (int i = startContainerId; i < endContainerId; ++i) {
            TaskGroup taskGroup = new TaskGroup(String.valueOf(i), new ArrayList());
            containers.add(taskGroup);
        }
        return containers;
    }

    private void assignTasksToContainers(int[] taskCountPerContainer, List<String> taskNamesToAssign, List<TaskGroup> containers) {
        for (TaskGroup taskGroup : containers) {
            for (int j = taskGroup.size(); j < taskCountPerContainer[Integer.valueOf(taskGroup.getContainerId())]; ++j) {
                String taskName = taskNamesToAssign.remove(0);
                taskGroup.addTaskName(taskName);
                log.info("Assigned task {} to container {}", (Object)taskName, (Object)taskGroup.getContainerId());
            }
        }
    }

    private int[] calculateTaskCountPerContainer(int taskCount, int prevContainerCount, int currentContainerCount) {
        int[] newTaskCountPerContainer = new int[Math.max(currentContainerCount, prevContainerCount)];
        Arrays.fill(newTaskCountPerContainer, 0);
        for (int i = 0; i < currentContainerCount; ++i) {
            newTaskCountPerContainer[i] = taskCount / currentContainerCount;
            if (taskCount % currentContainerCount <= i) continue;
            int n = i;
            newTaskCountPerContainer[n] = newTaskCountPerContainer[n] + 1;
        }
        return newTaskCountPerContainer;
    }

    private Set<ContainerModel> buildContainerModels(Set<TaskModel> tasks, List<TaskGroup> containerTasks) {
        HashMap<String, TaskModel> taskNameToModel = new HashMap<String, TaskModel>();
        for (TaskModel model : tasks) {
            taskNameToModel.put(model.getTaskName().getTaskName(), model);
        }
        HashSet<ContainerModel> containerModels = new HashSet<ContainerModel>();
        for (TaskGroup container : containerTasks) {
            HashMap<TaskName, TaskModel> containerTaskModels = new HashMap<TaskName, TaskModel>();
            for (String taskName : container.taskNames) {
                TaskModel model = (TaskModel)taskNameToModel.get(taskName);
                containerTaskModels.put(model.getTaskName(), model);
            }
            containerModels.add(new ContainerModel(container.containerId, containerTaskModels));
        }
        return Collections.unmodifiableSet(containerModels);
    }

    private List<TaskGroup> getOrderedContainers(Map<String, String> taskToContainerId) {
        log.debug("Got task to container map: {}", taskToContainerId);
        HashMap<String, ArrayList<String>> containerIdToTaskNames = new HashMap<String, ArrayList<String>>();
        for (Map.Entry<String, String> entry : taskToContainerId.entrySet()) {
            String taskName = entry.getKey();
            String containerId = entry.getValue();
            ArrayList<String> taskNames = (ArrayList<String>)containerIdToTaskNames.get(containerId);
            if (taskNames == null) {
                taskNames = new ArrayList<String>();
                containerIdToTaskNames.put(containerId, taskNames);
            }
            taskNames.add(taskName);
        }
        ArrayList<TaskGroup> containerTasks = new ArrayList<TaskGroup>(containerIdToTaskNames.size());
        for (int i = 0; i < containerIdToTaskNames.size(); ++i) {
            if (containerIdToTaskNames.get(String.valueOf(i)) == null) {
                throw new IllegalStateException("Task mapping is missing container: " + i);
            }
            containerTasks.add(new TaskGroup(String.valueOf(i), (List)containerIdToTaskNames.get(String.valueOf(i))));
        }
        return containerTasks;
    }

    private static class TaskGroup {
        private final List<String> taskNames = new LinkedList<String>();
        private final String containerId;

        private TaskGroup(String containerId, List<String> taskNames) {
            this.containerId = containerId;
            Collections.sort(taskNames);
            this.taskNames.addAll(taskNames);
        }

        public String getContainerId() {
            return this.containerId;
        }

        public void addTaskName(String taskName) {
            this.taskNames.add(taskName);
        }

        public String removeTask() {
            return this.taskNames.remove(this.taskNames.size() - 1);
        }

        public int size() {
            return this.taskNames.size();
        }
    }
}

