/*
 * Decompiled with CFR 0.152.
 */
package org.apache.samza.sql.translator;

import java.io.Serializable;
import java.util.LinkedList;
import java.util.List;
import org.apache.calcite.adapter.enumerable.EnumerableTableScan;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlExplainFormat;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.commons.lang.Validate;
import org.apache.samza.SamzaException;
import org.apache.samza.operators.KV;
import org.apache.samza.operators.MessageStream;
import org.apache.samza.operators.functions.MapFunction;
import org.apache.samza.operators.functions.StreamTableJoinFunction;
import org.apache.samza.serializers.JsonSerdeV2;
import org.apache.samza.serializers.KVSerde;
import org.apache.samza.serializers.Serde;
import org.apache.samza.sql.data.SamzaSqlCompositeKey;
import org.apache.samza.sql.interfaces.SqlIOConfig;
import org.apache.samza.sql.interfaces.SqlIOResolver;
import org.apache.samza.sql.serializers.SamzaSqlRelMessageSerdeFactory;
import org.apache.samza.sql.translator.SamzaSqlRelMessageJoinFunction;
import org.apache.samza.sql.translator.TranslatorContext;
import org.apache.samza.table.Table;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class JoinTranslator {
    private static final Logger log = LoggerFactory.getLogger(JoinTranslator.class);
    private int joinId;
    private SqlIOResolver ioResolver;

    JoinTranslator(int joinId, SqlIOResolver ioResolver) {
        this.joinId = joinId;
        this.ioResolver = ioResolver;
    }

    void translate(LogicalJoin join, TranslatorContext context) {
        this.validateJoinQuery(join);
        boolean isTablePosOnRight = this.isTable(join.getRight());
        LinkedList<Integer> streamKeyIds = new LinkedList<Integer>();
        LinkedList<Integer> tableKeyIds = new LinkedList<Integer>();
        this.populateStreamAndTableKeyIds(((RexCall)join.getCondition()).getOperands(), join, isTablePosOnRight, streamKeyIds, tableKeyIds);
        Table table = this.loadLocalTable(isTablePosOnRight, tableKeyIds, join, context);
        MessageStream inputStream = isTablePosOnRight ? context.getMessageStream(join.getLeft().getId()) : context.getMessageStream(join.getRight().getId());
        List streamFieldNames = (isTablePosOnRight ? join.getLeft() : join.getRight()).getRowType().getFieldNames();
        List tableFieldNames = (isTablePosOnRight ? join.getRight() : join.getLeft()).getRowType().getFieldNames();
        Validate.isTrue((streamKeyIds.size() == tableKeyIds.size() ? 1 : 0) != 0);
        log.info("Joining on the following Stream and Table field(s): ");
        for (int i = 0; i < streamKeyIds.size(); ++i) {
            log.info((String)streamFieldNames.get((Integer)streamKeyIds.get(i)) + " with " + (String)tableFieldNames.get((Integer)tableKeyIds.get(i)));
        }
        SamzaSqlRelMessageJoinFunction joinFn = new SamzaSqlRelMessageJoinFunction(join.getJoinType(), isTablePosOnRight, streamKeyIds, streamFieldNames, tableFieldNames);
        JsonSerdeV2 keySerde = new JsonSerdeV2(SamzaSqlCompositeKey.class);
        SamzaSqlRelMessageSerdeFactory.SamzaSqlRelMessageSerde valueSerde = (SamzaSqlRelMessageSerdeFactory.SamzaSqlRelMessageSerde)new SamzaSqlRelMessageSerdeFactory().getSerde(null, null);
        MessageStream outputStream = inputStream.partitionBy((MapFunction & Serializable)m -> SamzaSqlCompositeKey.createSamzaSqlCompositeKey(m, streamKeyIds), (MapFunction & Serializable)m -> m, KVSerde.of((Serde)keySerde, (Serde)valueSerde), "stream_" + this.joinId).map(KV::getValue).join(table, (StreamTableJoinFunction)joinFn);
        context.registerMessageStream(join.getId(), outputStream);
    }

    private void validateJoinQuery(LogicalJoin join) {
        JoinRelType joinRelType = join.getJoinType();
        if (joinRelType.compareTo((Enum)JoinRelType.INNER) != 0 && joinRelType.compareTo((Enum)JoinRelType.LEFT) != 0 && joinRelType.compareTo((Enum)JoinRelType.RIGHT) != 0) {
            throw new SamzaException("Query with only INNER and LEFT/RIGHT OUTER join are supported.");
        }
        boolean isTablePosOnLeft = this.isTable(join.getLeft());
        boolean isTablePosOnRight = this.isTable(join.getRight());
        if (!isTablePosOnLeft && !isTablePosOnRight) {
            throw new SamzaException("Invalid query with both sides of join being denoted as 'stream'. Stream-stream join is not yet supported. " + this.dumpRelPlanForNode((RelNode)join));
        }
        if (isTablePosOnLeft && isTablePosOnRight) {
            throw new SamzaException("Invalid query with both sides of join being denoted as 'table'. " + this.dumpRelPlanForNode((RelNode)join));
        }
        if (joinRelType.compareTo((Enum)JoinRelType.LEFT) == 0 && isTablePosOnLeft && !isTablePosOnRight) {
            throw new SamzaException("Invalid query for outer left join. Left side of the join should be a 'stream' and right side of join should be a 'table'. " + this.dumpRelPlanForNode((RelNode)join));
        }
        if (joinRelType.compareTo((Enum)JoinRelType.RIGHT) == 0 && isTablePosOnRight && !isTablePosOnLeft) {
            throw new SamzaException("Invalid query for outer right join. Left side of the join should be a 'table' and right side of join should be a 'stream'. " + this.dumpRelPlanForNode((RelNode)join));
        }
        this.validateJoinCondition(join.getCondition());
    }

    private void validateJoinCondition(RexNode operand) {
        if (!(operand instanceof RexCall)) {
            throw new SamzaException("SQL Query is not supported. Join condition operand " + operand + " is of type " + operand.getClass());
        }
        RexCall condition = (RexCall)operand;
        if (condition.isAlwaysTrue()) {
            throw new SamzaException("Query results in a cross join, which is not supported. Please optimize the query. It is expected that the joins should include JOIN ON operator in the sql query.");
        }
        if (condition.getKind() != SqlKind.EQUALS && condition.getKind() != SqlKind.AND) {
            throw new SamzaException("Only equi-joins and AND operator is supported in join condition.");
        }
    }

    private void populateStreamAndTableKeyIds(List<RexNode> operands, LogicalJoin join, boolean isTablePosOnRight, List<Integer> streamKeyIds, List<Integer> tableKeyIds) {
        if (operands.get(0) instanceof RexCall) {
            operands.forEach(operand -> {
                this.validateJoinCondition((RexNode)operand);
                this.populateStreamAndTableKeyIds(((RexCall)operand).getOperands(), join, isTablePosOnRight, streamKeyIds, tableKeyIds);
            });
            return;
        }
        Validate.isTrue((operands.size() == 2 ? 1 : 0) != 0);
        if (!(operands.get(0) instanceof RexInputRef) || !(operands.get(1) instanceof RexInputRef)) {
            throw new SamzaException("SQL query is not supported. Join condition " + join.getCondition() + " should have reference operands but the types are " + operands.get(0).getClass() + " and " + operands.get(1).getClass());
        }
        RexInputRef leftRef = (RexInputRef)operands.get(0);
        RexInputRef rightRef = (RexInputRef)operands.get(1);
        this.validateKey(leftRef);
        this.validateKey(rightRef);
        if (leftRef.getIndex() > rightRef.getIndex()) {
            RexInputRef tmpRef = leftRef;
            leftRef = rightRef;
            rightRef = tmpRef;
        }
        int deltaKeyIdx = rightRef.getIndex() - join.getLeft().getRowType().getFieldCount();
        streamKeyIds.add(isTablePosOnRight ? leftRef.getIndex() : deltaKeyIdx);
        tableKeyIds.add(isTablePosOnRight ? deltaKeyIdx : leftRef.getIndex());
    }

    private void validateKey(RexInputRef ref) {
        SqlTypeName sqlTypeName = ref.getType().getSqlTypeName();
        if (sqlTypeName != SqlTypeName.BOOLEAN && sqlTypeName != SqlTypeName.TINYINT && sqlTypeName != SqlTypeName.SMALLINT && sqlTypeName != SqlTypeName.INTEGER && sqlTypeName != SqlTypeName.CHAR && sqlTypeName != SqlTypeName.BIGINT && sqlTypeName != SqlTypeName.VARCHAR && sqlTypeName != SqlTypeName.DOUBLE && sqlTypeName != SqlTypeName.FLOAT) {
            log.error("Unsupported key type " + sqlTypeName + " used in join condition.");
            throw new SamzaException("Unsupported key type used in join condition.");
        }
    }

    private String dumpRelPlanForNode(RelNode relNode) {
        return RelOptUtil.dumpPlan((String)"Rel expression: ", (RelNode)relNode, (SqlExplainFormat)SqlExplainFormat.TEXT, (SqlExplainLevel)SqlExplainLevel.EXPPLAN_ATTRIBUTES);
    }

    private SqlIOConfig resolveSourceConfig(RelNode relNode) {
        String sourceName = String.join((CharSequence)".", relNode.getTable().getQualifiedName());
        SqlIOConfig sourceConfig = this.ioResolver.fetchSourceInfo(sourceName);
        if (sourceConfig == null) {
            throw new SamzaException("Unsupported source found in join statement: " + sourceName);
        }
        return sourceConfig;
    }

    private boolean isTable(RelNode relNode) {
        if (relNode instanceof EnumerableTableScan) {
            return this.resolveSourceConfig(relNode).getTableDescriptor().isPresent();
        }
        return false;
    }

    private Table loadLocalTable(boolean isTablePosOnRight, List<Integer> tableKeyIds, LogicalJoin join, TranslatorContext context) {
        RelNode relNode = isTablePosOnRight ? join.getRight() : join.getLeft();
        MessageStream relOutputStream = context.getMessageStream(relNode.getId());
        SqlIOConfig sourceConfig = this.resolveSourceConfig(relNode);
        if (!sourceConfig.getTableDescriptor().isPresent()) {
            String errMsg = "Failed to resolve table source in join operation: node=" + relNode;
            log.error(errMsg);
            throw new SamzaException(errMsg);
        }
        Table table = context.getStreamAppDescriptor().getTable(sourceConfig.getTableDescriptor().get());
        relOutputStream.map((MapFunction & Serializable)m -> new KV((Object)SamzaSqlCompositeKey.createSamzaSqlCompositeKey(m, tableKeyIds), m)).sendTo(table);
        return table;
    }
}

