/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hertzbeat.collector.collect.websocket;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.net.UnknownHostException;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.lang3.StringUtils;
import org.apache.hertzbeat.collector.collect.AbstractCollect;
import org.apache.hertzbeat.common.entity.job.Metrics;
import org.apache.hertzbeat.common.entity.job.protocol.WebsocketProtocol;
import org.apache.hertzbeat.common.entity.message.CollectRep;
import org.apache.hertzbeat.common.util.CommonUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;

public class WebsocketCollectImpl
extends AbstractCollect {
    private static final Logger log = LoggerFactory.getLogger(WebsocketCollectImpl.class);

    @Override
    public void preCheck(Metrics metrics) throws IllegalArgumentException {
        if (metrics == null || metrics.getWebsocket() == null) {
            throw new IllegalArgumentException("Websocket collect must has Websocket params");
        }
    }

    @Override
    public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {
        long startTime = System.currentTimeMillis();
        WebsocketProtocol websocketProtocol = metrics.getWebsocket();
        if (StringUtils.isBlank(websocketProtocol.getPath())) {
            websocketProtocol.setPath("/");
        }
        this.checkParam(websocketProtocol);
        String host = websocketProtocol.getHost();
        String port = websocketProtocol.getPort();
        Socket socket = null;
        try {
            socket = new Socket();
            InetSocketAddress socketAddress = new InetSocketAddress(host, Integer.parseInt(port));
            socket.connect(socketAddress);
            if (socket.isConnected()) {
                long responseTime = System.currentTimeMillis() - startTime;
                OutputStream out = socket.getOutputStream();
                InputStream in = socket.getInputStream();
                this.send(out, websocketProtocol);
                Map<String, String> resultMap = this.readHeaders(in);
                resultMap.put("responseTime", Long.toString(responseTime));
                in.close();
                out.close();
                socket.close();
                List<String> aliasFields = metrics.getAliasFields();
                CollectRep.ValueRow.Builder valueRowBuilder = CollectRep.ValueRow.newBuilder();
                for (String field : aliasFields) {
                    String fieldValue = resultMap.get(field);
                    valueRowBuilder.addColumns(Objects.requireNonNullElse(fieldValue, "&nbsp;"));
                }
                builder.addValues(valueRowBuilder.build());
            } else {
                builder.setCode(CollectRep.Code.UN_CONNECTABLE);
                builder.setMsg("Peer connect failed:");
            }
        }
        catch (UnknownHostException unknownHostException) {
            String errorMsg = CommonUtil.getMessageFromThrowable(unknownHostException);
            log.info(errorMsg);
            builder.setCode(CollectRep.Code.UN_CONNECTABLE);
            builder.setMsg("UnknownHost:" + errorMsg);
        }
        catch (SocketTimeoutException socketTimeoutException) {
            String errorMsg = CommonUtil.getMessageFromThrowable(socketTimeoutException);
            log.info(errorMsg);
            builder.setCode(CollectRep.Code.UN_CONNECTABLE);
            builder.setMsg("Socket connect timeout: " + errorMsg);
        }
        catch (IOException ioException) {
            String errorMsg = CommonUtil.getMessageFromThrowable(ioException);
            log.info(errorMsg);
            builder.setCode(CollectRep.Code.UN_CONNECTABLE);
            builder.setMsg("Connect may fail:" + errorMsg);
        }
    }

    @Override
    public String supportProtocol() {
        return "websocket";
    }

    private void send(OutputStream out, WebsocketProtocol websocketProtocol) throws IOException {
        byte[] key = this.generateRandomKey();
        String base64Key = this.base64Encode(key);
        String requestLine = "GET " + websocketProtocol.getPath() + " HTTP/1.1\r\n";
        out.write(requestLine.getBytes());
        String hostName = InetAddress.getLocalHost().getHostAddress();
        out.write(("Host:" + hostName + "\r\n").getBytes());
        out.write("Upgrade: websocket\r\n".getBytes());
        out.write("Connection: Upgrade\r\n".getBytes());
        out.write("Sec-WebSocket-Version: 13\r\n".getBytes());
        out.write("Sec-WebSocket-Extensions: chat, superchat\r\n".getBytes());
        out.write(("Sec-WebSocket-Key: " + base64Key + "\r\n").getBytes());
        out.write("Content-Length: 0\r\n".getBytes());
        out.write("\r\n".getBytes());
        out.flush();
    }

    private Map<String, String> readHeaders(InputStream in) throws IOException {
        String line;
        HashMap<String, String> map = new HashMap<String, String>(8);
        BufferedReader reader = new BufferedReader(new InputStreamReader(in));
        while ((line = reader.readLine()) != null && !line.isEmpty()) {
            int separatorIndex = line.indexOf(58);
            if (separatorIndex != -1) {
                String key = line.substring(0, separatorIndex).trim();
                String value = line.substring(separatorIndex + 1).trim();
                map.put(StringUtils.uncapitalize(key), value);
                continue;
            }
            String[] parts = line.split("\\s+", 3);
            if (parts.length != 3) continue;
            for (String part : parts) {
                if (part.startsWith("HTTP")) {
                    map.put("httpVersion", part);
                    continue;
                }
                if (StringUtils.isNotBlank(part) && Character.isDigit(part.charAt(0))) {
                    map.put("responseCode", part);
                    continue;
                }
                map.put("statusMessage", part);
            }
        }
        return map;
    }

    private byte[] generateRandomKey() {
        SecureRandom secureRandom = new SecureRandom();
        byte[] key = new byte[16];
        secureRandom.nextBytes(key);
        return key;
    }

    private void checkParam(WebsocketProtocol protocol) {
        Assert.hasText(protocol.getHost(), "Websocket Protocol host is required.");
        Assert.hasText(protocol.getPort(), "Websocket Protocol port is required.");
        Assert.hasText(protocol.getPath(), "Websocket Protocol path is required.");
    }

    private String base64Encode(byte[] data) {
        return Base64.getEncoder().encodeToString(data);
    }
}

