/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.pinot.core.query.optimizer.filter;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.ExpressionType;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.FieldSpec.DataType;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.sql.FilterKind;


/**
 * The {@code MergeRangeFilterOptimizer} merges multiple RANGE predicates on the same column joined by AND by taking
 * their intersection. It also pulls up the merged predicate in the absence of other predicates.
 *
 * NOTE: This optimizer follows the {@link FlattenAndOrFilterOptimizer}, so all the AND/OR filters are already
 *       flattened.
 */
public class MergeRangeFilterOptimizer implements FilterOptimizer {

  @Override
  public Expression optimize(Expression filterExpression, @Nullable Schema schema) {
    if (schema == null || filterExpression.getType() != ExpressionType.FUNCTION) {
      return filterExpression;
    }
    Function function = filterExpression.getFunctionCall();
    String operator = function.getOperator();
    if (operator.equals(FilterKind.AND.name())) {
      List<Expression> children = function.getOperands();
      Map<String, Range> rangeMap = new HashMap<>();
      List<Expression> newChildren = new ArrayList<>();
      boolean recreateFilter = false;

      // Iterate over all the child filters to create and merge ranges
      for (Expression child : children) {
        Function childFunction = child.getFunctionCall();
        FilterKind filterKind = FilterKind.valueOf(childFunction.getOperator());
        assert filterKind != FilterKind.AND;
        if (filterKind == FilterKind.OR || filterKind == FilterKind.NOT) {
          childFunction.getOperands().replaceAll(o -> optimize(o, schema));
          newChildren.add(child);
        } else if (filterKind.isRange()) {
          List<Expression> operands = childFunction.getOperands();
          Expression lhs = operands.get(0);
          if (lhs.getType() != ExpressionType.IDENTIFIER) {
            // Skip optimizing transform expression
            newChildren.add(child);
            continue;
          }
          String column = lhs.getIdentifier().getName();
          FieldSpec fieldSpec = schema.getFieldSpecFor(column);
          if (fieldSpec == null || !fieldSpec.isSingleValueField()) {
            // Skip optimizing multi-value column
            // NOTE: We cannot optimize multi-value column because [0, 10] will match filter "col < 1 AND col > 9", but
            //       not the merged one.
            newChildren.add(child);
            continue;
          }
          // Create a range and merge with current range if exists
          DataType dataType = fieldSpec.getDataType();
          Range range = getRange(filterKind, operands, dataType);
          Range currentRange = rangeMap.get(column);
          if (currentRange == null) {
            rangeMap.put(column, range);
          } else {
            currentRange.intersect(range);
            recreateFilter = true;
          }
        } else {
          newChildren.add(child);
        }
      }

      if (recreateFilter) {
        if (newChildren.isEmpty() && rangeMap.size() == 1) {
          // Single range without other filters
          Map.Entry<String, Range> entry = rangeMap.entrySet().iterator().next();
          return getRangeFilterExpression(entry.getKey(), entry.getValue());
        } else {
          for (Map.Entry<String, Range> entry : rangeMap.entrySet()) {
            newChildren.add(getRangeFilterExpression(entry.getKey(), entry.getValue()));
          }
          function.setOperands(newChildren);
          return filterExpression;
        }
      } else {
        return filterExpression;
      }
    } else if (operator.equals(FilterKind.OR.name()) || operator.equals(FilterKind.NOT.name())) {
      function.getOperands().replaceAll(c -> optimize(c, schema));
      return filterExpression;
    } else {
      return filterExpression;
    }
  }

  /**
   * Helper method to create a Range from the given filter kind, operands and data type.
   */
  private static Range getRange(FilterKind filterKind, List<Expression> operands, DataType dataType) {
    switch (filterKind) {
      case GREATER_THAN:
        return new Range(getComparable(operands.get(1), dataType), false, null, false);
      case GREATER_THAN_OR_EQUAL:
        return new Range(getComparable(operands.get(1), dataType), true, null, false);
      case LESS_THAN:
        return new Range(null, false, getComparable(operands.get(1), dataType), false);
      case LESS_THAN_OR_EQUAL:
        return new Range(null, false, getComparable(operands.get(1), dataType), true);
      case BETWEEN:
        return new Range(getComparable(operands.get(1), dataType), true, getComparable(operands.get(2), dataType),
            true);
      case RANGE:
        return Range.getRange(operands.get(1).getLiteral().getStringValue(), dataType);
      default:
        throw new IllegalStateException("Unsupported filter kind: " + filterKind);
    }
  }

  /**
   * Helper method to create a Comparable from the given literal expression and data type.
   */
  @SuppressWarnings("rawtypes")
  private static Comparable getComparable(Expression literalExpression, DataType dataType) {
    return dataType.convertInternal(literalExpression.getLiteral().getFieldValue().toString());
  }

  /**
   * Helper method to construct a RANGE predicate filter Expression from the given column and range.
   */
  private static Expression getRangeFilterExpression(String column, Range range) {
    Expression rangeFilter = RequestUtils.getFunctionExpression(FilterKind.RANGE.name());
    rangeFilter.getFunctionCall().setOperands(Arrays.asList(RequestUtils.getIdentifierExpression(column),
        RequestUtils.getLiteralExpression(range.getRangeString())));
    return rangeFilter;
  }
}
