/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.streams.processor.internals.assignment;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.UUID;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.kafka.common.Cluster;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.TopicPartitionInfo;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.InternalTopicManager;
import org.apache.kafka.streams.processor.internals.TopologyMetadata;
import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration;
import org.apache.kafka.streams.processor.internals.assignment.ClientState;
import org.apache.kafka.streams.processor.internals.assignment.Graph;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RackAwareTaskAssignor {
    public static final int STATELESS_TRAFFIC_COST = 1;
    public static final int STATELESS_NON_OVERLAP_COST = 0;
    private static final Logger log = LoggerFactory.getLogger(RackAwareTaskAssignor.class);
    private static final int SOURCE_ID = -1;
    private static final int STANDBY_OPTIMIZER_MAX_ITERATION = 4;
    private final Cluster fullMetadata;
    private final Map<TaskId, Set<TopicPartition>> partitionsForTask;
    private final Map<TaskId, Set<TopicPartition>> changelogPartitionsForTask;
    private final AssignorConfiguration.AssignmentConfigs assignmentConfigs;
    private final Map<TopicPartition, Set<String>> racksForPartition;
    private final Map<UUID, String> racksForProcess;
    private final InternalTopicManager internalTopicManager;
    private final boolean validClientRack;
    private final Time time;
    private Boolean canEnable = null;

    public RackAwareTaskAssignor(Cluster fullMetadata, Map<TaskId, Set<TopicPartition>> partitionsForTask, Map<TaskId, Set<TopicPartition>> changelogPartitionsForTask, Map<TopologyMetadata.Subtopology, Set<TaskId>> tasksForTopicGroup, Map<UUID, Map<String, Optional<String>>> racksForProcessConsumer, InternalTopicManager internalTopicManager, AssignorConfiguration.AssignmentConfigs assignmentConfigs, Time time) {
        this.fullMetadata = fullMetadata;
        this.partitionsForTask = partitionsForTask;
        this.changelogPartitionsForTask = changelogPartitionsForTask;
        this.internalTopicManager = internalTopicManager;
        this.assignmentConfigs = assignmentConfigs;
        this.racksForPartition = new HashMap<TopicPartition, Set<String>>();
        this.racksForProcess = new HashMap<UUID, String>();
        this.time = Objects.requireNonNull(time, "Time was not specified");
        this.validClientRack = this.validateClientRack(racksForProcessConsumer);
    }

    public boolean validClientRack() {
        return this.validClientRack;
    }

    public synchronized boolean canEnableRackAwareAssignor() {
        if ("none".equals(this.assignmentConfigs.rackAwareAssignmentStrategy)) {
            return false;
        }
        if (this.canEnable != null) {
            return this.canEnable;
        }
        this.canEnable = this.validClientRack && this.validateTopicPartitionRack(false);
        if (this.assignmentConfigs.numStandbyReplicas == 0 || !this.canEnable.booleanValue()) {
            return this.canEnable;
        }
        this.canEnable = this.validateTopicPartitionRack(true);
        return this.canEnable;
    }

    public boolean populateTopicsToDescribe(Set<String> topicsToDescribe, boolean changelog) {
        if (changelog) {
            this.changelogPartitionsForTask.values().stream().flatMap(Collection::stream).forEach(tp -> topicsToDescribe.add(tp.topic()));
            return true;
        }
        for (Set<TopicPartition> topicPartitions : this.partitionsForTask.values()) {
            for (TopicPartition topicPartition : topicPartitions) {
                PartitionInfo partitionInfo = this.fullMetadata.partition(topicPartition);
                if (partitionInfo == null) {
                    log.error("TopicPartition {} doesn't exist in cluster", (Object)topicPartition);
                    return false;
                }
                Node[] replica = partitionInfo.replicas();
                if (replica == null || replica.length == 0) {
                    topicsToDescribe.add(topicPartition.topic());
                    continue;
                }
                for (Node node : replica) {
                    if (!node.hasRack()) {
                        log.warn("Node {} for topic partition {} doesn't have rack", (Object)node, (Object)topicPartition);
                        return false;
                    }
                    this.racksForPartition.computeIfAbsent(topicPartition, k -> new HashSet()).add(node.rack());
                }
            }
        }
        return true;
    }

    private boolean validateTopicPartitionRack(boolean changelogTopics) {
        HashSet<String> topicsToDescribe = new HashSet<String>();
        if (!this.populateTopicsToDescribe(topicsToDescribe, changelogTopics)) {
            return false;
        }
        if (!topicsToDescribe.isEmpty()) {
            log.info("Fetching PartitionInfo for topics {}", topicsToDescribe);
            try {
                Map<String, List<TopicPartitionInfo>> topicPartitionInfo = this.internalTopicManager.getTopicPartitionInfo(topicsToDescribe);
                if (topicsToDescribe.size() > topicPartitionInfo.size()) {
                    topicsToDescribe.removeAll(topicPartitionInfo.keySet());
                    log.error("Failed to describe topic for {}", topicsToDescribe);
                    return false;
                }
                for (Map.Entry<String, List<TopicPartitionInfo>> entry : topicPartitionInfo.entrySet()) {
                    List<TopicPartitionInfo> partitionInfos = entry.getValue();
                    for (TopicPartitionInfo partitionInfo : partitionInfos) {
                        int partition = partitionInfo.partition();
                        List replicas = partitionInfo.replicas();
                        if (replicas == null || replicas.isEmpty()) {
                            log.error("No replicas found for topic partition {}: {}", (Object)entry.getKey(), (Object)partition);
                            return false;
                        }
                        TopicPartition topicPartition = new TopicPartition(entry.getKey(), partition);
                        for (Node node : replicas) {
                            if (node.hasRack()) {
                                this.racksForPartition.computeIfAbsent(topicPartition, k -> new HashSet()).add(node.rack());
                                continue;
                            }
                            return false;
                        }
                    }
                }
            }
            catch (Exception e) {
                log.error("Failed to describe topics {}", topicsToDescribe, (Object)e);
                return false;
            }
        }
        return true;
    }

    private boolean validateClientRack(Map<UUID, Map<String, Optional<String>>> racksForProcessConsumer) {
        if (racksForProcessConsumer == null) {
            return false;
        }
        for (Map.Entry<UUID, Map<String, Optional<String>>> entry : racksForProcessConsumer.entrySet()) {
            UUID processId = entry.getKey();
            KeyValue<String, String> previousRackInfo = null;
            for (Map.Entry<String, Optional<String>> rackEntry : entry.getValue().entrySet()) {
                if (!rackEntry.getValue().isPresent()) {
                    log.error(String.format("RackId doesn't exist for process %s and consumer %s", processId, rackEntry.getKey()));
                    return false;
                }
                if (previousRackInfo == null) {
                    previousRackInfo = KeyValue.pair(rackEntry.getKey(), rackEntry.getValue().get());
                    continue;
                }
                if (((String)previousRackInfo.value).equals(rackEntry.getValue().get())) continue;
                log.error(String.format("Consumers %s and %s for same process %s has different rackId %s and %s. File a ticket for this bug", previousRackInfo.key, rackEntry.getKey(), entry.getKey(), previousRackInfo.value, rackEntry.getValue().get()));
                return false;
            }
            if (previousRackInfo == null) {
                log.error(String.format("RackId doesn't exist for process %s", processId));
                return false;
            }
            this.racksForProcess.put(entry.getKey(), (String)previousRackInfo.value);
        }
        return true;
    }

    public Map<UUID, String> racksForProcess() {
        return Collections.unmodifiableMap(this.racksForProcess);
    }

    public Map<TopicPartition, Set<String>> racksForPartition() {
        return Collections.unmodifiableMap(this.racksForPartition);
    }

    private int getCost(TaskId taskId, UUID processId, boolean inCurrentAssignment, int trafficCost, int nonOverlapCost, boolean isStandby) {
        Set<TopicPartition> topicPartitions;
        String clientRack = this.racksForProcess.get(processId);
        if (clientRack == null) {
            throw new IllegalStateException("Client " + processId + " doesn't have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
        }
        Set<TopicPartition> set = topicPartitions = isStandby ? this.changelogPartitionsForTask.get(taskId) : this.partitionsForTask.get(taskId);
        if (topicPartitions == null || topicPartitions.isEmpty()) {
            throw new IllegalStateException("Task " + taskId + " has no TopicPartitions");
        }
        int cost = 0;
        for (TopicPartition tp : topicPartitions) {
            Set<String> tpRacks = this.racksForPartition.get(tp);
            if (tpRacks == null || tpRacks.isEmpty()) {
                throw new IllegalStateException("TopicPartition " + tp + " has no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
            }
            if (tpRacks.contains(clientRack)) continue;
            cost += trafficCost;
        }
        if (!inCurrentAssignment) {
            cost += nonOverlapCost;
        }
        return cost;
    }

    private static int getSinkNodeID(List<UUID> clientList, List<TaskId> taskIdList) {
        return clientList.size() + taskIdList.size();
    }

    private static int getClientNodeId(List<TaskId> taskIdList, int clientIndex) {
        return clientIndex + taskIdList.size();
    }

    private static int getClientIndex(List<TaskId> taskIdList, int clientNodeId) {
        return clientNodeId - taskIdList.size();
    }

    long activeTasksCost(SortedSet<TaskId> activeTasks, SortedMap<UUID, ClientState> clientStates, int trafficCost, int nonOverlapCost) {
        return this.tasksCost(activeTasks, clientStates, trafficCost, nonOverlapCost, ClientState::hasActiveTask, false, false);
    }

    long standByTasksCost(SortedSet<TaskId> standbyTasks, SortedMap<UUID, ClientState> clientStates, int trafficCost, int nonOverlapCost) {
        return this.tasksCost(standbyTasks, clientStates, trafficCost, nonOverlapCost, ClientState::hasStandbyTask, true, true);
    }

    private long tasksCost(SortedSet<TaskId> tasks, SortedMap<UUID, ClientState> clientStates, int trafficCost, int nonOverlapCost, BiPredicate<ClientState, TaskId> hasAssignedTask, boolean hasReplica, boolean isStandby) {
        if (tasks.isEmpty()) {
            return 0L;
        }
        ArrayList<UUID> clientList = new ArrayList<UUID>(clientStates.keySet());
        ArrayList<TaskId> taskIdList = new ArrayList<TaskId>(tasks);
        Graph<Integer> graph = this.constructTaskGraph(clientList, taskIdList, clientStates, new HashMap<TaskId, UUID>(), new HashMap<UUID, Integer>(), hasAssignedTask, trafficCost, nonOverlapCost, hasReplica, isStandby);
        return graph.totalCost();
    }

    public long optimizeActiveTasks(SortedSet<TaskId> activeTasks, SortedMap<UUID, ClientState> clientStates, int trafficCost, int nonOverlapCost) {
        if (activeTasks.isEmpty()) {
            return 0L;
        }
        log.info("Assignment before active task optimization is {}\n with cost {}", clientStates, (Object)this.activeTasksCost(activeTasks, clientStates, trafficCost, nonOverlapCost));
        long startTime = this.time.milliseconds();
        ArrayList<UUID> clientList = new ArrayList<UUID>(clientStates.keySet());
        ArrayList<TaskId> taskIdList = new ArrayList<TaskId>(activeTasks);
        HashMap<TaskId, UUID> taskClientMap = new HashMap<TaskId, UUID>();
        HashMap<UUID, Integer> originalAssignedTaskNumber = new HashMap<UUID, Integer>();
        Graph<Integer> graph = this.constructTaskGraph(clientList, taskIdList, clientStates, taskClientMap, originalAssignedTaskNumber, ClientState::hasActiveTask, trafficCost, nonOverlapCost, false, false);
        graph.solveMinCostFlow();
        long cost = graph.totalCost();
        this.assignTaskFromMinCostFlow(graph, clientList, taskIdList, clientStates, originalAssignedTaskNumber, taskClientMap, ClientState::assignActive, ClientState::unassignActive, ClientState::hasActiveTask);
        long duration = this.time.milliseconds() - startTime;
        log.info("Assignment after {} milliseconds for active task optimization is {}\n with cost {}", new Object[]{duration, clientStates, cost});
        return cost;
    }

    public long optimizeStandbyTasks(SortedMap<UUID, ClientState> clientStates, int trafficCost, int nonOverlapCost, MoveStandbyTaskPredicate moveStandbyTask) {
        BiFunction<ClientState, ClientState, List> getMovableTasks = (source, destination) -> source.standbyTasks().stream().filter(task -> !destination.hasAssignedTask((TaskId)task)).filter(task -> moveStandbyTask.canMove((ClientState)source, (ClientState)destination, (TaskId)task, (Map<UUID, ClientState>)clientStates)).sorted().collect(Collectors.toList());
        long startTime = this.time.milliseconds();
        ArrayList<UUID> clientList = new ArrayList<UUID>(clientStates.keySet());
        TreeSet<TaskId> standbyTasks = new TreeSet<TaskId>();
        clientStates.values().forEach(clientState -> standbyTasks.addAll(clientState.standbyTasks()));
        log.info("Assignment before standby task optimization is {}\n with cost {}", clientStates, (Object)this.standByTasksCost(standbyTasks, clientStates, trafficCost, nonOverlapCost));
        boolean taskMoved = true;
        int round = 0;
        while (taskMoved && round < 4) {
            taskMoved = false;
            ++round;
            for (int i = 0; i < clientList.size(); ++i) {
                ClientState clientState1 = (ClientState)clientStates.get(clientList.get(i));
                for (int j = i + 1; j < clientList.size(); ++j) {
                    String rack2;
                    ClientState clientState2 = (ClientState)clientStates.get(clientList.get(j));
                    String rack1 = this.racksForProcess.get(clientState1.processId());
                    if (rack1.equals(rack2 = this.racksForProcess.get(clientState2.processId()))) continue;
                    List movable1 = getMovableTasks.apply(clientState1, clientState2);
                    List movable2 = getMovableTasks.apply(clientState2, clientState1);
                    if (movable1.isEmpty() || movable2.isEmpty()) continue;
                    List<TaskId> taskIdList = Stream.concat(movable1.stream(), movable2.stream()).sorted().collect(Collectors.toList());
                    HashMap<TaskId, UUID> taskClientMap = new HashMap<TaskId, UUID>();
                    List<UUID> clients = Stream.of((UUID)clientList.get(i), (UUID)clientList.get(j)).sorted().collect(Collectors.toList());
                    HashMap<UUID, Integer> originalAssignedTaskNumber = new HashMap<UUID, Integer>();
                    Graph<Integer> graph = this.constructTaskGraph(clients, taskIdList, clientStates, taskClientMap, originalAssignedTaskNumber, ClientState::hasStandbyTask, trafficCost, nonOverlapCost, true, true);
                    graph.solveMinCostFlow();
                    taskMoved |= this.assignTaskFromMinCostFlow(graph, clients, taskIdList, clientStates, originalAssignedTaskNumber, taskClientMap, ClientState::assignStandby, ClientState::unassignStandby, ClientState::hasStandbyTask);
                }
            }
        }
        long cost = this.standByTasksCost(standbyTasks, clientStates, trafficCost, nonOverlapCost);
        long duration = this.time.milliseconds() - startTime;
        log.info("Assignment after {} rounds and {} milliseconds for standby task optimization is {}\n with cost {}", new Object[]{round, duration, clientStates, cost});
        return cost;
    }

    private Graph<Integer> constructTaskGraph(List<UUID> clientList, List<TaskId> taskIdList, Map<UUID, ClientState> clientStates, Map<TaskId, UUID> taskClientMap, Map<UUID, Integer> originalAssignedTaskNumber, BiPredicate<ClientState, TaskId> hasAssignedTask, int trafficCost, int nonOverlapCost, boolean hasReplica, boolean isStandby) {
        Graph<Integer> graph = new Graph<Integer>();
        for (TaskId taskId : taskIdList) {
            for (Map.Entry<UUID, ClientState> clientState : clientStates.entrySet()) {
                if (!hasAssignedTask.test(clientState.getValue(), taskId)) continue;
                originalAssignedTaskNumber.merge(clientState.getKey(), 1, Integer::sum);
            }
        }
        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); ++taskNodeId) {
            TaskId taskId;
            taskId = taskIdList.get(taskNodeId);
            for (int j = 0; j < clientList.size(); ++j) {
                int clientNodeId = RackAwareTaskAssignor.getClientNodeId(taskIdList, j);
                UUID processId = clientList.get(j);
                int flow = hasAssignedTask.test(clientStates.get(processId), taskId) ? 1 : 0;
                int cost = this.getCost(taskId, processId, flow == 1, trafficCost, nonOverlapCost, isStandby);
                if (flow == 1) {
                    if (!hasReplica && taskClientMap.containsKey(taskId)) {
                        throw new IllegalArgumentException("Task " + taskId + " assigned to multiple clients " + processId + ", " + taskClientMap.get(taskId));
                    }
                    taskClientMap.put(taskId, processId);
                }
                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
            }
            if (!taskClientMap.containsKey(taskId)) {
                throw new IllegalArgumentException("Task " + taskId + " not assigned to any client");
            }
            graph.addEdge(-1, taskNodeId, 1, 0, 1);
        }
        int sinkId = RackAwareTaskAssignor.getSinkNodeID(clientList, taskIdList);
        for (int i = 0; i < clientList.size(); ++i) {
            int clientNodeId = RackAwareTaskAssignor.getClientNodeId(taskIdList, i);
            int capacity = originalAssignedTaskNumber.getOrDefault(clientList.get(i), 0);
            graph.addEdge(clientNodeId, sinkId, capacity, 0, capacity);
        }
        graph.setSourceNode(-1);
        graph.setSinkNode(sinkId);
        return graph;
    }

    private boolean assignTaskFromMinCostFlow(Graph<Integer> graph, List<UUID> clientList, List<TaskId> taskIdList, Map<UUID, ClientState> clientStates, Map<UUID, Integer> originalAssignedTaskNumber, Map<TaskId, UUID> taskClientMap, BiConsumer<ClientState, TaskId> assignTask, BiConsumer<ClientState, TaskId> unAssignTask, BiPredicate<ClientState, TaskId> hasAssignedTask) {
        int tasksAssigned = 0;
        boolean taskMoved = false;
        block0: for (int taskNodeId = 0; taskNodeId < taskIdList.size(); ++taskNodeId) {
            TaskId taskId = taskIdList.get(taskNodeId);
            SortedMap<Integer, Graph.Edge> sortedMap = graph.edges(taskNodeId);
            for (Graph.Edge edge : sortedMap.values()) {
                UUID originalProcessId;
                if (edge.flow <= 0) continue;
                ++tasksAssigned;
                int clientIndex = RackAwareTaskAssignor.getClientIndex(taskIdList, (Integer)edge.destination);
                UUID processId = clientList.get(clientIndex);
                if (processId.equals(originalProcessId = taskClientMap.get(taskId))) continue block0;
                unAssignTask.accept(clientStates.get(originalProcessId), taskId);
                assignTask.accept(clientStates.get(processId), taskId);
                taskMoved = true;
            }
        }
        if (tasksAssigned != taskIdList.size()) {
            throw new IllegalStateException("Computed active task assignment number " + tasksAssigned + " is different size " + taskIdList.size());
        }
        HashMap assignedTaskNumber = new HashMap();
        for (TaskId taskId : taskIdList) {
            for (Map.Entry entry : clientStates.entrySet()) {
                if (!hasAssignedTask.test((ClientState)entry.getValue(), taskId)) continue;
                assignedTaskNumber.merge(entry.getKey(), 1, Integer::sum);
            }
        }
        if (originalAssignedTaskNumber.size() != assignedTaskNumber.size()) {
            throw new IllegalStateException("There are " + originalAssignedTaskNumber.size() + " clients have  active tasks before assignment, but " + assignedTaskNumber.size() + " clients have active tasks after assignment");
        }
        for (Map.Entry entry : originalAssignedTaskNumber.entrySet()) {
            int capacity = assignedTaskNumber.getOrDefault(entry.getKey(), 0);
            if (Objects.equals(entry.getValue(), capacity)) continue;
            throw new IllegalStateException("There are " + entry.getValue() + " tasks assigned to client " + entry.getKey() + " before assignment, but " + capacity + " tasks  are assigned to it after assignment");
        }
        return taskMoved;
    }

    @FunctionalInterface
    public static interface MoveStandbyTaskPredicate {
        public boolean canMove(ClientState var1, ClientState var2, TaskId var3, Map<UUID, ClientState> var4);
    }
}

