/*
 * Decompiled with CFR 0.152.
 */
package weka.filters.supervised.instance;

import java.util.Collections;
import java.util.Comparator;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.SupervisedFilter;

public class SMOTE
extends Filter
implements SupervisedFilter,
OptionHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -1653880819059250364L;
    protected int m_NearestNeighbors = 5;
    protected int m_RandomSeed = 1;
    protected double m_Percentage = 100.0;
    protected String m_ClassValueIndex = "0";
    protected boolean m_DetectMinorityClass = true;

    public String globalInfo() {
        return "Resamples a dataset by applying the Synthetic Minority Oversampling TEchnique (SMOTE). The original dataset must fit entirely in memory. The amount of SMOTE and number of nearest neighbors may be specified. For more information, see \n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Nitesh V. Chawla et. al.");
        result.setValue(TechnicalInformation.Field.TITLE, "Synthetic Minority Over-sampling Technique");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Journal of Artificial Intelligence Research");
        result.setValue(TechnicalInformation.Field.YEAR, "2002");
        result.setValue(TechnicalInformation.Field.VOLUME, "16");
        result.setValue(TechnicalInformation.Field.PAGES, "321-357");
        return result;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5542 $");
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enableAllAttributes();
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.addElement(new Option("\tSpecifies the random number seed\n\t(default 1)", "S", 1, "-S <num>"));
        newVector.addElement(new Option("\tSpecifies percentage of SMOTE instances to create.\n\t(default 100.0)\n", "P", 1, "-P <percentage>"));
        newVector.addElement(new Option("\tSpecifies the number of nearest neighbors to use.\n\t(default 5)\n", "K", 1, "-K <nearest-neighbors>"));
        newVector.addElement(new Option("\tSpecifies the index of the nominal class value to SMOTE\n\t(default 0: auto-detect non-empty minority class))\n", "C", 1, "-C <value-index>"));
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String seedStr = Utils.getOption('S', options);
        if (seedStr.length() != 0) {
            this.setRandomSeed(Integer.parseInt(seedStr));
        } else {
            this.setRandomSeed(1);
        }
        String percentageStr = Utils.getOption('P', options);
        if (percentageStr.length() != 0) {
            this.setPercentage(new Double(percentageStr));
        } else {
            this.setPercentage(100.0);
        }
        String nnStr = Utils.getOption('K', options);
        if (nnStr.length() != 0) {
            this.setNearestNeighbors(Integer.parseInt(nnStr));
        } else {
            this.setNearestNeighbors(5);
        }
        String classValueIndexStr = Utils.getOption('C', options);
        if (classValueIndexStr.length() != 0) {
            this.setClassValue(classValueIndexStr);
        } else {
            this.m_DetectMinorityClass = true;
        }
    }

    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        result.add("-C");
        result.add(this.getClassValue());
        result.add("-K");
        result.add("" + this.getNearestNeighbors());
        result.add("-P");
        result.add("" + this.getPercentage());
        result.add("-S");
        result.add("" + this.getRandomSeed());
        return result.toArray(new String[result.size()]);
    }

    public String randomSeedTipText() {
        return "The seed used for random sampling.";
    }

    public int getRandomSeed() {
        return this.m_RandomSeed;
    }

    public void setRandomSeed(int value) {
        this.m_RandomSeed = value;
    }

    public String percentageTipText() {
        return "The percentage of SMOTE instances to create.";
    }

    public void setPercentage(double value) {
        if (value >= 0.0) {
            this.m_Percentage = value;
        } else {
            System.err.println("Percentage must be >= 0!");
        }
    }

    public double getPercentage() {
        return this.m_Percentage;
    }

    public String nearestNeighborsTipText() {
        return "The number of nearest neighbors to use.";
    }

    public void setNearestNeighbors(int value) {
        if (value >= 1) {
            this.m_NearestNeighbors = value;
        } else {
            System.err.println("At least 1 neighbor necessary!");
        }
    }

    public int getNearestNeighbors() {
        return this.m_NearestNeighbors;
    }

    public String classValueTipText() {
        return "The index of the class value to which SMOTE should be applied. Use a value of 0 to auto-detect the non-empty minority class.";
    }

    public void setClassValue(String value) {
        this.m_ClassValueIndex = value;
        this.m_DetectMinorityClass = this.m_ClassValueIndex.equals("0");
    }

    public String getClassValue() {
        return this.m_ClassValueIndex;
    }

    public boolean setInputFormat(Instances instanceInfo) throws Exception {
        super.setInputFormat(instanceInfo);
        super.setOutputFormat(instanceInfo);
        return true;
    }

    public boolean input(Instance instance) {
        if (this.getInputFormat() == null) {
            throw new IllegalStateException("No input instance format defined");
        }
        if (this.m_NewBatch) {
            this.resetQueue();
            this.m_NewBatch = false;
        }
        if (this.m_FirstBatchDone) {
            this.push(instance);
            return true;
        }
        this.bufferInput(instance);
        return false;
    }

    public boolean batchFinished() throws Exception {
        if (this.getInputFormat() == null) {
            throw new IllegalStateException("No input instance format defined");
        }
        if (!this.m_FirstBatchDone) {
            this.doSMOTE();
        }
        this.flushInput();
        this.m_NewBatch = true;
        this.m_FirstBatchDone = true;
        return this.numPendingOutput() != 0;
    }

    protected void doSMOTE() throws Exception {
        int nearestNeighbors;
        int minIndex = 0;
        int min = Integer.MAX_VALUE;
        if (this.m_DetectMinorityClass) {
            int[] classCounts = this.getInputFormat().attributeStats((int)this.getInputFormat().classIndex()).nominalCounts;
            for (int i = 0; i < classCounts.length; ++i) {
                if (classCounts[i] == 0 || classCounts[i] >= min) continue;
                min = classCounts[i];
                minIndex = i;
            }
        } else {
            String classVal = this.getClassValue();
            minIndex = classVal.equalsIgnoreCase("first") ? 1 : (classVal.equalsIgnoreCase("last") ? this.getInputFormat().numClasses() : Integer.parseInt(classVal));
            if (minIndex > this.getInputFormat().numClasses()) {
                throw new Exception("value index must be <= the number of classes");
            }
            --minIndex;
        }
        if ((nearestNeighbors = min <= this.getNearestNeighbors() ? min - 1 : this.getNearestNeighbors()) < 1) {
            throw new Exception("Cannot use 0 neighbors!");
        }
        Instances sample = this.getInputFormat().stringFreeStructure();
        Enumeration instanceEnum = this.getInputFormat().enumerateInstances();
        while (instanceEnum.hasMoreElements()) {
            Instance instance = (Instance)instanceEnum.nextElement();
            this.push((Instance)instance.copy());
            if ((int)instance.classValue() != minIndex) continue;
            sample.add(instance);
        }
        HashMap<Attribute, double[][]> vdmMap = new HashMap<Attribute, double[][]>();
        Enumeration attrEnum = this.getInputFormat().enumerateAttributes();
        while (attrEnum.hasMoreElements()) {
            Attribute attr = (Attribute)attrEnum.nextElement();
            if (attr.equals(this.getInputFormat().classAttribute()) || !attr.isNominal() && !attr.isString()) continue;
            double[][] vdm = new double[attr.numValues()][attr.numValues()];
            vdmMap.put(attr, vdm);
            int[] featureValueCounts = new int[attr.numValues()];
            int[][] featureValueCountsByClass = new int[this.getInputFormat().classAttribute().numValues()][attr.numValues()];
            instanceEnum = this.getInputFormat().enumerateInstances();
            while (instanceEnum.hasMoreElements()) {
                Instance instance = (Instance)instanceEnum.nextElement();
                int value = (int)instance.value(attr);
                int classValue = (int)instance.classValue();
                int n = value;
                featureValueCounts[n] = featureValueCounts[n] + 1;
                int[] nArray = featureValueCountsByClass[classValue];
                int n2 = value;
                nArray[n2] = nArray[n2] + 1;
            }
            for (int valueIndex1 = 0; valueIndex1 < attr.numValues(); ++valueIndex1) {
                for (int valueIndex2 = 0; valueIndex2 < attr.numValues(); ++valueIndex2) {
                    double sum = 0.0;
                    for (int classValueIndex = 0; classValueIndex < this.getInputFormat().numClasses(); ++classValueIndex) {
                        double c1i = featureValueCountsByClass[classValueIndex][valueIndex1];
                        double c2i = featureValueCountsByClass[classValueIndex][valueIndex2];
                        double c1 = featureValueCounts[valueIndex1];
                        double c2 = featureValueCounts[valueIndex2];
                        double term1 = c1i / c1;
                        double term2 = c2i / c2;
                        sum += Math.abs(term1 - term2);
                    }
                    vdm[valueIndex1][valueIndex2] = sum;
                }
            }
        }
        Random rand = new Random(this.getRandomSeed());
        List<Integer> extraIndices = new LinkedList<Integer>();
        double percentageRemainder = this.getPercentage() / 100.0 - Math.floor(this.getPercentage() / 100.0);
        int extraIndicesCount = (int)(percentageRemainder * (double)sample.numInstances());
        if (extraIndicesCount >= 1) {
            for (int i = 0; i < sample.numInstances(); ++i) {
                extraIndices.add(i);
            }
        }
        Collections.shuffle(extraIndices, rand);
        extraIndices = extraIndices.subList(0, extraIndicesCount);
        HashSet extraIndexSet = new HashSet(extraIndices);
        Instance[] nnArray = new Instance[nearestNeighbors];
        for (int i = 0; i < sample.numInstances(); ++i) {
            Instance instanceI = sample.instance(i);
            LinkedList<Object[]> distanceToInstance = new LinkedList<Object[]>();
            for (int j = 0; j < sample.numInstances(); ++j) {
                Instance instanceJ = sample.instance(j);
                if (i == j) continue;
                double distance = 0.0;
                attrEnum = this.getInputFormat().enumerateAttributes();
                while (attrEnum.hasMoreElements()) {
                    Attribute attr = (Attribute)attrEnum.nextElement();
                    if (attr.equals(this.getInputFormat().classAttribute())) continue;
                    double iVal = instanceI.value(attr);
                    double jVal = instanceJ.value(attr);
                    if (attr.isNumeric()) {
                        distance += Math.pow(iVal - jVal, 2.0);
                        continue;
                    }
                    distance += ((double[][])vdmMap.get(attr))[(int)iVal][(int)jVal];
                }
                distance = Math.pow(distance, 0.5);
                distanceToInstance.add(new Object[]{distance, instanceJ});
            }
            Collections.sort(distanceToInstance, new Comparator(){

                public int compare(Object o1, Object o2) {
                    double distance1 = (Double)((Object[])o1)[0];
                    double distance2 = (Double)((Object[])o2)[0];
                    return (int)Math.ceil(distance1 - distance2);
                }
            });
            Iterator entryIterator = distanceToInstance.iterator();
            for (int j = 0; entryIterator.hasNext() && j < nearestNeighbors; ++j) {
                nnArray[j] = (Instance)((Object[])entryIterator.next())[1];
            }
            for (int n = (int)Math.floor(this.getPercentage() / 100.0); n > 0 || extraIndexSet.remove(i); --n) {
                double[] values = new double[sample.numAttributes()];
                int nn = rand.nextInt(nearestNeighbors);
                attrEnum = this.getInputFormat().enumerateAttributes();
                while (attrEnum.hasMoreElements()) {
                    int iVal;
                    Attribute attr = (Attribute)attrEnum.nextElement();
                    if (attr.equals(this.getInputFormat().classAttribute())) continue;
                    if (attr.isNumeric()) {
                        double dif = nnArray[nn].value(attr) - instanceI.value(attr);
                        double gap = rand.nextDouble();
                        values[attr.index()] = instanceI.value(attr) + gap * dif;
                        continue;
                    }
                    if (attr.isDate()) {
                        double dif = nnArray[nn].value(attr) - instanceI.value(attr);
                        double gap = rand.nextDouble();
                        values[attr.index()] = (long)(instanceI.value(attr) + gap * dif);
                        continue;
                    }
                    int[] valueCounts = new int[attr.numValues()];
                    int n3 = iVal = (int)instanceI.value(attr);
                    valueCounts[n3] = valueCounts[n3] + 1;
                    for (int nnEx = 0; nnEx < nearestNeighbors; ++nnEx) {
                        int val;
                        int n4 = val = (int)nnArray[nnEx].value(attr);
                        valueCounts[n4] = valueCounts[n4] + 1;
                    }
                    int maxIndex = 0;
                    int max = Integer.MIN_VALUE;
                    for (int index = 0; index < attr.numValues(); ++index) {
                        if (valueCounts[index] <= max) continue;
                        max = valueCounts[index];
                        maxIndex = index;
                    }
                    values[attr.index()] = maxIndex;
                }
                values[sample.classIndex()] = minIndex;
                Instance synthetic = new Instance(1.0, values);
                this.push(synthetic);
            }
        }
    }

    public static void main(String[] args) {
        SMOTE.runFilter(new SMOTE(), args);
    }
}

