/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.ie.crf.CRFLabel;
import edu.stanford.nlp.ie.crf.FactorTable;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.sequences.SequenceListener;
import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.GeneralizedCounter;
import edu.stanford.nlp.util.Index;
import java.util.Arrays;

public class CRFCliqueTree
implements SequenceModel,
SequenceListener {
    private FactorTable[] factorTables;
    private double z;
    private Index classIndex;
    private String backgroundSymbol;
    private int backgroundIndex;
    private int windowSize;
    private int numClasses;
    private int[] possibleValues;

    private CRFCliqueTree() {
    }

    private CRFCliqueTree(FactorTable[] factorTables, Index classIndex, String backgroundSymbol) {
        this.factorTables = factorTables;
        this.classIndex = classIndex;
        this.backgroundSymbol = backgroundSymbol;
        this.backgroundIndex = classIndex.indexOf(backgroundSymbol);
        this.z = factorTables[0].totalMass();
        this.windowSize = factorTables[0].windowSize();
        this.numClasses = classIndex.size();
        this.possibleValues = new int[this.numClasses];
        for (int i = 0; i < this.numClasses; ++i) {
            this.possibleValues[i] = i;
        }
    }

    public Index classIndex() {
        return this.classIndex;
    }

    public int length() {
        return this.factorTables.length;
    }

    public int leftWindow() {
        return this.windowSize;
    }

    public int rightWindow() {
        return 0;
    }

    public int[] getPossibleValues(int position) {
        return this.possibleValues;
    }

    public double scoreOf(int[] sequence, int pos) {
        return this.scoresOf(sequence, pos)[sequence[pos]];
    }

    public double[] scoresOf(int[] sequence, int position) {
        int i;
        if (position >= this.factorTables.length) {
            throw new RuntimeException("Index out of bounds: " + position);
        }
        double[] probThisGivenPrev = new double[this.numClasses];
        double[] probNextGivenThis = new double[this.numClasses];
        int prevLength = this.windowSize - 1;
        int[] prev = new int[prevLength + 1];
        for (i = 0; i < prevLength - position; ++i) {
            prev[i] = this.classIndex.indexOf(this.backgroundSymbol);
        }
        while (i < prevLength) {
            prev[i] = sequence[position - prevLength + i];
            ++i;
        }
        for (int label = 0; label < this.numClasses; ++label) {
            prev[prev.length - 1] = label;
            probThisGivenPrev[label] = this.factorTables[position].unnormalizedLogProb(prev);
        }
        int nextLength = this.windowSize - 1;
        if (position + nextLength >= this.length()) {
            nextLength = this.length() - position - 1;
        }
        FactorTable nextFactorTable = this.factorTables[position + nextLength];
        if (nextLength != this.windowSize - 1) {
            for (int j = 0; j < this.windowSize - 1 - nextLength; ++j) {
                nextFactorTable = nextFactorTable.sumOutFront();
            }
        }
        if (nextLength == 0) {
            Arrays.fill(probNextGivenThis, 1.0);
        } else {
            int[] next = new int[nextLength];
            System.arraycopy(sequence, position + 1, next, 0, nextLength);
            for (int label = 0; label < this.numClasses; ++label) {
                probNextGivenThis[label] = nextFactorTable.unnormalizedConditionalLogProbGivenFirst(label, next);
            }
        }
        return ArrayMath.pairwiseAdd(probThisGivenPrev, probNextGivenThis);
    }

    public double scoreOf(int[] sequence) {
        int[] given = new int[this.window() - 1];
        Arrays.fill(given, this.classIndex.indexOf(this.backgroundSymbol));
        double logProb = 0.0;
        for (int i = 0; i < this.length(); ++i) {
            int label = sequence[i];
            logProb += this.condLogProbGivenPrevious(i, label, given);
            System.arraycopy(given, 1, given, 0, given.length - 1);
            given[given.length - 1] = label;
        }
        return logProb;
    }

    public int window() {
        return this.windowSize;
    }

    public int getNumClasses() {
        return this.numClasses;
    }

    public double totalMass() {
        return this.z;
    }

    public int backgroundIndex() {
        return this.backgroundIndex;
    }

    public String backgroundSymbol() {
        return this.backgroundSymbol;
    }

    public double logProb(int position, int label) {
        double u = this.factorTables[position].unnormalizedLogProbEnd(label);
        return u - this.z;
    }

    public double prob(int position, int label) {
        return Math.exp(this.logProb(position, label));
    }

    public double logProb(int position, Object label) {
        return this.logProb(position, this.classIndex.indexOf(label));
    }

    public double prob(int position, Object label) {
        return Math.exp(this.logProb(position, label));
    }

    public ClassicCounter probs(int position) {
        ClassicCounter c = new ClassicCounter();
        for (int i = 0; i < this.classIndex.size(); ++i) {
            Object label = this.classIndex.get(i);
            c.incrementCount(label, this.prob(position, i));
        }
        return c;
    }

    public ClassicCounter logProbs(int position) {
        ClassicCounter c = new ClassicCounter();
        for (int i = 0; i < this.classIndex.size(); ++i) {
            Object label = this.classIndex.get(i);
            c.incrementCount(label, this.logProb(position, i));
        }
        return c;
    }

    public double logProb(int position, int[] labels) {
        if (labels.length < this.windowSize) {
            return this.factorTables[position].unnormalizedLogProbEnd(labels) - this.z;
        }
        if (labels.length == this.windowSize) {
            return this.factorTables[position].unnormalizedLogProb(labels) - this.z;
        }
        int[] l = new int[this.windowSize];
        System.arraycopy(labels, 0, l, 0, l.length);
        int position1 = position - labels.length + this.windowSize;
        double p = this.factorTables[position1].unnormalizedLogProb(l) - this.z;
        l = new int[this.windowSize - 1];
        System.arraycopy(labels, 1, l, 0, l.length);
        ++position1;
        for (int i = this.windowSize; i < labels.length; ++i) {
            p += this.condLogProbGivenPrevious(position1++, labels[i], l);
            System.arraycopy(l, 1, l, 0, l.length - 1);
            l[this.windowSize - 2] = labels[i];
        }
        return p;
    }

    public double prob(int position, int[] labels) {
        return Math.exp(this.logProb(position, labels));
    }

    public double logProb(int position, Object[] labels) {
        return this.logProb(position, this.objectArrayToIntArray(labels));
    }

    public double prob(int position, Object[] labels) {
        return Math.exp(this.logProb(position, labels));
    }

    /*
     * Unable to fully structure code
     */
    public GeneralizedCounter logProbs(int position, int window) {
        gc = new GeneralizedCounter<Object>(window);
        labels = new int[window];
        block0: while (true) {
            labelsList = Arrays.asList(this.intArrayToObjectArray(labels));
            gc.incrementCount(labelsList, this.logProb(position, labels));
            i = 0;
            while (true) {
                if (i >= labels.length) continue block0;
                v0 = i;
                labels[v0] = labels[v0] + 1;
                if (labels[i] >= this.numClasses) ** break;
                continue block0;
                if (i == labels.length - 1) break block0;
                labels[i] = 0;
                ++i;
            }
            break;
        }
        return gc;
    }

    /*
     * Unable to fully structure code
     */
    public GeneralizedCounter probs(int position, int window) {
        gc = new GeneralizedCounter<Object>(window);
        labels = new int[window];
        block0: while (true) {
            labelsList = Arrays.asList(this.intArrayToObjectArray(labels));
            gc.incrementCount(labelsList, this.prob(position, labels));
            i = 0;
            while (true) {
                if (i >= labels.length) continue block0;
                v0 = i;
                labels[v0] = labels[v0] + 1;
                if (labels[i] >= this.numClasses) ** break;
                continue block0;
                if (i == labels.length - 1) break block0;
                labels[i] = 0;
                ++i;
            }
            break;
        }
        return gc;
    }

    private int[] objectArrayToIntArray(Object[] os) {
        int[] is = new int[os.length];
        for (int i = 0; i < os.length; ++i) {
            is[i] = this.classIndex.indexOf(os[i]);
        }
        return is;
    }

    private Object[] intArrayToObjectArray(int[] is) {
        Object[] os = new Object[is.length];
        for (int i = 0; i < is.length; ++i) {
            os[i] = this.classIndex.get(is[i]);
        }
        return os;
    }

    public double condLogProbGivenPrevious(int position, int label, int[] prevLabels) {
        if (prevLabels.length + 1 == this.windowSize) {
            return this.factorTables[position].conditionalLogProbGivenPrevious(prevLabels, label);
        }
        if (prevLabels.length + 1 < this.windowSize) {
            FactorTable ft = this.factorTables[position].sumOutFront();
            while (ft.windowSize() > prevLabels.length + 1) {
                ft = ft.sumOutFront();
            }
            return ft.conditionalLogProbGivenPrevious(prevLabels, label);
        }
        int[] p = new int[this.windowSize - 1];
        System.arraycopy(prevLabels, prevLabels.length - p.length, p, 0, p.length);
        return this.factorTables[position].conditionalLogProbGivenPrevious(p, label);
    }

    public double condLogProbGivenPrevious(int position, Object label, Object[] prevLabels) {
        return this.condLogProbGivenPrevious(position, this.classIndex.indexOf(label), this.objectArrayToIntArray(prevLabels));
    }

    public double condProbGivenPrevious(int position, int label, int[] prevLabels) {
        return Math.exp(this.condLogProbGivenPrevious(position, label, prevLabels));
    }

    public double condProbGivenPrevious(int position, Object label, Object[] prevLabels) {
        return Math.exp(this.condLogProbGivenPrevious(position, label, prevLabels));
    }

    public ClassicCounter condLogProbsGivenPrevious(int position, int[] prevlabels) {
        ClassicCounter c = new ClassicCounter();
        for (int i = 0; i < this.classIndex.size(); ++i) {
            Object label = this.classIndex.get(i);
            c.incrementCount(label, this.condLogProbGivenPrevious(position, i, prevlabels));
        }
        return c;
    }

    public ClassicCounter condLogProbsGivenPrevious(int position, Object[] prevlabels) {
        ClassicCounter c = new ClassicCounter();
        for (int i = 0; i < this.classIndex.size(); ++i) {
            Object label = this.classIndex.get(i);
            c.incrementCount(label, this.condLogProbGivenPrevious(position, (Object)i, prevlabels));
        }
        return c;
    }

    public double condLogProbGivenNext(int position, int label, int[] nextLabels) {
        position += nextLabels.length;
        if (nextLabels.length + 1 == this.windowSize) {
            return this.factorTables[position].conditionalLogProbGivenNext(nextLabels, label);
        }
        if (nextLabels.length + 1 < this.windowSize) {
            FactorTable ft = this.factorTables[position].sumOutFront();
            while (ft.windowSize() > nextLabels.length + 1) {
                ft = ft.sumOutFront();
            }
            return ft.conditionalLogProbGivenPrevious(nextLabels, label);
        }
        int[] p = new int[this.windowSize - 1];
        System.arraycopy(nextLabels, 0, p, 0, p.length);
        return this.factorTables[position].conditionalLogProbGivenPrevious(p, label);
    }

    public double condLogProbGivenNext(int position, Object label, Object[] nextLabels) {
        return this.condLogProbGivenNext(position, this.classIndex.indexOf(label), this.objectArrayToIntArray(nextLabels));
    }

    public double condProbGivenNext(int position, int label, int[] nextLabels) {
        return Math.exp(this.condLogProbGivenNext(position, label, nextLabels));
    }

    public double condProbGivenNext(int position, Object label, Object[] nextLabels) {
        return Math.exp(this.condLogProbGivenNext(position, label, nextLabels));
    }

    public ClassicCounter condLogProbsGivenNext(int position, int[] nextlabels) {
        ClassicCounter c = new ClassicCounter();
        for (int i = 0; i < this.classIndex.size(); ++i) {
            Object label = this.classIndex.get(i);
            c.incrementCount(label, this.condLogProbGivenNext(position, i, nextlabels));
        }
        return c;
    }

    public ClassicCounter condLogProbsGivenNext(int position, Object[] nextlabels) {
        ClassicCounter c = new ClassicCounter();
        for (int i = 0; i < this.classIndex.size(); ++i) {
            Object label = this.classIndex.get(i);
            c.incrementCount(label, this.condLogProbGivenNext(position, (Object)i, nextlabels));
        }
        return c;
    }

    public static CRFCliqueTree getCalibratedCliqueTree(double[][] weights, int[][][] data, Index[] labelIndices, int numClasses, Index classIndex, String backgroundSymbol) {
        int i;
        FactorTable[] factorTables = new FactorTable[data.length];
        FactorTable[] messages = new FactorTable[data.length - 1];
        for (i = 0; i < data.length; ++i) {
            factorTables[i] = CRFCliqueTree.getFactorTable(weights, data[i], labelIndices, numClasses);
            if (i <= 0) continue;
            messages[i - 1] = factorTables[i - 1].sumOutFront();
            factorTables[i].multiplyInFront(messages[i - 1]);
        }
        for (i = factorTables.length - 2; i >= 0; --i) {
            FactorTable summedOut = factorTables[i + 1].sumOutEnd();
            summedOut.divideBy(messages[i]);
            factorTables[i].multiplyInEnd(summedOut);
        }
        return new CRFCliqueTree(factorTables, classIndex, backgroundSymbol);
    }

    public static CRFCliqueTree getCalibratedCliqueTree(double[] weights, double wscale, int[][] weightIndices, int[][][] data, Index[] labelIndices, int numClasses, Index classIndex, String backgroundSymbol) {
        int i;
        FactorTable[] factorTables = new FactorTable[data.length];
        FactorTable[] messages = new FactorTable[data.length - 1];
        for (i = 0; i < data.length; ++i) {
            factorTables[i] = CRFCliqueTree.getFactorTable(weights, wscale, weightIndices, data[i], labelIndices, numClasses);
            if (i <= 0) continue;
            messages[i - 1] = factorTables[i - 1].sumOutFront();
            factorTables[i].multiplyInFront(messages[i - 1]);
        }
        for (i = factorTables.length - 2; i >= 0; --i) {
            FactorTable summedOut = factorTables[i + 1].sumOutEnd();
            summedOut.divideBy(messages[i]);
            factorTables[i].multiplyInEnd(summedOut);
        }
        return new CRFCliqueTree(factorTables, classIndex, backgroundSymbol);
    }

    private static FactorTable getFactorTable(double[] weights, double wscale, int[][] weightIndices, int[][] data, Index[] labelIndices, int numClasses) {
        FactorTable factorTable = null;
        for (int j = 0; j < labelIndices.length; ++j) {
            Index labelIndex = labelIndices[j];
            FactorTable ft = new FactorTable(numClasses, j + 1);
            int liSize = labelIndex.size();
            for (int k = 0; k < liSize; ++k) {
                int[] label = ((CRFLabel)labelIndex.get(k)).getLabel();
                double weight2 = 0.0;
                for (int m = 0; m < data[j].length; ++m) {
                    int wi = weightIndices[data[j][m]][k];
                    weight2 += wscale * weights[wi];
                }
                ft.setValue(label, weight2);
            }
            if (j > 0) {
                ft.multiplyInEnd(factorTable);
            }
            factorTable = ft;
        }
        return factorTable;
    }

    private static FactorTable getFactorTable(double[][] weights, int[][] data, Index[] labelIndices, int numClasses) {
        FactorTable factorTable = null;
        for (int j = 0; j < labelIndices.length; ++j) {
            Index labelIndex = labelIndices[j];
            FactorTable ft = new FactorTable(numClasses, j + 1);
            int liSize = labelIndex.size();
            for (int k = 0; k < liSize; ++k) {
                int[] label = ((CRFLabel)labelIndex.get(k)).getLabel();
                double weight2 = 0.0;
                for (int m = 0; m < data[j].length; ++m) {
                    weight2 += weights[data[j][m]][k];
                }
                ft.setValue(label, weight2);
            }
            if (j > 0) {
                ft.multiplyInEnd(factorTable);
            }
            factorTable = ft;
        }
        return factorTable;
    }

    public double[] getConditionalDistribution(int[] sequence, int position) {
        double[] result = this.scoresOf(sequence, position);
        ArrayMath.logNormalize(result);
        result = ArrayMath.exp(result);
        return result;
    }

    public void updateSequenceElement(int[] sequence, int pos, int oldVal) {
    }

    public void setInitialSequence(int[] sequence) {
    }

    public int getNumValues() {
        return this.numClasses;
    }
}

