#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Additional Spark functions used in pandas-on-Spark.
"""
from pyspark.sql.column import Column
from pyspark.sql.utils import is_remote


def product(col: Column, dropna: bool) -> Column:
    if is_remote():
        from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

        return _invoke_function_over_columns(
            "pandas_product",
            col,
            lit(dropna),
        )

    else:
        from pyspark import SparkContext

        sc = SparkContext._active_spark_context
        return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna))


def stddev(col: Column, ddof: int) -> Column:
    if is_remote():
        from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

        return _invoke_function_over_columns(
            "pandas_stddev",
            col,
            lit(ddof),
        )

    else:
        from pyspark import SparkContext

        sc = SparkContext._active_spark_context
        return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof))


def var(col: Column, ddof: int) -> Column:
    if is_remote():
        from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

        return _invoke_function_over_columns(
            "pandas_var",
            col,
            lit(ddof),
        )

    else:
        from pyspark import SparkContext

        sc = SparkContext._active_spark_context
        return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof))


def skew(col: Column) -> Column:
    if is_remote():
        from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns

        return _invoke_function_over_columns(
            "pandas_skew",
            col,
        )

    else:
        from pyspark import SparkContext

        sc = SparkContext._active_spark_context
        return Column(sc._jvm.PythonSQLUtils.pandasSkewness(col._jc))


def kurt(col: Column) -> Column:
    if is_remote():
        from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns

        return _invoke_function_over_columns(
            "pandas_kurt",
            col,
        )

    else:
        from pyspark import SparkContext

        sc = SparkContext._active_spark_context
        return Column(sc._jvm.PythonSQLUtils.pandasKurtosis(col._jc))


def mode(col: Column, dropna: bool) -> Column:
    if is_remote():
        from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

        return _invoke_function_over_columns(
            "pandas_mode",
            col,
            lit(dropna),
        )

    else:
        from pyspark import SparkContext

        sc = SparkContext._active_spark_context
        return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna))


def covar(col1: Column, col2: Column, ddof: int) -> Column:
    if is_remote():
        from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

        return _invoke_function_over_columns(
            "pandas_covar",
            col1,
            col2,
            lit(ddof),
        )

    else:
        from pyspark import SparkContext

        sc = SparkContext._active_spark_context
        return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc, ddof))


def ewm(col: Column, alpha: float, ignore_na: bool) -> Column:
    if is_remote():
        from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit

        return _invoke_function_over_columns(
            "ewm",
            col,
            lit(alpha),
            lit(ignore_na),
        )

    else:
        from pyspark import SparkContext

        sc = SparkContext._active_spark_context
        return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na))


def null_index(col: Column) -> Column:
    if is_remote():
        from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns

        return _invoke_function_over_columns(
            "null_index",
            col,
        )

    else:
        from pyspark import SparkContext

        sc = SparkContext._active_spark_context
        return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
