// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

//! Contains functions and function factories to compare arrays.

use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::ArrowNativeType;
use arrow_schema::{ArrowError, DataType};
use std::cmp::Ordering;

/// Compare the values at two arbitrary indices in two arrays.
pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;

fn compare_primitives<T: ArrowPrimitiveType>(
    left: &dyn Array,
    right: &dyn Array,
) -> DynComparator
where
    T::Native: ArrowNativeTypeOp,
{
    let left: PrimitiveArray<T> = PrimitiveArray::from(left.to_data());
    let right: PrimitiveArray<T> = PrimitiveArray::from(right.to_data());
    Box::new(move |i, j| left.value(i).compare(right.value(j)))
}

fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator {
    let left: BooleanArray = BooleanArray::from(left.to_data());
    let right: BooleanArray = BooleanArray::from(right.to_data());

    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
}

fn compare_string(left: &dyn Array, right: &dyn Array) -> DynComparator {
    let left: StringArray = StringArray::from(left.to_data());
    let right: StringArray = StringArray::from(right.to_data());

    Box::new(move |i, j| left.value(i).cmp(right.value(j)))
}

fn compare_dict_primitive<K, V>(left: &dyn Array, right: &dyn Array) -> DynComparator
where
    K: ArrowDictionaryKeyType,
    V: ArrowPrimitiveType,
    V::Native: ArrowNativeTypeOp,
{
    let left = left.as_dictionary::<K>();
    let right = right.as_dictionary::<K>();

    let left_keys: PrimitiveArray<K> = PrimitiveArray::from(left.keys().to_data());
    let right_keys: PrimitiveArray<K> = PrimitiveArray::from(right.keys().to_data());
    let left_values: PrimitiveArray<V> = left.values().to_data().into();
    let right_values: PrimitiveArray<V> = right.values().to_data().into();

    Box::new(move |i: usize, j: usize| {
        let key_left = left_keys.value(i).as_usize();
        let key_right = right_keys.value(j).as_usize();
        let left = left_values.value(key_left);
        let right = right_values.value(key_right);
        left.compare(right)
    })
}

fn compare_dict_string<T>(left: &dyn Array, right: &dyn Array) -> DynComparator
where
    T: ArrowDictionaryKeyType,
{
    let left = left.as_dictionary::<T>();
    let right = right.as_dictionary::<T>();

    let left_keys: PrimitiveArray<T> = PrimitiveArray::from(left.keys().to_data());
    let right_keys: PrimitiveArray<T> = PrimitiveArray::from(right.keys().to_data());
    let left_values = StringArray::from(left.values().to_data());
    let right_values = StringArray::from(right.values().to_data());

    Box::new(move |i: usize, j: usize| {
        let key_left = left_keys.value(i).as_usize();
        let key_right = right_keys.value(j).as_usize();
        let left = left_values.value(key_left);
        let right = right_values.value(key_right);
        left.cmp(right)
    })
}

fn cmp_dict_primitive<VT>(
    key_type: &DataType,
    left: &dyn Array,
    right: &dyn Array,
) -> Result<DynComparator, ArrowError>
where
    VT: ArrowPrimitiveType,
    VT::Native: ArrowNativeTypeOp,
{
    use DataType::*;

    Ok(match key_type {
        UInt8 => compare_dict_primitive::<UInt8Type, VT>(left, right),
        UInt16 => compare_dict_primitive::<UInt16Type, VT>(left, right),
        UInt32 => compare_dict_primitive::<UInt32Type, VT>(left, right),
        UInt64 => compare_dict_primitive::<UInt64Type, VT>(left, right),
        Int8 => compare_dict_primitive::<Int8Type, VT>(left, right),
        Int16 => compare_dict_primitive::<Int16Type, VT>(left, right),
        Int32 => compare_dict_primitive::<Int32Type, VT>(left, right),
        Int64 => compare_dict_primitive::<Int64Type, VT>(left, right),
        t => {
            return Err(ArrowError::InvalidArgumentError(format!(
                "Dictionaries do not support keys of type {t:?}"
            )));
        }
    })
}

macro_rules! cmp_dict_primitive_helper {
    ($t:ty, $key_type_lhs:expr, $left:expr, $right:expr) => {
        cmp_dict_primitive::<$t>($key_type_lhs, $left, $right)?
    };
}

/// returns a comparison function that compares two values at two different positions
/// between the two arrays.
/// The arrays' types must be equal.
/// # Example
/// ```
/// use arrow_array::Int32Array;
/// use arrow_ord::ord::build_compare;
///
/// let array1 = Int32Array::from(vec![1, 2]);
/// let array2 = Int32Array::from(vec![3, 4]);
///
/// let cmp = build_compare(&array1, &array2).unwrap();
///
/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
/// ```
// This is a factory of comparisons.
// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
pub fn build_compare(
    left: &dyn Array,
    right: &dyn Array,
) -> Result<DynComparator, ArrowError> {
    use arrow_schema::{DataType::*, IntervalUnit::*, TimeUnit::*};
    Ok(match (left.data_type(), right.data_type()) {
        (a, b) if a != b => {
            return Err(ArrowError::InvalidArgumentError(
                "Can't compare arrays of different types".to_string(),
            ));
        }
        (Boolean, Boolean) => compare_boolean(left, right),
        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
        (Float16, Float16) => compare_primitives::<Float16Type>(left, right),
        (Float32, Float32) => compare_primitives::<Float32Type>(left, right),
        (Float64, Float64) => compare_primitives::<Float64Type>(left, right),
        (Decimal128(_, _), Decimal128(_, _)) => {
            compare_primitives::<Decimal128Type>(left, right)
        }
        (Decimal256(_, _), Decimal256(_, _)) => {
            compare_primitives::<Decimal256Type>(left, right)
        }
        (Date32, Date32) => compare_primitives::<Date32Type>(left, right),
        (Date64, Date64) => compare_primitives::<Date64Type>(left, right),
        (Time32(Second), Time32(Second)) => {
            compare_primitives::<Time32SecondType>(left, right)
        }
        (Time32(Millisecond), Time32(Millisecond)) => {
            compare_primitives::<Time32MillisecondType>(left, right)
        }
        (Time64(Microsecond), Time64(Microsecond)) => {
            compare_primitives::<Time64MicrosecondType>(left, right)
        }
        (Time64(Nanosecond), Time64(Nanosecond)) => {
            compare_primitives::<Time64NanosecondType>(left, right)
        }
        (Timestamp(Second, _), Timestamp(Second, _)) => {
            compare_primitives::<TimestampSecondType>(left, right)
        }
        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
            compare_primitives::<TimestampMillisecondType>(left, right)
        }
        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
            compare_primitives::<TimestampMicrosecondType>(left, right)
        }
        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
            compare_primitives::<TimestampNanosecondType>(left, right)
        }
        (Interval(YearMonth), Interval(YearMonth)) => {
            compare_primitives::<IntervalYearMonthType>(left, right)
        }
        (Interval(DayTime), Interval(DayTime)) => {
            compare_primitives::<IntervalDayTimeType>(left, right)
        }
        (Interval(MonthDayNano), Interval(MonthDayNano)) => {
            compare_primitives::<IntervalMonthDayNanoType>(left, right)
        }
        (Duration(Second), Duration(Second)) => {
            compare_primitives::<DurationSecondType>(left, right)
        }
        (Duration(Millisecond), Duration(Millisecond)) => {
            compare_primitives::<DurationMillisecondType>(left, right)
        }
        (Duration(Microsecond), Duration(Microsecond)) => {
            compare_primitives::<DurationMicrosecondType>(left, right)
        }
        (Duration(Nanosecond), Duration(Nanosecond)) => {
            compare_primitives::<DurationNanosecondType>(left, right)
        }
        (Utf8, Utf8) => compare_string(left, right),
        (LargeUtf8, LargeUtf8) => compare_string(left, right),
        (
            Dictionary(key_type_lhs, value_type_lhs),
            Dictionary(key_type_rhs, value_type_rhs),
        ) => {
            if key_type_lhs != key_type_rhs || value_type_lhs != value_type_rhs {
                return Err(ArrowError::InvalidArgumentError(
                    "Can't compare arrays of different types".to_string(),
                ));
            }

            let key_type_lhs = key_type_lhs.as_ref();
            downcast_primitive! {
                value_type_lhs.as_ref() => (cmp_dict_primitive_helper, key_type_lhs, left, right),
                Utf8 => match key_type_lhs {
                    UInt8 => compare_dict_string::<UInt8Type>(left, right),
                    UInt16 => compare_dict_string::<UInt16Type>(left, right),
                    UInt32 => compare_dict_string::<UInt32Type>(left, right),
                    UInt64 => compare_dict_string::<UInt64Type>(left, right),
                    Int8 => compare_dict_string::<Int8Type>(left, right),
                    Int16 => compare_dict_string::<Int16Type>(left, right),
                    Int32 => compare_dict_string::<Int32Type>(left, right),
                    Int64 => compare_dict_string::<Int64Type>(left, right),
                    lhs => {
                        return Err(ArrowError::InvalidArgumentError(format!(
                            "Dictionaries do not support keys of type {lhs:?}"
                        )));
                    }
                },
                t => {
                    return Err(ArrowError::InvalidArgumentError(format!(
                        "Dictionaries of value data type {t:?} are not supported"
                    )));
                }
            }
        }
        (FixedSizeBinary(_), FixedSizeBinary(_)) => {
            let left: FixedSizeBinaryArray = left.to_data().into();
            let right: FixedSizeBinaryArray = right.to_data().into();

            Box::new(move |i, j| left.value(i).cmp(right.value(j)))
        }
        (lhs, _) => {
            return Err(ArrowError::InvalidArgumentError(format!(
                "The data type type {lhs:?} has no natural order"
            )));
        }
    })
}

#[cfg(test)]
pub mod tests {
    use super::*;
    use arrow_array::{FixedSizeBinaryArray, Float64Array, Int32Array};
    use arrow_buffer::i256;
    use half::f16;
    use std::cmp::Ordering;
    use std::sync::Arc;

    #[test]
    fn test_fixed_size_binary() {
        let items = vec![vec![1u8], vec![2u8]];
        let array = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();

        let cmp = build_compare(&array, &array).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 1));
    }

    #[test]
    fn test_fixed_size_binary_fixed_size_binary() {
        let items = vec![vec![1u8]];
        let array1 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();
        let items = vec![vec![2u8]];
        let array2 = FixedSizeBinaryArray::try_from_iter(items.into_iter()).unwrap();

        let cmp = build_compare(&array1, &array2).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 0));
    }

    #[test]
    fn test_i32() {
        let array = Int32Array::from(vec![1, 2]);

        let cmp = build_compare(&array, &array).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 1));
    }

    #[test]
    fn test_i32_i32() {
        let array1 = Int32Array::from(vec![1]);
        let array2 = Int32Array::from(vec![2]);

        let cmp = build_compare(&array1, &array2).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 0));
    }

    #[test]
    fn test_f16() {
        let array = Float16Array::from(vec![f16::from_f32(1.0), f16::from_f32(2.0)]);

        let cmp = build_compare(&array, &array).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 1));
    }

    #[test]
    fn test_f64() {
        let array = Float64Array::from(vec![1.0, 2.0]);

        let cmp = build_compare(&array, &array).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 1));
    }

    #[test]
    fn test_f64_nan() {
        let array = Float64Array::from(vec![1.0, f64::NAN]);

        let cmp = build_compare(&array, &array).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 1));
        assert_eq!(Ordering::Equal, (cmp)(1, 1));
    }

    #[test]
    fn test_f64_zeros() {
        let array = Float64Array::from(vec![-0.0, 0.0]);

        let cmp = build_compare(&array, &array).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 1));
        assert_eq!(Ordering::Greater, (cmp)(1, 0));
    }

    #[test]
    fn test_decimal() {
        let array = vec![Some(5_i128), Some(2_i128), Some(3_i128)]
            .into_iter()
            .collect::<Decimal128Array>()
            .with_precision_and_scale(23, 6)
            .unwrap();

        let cmp = build_compare(&array, &array).unwrap();
        assert_eq!(Ordering::Less, (cmp)(1, 0));
        assert_eq!(Ordering::Greater, (cmp)(0, 2));
    }

    #[test]
    fn test_decimali256() {
        let array = vec![
            Some(i256::from_i128(5_i128)),
            Some(i256::from_i128(2_i128)),
            Some(i256::from_i128(3_i128)),
        ]
        .into_iter()
        .collect::<Decimal256Array>()
        .with_precision_and_scale(53, 6)
        .unwrap();

        let cmp = build_compare(&array, &array).unwrap();
        assert_eq!(Ordering::Less, (cmp)(1, 0));
        assert_eq!(Ordering::Greater, (cmp)(0, 2));
    }

    #[test]
    fn test_dict() {
        let data = vec!["a", "b", "c", "a", "a", "c", "c"];
        let array = data.into_iter().collect::<DictionaryArray<Int16Type>>();

        let cmp = build_compare(&array, &array).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 1));
        assert_eq!(Ordering::Equal, (cmp)(3, 4));
        assert_eq!(Ordering::Greater, (cmp)(2, 3));
    }

    #[test]
    fn test_multiple_dict() {
        let d1 = vec!["a", "b", "c", "d"];
        let a1 = d1.into_iter().collect::<DictionaryArray<Int16Type>>();
        let d2 = vec!["e", "f", "g", "a"];
        let a2 = d2.into_iter().collect::<DictionaryArray<Int16Type>>();

        let cmp = build_compare(&a1, &a2).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 0));
        assert_eq!(Ordering::Equal, (cmp)(0, 3));
        assert_eq!(Ordering::Greater, (cmp)(1, 3));
    }

    #[test]
    fn test_primitive_dict() {
        let values = Int32Array::from(vec![1_i32, 0, 2, 5]);
        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
        let array1 = DictionaryArray::new(keys, Arc::new(values));

        let values = Int32Array::from(vec![2_i32, 3, 4, 5]);
        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
        let array2 = DictionaryArray::new(keys, Arc::new(values));

        let cmp = build_compare(&array1, &array2).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 0));
        assert_eq!(Ordering::Less, (cmp)(0, 3));
        assert_eq!(Ordering::Equal, (cmp)(3, 3));
        assert_eq!(Ordering::Greater, (cmp)(3, 1));
        assert_eq!(Ordering::Greater, (cmp)(3, 2));
    }

    #[test]
    fn test_float_dict() {
        let values = Float32Array::from(vec![1.0, 0.5, 2.1, 5.5]);
        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
        let array1 = DictionaryArray::try_new(keys, Arc::new(values)).unwrap();

        let values = Float32Array::from(vec![1.2, 3.2, 4.0, 5.5]);
        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
        let array2 = DictionaryArray::new(keys, Arc::new(values));

        let cmp = build_compare(&array1, &array2).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 0));
        assert_eq!(Ordering::Less, (cmp)(0, 3));
        assert_eq!(Ordering::Equal, (cmp)(3, 3));
        assert_eq!(Ordering::Greater, (cmp)(3, 1));
        assert_eq!(Ordering::Greater, (cmp)(3, 2));
    }

    #[test]
    fn test_timestamp_dict() {
        let values = TimestampSecondArray::from(vec![1, 0, 2, 5]);
        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
        let array1 = DictionaryArray::new(keys, Arc::new(values));

        let values = TimestampSecondArray::from(vec![2, 3, 4, 5]);
        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
        let array2 = DictionaryArray::new(keys, Arc::new(values));

        let cmp = build_compare(&array1, &array2).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 0));
        assert_eq!(Ordering::Less, (cmp)(0, 3));
        assert_eq!(Ordering::Equal, (cmp)(3, 3));
        assert_eq!(Ordering::Greater, (cmp)(3, 1));
        assert_eq!(Ordering::Greater, (cmp)(3, 2));
    }

    #[test]
    fn test_interval_dict() {
        let values = IntervalDayTimeArray::from(vec![1, 0, 2, 5]);
        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
        let array1 = DictionaryArray::new(keys, Arc::new(values));

        let values = IntervalDayTimeArray::from(vec![2, 3, 4, 5]);
        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
        let array2 = DictionaryArray::new(keys, Arc::new(values));

        let cmp = build_compare(&array1, &array2).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 0));
        assert_eq!(Ordering::Less, (cmp)(0, 3));
        assert_eq!(Ordering::Equal, (cmp)(3, 3));
        assert_eq!(Ordering::Greater, (cmp)(3, 1));
        assert_eq!(Ordering::Greater, (cmp)(3, 2));
    }

    #[test]
    fn test_duration_dict() {
        let values = DurationSecondArray::from(vec![1, 0, 2, 5]);
        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
        let array1 = DictionaryArray::new(keys, Arc::new(values));

        let values = DurationSecondArray::from(vec![2, 3, 4, 5]);
        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
        let array2 = DictionaryArray::new(keys, Arc::new(values));

        let cmp = build_compare(&array1, &array2).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 0));
        assert_eq!(Ordering::Less, (cmp)(0, 3));
        assert_eq!(Ordering::Equal, (cmp)(3, 3));
        assert_eq!(Ordering::Greater, (cmp)(3, 1));
        assert_eq!(Ordering::Greater, (cmp)(3, 2));
    }

    #[test]
    fn test_decimal_dict() {
        let values = Decimal128Array::from(vec![1, 0, 2, 5]);
        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
        let array1 = DictionaryArray::new(keys, Arc::new(values));

        let values = Decimal128Array::from(vec![2, 3, 4, 5]);
        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
        let array2 = DictionaryArray::new(keys, Arc::new(values));

        let cmp = build_compare(&array1, &array2).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 0));
        assert_eq!(Ordering::Less, (cmp)(0, 3));
        assert_eq!(Ordering::Equal, (cmp)(3, 3));
        assert_eq!(Ordering::Greater, (cmp)(3, 1));
        assert_eq!(Ordering::Greater, (cmp)(3, 2));
    }

    #[test]
    fn test_decimal256_dict() {
        let values = Decimal256Array::from(vec![
            i256::from_i128(1),
            i256::from_i128(0),
            i256::from_i128(2),
            i256::from_i128(5),
        ]);
        let keys = Int8Array::from_iter_values([0, 0, 1, 3]);
        let array1 = DictionaryArray::new(keys, Arc::new(values));

        let values = Decimal256Array::from(vec![
            i256::from_i128(2),
            i256::from_i128(3),
            i256::from_i128(4),
            i256::from_i128(5),
        ]);
        let keys = Int8Array::from_iter_values([0, 1, 1, 3]);
        let array2 = DictionaryArray::new(keys, Arc::new(values));

        let cmp = build_compare(&array1, &array2).unwrap();

        assert_eq!(Ordering::Less, (cmp)(0, 0));
        assert_eq!(Ordering::Less, (cmp)(0, 3));
        assert_eq!(Ordering::Equal, (cmp)(3, 3));
        assert_eq!(Ordering::Greater, (cmp)(3, 1));
        assert_eq!(Ordering::Greater, (cmp)(3, 2));
    }
}
