/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.internal.Logging;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.impl.Utils$;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseMatrix$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.SparseMatrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import org.slf4j.Logger;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function3;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001\u0005\u0015a!\u0002\u000b\u0016\u0001e\t\u0003\u0002C\u001d\u0001\u0005\u0003\u0005\u000b\u0011B\u001e\t\u0011\u001d\u0003!\u0011!Q\u0001\nmB\u0001\u0002\u0013\u0001\u0003\u0002\u0003\u0006I!\u0013\u0005\t\u0019\u0002\u0011\t\u0011)A\u0005\u0013\"AQ\n\u0001B\u0001B\u0003%a\nC\u0003V\u0001\u0011\u0005a\u000bC\u0004^\u0001\t\u0007I\u0011\u00020\t\r\t\u0004\u0001\u0015!\u0003`\u0011\u001d\u0019\u0007A1A\u0005RyCa\u0001\u001a\u0001!\u0002\u0013y\u0006bB3\u0001\u0005\u0004%IA\u0018\u0005\u0007M\u0002\u0001\u000b\u0011B0\t\u000f\u001d\u0004!\u0019!C\u0005=\"1\u0001\u000e\u0001Q\u0001\n}C\u0001\"\u001b\u0001\t\u0006\u0004%IA\u001b\u0005\t_\u0002A)\u0019!C\u0005a\"AQ\u000f\u0001EC\u0002\u0013%a\u000f\u0003\u0005|\u0001!\u0015\r\u0011\"\u0003w\u0011\u0015i\b\u0001\"\u0001\u007f\u0005\tjU\u000f\u001c;j]>l\u0017.\u00197M_\u001eL7\u000f^5d\u00052|7m[!hOJ,w-\u0019;pe*\u0011acF\u0001\u000bC\u001e<'/Z4bi>\u0014(B\u0001\r\u001a\u0003\u0015y\u0007\u000f^5n\u0015\tQ2$\u0001\u0002nY*\u0011A$H\u0001\u0006gB\f'o\u001b\u0006\u0003=}\ta!\u00199bG\",'\"\u0001\u0011\u0002\u0007=\u0014xm\u0005\u0003\u0001E!\u001a\u0004CA\u0012'\u001b\u0005!#\"A\u0013\u0002\u000bM\u001c\u0017\r\\1\n\u0005\u001d\"#AB!osJ+g\r\u0005\u0003*U1\u0012T\"A\u000b\n\u0005-*\"\u0001\b#jM\u001a,'/\u001a8uS\u0006\u0014G.\u001a'pgN\fum\u001a:fO\u0006$xN\u001d\t\u0003[Aj\u0011A\f\u0006\u0003_e\tqAZ3biV\u0014X-\u0003\u00022]\ti\u0011J\\:uC:\u001cWM\u00117pG.\u0004\"!\u000b\u0001\u0011\u0005Q:T\"A\u001b\u000b\u0005YZ\u0012\u0001C5oi\u0016\u0014h.\u00197\n\u0005a*$a\u0002'pO\u001eLgnZ\u0001\rE\u000eLeN^3sg\u0016\u001cF\u000fZ\u0002\u0001!\rat(Q\u0007\u0002{)\u0011ahG\u0001\nEJ|\u0017\rZ2bgRL!\u0001Q\u001f\u0003\u0013\t\u0013x.\u00193dCN$\bcA\u0012C\t&\u00111\t\n\u0002\u0006\u0003J\u0014\u0018-\u001f\t\u0003G\u0015K!A\u0012\u0013\u0003\r\u0011{WO\u00197f\u00031\u00117mU2bY\u0016$W*Z1o\u000311\u0017\u000e^%oi\u0016\u00148-\u001a9u!\t\u0019#*\u0003\u0002LI\t9!i\\8mK\u0006t\u0017a\u00034ji^KG\u000f['fC:\faBY2D_\u00164g-[2jK:$8\u000fE\u0002=\u007f=\u0003\"\u0001U*\u000e\u0003ES!AU\r\u0002\r1Lg.\u00197h\u0013\t!\u0016K\u0001\u0004WK\u000e$xN]\u0001\u0007y%t\u0017\u000e\u001e \u0015\u000b]K&l\u0017/\u0015\u0005IB\u0006\"B'\u0007\u0001\u0004q\u0005\"B\u001d\u0007\u0001\u0004Y\u0004\"B$\u0007\u0001\u0004Y\u0004\"\u0002%\u0007\u0001\u0004I\u0005\"\u0002'\u0007\u0001\u0004I\u0015a\u00038v[\u001a+\u0017\r^;sKN,\u0012a\u0018\t\u0003G\u0001L!!\u0019\u0013\u0003\u0007%sG/\u0001\u0007ok64U-\u0019;ve\u0016\u001c\b%A\u0002eS6\fA\u0001Z5nA\u0005Ab.^7GK\u0006$XO]3t!2,8/\u00138uKJ\u001cW\r\u001d;\u000239,XNR3biV\u0014Xm\u001d)mkNLe\u000e^3sG\u0016\u0004H\u000fI\u0001\u000b]Vl7\t\\1tg\u0016\u001c\u0018a\u00038v[\u000ec\u0017m]:fg\u0002\n\u0011cY8fM\u001aL7-[3oiN\f%O]1z+\u0005\t\u0005FA\bm!\t\u0019S.\u0003\u0002oI\tIAO]1og&,g\u000e^\u0001\u0007Y&tW-\u0019:\u0016\u0003E\u0004\"\u0001\u0015:\n\u0005M\f&a\u0003#f]N,W*\u0019;sSbD#\u0001\u00057\u0002\u0013%tG/\u001a:dKB$X#A<\u0011\u0005AC\u0018BA=R\u0005-!UM\\:f-\u0016\u001cGo\u001c:)\u0005Ea\u0017\u0001D7be\u001eLgn\u00144gg\u0016$\bF\u0001\nm\u0003\r\tG\r\u001a\u000b\u0004\u007f\u0006\u0005Q\"\u0001\u0001\t\r\u0005\r1\u00031\u0001-\u0003\u0015\u0011Gn\\2l\u0001")
public class MultinomialLogisticBlockAggregator
implements DifferentiableLossAggregator<InstanceBlock, MultinomialLogisticBlockAggregator>,
Logging {
    private transient double[] coefficientsArray;
    private transient DenseMatrix linear;
    private transient DenseVector intercept;
    private transient DenseVector marginOffset;
    private final Broadcast<double[]> bcScaledMean;
    private final boolean fitIntercept;
    private final boolean fitWithMean;
    private final Broadcast<Vector> bcCoefficients;
    private final int numFeatures;
    private final int dim;
    private final int numFeaturesPlusIntercept;
    private final int numClasses;
    private transient Logger org$apache$spark$internal$Logging$$log_;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient byte bitmap$trans$0;
    private volatile boolean bitmap$0;

    public String logName() {
        return Logging.logName$((Logging)this);
    }

    public Logger log() {
        return Logging.log$((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.logInfo$((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.logDebug$((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.logTrace$((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.logWarning$((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.logError$((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.logInfo$((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.logDebug$((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.logTrace$((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.logWarning$((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.logError$((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.isTraceEnabled$((Logging)this);
    }

    public void initializeLogIfNecessary(boolean isInterpreter) {
        Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter);
    }

    public boolean initializeLogIfNecessary(boolean isInterpreter, boolean silent) {
        return Logging.initializeLogIfNecessary$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    public boolean initializeLogIfNecessary$default$2() {
        return Logging.initializeLogIfNecessary$default$2$((Logging)this);
    }

    public void initializeForcefully(boolean isInterpreter, boolean silent) {
        Logging.initializeForcefully$((Logging)this, (boolean)isInterpreter, (boolean)silent);
    }

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    public Logger org$apache$spark$internal$Logging$$log_() {
        return this.org$apache$spark$internal$Logging$$log_;
    }

    public void org$apache$spark$internal$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$internal$Logging$$log_ = x$1;
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        MultinomialLogisticBlockAggregator multinomialLogisticBlockAggregator = this;
        synchronized (multinomialLogisticBlockAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private int numFeaturesPlusIntercept() {
        return this.numFeaturesPlusIntercept;
    }

    private int numClasses() {
        return this.numClasses;
    }

    private double[] coefficientsArray$lzycompute() {
        MultinomialLogisticBlockAggregator multinomialLogisticBlockAggregator = this;
        synchronized (multinomialLogisticBlockAggregator) {
            if ((byte)(this.bitmap$trans$0 & 1) == 0) {
                double[] values;
                DenseVector denseVector;
                Option option;
                Vector vector = (Vector)this.bcCoefficients.value();
                if (!(vector instanceof DenseVector) || (option = DenseVector$.MODULE$.unapply(denseVector = (DenseVector)vector)).isEmpty()) {
                    throw new IllegalArgumentException(new StringBuilder(55).append("coefficients only supports dense vector but ").append("got type ").append(this.bcCoefficients.value().getClass()).append(".)").toString());
                }
                double[] dArray = values = (double[])option.get();
                this.coefficientsArray = dArray;
                this.bitmap$trans$0 = (byte)(this.bitmap$trans$0 | 1);
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        return (byte)(this.bitmap$trans$0 & 1) == 0 ? this.coefficientsArray$lzycompute() : this.coefficientsArray;
    }

    private DenseMatrix linear$lzycompute() {
        MultinomialLogisticBlockAggregator multinomialLogisticBlockAggregator = this;
        synchronized (multinomialLogisticBlockAggregator) {
            if ((byte)(this.bitmap$trans$0 & 2) == 0) {
                this.linear = this.fitIntercept ? new DenseMatrix(this.numClasses(), this.numFeatures(), (double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())).take(this.numClasses() * this.numFeatures())) : new DenseMatrix(this.numClasses(), this.numFeatures(), this.coefficientsArray());
                this.bitmap$trans$0 = (byte)(this.bitmap$trans$0 | 2);
            }
        }
        return this.linear;
    }

    private DenseMatrix linear() {
        return (byte)(this.bitmap$trans$0 & 2) == 0 ? this.linear$lzycompute() : this.linear;
    }

    private DenseVector intercept$lzycompute() {
        MultinomialLogisticBlockAggregator multinomialLogisticBlockAggregator = this;
        synchronized (multinomialLogisticBlockAggregator) {
            if ((byte)(this.bitmap$trans$0 & 4) == 0) {
                this.intercept = this.fitIntercept ? new DenseVector((double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())).takeRight(this.numClasses())) : null;
                this.bitmap$trans$0 = (byte)(this.bitmap$trans$0 | 4);
            }
        }
        return this.intercept;
    }

    private DenseVector intercept() {
        return (byte)(this.bitmap$trans$0 & 4) == 0 ? this.intercept$lzycompute() : this.intercept;
    }

    /*
     * WARNING - void declaration
     */
    private DenseVector marginOffset$lzycompute() {
        MultinomialLogisticBlockAggregator multinomialLogisticBlockAggregator = this;
        synchronized (multinomialLogisticBlockAggregator) {
            if ((byte)(this.bitmap$trans$0 & 8) == 0) {
                Object v0;
                if (this.fitWithMean) {
                    void var2_2;
                    DenseVector offset = new DenseVector((double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())).takeRight(this.numClasses()));
                    BLAS$.MODULE$.gemv(-1.0, (Matrix)this.linear(), Vectors$.MODULE$.dense((double[])this.bcScaledMean.value()), 1.0, offset);
                    v0 = var2_2;
                } else {
                    v0 = null;
                }
                this.marginOffset = v0;
                this.bitmap$trans$0 = (byte)(this.bitmap$trans$0 | 8);
            }
        }
        return this.marginOffset;
    }

    private DenseVector marginOffset() {
        return (byte)(this.bitmap$trans$0 & 8) == 0 ? this.marginOffset$lzycompute() : this.marginOffset;
    }

    @Override
    public MultinomialLogisticBlockAggregator add(InstanceBlock block) {
        block13: {
            BoxedUnit boxedUnit;
            Predef$.MODULE$.require(block.matrix().isTransposed());
            Predef$.MODULE$.require(this.numFeatures() == block.numFeatures(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(66).append("Dimensions mismatch when adding new ").append("instance. Expecting ").append(this.numFeatures()).append(" but got ").append(block.numFeatures()).append(".").toString());
            Predef$.MODULE$.require(block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$1 -> x$1 >= 0.0), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(34).append("instance weights ").append(block.weightIter().mkString("[", ",", "]")).append(" has to be >= 0.0").toString());
            if (block.weightIter().forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$2 -> x$2 == 0.0)) {
                return this;
            }
            int size = block.size();
            DenseMatrix mat = DenseMatrix$.MODULE$.zeros(size, this.numClasses());
            double[] arr = mat.values();
            if (this.fitIntercept) {
                DenseVector offset = this.fitWithMean ? this.marginOffset() : this.intercept();
                for (int j2 = 0; j2 < this.numClasses(); ++j2) {
                    if (offset.apply(j2) == 0.0) continue;
                    Arrays.fill(arr, j2 * size, (j2 + 1) * size, offset.apply(j2));
                }
            }
            BLAS$.MODULE$.gemm(1.0, block.matrix(), this.linear().transpose(), 1.0, mat);
            double localLossSum = 0.0;
            double localWeightSum = 0.0;
            for (int i2 = 0; i2 < size; ++i2) {
                double weight = block.getWeight().apply$mcDI$sp(i2);
                localWeightSum += weight;
                if (weight > 0.0) {
                    int labelIndex = i2 + (int)block.getLabel(i2) * size;
                    Utils$.MODULE$.softmax(arr, this.numClasses(), i2, size, arr);
                    localLossSum -= weight * package$.MODULE$.log(arr[labelIndex]);
                    if (weight != 1.0) {
                        BLAS$.MODULE$.javaBLAS().dscal(this.numClasses(), weight, arr, i2, size);
                    }
                    arr[labelIndex] = arr[labelIndex] - weight;
                    continue;
                }
                BLAS$.MODULE$.javaBLAS().dscal(this.numClasses(), 0.0, arr, i2, size);
            }
            this.lossSum_$eq(this.lossSum() + localLossSum);
            this.weightSum_$eq(this.weightSum() + localWeightSum);
            Matrix matrix = block.matrix();
            if (matrix instanceof DenseMatrix) {
                DenseMatrix denseMatrix = (DenseMatrix)matrix;
                BLAS$.MODULE$.nativeBLAS().dgemm("T", "T", this.numClasses(), this.numFeatures(), size, 1.0, mat.values(), size, denseMatrix.values(), this.numFeatures(), 1.0, this.gradientSumArray(), this.numClasses());
                boxedUnit = BoxedUnit.UNIT;
            } else if (matrix instanceof SparseMatrix) {
                SparseMatrix sparseMatrix = (SparseMatrix)matrix;
                DenseMatrix linearGradSumMat = DenseMatrix$.MODULE$.zeros(this.numFeatures(), this.numClasses());
                BLAS$.MODULE$.gemm(1.0, (Matrix)sparseMatrix.transpose(), mat, 0.0, linearGradSumMat);
                linearGradSumMat.foreachActive((Function3 & Serializable & scala.Serializable)(i, j, v) -> {
                    MultinomialLogisticBlockAggregator.$anonfun$add$5(this, BoxesRunTime.unboxToInt((Object)i), BoxesRunTime.unboxToInt((Object)j), BoxesRunTime.unboxToDouble((Object)v));
                    return BoxedUnit.UNIT;
                });
                boxedUnit = BoxedUnit.UNIT;
            } else {
                throw new MatchError((Object)matrix);
            }
            if (!this.fitIntercept) break block13;
            double[] multiplierSum = (double[])Array$.MODULE$.ofDim(this.numClasses(), ClassTag$.MODULE$.Double());
            for (int j3 = 0; j3 < this.numClasses(); ++j3) {
                int i3;
                int end = i3 + size;
                for (i3 = j3 * size; i3 < end; ++i3) {
                    int n = j3;
                    multiplierSum[n] = multiplierSum[n] + arr[i3];
                }
            }
            if (this.fitWithMean) {
                BLAS$.MODULE$.nativeBLAS().dger(this.numClasses(), this.numFeatures(), -1.0, multiplierSum, 1, (double[])this.bcScaledMean.value(), 1, this.gradientSumArray(), this.numClasses());
            }
            BLAS$.MODULE$.javaBLAS().daxpy(this.numClasses(), 1.0, multiplierSum, 0, 1, this.gradientSumArray(), this.numClasses() * this.numFeatures(), 1);
        }
        return this;
    }

    public static final /* synthetic */ void $anonfun$add$5(MultinomialLogisticBlockAggregator $this, int i, int j, double v) {
        int n = i * $this.numClasses() + j;
        $this.gradientSumArray()[n] = $this.gradientSumArray()[n] + v;
    }

    public MultinomialLogisticBlockAggregator(Broadcast<double[]> bcInverseStd, Broadcast<double[]> bcScaledMean, boolean fitIntercept, boolean fitWithMean, Broadcast<Vector> bcCoefficients) {
        this.bcScaledMean = bcScaledMean;
        this.fitIntercept = fitIntercept;
        this.fitWithMean = fitWithMean;
        this.bcCoefficients = bcCoefficients;
        DifferentiableLossAggregator.$init$(this);
        Logging.$init$((Logging)this);
        if (fitWithMean) {
            Predef$.MODULE$.require(fitIntercept, (Function0 & Serializable & scala.Serializable)() -> "for training without intercept, should not center the vectors");
            Predef$.MODULE$.require(bcScaledMean != null && ((double[])bcScaledMean.value()).length == ((double[])bcInverseStd.value()).length, (Function0 & Serializable & scala.Serializable)() -> "scaled means is required when center the vectors");
        }
        this.numFeatures = ((double[])bcInverseStd.value()).length;
        this.dim = ((Vector)bcCoefficients.value()).size();
        this.numFeaturesPlusIntercept = fitIntercept ? this.numFeatures() + 1 : this.numFeatures();
        this.numClasses = this.dim() / this.numFeaturesPlusIntercept();
        Predef$.MODULE$.require(this.dim() == this.numClasses() * this.numFeaturesPlusIntercept());
    }
}

