2016-12-30 183 views
0

我正在尝试编写生成决策树的ID3算法,但运行我的代码时出现StackOverflowError错误。 当调试时,我注意到循环开始时,属性下降到4(从最初9)。 树生成的代码如下。我打电话的所有功能都正常工作,它们已经过测试。 但是,错误代码指出问题出在另一个使用流的函数上,但它已经单独测试了 ,我知道它正常工作。请记住,我正在处理随机数据,因此该函数有时会抛出错误,有时不会。我在其下面发布了错误代码,但熵函数和信息增益工作。StackOverflowError决策树生成JAVA

这是树节点结构:

public class TreeNode { 
    List<Patient> samples; 
    List<TreeNode> children; 
    TreeNode parent; 
    Integer attribute; 
    String attributeValue; 
    String className; 

    public TreeNode(List<Patient> samples, List<TreeNode> children, TreeNode parent, Integer attribute, 
      String attributeValue, String className) { 
     this.samples = samples; 
     this.children = children; 
     this.parent = parent; 
     this.attribute = attribute; 
     this.attributeValue = attributeValue; 
     this.className = className; 
    } 
} 

这就是抛出错误代码:

public TreeNode id3(List<Patient> patients, List<Integer> attributes, TreeNode root) { 
     boolean isLeaf = patients.stream().collect(Collectors.groupingBy(i -> i.className)).keySet().size() == 1; 
     if (isLeaf) { 
      root.setClassName(patients.get(0).className); 
      return root; 
     } 
     if (attributes.size() == 0) { 
      root.setClassName(mostCommonClass(patients)); 
      return root; 
     } 
     int bestAttribute = maxInformationGainAttribute(patients, attributes); 
     Set<String> attributeValues = attributeValues(patients, bestAttribute); 
     for (String value : attributeValues) { 
      List<Patient> branch = patients.stream().filter(i -> i.patientData[bestAttribute].equals(value)) 
        .collect(Collectors.toList()); 

      TreeNode child = new TreeNode(branch, new ArrayList<>(), root, bestAttribute, value, null); 

      if (branch.isEmpty()) { 
       child.setClassName(mostCommonClass(patients)); 
       root.addChild(new TreeNode(child)); 
      } else { 
       List<Integer> newAttributes = new ArrayList<>(); 
       newAttributes.addAll(attributes); 
       newAttributes.remove(new Integer(bestAttribute)); 
       root.addChild(new TreeNode(id3(branch, newAttributes, child))); 
      } 
     } 
     return root; 
    } 

这些都是其他功能:

public static double entropy(List<Patient> patients) { 
     double entropy = 0.0; 
     double recurP = (double) patients.stream().filter(i -> i.className.equals("recurrence-events")).count() 
       /(double) patients.size(); 
     double noRecurP = (double) patients.stream().filter(i -> i.className.equals("no-recurrence-events")).count() 
       /(double) patients.size(); 
     entropy -= (recurP * (recurP > 0 ? Math.log(recurP) : 0/Math.log(2)) 
       + noRecurP * (noRecurP > 0 ? Math.log(noRecurP) : 0/Math.log(2))); 
     return entropy; 
    } 



public static double informationGain(List<Patient> patients, int attribute) { 
     double informationGain = entropy(patients); 
     Map<String, List<Patient>> patientsGroupedByAttribute = patients.stream() 
       .collect(Collectors.groupingBy(i -> i.patientData[attribute])); 
     List<List<Patient>> subsets = new ArrayList<>(); 
     for (String i : patientsGroupedByAttribute.keySet()) { 
      subsets.add(patientsGroupedByAttribute.get(i)); 
     } 

     for (List<Patient> lp : subsets) { 
      informationGain -= proportion(lp, patients) * entropy(lp); 
     } 
     return informationGain; 
    } 


private static int maxInformationGainAttribute(List<Patient> patients, List<Integer> attributes) { 
     int maxAttribute = 0; 
     double maxInformationGain = 0; 
     for (int i : attributes) { 
      if (informationGain(patients, i) > maxInformationGain) { 
       maxAttribute = i; 
       maxInformationGain = informationGain(patients, i); 
      } 
     } 
     return maxAttribute; 
    } 

例外:

Exception in thread "main" java.lang.StackOverflowError 
    at java.util.stream.ReferencePipeline$2$1.accept(Unknown Source) 
    at java.util.ArrayList$ArrayListSpliterator.forEachRemaining(Unknown Source) 
    at java.util.stream.AbstractPipeline.copyInto(Unknown Source) 
    at java.util.stream.AbstractPipeline.wrapAndCopyInto(Unknown Source) 
    at java.util.stream.ReduceOps$ReduceOp.evaluateSequential(Unknown Source) 
    at java.util.stream.AbstractPipeline.evaluate(Unknown Source) 
    at java.util.stream.LongPipeline.reduce(Unknown Source) 
    at java.util.stream.LongPipeline.sum(Unknown Source) 
    at java.util.stream.ReferencePipeline.count(Unknown Source) 
    at Patient.entropy(Patient.java:39) 
    at Patient.informationGain(Patient.java:67) 
    at Patient.maxInformationGainAttribute(Patient.java:85) 
    at Patient.id3(Patient.java:109) 

回答

0

行:

root.addChild(new TreeNode(id3(branch, newAttributes, child)));

被调用每一个方法递归时间,从而导致堆栈溢出。这告诉我你的逻辑中有什么错误,没有任何结束递归的“基本情况”,即返回根目录。我对预期的行为或开始的数据知之甚少,无法确定发生了什么问题,但我会先用调试器逐步完成代码,并确保该方法中的逻辑表现出您期望的行为。我知道这不是一个很好的答案,但它是一个起点,希望帮助或其他人会用更具体的解决方案加以注意。

+0

我一直在一遍又一遍的调试它,它的工作原理直到属性降到4,这是奇怪的部分。当属性下降到4时,它开始回退一步,并再次向前走。但它直到那时才生成适当的树。 :( – vixenn

+0

我会看看两种方法, maxInformationGainAttribute(患者,属性); 和 attributeValues(patients,bestAttribute); ,并确保它们返回您所期望的值,以防止它卡住。 –

+0

确保maxInformationGainAttribute(patients,attributes);正在做它应该做的事情,因为如果它不修改属性列表,那么您将在此行传递相同的值: newAttributes.addAll(attributes); –