// Copyright 2018 The Cockroach Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License.

package xform

import (
	"github.com/cockroachdb/cockroach/pkg/sql/opt"
	"github.com/cockroachdb/cockroach/pkg/sql/opt/cat"
	"github.com/cockroachdb/cockroach/pkg/sql/opt/memo"
	"github.com/cockroachdb/cockroach/pkg/sql/opt/norm"
	"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
)

// indexScanBuilder composes a constrained, limited scan over a table index.
// Any filters are created as close to the scan as possible, and index joins can
// be used to scan a non-covering index. For example, in order to construct:
//
//   (IndexJoin
//     (Select (Scan $scanPrivate) $filters)
//     $indexJoinPrivate
//   )
//
// make the following calls:
//
//   var sb indexScanBuilder
//   sb.init(c, tabID)
//   sb.setScan(scanPrivate)
//   sb.addSelect(filters)
//   sb.addIndexJoin(cols)
//   expr := sb.build()
//
type indexScanBuilder struct {
	c                *CustomFuncs
	f                *norm.Factory
	mem              *memo.Memo
	tabID            opt.TableID
	pkCols           opt.ColSet
	scanPrivate      memo.ScanPrivate
	innerFilters     memo.FiltersExpr
	outerFilters     memo.FiltersExpr
	indexJoinPrivate memo.IndexJoinPrivate
}

func (b *indexScanBuilder) init(c *CustomFuncs, tabID opt.TableID) {
	b.c = c
	b.f = c.e.f
	b.mem = c.e.mem
	b.tabID = tabID
}

// primaryKeyCols returns the columns from the scanned table's primary index.
func (b *indexScanBuilder) primaryKeyCols() opt.ColSet {
	// Ensure that pkCols set is initialized with the primary index columns.
	if b.pkCols.Empty() {
		primaryIndex := b.c.e.mem.Metadata().Table(b.tabID).Index(cat.PrimaryIndex)
		for i, cnt := 0, primaryIndex.KeyColumnCount(); i < cnt; i++ {
			b.pkCols.Add(int(b.tabID.ColumnID(primaryIndex.Column(i).Ordinal)))
		}
	}
	return b.pkCols
}

// setScan constructs a standalone Scan expression. As a side effect, it clears
// any expressions added during previous invocations of the builder. setScan
// makes a copy of scanPrivate so that it doesn't escape.
func (b *indexScanBuilder) setScan(scanPrivate *memo.ScanPrivate) {
	b.scanPrivate = *scanPrivate
	b.innerFilters = nil
	b.outerFilters = nil
	b.indexJoinPrivate = memo.IndexJoinPrivate{}
}

// addSelect wraps the input expression with a Select expression having the
// given filter.
func (b *indexScanBuilder) addSelect(filters memo.FiltersExpr) {
	if len(filters) != 0 {
		if b.indexJoinPrivate.Table == 0 {
			if b.innerFilters != nil {
				panic(pgerror.NewAssertionErrorf("cannot call addSelect methods twice before index join is added"))
			}
			b.innerFilters = filters
		} else {
			if b.outerFilters != nil {
				panic(pgerror.NewAssertionErrorf("cannot call addSelect methods twice after index join is added"))
			}
			b.outerFilters = filters
		}
	}
}

// addSelectAfterSplit first splits the given filter into two parts: a filter
// that only involves columns in the given set, and a remaining filter that
// includes everything else. The filter that is bound by the columns becomes a
// Select expression that wraps the input expression, and the remaining filter
// is returned (or 0 if there is no remaining filter).
func (b *indexScanBuilder) addSelectAfterSplit(
	filters memo.FiltersExpr, cols opt.ColSet,
) (remainingFilters memo.FiltersExpr) {
	if len(filters) == 0 {
		return nil
	}

	if b.c.FiltersBoundBy(filters, cols) {
		// Filter is fully bound by the cols, so add entire filter.
		b.addSelect(filters)
		return nil
	}

	// Try to split filter.
	boundConditions := b.c.ExtractBoundConditions(filters, cols)
	if len(boundConditions) == 0 {
		// None of the filter conjuncts can be bound by the cols, so no expression
		// can be added.
		return filters
	}

	// Add conditions which are fully bound by the cols and return the rest.
	b.addSelect(boundConditions)
	return b.c.ExtractUnboundConditions(filters, cols)
}

// addIndexJoin wraps the input expression with an IndexJoin expression that
// produces the given set of columns by lookup in the primary index.
func (b *indexScanBuilder) addIndexJoin(cols opt.ColSet) {
	if b.indexJoinPrivate.Table != 0 {
		panic(pgerror.NewAssertionErrorf("cannot call addIndexJoin twice"))
	}
	if b.outerFilters != nil {
		panic(pgerror.NewAssertionErrorf("cannot add index join after an outer filter has been added"))
	}
	b.indexJoinPrivate = memo.IndexJoinPrivate{
		Table: b.tabID,
		Cols:  cols,
	}
}

// build constructs the final memo expression by composing together the various
// expressions that were specified by previous calls to various add methods.
func (b *indexScanBuilder) build(grp memo.RelExpr) {
	// 1. Only scan.
	if len(b.innerFilters) == 0 && b.indexJoinPrivate.Table == 0 {
		b.mem.AddScanToGroup(&memo.ScanExpr{ScanPrivate: b.scanPrivate}, grp)
		return
	}

	// 2. Wrap scan in inner filter if it was added.
	input := b.f.ConstructScan(&b.scanPrivate)
	if len(b.innerFilters) != 0 {
		if b.indexJoinPrivate.Table == 0 {
			b.mem.AddSelectToGroup(&memo.SelectExpr{Input: input, Filters: b.innerFilters}, grp)
			return
		}

		input = b.f.ConstructSelect(input, b.innerFilters)
	}

	// 3. Wrap input in index join if it was added.
	if b.indexJoinPrivate.Table != 0 {
		if len(b.outerFilters) == 0 {
			indexJoin := &memo.IndexJoinExpr{Input: input, IndexJoinPrivate: b.indexJoinPrivate}
			b.mem.AddIndexJoinToGroup(indexJoin, grp)
			return
		}

		input = b.f.ConstructIndexJoin(input, &b.indexJoinPrivate)
	}

	// 4. Wrap input in outer filter (which must exist at this point).
	if len(b.outerFilters) == 0 {
		// indexJoinDef == 0: outerFilters == 0 handled by #1 and #2 above.
		// indexJoinDef != 0: outerFilters == 0 handled by #3 above.
		panic(pgerror.NewAssertionErrorf("outer filter cannot be 0 at this point"))
	}
	b.mem.AddSelectToGroup(&memo.SelectExpr{Input: input, Filters: b.outerFilters}, grp)
}
