注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

Koala++'s blog

计算广告学 RTB

 
 
 

日志

 
 

Weka开发[53]——KDTree源代码分析  

2011-08-06 11:56:27|  分类: 机器学习 |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |

         最近在做范围索引的工作,读了一些范围查找的资料,KDTree当然是一个经典结构了,它在Weka中实现了,Weka还实现了CoverTreeBallTree。下面就分析一下Weka是如何实现KDTree的。

         推荐资料:An intoductory tutorial on kd-trees,也是源码列出的参考论文中的一篇。

protected void buildKDTree(Instances instances) throws Exception {

    // ... ...

    double[][] universe = m_EuclideanDistance.getRanges();

 

    // building tree

    m_NumNodes = m_NumLeaves = 1;

    m_MaxDepth = 0;

    m_Root = new KDTreeNode(m_NumNodes, 0,

m_Instances.numInstances() - 1, universe);

 

    splitNodes(m_Root, universe, m_MaxDepth + 1);

}

         m_EuclideanDistance是欧几里德距离计算类对象,getRange得到一个二维数组,第一维是属性维,第二维有三个元素,最小值,最大值,和最大值与最小值的差值,即宽度。m_Root当然就是根结点了,然后调用splitNodes建树。

protected void splitNodes(KDTreeNode node, double[][] universe, int depth) throws Exception {

    double[][] nodeRanges = m_EuclideanDistance.initializeRanges(

           m_InstList, node.m_Start, node.m_End);

 

    // splitting a node so it is no longer a leaf

    m_NumLeaves--;

 

    if (depth > m_MaxDepth)

       m_MaxDepth = depth;

 

    m_Splitter.splitNode(node, m_NumNodes, nodeRanges, universe);

    m_NumNodes += 2;

    m_NumLeaves += 2;

 

    splitNodes(node.m_Left, universe, depth + 1);

    splitNodes(node.m_Right, universe, depth + 1);

}

         nodeRanges是得到当前结点的范围信息,m_NumLeaves—是因为当前分裂的结点当然是个叶子结点,它分裂后,它就是非叶子结点,所以叶子个数减一。下面是得到树的深度,再下来是用m_Splitternode进行分裂,它的结果就是node.m_Leftnode.m_Right,然后对分裂出来的叶子递归调用。

public void splitNode(KDTreeNode node, int numNodesCreated,

       double[][] nodeRanges, double[][] universe) throws Exception {

// ... ...

    double splitVal = node.m_NodesRectBounds[MIN][splitDim]

           + (node.m_NodesRectBounds[MAX][splitDim] –

node.m_NodesRectBounds[MIN][splitDim]) * 0.5;

    if (splitVal < node.m_NodeRanges[splitDim][MIN])

       splitVal = node.m_NodeRanges[splitDim][MIN];

    else if (splitVal >= node.m_NodeRanges[splitDim][MAX])

       splitVal = node.m_NodeRanges[splitDim][MAX]

              - node.m_NodeRanges[splitDim][WIDTH] * 0.001;

 

    int rightStart = rearrangePoints(m_InstList, node.m_Start,

node.m_End, splitDim, splitVal);

 

    node.m_SplitDim = splitDim;

    node.m_SplitValue = splitVal;

 

    double[][] widths = new double[2][node.m_NodesRectBounds[0].length];

 

    System.arraycopy(node.m_NodesRectBounds[MIN], 0, widths[MIN], 0,

           node.m_NodesRectBounds[MIN].length);

    System.arraycopy(node.m_NodesRectBounds[MAX], 0, widths[MAX], 0,

           node.m_NodesRectBounds[MAX].length);

    widths[MAX][splitDim] = splitVal;

 

    node.m_Left = new KDTreeNode(numNodesCreated + 1, node.m_Start,

           rightStart - 1, m_EuclideanDistance.initializeRanges(

                  m_InstList, node.m_Start, rightStart - 1), widths);

 

    widths = new double[2][node.m_NodesRectBounds[0].length];

    System.arraycopy(node.m_NodesRectBounds[MIN], 0, widths[MIN], 0,

           node.m_NodesRectBounds[MIN].length);

    System.arraycopy(node.m_NodesRectBounds[MAX], 0, widths[MAX], 0,

           node.m_NodesRectBounds[MAX].length);

    widths[MIN][splitDim] = splitVal;

 

    node.m_Right = new KDTreeNode(numNodesCreated + 2, rightStart,

           node.m_End, m_EuclideanDistance.initializeRanges(

m_InstList, rightStart, node.m_End), widths);

}

         splitValuemin + (max-min)/2,也就是中位数,下面判断它是不是与MinMax值相等,这是因为相等,就可能分的非常不合理,比如有7个数,0, 1, 4, 4, 4, 4, 4,那么4是中位数,数字相等当然是要被分到同一个结点中去,如果要把中位数分到左子树,那问题就出来了,所有的值都被划到左子树了。就需要对分裂值进行调整,这时向左调整。如果中位数等于最大值,就会把中位数向左移。

         下面将最大值,最小值信息拷贝到最的左右子结点,将splitDim维的值在左子树中的最大值设置为splitVal,而在右子树中将小值设置为splitVal

public Instances kNearestNeighbours(Instance target, int k)

       throws Exception {

    MyHeap heap = new MyHeap(k);

    findNearestNeighbours(target, m_Root, k, heap, 0.0);

}

         这里用了堆结构,在N个数中找个M个最小数,当然是也堆的经典应用了,也就不多解释了。

KDTreeNode nearer, further;

boolean targetInLeft = m_EuclideanDistance.valueIsSmallerEqual(

       target, node.m_SplitDim, node.m_SplitValue);

 

if (targetInLeft) {

    nearer = node.m_Left;

    further = node.m_Right;

} else {

    nearer = node.m_Right;

    further = node.m_Left;

}

findNearestNeighbours(target, nearer, k, heap, distanceToParents);

         node.m_SplitDim维看target样本会被分到左子树还是右子树,如果它被分到左子树,那么nearer就是node._m_Left,其它的就不解释了,如果你手里有我推荐的论文,你就会知道这是Algorithm: Nearest Neighbour in a kd tree中的第27步。findNearestNeighours递归调用,这是第8步。

if (heap.size() < k) { // if haven't found the first k

    double distanceToSplitPlane = distanceToParents

           + m_EuclideanDistance.sqDifference(node.m_SplitDim,

                  target.value(node.m_SplitDim),

                  node.m_SplitValue);

    findNearestNeighbours(target, further, k, heap,

           distanceToSplitPlane);

    return;

} else {

    double distanceToSplitPlane = distanceToParents

           + m_EuclideanDistance.sqDifference(node.m_SplitDim,

                  target.value(node.m_SplitDim),

                  node.m_SplitValue);

    if (heap.peek().distance >= distanceToSplitPlane) {

       findNearestNeighbours(target, further, k, heap,

              distanceToSplitPlane);

    }

}

计算距离与距离的代码,我不明白作者怎么不提到外边来。如果堆的大小还没有到指定的邻居数,就会也递归右子树,如果达到了k,最检查现在堆中的最远距离是否比当前的更大,如果是,递归调用左子树。其实不难理解,父结点是某一维的中位数,如果是further,那只能距离更远。

for (int idx = node.m_Start; idx <= node.m_End; idx++) {

    if (target == m_Instances.instance(m_InstList[idx]))

       continue;

    if (heap.size() < k) {

       distance = m_EuclideanDistance.distance(target,

              m_Instances.instance(m_InstList[idx]),

              Double.POSITIVE_INFINITY, m_Stats);

       heap.put(m_InstList[idx], distance);

    } else {

       MyHeapElement temp = heap.peek();

       distance = m_EuclideanDistance.distance(target,

              m_Instances.instance(m_InstList[idx]),

              temp.distance, m_Stats);

       if (distance < temp.distance) {

           heap.putBySubstitute(m_InstList[idx], distance);

       } else if (distance == temp.distance) {

           heap.putKthNearest(m_InstList[idx], distance);

       }

    }// end else heap.size==k

}// end for

         在叶子结点上,和在非叶子结点上的处理逻辑相似,没有达到k,就计算距离加入堆,如果没有达到k,计算距离,并与最远距离比较,如果更近,才加入heap,这里有个特殊处理就是对相等情况的处理,这就不是关键了。

 

 

 

 

  评论这张
 
阅读(3694)| 评论(0)
推荐 转载

历史上的今天

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017