注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

Koala++'s blog

计算广告学 RTB

 
 
 

日志

 
 

Weka开发[19]——NaiveBayes源代码分析  

2009-05-04 18:20:51|  分类: 机器学习 |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |

         本来不想自己写Naïve Bayes这篇源码分析的,叫不动人,没办法,只好自己写了。请读者自己看一下论文Estimating Continuous Distributions in Bayesian Classifiers。什么都不懂的读者请看Tom mitchellGenerative and Discriminative Classifiers: Naive Bayes and Logistic Regression

         直接看buildClassifier函数,下面我只把我认为重要的代码列出来:

if (m_UseDiscretization) {

    m_Disc = new weka.filters.supervised.attribute.Discretize();

    m_Disc.setInputFormat(m_Instances);

    m_Instances = weka.filters.Filter.useFilter(m_Instances, m_Disc);

} else {

    m_Disc = null;

}

         如果需要进行离散化,就进行离散化,有人问过我如何离散化,上面就是。

// Reserve space for the distributions

m_Distributions = new Estimator[m_Instances.numAttributes() - 1]

[m_Instances.numClasses()];

m_ClassDistribution = new DiscreteEstimator(m_Instances.numClasses(),

true);

m_Distributions就是P(C)m_ClassDistribution就是P(X|C)

int attIndex = 0;

Enumeration enu = m_Instances.enumerateAttributes();

while (enu.hasMoreElements()) {

    Attribute attribute = (Attribute) enu.nextElement();

    ……

    attIndex++;

}

    循环对每一个特征进行处理

// If the attribute is numeric, determine the estimator

// numeric precision from differences between adjacent values

double numPrecision = DEFAULT_NUM_PRECISION;

if (attribute.type() == Attribute.NUMERIC) {

    m_Instances.sort(attribute);

    if ((m_Instances.numInstances() > 0)

       && !m_Instances.instance(0).isMissing(attribute)) {

       double lastVal = m_Instances.instance(0).value(attribute);

       double currentVal, deltaSum = 0;

       int distinct = 0;

       for (int i = 1; i < m_Instances.numInstances(); i++) {

           Instance currentInst = m_Instances.instance(i);

           if (currentInst.isMissing(attribute)) {

              break;

           }

           currentVal = currentInst.value(attribute);

           if (currentVal != lastVal) {

              deltaSum += currentVal - lastVal;

              lastVal = currentVal;

              distinct++;

           }

       }

       if (distinct > 0) {

           numPrecision = deltaSum / distinct;

       }

    }

}

这一大段代码令我惊讶的是只是为了确定精度(其实它是为了可以增量式的学习),精度就是平时说的保留几位小数,不过这里不是保留多少位。先看一下代码,先对m_Instances进行排序,以前也说过排序后,这个属性的上是缺失值的样本就排到了最前面,判断如果第一个样本在这个属性上缺失值,那么就不用执行了(instances.deleteWithMissingClass();这一句已经执行了,所以不太可能发生)。接下来,得到每个样本的在当前属性的属性值currentVal,如果与前一个样本在当前属性的属性值不同,则相减,将每次差值累加至deltaSum中,最后numPrecision就是差值之和deltaSum除所有不同的属性值。

for (int j = 0; j < m_Instances.numClasses(); j++) {

    switch (attribute.type()) {

       case Attribute.NUMERIC:

           if (m_UseKernelEstimator) {

               m_Distributions[attIndex][j] = new

                  KernelEstimator(numPrecision);

           } else {

              m_Distributions[attIndex][j] = new

                  NormalEstimator(numPrecision);

           }

           break;

       case Attribute.NOMINAL:

           m_Distributions[attIndex][j] = new

                  DiscreteEstimator(attribute.numValues(), true);

           break;

       default:

           throw new Exception("Attribute type unknown to NaiveBayes");

    }

}

这段代码写的看起来有点怪。判断当前属性的类型,如果是NUMERIC也就是连续属值,你可以选择KernelEstimator也可以用NormalEstimator,都用numPrecision构造参数。区别在论文中已经讲的很清楚了,两者都是用平均值和方差来计算,这也是常识了。如果是NOMINAL也就是离散值,那就用DiscreteEstimator

// Compute counts

Enumeration enumInsts = m_Instances.enumerateInstances();

while (enumInsts.hasMoreElements()) {

    Instance instance = (Instance) enumInsts.nextElement();

    updateClassifier(instance);

}

终于到了有点意义的代码,对每一个样本进行统计。updateClassifier就是根据样本更新分类器,Naïve Bayes可以是增量式的,这总是知道的吧。

public void updateClassifier(Instance instance) throws Exception {

    if (!instance.classIsMissing()) {

       Enumeration enumAtts = m_Instances.enumerateAttributes();

       int attIndex = 0;

       while (enumAtts.hasMoreElements()) {

           Attribute attribute = (Attribute) enumAtts.nextElement();

           if (!instance.isMissing(attribute)) {

              m_Distributions[attIndex][(int) instance.classValue()]

              .addValue(instance.value(attribute),instance.weight());

           }

           attIndex++;

       }

m_ClassDistribution.addValue(instance.classValue(),

instance.weight());

    }

}

进行统计,m_Distributions第一个下标就是当前属性的下标,第二个下标是类别值。最重要的函数就是addValue了,它对样本的对应类别属性值分布进行统计。最后m_ClassDistribution是对类别进行统计。

         下面看一下最简单的DiscreteEstimator的构造函数:

public DiscreteEstimator(int numSymbols, boolean laplace) {

    m_Counts = new double[numSymbols];

    m_SumOfCounts = 0;

    if (laplace) {

       for (int i = 0; i < numSymbols; i++) {

           m_Counts[i] = 1;

       }

       m_SumOfCounts = (double) numSymbols;

    }

}

public DiscreteEstimator(int nSymbols, double fPrior) {

    m_Counts = new double[nSymbols];

    for (int iSymbol = 0; iSymbol < nSymbols; iSymbol++) {

       m_Counts[iSymbol] = fPrior;

    }

    m_SumOfCounts = fPrior * (double) nSymbols;

}

         也没什么区别,第一个用Laplace,第二个不知道是什么,反正差不多。

public void addValue(double data, double weight) {

    m_Counts[(int) data] += weight;

    m_SumOfCounts += weight;

}

离散型的addValue非常简单,就是在对应的属性值上加上这个样本的权重。

再看一下NormalEstimator的构造函数:

public NormalEstimator(double precision) {

    m_Precision = precision;

    // Allow at most 3 sd's within one interval

    m_StandardDev = m_Precision / (2 * 3);

}

         精度已经解释过了,再说一次,这里的精度不是精准到第几位,而是一个值。下面的2我想我应该是对的,可是我怕我的想法是错的,讲出来会被人笑死,如果有人有想法,讲一声,3的意思是在这个精度范围内最多能有三个标准差。

public void addValue(double data, double weight) {

    if (weight == 0) {

       return;

    }

    data = round(data);

    m_SumOfWeights += weight;

    m_SumOfValues += data * weight;

    m_SumOfValuesSq += data * data * weight;

 

    if (m_SumOfWeights > 0) {

       m_Mean = m_SumOfValues / m_SumOfWeights;

       double stdDev = Math.sqrt(Math.abs(m_SumOfValuesSq - m_Mean

              * m_SumOfValues)

              / m_SumOfWeights);

       // If the stdDev ~= 0, we really have no idea of scale yet,

       // so stick with the default. Otherwise...

       if (stdDev > 1e-10) {

           m_StandardDev = Math.max(m_Precision / (2 * 3),

           // allow at most 3sd's within one interval

              stdDev);

       }

    }

}

这段程序没什么好讲的,有兴趣可以去Wiki搜索Algorithms for calculating variance词条,里有Weighted incremental algorithm可能看起来更清楚一点。

         下面看一下distributionForInstance函数。

public double[] distributionForInstance(Instance instance) {

    double[] probs = new double[m_NumClasses];

    for (int j = 0; j < m_NumClasses; j++) {

       probs[j] = m_ClassDistribution.getProbability(j);

    }

    Enumeration enumAtts = instance.enumerateAttributes();

    int attIndex = 0;

    while (enumAtts.hasMoreElements()) {

       Attribute attribute = (Attribute) enumAtts.nextElement();

       if (!instance.isMissing(attribute)) {

           double temp, max = 0;

           for (int j = 0; j < m_NumClasses; j++) {

              temp = Math.max(1e-75, Math.pow(

                     m_Distributions[attIndex][j]

                     .getProbability(instance.value(attribute)),

                      m_Instances.attribute(attIndex).weight()));

              probs[j] *= temp;

              if (probs[j] > max) {

                  max = probs[j];

              }

           }

       }

    }

    attIndex++;

    // Display probabilities

    Utils.normalize(probs);

    return probs;

}

首先得到类别的概率,希望你还记得公式是什么,对于每一个类别,计算在每个类别上的概率,也就是tempprobs[j] *= temp还是公式。最后看一下哪一个类别是最有可能的类别。

         DiscreteEstimatorNormalEstimatorgetProbability函数分别如下:

public double getProbability(double data) {

    if (m_SumOfCounts == 0) {

       return 0;

    }

    return (double) m_Counts[(int) data] / m_SumOfCounts;

}

public double getProbability(double data) {

    data = round(data);

    double zLower = (data - m_Mean - (m_Precision / 2)) / m_StandardDev;

    double zUpper = (data - m_Mean + (m_Precision / 2)) / m_StandardDev;

 

    double pLower = Statistics.normalProbability(zLower);

    double pUpper = Statistics.normalProbability(zUpper);

    return pUpper - pLower;

}

         第一个没什么好讲的,直接返回,第二个我只明白+-(m_Precision/2)的意思是根据精度求它可能的最小值和最大值。

         如果本科时多学点计算方法,概率统计可能今天不会这么痛苦。

  评论这张
 
阅读(4423)| 评论(8)
推荐 转载

历史上的今天

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017