/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Triple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class GeneralizedExpectationObjectiveFunction<L, F>
extends AbstractCachingDiffFunction {
    private GeneralDataset<L, F> labeledDataset;
    private List<? extends Datum<L, F>> unlabeledDataList;
    private List<F> geFeatures;
    private LinearClassifier<L, F> classifier;
    private double[][] geFeature2EmpiricalDist;
    private List<List<Integer>> geFeature2DatumList;
    protected int numFeatures = 0;
    protected int numClasses = 0;

    @Override
    public int domainDimension() {
        return this.numFeatures * this.numClasses;
    }

    int classOf(int index) {
        return index % this.numClasses;
    }

    int featureOf(int index) {
        return index / this.numClasses;
    }

    protected int indexOf(int f, int c) {
        return f * this.numClasses + c;
    }

    public double[][] to2D(double[] x) {
        double[][] x2 = new double[this.numFeatures][this.numClasses];
        for (int i = 0; i < this.numFeatures; ++i) {
            for (int j = 0; j < this.numClasses; ++j) {
                x2[i][j] = x[this.indexOf(i, j)];
            }
        }
        return x2;
    }

    @Override
    protected void calculate(double[] x) {
        this.classifier.setWeights(this.to2D(x));
        if (this.derivative == null) {
            this.derivative = new double[x.length];
        } else {
            Arrays.fill(this.derivative, 0.0);
        }
        ClassicCounter<Triple<Integer, Integer, Integer>> feature2classPairDerivatives = new ClassicCounter<Triple<Integer, Integer, Integer>>();
        this.value = 0.0;
        for (int n = 0; n < this.geFeatures.size(); ++n) {
            double[] modelDist = new double[this.numClasses];
            Arrays.fill(modelDist, 0.0);
            List<Integer> activeData = this.geFeature2DatumList.get(n);
            for (int i = 0; i < activeData.size(); ++i) {
                Datum<L, F> datum = this.unlabeledDataList.get(activeData.get(i));
                double[] probs = this.getModelProbs(datum);
                for (int c = 0; c < this.numClasses; ++c) {
                    int n2 = c;
                    modelDist[n2] = modelDist[n2] + probs[c];
                }
                this.updateDerivative(datum, probs, feature2classPairDerivatives);
            }
            if (activeData.size() <= 0) continue;
            int c = 0;
            while (c < this.numClasses) {
                int n3 = c++;
                modelDist[n3] = modelDist[n3] / (double)activeData.size();
            }
            this.smoothDistribution(modelDist);
            for (c = 0; c < this.numClasses; ++c) {
                this.value += -this.geFeature2EmpiricalDist[n][c] * Math.log(modelDist[c]);
            }
            for (int f = 0; f < this.labeledDataset.featureIndex().size(); ++f) {
                for (int c2 = 0; c2 < this.numClasses; ++c2) {
                    int wtIndex = this.indexOf(f, c2);
                    for (int cPrime = 0; cPrime < this.numClasses; ++cPrime) {
                        int n4 = wtIndex;
                        this.derivative[n4] = this.derivative[n4] + feature2classPairDerivatives.getCount(new Triple<Integer, Integer, Integer>(f, c2, cPrime)) * this.geFeature2EmpiricalDist[n][cPrime] / modelDist[cPrime];
                    }
                    int n5 = wtIndex;
                    this.derivative[n5] = this.derivative[n5] / (double)activeData.size();
                }
            }
        }
    }

    private void updateDerivative(Datum<L, F> datum, double[] probs, Counter<Triple<Integer, Integer, Integer>> feature2classPairDerivatives) {
        for (Object feature : datum.asFeatures()) {
            int fID = this.labeledDataset.featureIndex.indexOf(feature);
            if (fID < 0) continue;
            for (int c = 0; c < this.numClasses; ++c) {
                for (int cPrime = 0; cPrime < this.numClasses; ++cPrime) {
                    if (cPrime == c) {
                        feature2classPairDerivatives.incrementCount(new Triple<Integer, Integer, Integer>(fID, c, cPrime), -probs[c] * (1.0 - probs[c]) * this.valueOfFeature(feature, datum));
                        continue;
                    }
                    feature2classPairDerivatives.incrementCount(new Triple<Integer, Integer, Integer>(fID, c, cPrime), probs[c] * probs[cPrime] * this.valueOfFeature(feature, datum));
                }
            }
        }
    }

    private double valueOfFeature(F feature, Datum<L, F> datum) {
        if (datum instanceof RVFDatum) {
            return ((RVFDatum)datum).asFeaturesCounter().getCount(feature);
        }
        return 1.0;
    }

    private void computeEmpiricalStatistics(List<F> geFeatures) {
        Datum<L, F> datum;
        int i;
        int n;
        this.geFeature2EmpiricalDist = new double[geFeatures.size()][this.labeledDataset.labelIndex.size()];
        this.geFeature2DatumList = new ArrayList<List<Integer>>(geFeatures.size());
        HashMap<F, Integer> geFeatureMap = new HashMap<F, Integer>();
        HashSet<Integer> activeUnlabeledExamples = new HashSet<Integer>();
        for (n = 0; n < geFeatures.size(); ++n) {
            F geFeature = geFeatures.get(n);
            this.geFeature2DatumList.add(new ArrayList());
            Arrays.fill(this.geFeature2EmpiricalDist[n], 0.0);
            geFeatureMap.put(geFeature, n);
        }
        for (i = 0; i < this.labeledDataset.size(); ++i) {
            datum = this.labeledDataset.getDatum(i);
            int labelID = this.labeledDataset.labelIndex.indexOf(datum.label());
            for (Object feature : datum.asFeatures()) {
                if (!geFeatureMap.containsKey(feature)) continue;
                int geFnum = (Integer)geFeatureMap.get(feature);
                double[] dArray = this.geFeature2EmpiricalDist[geFnum];
                int n2 = labelID;
                dArray[n2] = dArray[n2] + 1.0;
            }
        }
        for (n = 0; n < geFeatures.size(); ++n) {
            ArrayMath.normalize(this.geFeature2EmpiricalDist[n]);
            this.smoothDistribution(this.geFeature2EmpiricalDist[n]);
        }
        for (i = 0; i < this.unlabeledDataList.size(); ++i) {
            datum = this.unlabeledDataList.get(i);
            for (Object feature : datum.asFeatures()) {
                if (!geFeatureMap.containsKey(feature)) continue;
                int geFnum = (Integer)geFeatureMap.get(feature);
                this.geFeature2DatumList.get(geFnum).add(i);
                activeUnlabeledExamples.add(i);
            }
        }
        System.out.println("Number of active unlabeled examples:" + activeUnlabeledExamples.size());
    }

    private void smoothDistribution(double[] dist) {
        double epsilon = 1.0E-6;
        int i = 0;
        while (i < dist.length) {
            int n = i++;
            dist[n] = dist[n] + epsilon;
        }
        ArrayMath.normalize(dist);
    }

    private double[] getModelProbs(Datum<L, F> datum) {
        double[] condDist = new double[this.labeledDataset.numClasses()];
        Counter<L> probCounter = this.classifier.probabilityOf(datum);
        for (L label : probCounter.keySet()) {
            int labelID = this.labeledDataset.labelIndex.indexOf(label);
            condDist[labelID] = probCounter.getCount(label);
        }
        return condDist;
    }

    public GeneralizedExpectationObjectiveFunction(GeneralDataset<L, F> labeledDataset, List<? extends Datum<L, F>> unlabeledDataList, List<F> geFeatures) {
        System.out.println("Number of labeled examples:" + labeledDataset.size + "\nNumber of unlabeled examples:" + unlabeledDataList.size());
        System.out.println("Number of GE features:" + geFeatures.size());
        this.numFeatures = labeledDataset.numFeatures();
        this.numClasses = labeledDataset.numClasses();
        this.labeledDataset = labeledDataset;
        this.unlabeledDataList = unlabeledDataList;
        this.geFeatures = geFeatures;
        this.classifier = new LinearClassifier(null, labeledDataset.featureIndex, labeledDataset.labelIndex);
        this.computeEmpiricalStatistics(geFeatures);
    }
}

