注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

Koala++'s blog

计算广告学 RTB

 
 
 

日志

 
 

Weka开发[22]——REPTree源代码分析(2)  

2009-05-28 12:48:18|  分类: 机器学习 |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |

    有长度限制,我拆成了两部分。   

    好了,终于可以建树了,除了VC,我还真没怎么见过这么多参数。现在把它拆开分析:

// Store structure of dataset, set minimum number of instances

// and make space for potential info from pruning data

m_Info = header;

m_HoldOutDist = new double[data.numClasses()];

 

// Make leaf if there are no training instances

int helpIndex = 0;

if (data.classIndex() == 0) {

    helpIndex = 1;

}

if (sortedIndices[helpIndex].length == 0) {

    if (data.classAttribute().isNumeric()) {

       m_Distribution = new double[2];

    } else {

       m_Distribution = new double[data.numClasses()];

    }

    m_ClassProbs = null;

    return;

}

         m_Info保存的是数据集的表头结构,m_HoldOutDist后面会讲到,是用于剪枝的。这面这个有点意思,helpIndex在类别index不是0的情况下是1,否则是0,因为sortedIndices中没有类别列。初始化m_Distribution,如果是连续值,数组长度是2,第一个保存方差,后面是样本总权重。离散值不会说,当然是类别值个数。

double priorVar = 0;

if (data.classAttribute().isNumeric()) {

 

    // Compute prior variance

    double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0;

    for (int i = 0; i < sortedIndices[helpIndex].length; i++) {

       Instance inst = data.instance(sortedIndices[helpIndex][i]);

       totalSum += inst.classValue() * weights[helpIndex][i];

       totalSumSquared += inst.classValue() * inst.classValue()

              * weights[helpIndex][i];

       totalSumOfWeights += weights[helpIndex][i];

    }

    priorVar = singleVariance(totalSum, totalSumSquared,

           totalSumOfWeights);

}

         这个就非常简单了,如果类别是连续值。再说一下,这里helpIndex无所谓,只要不是类别index就好。totalSum是类别值与样本权重的乘积和,totalSumSquared是类别值平方乘样本权重和,totalSumOfWeights是权重和。这里还是说一下,singleVariance就是变换后的期望计算公式。

// Check if node doesn't contain enough instances, is pure

// or the maximum tree depth is reached

m_ClassProbs = new double[classProbs.length];

System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);

if ((totalWeight < (2 * minNum))

       ||

 

       // Nominal case

       (data.classAttribute().isNominal() && Utils.eq(

              m_ClassProbs[Utils.maxIndex(m_ClassProbs)], Utils

                     .sum(m_ClassProbs)))

       ||

 

       // Numeric case

       (data.classAttribute().isNumeric() && ((priorVar / totalWeight)

 < minVariance))

       ||

 

       // Check tree depth

       ((m_MaxDepth >= 0) && (depth >= maxDepth))) {

 

    // Make leaf

    m_Attribute = -1;

    if (data.classAttribute().isNominal()) {

 

       // Nominal case

       m_Distribution = new double[m_ClassProbs.length];

       for (int i = 0; i < m_ClassProbs.length; i++) {

           m_Distribution[i] = m_ClassProbs[i];

       }

       Utils.normalize(m_ClassProbs);

    } else {

 

       // Numeric case

       m_Distribution = new double[2];

       m_Distribution[0] = priorVar;

       m_Distribution[1] = totalWeight;

    }

    return;

}

         先看一下不会再分裂的情况,第一种,总样本权重还不到最小分裂样本数的2(因为至少要分出来两个子结点嘛),第二种,类别是离散值的情况下,如果样本都属于一个类别(以前讲过为什么)。第三种,类别是连续值的情况下,如果方差小于一个最小方差,最小方差是由一个定义的常数与总方差的积。最后一种如果超过了定义的树的深度。

         如果是离散值,就将m_ClassProbs数组中的内容复制到m_Distribution中,再进行规范化,如果是连续值,把方差和总权重保存。

// Compute class distributions and value of splitting

// criterion for each attribute

double[] vals = new double[data.numAttributes()];

double[][][] dists = new double[data.numAttributes()][0][0];

double[][] props = new double[data.numAttributes()][0];

double[][] totalSubsetWeights = new double[data.numAttributes()][0];

double[] splits = new double[data.numAttributes()];

if (data.classAttribute().isNominal()) {

 

    // Nominal case

    for (int i = 0; i < data.numAttributes(); i++) {

       if (i != data.classIndex()) {

           splits[i] = distribution(props, dists, i,

                  sortedIndices[i], weights[i],

                  totalSubsetWeights, data);

           vals[i] = gain(dists[i], priorVal(dists[i]));

       }

    }

} else {

 

    // Numeric case

    for (int i = 0; i < data.numAttributes(); i++) {

       if (i != data.classIndex()) {

           splits[i] = numericDistribution(props, dists, i,

                  sortedIndices[i], weights[i],

                  totalSubsetWeights, data, vals);

       }

    }

}

         这里出现了一下ditribution函数,也是非常长,但是又很重要,所以我还是先介绍它:

double splitPoint = Double.NaN;

Attribute attribute = data.attribute(att);

double[][] dist = null;

int i;

 

if (attribute.isNominal()) {

 

    // For nominal attributes

    dist = new double[attribute.numValues()][data.numClasses()];

    for (i = 0; i < sortedIndices.length; i++) {

       Instance inst = data.instance(sortedIndices[i]);

       if (inst.isMissing(att)) {

           break;

       }

       dist[(int) inst.value(att)][(int) inst.classValue()] +=

           weights[i];

    }

}

         先讲一下离散值的情况,实现与j48包下面的Distribution非常相似,dist第一维是属性值,第二维是类别值,元素值是样本权重累加值。

else {

    // For numeric attributes

    double[][] currDist = new double[2][data.numClasses()];

    dist = new double[2][data.numClasses()];

 

    // Move all instances into second subset

    for (int j = 0; j < sortedIndices.length; j++) {

        Instance inst = data.instance(sortedIndices[j]);

       if (inst.isMissing(att)) {

           break;

       }

       currDist[1][(int) inst.classValue()] += weights[j];

    }

    double priorVal = priorVal(currDist);

    System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);

 

    // Try all possible split points

    double currSplit = data.instance(sortedIndices[0]).value(att);

    double currVal, bestVal = -Double.MAX_VALUE;

    for (i = 0; i < sortedIndices.length; i++) {

       Instance inst = data.instance(sortedIndices[i]);

       if (inst.isMissing(att)) {

           break;

       }

       if (inst.value(att) > currSplit) {

           currVal = gain(currDist, priorVal);

           if (currVal > bestVal) {

              bestVal = currVal;

              splitPoint = (inst.value(att) + currSplit) / 2.0;

              for (int j = 0; j < currDist.length; j++) {

                  System.arraycopy(currDist[j], 0, dist[j], 0,

                         dist[j].length);

              }

           }

       }

       currSplit = inst.value(att);

       currDist[0][(int) inst.classValue()] += weights[i];

       currDist[1][(int) inst.classValue()] -= weights[i];

    }

}

         不想讲了,和J48也是一样,先把样本存在后一子结点中currDist[1],然后依次试属性值,找到一个最好看分裂点。

// Compute weights

props[att] = new double[dist.length];

for (int k = 0; k < props[att].length; k++) {

    props[att][k] = Utils.sum(dist[k]);

}

if (!(Utils.sum(props[att]) > 0)) {

    for (int k = 0; k < props[att].length; k++) {

       props[att][k] = 1.0 / (double) props[att].length;

    }

} else {

    Utils.normalize(props[att]);

}

         props中保存的就是第att个属性的第k个属性值的样本权重之和。如果这个值不太于0,就给它赋值为1除以这个属性的全部可能取值。否则规范化。

// Distribute counts

while (i < sortedIndices.length) {

    Instance inst = data.instance(sortedIndices[i]);

    for (int j = 0; j < dist.length; j++) {

       dist[j][(int) inst.classValue()] += props[att][j]

              * weights[i];

    }

    i++;

}

 

// Compute subset weights

subsetWeights[att] = new double[dist.length];

for (int j = 0; j < dist.length; j++) {

    subsetWeights[att][j] += Utils.sum(dist[j]);

}

 

// Return distribution and split point

dists[att] = dist;

return splitPoint;

         i这里初始是有确定属性值与缺失值的分界下标值,开始一时头晕还没看出来,调试才看出来。如果有缺失值,就用每一个属性值都加上相应的权重来代替。在att属性上分裂,那种子结点的权重和为distj这种属性取值上的和。最后把dist赋值给dists[att],返回分裂点。

         现在再跳回到buildTree函数,接着讲gain函数就是计算信息增益,不讲了。numericDistribution还是这么长,而且也差不多,也就算了吧。

// Find best attribute

m_Attribute = Utils.maxIndex(vals);

int numAttVals = dists[m_Attribute].length;

 

// Check if there are at least two subsets with

// required minimum number of instances

int count = 0;

for (int i = 0; i < numAttVals; i++) {

    if (totalSubsetWeights[m_Attribute][i] >= minNum) {

       count++;

    }

    if (count > 1) {

       break;

    }

}

         vals中信息增益值,m_Attribute就是有最大信息增益值的属性下标,再下来看是否这个属性可以分出两个大于minNum样本数的子结点。

// Any useful split found?

if ((vals[m_Attribute] > 0) && (count > 1)) {

 

    // Build subtrees

    m_SplitPoint = splits[m_Attribute];

    m_Prop = props[m_Attribute];

    int[][][] subsetIndices = new int[numAttVals][data

           .numAttributes()][0];

    double[][][] subsetWeights = new double[numAttVals][data

           .numAttributes()][0];

    splitData(subsetIndices, subsetWeights, m_Attribute,

           m_SplitPoint, sortedIndices, weights, data);

    m_Successors = new Tree[numAttVals];

    for (int i = 0; i < numAttVals; i++) {

       m_Successors[i] = new Tree();

       m_Successors[i].buildTree(subsetIndices[i],

              subsetWeights[i], data,

              totalSubsetWeights[m_Attribute][i],

              dists[m_Attribute][i], header, minNum, minVariance,

              depth + 1, maxDepth);

    }

} else {

 

    // Make leaf

    m_Attribute = -1;

}

         如果找到了可以分裂的属性,那我们就可以建立了树了,看起来乱七八糟很复杂的样子,其实如果你把上面讲的搞清楚了,这里和ID3J48没有什么区别。如果不能分裂,就把m_Attribute1,标记一下。

// Normalize class counts

if (data.classAttribute().isNominal()) {

    m_Distribution = new double[m_ClassProbs.length];

    for (int i = 0; i < m_ClassProbs.length; i++) {

       m_Distribution[i] = m_ClassProbs[i];

    }

    Utils.normalize(m_ClassProbs);

} else {

    m_Distribution = new double[2];

    m_Distribution[0] = priorVar;

    m_Distribution[1] = totalWeight;

}

         这个其实没什么好讲的,只是赋值到m_Distribution,建树就已经讲完了。但是在buildClassifier我们还剩下三行,是关于剪枝的,当时在介绍J48的时候,就没有讲,因为我不需要用那部分,当时也没怎么看。

// Insert pruning data and perform reduced error pruning

if (!m_NoPruning) {

    m_Tree.insertHoldOutSet(prune);

    m_Tree.reducedErrorPrune();

    m_Tree.backfitHoldOutSet(prune);

}

         如果非不剪枝,那么就是剪枝了,先看第一个被调用的函数:

protected void insertHoldOutSet(Instances data) throws Exception {

 

    for (int i = 0; i < data.numInstances(); i++) {

       insertHoldOutInstance(data.instance(i), data.instance(i)

              .weight(), this);

    }

}

         prune数据集中的每一个样本作为参数调用insertHoldOutInstance,它也有点长,把它一部分一部分列出来:

// Insert instance into hold-out class distribution

if (inst.classAttribute().isNominal()) {

 

    // Nominal case

    m_HoldOutDist[(int) inst.classValue()] += weight;

    int predictedClass = 0;

    if (m_ClassProbs == null) {

       predictedClass = Utils.maxIndex(parent.m_ClassProbs);

    } else {

       predictedClass = Utils.maxIndex(m_ClassProbs);

    }

    if (predictedClass != (int) inst.classValue()) {

       m_HoldOutError += weight;

    }

} else {

 

    // Numeric case

    m_HoldOutDist[0] += weight;

    double diff = 0;

    if (m_ClassProbs == null) {

       diff = parent.m_ClassProbs[0] - inst.classValue();

    } else {

       diff = m_ClassProbs[0] - inst.classValue();

    }

    m_HoldOutError += diff * diff * weight;

}

         看一下离散的情况,如果是离散类别,看它预测出的类别是否与真实类别相同,如果不同,就把样本权重累加到m_HoldOutError上,其中==null的情况应该是这个叶子结点上曾经分的时候就没样本。在连续类别时,是把预测值与真实值的差的平方乘权重加到m_holdOutError上,

// The process is recursive

if (m_Attribute != -1) {

 

    // If node is not a leaf

    if (inst.isMissing(m_Attribute)) {

 

       // Distribute instance

       for (int i = 0; i < m_Successors.length; i++) {

           if (m_Prop[i] > 0) {

              m_Successors[i].insertHoldOutInstance(inst, weight

                     * m_Prop[i], this);

           }

       }

    } else {

 

       if (m_Info.attribute(m_Attribute).isNominal()) {

 

           // Treat nominal attributes

           m_Successors[(int) inst.value(m_Attribute)]

                  .insertHoldOutInstance(inst, weight, this);

       } else {

 

           // Treat numeric attributes

           if (inst.value(m_Attribute) < m_SplitPoint) {

              m_Successors[0].insertHoldOutInstance(inst, weight,

                      this);

           } else {

              m_Successors[1].insertHoldOutInstance(inst, weight,

                     this);

           }

       }

    }

}

         m_Attribute等于-1时就是叶子结点,前面已经讲过了,如果是缺失值的情况,又是把所有可能算一遍(前两天看论文,有一篇论文提到对缺失值的运行,在C4.5中占到了80%的时间)。如果不是缺失值就递归。这个函数整体的含义就是计算父结点和子结点,为最后看分还是不分好做准备。

         好了,看第二个函数:

protected double reducedErrorPrune() throws Exception {

 

    // Is node leaf ?

    if (m_Attribute == -1) {

       return m_HoldOutError;

    }

 

    // Prune all sub trees

    double errorTree = 0;

    for (int i = 0; i < m_Successors.length; i++) {

       errorTree += m_Successors[i].reducedErrorPrune();

    }

 

    // Replace sub tree with leaf if error doesn't get worse

    if (errorTree >= m_HoldOutError) {

       m_Attribute = -1;

       m_Successors = null;

       return m_HoldOutError;

    } else {

       return errorTree;

    }

}

         如果开始就是叶子结点,太不可思议了,直接返回。接下来,这是一个递归,递归就在做一件事情,如果几个子结点的错误加起来比父结点还高,意思也就是说分裂比不分裂还要差,那么我们就把子结点剪去,也就是剪枝,在这里是剪叶子?剪枝的时候,设置m_Attribute,然后把子结点置空,返回父结点的错误值。

         最后一个函数:

protected void backfitHoldOutSet(Instances data) throws Exception {

 

    for (int i = 0; i < data.numInstances(); i++) {

       backfitHoldOutInstance(data.instance(i), data.instance(i)

              .weight(), this);

    }

}

         backfitHoldOutInstance不难,但是还有有点长,分开贴出来:

// Insert instance into hold-out class distribution

if (inst.classAttribute().isNominal()) {

 

    // Nominal case

    if (m_ClassProbs == null) {

       m_ClassProbs = new double[inst.numClasses()];

    }

    System.arraycopy(m_Distribution, 0, m_ClassProbs, 0, inst

           .numClasses());

    m_ClassProbs[(int) inst.classValue()] += weight;

    Utils.normalize(m_ClassProbs);

} else {

 

    // Numeric case

    if (m_ClassProbs == null) {

       m_ClassProbs = new double[1];

    }

    m_ClassProbs[0] *= m_Distribution[1];

    m_ClassProbs[0] += weight * inst.classValue();

    m_ClassProbs[0] /= (m_Distribution[1] + weight);

}

         这个函数主要是把以前用训练集测出来的值,现在把剪枝集的样本信息也加进去。这些以前也都讲过。

// The process is recursive

if (m_Attribute != -1) {

 

    // If node is not a leaf

    if (inst.isMissing(m_Attribute)) {

 

       // Distribute instance

       for (int i = 0; i < m_Successors.length; i++) {

           if (m_Prop[i] > 0) {

              m_Successors[i].backfitHoldOutInstance(inst, weight

                     * m_Prop[i], this);

           }

       }

    } else {

 

       if (m_Info.attribute(m_Attribute).isNominal()) {

 

           // Treat nominal attributes

           m_Successors[(int) inst.value(m_Attribute)]

                  .backfitHoldOutInstance(inst, weight, this);

       } else {

 

           // Treat numeric attributes

           if (inst.value(m_Attribute) < m_SplitPoint) {

              m_Successors[0].backfitHoldOutInstance(inst,

                     weight, this);

           } else {

              m_Successors[1].backfitHoldOutInstance(inst,

                     weight, this);

           }

       }

    }

}

         不想讲了,自己看吧,distributionForInstance也不讲了,如果是一直看我的东西过来的,到现在还不明白,我也没话说了。

  评论这张
 
阅读(1973)| 评论(0)
推荐 转载

历史上的今天

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017