注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

Koala++'s blog

计算广告学 RTB

 
 
 

日志

 
 

Weka开发[28]——EM源代码分析(1)  

2009-10-06 19:55:41|  分类: 机器学习 |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |

EM算法在clusterers下面,提一下是因为我没有想到它竟然在这里

       引自Andrew NgLecture notes mixtures of Gaussians and the EM algorithmThe EM-algorithm is also reminiscent of the K-means clustering algorithm, except that instead of “hard” cluster assignment c(i), we instead have the “soft” assignment w_j^(i). Similar to K-means, it is also susceptible to local optima, so reinitializing at several different initial parameters may be a good idea

       Soft指的是我们猜测是概率,取值在[0,1]区间,相反,“hard”猜测是指单个最好的猜测,可以取值在{0,1}或是{1,…,k}。英文原文:The term “soft” refers to our guesses being probabilities and taking values in [0,1]; in contrast, a “hard” guess is one that represents a single best guess( such as taking values in {0,1} or {1,…,k})

       下面的图来自Ng AndrewBishop Chistopher,第一组图K-Means的猜测是两个点,而第二组图EM是对概率的猜测。

Weka源代码分析——EM(1) - quweiprotoss - Koala++s blogWeka源代码分析——EM(1) - quweiprotoss - Koala++s blog

另一点是刚才文中提到的,多个初始化点,在代码中也体现了。

Ng在对EM算法收敛证明之后,解释如下:Hence, EM causes the likelihood to converge monotonically. In our description of the EM algorithm, we said we'd run it until convergence. Given the result that we just showed, one reasonable convergence test would be to check if the increase in l(theta) between successive iterations is smaller than some tolerance parameter, and to declare convergence if EM is improving l(theta) too slowly.

buildCluster开始:

if (data.checkForStringAttributes()) {

    throw new Exception("Can't handle string attributes!");

}

 

m_replaceMissing = new ReplaceMissingValues();

Instances instances = new Instances(data);

instances.setClassIndex(-1);

m_replaceMissing.setInputFormat(instances);

data = weka.filters.Filter.useFilter(instances, m_replaceMissing);

instances = null;

 

m_theInstances = data;

 

// calculate min and max values for attributes

m_minValues = new double[m_theInstances.numAttributes()];

m_maxValues = new double[m_theInstances.numAttributes()];

for (int i = 0; i < m_theInstances.numAttributes(); i++) {

    m_minValues[i] = m_maxValues[i] = Double.NaN;

}

for (int i = 0; i < m_theInstances.numInstances(); i++) {

    updateMinMax(m_theInstances.instance(i));

}

       ReplaceMissingValues是将缺失值用平均值或中位数代替。m_minValuesm_maxValues是每个属性的最小值与最大值数组。

private void updateMinMax(Instance instance) {

    for (int j = 0; j < m_theInstances.numAttributes(); j++) {

       if (!instance.isMissing(j)) {

           if (Double.isNaN(m_minValues[j])) {

              m_minValues[j] = instance.value(j);

              m_maxValues[j] = instance.value(j);

           } else {

              if (instance.value(j) < m_minValues[j]) {

                  m_minValues[j] = instance.value(j);

              } else {

                  if (instance.value(j) > m_maxValues[j]) {

                     m_maxValues[j] = instance.value(j);

                  }

               }

           }

       }

    }

}

       Double.isNan这里是判断是不是还没有一个真正的属性值来代替过它。其它的代码就是找第j个属性的最大值和最小值。

doEM();

 

// save memory

m_theInstances = new Instances(m_theInstances, 0);

       doEM之后就是释放空间了,那么所有的工作都是在doEM中完成的:

private void doEM() throws Exception {

 

    m_rr = new Random(m_rseed);

 

    // throw away numbers to avoid problem of similar initial numbers

    // from a similar seed

    for (int i = 0; i < 10; i++)

       m_rr.nextDouble();

 

    m_num_instances = m_theInstances.numInstances();

    m_num_attribs = m_theInstances.numAttributes();

 

    // setDefaultStdDevs(theInstances);

    // cross validate to determine number of clusters?

    if (m_initialNumClusters == -1) {

       if (m_theInstances.numInstances() > 9) {

           CVClusters();

           m_rr = new Random(m_rseed);

           for (int i = 0; i < 10; i++)

              m_rr.nextDouble();

       } else {

           m_num_clusters = 1;

       }

    }

 

    // fit full training set

    EM_Init(m_theInstances);

    m_loglikely = iterate(m_theInstances, m_verbose);

}

       丢弃从同一个种子得到的随机数,这个与下面的代码有关?如果m_initialNumClusters ==-1表明没有指定要聚多少个类,那么要用cross validate来决定聚多少个类。如果样本数大于9,用CVClusters函数来决定。如果小于9个样本,就认为就一个类。EM_Init初始化,然后迭代,先不去管CVClusters,认为已经指定了m_initialNumClusters,那么先看EM_Init

// run k means 10 times and choose best solution

SimpleKMeans bestK = null;

double bestSqE = Double.MAX_VALUE;

for (i = 0; i < 10; i++) {

    SimpleKMeans sk = new SimpleKMeans();

    sk.setSeed(m_rr.nextInt());

    sk.setNumClusters(m_num_clusters);

    sk.buildClusterer(inst);

    if (sk.getSquaredError() < bestSqE) {

       bestSqE = sk.getSquaredError();

       bestK = sk;

    }

}

       这里是用不同的随机种子初始化,最后求得一个最好的SimpleKMeans对象。

// initialize with best k-means solution

m_num_clusters = bestK.numberOfClusters();

m_weights = new double[inst.numInstances()][m_num_clusters];

m_model = new DiscreteEstimator[m_num_clusters][m_num_attribs];

m_modelNormal = new double[m_num_clusters][m_num_attribs][3];

m_priors = new double[m_num_clusters];

Instances centers = bestK.getClusterCentroids();

Instances stdD = bestK.getClusterStandardDevs();

int[][][] nominalCounts = bestK.getClusterNominalCounts();

int[] clusterSizes = bestK.getClusterSizes();

       centers是聚类后的所有中心点,stdD是标准差,而nominalCounts第一维大小为所聚类的个数,第二维属性数,第三级该维的取值数。

for (i = 0; i < m_num_clusters; i++) {

    Instance center = centers.instance(i);

    for (j = 0; j < m_num_attribs; j++) {

       if (inst.attribute(j).isNominal()) {

           m_model[i][j] = new DiscreteEstimator(m_theInstances

                  .attribute(j).numValues(), true);

           for (k = 0; k < inst.attribute(j).numValues(); k++) {

              m_model[i][j].addValue(k, nominalCounts[i][j][k]);

           }

       } else {

           double minStdD = (m_minStdDevPerAtt != null) ?

              m_minStdDevPerAtt[j] : m_minStdDev;

           double mean = (center.isMissing(j)) ? inst.meanOrMode(j)

                  : center.value(j);

           m_modelNormal[i][j][0] = mean;

           double stdv = (stdD.instance(i).isMissing(j)) ?

((m_maxValues[j] - m_minValues[j]) / (2 * m_num_clusters))

                  : stdD.instance(i).value(j);

           if (stdv < minStdD) {

              stdv = inst.attributeStats(j).numericStats.stdDev;

              if (Double.isInfinite(stdv)) {

                  stdv = minStdD;

              }

              if (stdv < minStdD) {

                  stdv = minStdD;

              }

           }

           if (stdv <= 0) {

              stdv = m_minStdDev;

           }

 

           m_modelNormal[i][j][1] = stdv;

           m_modelNormal[i][j][2] = 1.0;

       }

    }

}

       这里DiscreteEstimator是针对离散数据进行统计的一个类,构造函数如下:

public DiscreteEstimator(int numSymbols, boolean laplace) {

 

    m_Counts = new double[numSymbols];

    m_SumOfCounts = 0;

    if (laplace) {

       for (int i = 0; i < numSymbols; i++) {

           m_Counts[i] = 1;

       }

    }

    m_SumOfCounts = (double) numSymbols;

}

       这里使用了laplace平滑,m_Counts初始为1,也就是平常所见过的公式加上1,并且SumOfCounts也初始化为取值的个数,也就是公式中分母最后加的那个数。

public void addValue(double data, double weight) {

 

    m_Counts[(int) data] += weight;

    m_SumOfCounts += weight;

}

       addValue函数很简单就是在第几个取值上,加上相应的权重。写这么麻烦是因为不能得到连续值的估计。minStdD控制精度,平均值是中心点的取值,而stdv就是在SimpleKMeans中计算出的值。M_modelNormal[i][j][0]是均值,M_modelNormal[i][j][1]是方差,M_modelNormal[i][j][2]记录的是概率。

for (j = 0; j < m_num_clusters; j++) {

    //m_priors[j] += 1.0;

    m_priors[j] = clusterSizes[j];

}

Utils.normalize(m_priors);

       通过每个所聚类的大小,算出先验概率。

  评论这张
 
阅读(2978)| 评论(0)
推荐 转载

历史上的今天

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017