/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.service.deploy.worker.congestcontrol;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.ThreadUtils;
import org.apache.celeborn.service.deploy.worker.WorkerSource;
import org.apache.celeborn.service.deploy.worker.congestcontrol.BufferStatusHub;
import org.apache.celeborn.service.deploy.worker.memory.MemoryManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CongestionController {
    private static final Logger logger = LoggerFactory.getLogger(CongestionController.class);
    private static volatile CongestionController _INSTANCE = null;
    private final WorkerSource workerSource;
    private final int sampleTimeWindowSeconds;
    private final long highWatermark;
    private final long lowWatermark;
    private final long userInactiveTimeMills;
    private final AtomicBoolean overHighWatermark = new AtomicBoolean(false);
    private final BufferStatusHub consumedBufferStatusHub;
    private final ConcurrentHashMap<UserIdentifier, UserBufferInfo> userBufferStatuses;
    private final ScheduledExecutorService removeUserExecutorService;

    protected CongestionController(WorkerSource workerSource, int sampleTimeWindowSeconds, long highWatermark, long lowWatermark, long userInactiveTimeMills) {
        assert (highWatermark > lowWatermark);
        this.workerSource = workerSource;
        this.sampleTimeWindowSeconds = sampleTimeWindowSeconds;
        this.highWatermark = highWatermark;
        this.lowWatermark = lowWatermark;
        this.userInactiveTimeMills = userInactiveTimeMills;
        this.consumedBufferStatusHub = new BufferStatusHub(sampleTimeWindowSeconds);
        this.userBufferStatuses = JavaUtils.newConcurrentHashMap();
        this.removeUserExecutorService = ThreadUtils.newDaemonSingleThreadScheduledExecutor((String)"remove-inactive-user");
        this.removeUserExecutorService.scheduleWithFixedDelay(this::removeInactiveUsers, 0L, userInactiveTimeMills, TimeUnit.MILLISECONDS);
        this.workerSource.addGauge(WorkerSource.POTENTIAL_CONSUME_SPEED(), this::getPotentialConsumeSpeed);
        this.workerSource.addGauge(WorkerSource.WORKER_CONSUME_SPEED(), this.consumedBufferStatusHub::avgBytesPerSec);
    }

    public static synchronized CongestionController initialize(WorkerSource workSource, int sampleTimeWindowSeconds, long highWatermark, long lowWatermark, long userInactiveTimeMills) {
        _INSTANCE = new CongestionController(workSource, sampleTimeWindowSeconds, highWatermark, lowWatermark, userInactiveTimeMills);
        return _INSTANCE;
    }

    public static CongestionController instance() {
        return _INSTANCE;
    }

    public boolean isUserCongested(UserIdentifier userIdentifier) {
        if (this.userBufferStatuses.size() == 0) {
            return false;
        }
        long pendingConsumed = this.getTotalPendingBytes();
        if (pendingConsumed > this.highWatermark && this.overHighWatermark.compareAndSet(false, true)) {
            logger.info("Pending consume bytes: {} higher than high watermark, need to congest it", (Object)pendingConsumed);
        }
        if (this.overHighWatermark.get()) {
            this.trimMemoryUsage();
            pendingConsumed = this.getTotalPendingBytes();
            if (pendingConsumed < this.lowWatermark && this.overHighWatermark.compareAndSet(true, false)) {
                logger.info("Lower than low watermark, exit congestion control");
            }
            if (!this.overHighWatermark.get()) {
                return false;
            }
            long userProduceSpeed = this.getUserProduceSpeed(this.userBufferStatuses.get(userIdentifier));
            long avgConsumeSpeed = this.getPotentialConsumeSpeed();
            if (logger.isDebugEnabled()) {
                logger.debug("The user {}, produceSpeed is {}, while consumeSpeed is {}, need to congest it: {}", new Object[]{userIdentifier, userProduceSpeed, avgConsumeSpeed, userProduceSpeed > avgConsumeSpeed});
            }
            return userProduceSpeed > avgConsumeSpeed;
        }
        return false;
    }

    public void produceBytes(UserIdentifier userIdentifier, int numBytes) {
        long currentTimeMillis = System.currentTimeMillis();
        UserBufferInfo userBufferInfo = this.userBufferStatuses.computeIfAbsent(userIdentifier, user -> {
            logger.info("New user {} comes, initializing its rate status", user);
            BufferStatusHub bufferStatusHub = new BufferStatusHub(this.sampleTimeWindowSeconds);
            UserBufferInfo userInfo = new UserBufferInfo(currentTimeMillis, bufferStatusHub);
            this.workerSource.addGauge(WorkerSource.USER_PRODUCE_SPEED(), userIdentifier.toJMap(), () -> this.getUserProduceSpeed(userInfo));
            return userInfo;
        });
        BufferStatusHub.BufferStatusNode node = new BufferStatusHub.BufferStatusNode(numBytes);
        userBufferInfo.updateInfo(currentTimeMillis, node);
    }

    public void consumeBytes(int numBytes) {
        long currentTimeMillis = System.currentTimeMillis();
        BufferStatusHub.BufferStatusNode node = new BufferStatusHub.BufferStatusNode(numBytes);
        this.consumedBufferStatusHub.add(currentTimeMillis, node);
    }

    public long getTotalPendingBytes() {
        return MemoryManager.instance().getMemoryUsage();
    }

    public void trimMemoryUsage() {
        MemoryManager.instance().trimAllListeners();
    }

    public long getPotentialConsumeSpeed() {
        if (this.userBufferStatuses.size() == 0) {
            return 0L;
        }
        return this.consumedBufferStatusHub.avgBytesPerSec() / (long)this.userBufferStatuses.size();
    }

    private long getUserProduceSpeed(UserBufferInfo userBufferInfo) {
        if (userBufferInfo != null) {
            return userBufferInfo.getBufferStatusHub().avgBytesPerSec();
        }
        return 0L;
    }

    private void removeInactiveUsers() {
        try {
            long currentTimeMillis = System.currentTimeMillis();
            for (Map.Entry<UserIdentifier, UserBufferInfo> next : this.userBufferStatuses.entrySet()) {
                UserIdentifier userIdentifier = next.getKey();
                UserBufferInfo userBufferInfo = next.getValue();
                if (currentTimeMillis - userBufferInfo.getTimestamp() < this.userInactiveTimeMills) continue;
                this.userBufferStatuses.remove(userIdentifier);
                this.workerSource.removeGauge(WorkerSource.USER_PRODUCE_SPEED(), userIdentifier.toMap());
                logger.info("User {} has been expired, remove from rate limit list", (Object)userIdentifier);
            }
        }
        catch (Exception e) {
            logger.error("Error occurs when removing inactive users, ", (Throwable)e);
        }
    }

    public void close() {
        this.removeUserExecutorService.shutdownNow();
        this.userBufferStatuses.clear();
        this.consumedBufferStatusHub.clear();
    }

    public static synchronized void destroy() {
        if (_INSTANCE != null) {
            _INSTANCE.close();
            _INSTANCE = null;
        }
    }

    private static class UserBufferInfo {
        long timestamp;
        final BufferStatusHub bufferStatusHub;

        public UserBufferInfo(long timestamp, BufferStatusHub bufferStatusHub) {
            this.timestamp = timestamp;
            this.bufferStatusHub = bufferStatusHub;
        }

        synchronized void updateInfo(long timestamp, BufferStatusHub.BufferStatusNode node) {
            this.timestamp = timestamp;
            this.bufferStatusHub.add(timestamp, node);
        }

        public long getTimestamp() {
            return this.timestamp;
        }

        public BufferStatusHub getBufferStatusHub() {
            return this.bufferStatusHub;
        }
    }
}

