/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.index.sai.memory;

import io.github.jbellis.jvector.util.Bits;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.NavigableSet;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.cassandra.db.Clustering;
import org.apache.cassandra.db.DecoratedKey;
import org.apache.cassandra.db.PartitionPosition;
import org.apache.cassandra.dht.AbstractBounds;
import org.apache.cassandra.index.sai.IndexContext;
import org.apache.cassandra.index.sai.QueryContext;
import org.apache.cassandra.index.sai.VectorQueryContext;
import org.apache.cassandra.index.sai.disk.format.IndexDescriptor;
import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata;
import org.apache.cassandra.index.sai.disk.v1.vector.OnHeapGraph;
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
import org.apache.cassandra.index.sai.iterators.KeyRangeListIterator;
import org.apache.cassandra.index.sai.memory.MemoryIndex;
import org.apache.cassandra.index.sai.plan.Expression;
import org.apache.cassandra.index.sai.utils.PrimaryKey;
import org.apache.cassandra.index.sai.utils.PrimaryKeys;
import org.apache.cassandra.index.sai.utils.RangeUtil;
import org.apache.cassandra.index.sai.utils.TypeUtil;
import org.apache.cassandra.tracing.Tracing;
import org.apache.cassandra.utils.Pair;
import org.apache.cassandra.utils.bytecomparable.ByteComparable;

public class VectorMemoryIndex
extends MemoryIndex {
    private final OnHeapGraph<PrimaryKey> graph;
    private final LongAdder writeCount = new LongAdder();
    private PrimaryKey minimumKey;
    private PrimaryKey maximumKey;
    private final NavigableSet<PrimaryKey> primaryKeys = new ConcurrentSkipListSet<PrimaryKey>();

    public VectorMemoryIndex(IndexContext indexContext) {
        super(indexContext);
        this.graph = new OnHeapGraph(indexContext.getValidator(), indexContext.getIndexWriterConfig());
    }

    @Override
    public synchronized long add(DecoratedKey key, Clustering<?> clustering, ByteBuffer value) {
        if (value == null || value.remaining() == 0) {
            return 0L;
        }
        PrimaryKey primaryKey = this.indexContext.hasClustering() ? this.indexContext.keyFactory().create(key, clustering) : this.indexContext.keyFactory().create(key);
        return this.index(primaryKey, value);
    }

    private long index(PrimaryKey primaryKey, ByteBuffer value) {
        this.updateKeyBounds(primaryKey);
        this.writeCount.increment();
        this.primaryKeys.add(primaryKey);
        return this.graph.add(value, primaryKey, OnHeapGraph.InvalidVectorBehavior.FAIL);
    }

    @Override
    public long update(DecoratedKey key, Clustering<?> clustering, ByteBuffer oldValue, ByteBuffer newValue) {
        boolean different;
        int newRemaining;
        int oldRemaining = oldValue == null ? 0 : oldValue.remaining();
        int n = newRemaining = newValue == null ? 0 : newValue.remaining();
        if (oldRemaining == 0 && newRemaining == 0) {
            return 0L;
        }
        if (oldRemaining != newRemaining) {
            assert (oldRemaining == 0 || newRemaining == 0);
            different = true;
        } else {
            different = IntStream.range(0, oldRemaining).anyMatch(i -> oldValue.get(i) != newValue.get(i));
        }
        long bytesUsed = 0L;
        if (different) {
            PrimaryKey primaryKey = this.indexContext.hasClustering() ? this.indexContext.keyFactory().create(key, clustering) : this.indexContext.keyFactory().create(key);
            this.updateKeyBounds(primaryKey);
            if (newRemaining > 0) {
                bytesUsed += this.graph.add(newValue, primaryKey, OnHeapGraph.InvalidVectorBehavior.FAIL);
            }
            if (oldRemaining > 0) {
                bytesUsed -= this.graph.remove(oldValue, primaryKey);
            }
            if (newRemaining <= 0 && oldRemaining > 0) {
                this.primaryKeys.remove(primaryKey);
            }
        }
        return bytesUsed;
    }

    private void updateKeyBounds(PrimaryKey primaryKey) {
        if (this.minimumKey == null) {
            this.minimumKey = primaryKey;
        } else if (primaryKey.compareTo(this.minimumKey) < 0) {
            this.minimumKey = primaryKey;
        }
        if (this.maximumKey == null) {
            this.maximumKey = primaryKey;
        } else if (primaryKey.compareTo(this.maximumKey) > 0) {
            this.maximumKey = primaryKey;
        }
    }

    @Override
    public KeyRangeIterator search(QueryContext queryContext, Expression expr, AbstractBounds<PartitionPosition> keyRange) {
        Bits bits;
        assert (expr.getOp() == Expression.IndexOperator.ANN) : "Only ANN is supported for vector search, received " + expr.getOp();
        VectorQueryContext vectorQueryContext = queryContext.vectorContext();
        ByteBuffer buffer = expr.lower.value.raw;
        float[] qv = TypeUtil.decomposeVector(this.indexContext, buffer);
        if (!RangeUtil.coversFullRing(keyRange)) {
            Set<PrimaryKey> resultKeys;
            boolean leftInclusive = ((PartitionPosition)keyRange.left).kind() != PartitionPosition.Kind.MAX_BOUND;
            boolean rightInclusive = ((PartitionPosition)keyRange.right).kind() != PartitionPosition.Kind.MIN_BOUND;
            boolean isMaxToken = ((PartitionPosition)keyRange.right).getToken().isMinimum();
            PrimaryKey left = this.indexContext.keyFactory().create(((PartitionPosition)keyRange.left).getToken());
            PrimaryKey right = isMaxToken ? null : this.indexContext.keyFactory().create(((PartitionPosition)keyRange.right).getToken());
            Set<PrimaryKey> set = resultKeys = isMaxToken ? this.primaryKeys.tailSet(left, leftInclusive) : this.primaryKeys.subSet(left, leftInclusive, right, rightInclusive);
            if (!vectorQueryContext.getShadowedPrimaryKeys().isEmpty()) {
                resultKeys = resultKeys.stream().filter(pk -> !vectorQueryContext.containsShadowedPrimaryKey((PrimaryKey)pk)).collect(Collectors.toSet());
            }
            if (resultKeys.isEmpty()) {
                return KeyRangeIterator.empty();
            }
            int bruteForceRows = this.maxBruteForceRows(vectorQueryContext.limit(), resultKeys.size(), this.graph.size());
            Tracing.trace("Search range covers {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", resultKeys.size(), bruteForceRows, this.graph.size(), vectorQueryContext.limit());
            if (resultKeys.size() < Math.max(vectorQueryContext.limit(), bruteForceRows)) {
                return new ReorderingRangeIterator(new PriorityQueue<PrimaryKey>(resultKeys));
            }
            bits = new KeyRangeFilteringBits(keyRange, vectorQueryContext.bitsetForShadowedPrimaryKeys(this.graph));
        } else {
            bits = queryContext.vectorContext().bitsetForShadowedPrimaryKeys(this.graph);
        }
        PriorityQueue<PrimaryKey> keyQueue = this.graph.search(qv, queryContext.vectorContext().limit(), bits);
        if (keyQueue.isEmpty()) {
            return KeyRangeIterator.empty();
        }
        return new ReorderingRangeIterator(keyQueue);
    }

    @Override
    public KeyRangeIterator limitToTopResults(List<PrimaryKey> primaryKeys, Expression expression, int limit) {
        KeyFilteringBits bits;
        if (this.minimumKey == null) {
            return KeyRangeIterator.empty();
        }
        List<PrimaryKey> results = primaryKeys.stream().dropWhile(k -> k.compareTo(this.minimumKey) < 0).takeWhile(k -> k.compareTo(this.maximumKey) <= 0).collect(Collectors.toList());
        int maxBruteForceRows = this.maxBruteForceRows(limit, results.size(), this.graph.size());
        Tracing.trace("SAI materialized {} rows; max brute force rows is {} for memtable index with {} nodes, LIMIT {}", results.size(), maxBruteForceRows, this.graph.size(), limit);
        if (results.size() <= maxBruteForceRows) {
            if (results.isEmpty()) {
                return KeyRangeIterator.empty();
            }
            return new KeyRangeListIterator(this.minimumKey, this.maximumKey, results);
        }
        ByteBuffer buffer = expression.lower.value.raw;
        float[] qv = TypeUtil.decomposeVector(this.indexContext, buffer);
        PriorityQueue<PrimaryKey> keyQueue = this.graph.search(qv, limit, bits = new KeyFilteringBits(results));
        if (keyQueue.isEmpty()) {
            return KeyRangeIterator.empty();
        }
        return new ReorderingRangeIterator(keyQueue);
    }

    private int maxBruteForceRows(int limit, int nPermittedOrdinals, int graphSize) {
        int expectedNodesVisited = VectorMemoryIndex.expectedNodesVisited(limit, nPermittedOrdinals, graphSize);
        int expectedComparisons = this.indexContext.getIndexWriterConfig().getMaximumNodeConnections() * expectedNodesVisited;
        double memoryToDiskFactor = 0.25;
        return (int)Math.max((double)limit, memoryToDiskFactor * (double)expectedComparisons);
    }

    public static int expectedNodesVisited(int limit, int nPermittedOrdinals, int graphSize) {
        int sizeRestriction = Math.min(nPermittedOrdinals, graphSize);
        int raw = (int)(0.7 * Math.pow(Math.log(graphSize), 2.0) * Math.pow(graphSize, 0.33) * Math.pow(Math.log(limit), 2.0) * Math.pow(Math.log((double)graphSize / (double)sizeRestriction), 2.0) / Math.pow(sizeRestriction, 0.13));
        return Math.min(Math.max(raw, Math.min(limit, graphSize)), graphSize);
    }

    @Override
    public Iterator<Pair<ByteComparable, PrimaryKeys>> iterator() {
        throw new UnsupportedOperationException();
    }

    @Override
    public SegmentMetadata.ComponentMetadataMap writeDirect(IndexDescriptor indexDescriptor, IndexContext indexContext, Function<PrimaryKey, Integer> postingTransformer) throws IOException {
        return this.graph.writeData(indexDescriptor, indexContext, postingTransformer);
    }

    @Override
    public boolean isEmpty() {
        return this.graph.isEmpty();
    }

    @Override
    @Nullable
    public ByteBuffer getMinTerm() {
        return null;
    }

    @Override
    @Nullable
    public ByteBuffer getMaxTerm() {
        return null;
    }

    private class KeyFilteringBits
    implements Bits {
        private final List<PrimaryKey> results;

        public KeyFilteringBits(List<PrimaryKey> results) {
            this.results = results;
        }

        public boolean get(int i) {
            Collection<PrimaryKey> pk = VectorMemoryIndex.this.graph.keysFromOrdinal(i);
            return this.results.stream().anyMatch(pk::contains);
        }

        public int length() {
            return this.results.size();
        }
    }

    private class ReorderingRangeIterator
    extends KeyRangeIterator {
        private final PriorityQueue<PrimaryKey> keyQueue;

        ReorderingRangeIterator(PriorityQueue<PrimaryKey> keyQueue) {
            super(VectorMemoryIndex.this.minimumKey, VectorMemoryIndex.this.maximumKey, keyQueue.size());
            this.keyQueue = keyQueue;
        }

        @Override
        protected void performSkipTo(PrimaryKey nextKey) {
            while (!this.keyQueue.isEmpty() && this.keyQueue.peek().compareTo(nextKey) < 0) {
                this.keyQueue.poll();
            }
        }

        @Override
        public void close() {
        }

        @Override
        protected PrimaryKey computeNext() {
            if (this.keyQueue.isEmpty()) {
                return (PrimaryKey)this.endOfData();
            }
            return this.keyQueue.poll();
        }
    }

    private class KeyRangeFilteringBits
    implements Bits {
        private final AbstractBounds<PartitionPosition> keyRange;
        @Nullable
        private final Bits bits;

        public KeyRangeFilteringBits(@Nullable AbstractBounds<PartitionPosition> keyRange, Bits bits) {
            this.keyRange = keyRange;
            this.bits = bits;
        }

        public boolean get(int ordinal) {
            if (this.bits != null && !this.bits.get(ordinal)) {
                return false;
            }
            Collection<PrimaryKey> keys = VectorMemoryIndex.this.graph.keysFromOrdinal(ordinal);
            return keys.stream().anyMatch(k -> this.keyRange.contains(k.partitionKey()));
        }

        public int length() {
            return VectorMemoryIndex.this.graph.size();
        }
    }
}

