使用 Java WEKA 库时正确标记预测类

Correctly labeling predicted classes when using the Java WEKA library

提问人:Hack-R 提问时间:1/31/2017 最后编辑:CommunityHack-R 更新时间:1/31/2017 访问量:676

问:

我有一个程序,它训练一个具有 2 类分类结果的算法,然后运行并写出未标记数据集的预测(2 个类中每个类的概率)。

针对此程序运行的所有数据集将具有与结果相同的 2 个类。考虑到这一点,我运行了预测,并使用了一些事后统计来确定哪一列结果描述了哪个结果,并继续对它们进行硬编码:

public class runPredictions {
public static void runPredictions(ArrayList al2) throws IOException, Exception{
    // Retrieve objects
    Instances newTest = (Instances) al2.get(0);
    Classifier clf = (Classifier) al2.get(1);

    // Print status
    System.out.println("Generating predictions...");

    // create copy
    Instances labeled = new Instances(newTest);

    BufferedWriter outFile = new BufferedWriter(new FileWriter("silverbullet_rro_output.csv"));
    StringBuilder builder = new StringBuilder();

    builder.append("Prob_Retain"+","+"Prob_Attrite"+"\n");
    for (int i = 0; i < labeled.size(); i++)      
    {
        double[] clsLabel = clf.distributionForInstance(newTest.instance(i));
        for(int j=0;j<2;j++){
           builder.append(clsLabel[j]+""); 
           if(j < clsLabel.length - 1)
               builder.append(",");
        }
        builder.append("\n");
    }
    outFile.write(builder.toString());//save the string representation
    System.out.println("Output file written.");
    System.out.println("Completed successfully!");
    outFile.close();    
}    
}

这样做的问题在于,事实证明,2 列中的哪一列描述了 2 个结果类别中的哪一列不是固定的。这似乎与哪个类别首先出现在训练数据集中有关,这完全是任意的。因此,当其他数据集与该程序一起使用时,硬编码标签是向后编码的。

因此,我需要一种更好的方法来标记它们,但是查看 和 的文档,我没有看到任何有用的东西。ClassifierdistributionForInstance

更新

我想出了如何将其打印到屏幕上(多亏了这一点),但在将其写入 csv 时仍然遇到问题:

for (int i = 0; i < labeled.size(); i++)      
    {
        // Discreet prediction
        double predictionIndex = 
            clf.classifyInstance(newTest.instance(i)); 

        // Get the predicted class label from the predictionIndex.
        String predictedClassLabel =
            newTest.classAttribute().value((int) predictionIndex);

        // Get the prediction probability distribution.
        double[] predictionDistribution = 
            clf.distributionForInstance(newTest.instance(i)); 

        // Print out the true predicted label, and the distribution
        System.out.printf("%5d: predicted=%-10s, distribution=", 
                          i, predictedClassLabel); 

        // Loop over all the prediction labels in the distribution.
        for (int predictionDistributionIndex = 0; 
             predictionDistributionIndex < predictionDistribution.length; 
             predictionDistributionIndex++)
        {
            // Get this distribution index's class label.
            String predictionDistributionIndexAsClassLabel = 
                newTest.classAttribute().value(
                    predictionDistributionIndex);

            // Get the probability.
            double predictionProbability = 
                predictionDistribution[predictionDistributionIndex];

            System.out.printf("[%10s : %6.3f]", 
                              predictionDistributionIndexAsClassLabel, 
                              predictionProbability );

            // Attempt to write to CSV
            builder.append(i+","+predictedClassLabel+","+
                    predictionDistributionIndexAsClassLabel+","+predictionProbability);
                            //.charAt(0)+','+predictionProbability.charAt(0));

        }

        System.out.printf("\n");
        builder.append("\n");
爪哇 维卡

评论


答:

1赞 Walter 1/31/2017 #1

我从这个答案和这个答案改编了下面的代码。基本上,您可以查询类属性的测试数据,然后获取每个可能的类的特定值。

for (int i = 0; i < labeled.size(); i++)      
{
// Discreet prediction

double predictionIndex = 
    clf.classifyInstance(newTest.instance(i)); 

// Get the predicted class label from the predictionIndex.
String predictedClassLabel =
    newTest.classAttribute().value((int) predictionIndex);

// Get the prediction probability distribution.
double[] predictionDistribution = 
    clf.distributionForInstance(newTest.instance(i)); 

// Print out the true predicted label, and the distribution
System.out.printf("%5d: predicted=%-10s, distribution=", 
                  i, predictedClassLabel); 

// Loop over all the prediction labels in the distribution.
for (int predictionDistributionIndex = 0; 
     predictionDistributionIndex < predictionDistribution.length; 
     predictionDistributionIndex++)
{
    // Get this distribution index's class label.
    String predictionDistributionIndexAsClassLabel = 
        newTest.classAttribute().value(
            predictionDistributionIndex);

    // Get the probability.
    double predictionProbability = 
        predictionDistribution[predictionDistributionIndex];

    System.out.printf("[%10s : %6.3f]", 
                      predictionDistributionIndexAsClassLabel, 
                      predictionProbability );

    // Write to CSV
    builder.append(i+","+
            predictionDistributionIndexAsClassLabel+","+predictionProbability);


}

System.out.printf("\n");
builder.append("\n");

}


// Save results in .csv file
outFile.write(builder.toString());//save the string representation

评论

1赞 Walter 1/31/2017
你是绝对正确的,我应该是一个不同的索引!它只是您正在评估的实例。我会纠正的
0赞 Hack-R 1/31/2017
再次感谢。所以第一行的 for 循环应该是这样的,对吧?对于 2 类(2 个标签)的情况,将始终为 0 或 1,但是在您拥有的行上,我们不需要某个地方来访问正确的部分?for (int j = 0; j < newTest.size(); j++)inewTest.classAttribute().value(i)jnewTest
1赞 Hack-R 1/31/2017
我认为可能有一个小错误或缺少组件(见上面的评论),但我刚才在我的版本中得到了这个工作,你值得称赞,所以我要用我的版本编辑你的帖子并将其标记为解决方案。如果你想回滚我将要进行的编辑,只需调整你的版本,这非常酷。再次感谢您的帮助,我非常感谢!