/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.engine.server.task;

import com.hazelcast.cluster.Address;
import com.hazelcast.logging.ILogger;
import com.hazelcast.logging.Logger;
import com.hazelcast.spi.impl.AbstractInvocationFuture;
import com.hazelcast.spi.impl.operationservice.Operation;
import java.io.IOException;
import java.io.Serializable;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.apache.seatunnel.api.serialization.Serializer;
import org.apache.seatunnel.api.source.SourceSplit;
import org.apache.seatunnel.api.source.SourceSplitEnumerator;
import org.apache.seatunnel.engine.common.utils.ExceptionUtil;
import org.apache.seatunnel.engine.core.dag.actions.SourceAction;
import org.apache.seatunnel.engine.server.checkpoint.ActionSubtaskState;
import org.apache.seatunnel.engine.server.checkpoint.CheckpointBarrier;
import org.apache.seatunnel.engine.server.checkpoint.operation.TaskAcknowledgeOperation;
import org.apache.seatunnel.engine.server.execution.ProgressState;
import org.apache.seatunnel.engine.server.execution.TaskLocation;
import org.apache.seatunnel.engine.server.task.CoordinatorTask;
import org.apache.seatunnel.engine.server.task.context.SeaTunnelSplitEnumeratorContext;
import org.apache.seatunnel.engine.server.task.operation.checkpoint.BarrierFlowOperation;
import org.apache.seatunnel.engine.server.task.record.Barrier;
import org.apache.seatunnel.engine.server.task.statemachine.SeaTunnelTaskState;

public class SourceSplitEnumeratorTask<SplitT extends SourceSplit>
extends CoordinatorTask {
    private static final ILogger LOGGER = Logger.getLogger(SourceSplitEnumeratorTask.class);
    private static final long serialVersionUID = -3713701594297977775L;
    private final SourceAction<?, SplitT, Serializable> source;
    private SourceSplitEnumerator<SplitT, Serializable> enumerator;
    private SeaTunnelSplitEnumeratorContext<SplitT> enumeratorContext;
    private Serializer<Serializable> enumeratorStateSerializer;
    private int maxReaderSize;
    private Set<Long> unfinishedReaders;
    private Map<TaskLocation, Address> taskMemberMapping;
    private Map<Long, TaskLocation> taskIDToTaskLocationMapping;
    private Map<Integer, TaskLocation> taskIndexToTaskLocationMapping;
    private volatile SeaTunnelTaskState currState;
    private volatile boolean readerRegisterComplete;

    @Override
    public void init() throws Exception {
        this.currState = SeaTunnelTaskState.INIT;
        super.init();
        this.readerRegisterComplete = false;
        LOGGER.info("starting seatunnel source split enumerator task, source name: " + this.source.getName());
        this.enumeratorContext = new SeaTunnelSplitEnumeratorContext(this.source.getParallelism(), this);
        this.enumeratorStateSerializer = this.source.getSource().getEnumeratorStateSerializer();
        this.taskMemberMapping = new ConcurrentHashMap<TaskLocation, Address>();
        this.taskIDToTaskLocationMapping = new ConcurrentHashMap<Long, TaskLocation>();
        this.taskIndexToTaskLocationMapping = new ConcurrentHashMap<Integer, TaskLocation>();
        this.maxReaderSize = this.source.getParallelism();
        this.unfinishedReaders = new CopyOnWriteArraySet<Long>();
    }

    @Override
    public void close() throws IOException {
        if (this.enumerator != null) {
            this.enumerator.close();
        }
        this.progress.done();
    }

    public SourceSplitEnumeratorTask(long jobID, TaskLocation taskID, SourceAction<?, SplitT, ?> source) {
        super(jobID, taskID);
        this.source = source;
        this.currState = SeaTunnelTaskState.CREATED;
    }

    @Override
    @NonNull
    public ProgressState call() throws Exception {
        this.stateProcess();
        return this.progress.toState();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void triggerBarrier(Barrier barrier) throws Exception {
        if (barrier.prepareClose()) {
            this.currState = SeaTunnelTaskState.PREPARE_CLOSE;
            this.prepareCloseBarrierId.set(barrier.getId());
        }
        long barrierId = barrier.getId();
        Serializable snapshotState = null;
        SeaTunnelSplitEnumeratorContext<SplitT> seaTunnelSplitEnumeratorContext = this.enumeratorContext;
        synchronized (seaTunnelSplitEnumeratorContext) {
            if (barrier.snapshot()) {
                snapshotState = this.enumerator.snapshotState(barrierId);
            }
            this.sendToAllReader(location -> new BarrierFlowOperation(barrier, (TaskLocation)location));
        }
        if (barrier.snapshot()) {
            byte[] serialize = this.enumeratorStateSerializer.serialize(snapshotState);
            this.getExecutionContext().sendToMaster(new TaskAcknowledgeOperation(this.taskLocation, (CheckpointBarrier)barrier, Collections.singletonList(new ActionSubtaskState(this.source.getId(), -1, Collections.singletonList(serialize)))));
        }
    }

    @Override
    public void restoreState(List<ActionSubtaskState> actionStateList) throws Exception {
        Optional<Serializable> state = actionStateList.stream().map(ActionSubtaskState::getState).flatMap(Collection::stream).map(bytes -> ExceptionUtil.sneaky(() -> this.enumeratorStateSerializer.deserialize((byte[])bytes))).findFirst();
        this.enumerator = state.isPresent() ? this.source.getSource().restoreEnumerator(this.enumeratorContext, state.get()) : this.source.getSource().createEnumerator(this.enumeratorContext);
        this.restoreComplete = true;
    }

    public void addSplitsBack(List<SplitT> splits, int subtaskId) {
        this.enumerator.addSplitsBack(splits, subtaskId);
    }

    public void receivedReader(TaskLocation readerId, Address memberAddr) {
        LOGGER.info("received reader register, readerID: " + readerId);
        this.addTaskMemberMapping(readerId, memberAddr);
        this.enumerator.registerReader(readerId.getTaskIndex());
        if (this.maxReaderSize == this.taskMemberMapping.size()) {
            this.readerRegisterComplete = true;
        }
    }

    public void requestSplit(long taskID) {
        this.enumerator.handleSplitRequest((int)taskID);
    }

    public void addTaskMemberMapping(TaskLocation taskID, Address memberAdder) {
        this.taskMemberMapping.put(taskID, memberAdder);
        this.taskIDToTaskLocationMapping.put(taskID.getTaskID(), taskID);
        this.taskIndexToTaskLocationMapping.put(taskID.getTaskIndex(), taskID);
        this.unfinishedReaders.add(taskID.getTaskID());
    }

    public Address getTaskMemberAddress(long taskID) {
        return this.taskMemberMapping.get(this.taskIDToTaskLocationMapping.get(taskID));
    }

    public TaskLocation getTaskMemberLocation(long taskID) {
        return this.taskIDToTaskLocationMapping.get(taskID);
    }

    public Address getTaskMemberAddressByIndex(int taskIndex) {
        return this.taskMemberMapping.get(this.taskIndexToTaskLocationMapping.get(taskIndex));
    }

    public TaskLocation getTaskMemberLocationByIndex(int taskIndex) {
        return this.taskIndexToTaskLocationMapping.get(taskIndex);
    }

    public void readerFinished(long taskID) {
        this.unfinishedReaders.remove(taskID);
        if (this.unfinishedReaders.isEmpty()) {
            this.prepareCloseStatus = true;
        }
    }

    private void stateProcess() throws Exception {
        switch (this.currState) {
            case INIT: {
                this.currState = SeaTunnelTaskState.WAITING_RESTORE;
                this.reportTaskStatus(SeaTunnelTaskState.WAITING_RESTORE);
                break;
            }
            case WAITING_RESTORE: {
                if (!this.restoreComplete) break;
                this.currState = SeaTunnelTaskState.READY_START;
                this.reportTaskStatus(SeaTunnelTaskState.READY_START);
                break;
            }
            case READY_START: {
                if (!this.startCalled || !this.readerRegisterComplete) break;
                this.currState = SeaTunnelTaskState.STARTING;
                this.enumerator.open();
                break;
            }
            case STARTING: {
                this.currState = SeaTunnelTaskState.RUNNING;
                LOGGER.info("received enough reader, starting enumerator...");
                this.enumerator.run();
                break;
            }
            case RUNNING: {
                if (!this.prepareCloseStatus) break;
                this.triggerBarrier(Barrier.completedBarrier());
                this.currState = SeaTunnelTaskState.PREPARE_CLOSE;
                break;
            }
            case PREPARE_CLOSE: {
                if (!this.closeCalled) break;
                this.currState = SeaTunnelTaskState.CLOSED;
                break;
            }
            case CLOSED: {
                this.close();
                return;
            }
            case CANCELLING: {
                this.close();
                this.currState = SeaTunnelTaskState.CANCELED;
                return;
            }
            default: {
                throw new IllegalArgumentException("Unknown Enumerator State: " + this.currState);
            }
        }
    }

    public Set<Integer> getRegisteredReaders() {
        return this.taskMemberMapping.keySet().stream().map(TaskLocation::getTaskIndex).collect(Collectors.toSet());
    }

    private void sendToAllReader(Function<TaskLocation, Operation> function) {
        ArrayList futures = new ArrayList();
        this.taskMemberMapping.forEach((location, address) -> futures.add(this.getExecutionContext().sendToMember((Operation)function.apply((TaskLocation)location), (Address)address)));
        futures.forEach(AbstractInvocationFuture::join);
    }

    @Override
    public Set<URL> getJarsUrl() {
        return new HashSet<URL>(this.source.getJarUrls());
    }

    @Override
    public void notifyCheckpointComplete(long checkpointId) throws Exception {
        this.enumerator.notifyCheckpointComplete(checkpointId);
        if (this.currState == SeaTunnelTaskState.PREPARE_CLOSE && this.prepareCloseBarrierId.get() == checkpointId) {
            this.closeCall();
        }
    }

    @Override
    public void notifyCheckpointAborted(long checkpointId) throws Exception {
        this.enumerator.notifyCheckpointAborted(checkpointId);
        if (this.currState == SeaTunnelTaskState.PREPARE_CLOSE && this.prepareCloseBarrierId.get() == checkpointId) {
            this.closeCall();
        }
    }
}

