// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "common/exception.h"
#include "common/status.h"
#include "exprs/runtime_filter.h"
#include "runtime/runtime_filter_mgr.h"
#include "runtime/runtime_state.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
#include "vec/core/block.h" // IWYU pragma: keep
#include "vec/runtime/shared_hash_table_controller.h"

namespace doris {
// this class used in hash join node
class VRuntimeFilterSlots {
public:
    VRuntimeFilterSlots(
            const std::vector<std::shared_ptr<vectorized::VExprContext>>& build_expr_ctxs,
            const std::vector<IRuntimeFilter*>& runtime_filters, bool need_local_merge = false)
            : _build_expr_context(build_expr_ctxs),
              _runtime_filters(runtime_filters),
              _need_local_merge(need_local_merge) {}

    Status init(RuntimeState* state, int64_t hash_table_size) {
        // runtime filter effect strategy
        // 1. we will ignore IN filter when hash_table_size is too big
        // 2. we will ignore BLOOM filter and MinMax filter when hash_table_size
        // is too small and IN filter has effect
        std::map<int, bool> has_in_filter;

        auto ignore_local_filter = [&](int filter_id) {
            auto runtime_filter_mgr = _need_local_merge ? state->global_runtime_filter_mgr()
                                                        : state->local_runtime_filter_mgr();

            std::vector<IRuntimeFilter*> filters;
            RETURN_IF_ERROR(runtime_filter_mgr->get_consume_filters(filter_id, filters));
            if (filters.empty()) {
                throw Exception(ErrorCode::INTERNAL_ERROR, "filters empty, filter_id={}",
                                filter_id);
            }
            for (auto* filter : filters) {
                filter->set_ignored("");
                filter->signal();
            }
            return Status::OK();
        };

        auto ignore_remote_filter = [](IRuntimeFilter* runtime_filter, std::string& msg) {
            runtime_filter->set_ignored(msg);
            RETURN_IF_ERROR(runtime_filter->publish());
            return Status::OK();
        };

        // ordered vector: IN, IN_OR_BLOOM, others.
        // so we can ignore other filter if IN Predicate exists.
        auto compare_desc = [](IRuntimeFilter* d1, IRuntimeFilter* d2) {
            if (d1->type() == d2->type()) {
                return false;
            } else if (d1->type() == RuntimeFilterType::IN_FILTER) {
                return true;
            } else if (d2->type() == RuntimeFilterType::IN_FILTER) {
                return false;
            } else if (d1->type() == RuntimeFilterType::IN_OR_BLOOM_FILTER) {
                return true;
            } else if (d2->type() == RuntimeFilterType::IN_OR_BLOOM_FILTER) {
                return false;
            } else {
                return d1->type() < d2->type();
            }
        };
        std::sort(_runtime_filters.begin(), _runtime_filters.end(), compare_desc);

        // do not create 'in filter' when hash_table size over limit
        const auto max_in_num = state->runtime_filter_max_in_num();
        const bool over_max_in_num = (hash_table_size >= max_in_num);

        for (auto* runtime_filter : _runtime_filters) {
            if (runtime_filter->expr_order() < 0 ||
                runtime_filter->expr_order() >= _build_expr_context.size()) {
                return Status::InternalError(
                        "runtime_filter meet invalid expr_order, expr_order={}, "
                        "_build_expr_context.size={}",
                        runtime_filter->expr_order(), _build_expr_context.size());
            }

            bool is_in_filter = (runtime_filter->type() == RuntimeFilterType::IN_FILTER);

            if (over_max_in_num &&
                runtime_filter->type() == RuntimeFilterType::IN_OR_BLOOM_FILTER) {
                RETURN_IF_ERROR(runtime_filter->change_to_bloom_filter());
            }

            if (runtime_filter->is_bloomfilter()) {
                RETURN_IF_ERROR(runtime_filter->init_bloom_filter(hash_table_size));
            }

            // Note:
            // In the case that exist *remote target* and in filter and other filter,
            // we must merge other filter whatever in filter is over the max num in current node,
            // because:
            // case 1: (in filter >= max num) in current node, so in filter will be ignored,
            //         and then other filter can be used
            // case 2: (in filter < max num) in current node, we don't know whether the in filter
            //         will be ignored in merge node, so we must transfer other filter to merge node
            if (!runtime_filter->has_remote_target()) {
                bool exists_in_filter = has_in_filter[runtime_filter->expr_order()];
                if (is_in_filter && over_max_in_num) {
                    VLOG_DEBUG << "fragment instance " << print_id(state->fragment_instance_id())
                               << " ignore runtime filter(in filter id "
                               << runtime_filter->filter_id() << ") because: in_num("
                               << hash_table_size << ") >= max_in_num(" << max_in_num << ")";
                    RETURN_IF_ERROR(ignore_local_filter(runtime_filter->filter_id()));
                    continue;
                } else if (!is_in_filter && exists_in_filter) {
                    // do not create 'bloom filter' and 'minmax filter' when 'in filter' has created
                    // because in filter is exactly filter, so it is enough to filter data
                    VLOG_DEBUG << "fragment instance " << print_id(state->fragment_instance_id())
                               << " ignore runtime filter("
                               << IRuntimeFilter::to_string(runtime_filter->type()) << " id "
                               << runtime_filter->filter_id()
                               << ") because: already exists in filter";
                    RETURN_IF_ERROR(ignore_local_filter(runtime_filter->filter_id()));
                    continue;
                }
            } else if (is_in_filter && over_max_in_num) {
                std::string msg = fmt::format(
                        "fragment instance {} ignore runtime filter(in filter id {}) because: "
                        "in_num({}) >= max_in_num({})",
                        print_id(state->fragment_instance_id()), runtime_filter->filter_id(),
                        hash_table_size, max_in_num);
                RETURN_IF_ERROR(ignore_remote_filter(runtime_filter, msg));
                continue;
            }

            if ((runtime_filter->type() == RuntimeFilterType::IN_FILTER) ||
                (runtime_filter->type() == RuntimeFilterType::IN_OR_BLOOM_FILTER &&
                 !over_max_in_num)) {
                has_in_filter[runtime_filter->expr_order()] = true;
            }
            _runtime_filters_map[runtime_filter->expr_order()].push_back(runtime_filter);
        }

        return Status::OK();
    }

    void insert(const vectorized::Block* block) {
        for (int i = 0; i < _build_expr_context.size(); ++i) {
            auto iter = _runtime_filters_map.find(i);
            if (iter == _runtime_filters_map.end()) {
                continue;
            }

            int result_column_id = _build_expr_context[i]->get_last_result_column_id();
            const auto& column = block->get_by_position(result_column_id).column;
            for (auto* filter : iter->second) {
                filter->insert_batch(column, 1);
            }
        }
    }

    // publish runtime filter
    Status publish(bool publish_local = false) {
        for (auto& pair : _runtime_filters_map) {
            for (auto& filter : pair.second) {
                RETURN_IF_ERROR(filter->publish(publish_local));
            }
        }
        return Status::OK();
    }

    void copy_to_shared_context(vectorized::SharedHashTableContextPtr& context) {
        for (auto& it : _runtime_filters_map) {
            for (auto& filter : it.second) {
                context->runtime_filters[filter->filter_id()] = filter->get_shared_context_ref();
            }
        }
    }

    Status copy_from_shared_context(vectorized::SharedHashTableContextPtr& context) {
        for (auto& it : _runtime_filters_map) {
            for (auto& filter : it.second) {
                auto filter_id = filter->filter_id();
                auto ret = context->runtime_filters.find(filter_id);
                if (ret == context->runtime_filters.end()) {
                    return Status::Aborted("invalid runtime filter id: {}", filter_id);
                }
                filter->get_shared_context_ref() = ret->second;
            }
        }
        return Status::OK();
    }

    bool empty() { return _runtime_filters_map.empty(); }

private:
    const std::vector<std::shared_ptr<vectorized::VExprContext>>& _build_expr_context;
    std::vector<IRuntimeFilter*> _runtime_filters;
    const bool _need_local_merge = false;
    // prob_contition index -> [IRuntimeFilter]
    std::map<int, std::list<IRuntimeFilter*>> _runtime_filters_map;
};

} // namespace doris
