/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.cf.taste.example.kddcup.track1.svd;

import java.util.Collection;
import java.util.Random;
import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.example.kddcup.track1.svd.DataModelFactorizablePreferences;
import org.apache.mahout.cf.taste.example.kddcup.track1.svd.FactorizablePreferences;
import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
import org.apache.mahout.common.RandomUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelArraysSGDFactorizer
implements Factorizer {
    public static final double DEFAULT_LEARNING_RATE = 0.005;
    public static final double DEFAULT_PREVENT_OVERFITTING = 0.02;
    public static final double DEFAULT_RANDOM_NOISE = 0.005;
    private final int numFeatures;
    private final int numIterations;
    private final float minPreference;
    private final float maxPreference;
    private final Random random;
    private final double learningRate;
    private final double preventOverfitting;
    private final FastByIDMap<Integer> userIDMapping;
    private final FastByIDMap<Integer> itemIDMapping;
    private final double[][] userFeatures;
    private final double[][] itemFeatures;
    private final int[] userIndexes;
    private final int[] itemIndexes;
    private final float[] values;
    private final double defaultValue;
    private final double interval;
    private final double[] cachedEstimates;
    private static final Logger log = LoggerFactory.getLogger(ParallelArraysSGDFactorizer.class);

    public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) {
        this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, 0.005, 0.02, 0.005);
    }

    public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations, double learningRate, double preventOverfitting, double randomNoise) {
        this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, learningRate, preventOverfitting, randomNoise);
    }

    public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePrefs, int numFeatures, int numIterations) {
        this(factorizablePrefs, numFeatures, numIterations, 0.005, 0.02, 0.005);
    }

    public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePreferences, int numFeatures, int numIterations, double learningRate, double preventOverfitting, double randomNoise) {
        this.numFeatures = numFeatures;
        this.numIterations = numIterations;
        this.minPreference = factorizablePreferences.getMinPreference();
        this.maxPreference = factorizablePreferences.getMaxPreference();
        this.random = RandomUtils.getRandom();
        this.learningRate = learningRate;
        this.preventOverfitting = preventOverfitting;
        int numUsers = factorizablePreferences.numUsers();
        int numItems = factorizablePreferences.numItems();
        int numPrefs = factorizablePreferences.numPreferences();
        log.info("Mapping {} users...", (Object)numUsers);
        this.userIDMapping = new FastByIDMap(numUsers);
        int index = 0;
        LongPrimitiveIterator userIterator = factorizablePreferences.getUserIDs();
        while (userIterator.hasNext()) {
            this.userIDMapping.put(userIterator.nextLong(), (Object)index++);
        }
        log.info("Mapping {} items", (Object)numItems);
        this.itemIDMapping = new FastByIDMap(numItems);
        index = 0;
        LongPrimitiveIterator itemIterator = factorizablePreferences.getItemIDs();
        while (itemIterator.hasNext()) {
            this.itemIDMapping.put(itemIterator.nextLong(), (Object)index++);
        }
        this.userIndexes = new int[numPrefs];
        this.itemIndexes = new int[numPrefs];
        this.values = new float[numPrefs];
        this.cachedEstimates = new double[numPrefs];
        index = 0;
        log.info("Loading {} preferences into memory", (Object)numPrefs);
        FullRunningAverage average = new FullRunningAverage();
        for (Preference preference : factorizablePreferences.getPreferences()) {
            this.userIndexes[index] = (Integer)this.userIDMapping.get(preference.getUserID());
            this.itemIndexes[index] = (Integer)this.itemIDMapping.get(preference.getItemID());
            this.values[index] = preference.getValue();
            this.cachedEstimates[index] = 0.0;
            average.addDatum((double)preference.getValue());
            if (++index % 1000000 != 0) continue;
            log.info("Processed {} preferences", (Object)index);
        }
        log.info("Processed {} preferences, done.", (Object)index);
        double averagePreference = average.getAverage();
        log.info("Average preference value is {}", (Object)averagePreference);
        double prefInterval = factorizablePreferences.getMaxPreference() - factorizablePreferences.getMinPreference();
        this.defaultValue = Math.sqrt((averagePreference - prefInterval * 0.1) / (double)numFeatures);
        this.interval = prefInterval * 0.1 / (double)numFeatures;
        this.userFeatures = new double[numUsers][numFeatures];
        this.itemFeatures = new double[numItems][numFeatures];
        log.info("Initializing feature vectors...");
        for (int feature = 0; feature < numFeatures; ++feature) {
            for (int userIndex = 0; userIndex < numUsers; ++userIndex) {
                this.userFeatures[userIndex][feature] = this.defaultValue + (this.random.nextDouble() - 0.5) * this.interval * randomNoise;
            }
            for (int itemIndex = 0; itemIndex < numItems; ++itemIndex) {
                this.itemFeatures[itemIndex][feature] = this.defaultValue + (this.random.nextDouble() - 0.5) * this.interval * randomNoise;
            }
        }
    }

    public Factorization factorize() throws TasteException {
        for (int feature = 0; feature < this.numFeatures; ++feature) {
            log.info("Shuffling preferences...");
            this.shufflePreferences();
            log.info("Starting training of feature {} ...", (Object)feature);
            for (int currentIteration = 0; currentIteration < this.numIterations; ++currentIteration) {
                if (currentIteration == this.numIterations - 1) {
                    double rmse = this.trainingIterationWithRmse(feature);
                    log.info("Finished training feature {} with RMSE {}", (Object)feature, (Object)rmse);
                    continue;
                }
                this.trainingIteration(feature);
            }
            if (feature >= this.numFeatures - 1) continue;
            log.info("Updating cache...");
            for (int index = 0; index < this.userIndexes.length; ++index) {
                this.cachedEstimates[index] = this.estimate(this.userIndexes[index], this.itemIndexes[index], feature, this.cachedEstimates[index], false);
            }
        }
        log.info("Factorization done");
        return new Factorization(this.userIDMapping, this.itemIDMapping, this.userFeatures, this.itemFeatures);
    }

    private void trainingIteration(int feature) {
        for (int index = 0; index < this.userIndexes.length; ++index) {
            this.train(this.userIndexes[index], this.itemIndexes[index], feature, this.values[index], this.cachedEstimates[index]);
        }
    }

    private double trainingIterationWithRmse(int feature) {
        double rmse = 0.0;
        for (int index = 0; index < this.userIndexes.length; ++index) {
            double error = this.train(this.userIndexes[index], this.itemIndexes[index], feature, this.values[index], this.cachedEstimates[index]);
            rmse += error * error;
        }
        return Math.sqrt(rmse / (double)this.userIndexes.length);
    }

    private double estimate(int userIndex, int itemIndex, int feature, double cachedEstimate, boolean trailing) {
        double sum = cachedEstimate;
        sum += this.userFeatures[userIndex][feature] * this.itemFeatures[itemIndex][feature];
        if (trailing) {
            if ((sum += (double)(this.numFeatures - feature - 1) * (this.defaultValue + this.interval) * (this.defaultValue + this.interval)) > (double)this.maxPreference) {
                sum = this.maxPreference;
            } else if (sum < (double)this.minPreference) {
                sum = this.minPreference;
            }
        }
        return sum;
    }

    public double train(int userIndex, int itemIndex, int feature, double original, double cachedEstimate) {
        double error = original - this.estimate(userIndex, itemIndex, feature, cachedEstimate, true);
        double[] userVector = this.userFeatures[userIndex];
        double[] itemVector = this.itemFeatures[itemIndex];
        int n = feature;
        userVector[n] = userVector[n] + this.learningRate * (error * itemVector[feature] - this.preventOverfitting * userVector[feature]);
        int n2 = feature;
        itemVector[n2] = itemVector[n2] + this.learningRate * (error * userVector[feature] - this.preventOverfitting * itemVector[feature]);
        return error;
    }

    protected void shufflePreferences() {
        for (int currentPos = this.userIndexes.length - 1; currentPos > 0; --currentPos) {
            int swapPos = this.random.nextInt(currentPos + 1);
            this.swapPreferences(currentPos, swapPos);
        }
    }

    private void swapPreferences(int posA, int posB) {
        int tmpUserIndex = this.userIndexes[posA];
        int tmpItemIndex = this.itemIndexes[posA];
        float tmpValue = this.values[posA];
        double tmpEstimate = this.cachedEstimates[posA];
        this.userIndexes[posA] = this.userIndexes[posB];
        this.itemIndexes[posA] = this.itemIndexes[posB];
        this.values[posA] = this.values[posB];
        this.cachedEstimates[posA] = this.cachedEstimates[posB];
        this.userIndexes[posB] = tmpUserIndex;
        this.itemIndexes[posB] = tmpItemIndex;
        this.values[posB] = tmpValue;
        this.cachedEstimates[posB] = tmpEstimate;
    }

    public void refresh(Collection<Refreshable> alreadyRefreshed) {
    }
}

