/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer;
import weka.classifiers.trees.RandomTree;
import weka.core.BatchPredictor;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.PartitionGenerator;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class RandomCommittee
extends RandomizableParallelIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler,
PartitionGenerator {
    static final long serialVersionUID = -9204394360557300093L;
    protected Instances m_data;

    public RandomCommittee() {
        this.m_Classifier = new RandomTree();
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.RandomTree";
    }

    public String globalInfo() {
        return "Class for building an ensemble of randomizable base classifiers. Each base classifiers is built using a different random number seed (but based one the same data). The final prediction is a straight average of the predictions generated by the individual base classifiers.";
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        this.m_data = new Instances(data);
        super.buildClassifier(this.m_data);
        if (!(this.m_Classifier instanceof Randomizable)) {
            throw new IllegalArgumentException("Base learner must implement Randomizable!");
        }
        this.m_Classifiers = AbstractClassifier.makeCopies(this.m_Classifier, this.m_NumIterations);
        Random random = this.m_data.getRandomNumberGenerator(this.m_Seed);
        if (!(this.m_Classifier instanceof WeightedInstancesHandler) && !this.m_data.allInstanceWeightsIdentical()) {
            this.m_data = this.m_data.resampleWithWeights(random);
        }
        for (int j = 0; j < this.m_Classifiers.length; ++j) {
            ((Randomizable)((Object)this.m_Classifiers[j])).setSeed(random.nextInt());
        }
        this.buildClassifiers();
        this.m_data = null;
    }

    @Override
    protected synchronized Instances getTrainingSet(int iteration) throws Exception {
        return this.m_data;
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] sums = new double[instance.numClasses()];
        double numPreds = 0.0;
        for (int i = 0; i < this.m_NumIterations; ++i) {
            if (instance.classAttribute().isNumeric()) {
                double pred = this.m_Classifiers[i].classifyInstance(instance);
                if (Utils.isMissingValue(pred)) continue;
                sums[0] = sums[0] + pred;
                numPreds += 1.0;
                continue;
            }
            double[] newProbs = this.m_Classifiers[i].distributionForInstance(instance);
            for (int j = 0; j < newProbs.length; ++j) {
                int n = j;
                sums[n] = sums[n] + newProbs[j];
            }
        }
        if (instance.classAttribute().isNumeric()) {
            sums[0] = numPreds == 0.0 ? Utils.missingValue() : sums[0] / numPreds;
            return sums;
        }
        if (Utils.eq(Utils.sum(sums), 0.0)) {
            return sums;
        }
        Utils.normalize(sums);
        return sums;
    }

    @Override
    public String batchSizeTipText() {
        return "Batch size to use if base learner is a BatchPredictor";
    }

    @Override
    public void setBatchSize(String size) {
        if (this.getClassifier() instanceof BatchPredictor) {
            ((BatchPredictor)((Object)this.getClassifier())).setBatchSize(size);
        } else {
            super.setBatchSize(size);
        }
    }

    @Override
    public String getBatchSize() {
        if (this.getClassifier() instanceof BatchPredictor) {
            return ((BatchPredictor)((Object)this.getClassifier())).getBatchSize();
        }
        return super.getBatchSize();
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public double[][] distributionsForInstances(final Instances insts) throws Exception {
        if (this.getClassifier() instanceof BatchPredictor) {
            ExecutorService pool = Executors.newFixedThreadPool(this.m_numExecutionSlots);
            int chunksize = this.m_Classifiers.length / this.m_numExecutionSlots;
            HashSet<Future<double[][]>> results = new HashSet<Future<double[][]>>();
            for (int j = 0; j < this.m_numExecutionSlots; ++j) {
                final int lo = j * chunksize;
                final int n = j < this.m_numExecutionSlots - 1 ? lo + chunksize : this.m_Classifiers.length;
                Future<double[][]> futureT = pool.submit(new Callable<double[][]>(){

                    @Override
                    public double[][] call() throws Exception {
                        if (insts.classAttribute().isNumeric()) {
                            double[][] ensemblePreds = new double[insts.numInstances()][2];
                            for (int i = lo; i < n; ++i) {
                                double[][] preds = ((BatchPredictor)((Object)RandomCommittee.this.m_Classifiers[i])).distributionsForInstances(insts);
                                for (int j = 0; j < preds.length; ++j) {
                                    if (Utils.isMissingValue(preds[j][0])) continue;
                                    double[] dArray = ensemblePreds[j];
                                    dArray[0] = dArray[0] + preds[j][0];
                                    double[] dArray2 = ensemblePreds[j];
                                    dArray2[1] = dArray2[1] + 1.0;
                                }
                            }
                            return ensemblePreds;
                        }
                        double[][] ensemblePreds = new double[insts.numInstances()][insts.numClasses()];
                        for (int i = lo; i < n; ++i) {
                            double[][] preds = ((BatchPredictor)((Object)RandomCommittee.this.m_Classifiers[i])).distributionsForInstances(insts);
                            for (int j = 0; j < preds.length; ++j) {
                                for (int k = 0; k < preds[j].length; ++k) {
                                    double[] dArray = ensemblePreds[j];
                                    int n2 = k;
                                    dArray[n2] = dArray[n2] + preds[j][k];
                                }
                            }
                        }
                        return ensemblePreds;
                    }
                });
                results.add(futureT);
            }
            double[][] ensemblePreds = new double[insts.numInstances()][insts.classAttribute().isNumeric() ? 2 : insts.numClasses()];
            try {
                for (Future future : results) {
                    double[][] preds = (double[][])future.get();
                    for (int j = 0; j < preds.length; ++j) {
                        for (int k = 0; k < preds[j].length; ++k) {
                            double[] dArray = ensemblePreds[j];
                            int n = k;
                            dArray[n] = dArray[n] + preds[j][k];
                        }
                    }
                }
            }
            catch (Exception e) {
                System.out.println("RandomCommittee: predictions could not be generated by thread.");
                e.printStackTrace();
            }
            pool.shutdown();
            if (insts.classAttribute().isNumeric()) {
                void var7_16;
                double[][] finalPreds = new double[ensemblePreds.length][1];
                boolean bl = false;
                while (var7_16 < ensemblePreds.length) {
                    finalPreds[var7_16][0] = ensemblePreds[var7_16][1] == 0.0 ? Utils.missingValue() : ensemblePreds[var7_16][0] / ensemblePreds[var7_16][1];
                    ++var7_16;
                }
                return finalPreds;
            }
            for (int j = 0; j < ensemblePreds.length; ++j) {
                double d = Utils.sum(ensemblePreds[j]);
                if (Utils.eq(d, 0.0)) continue;
                Utils.normalize(ensemblePreds[j], d);
            }
            return ensemblePreds;
        }
        double[][] result = new double[insts.numInstances()][insts.numClasses()];
        for (int i = 0; i < insts.numInstances(); ++i) {
            result[i] = this.distributionForInstance(insts.instance(i));
        }
        return result;
    }

    @Override
    public boolean implementsMoreEfficientBatchPrediction() {
        if (!(this.getClassifier() instanceof BatchPredictor)) {
            return super.implementsMoreEfficientBatchPrediction();
        }
        return ((BatchPredictor)((Object)this.getClassifier())).implementsMoreEfficientBatchPrediction();
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "RandomCommittee: No model built yet.";
        }
        StringBuffer text = new StringBuffer();
        text.append("All the base classifiers: \n\n");
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            text.append(this.m_Classifiers[i].toString() + "\n\n");
        }
        return text.toString();
    }

    @Override
    public void generatePartition(Instances data) throws Exception {
        if (!(this.m_Classifier instanceof PartitionGenerator)) {
            throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot generate a partition");
        }
        this.buildClassifier(data);
    }

    @Override
    public double[] getMembershipValues(Instance inst) throws Exception {
        if (this.m_Classifier instanceof PartitionGenerator) {
            ArrayList<double[]> al = new ArrayList<double[]>();
            int size = 0;
            for (int i = 0; i < this.m_Classifiers.length; ++i) {
                double[] r = ((PartitionGenerator)((Object)this.m_Classifiers[i])).getMembershipValues(inst);
                size += r.length;
                al.add(r);
            }
            double[] values = new double[size];
            int pos = 0;
            for (double[] v : al) {
                System.arraycopy(v, 0, values, pos, v.length);
                pos += v.length;
            }
            return values;
        }
        throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot generate a partition");
    }

    @Override
    public int numElements() throws Exception {
        if (this.m_Classifier instanceof PartitionGenerator) {
            int size = 0;
            for (int i = 0; i < this.m_Classifiers.length; ++i) {
                size += ((PartitionGenerator)((Object)this.m_Classifiers[i])).numElements();
            }
            return size;
        }
        throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot generate a partition");
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision$");
    }

    public static void main(String[] argv) {
        RandomCommittee.runClassifier(new RandomCommittee(), argv);
    }
}

