注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

Koala++'s blog

计算广告学 RTB

 
 
 

日志

 
 

Weka开发[25]——Bagging源代码分析  

2009-08-12 14:24:24|  分类: 机器学习 |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |

         先翻译一段Bagging的介绍,Breimanbagging算法,是bootstrap aggregating的缩写,是最早的Ensemble算法之一,它也是最直接容易实现,又有着另人惊讶的好效果的算法之一。Bagging中的多样性是由有放回抽取训练样本来实现的,用这种方式随机产生多个训练数据的子集,在每一个训练集的子集上训练一个同种分类器,最终分类结果是由多个分类器的分类结果多数投票而产生的。Breiman’s bagging, short for bootstrap aggregating, is one of the earliest ensemble based algorithms. It is also one of the most intuitive and simplest to implement, with a surprisingly good performance . Diversity in bagging is obtained by using bootstrapped replicas of the training data: different training data subsets are randomly drawn—with replacement—from the entire training data. Each training data subset is used to train a different classifier of the same type. Individual classifiers are then combined by taking a majority vote of their decisions. For any given instance, the class chosen by most classifiers is the ensemble decision.

Bagging类在weka.classifiers.meta包下面。Bagging继承自RandomizeableInteratedSingleClassifierEnhancer,而它又继承自IteratedSingleClassifierEnhancer,它再继承自SingleClassifierEnhancer,最后一个继承自Classifier。我的UML工具似乎过期了,有空补上。

         看一下构造函数:

public Bagging() {

    m_Classifier = new weka.classifiers.trees.REPTree();

}

         可以看到默认的基分类器是REPTree

         接下来看buildClassifier函数:

// can classifier handle the data?

getCapabilities().testWithFail(data);

 

// remove instances with missing class

data = new Instances(data);

data.deleteWithMissingClass();

 

super.buildClassifier(data);

 

if (m_CalcOutOfBag && (m_BagSizePercent != 100)) {

    throw new IllegalArgumentException("Bag size needs to be 100% if "

           + "out-of-bag error is to be calculated!");

}

         只有一行代码值得看一下super.buildClassifier

public void buildClassifier(Instances data) throws Exception {

 

    if (m_Classifier == null) {

       throw new Exception("A base classifier has not been specified!");

    }

    m_Classifiers = Classifier.makeCopies(m_Classifier,

       m_NumIterations);

}

         这里将m_Classifier复制m_NumIterations份到m_Classifiers数组中去。

int bagSize = data.numInstances() * m_BagSizePercent / 100;

Random random = new Random(m_Seed);

 

boolean[][] inBag = null;

if (m_CalcOutOfBag)

    inBag = new boolean[m_Classifiers.length][];

         bagSize是一个Bag的大小,也就是它里面有多少样本。

for (int j = 0; j < m_Classifiers.length; j++) {

    Instances bagData = null;

 

    // create the in-bag dataset

    if (m_CalcOutOfBag) {

       inBag[j] = new boolean[data.numInstances()];

       bagData = resampleWithWeights(data, random, inBag[j]);

    } else {

       bagData = data.resampleWithWeights(random);

       if (bagSize < data.numInstances()) {

           bagData.randomize(random);

           Instances newBagData = new Instances(bagData, 0, bagSize);

           bagData = newBagData;

       }

    }

 

    if (m_Classifier instanceof Randomizable) {

       ((Randomizable) m_Classifiers[j]).setSeed(random.nextInt());

    }

 

    // build the classifier

    m_Classifiers[j].buildClassifier(bagData);

}

         暂时不去看m_CalcOutOfBag的情况,当然最关键的是resampleWithWeights

/**

 * Creates a new dataset of the same size using random sampling with

 * replacement according to the current instance weights. The weights of

 * the instances in the new dataset are set to one.

*/

public Instances resampleWithWeights(Random random) {

 

    double[] weights = new double[numInstances()];

    for (int i = 0; i < weights.length; i++) {

       weights[i] = instance(i).weight();

    }

    return resampleWithWeights(random, weights);

}

         注释上写的是根据当前样本的权重用有放回取样的方法创建一个同样大小的新数据集,新数据集中的样本权重为1。这里先是把权重记录下来,再用一个重载函数去做:

 

接下来是看数据集中的样本是否大于bagSize,如果不大于,其实就没什么意思了。如果大于,再把bagData随机一次,取前面的bagSize个样本,下面如果m_ClassifierRandomizable的一个实例,那么就给它再指定一个新的随机种子,这点很关键,自己写的时候,常常忘记。最后训练第j个分类器。

         现在再看resampleWithWeights

public Instances resampleWithWeights(Random random, double[] weights) {

 

    if (weights.length != numInstances()) {

       throw new IllegalArgumentException(

              "weights.length != numInstances.");

    }

    Instances newData = new Instances(this, numInstances());

    if (numInstances() == 0) {

       return newData;

    }

    double[] probabilities = new double[numInstances()];

    double sumProbs = 0, sumOfWeights = Utils.sum(weights);

    for (int i = 0; i < numInstances(); i++) {

       sumProbs += random.nextDouble();

       probabilities[i] = sumProbs;

    }

    Utils.normalize(probabilities, sumProbs / sumOfWeights);

 

    // Make sure that rounding errors don't mess things up

    probabilities[numInstances() - 1] = sumOfWeights;

    int k = 0;

    int l = 0;

    sumProbs = 0;

    while ((k < numInstances() && (l < numInstances()))) {

       if (weights[l] < 0) {

           throw new IllegalArgumentException(

                  "Weights have to be positive.");

       }

       sumProbs += weights[l];

       while ((k < numInstances()) && (probabilities[k] <= sumProbs)) {

           newData.add(instance(l));

           newData.instance(k).setWeight(1);

           k++;

       }

       l++;

    }

    return newData;

}

         sumProbs是产生的随机数的总和,而probabilities是第i次的总和,Utils.normalize的代码如下:

public static void normalize(double[] doubles, double sum) {

    if (Double.isNaN(sum)) {

       throw new IllegalArgumentException(

              "Can't normalize array. Sum is NaN.");

    }

    if (sum == 0) {

       // Maybe this should just be a return.

       throw new IllegalArgumentException(

              "Can't normalize array. Sum is zero.");

    }

    for (int i = 0; i < doubles.length; i++) {

       doubles[i] /= sum;

    }

}

这一步是将所产生的随机数与权重对应起来,因为产生的probabilities(0,1)范围内,可能与样本权重对应不起来,在下面的二重循环中,看到sumProbs重新记数,它的意义就是加上weights[l]之后,probability[k] 如果到不到相应的sumProbs,就重复地加这一个相同的样本。通过这种方式来产生有放回的取样样本。

现在看m_CalcOutOfBagtrue的时候,首先会有一个inBag二维数组,第一维大小为分类器个数,第二维为样本个数。public final Instances resampleWithWeights(Instances data, Random random, boolean[] sampled)这个函数与Intances中的差不多,只多了一句话就是sampled[l] = true,表示这个样本采样时有它。接下来看buildClassifier的后面一部分,看起来很长,其实蛮简单的。

// calc OOB error?

if (getCalcOutOfBag()) {

    double outOfBagCount = 0.0;

    double errorSum = 0.0;

    boolean numeric = data.classAttribute().isNumeric();

 

    for (int i = 0; i < data.numInstances(); i++) {

       double vote;

       double[] votes;

       if (numeric)

           votes = new double[1];

       else

           votes = new double[data.numClasses()];

 

       // determine predictions for instance

       int voteCount = 0;

       for (int j = 0; j < m_Classifiers.length; j++) {

           if (inBag[j][i])

              continue;

 

           voteCount++;

           double pred = m_Classifiers[j].classifyInstance(data

                  .instance(i));

           if (numeric)

              votes[0] += pred;

           else

              votes[(int) pred]++;

       }

 

       // "vote"

       if (numeric) {

           vote = votes[0];

           if (voteCount > 0) {

              vote /= voteCount; // average

           }

       } else {

           vote = Utils.maxIndex(votes); // majority vote

       }

 

       // error for instance

       outOfBagCount += data.instance(i).weight();

       if (numeric) {

           errorSum += StrictMath.abs(vote

                  - data.instance(i).classValue())

                  * data.instance(i).weight();

       } else {

           if (vote != data.instance(i).classValue())

              errorSum += data.instance(i).weight();

       }

    }

 

    m_OutOfBagError = errorSum / outOfBagCount;

} else {

    m_OutOfBagError = 0;

}

         这里inBag就可以判断哪几个分类器学习的时候有某一个样本,看out of bag错误率的时候,也就是用那些在学习时没有见过这个样本的分类器去分类这个样本,再用多数投票(majority vote)的方法决定分类结果。对于数值型的属性,就是将结果减去真实值,再乘权重,而对于离散型属性,只需要在分类错时,乘以权重累加到errorSum上。

 

 

 

 

 

 

 

  评论这张
 
阅读(2766)| 评论(0)
推荐 转载

历史上的今天

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017