/*
 * Decompiled with CFR 0.152.
 */
package org.apache.arrow.flight.grpc;

import io.grpc.Context;
import io.grpc.Contexts;
import io.grpc.ForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightMethod;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.grpc.MetadataAdapter;
import org.apache.arrow.flight.grpc.RequestContextAdapter;
import org.apache.arrow.flight.grpc.StatusUtils;

public class ServerInterceptorAdapter
implements ServerInterceptor {
    public static final Context.Key<Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware>> SERVER_MIDDLEWARE_KEY = Context.key((String)"arrow.flight.server_middleware");
    private final List<KeyFactory<?>> factories;

    public ServerInterceptorAdapter(List<KeyFactory<?>> factories) {
        this.factories = factories;
    }

    public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
        CallInfo info = new CallInfo(FlightMethod.fromProtocol(call.getMethodDescriptor().getFullMethodName()));
        final ArrayList middleware = new ArrayList();
        LinkedHashMap middlewareMap = new LinkedHashMap();
        MetadataAdapter headerAdapter = new MetadataAdapter(headers);
        RequestContextAdapter requestContextAdapter = new RequestContextAdapter();
        for (KeyFactory<?> factory : this.factories) {
            Object m;
            try {
                m = ((KeyFactory)factory).factory.onCallStarted(info, headerAdapter, requestContextAdapter);
            }
            catch (FlightRuntimeException e) {
                call.close(StatusUtils.toGrpcStatus(e.status()), new Metadata());
                return new ServerCall.Listener<ReqT>(){};
            }
            middleware.add(m);
            middlewareMap.put(((KeyFactory)factory).key, m);
        }
        Context contextWithMiddlewareAndRequestsOptions = Context.current().withValue(SERVER_MIDDLEWARE_KEY, Collections.unmodifiableMap(middlewareMap)).withValue(RequestContextAdapter.REQUEST_CONTEXT_KEY, (Object)requestContextAdapter);
        ForwardingServerCall.SimpleForwardingServerCall forwardingServerCall = new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(call){
            boolean sentHeaders;
            {
                super(delegate);
                this.sentHeaders = false;
            }

            public void sendHeaders(Metadata headers) {
                this.sentHeaders = true;
                try {
                    MetadataAdapter headerAdapter = new MetadataAdapter(headers);
                    middleware.forEach(m -> m.onBeforeSendingHeaders(headerAdapter));
                }
                finally {
                    super.sendHeaders(headers);
                }
            }

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            public void close(Status status, Metadata trailers) {
                try {
                    if (!this.sentHeaders) {
                        MetadataAdapter headerAdapter = new MetadataAdapter(trailers);
                        middleware.forEach(m -> m.onBeforeSendingHeaders(headerAdapter));
                    }
                }
                finally {
                    super.close(status, trailers);
                }
                CallStatus flightStatus = StatusUtils.fromGrpcStatus(status);
                middleware.forEach(m -> m.onCallCompleted(flightStatus));
            }
        };
        return Contexts.interceptCall((Context)contextWithMiddlewareAndRequestsOptions, (ServerCall)forwardingServerCall, (Metadata)headers, next);
    }

    public static class KeyFactory<T extends FlightServerMiddleware> {
        private final FlightServerMiddleware.Key<T> key;
        private final FlightServerMiddleware.Factory<T> factory;

        public KeyFactory(FlightServerMiddleware.Key<T> key, FlightServerMiddleware.Factory<T> factory) {
            this.key = key;
            this.factory = factory;
        }
    }
}

