/*
 * Decompiled with CFR 0.152.
 */
package com.linecorp.armeria.server.websocket;

import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.ResponseHeadersBuilder;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.stream.ClosedStreamException;
import com.linecorp.armeria.common.stream.StreamMessage;
import com.linecorp.armeria.common.websocket.WebSocket;
import com.linecorp.armeria.common.websocket.WebSocketFrame;
import com.linecorp.armeria.internal.common.websocket.WebSocketFrameDecoder;
import com.linecorp.armeria.internal.common.websocket.WebSocketFrameEncoder;
import com.linecorp.armeria.internal.common.websocket.WebSocketUtil;
import com.linecorp.armeria.internal.common.websocket.WebSocketWrapper;
import com.linecorp.armeria.internal.shaded.guava.base.Splitter;
import com.linecorp.armeria.internal.shaded.guava.hash.Hashing;
import com.linecorp.armeria.internal.shaded.guava.net.HostAndPort;
import com.linecorp.armeria.server.AbstractHttpService;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.websocket.WebSocketServiceBuilder;
import com.linecorp.armeria.server.websocket.WebSocketServiceHandler;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.websocketx.WebSocketVersion;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class WebSocketService
extends AbstractHttpService {
    private static final Logger logger = LoggerFactory.getLogger(WebSocketService.class);
    private static final String WEBSOCKET_13_ACCEPT_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    private static final String SUB_PROTOCOL_WILDCARD = "*";
    private static final ResponseHeaders UNSUPPORTED_WEB_SOCKET_VERSION = ResponseHeaders.builder(HttpStatus.BAD_REQUEST).add((CharSequence)HttpHeaderNames.SEC_WEBSOCKET_VERSION, WebSocketVersion.V13.toHttpHeaderValue()).contentType(MediaType.PLAIN_TEXT_UTF_8).build();
    private static final Splitter commaSplitter = Splitter.on(',').trimResults().omitEmptyStrings();
    private static final WebSocketFrameEncoder encoder = WebSocketFrameEncoder.of(false);
    private final WebSocketServiceHandler handler;
    private final int maxFramePayloadLength;
    private final boolean allowMaskMismatch;
    private final Set<String> subprotocols;
    private final Set<String> allowedOrigins;
    private final boolean allowAnyOrigin;

    public static WebSocketService of(WebSocketServiceHandler handler) {
        return new WebSocketServiceBuilder(handler).build();
    }

    public static WebSocketServiceBuilder builder(WebSocketServiceHandler handler) {
        return new WebSocketServiceBuilder(handler);
    }

    WebSocketService(WebSocketServiceHandler handler, int maxFramePayloadLength, boolean allowMaskMismatch, Set<String> subprotocols, Set<String> allowedOrigins, boolean allowAnyOrigin) {
        this.handler = handler;
        this.maxFramePayloadLength = maxFramePayloadLength;
        this.allowMaskMismatch = allowMaskMismatch;
        this.subprotocols = subprotocols;
        this.allowedOrigins = allowedOrigins;
        this.allowAnyOrigin = allowAnyOrigin;
    }

    @Override
    protected HttpResponse doGet(ServiceRequestContext ctx, HttpRequest req) throws Exception {
        if (!ctx.sessionProtocol().isExplicitHttp1()) {
            return HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED);
        }
        RequestHeaders headers = req.headers();
        if (!WebSocketUtil.isHttp1WebSocketUpgradeRequest(headers)) {
            return HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, "The upgrade header must contain:\n  Upgrade: websocket\n  Connection: Upgrade");
        }
        HttpResponse invalidResponse = this.checkOrigin(ctx, headers);
        if (invalidResponse != null) {
            return invalidResponse;
        }
        invalidResponse = WebSocketService.checkVersion(headers);
        if (invalidResponse != null) {
            return invalidResponse;
        }
        String webSocketKey = headers.get(HttpHeaderNames.SEC_WEBSOCKET_KEY, "");
        if (webSocketKey.isEmpty()) {
            return HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, "missing Sec-WebSocket-Key header");
        }
        String accept = WebSocketService.generateSecWebSocketAccept(webSocketKey);
        ResponseHeadersBuilder responseHeadersBuilder = ResponseHeaders.builder(HttpStatus.SWITCHING_PROTOCOLS).add((CharSequence)HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString()).add((CharSequence)HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE.toString()).add((CharSequence)HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept);
        this.maybeAddSubprotocol(headers, responseHeadersBuilder);
        return this.handleUpgradeRequest(ctx, req, responseHeadersBuilder.build());
    }

    private void maybeAddSubprotocol(RequestHeaders headers, ResponseHeadersBuilder responseHeadersBuilder) {
        String subprotocols = headers.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "");
        if (subprotocols.isEmpty()) {
            return;
        }
        commaSplitter.splitToStream(subprotocols).filter(sub -> SUB_PROTOCOL_WILDCARD.equals(sub) || this.subprotocols.contains(sub)).findFirst().ifPresent(selectedSubprotocol -> responseHeadersBuilder.add((CharSequence)HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, (String)selectedSubprotocol));
    }

    private static String generateSecWebSocketAccept(String webSocketKey) {
        String acceptSeed = webSocketKey + WEBSOCKET_13_ACCEPT_GUID;
        byte[] sha1 = Hashing.sha1().hashBytes(acceptSeed.getBytes(StandardCharsets.US_ASCII)).asBytes();
        return Base64.getEncoder().encodeToString(sha1);
    }

    private HttpResponse handleUpgradeRequest(ServiceRequestContext ctx, HttpRequest req, ResponseHeaders responseHeaders) {
        WebSocketFrameDecoder decoder = new WebSocketFrameDecoder(ctx, this.maxFramePayloadLength, this.allowMaskMismatch, true);
        StreamMessage<WebSocketFrame> inboundFrames = req.decode(decoder, ctx.alloc());
        WebSocket outboundFrames = this.handler.handle(ctx, new WebSocketWrapper(inboundFrames));
        decoder.setOutboundWebSocket(outboundFrames);
        return HttpResponse.of(responseHeaders, outboundFrames.recoverAndResume(cause -> {
            if (cause instanceof ClosedStreamException) {
                return StreamMessage.aborted(cause);
            }
            ctx.logBuilder().responseCause((Throwable)cause);
            return StreamMessage.of(WebSocketUtil.newCloseWebSocketFrame(cause));
        }).map(frame -> HttpData.wrap(encoder.encode(ctx, (WebSocketFrame)frame))));
    }

    @Override
    protected HttpResponse doConnect(ServiceRequestContext ctx, HttpRequest req) throws Exception {
        if (!ctx.sessionProtocol().isExplicitHttp2()) {
            return HttpResponse.of(HttpStatus.METHOD_NOT_ALLOWED);
        }
        RequestHeaders headers = req.headers();
        if (!WebSocketUtil.isHttp2WebSocketUpgradeRequest(headers)) {
            logger.trace("RequestHeaders does not contain headers for WebSocket upgrade. headers: {}", (Object)headers);
            return HttpResponse.of(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, "The upgrade header must contain:\n  :protocol = websocket");
        }
        HttpResponse invalidResponse = this.checkOrigin(ctx, headers);
        if (invalidResponse != null) {
            return invalidResponse;
        }
        invalidResponse = WebSocketService.checkVersion(headers);
        if (invalidResponse != null) {
            return invalidResponse;
        }
        ResponseHeadersBuilder responseHeadersBuilder = ResponseHeaders.builder(HttpStatus.OK);
        this.maybeAddSubprotocol(headers, responseHeadersBuilder);
        return this.handleUpgradeRequest(ctx, req, responseHeadersBuilder.build());
    }

    @Nullable
    private HttpResponse checkOrigin(ServiceRequestContext ctx, RequestHeaders headers) {
        if (this.allowAnyOrigin) {
            return null;
        }
        String origin = headers.get(HttpHeaderNames.ORIGIN, "");
        if (origin.isEmpty()) {
            return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, "missing the origin header");
        }
        if (this.allowedOrigins.isEmpty()) {
            if (!WebSocketService.isSameOrigin(ctx, headers, origin)) {
                return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, "not allowed origin: " + origin);
            }
            return null;
        }
        if (!this.allowedOrigins.contains(origin)) {
            return HttpResponse.of(HttpStatus.FORBIDDEN, MediaType.PLAIN_TEXT_UTF_8, "not allowed origin: " + origin + ", allowed: " + this.allowedOrigins);
        }
        return null;
    }

    private static boolean isSameOrigin(ServiceRequestContext ctx, RequestHeaders headers, String origin) {
        int schemeDelimiter = origin.indexOf("://");
        if (schemeDelimiter < 0) {
            return false;
        }
        String scheme = origin.substring(0, schemeDelimiter);
        SessionProtocol originSessionProtocol = SessionProtocol.find(scheme);
        if (originSessionProtocol == null) {
            return false;
        }
        if (!(ctx.sessionProtocol().isHttp() && originSessionProtocol.isHttp() || ctx.sessionProtocol().isHttps() && originSessionProtocol.isHttps())) {
            return false;
        }
        String authority = headers.authority();
        assert (authority != null);
        HostAndPort authorityHostAndPort = HostAndPort.fromString(authority);
        String authorityHost = authorityHostAndPort.getHost();
        int authorityPort = authorityHostAndPort.getPortOrDefault(ctx.sessionProtocol().defaultPort());
        HostAndPort originHostAndPort = HostAndPort.fromString(origin.substring(schemeDelimiter + 3));
        String originHost = originHostAndPort.getHost();
        int originPort = originHostAndPort.getPortOrDefault(originSessionProtocol.defaultPort());
        return authorityPort == originPort && authorityHost.equals(originHost);
    }

    @Nullable
    private static HttpResponse checkVersion(RequestHeaders headers) {
        String version = headers.get(HttpHeaderNames.SEC_WEBSOCKET_VERSION);
        if (!WebSocketVersion.V13.toHttpHeaderValue().equalsIgnoreCase(version)) {
            return HttpResponse.of(UNSUPPORTED_WEB_SOCKET_VERSION, HttpData.ofUtf8("Only 13 version is supported."));
        }
        return null;
    }
}

