注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

Koala++'s blog

计算广告学 RTB

 
 
 

日志

 
 

Weka开发[22]——REPTree源代码分析(1)  

2009-05-28 12:46:41|  分类: 机器学习 |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |

    如果你分析完了ID3,还想进一步学习,最好还是先学习REPTree,它没有牵扯到那么多类,两个类完成了全部的工作,看起来比较清楚,J48虽然有很强的可扩展性,但是初看起来还是有些费力,REPTree也是我卖算法时(为了买一台运算能力强一点的计算机,我也不得不赚钱),顺便分析的,但因为我以前介绍过J48了,重复的东西不想再次介绍了,如果有什么不明白的,就把我两篇写的结合起来看吧。

    我们再次从buildClassifier开始。

Random random = new Random(m_Seed);

 

m_zeroR = null;

if (data.numAttributes() == 1) {

    m_zeroR = new ZeroR();

    m_zeroR.buildClassifier(data);

    return;

}

         如果就只有一个属性,也就是类别属性,就用ZeroR分类器学习,ZeroR分类器返回训练集中出现最多的类别值,已经讲过了Weka开发[15]

// Randomize and stratify

data.randomize(random);

if (data.classAttribute().isNominal()) {

    data.stratify(m_NumFolds);

}

         randomize就是把data中的数据重排一下,如果类别属性是离散值,那么用stratify函数,stratify意思是分层,现在把这个函数列出来:

public void stratify(int numFolds) {

if (classAttribute().isNominal()) {

       // sort by class

       int index = 1;

       while (index < numInstances()) {

           Instance instance1 = instance(index - 1);

           for (int j = index; j < numInstances(); j++) {

              Instance instance2 = instance(j);

              if ((instance1.classValue() == instance2.classValue())

                     || (instance1.classIsMissing() && instance2

                                .classIsMissing())) {

                  swap(index, j);

                  index++;

              }

           }

           index++;

       }

       stratStep(numFolds);

    }

}

         上面这两重循环,就是根据类别值进行冒泡。下面有调用了stratStep函数:

protected void stratStep(int numFolds) {

 

    FastVector newVec = new FastVector(m_Instances.capacity());

    int start = 0, j;

 

    // create stratified batch

    while (newVec.size() < numInstances()) {

       j = start;

       while (j < numInstances()) {

           newVec.addElement(instance(j));

           j = j + numFolds;

       }

       start++;

    }

    m_Instances = newVec;

}

         这里我举一个例子说明:j=0时,numFolds10时,newVec加入的instance下标就为0,10,20…。这样的好处就是我们把各种类别的样本类似平均分布了。

// Split data into training and pruning set

Instances train = null;

Instances prune = null;

if (!m_NoPruning) {

    train = data.trainCV(m_NumFolds, 0, random);

    prune = data.testCV(m_NumFolds, 0);

} else {

    train = data;

}

关于trainCV这个就不讲了,就是crossValidation的第0个训练集作为这次的训练集(train)。而作为剪枝的数据集prune为第0个测试集。

// Create array of sorted indices and weights

int[][] sortedIndices = new int[train.numAttributes()][0];

double[][] weights = new double[train.numAttributes()][0];

double[] vals = new double[train.numInstances()];

for (int j = 0; j < train.numAttributes(); j++) {

    if (j != train.classIndex()) {

       weights[j] = new double[train.numInstances()];

       if (train.attribute(j).isNominal()) {

 

           // Handling nominal attributes. Putting indices of

           // instances with missing values at the end.

           sortedIndices[j] = new int[train.numInstances()];

           int count = 0;

           for (int i = 0; i < train.numInstances(); i++) {

              Instance inst = train.instance(i);

              if (!inst.isMissing(j)) {

                  sortedIndices[j][count] = i;

                  weights[j][count] = inst.weight();

                  count++;

              }

           }

           for (int i = 0; i < train.numInstances(); i++) {

              Instance inst = train.instance(i);

              if (inst.isMissing(j)) {

                  sortedIndices[j][count] = i;

                  weights[j][count] = inst.weight();

                  count++;

              }

           }

       } else {

           // Sorted indices are computed for numeric attributes

           for (int i = 0; i < train.numInstances(); i++) {

              Instance inst = train.instance(i);

              vals[i] = inst.value(j);

           }

           sortedIndices[j] = Utils.sort(vals);

           for (int i = 0; i < train.numInstances(); i++) {

              weights[j][i] = train.instance(sortedIndices[j][i])

                     .weight();

           }

       }

    }

}

         sortedIndices表示第j属性的第count个样本下标是多少,weights表示第j个属性第count个样本的权重,如果j属性是离散值,通过两个for循环,在sortedIndicesweights中在j属性上是缺失值的样本就排在了后面。如果是连续值,那么就把全部样本j属性值得到,再排序,最后记录权重。

// Compute initial class counts

double[] classProbs = new double[train.numClasses()];

double totalWeight = 0, totalSumSquared = 0;

for (int i = 0; i < train.numInstances(); i++) {

    Instance inst = train.instance(i);

    if (data.classAttribute().isNominal()) {

       classProbs[(int) inst.classValue()] += inst.weight();

       totalWeight += inst.weight();

    } else {

       classProbs[0] += inst.classValue() * inst.weight();

       totalSumSquared += inst.classValue() * inst.classValue()

              * inst.weight();

       totalWeight += inst.weight();

    }

}

m_Tree = new Tree();

double trainVariance = 0;

if (data.classAttribute().isNumeric()) {

    trainVariance = m_Tree.singleVariance(classProbs[0],

           totalSumSquared, totalWeight) / totalWeight;

    classProbs[0] /= totalWeight;

}

         计算初始化类别概率,如果类别是离散值,classProbs中记录的是属性类别inst.classValue()的样本权重之和,totalWeight是全部样本权重和。如果类别是连续值,classProbs[0]中是权重乘以类别值,它还有一个totalSumSquared是类别值平方乘以权重之和。

         m_Tree是一个Tree对象,如果是连续值类别,用m_Tree的成员函数来计算trainVariance这个带权重的方差,最后classProbs[0]相当于期望。

// Build tree

m_Tree.buildTree(sortedIndices, weights, train, totalWeight,

       classProbs, new Instances(train, 0), m_MinNum,

       m_MinVarianceProp * trainVariance, 0, m_MaxDepth);

  评论这张
 
阅读(1843)| 评论(0)
推荐 转载

历史上的今天

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017