/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.util.quantization;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.IntroSelector;
import org.apache.lucene.util.Selector;
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;

public class ScalarQuantizer {
    public static final int SCALAR_QUANTIZATION_SAMPLE_SIZE = 25000;
    static final int SCRATCH_SIZE = 20;
    private final float alpha;
    private final float scale;
    private final byte bits;
    private final float minQuantile;
    private final float maxQuantile;
    private static final Random random = new Random(42L);

    public ScalarQuantizer(float minQuantile, float maxQuantile, byte bits) {
        if (Float.isNaN(minQuantile) || Float.isInfinite(minQuantile) || Float.isNaN(maxQuantile) || Float.isInfinite(maxQuantile)) {
            throw new IllegalStateException("Scalar quantizer does not support infinite or NaN values");
        }
        assert (maxQuantile >= minQuantile);
        assert (bits > 0 && bits <= 8);
        this.minQuantile = minQuantile;
        this.maxQuantile = maxQuantile;
        this.bits = bits;
        float divisor = (1 << bits) - 1;
        this.scale = divisor / (maxQuantile - minQuantile);
        this.alpha = (maxQuantile - minQuantile) / divisor;
    }

    public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) {
        assert (src.length == dest.length);
        float correction = 0.0f;
        for (int i = 0; i < src.length; ++i) {
            correction += this.quantizeFloat(src[i], dest, i);
        }
        if (similarityFunction.equals((Object)VectorSimilarityFunction.EUCLIDEAN)) {
            return 0.0f;
        }
        return correction;
    }

    private float quantizeFloat(float v, byte[] dest, int destIndex) {
        assert (dest == null || destIndex < dest.length);
        float dx = v - this.minQuantile;
        float dxc = Math.max(this.minQuantile, Math.min(this.maxQuantile, v)) - this.minQuantile;
        float dxs = this.scale * dxc;
        float dxq = (float)Math.round(dxs) * this.alpha;
        if (dest != null) {
            dest[destIndex] = (byte)Math.round(dxs);
        }
        return this.minQuantile * (v - this.minQuantile / 2.0f) + (dx - dxq) * dxq;
    }

    public float recalculateCorrectiveOffset(byte[] quantizedVector, ScalarQuantizer oldQuantizer, VectorSimilarityFunction similarityFunction) {
        if (similarityFunction.equals((Object)VectorSimilarityFunction.EUCLIDEAN)) {
            return 0.0f;
        }
        float correctiveOffset = 0.0f;
        for (int i = 0; i < quantizedVector.length; ++i) {
            float v = oldQuantizer.alpha * (float)quantizedVector[i] + oldQuantizer.minQuantile;
            correctiveOffset += this.quantizeFloat(v, null, 0);
        }
        return correctiveOffset;
    }

    void deQuantize(byte[] src, float[] dest) {
        assert (src.length == dest.length);
        for (int i = 0; i < src.length; ++i) {
            dest[i] = this.alpha * (float)src[i] + this.minQuantile;
        }
    }

    public float getLowerQuantile() {
        return this.minQuantile;
    }

    public float getUpperQuantile() {
        return this.maxQuantile;
    }

    public float getConstantMultiplier() {
        return this.alpha * this.alpha;
    }

    public byte getBits() {
        return this.bits;
    }

    public String toString() {
        return "ScalarQuantizer{minQuantile=" + this.minQuantile + ", maxQuantile=" + this.maxQuantile + ", bits=" + this.bits + "}";
    }

    private static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) {
        int[] vectorsToTake = IntStream.range(0, sampleSize).toArray();
        for (int i = sampleSize; i < numFloatVecs; ++i) {
            int j = random.nextInt(i + 1);
            if (j >= sampleSize) continue;
            vectorsToTake[j] = i;
        }
        Arrays.sort(vectorsToTake);
        return vectorsToTake;
    }

    public static ScalarQuantizer fromVectors(FloatVectorValues floatVectorValues, float confidenceInterval, int totalVectorCount, byte bits) throws IOException {
        return ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, totalVectorCount, bits, 25000);
    }

    static ScalarQuantizer fromVectors(FloatVectorValues floatVectorValues, float confidenceInterval, int totalVectorCount, byte bits, int quantizationSampleSize) throws IOException {
        assert (0.9f <= confidenceInterval && confidenceInterval <= 1.0f);
        assert (quantizationSampleSize > 20);
        if (totalVectorCount == 0) {
            return new ScalarQuantizer(0.0f, 0.0f, bits);
        }
        if (confidenceInterval == 1.0f) {
            float min2 = Float.POSITIVE_INFINITY;
            float max2 = Float.NEGATIVE_INFINITY;
            while (floatVectorValues.nextDoc() != Integer.MAX_VALUE) {
                for (float v : floatVectorValues.vectorValue()) {
                    min2 = Math.min(min2, v);
                    max2 = Math.max(max2, v);
                }
            }
            return new ScalarQuantizer(min2, max2, bits);
        }
        float[] quantileGatheringScratch = new float[floatVectorValues.dimension() * Math.min(20, totalVectorCount)];
        int count = 0;
        double[] upperSum = new double[1];
        double[] lowerSum = new double[1];
        float[] confidenceIntervals = new float[]{confidenceInterval};
        if (totalVectorCount <= quantizationSampleSize) {
            int scratchSize = Math.min(20, totalVectorCount);
            int i = 0;
            while (floatVectorValues.nextDoc() != Integer.MAX_VALUE) {
                float[] vectorValue = floatVectorValues.vectorValue();
                System.arraycopy(vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length);
                if (++i != scratchSize) continue;
                ScalarQuantizer.extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
                i = 0;
                ++count;
            }
            return new ScalarQuantizer((float)lowerSum[0] / (float)count, (float)upperSum[0] / (float)count, bits);
        }
        int[] vectorsToTake = ScalarQuantizer.reservoirSampleIndices(totalVectorCount, quantizationSampleSize);
        int index = 0;
        int idx = 0;
        for (int i : vectorsToTake) {
            while (index <= i) {
                floatVectorValues.nextDoc();
                ++index;
            }
            assert (floatVectorValues.docID() != Integer.MAX_VALUE);
            float[] vectorValue = floatVectorValues.vectorValue();
            System.arraycopy(vectorValue, 0, quantileGatheringScratch, idx * vectorValue.length, vectorValue.length);
            if (++idx != 20) continue;
            ScalarQuantizer.extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
            ++count;
            idx = 0;
        }
        return new ScalarQuantizer((float)lowerSum[0] / (float)count, (float)upperSum[0] / (float)count, bits);
    }

    public static ScalarQuantizer fromVectorsAutoInterval(FloatVectorValues floatVectorValues, VectorSimilarityFunction function, int totalVectorCount, byte bits) throws IOException {
        if (totalVectorCount == 0) {
            return new ScalarQuantizer(0.0f, 0.0f, bits);
        }
        int sampleSize = Math.min(totalVectorCount, 1000);
        float[] quantileGatheringScratch = new float[floatVectorValues.dimension() * Math.min(20, totalVectorCount)];
        int count = 0;
        double[] upperSum = new double[2];
        double[] lowerSum = new double[2];
        ArrayList<float[]> sampledDocs = new ArrayList<float[]>(sampleSize);
        float[] confidenceIntervals = new float[]{1.0f - Math.min(32.0f, (float)floatVectorValues.dimension() / 10.0f) / (float)(floatVectorValues.dimension() + 1), 1.0f - 1.0f / (float)(floatVectorValues.dimension() + 1)};
        if (totalVectorCount <= sampleSize) {
            int scratchSize = Math.min(20, totalVectorCount);
            int i = 0;
            while (floatVectorValues.nextDoc() != Integer.MAX_VALUE) {
                ScalarQuantizer.gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, i);
                if (++i != scratchSize) continue;
                ScalarQuantizer.extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
                i = 0;
                ++count;
            }
        } else {
            int[] vectorsToTake = ScalarQuantizer.reservoirSampleIndices(totalVectorCount, 1000);
            int index = 0;
            int idx = 0;
            for (int i : vectorsToTake) {
                while (index <= i) {
                    floatVectorValues.nextDoc();
                    ++index;
                }
                assert (floatVectorValues.docID() != Integer.MAX_VALUE);
                ScalarQuantizer.gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, idx);
                if (++idx != 20) continue;
                ScalarQuantizer.extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
                ++count;
                idx = 0;
            }
        }
        float al = (float)lowerSum[1] / (float)count;
        float bu = (float)upperSum[1] / (float)count;
        float au = (float)lowerSum[0] / (float)count;
        float bl = (float)upperSum[0] / (float)count;
        if (Float.isNaN(al) || Float.isInfinite(al) || Float.isNaN(au) || Float.isInfinite(au) || Float.isNaN(bl) || Float.isInfinite(bl) || Float.isNaN(bu) || Float.isInfinite(bu)) {
            throw new IllegalStateException("Quantile calculation resulted in NaN or infinite values");
        }
        float[] lowerCandidates = new float[16];
        float[] upperCandidates = new float[16];
        int idx = 0;
        for (float i = 0.0f; i < 32.0f; i += 2.0f) {
            lowerCandidates[idx] = al + i * (au - al) / 32.0f;
            upperCandidates[idx] = bl + i * (bu - bl) / 32.0f;
            ++idx;
        }
        List<ScoreDocsAndScoreVariance> nearestNeighbors = ScalarQuantizer.findNearestNeighbors(sampledDocs, function);
        float[] bestPair = ScalarQuantizer.candidateGridSearch(nearestNeighbors, sampledDocs, lowerCandidates, upperCandidates, function, bits);
        return new ScalarQuantizer(bestPair[0], bestPair[1], bits);
    }

    private static void extractQuantiles(float[] confidenceIntervals, float[] quantileGatheringScratch, double[] upperSum, double[] lowerSum) {
        assert (confidenceIntervals.length == upperSum.length && confidenceIntervals.length == lowerSum.length);
        int i = 0;
        while (i < confidenceIntervals.length) {
            float[] upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(quantileGatheringScratch, confidenceIntervals[i]);
            int n = i;
            upperSum[n] = upperSum[n] + (double)upperAndLower[1];
            int n2 = i++;
            lowerSum[n2] = lowerSum[n2] + (double)upperAndLower[0];
        }
    }

    private static void gatherSample(FloatVectorValues floatVectorValues, float[] quantileGatheringScratch, List<float[]> sampledDocs, int i) throws IOException {
        float[] vectorValue = floatVectorValues.vectorValue();
        float[] copy = new float[vectorValue.length];
        System.arraycopy(vectorValue, 0, copy, 0, vectorValue.length);
        sampledDocs.add(copy);
        System.arraycopy(vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length);
    }

    private static float[] candidateGridSearch(List<ScoreDocsAndScoreVariance> nearestNeighbors, List<float[]> vectors, float[] lowerCandidates, float[] upperCandidates, VectorSimilarityFunction function, byte bits) {
        double mean;
        float upper;
        int i;
        double maxCorr = Double.NEGATIVE_INFINITY;
        float bestLower = 0.0f;
        float bestUpper = 0.0f;
        ScoreErrorCorrelator scoreErrorCorrelator = new ScoreErrorCorrelator(function, nearestNeighbors, vectors, bits);
        int bestQuandrantLower = 0;
        int bestQuandrantUpper = 0;
        for (i = 0; i < lowerCandidates.length; i += 4) {
            float lower = lowerCandidates[i];
            if (Float.isNaN(lower) || Float.isInfinite(lower)) {
                assert (false) : "Lower candidate is NaN or infinite";
                continue;
            }
            for (int j = 0; j < upperCandidates.length; j += 4) {
                upper = upperCandidates[j];
                if (Float.isNaN(upper) || Float.isInfinite(upper)) {
                    assert (false) : "Upper candidate is NaN or infinite";
                    continue;
                }
                if (upper <= lower || !((mean = scoreErrorCorrelator.scoreErrorCorrelation(lower, upper)) > maxCorr)) continue;
                maxCorr = mean;
                bestLower = lower;
                bestUpper = upper;
                bestQuandrantLower = i;
                bestQuandrantUpper = j;
            }
        }
        for (i = bestQuandrantLower + 1; i < bestQuandrantLower + 4; ++i) {
            for (int j = bestQuandrantUpper + 1; j < bestQuandrantUpper + 4; ++j) {
                float lower = lowerCandidates[i];
                upper = upperCandidates[j];
                if (Float.isNaN(lower) || Float.isInfinite(lower) || Float.isNaN(upper) || Float.isInfinite(upper)) {
                    assert (false) : "Lower or upper candidate is NaN or infinite";
                    continue;
                }
                if (upper <= lower || !((mean = scoreErrorCorrelator.scoreErrorCorrelation(lower, upper)) > maxCorr)) continue;
                maxCorr = mean;
                bestLower = lower;
                bestUpper = upper;
            }
        }
        return new float[]{bestLower, bestUpper};
    }

    private static List<ScoreDocsAndScoreVariance> findNearestNeighbors(List<float[]> vectors, VectorSimilarityFunction similarityFunction) {
        ArrayList<HitQueue> queues = new ArrayList<HitQueue>(vectors.size());
        queues.add(new HitQueue(10, false));
        for (int i = 0; i < vectors.size(); ++i) {
            float[] vector = vectors.get(i);
            for (int j = i + 1; j < vectors.size(); ++j) {
                float[] otherVector = vectors.get(j);
                float score = similarityFunction.compare(vector, otherVector);
                if (queues.size() <= j) {
                    queues.add(new HitQueue(10, false));
                }
                ((HitQueue)queues.get(i)).insertWithOverflow(new ScoreDoc(j, score));
                ((HitQueue)queues.get(j)).insertWithOverflow(new ScoreDoc(i, score));
            }
        }
        ArrayList<ScoreDocsAndScoreVariance> result = new ArrayList<ScoreDocsAndScoreVariance>(vectors.size());
        OnlineMeanAndVar meanAndVar = new OnlineMeanAndVar();
        for (int i = 0; i < vectors.size(); ++i) {
            HitQueue queue = (HitQueue)queues.get(i);
            ScoreDoc[] scoreDocs = new ScoreDoc[queue.size()];
            for (int j = queue.size() - 1; j >= 0; --j) {
                scoreDocs[j] = (ScoreDoc)queue.pop();
                assert (scoreDocs[j] != null);
                meanAndVar.add(scoreDocs[j].score);
            }
            result.add(new ScoreDocsAndScoreVariance(scoreDocs, meanAndVar.var()));
            meanAndVar.reset();
        }
        return result;
    }

    static float[] getUpperAndLowerQuantile(float[] arr, float confidenceInterval) {
        assert (arr.length > 0);
        if (arr.length <= 2) {
            Arrays.sort(arr);
            return new float[]{arr[0], arr[arr.length - 1]};
        }
        int selectorIndex = (int)((float)arr.length * (1.0f - confidenceInterval) / 2.0f + 0.5f);
        if (selectorIndex > 0) {
            FloatSelector selector = new FloatSelector(arr);
            ((Selector)selector).select(0, arr.length, arr.length - selectorIndex);
            ((Selector)selector).select(0, arr.length - selectorIndex, selectorIndex);
        }
        float min2 = Float.POSITIVE_INFINITY;
        float max2 = Float.NEGATIVE_INFINITY;
        for (int i = selectorIndex; i < arr.length - selectorIndex; ++i) {
            min2 = Math.min(arr[i], min2);
            max2 = Math.max(arr[i], max2);
        }
        return new float[]{min2, max2};
    }

    private static class ScoreErrorCorrelator {
        private final OnlineMeanAndVar corr = new OnlineMeanAndVar();
        private final OnlineMeanAndVar errors = new OnlineMeanAndVar();
        private final VectorSimilarityFunction function;
        private final List<ScoreDocsAndScoreVariance> nearestNeighbors;
        private final List<float[]> vectors;
        private final byte[] query;
        private final byte[] vector;
        private final byte bits;

        public ScoreErrorCorrelator(VectorSimilarityFunction function, List<ScoreDocsAndScoreVariance> nearestNeighbors, List<float[]> vectors, byte bits) {
            this.function = function;
            this.nearestNeighbors = nearestNeighbors;
            this.vectors = vectors;
            this.query = new byte[vectors.get(0).length];
            this.vector = new byte[vectors.get(0).length];
            this.bits = bits;
        }

        double scoreErrorCorrelation(float lowerQuantile, float upperQuantile) {
            this.corr.reset();
            ScalarQuantizer quantizer = new ScalarQuantizer(lowerQuantile, upperQuantile, this.bits);
            ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity(this.function, quantizer.getConstantMultiplier(), quantizer.bits);
            for (int i = 0; i < this.nearestNeighbors.size(); ++i) {
                float queryCorrection = quantizer.quantize(this.vectors.get(i), this.query, this.function);
                ScoreDocsAndScoreVariance scoreDocsAndScoreVariance = this.nearestNeighbors.get(i);
                ScoreDoc[] scoreDocs = scoreDocsAndScoreVariance.getScoreDocs();
                float scoreVariance = scoreDocsAndScoreVariance.scoreVariance;
                this.errors.reset();
                for (ScoreDoc scoreDoc : scoreDocs) {
                    float vectorCorrection = quantizer.quantize(this.vectors.get(scoreDoc.doc), this.vector, this.function);
                    float qScore = scalarQuantizedVectorSimilarity.score(this.query, queryCorrection, this.vector, vectorCorrection);
                    this.errors.add(qScore - scoreDoc.score);
                }
                this.corr.add(1.0f - this.errors.var() / scoreVariance);
            }
            return Double.isNaN(this.corr.mean) ? 0.0 : this.corr.mean;
        }
    }

    private static class OnlineMeanAndVar {
        private double mean = 0.0;
        private double var = 0.0;
        private int n = 0;

        private OnlineMeanAndVar() {
        }

        void reset() {
            this.mean = 0.0;
            this.var = 0.0;
            this.n = 0;
        }

        void add(double x) {
            ++this.n;
            double delta = x - this.mean;
            this.mean += delta / (double)this.n;
            this.var += delta * (x - this.mean);
        }

        float var() {
            return (float)(this.var / (double)(this.n - 1));
        }
    }

    private static class ScoreDocsAndScoreVariance {
        private final ScoreDoc[] scoreDocs;
        private final float scoreVariance;

        public ScoreDocsAndScoreVariance(ScoreDoc[] scoreDocs, float scoreVariance) {
            this.scoreDocs = scoreDocs;
            this.scoreVariance = scoreVariance;
        }

        public ScoreDoc[] getScoreDocs() {
            return this.scoreDocs;
        }
    }

    private static class FloatSelector
    extends IntroSelector {
        float pivot = Float.NaN;
        private final float[] arr;

        private FloatSelector(float[] arr) {
            this.arr = arr;
        }

        @Override
        protected void setPivot(int i) {
            this.pivot = this.arr[i];
        }

        @Override
        protected int comparePivot(int j) {
            return Float.compare(this.pivot, this.arr[j]);
        }

        @Override
        protected void swap(int i, int j) {
            float tmp = this.arr[i];
            this.arr[i] = this.arr[j];
            this.arr[j] = tmp;
        }
    }
}

