注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

Koala++'s blog

计算广告学 RTB

 
 
 

日志

 
 

Weka开发[21]——IBk(KNN)源代码分析  

2009-05-17 16:49:37|  分类: 机器学习 |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |

       如果你没有看上一篇IB1,请先看一下,因为重复的内容我在这里不会介绍了。

       直接看buildClassifier,这里只列出在IB1中也没有出现的代码:

try {

       m_NumClasses = instances.numClasses();

       m_ClassType = instances.classAttribute().type();

} catch (Exception ex) {

       throw new Error("This should never be reached");

}

// Throw away initial instances until within the specified window size

if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) {

        m_Train = new Instances(m_Train, m_Train.numInstances()

              - m_WindowSize, m_WindowSize);

}

 

// Compute the number of attributes that contribute

// to each prediction

m_NumAttributesUsed = 0.0;

for (int i = 0; i < m_Train.numAttributes(); i++) {

    if ((i != m_Train.classIndex())

           && (m_Train.attribute(i).isNominal() || m_Train

                  .attribute(i).isNumeric())) {

       m_NumAttributesUsed += 1.0;

    }

}

 

// Invalidate any currently cross-validation selected k

m_kNNValid = false;

       IB1中不关心m_NumClasses是因为它就找一个邻居,当然就一个值了。m_WindowSize是指用多少样本用于分类,这里不是随机选择而是直接选前m_WindowSize个。这里下面是看有多少属性参与预测。

       KNN也是一个可以增量学习的分器量,下面看一下它的updateClassifier代码:

public void updateClassifier(Instance instance) throws Exception {

 

    if (m_Train.equalHeaders(instance.dataset()) == false) {

       throw new Exception("Incompatible instance types");

    }

    if (instance.classIsMissing()) {

       return;

    }

    if (!m_DontNormalize) {

       updateMinMax(instance);

    }

    m_Train.add(instance);

    m_kNNValid = false;

    if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)){

       while (m_Train.numInstances() > m_WindowSize) {

           m_Train.delete(0);

       }

    }

}

       同样很简单,updateMinMax,如果超出窗口大小,循环删除超过窗口大小的第一个样本。

       这里注意IBk没有实现classifyInstance,它只实现了distributionForInstances:

public double[] distributionForInstance(Instance instance) throws Exception {

    if (m_Train.numInstances() == 0) {

       throw new Exception("No training instances!");

    }

    if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)){

       m_kNNValid = false;

       boolean deletedInstance = false;

       while (m_Train.numInstances() > m_WindowSize) {

           m_Train.delete(0);

       }

       //rebuild datastructure KDTree currently can't delete

       if (deletedInstance == true)

           m_NNSearch.setInstances(m_Train);

    }

 

    // Select k by cross validation

    if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) {

        crossValidate();

    }

 

    m_NNSearch.addInstanceInfo(instance);

 

    Instances neighbours = m_NNSearch.kNearestNeighbours(instance,

       m_kNN);

    double[] distances = m_NNSearch.getDistances();

    double[] distribution = makeDistribution(neighbours, distances);

 

    return distribution;

}

       前面两个判断不讲了,crossValidate()马上讲,寻找K个邻居在我第[18]篇里已经讲过了,现在我们看一下makeDistribution函数。

protected double[] makeDistribution(Instances neighbours,

double[] distances)throws Exception {

 

    double total = 0, weight;

    double[] distribution = new double[m_NumClasses];

 

// Set up a correction to the estimator

    if (m_ClassType == Attribute.NOMINAL) {

       for (int i = 0; i < m_NumClasses; i++) {

           distribution[i] = 1.0 / Math.max(1, m_Train.numInstances());

       }

       total = (double) m_NumClasses / Math.max(1,

           m_Train.numInstances());

    }

 

    for (int i = 0; i < neighbours.numInstances(); i++) {

       // Collect class counts

       Instance current = neighbours.instance(i);

       distances[i] = distances[i] * distances[i];

       distances[i] = Math.sqrt(distances[i] / m_NumAttributesUsed);

       switch (m_DistanceWeighting) {

       case WEIGHT_INVERSE:

           weight = 1.0 / (distances[i] + 0.001); // to avoid div by zero

           break;

       case WEIGHT_SIMILARITY:

           weight = 1.0 - distances[i];

           break;

       default: // WEIGHT_NONE:

           weight = 1.0;

           break;

       }

       weight *= current.weight();

       try {

           switch (m_ClassType) {

           case Attribute.NOMINAL:

              distribution[(int) current.classValue()] += weight;

              break;

           case Attribute.NUMERIC:

              distribution[0] += current.classValue() * weight;

              break;

           }

       } catch (Exception ex) {

           throw new Error("Data has no class attribute!");

       }

       total += weight;

    }

 

    // Normalise distribution

    if (total > 0) {

       Utils.normalize(distribution, total);

    }

    return distribution;

}

       第一行注释Set up a correction,我感觉没什么必要,又不是Bayes还有除0错误,没什么可修正的。这里可以看见它实现了三种距离权重计算方法,倒数,与1的差,另外就是固定权重1。然后如果类别是离散值把对应的类值加上权重,如果是连续值,就加上当前类别值剩权重。

       crossValidate简单地说就是用蛮力找在到底用多少个邻居好,它对m_Train中的样本进行循环,对每个样本找邻居,然后统计看寻找多少个邻居时最好。

protected void crossValidate() {

    double[] performanceStats = new double[m_kNNUpper];

    double[] performanceStatsSq = new double[m_kNNUpper];

 

    for (int i = 0; i < m_kNNUpper; i++) {

       performanceStats[i] = 0;

       performanceStatsSq[i] = 0;

    }

 

    m_kNN = m_kNNUpper;

    Instance instance;

    Instances neighbours;

    double[] origDistances, convertedDistances;

    for (int i = 0; i < m_Train.numInstances(); i++) {

       instance = m_Train.instance(i);

       neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);

       origDistances = m_NNSearch.getDistances();

 

       for (int j = m_kNNUpper - 1; j >= 0; j--) {

           // Update the performance stats

           convertedDistances = new double[origDistances.length];

           System.arraycopy(origDistances, 0, convertedDistances, 0,

                  origDistances.length);

           double[] distribution = makeDistribution(neighbours,

                     convertedDistances);

           double thisPrediction = Utils.maxIndex(distribution);

           if (m_Train.classAttribute().isNumeric()) {

              thisPrediction = distribution[0];

              double err = thisPrediction - instance.classValue();

              performanceStatsSq[j] += err * err; // Squared error

              performanceStats[j] += Math.abs(err); // Absolute error

           } else {

              if (thisPrediction != instance.classValue()) {

                  performanceStats[j]++; // Classification error

              }

           }

           if (j >= 1) {

               neighbours = pruneToK(neighbours,

convertedDistances, j);

           }

       }

}

 

    // Check through the performance stats and select the best

    // k value (or the lowest k if more than one best)

    double[] searchStats = performanceStats;

    if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {

       searchStats = performanceStatsSq;

    }

    double bestPerformance = Double.NaN;

    int bestK = 1;

    for (int i = 0; i < m_kNNUpper; i++) {

       if (Double.isNaN(bestPerformance)

              || (bestPerformance > searchStats[i])) {

           bestPerformance = searchStats[i];

           bestK = i + 1;

       }

    }

    m_kNN = bestK;

 

    m_kNNValid = true;

}

       m_kNNUpper是另一个设置最多有多少样本的参数,枚举每一个样本(instance),找它的邻居(neighbors),和距离(origDistances)。接下来就是把从0m_kNNUpper个邻居的得到的方差(performanceStatsSq)和标准差(performanceStats)与以前得到的值累加。pruneToK就是得到j个样本(如果j+1的距离不等于第j),后面就比较好理解了, m_MeanSquared对连续类别是选择用方差还是标准差进行选择,然后最出m_kNNUpper看在多少邻居的时候,分类误差最小,就认为是最好的邻居数。

  评论这张
 
阅读(3083)| 评论(0)
推荐 转载

历史上的今天

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017