################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import unittest

from pyflink.table import DataTypes
from pyflink.table.udf import TableFunction, udtf, ScalarFunction, udf
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkOldStreamTableTestCase, \
    PyFlinkBlinkStreamTableTestCase, PyFlinkOldBatchTableTestCase, PyFlinkBlinkBatchTableTestCase


class UserDefinedTableFunctionTests(object):

    def test_table_function(self):
        self._register_table_sink(
            ['a', 'b', 'c'],
            [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()])

        multi_emit = udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()])
        multi_num = udf(MultiNum(), result_type=DataTypes.BIGINT())

        t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c'])
        t = t.join_lateral(multi_emit(t.a, multi_num(t.b)).alias('x', 'y'))
        t = t.left_outer_join_lateral(condition_multi_emit(t.x, t.y).alias('m')) \
            .select("x, y, m")
        t = t.left_outer_join_lateral(identity(t.m).alias('n')) \
            .select("x, y, n")
        actual = self._get_output(t)
        self.assert_equals(actual,
                           ["+I[1, 0, null]", "+I[1, 1, null]", "+I[2, 0, null]", "+I[2, 1, null]",
                            "+I[3, 0, 0]", "+I[3, 0, 1]", "+I[3, 0, 2]", "+I[3, 1, 1]",
                            "+I[3, 1, 2]", "+I[3, 2, 2]", "+I[3, 3, null]"])

    def test_table_function_with_sql_query(self):
        self._register_table_sink(
            ['a', 'b', 'c'],
            [DataTypes.BIGINT(), DataTypes.BIGINT(), DataTypes.BIGINT()])

        self.t_env.create_temporary_system_function(
            "multi_emit", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]))

        t = self.t_env.from_elements([(1, 1, 3), (2, 1, 6), (3, 2, 9)], ['a', 'b', 'c'])
        self.t_env.register_table("MyTable", t)
        t = self.t_env.sql_query(
            "SELECT a, x, y FROM MyTable LEFT JOIN LATERAL TABLE(multi_emit(a, b)) as T(x, y)"
            " ON TRUE")
        actual = self._get_output(t)
        self.assert_equals(actual, ["+I[1, 1, 0]", "+I[2, 2, 0]", "+I[3, 3, 0]", "+I[3, 3, 1]"])

    def _register_table_sink(self, field_names: list, field_types: list):
        table_sink = source_sink_utils.TestAppendSink(field_names, field_types)
        self.t_env.register_table_sink("Results", table_sink)

    def _get_output(self, t):
        t.execute_insert("Results").wait()
        return source_sink_utils.results()


class PyFlinkStreamUserDefinedTableFunctionTests(UserDefinedTableFunctionTests,
                                                 PyFlinkOldStreamTableTestCase):
    pass


class PyFlinkBlinkStreamUserDefinedFunctionTests(UserDefinedTableFunctionTests,
                                                 PyFlinkBlinkStreamTableTestCase):
    pass


class PyFlinkBlinkBatchUserDefinedFunctionTests(UserDefinedTableFunctionTests,
                                                PyFlinkBlinkBatchTableTestCase):
    pass


class PyFlinkBatchUserDefinedTableFunctionTests(UserDefinedTableFunctionTests,
                                                PyFlinkOldBatchTableTestCase):
    def _register_table_sink(self, field_names: list, field_types: list):
        pass

    def _get_output(self, t):
        return self.collect(t)

    def test_row_type_as_input_types_and_result_types(self):
        # test input_types and result_types are DataTypes.ROW
        a = udtf(lambda i: i,
                 input_types=DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT())]),
                 result_types=DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT())]))

        self.assertEqual(a._input_types,
                         [DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT())])])
        self.assertEqual(a._result_types,
                         [DataTypes.ROW([DataTypes.FIELD("a", DataTypes.BIGINT())])])


class MultiEmit(TableFunction, unittest.TestCase):

    def open(self, function_context):
        mg = function_context.get_metric_group()
        self.counter = mg.add_group("key", "value").counter("my_counter")
        self.counter_sum = 0

    def eval(self, x, y):
        self.counter.inc(y)
        self.counter_sum += y
        for i in range(y):
            yield x, i


@udtf(result_types=[DataTypes.BIGINT()])
def identity(x):
    if x is not None:
        from pyflink.common import Row
        yield Row(x)


# test specify the input_types
@udtf(input_types=[DataTypes.BIGINT(), DataTypes.BIGINT()],
      result_types=DataTypes.BIGINT())
def condition_multi_emit(x, y):
    if x == 3:
        return range(y, x)


class MultiNum(ScalarFunction):
    def eval(self, x):
        return x * 2


if __name__ == '__main__':
    import unittest

    try:
        import xmlrunner

        testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
    except ImportError:
        testRunner = None
    unittest.main(testRunner=testRunner, verbosity=2)
