/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.util;

import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.IntroSelector;
import org.apache.lucene.util.Selector;

public class ScalarQuantizer {
    public static final int SCALAR_QUANTIZATION_SAMPLE_SIZE = 25000;
    private final float alpha;
    private final float scale;
    private final float minQuantile;
    private final float maxQuantile;
    private final float confidenceInterval;
    private static final Random random = new Random(42L);

    public ScalarQuantizer(float minQuantile, float maxQuantile, float confidenceInterval) {
        assert (maxQuantile >= minQuantile);
        this.minQuantile = minQuantile;
        this.maxQuantile = maxQuantile;
        this.scale = 127.0f / (maxQuantile - minQuantile);
        this.alpha = (maxQuantile - minQuantile) / 127.0f;
        this.confidenceInterval = confidenceInterval;
    }

    public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) {
        assert (src.length == dest.length);
        float correctiveOffset = 0.0f;
        for (int i = 0; i < src.length; ++i) {
            float v = src[i];
            float dx = Math.max(this.minQuantile, Math.min(this.maxQuantile, src[i])) - this.minQuantile;
            float dxs = this.scale * dx;
            float dxq = (float)Math.round(dxs) * this.alpha;
            correctiveOffset += this.minQuantile * (v - this.minQuantile / 2.0f) + (dx - dxq) * dxq;
            dest[i] = (byte)Math.round(dxs);
        }
        if (similarityFunction.equals((Object)VectorSimilarityFunction.EUCLIDEAN)) {
            return 0.0f;
        }
        return correctiveOffset;
    }

    public float recalculateCorrectiveOffset(byte[] quantizedVector, ScalarQuantizer oldQuantizer, VectorSimilarityFunction similarityFunction) {
        if (similarityFunction.equals((Object)VectorSimilarityFunction.EUCLIDEAN)) {
            return 0.0f;
        }
        float correctiveOffset = 0.0f;
        for (int i = 0; i < quantizedVector.length; ++i) {
            float v = oldQuantizer.alpha * (float)quantizedVector[i] + oldQuantizer.minQuantile;
            float dx = Math.max(this.minQuantile, Math.min(this.maxQuantile, v)) - this.minQuantile;
            float dxs = this.scale * dx;
            float dxq = (float)Math.round(dxs) * this.alpha;
            correctiveOffset += this.minQuantile * (v - this.minQuantile / 2.0f) + (dx - dxq) * dxq;
        }
        return correctiveOffset;
    }

    public void deQuantize(byte[] src, float[] dest) {
        assert (src.length == dest.length);
        for (int i = 0; i < src.length; ++i) {
            dest[i] = this.alpha * (float)src[i] + this.minQuantile;
        }
    }

    public float getLowerQuantile() {
        return this.minQuantile;
    }

    public float getUpperQuantile() {
        return this.maxQuantile;
    }

    public float getConfidenceInterval() {
        return this.confidenceInterval;
    }

    public float getConstantMultiplier() {
        return this.alpha * this.alpha;
    }

    public String toString() {
        return "ScalarQuantizer{minQuantile=" + this.minQuantile + ", maxQuantile=" + this.maxQuantile + ", confidenceInterval=" + this.confidenceInterval + "}";
    }

    static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) {
        int[] vectorsToTake = IntStream.range(0, sampleSize).toArray();
        for (int i = sampleSize; i < numFloatVecs; ++i) {
            int j = random.nextInt(i + 1);
            if (j >= sampleSize) continue;
            vectorsToTake[j] = i;
        }
        Arrays.sort(vectorsToTake);
        return vectorsToTake;
    }

    static float[] sampleVectors(FloatVectorValues floatVectorValues, int[] vectorsToTake) throws IOException {
        int dim = floatVectorValues.dimension();
        float[] values = new float[vectorsToTake.length * dim];
        int copyOffset = 0;
        int index = 0;
        for (int i : vectorsToTake) {
            while (index <= i) {
                floatVectorValues.nextDoc();
                ++index;
            }
            assert (floatVectorValues.docID() != Integer.MAX_VALUE);
            float[] floatVector = floatVectorValues.vectorValue();
            System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
            copyOffset += dim;
        }
        return values;
    }

    public static ScalarQuantizer fromVectors(FloatVectorValues floatVectorValues, float confidenceInterval) throws IOException {
        return ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, floatVectorValues.size(), 25000);
    }

    public static ScalarQuantizer fromVectors(FloatVectorValues floatVectorValues, float confidenceInterval, int totalVectorCount) throws IOException {
        return ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, totalVectorCount, 25000);
    }

    static ScalarQuantizer fromVectors(FloatVectorValues floatVectorValues, float confidenceInterval, int totalVectorCount, int quantizationSampleSize) throws IOException {
        assert (0.9f <= confidenceInterval && confidenceInterval <= 1.0f);
        if (totalVectorCount == 0) {
            return new ScalarQuantizer(0.0f, 0.0f, confidenceInterval);
        }
        if (confidenceInterval == 1.0f) {
            float min2 = Float.POSITIVE_INFINITY;
            float max2 = Float.NEGATIVE_INFINITY;
            while (floatVectorValues.nextDoc() != Integer.MAX_VALUE) {
                for (float v : floatVectorValues.vectorValue()) {
                    min2 = Math.min(min2, v);
                    max2 = Math.max(max2, v);
                }
            }
            return new ScalarQuantizer(min2, max2, confidenceInterval);
        }
        int dim = floatVectorValues.dimension();
        if (totalVectorCount <= quantizationSampleSize) {
            int copyOffset = 0;
            float[] values = new float[totalVectorCount * dim];
            while (floatVectorValues.nextDoc() != Integer.MAX_VALUE) {
                float[] floatVector = floatVectorValues.vectorValue();
                System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
                copyOffset += dim;
            }
            float[] upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(values, confidenceInterval);
            return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], confidenceInterval);
        }
        int numFloatVecs = totalVectorCount;
        int[] vectorsToTake = ScalarQuantizer.reservoirSampleIndices(numFloatVecs, quantizationSampleSize);
        float[] values = ScalarQuantizer.sampleVectors(floatVectorValues, vectorsToTake);
        float[] upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(values, confidenceInterval);
        return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], confidenceInterval);
    }

    static float[] getUpperAndLowerQuantile(float[] arr, float confidenceInterval) {
        assert (0.9f <= confidenceInterval && confidenceInterval <= 1.0f);
        int selectorIndex = (int)((float)arr.length * (1.0f - confidenceInterval) / 2.0f + 0.5f);
        if (selectorIndex > 0) {
            FloatSelector selector = new FloatSelector(arr);
            ((Selector)selector).select(0, arr.length, arr.length - selectorIndex);
            ((Selector)selector).select(0, arr.length - selectorIndex, selectorIndex);
        }
        float min2 = Float.POSITIVE_INFINITY;
        float max2 = Float.NEGATIVE_INFINITY;
        for (int i = selectorIndex; i < arr.length - selectorIndex; ++i) {
            min2 = Math.min(arr[i], min2);
            max2 = Math.max(arr[i], max2);
        }
        return new float[]{min2, max2};
    }

    private static class FloatSelector
    extends IntroSelector {
        float pivot = Float.NaN;
        private final float[] arr;

        private FloatSelector(float[] arr) {
            this.arr = arr;
        }

        @Override
        protected void setPivot(int i) {
            this.pivot = this.arr[i];
        }

        @Override
        protected int comparePivot(int j) {
            return Float.compare(this.pivot, this.arr[j]);
        }

        @Override
        protected void swap(int i, int j) {
            float tmp = this.arr[i];
            this.arr[i] = this.arr[j];
            this.arr[j] = tmp;
        }
    }
}

