/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

public class MINND
extends Classifier
implements OptionHandler,
MultiInstanceCapabilitiesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -4512599203273864994L;
    protected int m_Neighbour = 1;
    protected double[][] m_Mean = null;
    protected double[][] m_Variance = null;
    protected int m_Dimension = 0;
    protected Instances m_Attributes;
    protected double[] m_Class = null;
    protected int m_NumClasses = 0;
    protected double[] m_Weights = null;
    private static double m_ZERO = 1.0E-45;
    protected double m_Rate = -1.0;
    private double[] m_MinArray = null;
    private double[] m_MaxArray = null;
    private double m_STOP = 1.0E-45;
    private double[][] m_Change = null;
    private double[][] m_NoiseM = null;
    private double[][] m_NoiseV = null;
    private double[][] m_ValidM = null;
    private double[][] m_ValidV = null;
    private int m_Select = 1;
    private int m_Choose = 1;
    private double m_Decay = 0.5;

    public String globalInfo() {
        return "Multiple-Instance Nearest Neighbour with Distribution learner.\n\nIt uses gradient descent to find the weight for each dimension of each exeamplar from the starting point of 1.0. In order to avoid overfitting, it uses mean-square function (i.e. the Euclidean distance) to search for the weights.\n It then uses the weights to cleanse the training data. After that it searches for the weights again from the starting points of the weights searched before.\n Finally it uses the most updated weights to cleanse the test exemplar and then finds the nearest neighbour of the test exemplar using partly-weighted Kullback distance. But the variances in the Kullback distance are the ones before cleansing.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.MISC);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Xin Xu");
        result.setValue(TechnicalInformation.Field.YEAR, "2001");
        result.setValue(TechnicalInformation.Field.TITLE, "A nearest distribution approach to multiple-instance learning");
        result.setValue(TechnicalInformation.Field.SCHOOL, "University of Waikato");
        result.setValue(TechnicalInformation.Field.ADDRESS, "Hamilton, NZ");
        result.setValue(TechnicalInformation.Field.NOTE, "0657.591B");
        return result;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return result;
    }

    @Override
    public Capabilities getMultiInstanceCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.disableAllClasses();
        result.enable(Capabilities.Capability.NO_CLASS);
        return result;
    }

    @Override
    public void buildClassifier(Instances exs) throws Exception {
        int i;
        Instance example;
        this.getCapabilities().testWithFail(exs);
        Instances newData = new Instances(exs);
        newData.deleteWithMissingClass();
        int numegs = newData.numInstances();
        this.m_Dimension = newData.attribute(1).relation().numAttributes();
        this.m_Attributes = newData.stringFreeStructure();
        this.m_Change = new double[numegs][this.m_Dimension];
        this.m_NumClasses = exs.numClasses();
        this.m_Mean = new double[numegs][this.m_Dimension];
        this.m_Variance = new double[numegs][this.m_Dimension];
        this.m_Class = new double[numegs];
        this.m_Weights = new double[numegs];
        this.m_NoiseM = new double[numegs][this.m_Dimension];
        this.m_NoiseV = new double[numegs][this.m_Dimension];
        this.m_ValidM = new double[numegs][this.m_Dimension];
        this.m_ValidV = new double[numegs][this.m_Dimension];
        this.m_MinArray = new double[this.m_Dimension];
        this.m_MaxArray = new double[this.m_Dimension];
        int v = 0;
        while (v < this.m_Dimension) {
            this.m_MaxArray[v] = Double.NaN;
            this.m_MinArray[v] = Double.NaN;
            ++v;
        }
        int w = 0;
        while (w < numegs) {
            this.updateMinMax(newData.instance(w));
            ++w;
        }
        Instances data = this.m_Attributes;
        int x = 0;
        while (x < numegs) {
            example = newData.instance(x);
            example = this.scale(example);
            i = 0;
            while (i < this.m_Dimension) {
                this.m_Mean[x][i] = example.relationalValue(1).meanOrMode(i);
                this.m_Variance[x][i] = example.relationalValue(1).variance(i);
                if (Utils.eq(this.m_Variance[x][i], 0.0)) {
                    this.m_Variance[x][i] = m_ZERO;
                }
                this.m_Change[x][i] = 1.0;
                ++i;
            }
            data.add(example);
            this.m_Class[x] = example.classValue();
            this.m_Weights[x] = example.weight();
            ++x;
        }
        int z = 0;
        while (z < numegs) {
            this.findWeights(z, this.m_Mean);
            ++z;
        }
        x = 0;
        while (x < numegs) {
            example = this.preprocess(data, x);
            if (this.getDebug()) {
                System.out.println("???Exemplar " + x + " has been pre-processed:" + data.instance(x).relationalValue(1).sumOfWeights() + "|" + example.relationalValue(1).sumOfWeights() + "; class:" + this.m_Class[x]);
            }
            if (Utils.gr(example.relationalValue(1).sumOfWeights(), 0.0)) {
                i = 0;
                while (i < this.m_Dimension) {
                    this.m_ValidM[x][i] = example.relationalValue(1).meanOrMode(i);
                    this.m_ValidV[x][i] = example.relationalValue(1).variance(i);
                    if (Utils.eq(this.m_ValidV[x][i], 0.0)) {
                        this.m_ValidV[x][i] = m_ZERO;
                    }
                    ++i;
                }
            } else {
                this.m_ValidM[x] = null;
                this.m_ValidV[x] = null;
            }
            ++x;
        }
        z = 0;
        while (z < numegs) {
            if (this.m_ValidM[z] != null) {
                this.findWeights(z, this.m_ValidM);
            }
            ++z;
        }
    }

    public Instance preprocess(Instances data, int pos) throws Exception {
        Instance before = data.instance(pos);
        if ((int)before.classValue() == 0) {
            this.m_NoiseM[pos] = null;
            this.m_NoiseV[pos] = null;
            return before;
        }
        Instances after_relationInsts = before.attribute(1).relation().stringFreeStructure();
        Instances noises_relationInsts = before.attribute(1).relation().stringFreeStructure();
        Instances newData = this.m_Attributes;
        Instance after = new Instance(before.numAttributes());
        Instance noises = new Instance(before.numAttributes());
        after.setDataset(newData);
        noises.setDataset(newData);
        int g = 0;
        while (g < before.relationalValue(1).numInstances()) {
            Instance datum = before.relationalValue(1).instance(g);
            double[] dists = new double[data.numInstances()];
            int i = 0;
            while (i < data.numInstances()) {
                dists[i] = i != pos ? this.distance(datum, this.m_Mean[i], this.m_Variance[i], i) : Double.POSITIVE_INFINITY;
                ++i;
            }
            int[] pred = new int[this.m_NumClasses];
            int n = 0;
            while (n < pred.length) {
                pred[n] = 0;
                ++n;
            }
            int o = 0;
            while (o < this.m_Select) {
                int index = Utils.minIndex(dists);
                int n2 = (int)this.m_Class[index];
                pred[n2] = pred[n2] + 1;
                dists[index] = Double.POSITIVE_INFINITY;
                ++o;
            }
            int clas = Utils.maxIndex(pred);
            if ((int)before.classValue() != clas) {
                noises_relationInsts.add(datum);
            } else {
                after_relationInsts.add(datum);
            }
            ++g;
        }
        int relationValue = noises.attribute(1).addRelation(noises_relationInsts);
        noises.setValue(0, before.value(0));
        noises.setValue(1, (double)relationValue);
        noises.setValue(2, before.classValue());
        relationValue = after.attribute(1).addRelation(after_relationInsts);
        after.setValue(0, before.value(0));
        after.setValue(1, (double)relationValue);
        after.setValue(2, before.classValue());
        if (Utils.gr(noises.relationalValue(1).sumOfWeights(), 0.0)) {
            int i = 0;
            while (i < this.m_Dimension) {
                this.m_NoiseM[pos][i] = noises.relationalValue(1).meanOrMode(i);
                this.m_NoiseV[pos][i] = noises.relationalValue(1).variance(i);
                if (Utils.eq(this.m_NoiseV[pos][i], 0.0)) {
                    this.m_NoiseV[pos][i] = m_ZERO;
                }
                ++i;
            }
        } else {
            this.m_NoiseM[pos] = null;
            this.m_NoiseV[pos] = null;
        }
        return after;
    }

    private double distance(Instance first, double[] mean, double[] var, int pos) {
        double distance = 0.0;
        int i = 0;
        while (i < this.m_Dimension) {
            if (first.attribute(i).isNumeric()) {
                if (!first.isMissing(i)) {
                    double diff = first.value(i) - mean[i];
                    distance = Utils.gr(var[i], m_ZERO) ? (distance += this.m_Change[pos][i] * var[i] * diff * diff) : (distance += this.m_Change[pos][i] * diff * diff);
                } else {
                    distance = Utils.gr(var[i], m_ZERO) ? (distance += this.m_Change[pos][i] * var[i]) : (distance += this.m_Change[pos][i] * 1.0);
                }
            }
            ++i;
        }
        return distance;
    }

    private void updateMinMax(Instance ex) {
        Instances insts = ex.relationalValue(1);
        int j = 0;
        while (j < this.m_Dimension) {
            if (insts.attribute(j).isNumeric()) {
                int k = 0;
                while (k < insts.numInstances()) {
                    Instance ins = insts.instance(k);
                    if (!ins.isMissing(j)) {
                        if (Double.isNaN(this.m_MinArray[j])) {
                            this.m_MinArray[j] = ins.value(j);
                            this.m_MaxArray[j] = ins.value(j);
                        } else if (ins.value(j) < this.m_MinArray[j]) {
                            this.m_MinArray[j] = ins.value(j);
                        } else if (ins.value(j) > this.m_MaxArray[j]) {
                            this.m_MaxArray[j] = ins.value(j);
                        }
                    }
                    ++k;
                }
            }
            ++j;
        }
    }

    private Instance scale(Instance before) throws Exception {
        Instances afterInsts = before.relationalValue(1).stringFreeStructure();
        Instance after = new Instance(before.numAttributes());
        after.setDataset(this.m_Attributes);
        int i = 0;
        while (i < before.relationalValue(1).numInstances()) {
            Instance datum = before.relationalValue(1).instance(i);
            Instance inst = (Instance)datum.copy();
            int j = 0;
            while (j < this.m_Dimension) {
                if (before.relationalValue(1).attribute(j).isNumeric()) {
                    inst.setValue(j, (datum.value(j) - this.m_MinArray[j]) / (this.m_MaxArray[j] - this.m_MinArray[j]));
                }
                ++j;
            }
            afterInsts.add(inst);
            ++i;
        }
        int attValue = after.attribute(1).addRelation(afterInsts);
        after.setValue(0, before.value(0));
        after.setValue(1, (double)attValue);
        after.setValue(2, before.value(2));
        return after;
    }

    public void findWeights(int row, double[][] mean) {
        double[] neww = new double[this.m_Dimension];
        double[] oldw = new double[this.m_Dimension];
        System.arraycopy(this.m_Change[row], 0, neww, 0, this.m_Dimension);
        double newresult = this.target(neww, mean, row, this.m_Class);
        double result = Double.POSITIVE_INFINITY;
        double rate = 0.05;
        if (this.m_Rate != -1.0) {
            rate = this.m_Rate;
        }
        block0: while (Utils.gr(result - newresult, this.m_STOP)) {
            oldw = neww;
            neww = new double[this.m_Dimension];
            double[] delta = this.delta(oldw, mean, row, this.m_Class);
            int i = 0;
            while (i < this.m_Dimension) {
                if (Utils.gr(this.m_Variance[row][i], 0.0)) {
                    neww[i] = oldw[i] + rate * delta[i];
                }
                ++i;
            }
            result = newresult;
            newresult = this.target(neww, mean, row, this.m_Class);
            while (Utils.gr(newresult, result)) {
                if (this.m_Rate == -1.0) {
                    rate *= this.m_Decay;
                    i = 0;
                    while (i < this.m_Dimension) {
                        if (Utils.gr(this.m_Variance[row][i], 0.0)) {
                            neww[i] = oldw[i] + rate * delta[i];
                        }
                        ++i;
                    }
                    newresult = this.target(neww, mean, row, this.m_Class);
                    continue;
                }
                i = 0;
                while (i < this.m_Dimension) {
                    neww[i] = oldw[i];
                    ++i;
                }
                break block0;
            }
        }
        this.m_Change[row] = neww;
    }

    private double[] delta(double[] x, double[][] X, int rowpos, double[] Y) {
        double y = Y[rowpos];
        double[] delta = new double[this.m_Dimension];
        int h = 0;
        while (h < this.m_Dimension) {
            delta[h] = 0.0;
            ++h;
        }
        int i = 0;
        while (i < X.length) {
            if (i != rowpos && X[i] != null) {
                double var = y == Y[i] ? 0.0 : Math.sqrt((double)this.m_Dimension - 1.0);
                double distance = 0.0;
                int j = 0;
                while (j < this.m_Dimension) {
                    if (Utils.gr(this.m_Variance[rowpos][j], 0.0)) {
                        distance += x[j] * (X[rowpos][j] - X[i][j]) * (X[rowpos][j] - X[i][j]);
                    }
                    ++j;
                }
                if ((distance = Math.sqrt(distance)) != 0.0) {
                    int k = 0;
                    while (k < this.m_Dimension) {
                        if (this.m_Variance[rowpos][k] > 0.0) {
                            int n = k;
                            delta[n] = delta[n] + (var / distance - 1.0) * 0.5 * (X[rowpos][k] - X[i][k]) * (X[rowpos][k] - X[i][k]);
                        }
                        ++k;
                    }
                }
            }
            ++i;
        }
        return delta;
    }

    public double target(double[] x, double[][] X, int rowpos, double[] Y) {
        double y = Y[rowpos];
        double result = 0.0;
        int i = 0;
        while (i < X.length) {
            if (i != rowpos && X[i] != null) {
                double var = y == Y[i] ? 0.0 : Math.sqrt((double)this.m_Dimension - 1.0);
                double f = 0.0;
                int j = 0;
                while (j < this.m_Dimension) {
                    if (Utils.gr(this.m_Variance[rowpos][j], 0.0)) {
                        f += x[j] * (X[rowpos][j] - X[i][j]) * (X[rowpos][j] - X[i][j]);
                    }
                    ++j;
                }
                if (Double.isInfinite(f = Math.sqrt(f))) {
                    System.exit(1);
                }
                result += 0.5 * (f - var) * (f - var);
            }
            ++i;
        }
        return result;
    }

    @Override
    public double classifyInstance(Instance ex) throws Exception {
        ex = this.scale(ex);
        double[] var = new double[this.m_Dimension];
        int i = 0;
        while (i < this.m_Dimension) {
            var[i] = ex.relationalValue(1).variance(i);
            ++i;
        }
        double[] kullback = new double[this.m_Class.length];
        double[] predict = new double[this.m_NumClasses];
        int h = 0;
        while (h < predict.length) {
            predict[h] = 0.0;
            ++h;
        }
        if ((ex = this.cleanse(ex)).relationalValue(1).numInstances() == 0) {
            if (this.getDebug()) {
                System.out.println("???Whole exemplar falls into ambiguous area!");
            }
            return 1.0;
        }
        double[] mean = new double[this.m_Dimension];
        int i2 = 0;
        while (i2 < this.m_Dimension) {
            mean[i2] = ex.relationalValue(1).meanOrMode(i2);
            ++i2;
        }
        int h2 = 0;
        while (h2 < var.length) {
            if (Utils.eq(var[h2], 0.0)) {
                var[h2] = m_ZERO;
            }
            ++h2;
        }
        i2 = 0;
        while (i2 < this.m_Class.length) {
            kullback[i2] = this.m_ValidM[i2] != null ? this.kullback(mean, this.m_ValidM[i2], var, this.m_Variance[i2], i2) : Double.POSITIVE_INFINITY;
            ++i2;
        }
        int j = 0;
        while (j < this.m_Neighbour) {
            int pos = Utils.minIndex(kullback);
            int n = (int)this.m_Class[pos];
            predict[n] = predict[n] + this.m_Weights[pos];
            kullback[pos] = Double.POSITIVE_INFINITY;
            ++j;
        }
        if (this.getDebug()) {
            System.out.println("???There are still some unambiguous instances in this exemplar! Predicted as: " + Utils.maxIndex(predict));
        }
        return Utils.maxIndex(predict);
    }

    public Instance cleanse(Instance before) throws Exception {
        Instances insts = before.relationalValue(1).stringFreeStructure();
        Instance after = new Instance(before.numAttributes());
        after.setDataset(this.m_Attributes);
        int g = 0;
        while (g < before.relationalValue(1).numInstances()) {
            Instance datum = before.relationalValue(1).instance(g);
            double[] minNoiDists = new double[this.m_Choose];
            double[] minValDists = new double[this.m_Choose];
            int noiseCount = 0;
            int validCount = 0;
            double[] nDist = new double[this.m_Mean.length];
            double[] vDist = new double[this.m_Mean.length];
            int h = 0;
            while (h < this.m_Mean.length) {
                vDist[h] = this.m_ValidM[h] == null ? Double.POSITIVE_INFINITY : this.distance(datum, this.m_ValidM[h], this.m_ValidV[h], h);
                nDist[h] = this.m_NoiseM[h] == null ? Double.POSITIVE_INFINITY : this.distance(datum, this.m_NoiseM[h], this.m_NoiseV[h], h);
                ++h;
            }
            int k = 0;
            while (k < this.m_Choose) {
                int pos = Utils.minIndex(vDist);
                minValDists[k] = vDist[pos];
                vDist[pos] = Double.POSITIVE_INFINITY;
                pos = Utils.minIndex(nDist);
                minNoiDists[k] = nDist[pos];
                nDist[pos] = Double.POSITIVE_INFINITY;
                ++k;
            }
            int x = 0;
            int y = 0;
            while (x + y < this.m_Choose) {
                if (minValDists[x] <= minNoiDists[y]) {
                    ++validCount;
                    ++x;
                    continue;
                }
                ++noiseCount;
                ++y;
            }
            if (x >= y) {
                insts.add(datum);
            }
            ++g;
        }
        after.setValue(0, before.value(0));
        after.setValue(1, (double)after.attribute(1).addRelation(insts));
        after.setValue(2, before.value(2));
        return after;
    }

    public double kullback(double[] mu1, double[] mu2, double[] var1, double[] var2, int pos) {
        int p = mu1.length;
        double result = 0.0;
        int y = 0;
        while (y < p) {
            if (Utils.gr(var1[y], 0.0) && Utils.gr(var2[y], 0.0)) {
                result += Math.log(Math.sqrt(var2[y] / var1[y])) + var1[y] / (2.0 * var2[y]) + this.m_Change[pos][y] * (mu1[y] - mu2[y]) * (mu1[y] - mu2[y]) / (2.0 * var2[y]) - 0.5;
            }
            ++y;
        }
        return result;
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> result = new Vector<Option>();
        result.addElement(new Option("\tSet number of nearest neighbour for prediction\n\t(default 1)", "K", 1, "-K <number of neighbours>"));
        result.addElement(new Option("\tSet number of nearest neighbour for cleansing the training data\n\t(default 1)", "S", 1, "-S <number of neighbours>"));
        result.addElement(new Option("\tSet number of nearest neighbour for cleansing the testing data\n\t(default 1)", "E", 1, "-E <number of neighbours>"));
        return result.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.setDebug(Utils.getFlag('D', options));
        String numNeighbourString = Utils.getOption('K', options);
        if (numNeighbourString.length() != 0) {
            this.setNumNeighbours(Integer.parseInt(numNeighbourString));
        } else {
            this.setNumNeighbours(1);
        }
        numNeighbourString = Utils.getOption('S', options);
        if (numNeighbourString.length() != 0) {
            this.setNumTrainingNoises(Integer.parseInt(numNeighbourString));
        } else {
            this.setNumTrainingNoises(1);
        }
        numNeighbourString = Utils.getOption('E', options);
        if (numNeighbourString.length() != 0) {
            this.setNumTestingNoises(Integer.parseInt(numNeighbourString));
        } else {
            this.setNumTestingNoises(1);
        }
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.getDebug()) {
            result.add("-D");
        }
        result.add("-K");
        result.add("" + this.getNumNeighbours());
        result.add("-S");
        result.add("" + this.getNumTrainingNoises());
        result.add("-E");
        result.add("" + this.getNumTestingNoises());
        return result.toArray(new String[result.size()]);
    }

    public String numNeighboursTipText() {
        return "The number of nearest neighbours to the estimate the class prediction of test bags.";
    }

    public void setNumNeighbours(int numNeighbour) {
        this.m_Neighbour = numNeighbour;
    }

    public int getNumNeighbours() {
        return this.m_Neighbour;
    }

    public String numTrainingNoisesTipText() {
        return "The number of nearest neighbour instances in the selection of noises in the training data.";
    }

    public void setNumTrainingNoises(int numTraining) {
        this.m_Select = numTraining;
    }

    public int getNumTrainingNoises() {
        return this.m_Select;
    }

    public String numTestingNoisesTipText() {
        return "The number of nearest neighbour instances in the selection of noises in the test data.";
    }

    public int getNumTestingNoises() {
        return this.m_Choose;
    }

    public void setNumTestingNoises(int numTesting) {
        this.m_Choose = numTesting;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5527 $");
    }

    public static void main(String[] args) {
        MINND.runClassifier(new MINND(), args);
    }
}

