注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

Koala++'s blog

计算广告学 RTB

 
 
 

日志

 
 

Weka开发[35]——StringToWordVector源代码分析(2)  

2010-01-16 17:03:43|  分类: 机器学习 |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |

将每个类别的字记数放到array里去,再对它进行排序,词的个数少于m_WordsToKeep,那就全留下,也就是出现1次就可以了。否则就m_minTermFreq和刚才排对序的数组中第m_WordsToKeep元素之间的最大值,把它赋给prune[z]。而totalSize是全部词的数量。

// Convert the dictionary into an attribute index

// and create one attribute per word

FastVector attributes = new FastVector(totalsize

       + getInputFormat().numAttributes());

 

// Add the non-converted attributes

int classIndex = -1;

for (int i = 0; i < getInputFormat().numAttributes(); i++) {

    if (!m_SelectedRange.isInRange(i)) {

       if (getInputFormat().classIndex() == i) {

           classIndex = attributes.size();

       }

       attributes.addElement(getInputFormat().attribute(i).copy());

    }

}

         Attributes大小为词的总大小(重复算的)和全部原属性的大小,接下来,把没有to word vector的原属性加到attributes中。

// Add the word vector attributes (eliminating duplicates

// that occur in multiple classes)

TreeMap newDictionary = new TreeMap();

int index = attributes.size();

for (int z = 0; z < values; z++) {

    Iterator it = dictionaryArr[z].keySet().iterator();

    while (it.hasNext()) {

       String word = (String) it.next();

       Count count = (Count) dictionaryArr[z].get(word);

       if (count.count >= prune[z]) {

           if (newDictionary.get(word) == null) {

               newDictionary.put(word, new Integer(index++));

              attributes.addElement(new Attribute(m_Prefix + word));

           }

       }

    }

}

         这里是将词合起来成为newDictionary,如果一个词不会被去除(即它出现次次大于prune[z])就将这个词做为一个新的属性加进去,属性的名字是m_Prefix+word

// Compute document frequencies

m_DocsCounts = new int[attributes.size()];

Iterator it = newDictionary.keySet().iterator();

while (it.hasNext()) {

    String word = (String) it.next();

    int idx = ((Integer) newDictionary.get(word)).intValue();

    int docsCount = 0;

    for (int j = 0; j < values; j++) {

       Count c = (Count) dictionaryArr[j].get(word);

       if (c != null)

           docsCount += c.docCount;

    }

    m_DocsCounts[idx] = docsCount;

}

         将一个词的几个类别的docCount合并,合并后的计数放到m_DocCounts里。

// Trim vector and set instance variables

attributes.trimToSize();

m_Dictionary = newDictionary;

m_NumInstances = getInputFormat().numInstances();

 

// Set the filter's output format

Instances outputFormat = new Instances(getInputFormat().relationName(),

       attributes, 0);

outputFormat.setClassIndex(classIndex);

setOutputFormat(outputFormat);

         trimToSize是把刚才申请空间时,那些没用到的去掉。下面就产生一个新的输出格式,relationName就是arff的名字,一般显示在.arff文件的第一行。

// Convert all instances w/o normalization

FastVector fv = new FastVector();

int firstCopy = 0;

for (int i = 0; i < m_NumInstances; i++) {

    firstCopy = convertInstancewoDocNorm(getInputFormat().instance(

           i), fv);

}

         convertInstancewoDocNorm的代码拆开来:

// Convert the instance into a sorted set of indexes

TreeMap contained = new TreeMap();

 

// Copy all non-converted attributes from input to output

int firstCopy = 0;

for (int i = 0; i < getInputFormat().numAttributes(); i++) {

    if (!m_SelectedRange.isInRange(i)) {

       if (getInputFormat().attribute(i).type() != Attribute.STRING) {

           // Add simple nominal and numeric attributes directly

           if (instance.value(i) != 0.0) {

              contained.put(new Integer(firstCopy), new Double(

                     instance.value(i)));

           }

       } else {

           if (instance.isMissing(i)) {

              contained.put(new Integer(firstCopy), new Double(

                     Instance.missingValue()));

           } else {

              // If this is a string attribute, we have to first add

              // this value to the range of possible values, then add

              // its new internal index.

              if (outputFormatPeek().attribute(firstCopy).numValues()

 == 0) {

                  // Note that the first string value in a

                  // SparseInstance doesn't get printed.

                  outputFormatPeek()

                         .attribute(firstCopy)

                         .addStringValue(

                                "Hack to defeat SparseInstance bug");

              }

              int newIndex = outputFormatPeek().attribute(firstCopy)

                     .addStringValue(instance.stringValue(i));

              contained.put(new Integer(firstCopy), new Double(

                     newIndex));

           }

       }

       firstCopy++;

    }

}

         对这个样本的所有属性循环,如果不是STRING类型,自然没有to word vector的可能,相应属性值不为0,就加入到contained中去。若为缺失值,就赋给Instance.missingValue()值。如果它是一个string属性,我们需要先将这个值加入到可能值的范围中,再加入它的新内部索引。

for (int j = 0; j < instance.numAttributes(); j++) {

    if (m_SelectedRange.isInRange(j)

           && (instance.isMissing(j) == false)) {

 

       m_Tokenizer.tokenize(instance.stringValue(j));

 

       while (m_Tokenizer.hasMoreElements()) {

           String word = (String) m_Tokenizer.nextElement();

           if (this.m_lowerCaseTokens == true)

              word = word.toLowerCase();

           word = m_Stemmer.stem(word);

           Integer index = (Integer) m_Dictionary.get(word);

           if (index != null) {

              if (m_OutputCounts) { // Separate if here rather than

                                   // two lines down to avoid

                                   // hashtable lookup

                  Double count = (Double) contained.get(index);

                  if (count != null) {

                     contained.put(index, new Double(count

                            .doubleValue() + 1.0));

                  } else {

                     contained.put(index, new Double(1));

                  }

              } else {

                  contained.put(index, new Double(1));

              }

           }

       }

    }

}

         这里和刚才求字典时差不多,只是没有停词这一步,这是因为如果在词典中找不到就可以把停词过掉了(当然,也去掉了一些低频词),即index == nullm_OutputCountstrue为要记录出现了多少次,而false只管出现与否。

// Doing TFTransform

if (m_TFTransform == true) {

    Iterator it = contained.keySet().iterator();

    for (int i = 0; it.hasNext(); i++) {

       Integer index = (Integer) it.next();

       if (index.intValue() >= firstCopy) {

           double val = ((Double) contained.get(index)).doubleValue();

           val = Math.log(val + 1);

           contained.put(index, new Double(val));

       }

    }

}

         如果在进行TFTransform,就把求得的词向量中的值+1后求log+1应该是处理1这个情况的,因为log1等于0

// Doing IDFTransform

if (m_IDFTransform == true) {

    Iterator it = contained.keySet().iterator();

    for (int i = 0; it.hasNext(); i++) {

       Integer index = (Integer) it.next();

       if (index.intValue() >= firstCopy) {

           double val = ((Double) contained.get(index)).doubleValue();

           val = val * Math.log(m_NumInstances

                  / (double) m_DocsCounts[index.intValue()]);

           contained.put(index, new Double(val));

       }

    }

}

         这里相当于求tf-idf了,右边的valtf,而idflog(N/n)N为总出现次数,n为包含有这个词的文档数。

// Convert the set to structures needed to create a sparse instance.

double[] values = new double[contained.size()];

int[] indices = new int[contained.size()];

Iterator it = contained.keySet().iterator();

for (int i = 0; it.hasNext(); i++) {

    Integer index = (Integer) it.next();

    Double value = (Double) contained.get(index);

    values[i] = value.doubleValue();

    indices[i] = index.intValue();

}

 

Instance inst = new SparseInstance(instance.weight(), values, indices,

       outputFormatPeek().numAttributes());

inst.setDataset(outputFormatPeek());

 

v.addElement(inst);

contained中的值再转到valuesindices中。存到稀疏样本类SparseInstance中。再将这个样本保存到v这个FastVector中。

// Need to compute average document length if necessary

if (m_filterType != FILTER_NONE) {

    m_AvgDocLength = 0;

    for (int i = 0; i < fv.size(); i++) {

       Instance inst = (Instance) fv.elementAt(i);

       double docLength = 0;

       for (int j = 0; j < inst.numValues(); j++) {

           if (inst.index(j) >= firstCopy) {

              docLength += inst.valueSparse(j)

                     * inst.valueSparse(j);

           }

       }

       m_AvgDocLength += Math.sqrt(docLength);

    }

    m_AvgDocLength /= m_NumInstances;

}

 

// Perform normalization if necessary.

if (m_filterType == FILTER_NORMALIZE_ALL) {

    for (int i = 0; i < fv.size(); i++) {

       normalizeInstance((Instance) fv.elementAt(i), firstCopy);

    }

}

 

// Push all instances into the output queue

for (int i = 0; i < fv.size(); i++) {

    push((Instance) fv.elementAt(i));

}

         回到batchFinished中,如果要过滤,过滤的几个定义如下:

/** normalization: No normalization. */

public static final int FILTER_NONE = 0;

/** normalization: Normalize all data. */

public static final int FILTER_NORMALIZE_ALL = 1;

/** normalization: Normalize test data only. */

public static final int FILTER_NORMALIZE_TEST_ONLY = 2;

         m_AvgDocLength并不是词的个数,而是doctfidf(如果是tf-idf)的绝对值之和平均值。如果要对所有的样本进行normalization

private void normalizeInstance(Instance inst, int firstCopy)

       throws Exception {

 

    double docLength = 0;

 

    // Compute length of document vector

    for (int j = 0; j < inst.numValues(); j++) {

       if (inst.index(j) >= firstCopy) {

           docLength += inst.valueSparse(j) * inst.valueSparse(j);

       }

    }

    docLength = Math.sqrt(docLength);

 

    // Normalize document vector

    for (int j = 0; j < inst.numValues(); j++) {

       if (inst.index(j) >= firstCopy) {

           double val = inst.valueSparse(j) * m_AvgDocLength / docLength;

           inst.setValueSparse(j, val);

       }

    }

}

         很简单,就是如刚才一样求得docLength,再乘以m_AvgDocLength,再除以docLength

protected void push(Instance instance) {

 

    if (instance != null) {

       if (instance.dataset() != null)

           copyValues(instance, false);

       instance.setDataset(m_OutputFormat);

       m_OutputQueue.push(instance);

    }

}

         设置.arff 的文件头,再将样本加入到m_OutputQueue

  评论这张
 
阅读(2526)| 评论(2)
推荐 转载

历史上的今天

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017