注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

Koala++'s blog

计算广告学 RTB

 
 
 

日志

 
 

Weka开发[46]——HotSpot源代码分析  

2010-07-29 14:20:22|  分类: 机器学习 |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |

         先用iris.arff数据集运行一下HotSpot得到的结果是:

Hot Spot

========

Total population: 150 instances

Target attribute: class

Target value: Iris-setosa [value count in total population: 50 instances (33.33%)]

Minimum value count for segments: 50 instances (33% of total population)

Maximum branching factor: 2

Minimum improvement in target: 1%

 

class=Iris-setosa (33.33% [50/150])

  petallength <= 1.9 (100% [50/50])

  petalwidth <= 0.6 (100% [50/50])

         结果的意思大致如下Total population就是有多少个样本,这里是150个,Target attribute是指感兴趣的属性是哪个,这里是class属性,Target value是感兴趣的属性中感兴趣的值,这里是Iris-setosa,一共有50个样本属于这个类别值,接下来是支持次数是50,支持度是33%。最大分支数是2,对目标的最小提高是1%。再接下来是目标属性名字(class),目标属性值(setosa),再将下来是重复Target value里的内容。再下来是树形结构的输出。

HotSpotHashKeybuildAssociations开始看起:

         前面的m_target是指定类别属性indexm_targetIndex是指对类别中的哪一个值感兴趣,m_support是最小支持度,用百分数表示,m_supportCount是最小支持计数,

if (inst.attribute(m_target).isNumeric()) {

    if (m_supportCount > m_numInstances) {

       m_errorMessage = "Error: support set to more instances than

there are in the data!";

       return;

    }

    m_globalTarget = inst.meanOrMode(m_target);

} else {

    double[] probs = new double[inst.attributeStats(m_target)

.nominalCounts.length];

    for (int i = 0; i < probs.length; i++) {

       probs[i] = (double) inst.attributeStats(m_target)

.nominalCounts[i];

    }

    m_globalSupport = (int) probs[m_targetIndex];

    // check that global support is greater than min support

    if (m_globalSupport < m_supportCount) {

       m_errorMessage = "Error: minimum support " + m_supportCount

              + " is too high. Target value "

              + m_header.attribute(m_target).value(m_targetIndex)

              + " has support " + m_globalSupport + ".";

    }

 

    Utils.normalize(probs);

    m_globalTarget = probs[m_targetIndex];

}

         如果类别是连续值,那么m_globalTarget就是均值(Mode是众数,只有属性是离散值时,函数meanOrMode才返回众数)。如果是离散值,那么求得感兴趣的值是否小于m_supportCount,将归一化后的概率赋以m_globalTarget

m_ruleLookup = new HashMap<HotSpotHashKey, String>();

double[] splitVals = new double[m_header.numAttributes()];

byte[] tests = new byte[m_header.numAttributes()];

 

m_head = new HotNode(inst, m_globalTarget, splitVals, tests);

         在这里调用HotNode进行下一步计算。

m_insts = insts;

m_targetValue = targetValue;

PriorityQueue<HotTestDetails> splitQueue = new

PriorityQueue<HotTestDetails>();

 

// Consider each attribute

for (int i = 0; i < m_insts.numAttributes(); i++) {

    if (i != m_target) {

       if (m_insts.attribute(i).isNominal()) {

           evaluateNominal(i, splitQueue);

       } else {

           evaluateNumeric(i, splitQueue);

       }

    }

}

         对每一个属性调用evaluteNorminalevaluateNumeric,下面是evluateNorminal的代码:

int[] counts = m_insts.attributeStats(attIndex).nominalCounts;

boolean ok = false;

// only consider attribute values that result in subsets that

// meet/exceed min support

for (int i = 0; i < m_insts.attribute(attIndex).numValues(); i++) {

    if (counts[i] >= m_supportCount) {

       ok = true;

       break;

    }

}

         Counts中保存的是在attIndex属性上,每种属性取值的次数,再对这些次数与m_supportCount进行比较,如果超过支持次数,进行下一步:

double[] subsetMerit = new double[m_insts.attribute(attIndex)

       .numValues()];

 

for (int i = 0; i < m_insts.numInstances(); i++) {

    Instance temp = m_insts.instance(i);

    if (!temp.isMissing(attIndex)) {

       int attVal = (int) temp.value(attIndex);

       if (m_insts.attribute(m_target).isNumeric()) {

           subsetMerit[attVal] += temp.value(m_target);

       } else {

           subsetMerit[attVal] += ((int) temp.value(m_target) ==

m_targetIndex) ? 1.0 : 0;

       }

    }

}

         如果不是缺失值,如果类别值为连续值,则subsetMerit[attVal]累加属性值,如果类别值为离散值,则对属性值进行计数。

// add to queue if it meets min support and exceeds the merit

// for the full set

for (int i = 0; i < m_insts.attribute(attIndex).numValues(); i++) {

    // does the subset based on this value have enough

    // instances, and, furthermore, does the target value

    // (nominal only) occur enough times to exceed min support

    if (counts[i] >= m_supportCount

           && ((m_insts.attribute(m_target).isNominal()) ?

(subsetMerit[i] >= m_supportCount): ture {

       double merit = subsetMerit[i] / counts[i]; // subsetMerit[i][1];

       double delta = (m_minimize) ? m_targetValue - merit

              : merit - m_targetValue;

 

       if (delta / m_targetValue >= m_minImprovement) {

           double support = (m_insts.attribute(m_target)

                  .isNominal()) ? subsetMerit[i] : counts[i];

 

           HotTestDetails newD = new HotTestDetails(attIndex,

                  (double) i, false, (int) support,

                  counts[i], merit);

           pq.add(newD);

       }

    }

}

         If判断如果m_target属性是离散的那么判断是否类别属性值出次的次数超过了m_supportCount,如果是连续值,那么就直接为true。接下来merit是在该属性值上感兴趣的类别属性值出现的比率,delta是根据参数是m_minimize的真假来决定的。如果超过了m_minImprovement说明这个属性值得加入,也就是它提高了一定的比率。将下来new一个HotTestDetails对象newD加入splitQueue中。

         回到HotNode函数中,如果splitQueue的大小不为0,那么:

int queueSize = splitQueue.size();

 

// count how many of the potential children are unique

ArrayList<HotTestDetails> newCandidates = new

ArrayList<HotTestDetails>();

ArrayList<HotSpotHashKey> keyList = new ArrayList<HotSpotHashKey>();

 

for (int i = 0; i < queueSize; i++) {

    if (newCandidates.size() < m_maxBranchingFactor) {

       HotTestDetails temp = splitQueue.poll();

       double[] newSplitVals = splitVals.clone();

       byte[] newTests = tests.clone();

       newSplitVals[temp.m_splitAttIndex] = temp.m_splitValue + 1;

       newTests[temp.m_splitAttIndex] = (m_header

              .attribute(temp.m_splitAttIndex).isNominal()) ? (byte) 2

              : (temp.m_lessThan) ? (byte) 1 : (byte) 3;

       HotSpotHashKey key = new HotSpotHashKey(newSplitVals,

              newTests);

       m_lookups++;

       if (!m_ruleLookup.containsKey(key)) {

           // insert it into the hash table

           m_ruleLookup.put(key, "");

           newCandidates.add(temp);

           keyList.add(key);

           m_insertions++;

       }

    } else {

       break;

    }

}

         这里将刚才得到的属性保存到m_ruleLookup中去,newSplitVals是在指哪分开的,如果是离散值,那就是一个属性值,连续值就是在哪将属性分成了两部分,newTests就是指在比连续属性值分隔值是大还是小。2是等于,1是小于,3是大于。

m_children = new HotNode[(newCandidates.size() < m_maxBranchingFactor) ? newCandidates

       .size() : m_maxBranchingFactor];

// save the details of the tests at this node

m_testDetails = new HotTestDetails[m_children.length];

for (int i = 0; i < m_children.length; i++) {

    m_testDetails[i] = newCandidates.get(i);

}

 

// save memory

splitQueue = null;

newCandidates = null;

m_insts = new Instances(m_insts, 0);

 

// process the children

for (int i = 0; i < m_children.length; i++) {

    Instances subset = subset(insts, m_testDetails[i]);

    HotSpotHashKey tempKey = keyList.get(i);

    m_children[i] = new HotNode(subset,

           m_testDetails[i].m_merit, tempKey.m_splitValues,

           tempKey.m_testTypes);

}

         刚才是得到了树形结构的第一层,这段代码就是在每一个子结点上再递归,得到下一层的结点。

 

 

 

  评论这张
 
阅读(2205)| 评论(0)
推荐 转载

历史上的今天

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017