/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.api.ml;

import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.sysml.api.ml.BaseSystemMLClassifierModel;
import org.apache.sysml.api.ml.PredictionUtils$;
import org.apache.sysml.api.mlcontext.MLContext;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.MatrixMetadata;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.api.mlcontext.ScriptFactory;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;
import scala.runtime.BoxesRunTime;

public abstract class BaseSystemMLClassifierModel$class {
    public static MatrixBlock baseTransform(BaseSystemMLClassifierModel $this, MatrixBlock X, SparkContext sc, String probVar) {
        return $this.baseTransform(X, sc, probVar, -1, 1, 1);
    }

    public static MatrixBlock baseTransform(BaseSystemMLClassifierModel $this, MatrixBlock X, SparkContext sc, String probVar, int C, int H, int W) {
        Matrix Prob = $this.baseTransformHelper(X, sc, probVar, C, H, W);
        Script script1 = ScriptFactory.dml("source(\"nn/util.dml\") as util; Prediction = util::predict_class(Prob, C, H, W);").out("Prediction").in("Prob", Prob.toMatrixBlock(), Prob.getMatrixMetadata()).in("C", BoxesRunTime.boxToInteger((int)C)).in("H", BoxesRunTime.boxToInteger((int)H)).in("W", BoxesRunTime.boxToInteger((int)W));
        MatrixBlock ret = new MLContext(sc).execute(script1).getMatrix("Prediction").toMatrixBlock();
        if (ret.getNumColumns() != 1 && H == 1 && W == 1) {
            throw new RuntimeException("Expected predicted label to be a column vector");
        }
        return ret;
    }

    public static Matrix baseTransformHelper(BaseSystemMLClassifierModel $this, MatrixBlock X, SparkContext sc, String probVar, int C, int H, int W) {
        boolean isSingleNode = true;
        MLContext ml = new MLContext(sc);
        $this.updateML(ml);
        Tuple2<Script, String> script = $this.getPredictionScript(isSingleNode);
        MLResults modelPredict = ml.execute(((Script)script._1()).in((String)script._2(), X, new MatrixMetadata(Predef$.MODULE$.long2Long((long)X.getNumRows()), Predef$.MODULE$.long2Long((long)X.getNumColumns()), Predef$.MODULE$.long2Long(X.getNonZeros()))));
        return modelPredict.getMatrix(probVar);
    }

    public static MatrixBlock baseTransformProbability(BaseSystemMLClassifierModel $this, MatrixBlock X, SparkContext sc, String probVar) {
        return $this.baseTransformProbability(X, sc, probVar, -1, 1, 1);
    }

    public static MatrixBlock baseTransformProbability(BaseSystemMLClassifierModel $this, MatrixBlock X, SparkContext sc, String probVar, int C, int H, int W) {
        return $this.baseTransformHelper(X, sc, probVar, C, H, W).toMatrixBlock();
    }

    public static Dataset baseTransform(BaseSystemMLClassifierModel $this, Dataset df, SparkContext sc, String probVar, boolean outputProb) {
        return $this.baseTransform(df, sc, probVar, outputProb, -1, 1, 1);
    }

    public static Matrix baseTransformHelper(BaseSystemMLClassifierModel $this, Dataset df, SparkContext sc, String probVar, boolean outputProb, int C, int H, int W) {
        boolean isSingleNode = false;
        MLContext ml = new MLContext(sc);
        $this.updateML(ml);
        MatrixCharacteristics mcXin = new MatrixCharacteristics();
        JavaPairRDD<MatrixIndexes, MatrixBlock> Xin = RDDConverterUtils.dataFrameToBinaryBlock(JavaSparkContext$.MODULE$.fromSparkContext(df.rdd().sparkContext()), (Dataset<Row>)df.select("features", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0])), mcXin, false, true);
        Tuple2<Script, String> script = $this.getPredictionScript(isSingleNode);
        MatrixMetadata mmXin = new MatrixMetadata(mcXin);
        Matrix Xin_bin = new Matrix(Xin, mmXin);
        MLResults modelPredict = ml.execute(((Script)script._1()).in((String)script._2(), Xin_bin));
        return modelPredict.getMatrix(probVar);
    }

    public static Dataset baseTransform(BaseSystemMLClassifierModel $this, Dataset df, SparkContext sc, String probVar, boolean outputProb, int C, int H, int W) {
        Matrix Prob = $this.baseTransformHelper(df, sc, probVar, outputProb, C, H, W);
        Script script1 = ScriptFactory.dml("source(\"nn/util.dml\") as util; Prediction = util::predict_class(Prob, C, H, W);").out("Prediction").in("Prob", Prob).in("C", BoxesRunTime.boxToInteger((int)C)).in("H", BoxesRunTime.boxToInteger((int)H)).in("W", BoxesRunTime.boxToInteger((int)W));
        MLResults predLabelOut = new MLContext(sc).execute(script1);
        Dataset predictedDF = predLabelOut.getDataFrame("Prediction").select(RDDConverterUtils.DF_ID_COLUMN, (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"C1"})).withColumnRenamed("C1", "prediction");
        if (outputProb) {
            Dataset prob = Prob.toDFVectorWithIDColumn().withColumnRenamed("C1", "probability").select(RDDConverterUtils.DF_ID_COLUMN, (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"probability"}));
            Dataset<Row> dataset = RDDConverterUtilsExt.addIDToDataFrame((Dataset<Row>)df, df.sparkSession(), RDDConverterUtils.DF_ID_COLUMN);
            return PredictionUtils$.MODULE$.joinUsingID(dataset, PredictionUtils$.MODULE$.joinUsingID((Dataset<Row>)prob, (Dataset<Row>)predictedDF));
        }
        Dataset<Row> dataset = RDDConverterUtilsExt.addIDToDataFrame((Dataset<Row>)df, df.sparkSession(), RDDConverterUtils.DF_ID_COLUMN);
        return PredictionUtils$.MODULE$.joinUsingID(dataset, (Dataset<Row>)predictedDF);
    }

    public static boolean baseTransform$default$4(BaseSystemMLClassifierModel $this) {
        return true;
    }

    public static void $init$(BaseSystemMLClassifierModel $this) {
    }
}

