/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import java.io.Serializable;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.math.Numeric;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001!4QAD\b\u0001'mA\u0001\"\f\u0001\u0003\u0002\u0003\u0006Ia\f\u0005\te\u0001\u0011\t\u0011)A\u0005_!A1\u0007\u0001B\u0001B\u0003%A\u0007\u0003\u00058\u0001\t\u0005\t\u0015!\u00039\u0011!\t\u0005A!A!\u0002\u0013A\u0004\u0002\u0003\"\u0001\u0005\u0003\u0005\u000b\u0011B\"\t\u000b)\u0003A\u0011A&\t\u000fM\u0003!\u0019!C\u0005)\"1\u0001\f\u0001Q\u0001\nUCq!\u0017\u0001C\u0002\u0013EC\u000b\u0003\u0004[\u0001\u0001\u0006I!\u0016\u0005\t7\u0002A)\u0019!C\u00059\")A\r\u0001C\u0001K\nY\"\t\\8dW2+\u0017m\u001d;TcV\f'/Z:BO\u001e\u0014XmZ1u_JT!\u0001E\t\u0002\u0015\u0005<wM]3hCR|'O\u0003\u0002\u0013'\u0005)q\u000e\u001d;j[*\u0011A#F\u0001\u0003[2T!AF\f\u0002\u000bM\u0004\u0018M]6\u000b\u0005aI\u0012AB1qC\u000eDWMC\u0001\u001b\u0003\ry'oZ\n\u0004\u0001q\u0011\u0003CA\u000f!\u001b\u0005q\"\"A\u0010\u0002\u000bM\u001c\u0017\r\\1\n\u0005\u0005r\"AB!osJ+g\r\u0005\u0003$I\u0019bS\"A\b\n\u0005\u0015z!\u0001\b#jM\u001a,'/\u001a8uS\u0006\u0014G.\u001a'pgN\fum\u001a:fO\u0006$xN\u001d\t\u0003O)j\u0011\u0001\u000b\u0006\u0003SM\tqAZ3biV\u0014X-\u0003\u0002,Q\ti\u0011J\\:uC:\u001cWM\u00117pG.\u0004\"a\t\u0001\u0002\u00111\f'-\u001a7Ti\u0012\u001c\u0001\u0001\u0005\u0002\u001ea%\u0011\u0011G\b\u0002\u0007\t>,(\r\\3\u0002\u00131\f'-\u001a7NK\u0006t\u0017\u0001\u00044ji&sG/\u001a:dKB$\bCA\u000f6\u0013\t1dDA\u0004C_>dW-\u00198\u0002\u001b\t\u001cg)Z1ukJ,7o\u0015;e!\rIDHP\u0007\u0002u)\u00111(F\u0001\nEJ|\u0017\rZ2bgRL!!\u0010\u001e\u0003\u0013\t\u0013x.\u00193dCN$\bcA\u000f@_%\u0011\u0001I\b\u0002\u0006\u0003J\u0014\u0018-_\u0001\u000fE\u000e4U-\u0019;ve\u0016\u001cX*Z1o\u00039\u00117mQ8fM\u001aL7-[3oiN\u00042!\u000f\u001fE!\t)\u0005*D\u0001G\u0015\t95#\u0001\u0004mS:\fGnZ\u0005\u0003\u0013\u001a\u0013aAV3di>\u0014\u0018A\u0002\u001fj]&$h\b\u0006\u0004M\u001d>\u0003\u0016K\u0015\u000b\u0003Y5CQAQ\u0004A\u0002\rCQ!L\u0004A\u0002=BQAM\u0004A\u0002=BQaM\u0004A\u0002QBQaN\u0004A\u0002aBQ!Q\u0004A\u0002a\n1B\\;n\r\u0016\fG/\u001e:fgV\tQ\u000b\u0005\u0002\u001e-&\u0011qK\b\u0002\u0004\u0013:$\u0018\u0001\u00048v[\u001a+\u0017\r^;sKN\u0004\u0013a\u00013j[\u0006!A-[7!\u0003Y)gMZ3di&4XmQ8fM\u0006sGm\u00144gg\u0016$X#A/\u0011\tuqFiL\u0005\u0003?z\u0011a\u0001V;qY\u0016\u0014\u0004F\u0001\u0007b!\ti\"-\u0003\u0002d=\tIAO]1og&,g\u000e^\u0001\u0004C\u0012$GC\u0001\u0017g\u0011\u00159W\u00021\u0001'\u0003\u0015\u0011Gn\\2l\u0001")
public class BlockLeastSquaresAggregator
implements DifferentiableLossAggregator<InstanceBlock, BlockLeastSquaresAggregator> {
    private transient Tuple2<Vector, Object> effectiveCoefAndOffset;
    private final double labelStd;
    private final double labelMean;
    private final boolean fitIntercept;
    private final Broadcast<double[]> bcFeaturesStd;
    private final Broadcast<double[]> bcFeaturesMean;
    private final Broadcast<Vector> bcCoefficients;
    private final int numFeatures;
    private final int dim;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient boolean bitmap$trans$0;
    private volatile boolean bitmap$0;

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        BlockLeastSquaresAggregator blockLeastSquaresAggregator = this;
        synchronized (blockLeastSquaresAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private Tuple2<Vector, Object> effectiveCoefAndOffset$lzycompute() {
        BlockLeastSquaresAggregator blockLeastSquaresAggregator = this;
        synchronized (blockLeastSquaresAggregator) {
            if (!this.bitmap$trans$0) {
                double[] coefficientsArray = (double[])((Vector)this.bcCoefficients.value()).toArray().clone();
                double[] featuresMean = (double[])this.bcFeaturesMean.value();
                double[] featuresStd = (double[])this.bcFeaturesStd.value();
                double sum = 0.0;
                int len = coefficientsArray.length;
                for (int i = 0; i < len; ++i) {
                    if (featuresStd[i] != 0.0) {
                        sum += coefficientsArray[i] / featuresStd[i] * featuresMean[i];
                        continue;
                    }
                    coefficientsArray[i] = 0.0;
                }
                double offset = this.fitIntercept ? this.labelMean / this.labelStd - sum : 0.0;
                this.effectiveCoefAndOffset = new Tuple2((Object)Vectors$.MODULE$.dense(coefficientsArray), (Object)BoxesRunTime.boxToDouble((double)offset));
                this.bitmap$trans$0 = true;
            }
        }
        return this.effectiveCoefAndOffset;
    }

    private Tuple2<Vector, Object> effectiveCoefAndOffset() {
        return !this.bitmap$trans$0 ? this.effectiveCoefAndOffset$lzycompute() : this.effectiveCoefAndOffset;
    }

    @Override
    public BlockLeastSquaresAggregator add(InstanceBlock block) {
        Predef$.MODULE$.require(block.matrix().isTransposed());
        Predef$.MODULE$.require(this.numFeatures() == block.numFeatures(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(66).append("Dimensions mismatch when adding new ").append("instance. Expecting ").append(this.numFeatures()).append(" but got ").append(block.numFeatures()).append(".").toString());
        Predef$.MODULE$.require(block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$1 -> x$1 >= 0.0), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(34).append("instance weights ").append(block.weightIter().mkString("[", ",", "]")).append(" has to be >= 0.0").toString());
        if (block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$2 -> x$2 == 0.0)) {
            return this;
        }
        int size = block.size();
        Tuple2<Vector, Object> tuple2 = this.effectiveCoefAndOffset();
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        Vector effectiveCoefficientsVec = (Vector)tuple2._1();
        double offset = tuple2._2$mcD$sp();
        Tuple2 tuple22 = new Tuple2((Object)effectiveCoefficientsVec, (Object)BoxesRunTime.boxToDouble((double)offset));
        Tuple2 tuple23 = tuple22;
        Vector effectiveCoefficientsVec2 = (Vector)tuple23._1();
        double offset2 = tuple23._2$mcD$sp();
        DenseVector vec = new DenseVector((double[])Array$.MODULE$.tabulate(size, (Function1)(JFunction1.mcDI.sp & Serializable & scala.Serializable)i -> offset2 - block.getLabel(i) / $this.labelStd, ClassTag$.MODULE$.Double()));
        BLAS$.MODULE$.gemv(1.0, block.matrix(), effectiveCoefficientsVec2, 1.0, vec);
        double localLossSum = 0.0;
        for (int i2 = 0; i2 < size; ++i2) {
            double multiplier;
            double weight = block.getWeight().apply$mcDI$sp(i2);
            double diff = vec.apply(i2);
            localLossSum += weight * diff * diff / (double)2;
            vec.values()[i2] = multiplier = weight * diff;
        }
        this.lossSum_$eq(this.lossSum() + localLossSum);
        this.weightSum_$eq(this.weightSum() + BoxesRunTime.unboxToDouble((Object)block.weightIter().sum((Numeric)Numeric.DoubleIsFractional$.MODULE$)));
        DenseVector gradSumVec = new DenseVector(this.gradientSumArray());
        BLAS$.MODULE$.gemv(1.0, block.matrix().transpose(), (Vector)vec, 1.0, gradSumVec);
        return this;
    }

    public BlockLeastSquaresAggregator(double labelStd, double labelMean, boolean fitIntercept, Broadcast<double[]> bcFeaturesStd, Broadcast<double[]> bcFeaturesMean, Broadcast<Vector> bcCoefficients) {
        this.labelStd = labelStd;
        this.labelMean = labelMean;
        this.fitIntercept = fitIntercept;
        this.bcFeaturesStd = bcFeaturesStd;
        this.bcFeaturesMean = bcFeaturesMean;
        this.bcCoefficients = bcCoefficients;
        DifferentiableLossAggregator.$init$(this);
        Predef$.MODULE$.require(labelStd > 0.0, (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(54).append(this.getClass().getName()).append(" requires the label standard ").append("deviation to be positive.").toString());
        this.numFeatures = ((double[])bcFeaturesStd.value()).length;
        this.dim = this.numFeatures();
    }
}

