/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.api.ml;

import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.sysml.api.ml.BaseSystemMLClassifierModel;
import org.apache.sysml.api.ml.PredictionUtils$;
import org.apache.sysml.api.mlcontext.BinaryBlockMatrix;
import org.apache.sysml.api.mlcontext.MLContext;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.MatrixMetadata;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;

public abstract class BaseSystemMLClassifierModel$class {
    public static MatrixBlock baseTransform(BaseSystemMLClassifierModel $this, MatrixBlock X, MLResults mloutput, SparkContext sc, String probVar) {
        MLContext ml = new MLContext(sc);
        boolean isSingleNode = true;
        Tuple2<Script, String> script = $this.getPredictionScript(mloutput, isSingleNode);
        MLResults modelPredict = ml.execute(((Script)script._1()).in((String)script._2(), X, new MatrixMetadata(Predef$.MODULE$.long2Long((long)X.getNumRows()), Predef$.MODULE$.long2Long((long)X.getNumColumns()), Predef$.MODULE$.long2Long(X.getNonZeros()))));
        MatrixBlock ret = PredictionUtils$.MODULE$.computePredictedClassLabelsFromProbability(modelPredict, isSingleNode, sc, probVar).getBinaryBlockMatrix("Prediction").getMatrixBlock();
        if (ret.getNumColumns() != 1) {
            throw new RuntimeException("Expected predicted label to be a column vector");
        }
        return ret;
    }

    public static Dataset baseTransform(BaseSystemMLClassifierModel $this, Dataset df, MLResults mloutput, SparkContext sc, String probVar, boolean outputProb) {
        boolean isSingleNode = false;
        MLContext ml = new MLContext(sc);
        MatrixCharacteristics mcXin = new MatrixCharacteristics();
        JavaPairRDD<MatrixIndexes, MatrixBlock> Xin = RDDConverterUtils.dataFrameToBinaryBlock(JavaSparkContext$.MODULE$.fromSparkContext(df.rdd().sparkContext()), (Dataset<Row>)df.select("features", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0])), mcXin, false, true);
        Tuple2<Script, String> script = $this.getPredictionScript(mloutput, isSingleNode);
        BinaryBlockMatrix Xin_bin = new BinaryBlockMatrix(Xin, mcXin);
        MLResults modelPredict = ml.execute(((Script)script._1()).in((String)script._2(), Xin_bin));
        MLResults predLabelOut = PredictionUtils$.MODULE$.computePredictedClassLabelsFromProbability(modelPredict, isSingleNode, sc, probVar);
        Dataset predictedDF = predLabelOut.getDataFrame("Prediction").select(RDDConverterUtils.DF_ID_COLUMN, (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"C1"})).withColumnRenamed("C1", "prediction");
        if (outputProb) {
            Dataset prob = modelPredict.getDataFrame(probVar, true).withColumnRenamed("C1", "probability").select(RDDConverterUtils.DF_ID_COLUMN, (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"probability"}));
            Dataset<Row> dataset = RDDConverterUtilsExt.addIDToDataFrame((Dataset<Row>)df, df.sparkSession(), RDDConverterUtils.DF_ID_COLUMN);
            return PredictionUtils$.MODULE$.joinUsingID(dataset, PredictionUtils$.MODULE$.joinUsingID((Dataset<Row>)prob, (Dataset<Row>)predictedDF));
        }
        Dataset<Row> dataset = RDDConverterUtilsExt.addIDToDataFrame((Dataset<Row>)df, df.sparkSession(), RDDConverterUtils.DF_ID_COLUMN);
        return PredictionUtils$.MODULE$.joinUsingID(dataset, (Dataset<Row>)predictedDF);
    }

    public static boolean baseTransform$default$5(BaseSystemMLClassifierModel $this) {
        return true;
    }

    public static void $init$(BaseSystemMLClassifierModel $this) {
    }
}

