2014-09-11 139 views
2

我一直在努力训练下面的网络,并获得合适的权重,但它一直在运行。任何人都可以告诉我代码中可能有什么错误?这里{8,1}是输入,{-1}}预期使用一个符号函数输出。单层感知器培训?

import java.util.Arrays; 

public class ANN { 

    public static void main(String args[]) { 

     double threshold = 1.2; 
     double learningRate = 0.08; 

     // Init weights 

     double[] weights = { -1.4, 1.8 }; 

     int[][][] trainingData = { 
      {{8, 1}, {-1}}, 
      {{3, 2}, {-1}}, 
      {{6, 3}, {-1}}, 
      {{1, 4}, {-1}}, 
      {{9, 5}, {1}}, 
      {{5, 6}, {1}}, 
      {{2, 7}, {1}}, 
      {{4, 8}, {1}}, 
      {{7, 9}, {1}}, 
     }; 

     // Start training loop 
     while (true) { 
      int errorCount = 0; 
      // Loop over training data 
      for (int i = 0; i < trainingData.length; i++) { 
       System.out.println("Starting weights: " + Arrays.toString(weights)); 
       // Calculate weighted input 
       double weightedSum = 0; 
       for (int ii = 0; ii < trainingData[i][0].length; ii++) { 
        weightedSum += trainingData[i][0][ii] * weights[ii]; 
       } 

       // Calculate output 
       int output = 0; 
       if (threshold <= weightedSum) { 
        output = 1; 
       } 

       System.out.println("Target output: " + trainingData[i][1][0] 
         + ", " + "Actual Output: " + output); 

       // Calculate error 
       int error = trainingData[i][1][0] - output; 
       System.out.println("Error: " + error); 
       // Increase error count for incorrect output 
       if (error != 0) { 
        errorCount++; 
       } 

       // Update weights 
       for (int ii = 0; ii < trainingData[i][0].length; ii++) { 
        weights[ii] += learningRate * error 
          * trainingData[i][0][ii]; 
       } 

       System.out.println("New weights: " + Arrays.toString(weights)); 
       System.out.println(); 
      } 

      // If there are no errors, stop 
      if (errorCount == 0) { 
       System.out 
         .println("Final weights: " + Arrays.toString(weights)); 
       System.exit(0); 
      } 
     } 
    } 

} 

编辑:我认为问题出现在计算输出的代码片段中。它应该翻转,以便如果总和大于阈值,则输出为1,否则为0。

// Calculate output 
       int output = 0; 
       if (weightedSum > threshold) { 
        output = 1; 
       } 

回答

1

我遇到你的代码,并添加了一行之前的(ERRORCOUNT == 0)检查:

System.out.println(errorCount); 

这看似6和7之间振荡,这意味着神经网络总是无论训练数量如何,都会生成对训练数据的无效估计。如果训练数据无法达到100%的正确率,那么预计这种训练将永远持续。

希望这有助于!

1

您的错误可能是正面和负面的。在第一次运行中,错误是-1。因此,errorCount递增,退出循环的代码从不执行。

完整培训的条件应该基于错误本身,而不是errorCount。当错误达到最低水平(您将根据您的输入设置)时,培训将被视为已完成。