/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package utils

import (
	"bytes"
	"encoding/binary"
	"math"
	"math/big"

	"github.com/icza/bitio"
	"github.com/pkg/errors"
)

type ReadBufferByteBased interface {
	ReadBuffer
	GetBytes() []byte
	GetTotalBytes() uint64
	PeekByte(offset byte) byte
}

func NewReadBufferByteBased(data []byte, options ...ReadBufferByteBasedOptions) ReadBufferByteBased {
	buffer := bytes.NewBuffer(data)
	reader := bitio.NewReader(buffer)
	b := &byteReadBuffer{
		data:      data,
		reader:    reader,
		pos:       uint64(0),
		byteOrder: binary.BigEndian,
	}
	for _, option := range options {
		option(b)
	}
	return b
}

type ReadBufferByteBasedOptions = func(b *byteReadBuffer)

func WithByteOrderForReadBufferByteBased(byteOrder binary.ByteOrder) ReadBufferByteBasedOptions {
	return func(b *byteReadBuffer) {
		b.byteOrder = byteOrder
	}
}

///////////////////////////////////////
///////////////////////////////////////
//
// Internal section
//

type byteReadBuffer struct {
	data      []byte
	reader    *bitio.Reader
	pos       uint64
	byteOrder binary.ByteOrder
}

//
// Internal section
//
///////////////////////////////////////
///////////////////////////////////////

func (rb *byteReadBuffer) SetByteOrder(byteOrder binary.ByteOrder) {
	rb.byteOrder = byteOrder
}

func (rb *byteReadBuffer) GetByteOrder() binary.ByteOrder {
	return rb.byteOrder
}

func (rb *byteReadBuffer) GetPos() uint16 {
	return uint16(rb.pos / 8)
}

func (rb *byteReadBuffer) Reset(pos uint16) {
	rb.pos = uint64(0)
	rb.reader = bitio.NewReader(bytes.NewBuffer(rb.data))
	bytesToSkip := make([]byte, pos)
	_, err := rb.reader.Read(bytesToSkip)
	if err != nil {
		panic(errors.Wrap(err, "Should not happen")) // TODO: maybe this is a possible occurence since we accept a argument, better returns a error
	}
	rb.pos = uint64(pos * 8)
}

func (rb *byteReadBuffer) GetBytes() []byte {
	return rb.data
}

func (rb *byteReadBuffer) GetTotalBytes() uint64 {
	return uint64(len(rb.data))
}

func (rb *byteReadBuffer) HasMore(bitLength uint8) bool {
	return (rb.pos + uint64(bitLength)) <= (uint64(len(rb.data)) * 8)
}

func (rb *byteReadBuffer) PeekByte(offset uint8) uint8 {
	return rb.data[rb.GetPos()+uint16(offset)]
}

func (rb *byteReadBuffer) PullContext(_ string, _ ...WithReaderArgs) error {
	return nil
}

func (rb *byteReadBuffer) ReadBit(_ string, _ ...WithReaderArgs) (bool, error) {
	rb.pos += 1
	return rb.reader.ReadBool()
}

func (rb *byteReadBuffer) ReadByte(_ string, _ ...WithReaderArgs) (byte, error) {
	rb.pos += 8
	return rb.reader.ReadByte()
}

func (rb *byteReadBuffer) ReadByteArray(_ string, numberOfBytes int, _ ...WithReaderArgs) ([]byte, error) {
	byteArray := make([]byte, numberOfBytes)
	for i := 0; i < numberOfBytes; i++ {
		rb.pos += 8
		readByte, err := rb.reader.ReadByte()
		if err != nil {
			return nil, err
		}
		byteArray[i] = readByte
	}
	return byteArray, nil
}

func (rb *byteReadBuffer) ReadUint8(_ string, bitLength uint8, _ ...WithReaderArgs) (uint8, error) {
	rb.pos += uint64(bitLength)
	res, err := rb.reader.ReadBits(bitLength)
	if err != nil {
		return 0, errors.Wrapf(err, "error reading %d bits", bitLength)
	}
	return uint8(res), nil
}

func (rb *byteReadBuffer) ReadUint16(logicalName string, bitLength uint8, _ ...WithReaderArgs) (uint16, error) {
	if rb.byteOrder == binary.LittleEndian {
		// TODO: indirection till we have a native LE implementation
		bigInt, err := rb.ReadBigInt(logicalName, uint64(bitLength))
		if err != nil {
			return 0, err
		}
		return uint16(bigInt.Uint64()), nil
	}
	rb.pos += uint64(bitLength)
	res, err := rb.reader.ReadBits(bitLength)
	if err != nil {
		return 0, errors.Wrapf(err, "error reading %d bits", bitLength)
	}
	return uint16(res), nil
}

func (rb *byteReadBuffer) ReadUint32(logicalName string, bitLength uint8, _ ...WithReaderArgs) (uint32, error) {
	if rb.byteOrder == binary.LittleEndian {
		// TODO: indirection till we have a native LE implementation
		bigInt, err := rb.ReadBigInt(logicalName, uint64(bitLength))
		if err != nil {
			return 0, err
		}
		return uint32(bigInt.Uint64()), nil
	}
	rb.pos += uint64(bitLength)
	res, err := rb.reader.ReadBits(bitLength)
	if err != nil {
		return 0, errors.Wrapf(err, "error reading %d bits", bitLength)
	}
	return uint32(res), nil
}

func (rb *byteReadBuffer) ReadUint64(logicalName string, bitLength uint8, _ ...WithReaderArgs) (uint64, error) {
	if rb.byteOrder == binary.LittleEndian {
		// TODO: indirection till we have a native LE implementation
		bigInt, err := rb.ReadBigInt(logicalName, uint64(bitLength))
		if err != nil {
			return 0, err
		}
		return bigInt.Uint64(), nil
	}
	rb.pos += uint64(bitLength)
	res, err := rb.reader.ReadBits(bitLength)
	if err != nil {
		return 0, errors.Wrapf(err, "error reading %d bits", bitLength)
	}
	return res, nil
}

func (rb *byteReadBuffer) ReadInt8(_ string, bitLength uint8, _ ...WithReaderArgs) (int8, error) {
	rb.pos += uint64(bitLength)
	res, err := rb.reader.ReadBits(bitLength)
	if err != nil {
		return 0, errors.Wrapf(err, "error reading %d bits", bitLength)
	}
	return int8(res), nil
}

func (rb *byteReadBuffer) ReadInt16(logicalName string, bitLength uint8, _ ...WithReaderArgs) (int16, error) {
	if rb.byteOrder == binary.LittleEndian {
		// TODO: indirection till we have a native LE implementation
		bigInt, err := rb.ReadBigInt(logicalName, uint64(bitLength))
		if err != nil {
			return 0, err
		}
		return int16(bigInt.Int64()), nil
	}
	rb.pos += uint64(bitLength)
	res, err := rb.reader.ReadBits(bitLength)
	if err != nil {
		return 0, errors.Wrapf(err, "error reading %d bits", bitLength)
	}
	return int16(res), nil
}

func (rb *byteReadBuffer) ReadInt32(logicalName string, bitLength uint8, _ ...WithReaderArgs) (int32, error) {
	if rb.byteOrder == binary.LittleEndian {
		// TODO: indirection till we have a native LE implementation
		bigInt, err := rb.ReadBigInt(logicalName, uint64(bitLength))
		if err != nil {
			return 0, err
		}
		return int32(bigInt.Int64()), nil
	}
	rb.pos += uint64(bitLength)
	res, err := rb.reader.ReadBits(bitLength)
	if err != nil {
		return 0, errors.Wrapf(err, "error reading %d bits", bitLength)
	}
	return int32(res), nil
}

func (rb *byteReadBuffer) ReadInt64(logicalName string, bitLength uint8, _ ...WithReaderArgs) (int64, error) {
	if rb.byteOrder == binary.LittleEndian {
		// TODO: indirection till we have a native LE implementation
		bigInt, err := rb.ReadBigInt(logicalName, uint64(bitLength))
		if err != nil {
			return 0, err
		}
		return bigInt.Int64(), nil
	}
	rb.pos += uint64(bitLength)
	res, err := rb.reader.ReadBits(bitLength)
	if err != nil {
		return 0, errors.Wrapf(err, "error reading %d bits", bitLength)
	}
	return int64(res), nil
}

func (rb *byteReadBuffer) ReadBigInt(_ string, bitLength uint64, _ ...WithReaderArgs) (*big.Int, error) {
	// TODO: highly experimental remove this comment when tested or verifyed
	res := big.NewInt(0)

	// TODO: maybe we can use left shift and or of big int
	rawBytes := make([]byte, 0)
	correction := uint8(0)

	for remainingBits := bitLength; remainingBits > 0; {
		// we can max read 64 bit with bitio
		bitToRead := uint8(64)
		if remainingBits < 64 {
			bitToRead = uint8(remainingBits)
		}
		// we now read the bits
		data, err := rb.reader.ReadBits(bitToRead)
		if err != nil {
			return nil, errors.Wrapf(err, "error reading %d bits", bitLength)
		}
		rb.pos += bitLength

		// and check for uneven bits for a right shift at the end
		correction = 64 - bitToRead
		data <<= correction

		dataBytes := make([]byte, 8)
		binary.BigEndian.PutUint64(dataBytes, data)
		rawBytes = append(rawBytes, dataBytes...)

		remainingBits -= uint64(bitToRead)
	}

	res.SetBytes(rawBytes)

	// now we need to shift the last correction to right again
	res.Rsh(res, uint(correction))
	if rb.byteOrder == binary.LittleEndian {
		originalByteLength := len(rawBytes) - int(correction/8)
		resBytes := res.Bytes()
		padding := make([]byte, originalByteLength-len(resBytes))
		resBytes = append(padding, resBytes...)
		if rb.byteOrder == binary.LittleEndian {
			for i, j := 0, len(resBytes)-1; i <= j; i, j = i+1, j-1 {
				resBytes[i], resBytes[j] = resBytes[j], resBytes[i]
			}
		}
		res.SetBytes(resBytes)
	}

	return res, nil
}

func (rb *byteReadBuffer) ReadFloat32(logicalName string, bitLength uint8, _ ...WithReaderArgs) (float32, error) {
	if bitLength == 32 {
		rb.pos += uint64(bitLength)
		var uintValue uint32
		_uintValue, err := rb.reader.ReadBits(bitLength)
		if err != nil {
			return 0, errors.Wrapf(err, "error reading %d bits", bitLength)
		}
		uintValue = uint32(_uintValue)
		if rb.byteOrder == binary.LittleEndian {
			array := make([]byte, 4)
			binary.LittleEndian.PutUint32(array, uintValue)
			uintValue = binary.BigEndian.Uint32(array)
		}
		return math.Float32frombits(uintValue), nil
	} else if bitLength < 32 {
		// TODO: Note ... this is the format as described in the KNX specification
		var err error
		sign, err := rb.ReadBit(logicalName)
		if err != nil {
			return 0.0, errors.Wrap(err, "error reading sign")
		}
		exp, err := rb.ReadInt32(logicalName, 5)
		if err != nil {
			return 0.0, errors.Wrap(err, "error reading exponent")
		}
		mantissa, err := rb.ReadUint32(logicalName, 10)
		// In the mantissa notation actually the first bit is omitted, we need to add it back
		f := (0.01 * float64(mantissa)) * math.Pow(float64(2), float64(exp))
		if sign {
			return -float32(f), nil
		}
		return float32(f), nil
	} else {
		return 0.0, errors.New("too many bits for float32")
	}
}

func (rb *byteReadBuffer) ReadFloat64(_ string, bitLength uint8, _ ...WithReaderArgs) (float64, error) {
	rb.pos += uint64(bitLength)
	uintValue, err := rb.reader.ReadBits(bitLength)
	if err != nil {
		return 0, errors.Wrapf(err, "error reading %d bits", bitLength)
	}
	if rb.byteOrder == binary.LittleEndian {
		array := make([]byte, 8)
		binary.LittleEndian.PutUint64(array, uintValue)
		uintValue = binary.BigEndian.Uint64(array)
	}
	res := math.Float64frombits(uintValue)
	return res, nil
}

func (rb *byteReadBuffer) ReadBigFloat(logicalName string, bitLength uint8, _ ...WithReaderArgs) (*big.Float, error) {
	readFloat64, err := rb.ReadFloat64(logicalName, bitLength)
	if err != nil {
		return nil, errors.Wrap(err, "Error reading float64")
	}
	return big.NewFloat(readFloat64), nil
}

func (rb *byteReadBuffer) ReadString(logicalName string, bitLength uint32, encoding string, _ ...WithReaderArgs) (string, error) {
	stringBytes, err := rb.ReadByteArray(logicalName, int(bitLength/8))
	if err != nil {
		return "", errors.Wrap(err, "Error reading big int")
	}
	// TODO: make the null-termination a reader arg
	// End the string at the 0-character.
	for i, value := range stringBytes {
		if value == 0x00 {
			return string(stringBytes[0:i]), nil
		}
	}
	return string(stringBytes), nil
}

func (rb *byteReadBuffer) CloseContext(_ string, _ ...WithReaderArgs) error {
	return nil
}
