/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.table.planner.plan.utils

import org.apache.flink.table.functions.UserDefinedFunction
import org.apache.flink.table.planner.typeutils.DataViewUtils.DistinctViewSpec
import org.apache.flink.table.runtime.dataview.DataViewSpec
import org.apache.flink.table.types.DataType

import org.apache.calcite.rel.core.AggregateCall

import scala.collection.mutable.ArrayBuffer

/**
  * The information about aggregate function call
  *
  * @param agg  calcite agg call
  * @param function AggregateFunction or DeclarativeAggregateFunction
  * @param aggIndex the index of the aggregate call in the aggregation list
  * @param argIndexes the aggregate arguments indexes in the input
  * @param externalArgTypes  input types
  * @param externalAccTypes  accumulator types
  * @param viewSpecs  data view specs
  * @param externalResultType the result type of aggregate
  * @param consumeRetraction whether the aggregate consumes retractions
  */
case class AggregateInfo(
    agg: AggregateCall,
    function: UserDefinedFunction,
    aggIndex: Int,
    argIndexes: Array[Int],
    externalArgTypes: Array[DataType],
    externalAccTypes: Array[DataType],
    viewSpecs: Array[DataViewSpec],
    externalResultType: DataType,
    consumeRetraction: Boolean)

/**
  * The information about shared distinct of the aggregates. It indicates which aggregates are
  * distinct aggregates.
  *
  * @param argIndexes the distinct aggregate arguments indexes in the input
  * @param keyType the distinct key type
  * @param accType the accumulator type of the shared distinct
  * @param excludeAcc whether the distinct acc should excluded from the aggregate accumulator.
  *                    e.g. when this works in incremental mode, returns true, otherwise false.
  * @param dataViewSpec data view spec about this distinct agg used to generate state access,
  * None when dataview is not worked in state mode
  * @param consumeRetraction whether the distinct agg consumes retractions
  * @param filterArgs the ordinal of filter argument for each aggregate, -1 means without filter
  * @param aggIndexes the distinct aggregate index in the aggregation list
  */
case class DistinctInfo(
    argIndexes: Array[Int],
    keyType: DataType,
    accType: DataType,
    excludeAcc: Boolean,
    dataViewSpec: Option[DistinctViewSpec],
    consumeRetraction: Boolean,
    filterArgs: ArrayBuffer[Int],
    aggIndexes: ArrayBuffer[Int])


/**
  * The information contains all aggregate infos, and including input count information.
  *
  * @param aggInfos the information about every aggregates
  * @param indexOfCountStar  None if input count is not needed, otherwise is needed and the index
  *                        represents the count(*) index
  * @param countStarInserted  true when the count(*) is inserted into agg list,
  *                           false when the count(*) is already existent in agg list.
  * @param distinctInfos the distinct information, empty if all the aggregates are not distinct
  */
case class AggregateInfoList(
    aggInfos: Array[AggregateInfo],
    indexOfCountStar: Option[Int],
    countStarInserted: Boolean,
    distinctInfos: Array[DistinctInfo]) {

  def getAggNames: Array[String] = aggInfos.map(_.agg.getName)

  def getAccTypes: Array[DataType] = {
    aggInfos.flatMap(_.externalAccTypes) ++ distinctInfos.filter(!_.excludeAcc).map(_.accType)
  }

  def getActualAggregateCalls: Array[AggregateCall] = {
    getActualAggregateInfos.map(_.agg)
  }

  def getActualFunctions: Array[UserDefinedFunction] = {
    getActualAggregateInfos.map(_.function)
  }

  def getActualValueTypes: Array[DataType] = {
    getActualAggregateInfos.map(_.externalResultType)
  }

  def getIndexOfCountStar: Int = {
    if (indexOfCountStar.nonEmpty) {
      var accOffset = 0
      aggInfos.indices.foreach { i =>
        if (i < indexOfCountStar.get) {
          accOffset += aggInfos(i).externalAccTypes.length
        }
      }
      accOffset
    } else {
      -1
    }
  }

  def getActualAggregateInfos: Array[AggregateInfo] = {
    if (indexOfCountStar.nonEmpty && countStarInserted) {
      // need input count agg and the count1 is inserted,
      // which means the count1 shouldn't be calculated in value
      aggInfos.zipWithIndex
        .filter { case (_, index) => index != indexOfCountStar.get }
        .map { case (aggInfo, _) => aggInfo }
    } else {
      aggInfos
    }
  }
}
