注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

Koala++'s blog

计算广告学 RTB

 
 
 

日志

 
 

Weka开发[39]——CVParameterSelection源代码分析  

2010-04-15 14:24:09|  分类: 机器学习 |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |

       CVParmeterSelection是在Weka[37]中提到的寻找最优参数的类,它是继承自RandomizableSingleClassifierEnhanceRancomizableSingleClassifierEnhance的注释写到:Abstract utility class for handling settings common to randomizable meta classifiers that build an ensemble from a single base learner

       还是先从buildClassifier开始:

public void buildClassifier(Instances instances) throws Exception {

    Instances trainData = new Instances(instances);

    trainData.deleteWithMissingClass();

   

    m_InitOptions = ((OptionHandler) m_Classifier).getOptions();

    m_BestPerformance = -99;

    m_NumAttributes = trainData.numAttributes();

    Random random = new Random(m_Seed);

    trainData.randomize(random);

    m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances();

 

    // Check whether there are any parameters to optimize

    if (m_CVParams.size() == 0) {

       m_Classifier.buildClassifier(trainData);

       m_BestClassifierOptions = m_InitOptions;

       return;

    }

 

    if (trainData.classAttribute().isNominal()) {

       trainData.stratify(m_NumFolds);

    }

    m_BestClassifierOptions = null;

 

    // Set up m_ClassifierOptions -- take getOptions() and remove

    // those being optimised.

    m_ClassifierOptions = ((OptionHandler) m_Classifier).getOptions();

    for (int i = 0; i < m_CVParams.size(); i++) {

       Utils.getOption(

              ((CVParameter) m_CVParams.elementAt(i)).m_ParamChar,

              m_ClassifierOptions);

    }

    findParamsByCrossValidation(0, trainData, random);

 

    String[] options = (String[]) m_BestClassifierOptions.clone();

    ((OptionHandler) m_Classifier).setOptions(options);

    m_Classifier.buildClassifier(trainData);

}

       trainCV得到(numFolds-1)/numFolds比例的样本,用findParamsByCrossValidation得到最优参数,最后再用m_BestClassifierOptions最优参数进行训练,得到最优的分类器。

       findParamsByCrossValidation中的depth表示现在是针对第几个参数进行的优化,现在将ifelse分开:

if (depth < m_CVParams.size()) {

    CVParameter cvParam = (CVParameter) m_CVParams.elementAt(depth);

 

    double upper;

    switch ((int) (cvParam.m_Lower - cvParam.m_Upper + 0.5)) {

    case 1:

       upper = m_NumAttributes;

       break;

    case 2:

       upper = m_TrainFoldSize;

       break;

    default:

       upper = cvParam.m_Upper;

       break;

    }

    double increment = (upper - cvParam.m_Lower)

           / (cvParam.m_Steps - 1);

    for (cvParam.m_ParamValue = cvParam.m_Lower;

cvParam.m_ParamValue <= upper;

cvParam.m_ParamValue += increment) {

       findParamsByCrossValidation(depth + 1, trainData, random);

    }

}

       depth<m_CVParams.size()表示还优化到最后一个参数,这里m_Lower-m_Upper还能是正值的原因是在CVParameter类中特别设置的。

if (st.sval.toUpperCase().charAt(0) == 'A') {

    m_Upper = m_Lower - 1;

} else if (st.sval.toUpperCase().charAt(0) == 'I') {

    m_Upper = m_Lower - 2;

} else {

    throw new Exception("CVParameter " + param

           + ": Upper bound must be numeric, or 'A' or 'N'");

}

       ‘A’对所有属性进行循环,也就对应case 1‘I’表示的是以样本数进行循环。Increment是平均步长,递归调用findParamsByCrossValidation对下一个参数进行循环。

else {

    Evaluation evaluation = new Evaluation(trainData);

 

    // Set the classifier options

    String[] options = createOptions();

    ((OptionHandler) m_Classifier).setOptions(options);

    for (int j = 0; j < m_NumFolds; j++) {

 

       // We want to randomize the data the same way for every

       // learning scheme.

       Instances train = trainData.trainCV(m_NumFolds, j,

              new Random(1));

       Instances test = trainData.testCV(m_NumFolds, j);

       m_Classifier.buildClassifier(train);

       evaluation.setPriors(train);

       evaluation.evaluateModel(m_Classifier, test);

    }

    double error = evaluation.errorRate();

    if ((m_BestPerformance == -99) || (error < m_BestPerformance)) {

 

       m_BestPerformance = error;

       m_BestClassifierOptions = createOptions();

    }

}

       create得到参数后,对m_ClassifiersetOptions设置参数,进行进行m_NumFolds折交叉验证,这里是用错误率(1-正确率)来判断哪个分类器更好,m_BestPerformance得到的是当前最低的错误率,m_BestClassifierOptions是当前的最好参数。

 

 

 

 

 

  评论这张
 
阅读(1516)| 评论(0)
推荐 转载

历史上的今天

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017