/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hertzbeat.collector.collect.database;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.apache.hertzbeat.collector.collect.AbstractCollect;
import org.apache.hertzbeat.collector.collect.common.cache.CacheIdentifier;
import org.apache.hertzbeat.collector.collect.common.cache.ConnectionCommonCache;
import org.apache.hertzbeat.collector.collect.common.cache.JdbcConnect;
import org.apache.hertzbeat.collector.util.CollectUtil;
import org.apache.hertzbeat.common.entity.job.Metrics;
import org.apache.hertzbeat.common.entity.job.protocol.JdbcProtocol;
import org.apache.hertzbeat.common.entity.message.CollectRep;
import org.apache.hertzbeat.common.util.CommonUtil;
import org.postgresql.util.PSQLException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.jdbc.datasource.init.ScriptUtils;
import org.springframework.util.StringUtils;

public class JdbcCommonCollect
extends AbstractCollect {
    private static final Logger log = LoggerFactory.getLogger(JdbcCommonCollect.class);
    private static final String QUERY_TYPE_ONE_ROW = "oneRow";
    private static final String QUERY_TYPE_MULTI_ROW = "multiRow";
    private static final String QUERY_TYPE_COLUMNS = "columns";
    private static final String RUN_SCRIPT = "runScript";
    private static final String[] VULNERABLE_KEYWORDS = new String[]{"allowLoadLocalInfile", "allowLoadLocalInfileInPath", "useLocalInfile"};
    private final ConnectionCommonCache<CacheIdentifier, JdbcConnect> connectionCommonCache = new ConnectionCommonCache();

    @Override
    public void preCheck(Metrics metrics) throws IllegalArgumentException {
        if (metrics == null || metrics.getJdbc() == null) {
            throw new IllegalArgumentException("Database collect must has jdbc params");
        }
        if (StringUtils.hasText((String)metrics.getJdbc().getUrl())) {
            for (String keyword : VULNERABLE_KEYWORDS) {
                if (!metrics.getJdbc().getUrl().contains(keyword)) continue;
                throw new IllegalArgumentException("Jdbc url prohibit contains vulnerable param " + keyword);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {
        long startTime = System.currentTimeMillis();
        JdbcProtocol jdbcProtocol = metrics.getJdbc();
        String databaseUrl = this.constructDatabaseUrl(jdbcProtocol);
        int timeout = CollectUtil.getTimeout(jdbcProtocol.getTimeout());
        Statement statement = null;
        try {
            statement = this.getConnection(jdbcProtocol.getUsername(), jdbcProtocol.getPassword(), databaseUrl, timeout);
            switch (jdbcProtocol.getQueryType()) {
                case "oneRow": {
                    this.queryOneRow(statement, jdbcProtocol.getSql(), metrics.getAliasFields(), builder, startTime);
                    return;
                }
                case "multiRow": {
                    this.queryMultiRow(statement, jdbcProtocol.getSql(), metrics.getAliasFields(), builder, startTime);
                    return;
                }
                case "columns": {
                    this.queryOneRowByMatchTwoColumns(statement, jdbcProtocol.getSql(), metrics.getAliasFields(), builder, startTime);
                    return;
                }
                case "runScript": {
                    Connection connection = statement.getConnection();
                    FileSystemResource rc = new FileSystemResource(jdbcProtocol.getSql());
                    ScriptUtils.executeSqlScript((Connection)connection, (Resource)rc);
                    return;
                }
                default: {
                    builder.setCode(CollectRep.Code.FAIL);
                    builder.setMsg("Not support database query type: " + jdbcProtocol.getQueryType());
                    return;
                }
            }
        }
        catch (PSQLException psqlException) {
            if ("08001".equals(psqlException.getSQLState())) {
                builder.setCode(CollectRep.Code.UN_REACHABLE);
            } else {
                builder.setCode(CollectRep.Code.FAIL);
            }
            builder.setMsg("Error: " + psqlException.getMessage() + " Code: " + psqlException.getSQLState());
            return;
        }
        catch (SQLException sqlException) {
            log.warn("Jdbc sql error: {}, code: {}.", (Object)sqlException.getMessage(), (Object)sqlException.getErrorCode());
            builder.setCode(CollectRep.Code.FAIL);
            builder.setMsg("Query Error: " + sqlException.getMessage() + " Code: " + sqlException.getErrorCode());
            return;
        }
        catch (Exception e) {
            String errorMessage = CommonUtil.getMessageFromThrowable((Throwable)e);
            log.error("Jdbc error: {}.", (Object)errorMessage, (Object)e);
            builder.setCode(CollectRep.Code.FAIL);
            builder.setMsg("Query Error: " + errorMessage);
            return;
        }
        finally {
            if (statement != null) {
                try {
                    statement.close();
                }
                catch (Exception e) {
                    log.error("Jdbc close statement error: {}", (Object)e.getMessage());
                }
            }
        }
    }

    @Override
    public String supportProtocol() {
        return "jdbc";
    }

    private Statement getConnection(String username, String password, String url, Integer timeout) throws Exception {
        int timeoutSecond;
        CacheIdentifier identifier = CacheIdentifier.builder().ip(url).username(username).password(password).build();
        Optional<JdbcConnect> cacheOption = this.connectionCommonCache.getCache(identifier, true);
        Statement statement = null;
        if (cacheOption.isPresent()) {
            JdbcConnect jdbcConnect = cacheOption.get();
            try {
                statement = jdbcConnect.getConnection().createStatement();
                timeoutSecond = timeout / 1000;
                timeoutSecond = timeoutSecond <= 0 ? 1 : timeoutSecond;
                statement.setQueryTimeout(timeoutSecond);
                statement.setMaxRows(1000);
            }
            catch (Exception e) {
                log.info("The jdbc connect from cache, create statement error: {}", (Object)e.getMessage());
                try {
                    if (statement != null) {
                        statement.close();
                    }
                    jdbcConnect.close();
                }
                catch (Exception e2) {
                    log.error(e2.getMessage());
                }
                statement = null;
                this.connectionCommonCache.removeCache(identifier);
            }
        }
        if (statement != null) {
            return statement;
        }
        Connection connection = DriverManager.getConnection(url, username, password);
        statement = connection.createStatement();
        timeoutSecond = timeout / 1000;
        timeoutSecond = timeoutSecond <= 0 ? 1 : timeoutSecond;
        statement.setQueryTimeout(timeoutSecond);
        statement.setMaxRows(1000);
        JdbcConnect jdbcConnect = new JdbcConnect(connection);
        this.connectionCommonCache.addCache(identifier, jdbcConnect);
        return statement;
    }

    private void queryOneRow(Statement statement, String sql, List<String> columns, CollectRep.MetricsData.Builder builder, long startTime) throws Exception {
        statement.setMaxRows(1);
        try (ResultSet resultSet = statement.executeQuery(sql);){
            if (resultSet.next()) {
                CollectRep.ValueRow.Builder valueRowBuilder = CollectRep.ValueRow.newBuilder();
                for (String column : columns) {
                    if ("responseTime".equals(column)) {
                        long time = System.currentTimeMillis() - startTime;
                        valueRowBuilder.addColumns(String.valueOf(time));
                        continue;
                    }
                    String value = resultSet.getString(column);
                    value = value == null ? "&nbsp;" : value;
                    valueRowBuilder.addColumns(value);
                }
                builder.addValues(valueRowBuilder.build());
            }
        }
    }

    private void queryOneRowByMatchTwoColumns(Statement statement, String sql, List<String> columns, CollectRep.MetricsData.Builder builder, long startTime) throws Exception {
        try (ResultSet resultSet = statement.executeQuery(sql);){
            HashMap<String, String> values = new HashMap<String, String>(columns.size());
            while (resultSet.next()) {
                if (resultSet.getString(1) == null) continue;
                values.put(resultSet.getString(1).toLowerCase().trim(), resultSet.getString(2));
            }
            CollectRep.ValueRow.Builder valueRowBuilder = CollectRep.ValueRow.newBuilder();
            for (String column : columns) {
                if ("responseTime".equals(column)) {
                    long time = System.currentTimeMillis() - startTime;
                    valueRowBuilder.addColumns(String.valueOf(time));
                    continue;
                }
                String value = (String)values.get(column.toLowerCase());
                value = value == null ? "&nbsp;" : value;
                valueRowBuilder.addColumns(value);
            }
            builder.addValues(valueRowBuilder.build());
        }
    }

    private void queryMultiRow(Statement statement, String sql, List<String> columns, CollectRep.MetricsData.Builder builder, long startTime) throws Exception {
        try (ResultSet resultSet = statement.executeQuery(sql);){
            while (resultSet.next()) {
                CollectRep.ValueRow.Builder valueRowBuilder = CollectRep.ValueRow.newBuilder();
                for (String column : columns) {
                    if ("responseTime".equals(column)) {
                        long time = System.currentTimeMillis() - startTime;
                        valueRowBuilder.addColumns(String.valueOf(time));
                        continue;
                    }
                    String value = resultSet.getString(column);
                    value = value == null ? "&nbsp;" : value;
                    valueRowBuilder.addColumns(value);
                }
                builder.addValues(valueRowBuilder.build());
            }
        }
    }

    private String constructDatabaseUrl(JdbcProtocol jdbcProtocol) {
        if (Objects.nonNull(jdbcProtocol.getUrl()) && !Objects.equals("", jdbcProtocol.getUrl()) && jdbcProtocol.getUrl().startsWith("jdbc")) {
            return jdbcProtocol.getUrl();
        }
        return switch (jdbcProtocol.getPlatform()) {
            case "mysql", "mariadb" -> "jdbc:mysql://" + jdbcProtocol.getHost() + ":" + jdbcProtocol.getPort() + "/" + (jdbcProtocol.getDatabase() == null ? "" : jdbcProtocol.getDatabase()) + "?useUnicode=true&characterEncoding=utf-8&useSSL=false";
            case "postgresql" -> "jdbc:postgresql://" + jdbcProtocol.getHost() + ":" + jdbcProtocol.getPort() + "/" + (jdbcProtocol.getDatabase() == null ? "" : jdbcProtocol.getDatabase());
            case "clickhouse" -> "jdbc:clickhouse://" + jdbcProtocol.getHost() + ":" + jdbcProtocol.getPort() + "/" + (jdbcProtocol.getDatabase() == null ? "" : jdbcProtocol.getDatabase());
            case "sqlserver" -> "jdbc:sqlserver://" + jdbcProtocol.getHost() + ":" + jdbcProtocol.getPort() + ";" + (String)(jdbcProtocol.getDatabase() == null ? "" : "DatabaseName=" + jdbcProtocol.getDatabase()) + ";trustServerCertificate=true;";
            case "oracle" -> "jdbc:oracle:thin:@" + jdbcProtocol.getHost() + ":" + jdbcProtocol.getPort() + "/" + (jdbcProtocol.getDatabase() == null ? "" : jdbcProtocol.getDatabase());
            case "dm" -> "jdbc:dm://" + jdbcProtocol.getHost() + ":" + jdbcProtocol.getPort();
            default -> throw new IllegalArgumentException("Not support database platform: " + jdbcProtocol.getPlatform());
        };
    }
}

