/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.quantization.models.quantizationState;

import java.io.IOException;
import java.util.Arrays;
import lombok.Generated;
import lombok.NonNull;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateSerializer;

public final class MultiBitScalarQuantizationState
implements QuantizationState {
    @NonNull
    private ScalarQuantizationParams quantizationParams;
    @NonNull
    private float[][] thresholds;
    private float[][] rotationMatrix;

    @Override
    public ScalarQuantizationParams getQuantizationParams() {
        return this.quantizationParams;
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeVInt(Version.CURRENT.id);
        this.quantizationParams.writeTo(out);
        out.writeVInt(this.thresholds.length);
        for (float[] row : this.thresholds) {
            out.writeFloatArray(row);
        }
        if (Version.CURRENT.onOrAfter(Version.V_3_2_0)) {
            if (this.rotationMatrix != null) {
                out.writeBoolean(true);
                out.writeVInt(this.rotationMatrix.length);
                for (float[] row : this.rotationMatrix) {
                    out.writeFloatArray(row);
                }
            } else {
                out.writeBoolean(false);
            }
        }
    }

    public MultiBitScalarQuantizationState(StreamInput in) throws IOException {
        int version = in.readVInt();
        this.quantizationParams = new ScalarQuantizationParams(in, version);
        int rows = in.readVInt();
        this.thresholds = new float[rows][];
        for (int i = 0; i < rows; ++i) {
            this.thresholds[i] = in.readFloatArray();
        }
        if (Version.fromId((int)version).onOrAfter(Version.V_3_2_0) && in.readBoolean()) {
            int dims = in.readVInt();
            this.rotationMatrix = new float[dims][];
            for (int i = 0; i < dims; ++i) {
                this.rotationMatrix[i] = in.readFloatArray();
            }
        }
    }

    @Override
    public byte[] toByteArray() throws IOException {
        return QuantizationStateSerializer.serialize(this);
    }

    public static MultiBitScalarQuantizationState fromByteArray(byte[] bytes) throws IOException {
        return (MultiBitScalarQuantizationState)QuantizationStateSerializer.deserialize(bytes, MultiBitScalarQuantizationState::new);
    }

    @Override
    public int getBytesPerVector() {
        if (this.thresholds == null || this.thresholds.length == 0 || this.thresholds[0] == null) {
            throw new IllegalStateException("Error in getBytesStoredPerVector: The thresholds array is not initialized.");
        }
        int totalBits = this.thresholds.length * this.thresholds[0].length;
        return (totalBits + 7) / 8;
    }

    @Override
    public int getDimensions() {
        if (this.thresholds == null || this.thresholds.length == 0 || this.thresholds[0] == null) {
            throw new IllegalStateException("Error in getting Dimension: The thresholds array is not initialized.");
        }
        int originalDimensions = this.thresholds[0].length;
        int bitsPerDimension = this.thresholds.length;
        int totalBinaryDimensions = originalDimensions * bitsPerDimension;
        int alignedBinaryDimensions = totalBinaryDimensions + 7 & 0xFFFFFFF8;
        return alignedBinaryDimensions;
    }

    @Override
    public long ramBytesUsed() {
        long size = RamUsageEstimator.shallowSizeOfInstance(MultiBitScalarQuantizationState.class);
        size += RamUsageEstimator.shallowSizeOf((Object)this.quantizationParams);
        size += RamUsageEstimator.shallowSizeOf((Object[])this.thresholds);
        for (float[] row : this.thresholds) {
            size += RamUsageEstimator.sizeOf((float[])row);
        }
        if (this.rotationMatrix != null) {
            size += RamUsageEstimator.shallowSizeOf((Object[])this.rotationMatrix);
            for (float[] row : this.rotationMatrix) {
                size += RamUsageEstimator.sizeOf((float[])row);
            }
        }
        return size;
    }

    @Generated
    private static float[][] $default$rotationMatrix() {
        return null;
    }

    @Generated
    public static MultiBitScalarQuantizationStateBuilder builder() {
        return new MultiBitScalarQuantizationStateBuilder();
    }

    @NonNull
    @Generated
    public float[][] getThresholds() {
        return this.thresholds;
    }

    @Generated
    public float[][] getRotationMatrix() {
        return this.rotationMatrix;
    }

    @Generated
    public MultiBitScalarQuantizationState(@NonNull ScalarQuantizationParams quantizationParams, @NonNull float[][] thresholds, float[][] rotationMatrix) {
        if (quantizationParams == null) {
            throw new NullPointerException("quantizationParams is marked non-null but is null");
        }
        if (thresholds == null) {
            throw new NullPointerException("thresholds is marked non-null but is null");
        }
        this.quantizationParams = quantizationParams;
        this.thresholds = thresholds;
        this.rotationMatrix = rotationMatrix;
    }

    @Generated
    public MultiBitScalarQuantizationState() {
        this.rotationMatrix = MultiBitScalarQuantizationState.$default$rotationMatrix();
    }

    @Generated
    public static class MultiBitScalarQuantizationStateBuilder {
        @Generated
        private ScalarQuantizationParams quantizationParams;
        @Generated
        private float[][] thresholds;
        @Generated
        private boolean rotationMatrix$set;
        @Generated
        private float[][] rotationMatrix$value;

        @Generated
        MultiBitScalarQuantizationStateBuilder() {
        }

        @Generated
        public MultiBitScalarQuantizationStateBuilder quantizationParams(@NonNull ScalarQuantizationParams quantizationParams) {
            if (quantizationParams == null) {
                throw new NullPointerException("quantizationParams is marked non-null but is null");
            }
            this.quantizationParams = quantizationParams;
            return this;
        }

        @Generated
        public MultiBitScalarQuantizationStateBuilder thresholds(@NonNull float[][] thresholds) {
            if (thresholds == null) {
                throw new NullPointerException("thresholds is marked non-null but is null");
            }
            this.thresholds = thresholds;
            return this;
        }

        @Generated
        public MultiBitScalarQuantizationStateBuilder rotationMatrix(float[][] rotationMatrix) {
            this.rotationMatrix$value = rotationMatrix;
            this.rotationMatrix$set = true;
            return this;
        }

        @Generated
        public MultiBitScalarQuantizationState build() {
            float[][] rotationMatrix$value = this.rotationMatrix$value;
            if (!this.rotationMatrix$set) {
                rotationMatrix$value = MultiBitScalarQuantizationState.$default$rotationMatrix();
            }
            return new MultiBitScalarQuantizationState(this.quantizationParams, this.thresholds, rotationMatrix$value);
        }

        @Generated
        public String toString() {
            return "MultiBitScalarQuantizationState.MultiBitScalarQuantizationStateBuilder(quantizationParams=" + String.valueOf(this.quantizationParams) + ", thresholds=" + Arrays.deepToString((Object[])this.thresholds) + ", rotationMatrix$value=" + Arrays.deepToString((Object[])this.rotationMatrix$value) + ")";
        }
    }
}

