/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.transforms;

import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.util.UUID;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.Timer;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.transforms.AutoValue_GroupIntoBatches_BatchingParams;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.ShardedKey;
import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.joda.time.ReadableDuration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GroupIntoBatches<K, InputT>
extends PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, Iterable<InputT>>>> {
    private final BatchingParams<InputT> params;
    private static final UUID workerUuid = UUID.randomUUID();

    private GroupIntoBatches(BatchingParams<InputT> params) {
        this.params = params;
    }

    public static <K, InputT> GroupIntoBatches<K, InputT> ofSize(long batchSize) {
        Preconditions.checkState((batchSize < Long.MAX_VALUE ? 1 : 0) != 0);
        return new GroupIntoBatches(BatchingParams.create(batchSize, Long.MAX_VALUE, null, Duration.ZERO));
    }

    public static <K, InputT> GroupIntoBatches<K, InputT> ofByteSize(long batchSizeBytes) {
        Preconditions.checkState((batchSizeBytes < Long.MAX_VALUE ? 1 : 0) != 0);
        return new GroupIntoBatches(BatchingParams.create(Long.MAX_VALUE, batchSizeBytes, null, Duration.ZERO));
    }

    public static <K, InputT> GroupIntoBatches<K, InputT> ofByteSize(long batchSizeBytes, SerializableFunction<InputT, Long> getElementByteSize) {
        Preconditions.checkState((batchSizeBytes < Long.MAX_VALUE ? 1 : 0) != 0);
        return new GroupIntoBatches<K, InputT>(BatchingParams.create(Long.MAX_VALUE, batchSizeBytes, getElementByteSize, Duration.ZERO));
    }

    public BatchingParams<InputT> getBatchingParams() {
        return this.params;
    }

    public GroupIntoBatches<K, InputT> withMaxBufferingDuration(Duration duration) {
        Preconditions.checkArgument((duration != null && !duration.isShorterThan((ReadableDuration)Duration.ZERO) ? 1 : 0) != 0, (Object)"max buffering duration should be a non-negative value");
        return new GroupIntoBatches<K, InputT>(BatchingParams.create(this.params.getBatchSize(), this.params.getBatchSizeBytes(), this.params.getElementByteSize(), duration));
    }

    @Experimental
    public WithShardedKey withShardedKey() {
        return new WithShardedKey();
    }

    @Override
    public PCollection<KV<K, Iterable<InputT>>> expand(PCollection<KV<K, InputT>> input) {
        Duration allowedLateness = input.getWindowingStrategy().getAllowedLateness();
        Preconditions.checkArgument((boolean)(input.getCoder() instanceof KvCoder), (Object)"coder specified in the input PCollection is not a KvCoder");
        KvCoder inputCoder = (KvCoder)input.getCoder();
        Coder<?> valueCoder = inputCoder.getCoderArguments().get(1);
        SerializableFunction<InputT, Long> weigher = this.params.getWeigher(valueCoder);
        return (PCollection)input.apply(ParDo.of(new GroupIntoBatchesDoFn(this.params.getBatchSize(), this.params.getBatchSizeBytes(), weigher, allowedLateness, this.params.getMaxBufferingDuration(), valueCoder)));
    }

    @VisibleForTesting
    private static class GroupIntoBatchesDoFn<K, InputT>
    extends DoFn<KV<K, InputT>, KV<K, Iterable<InputT>>> {
        private static final Logger LOG = LoggerFactory.getLogger(GroupIntoBatchesDoFn.class);
        private final long batchSize;
        private final long batchSizeBytes;
        @Nullable
        private final SerializableFunction<InputT, Long> weigher;
        private final Duration allowedLateness;
        private final Duration maxBufferingDuration;
        private static final String END_OF_WINDOW_ID = "endOFWindow";
        @DoFn.TimerId(value="endOFWindow")
        private final TimerSpec windowTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME);
        private static final String END_OF_BUFFERING_ID = "endOfBuffering";
        @DoFn.TimerId(value="endOfBuffering")
        private final TimerSpec bufferingTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
        private static final String BATCH_ID = "batch";
        @DoFn.StateId(value="batch")
        private final StateSpec<BagState<InputT>> batchSpec;
        private static final String NUM_ELEMENTS_IN_BATCH_ID = "numElementsInBatch";
        @DoFn.StateId(value="numElementsInBatch")
        private final StateSpec<CombiningState<Long, long[], Long>> batchSizeSpec;
        private static final String NUM_BYTES_IN_BATCH_ID = "numBytesInBatch";
        @DoFn.StateId(value="numBytesInBatch")
        private final StateSpec<CombiningState<Long, long[], Long>> batchSizeBytesSpec;
        private final long prefetchFrequency;

        GroupIntoBatchesDoFn(long batchSize, long batchSizeBytes, @Nullable SerializableFunction<InputT, Long> weigher, Duration allowedLateness, Duration maxBufferingDuration, Coder<InputT> inputValueCoder) {
            this.batchSize = batchSize;
            this.batchSizeBytes = batchSizeBytes;
            this.weigher = weigher;
            this.allowedLateness = allowedLateness;
            this.maxBufferingDuration = maxBufferingDuration;
            this.batchSpec = StateSpecs.bag(inputValueCoder);
            Combine.BinaryCombineLongFn sumCombineFn = new Combine.BinaryCombineLongFn(){

                @Override
                public long identity() {
                    return 0L;
                }

                @Override
                public long apply(long left, long right) {
                    return left + right;
                }
            };
            this.batchSizeSpec = StateSpecs.combining(sumCombineFn);
            this.batchSizeBytesSpec = StateSpecs.combining(sumCombineFn);
            this.prefetchFrequency = batchSize / 5L <= 1L ? Long.MAX_VALUE : batchSize / 5L;
        }

        @DoFn.ProcessElement
        public void processElement(@DoFn.TimerId(value="endOFWindow") Timer windowTimer, @DoFn.TimerId(value="endOfBuffering") Timer bufferingTimer, @DoFn.StateId(value="batch") BagState<InputT> batch, @DoFn.StateId(value="numElementsInBatch") CombiningState<Long, long[], Long> storedBatchSize, @DoFn.StateId(value="numBytesInBatch") CombiningState<Long, long[], Long> storedBatchSizeBytes, @DoFn.Element KV<K, InputT> element, BoundedWindow window, DoFn.OutputReceiver<KV<K, Iterable<InputT>>> receiver) {
            Instant windowEnds = window.maxTimestamp().plus((ReadableDuration)this.allowedLateness);
            LOG.debug("*** SET TIMER *** to point in time {} for window {}", (Object)windowEnds, (Object)window);
            windowTimer.set(windowEnds);
            LOG.debug("*** BATCH *** Add element for window {} ", (Object)window);
            batch.add(element.getValue());
            storedBatchSize.add(1L);
            if (this.weigher != null) {
                storedBatchSizeBytes.add(this.weigher.apply(element.getValue()));
                storedBatchSizeBytes.readLater();
            }
            long num = storedBatchSize.read();
            if (this.maxBufferingDuration.isLongerThan((ReadableDuration)Duration.ZERO) && num == 1L) {
                bufferingTimer.offset(this.maxBufferingDuration).setRelative();
            }
            if (num % this.prefetchFrequency == 0L) {
                batch.readLater();
            }
            if (num >= this.batchSize || this.batchSizeBytes != Long.MAX_VALUE && storedBatchSizeBytes.read() >= this.batchSizeBytes) {
                LOG.debug("*** END OF BATCH *** for window {}", (Object)window.toString());
                this.flushBatch(receiver, element.getKey(), batch, storedBatchSize, storedBatchSizeBytes);
                if (this.maxBufferingDuration.isLongerThan((ReadableDuration)Duration.ZERO)) {
                    bufferingTimer.offset(this.maxBufferingDuration).setRelative();
                }
            }
        }

        @DoFn.OnTimer(value="endOfBuffering")
        public void onBufferingTimer(DoFn.OutputReceiver<KV<K, Iterable<InputT>>> receiver, @DoFn.Timestamp Instant timestamp, @DoFn.Key K key, @DoFn.StateId(value="batch") BagState<InputT> batch, @DoFn.StateId(value="numElementsInBatch") CombiningState<Long, long[], Long> storedBatchSize, @DoFn.StateId(value="numBytesInBatch") CombiningState<Long, long[], Long> storedBatchSizeBytes, @DoFn.TimerId(value="endOfBuffering") Timer bufferingTimer) {
            LOG.debug("*** END OF BUFFERING *** for timer timestamp {} with buffering duration {}", (Object)timestamp, (Object)this.maxBufferingDuration);
            this.flushBatch(receiver, key, batch, storedBatchSize, storedBatchSizeBytes);
        }

        @DoFn.OnWindowExpiration
        public void onWindowExpiration(DoFn.OutputReceiver<KV<K, Iterable<InputT>>> receiver, @DoFn.Key K key, @DoFn.StateId(value="batch") BagState<InputT> batch, @DoFn.StateId(value="numElementsInBatch") CombiningState<Long, long[], Long> storedBatchSize, @DoFn.StateId(value="numBytesInBatch") CombiningState<Long, long[], Long> storedBatchSizeBytes) {
            this.flushBatch(receiver, key, batch, storedBatchSize, storedBatchSizeBytes);
        }

        @DoFn.OnTimer(value="endOFWindow")
        public void onWindowTimer(DoFn.OutputReceiver<KV<K, Iterable<InputT>>> receiver, @DoFn.Timestamp Instant timestamp, @DoFn.Key K key, @DoFn.StateId(value="batch") BagState<InputT> batch, @DoFn.StateId(value="numElementsInBatch") CombiningState<Long, long[], Long> storedBatchSize, @DoFn.StateId(value="numBytesInBatch") CombiningState<Long, long[], Long> storedBatchSizeBytes, BoundedWindow window) {
            LOG.debug("*** END OF WINDOW *** for timer timestamp {} in windows {}", (Object)timestamp, (Object)window.toString());
            this.flushBatch(receiver, key, batch, storedBatchSize, storedBatchSizeBytes);
        }

        private void flushBatch(DoFn.OutputReceiver<KV<K, Iterable<InputT>>> receiver, K key, BagState<InputT> batch, CombiningState<Long, long[], Long> storedBatchSize, CombiningState<Long, long[], Long> storedBatchSizeBytes) {
            Object values = batch.read();
            if (!Iterables.isEmpty((Iterable)values)) {
                receiver.output(KV.of(key, values));
            }
            batch.clear();
            LOG.debug("*** BATCH *** clear");
            storedBatchSize.clear();
            storedBatchSizeBytes.clear();
        }
    }

    private static class ByteSizeObserver
    extends ElementByteSizeObserver {
        private long elementByteSize = 0L;

        private ByteSizeObserver() {
        }

        @Override
        protected void reportElementSize(long elementByteSize) {
            this.elementByteSize += elementByteSize;
        }

        public long getElementByteSize() {
            return this.elementByteSize;
        }
    }

    public class WithShardedKey
    extends PTransform<PCollection<KV<K, InputT>>, PCollection<KV<ShardedKey<K>, Iterable<InputT>>>> {
        private WithShardedKey() {
        }

        public BatchingParams<InputT> getBatchingParams() {
            return GroupIntoBatches.this.params;
        }

        @Override
        public PCollection<KV<ShardedKey<K>, Iterable<InputT>>> expand(PCollection<KV<K, InputT>> input) {
            Preconditions.checkArgument((boolean)(input.getCoder() instanceof KvCoder), (Object)"coder specified in the input PCollection is not a KvCoder");
            KvCoder inputCoder = (KvCoder)input.getCoder();
            Coder<?> keyCoder = inputCoder.getCoderArguments().get(0);
            Coder<?> valueCoder = inputCoder.getCoderArguments().get(1);
            return (PCollection)((PCollection)input.apply(MapElements.via(new SimpleFunction<KV<K, InputT>, KV<ShardedKey<K>, InputT>>(){

                @Override
                public KV<ShardedKey<K>, InputT> apply(KV<K, InputT> input) {
                    long tid = Thread.currentThread().getId();
                    ByteBuffer buffer = ByteBuffer.allocate(24);
                    buffer.putLong(workerUuid.getMostSignificantBits());
                    buffer.putLong(workerUuid.getLeastSignificantBits());
                    buffer.putLong(tid);
                    return KV.of(ShardedKey.of(input.getKey(), buffer.array()), input.getValue());
                }
            }))).setCoder(KvCoder.of(ShardedKey.Coder.of(keyCoder), valueCoder)).apply(new GroupIntoBatches(this.getBatchingParams()));
        }
    }

    @AutoValue
    public static abstract class BatchingParams<InputT>
    implements Serializable {
        public static <InputT> BatchingParams<InputT> create(long batchSize, long batchSizeBytes, SerializableFunction<InputT, Long> elementByteSize, Duration maxBufferingDuration) {
            return new AutoValue_GroupIntoBatches_BatchingParams<InputT>(batchSize, batchSizeBytes, elementByteSize, maxBufferingDuration);
        }

        public abstract long getBatchSize();

        public abstract long getBatchSizeBytes();

        @Nullable
        public abstract SerializableFunction<InputT, Long> getElementByteSize();

        public abstract Duration getMaxBufferingDuration();

        public SerializableFunction<InputT, Long> getWeigher(Coder<InputT> valueCoder) {
            SerializableFunction<Object, Long> weigher = this.getElementByteSize();
            if (this.getBatchSizeBytes() < Long.MAX_VALUE && weigher == null) {
                weigher = element -> {
                    try {
                        ByteSizeObserver observer = new ByteSizeObserver();
                        valueCoder.registerByteSizeObserver(element, observer);
                        observer.advance();
                        return observer.getElementByteSize();
                    }
                    catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                };
            }
            return weigher;
        }
    }
}

