/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.service.deploy.master;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import org.apache.celeborn.common.meta.DiskInfo;
import org.apache.celeborn.common.meta.DiskStatus;
import org.apache.celeborn.common.meta.WorkerInfo;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.StorageInfo;
import org.apache.commons.lang3.StringUtils;
import org.roaringbitmap.RoaringBitmap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class SlotsAllocator {
    private static final Logger logger = LoggerFactory.getLogger(SlotsAllocator.class);
    private static final Random rand = new Random();
    private static boolean initialized = false;
    private static double[] taskAllocationRatio = null;

    public static Map<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>> offerSlotsRoundRobin(List<WorkerInfo> workers, List<Integer> partitionIds, boolean shouldReplicate, boolean shouldRackAware) {
        if (partitionIds.isEmpty()) {
            return new HashMap<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>>();
        }
        if (workers.size() < 2 && shouldReplicate) {
            return new HashMap<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>>();
        }
        HashMap<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>> slots = new HashMap<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>>();
        HashMap<WorkerInfo, List<UsableDiskInfo>> restrictions = new HashMap<WorkerInfo, List<UsableDiskInfo>>();
        for (WorkerInfo worker : workers) {
            List usableDisks = restrictions.computeIfAbsent(worker, v -> new ArrayList());
            for (Map.Entry diskInfoEntry : worker.diskInfos().entrySet()) {
                if (!((DiskInfo)diskInfoEntry.getValue()).status().equals((Object)DiskStatus.HEALTHY)) continue;
                usableDisks.add(new UsableDiskInfo((DiskInfo)diskInfoEntry.getValue(), ((DiskInfo)diskInfoEntry.getValue()).availableSlots()));
            }
        }
        List<Integer> remain = SlotsAllocator.roundRobin(slots, partitionIds, workers, restrictions, shouldReplicate, shouldRackAware);
        if (!remain.isEmpty()) {
            remain = SlotsAllocator.roundRobin(slots, remain, workers, null, shouldReplicate, shouldRackAware);
        }
        if (!remain.isEmpty()) {
            SlotsAllocator.roundRobin(slots, remain, workers, null, shouldReplicate, false);
        }
        return slots;
    }

    public static Map<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>> offerSlotsLoadAware(List<WorkerInfo> workers, List<Integer> partitionIds, boolean shouldReplicate, boolean shouldRackAware, long minimumUsableSize, int diskGroupCount, double diskGroupGradient, double flushTimeWeight, double fetchTimeWeight) {
        Map<WorkerInfo, List<UsableDiskInfo>> restriction;
        HashMap<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>> slots;
        List<Integer> remainPartitions;
        if (partitionIds.isEmpty()) {
            return new HashMap<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>>();
        }
        if (workers.size() < 2 && shouldReplicate) {
            return new HashMap<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>>();
        }
        ArrayList<DiskInfo> usableDisks = new ArrayList<DiskInfo>();
        HashMap<DiskInfo, WorkerInfo> diskToWorkerMap = new HashMap<DiskInfo, WorkerInfo>();
        workers.forEach(i -> i.diskInfos().forEach((key, diskInfo) -> {
            diskToWorkerMap.put((DiskInfo)diskInfo, (WorkerInfo)i);
            if (diskInfo.actualUsableSpace() > minimumUsableSize && diskInfo.status().equals((Object)DiskStatus.HEALTHY)) {
                usableDisks.add((DiskInfo)diskInfo);
            }
        }));
        HashSet usableWorkers = new HashSet();
        for (DiskInfo disk : usableDisks) {
            usableWorkers.add(diskToWorkerMap.get(disk));
        }
        if (shouldReplicate && usableWorkers.size() <= 1 || usableDisks.isEmpty()) {
            logger.warn("offer slots for {} fallback to roundrobin because there is no usable disks", (Object)StringUtils.join(partitionIds, (char)','));
            return SlotsAllocator.offerSlotsRoundRobin(workers, partitionIds, shouldReplicate, shouldRackAware);
        }
        if (!initialized) {
            SlotsAllocator.initLoadAwareAlgorithm(diskGroupCount, diskGroupGradient);
        }
        if (!(remainPartitions = SlotsAllocator.roundRobin(slots = new HashMap<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>>(), partitionIds, new ArrayList<WorkerInfo>((restriction = SlotsAllocator.getRestriction(SlotsAllocator.placeDisksToGroups(usableDisks, diskGroupCount, flushTimeWeight, fetchTimeWeight), diskToWorkerMap, shouldReplicate ? partitionIds.size() * 2 : partitionIds.size())).keySet()), restriction, shouldReplicate, shouldRackAware)).isEmpty()) {
            remainPartitions = SlotsAllocator.roundRobin(slots, remainPartitions, new ArrayList<WorkerInfo>(workers), null, shouldReplicate, shouldRackAware);
        }
        if (!remainPartitions.isEmpty()) {
            SlotsAllocator.roundRobin(slots, remainPartitions, new ArrayList<WorkerInfo>(workers), null, shouldReplicate, false);
        }
        return slots;
    }

    private static StorageInfo getStorageInfo(List<WorkerInfo> workers, int workerIndex, Map<WorkerInfo, List<UsableDiskInfo>> restrictions, Map<WorkerInfo, Integer> workerDiskIndex) {
        WorkerInfo selectedWorker = workers.get(workerIndex);
        List<UsableDiskInfo> usableDiskInfos = restrictions.get(selectedWorker);
        int diskIndex = workerDiskIndex.computeIfAbsent(selectedWorker, v -> 0);
        while (usableDiskInfos.get((int)diskIndex).usableSlots <= 0L) {
            diskIndex = (diskIndex + 1) % usableDiskInfos.size();
        }
        --usableDiskInfos.get((int)diskIndex).usableSlots;
        StorageInfo storageInfo = new StorageInfo(usableDiskInfos.get((int)diskIndex).diskInfo.mountPoint());
        workerDiskIndex.put(selectedWorker, (diskIndex + 1) % usableDiskInfos.size());
        return storageInfo;
    }

    private static List<Integer> roundRobin(Map<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>> slots, List<Integer> partitionIds, List<WorkerInfo> workers, Map<WorkerInfo, List<UsableDiskInfo>> restrictions, boolean shouldReplicate, boolean shouldRackAware) {
        HashMap<WorkerInfo, Integer> workerDiskIndexForPrimary = new HashMap<WorkerInfo, Integer>();
        HashMap<WorkerInfo, Integer> workerDiskIndexForReplica = new HashMap<WorkerInfo, Integer>();
        ArrayList<Integer> partitionIdList = new ArrayList<Integer>(partitionIds);
        int primaryIndex = rand.nextInt(workers.size());
        Iterator iter = partitionIdList.iterator();
        block0: while (iter.hasNext()) {
            int nextPrimayInd = primaryIndex;
            int partitionId = (Integer)iter.next();
            StorageInfo storageInfo = new StorageInfo();
            if (restrictions != null) {
                while (!SlotsAllocator.haveUsableSlots(restrictions, workers, nextPrimayInd)) {
                    if ((nextPrimayInd = (nextPrimayInd + 1) % workers.size()) != primaryIndex) continue;
                    break block0;
                }
                storageInfo = SlotsAllocator.getStorageInfo(workers, nextPrimayInd, restrictions, workerDiskIndexForPrimary);
            }
            PartitionLocation primaryPartition = SlotsAllocator.createLocation(partitionId, workers.get(nextPrimayInd), null, storageInfo, true);
            if (shouldReplicate) {
                int nextReplicaInd = (nextPrimayInd + 1) % workers.size();
                if (restrictions != null) {
                    while (!SlotsAllocator.haveUsableSlots(restrictions, workers, nextReplicaInd) || !SlotsAllocator.satisfyRackAware(shouldRackAware, workers, nextPrimayInd, nextReplicaInd)) {
                        if ((nextReplicaInd = (nextReplicaInd + 1) % workers.size()) != nextPrimayInd) continue;
                        break block0;
                    }
                    storageInfo = SlotsAllocator.getStorageInfo(workers, nextReplicaInd, restrictions, workerDiskIndexForReplica);
                } else if (shouldRackAware) {
                    while (!SlotsAllocator.satisfyRackAware(true, workers, nextPrimayInd, nextReplicaInd)) {
                        if ((nextReplicaInd = (nextReplicaInd + 1) % workers.size()) != nextPrimayInd) continue;
                        break block0;
                    }
                }
                PartitionLocation replicaPartition = SlotsAllocator.createLocation(partitionId, workers.get(nextReplicaInd), primaryPartition, storageInfo, false);
                primaryPartition.setPeer(replicaPartition);
                Tuple2 locations = slots.computeIfAbsent(workers.get(nextReplicaInd), v -> new Tuple2(new ArrayList(), new ArrayList()));
                ((List)locations._2).add(replicaPartition);
            }
            Tuple2 locations = slots.computeIfAbsent(workers.get(nextPrimayInd), v -> new Tuple2(new ArrayList(), new ArrayList()));
            ((List)locations._1).add(primaryPartition);
            primaryIndex = (nextPrimayInd + 1) % workers.size();
            iter.remove();
        }
        return partitionIdList;
    }

    private static boolean haveUsableSlots(Map<WorkerInfo, List<UsableDiskInfo>> restrictions, List<WorkerInfo> workers, int index) {
        return restrictions.get(workers.get(index)).stream().mapToLong(i -> i.usableSlots).sum() > 0L;
    }

    private static boolean satisfyRackAware(boolean shouldRackAware, List<WorkerInfo> workers, int primaryIndex, int nextReplicaInd) {
        return !shouldRackAware || !Objects.equals(workers.get(primaryIndex).networkLocation(), workers.get(nextReplicaInd).networkLocation());
    }

    private static void initLoadAwareAlgorithm(int diskGroups, double diskGroupGradient) {
        int i;
        taskAllocationRatio = new double[diskGroups];
        double totalAllocations = 0.0;
        for (i = 0; i < diskGroups; ++i) {
            totalAllocations += Math.pow(1.0 + diskGroupGradient, diskGroups - 1 - i);
        }
        for (i = 0; i < diskGroups; ++i) {
            SlotsAllocator.taskAllocationRatio[i] = Math.pow(1.0 + diskGroupGradient, diskGroups - 1 - i) / totalAllocations;
        }
        logger.info("load-aware offer slots algorithm init with taskAllocationRatio {}", (Object)StringUtils.join((double[])taskAllocationRatio, (char)','));
        initialized = true;
    }

    private static List<List<DiskInfo>> placeDisksToGroups(List<DiskInfo> usableDisks, int diskGroupCount, double flushTimeWeight, double fetchTimeWeight) {
        ArrayList<List<DiskInfo>> diskGroups = new ArrayList<List<DiskInfo>>();
        usableDisks.sort((o1, o2) -> {
            double delta = (double)o1.avgFlushTime() * flushTimeWeight + (double)o1.avgFetchTime() * fetchTimeWeight - ((double)o2.avgFlushTime() * flushTimeWeight + (double)o2.avgFetchTime() * fetchTimeWeight);
            return delta < 0.0 ? -1 : (delta > 0.0 ? 1 : 0);
        });
        int diskCount = usableDisks.size();
        int startIndex = 0;
        int groupSizeSize = (int)Math.ceil((double)usableDisks.size() / (double)diskGroupCount);
        for (int i = 0; i < diskGroupCount; ++i) {
            ArrayList<DiskInfo> diskList = new ArrayList<DiskInfo>();
            if (startIndex >= usableDisks.size()) continue;
            if (startIndex + groupSizeSize <= diskCount) {
                diskList.addAll(usableDisks.subList(startIndex, startIndex + groupSizeSize));
                startIndex += groupSizeSize;
            } else {
                diskList.addAll(usableDisks.subList(startIndex, usableDisks.size()));
                startIndex = usableDisks.size();
            }
            diskGroups.add(diskList);
        }
        return diskGroups;
    }

    private static Map<WorkerInfo, List<UsableDiskInfo>> getRestriction(List<List<DiskInfo>> groups, Map<DiskInfo, WorkerInfo> diskWorkerMap, int partitionCnt) {
        int i;
        int groupSize = groups.size();
        long[] groupAllocations = new long[groupSize];
        HashMap<WorkerInfo, List<UsableDiskInfo>> restrictions = new HashMap<WorkerInfo, List<UsableDiskInfo>>();
        long[] groupAvailableSlots = new long[groupSize];
        for (int i2 = 0; i2 < groupSize; ++i2) {
            for (DiskInfo disk : groups.get(i2)) {
                int n = i2;
                groupAvailableSlots[n] = groupAvailableSlots[n] + disk.availableSlots();
            }
        }
        double[] currentAllocation = new double[groupSize];
        double currentAllocationSum = 0.0;
        for (i = 0; i < groupSize; ++i) {
            if (groups.get(i).isEmpty()) continue;
            currentAllocationSum += taskAllocationRatio[i];
        }
        for (i = 0; i < groupSize; ++i) {
            if (groups.get(i).isEmpty()) continue;
            currentAllocation[i] = taskAllocationRatio[i] / currentAllocationSum;
        }
        long toNextGroup = 0L;
        long left = partitionCnt;
        for (int i3 = 0; i3 < groupSize && left > 0L; left -= groupAllocations[i3], ++i3) {
            long estimateAllocation = (int)Math.ceil((double)partitionCnt * currentAllocation[i3]);
            if (estimateAllocation > left) {
                estimateAllocation = left;
            }
            if (estimateAllocation + toNextGroup > groupAvailableSlots[i3]) {
                groupAllocations[i3] = groupAvailableSlots[i3];
                toNextGroup = estimateAllocation - groupAvailableSlots[i3] + toNextGroup;
                continue;
            }
            groupAllocations[i3] = estimateAllocation + toNextGroup;
        }
        long groupLeft = 0L;
        for (int i4 = 0; i4 < groups.size(); ++i4) {
            int disksInsideGroup = groups.get(i4).size();
            long groupRequired = groupAllocations[i4] + groupLeft;
            for (DiskInfo disk : groups.get(i4)) {
                if (groupRequired <= 0L) break;
                List diskAllocation = restrictions.computeIfAbsent(diskWorkerMap.get(disk), v -> new ArrayList());
                long allocated = (int)Math.ceil((double)(groupAllocations[i4] + groupLeft) / (double)disksInsideGroup);
                if (allocated > disk.availableSlots()) {
                    allocated = disk.availableSlots();
                }
                if (allocated > groupRequired) {
                    allocated = groupRequired;
                }
                diskAllocation.add(new UsableDiskInfo(disk, Math.toIntExact(allocated)));
                groupRequired -= allocated;
            }
            groupLeft = groupRequired;
        }
        if (logger.isDebugEnabled()) {
            StringBuilder sb = new StringBuilder();
            for (int i5 = 0; i5 < groups.size(); ++i5) {
                sb.append("| group ").append(i5).append(" ");
                for (DiskInfo diskInfo : groups.get(i5)) {
                    WorkerInfo workerInfo = diskWorkerMap.get(diskInfo);
                    String workerHost = workerInfo.host();
                    long allocation = 0L;
                    if (restrictions.get(workerInfo) != null) {
                        for (UsableDiskInfo usableInfo : (List)restrictions.get(workerInfo)) {
                            if (!usableInfo.diskInfo.equals(diskInfo)) continue;
                            allocation = usableInfo.usableSlots;
                        }
                    }
                    sb.append(workerHost).append("-").append(diskInfo.mountPoint()).append(" flushtime:").append(diskInfo.avgFlushTime()).append(" fetchtime:").append(diskInfo.avgFetchTime()).append(" allocation: ").append(allocation).append(" ");
                }
                sb.append(" | ");
            }
            logger.debug("total {} allocate with group {} with allocations {}", new Object[]{partitionCnt, StringUtils.join((long[])groupAllocations, (char)','), sb});
        }
        return restrictions;
    }

    private static PartitionLocation createLocation(int partitionIndex, WorkerInfo workerInfo, PartitionLocation peer, StorageInfo storageInfo, boolean isPrimary) {
        return new PartitionLocation(partitionIndex, 0, workerInfo.host(), workerInfo.rpcPort(), workerInfo.pushPort(), workerInfo.fetchPort(), workerInfo.replicatePort(), isPrimary ? PartitionLocation.Mode.PRIMARY : PartitionLocation.Mode.REPLICA, peer, storageInfo, new RoaringBitmap());
    }

    public static Map<WorkerInfo, Map<String, Integer>> slotsToDiskAllocations(Map<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>> slots) {
        Iterator<WorkerInfo> workers = slots.keySet().iterator();
        HashMap<WorkerInfo, Map<String, Integer>> workerToSlots = new HashMap<WorkerInfo, Map<String, Integer>>();
        while (workers.hasNext()) {
            WorkerInfo worker = workers.next();
            Map slotsPerDisk = workerToSlots.computeIfAbsent(worker, v -> new HashMap());
            ArrayList jointLocations = new ArrayList();
            jointLocations.addAll((Collection)slots.get((Object)worker)._1);
            jointLocations.addAll((Collection)slots.get((Object)worker)._2);
            for (PartitionLocation location : jointLocations) {
                String mountPoint = location.getStorageInfo().getMountPoint();
                if (slotsPerDisk.containsKey(mountPoint)) {
                    slotsPerDisk.put(mountPoint, (Integer)slotsPerDisk.get(mountPoint) + 1);
                    continue;
                }
                slotsPerDisk.put(mountPoint, 1);
            }
        }
        return workerToSlots;
    }

    static class UsableDiskInfo {
        DiskInfo diskInfo;
        long usableSlots;

        UsableDiskInfo(DiskInfo diskInfo, long usableSlots) {
            this.diskInfo = diskInfo;
            this.usableSlots = usableSlots;
        }
    }
}

