1

对于我的APCS最后一个项目,我在做一个应用程序:如何训练OCR神经网络?

  • 允许用户绘制图纸面板上的数字;
  • 缩放/翻译每个笔画(由x-y坐标列表表示)为100x100;
  • 从缩放的笔画产生图像;
  • 从该图像产生二进制二维数组(0代表白色,否则为1);
  • 并将该二进制数组传递给神经元对象以进行字符识别。

下面的类代表神经元:

import java.awt.*; 
import java.util.*; 
import java.io.*; 

public class Neuron 
{ 
    private double[][] weights; 
    public static double LEARNING_RATE = 0.01; 

    /** 
    *Initialize weights 
    *Assign random double values to weights 
    */ 
    public Neuron(int r, int c) 
    { 
     weights = new double[r][c]; 

     PrintWriter printer = null; 
     try 
     { 
      printer = new PrintWriter("training.txt"); 
     } 
     catch (FileNotFoundException e) {}; 
     for (int i = 0; i < weights.length; i++) 
     { 
      for (int j = 0; j < weights[i].length; j++) 
      { 
       weights[i][j] = 2 * Math.random() - 1; //Generates random number between -1 and 1 
       if (j < weights[i].length - 1) 
        printer.print(weights[i][j] + " "); 
       else 
        printer.print(weights[i][j]); 
      } 
      printer.println(); 
     } 
     printer.close(); 
    } 

    public Neuron(String fileName) 
    { 
     File data = new File(fileName); 
     Scanner input = null; 
     try 
     { 
      input = new Scanner(data); 
     } 
     catch (FileNotFoundException e) 
     { 
      System.out.println("Error: could not open " + fileName); 
      System.exit(1); 
     } 

     int r = Drawing.DEF_HEIGHT, c = Drawing.DEF_WIDTH; 
     weights = new double[r][c]; 

     int i = 0, j = 0; 
     while (input.hasNext()) 
     { 
      weights[i][j] = input.nextDouble(); 
      j++; 
      if (j > weights[i].length - 1) 
      { 
       i++; 
       j = 0; 
      } 
     } 

     for (double[] a : weights) 
      System.out.println(Arrays.toString(a)); 

    } 

    /** 
    *1. Initialize a sum variable 
    *2. Multiply each index of weights by each index of bin 
    *3. Sum these values 
    *4. Return the activated sum 
    */ 
    public int feedforward(int[][] bin) //bin represents 2D array of binary values for a binary image 
    { 
     double sum = 0; 
     for (int i = 0; i < weights.length; i++) 
     { 
      for (int j = 0; j < weights[i].length; j++) 
       sum += weights[i][j] * bin[i][j]; 
     } 
     return activate(sum); 
    } 

    /** 
    *1. Generate a sigmoid (logistic) value from a sum 
    *2. "Digitize" the sigmoid value 
    *3. Return the digitized value, which corresponds to a number 
    */ 
    public int activate(double n) 
    { 
     double sig = 1.0/(1+Math.exp(-1*n)); 
     int digitized = 0; 

     if (sig < 0.1) 
      digitized = 0; 
     else if (sig >= 0.1 && sig < 0.2) 
      digitized = 1; 
     else if (sig >= 0.2 && sig < 0.3) 
      digitized = 2; 
     else if (sig >= 0.3 && sig < 0.4) 
      digitized = 3; 
     else if (sig >= 0.4 && sig < 0.5) 
      digitized = 4; 
     else if (sig >= 0.5 && sig < 0.6) 
      digitized = 5; 
     else if (sig >= 0.6 && sig < 0.7) 
      digitized = 6; 
     else if (sig >= 0.7 && sig < 0.8) 
      digitized = 7; 
     else if (sig >= 0.8 && sig < 0.9) 
      digitized = 8; 
     else if (sig >= 0.9) 
      digitized = 9; 

     System.out.println("Sigmoid value: " + sig + "\nDigitized value: " + digitized); 
     return digitized; 
    } 

    /** 
    * 1. Provide inputs and "known" answer 
    * 2. Guess according to the inputs using feedforward(inputs) 
    * 3. Compute the error 
    * 4. Adjust all weights according to the error and learning rate 
    */ 
    public void train(int[][] bin, int desired) 
    { 
     int guess = feedforward(bin); 
     int error = desired-guess; 

     for (int i = 0; i < weights.length; i++) 
     { 
      for (int j = 0; j < weights[i].length; j++) 
       weights[i][j] += LEARNING_RATE * error * bin[i][j]; 
     } 
    } 

} 

我使用不同的类来“训练”的神经元。这个其他类TrainingConsole.java基本上采用随机生成的组件的“training.txt”,为其提供训练示例(图像 - >二维二维数组),并根据错误,学习速率和相应的值调整权重对于bin:

import java.awt.image.BufferedImage; 
import java.io.*; 
import java.util.Arrays; 
import java.util.Scanner; 

import javax.imageio.ImageIO; 

public class TrainingConsole 
{ 

    private File folder; 
    private File data; 

    public TrainingConsole(String dataFileName, String folderName) 
    { 
     data = new File(dataFileName); 
     folder = new File(folderName); 
    } 

    public void changeFolder(String folderName) 
    { 
     folder = new File(folderName); 
    } 

    public void feedAll(int desired) 
    { 
     System.out.println(Arrays.toString(folder.listFiles())); 
     for (int i = 1; i < folder.listFiles().length; i++) //To exclude folder 
     { 
      BufferedImage img = new BufferedImage(Drawing.DEF_WIDTH,Drawing.DEF_HEIGHT,BufferedImage.TYPE_INT_RGB); 
      try 
      { 

       String name = folder.listFiles()[i].getName(); 
       if (name.substring(name.length()-4).equals(".png")) 
        img = ImageIO.read(folder.listFiles()[i]); 
      } 
      catch(IOException e) 
      {System.out.println("Error?");} 

      int[][] bin = new int[Drawing.DEF_WIDTH][Drawing.DEF_HEIGHT]; 

      if (img != null) 
      { 
       for (int y = 0; y < img.getHeight(); y++) 
       { 
        for (int x = 0; x < img.getWidth(); x++) 
        { 
         int rgb = img.getRGB(x,y); 
         //System.out.println(rgb); 
         if (rgb == -1) //White 
          bin[y][x] = 0; 
         else 
          bin[y][x] = 1; 
        } 
       } 
       for (int[] a : bin) 
        System.out.println(Arrays.toString(a)); 
       train(bin,desired); 
      } 
     } 
    } 

    public void train(int[][] bin, int desired) { 
     int guess = feedforward(bin); 
     int error = desired - guess; 

     Scanner input = null; 
     try { 
      input = new Scanner(data); 
     } catch (FileNotFoundException e) { 
      System.exit(1); 
     } 
     double[][] weights = new double[Drawing.DEF_HEIGHT][Drawing.DEF_WIDTH]; 
     int i = 0, j = 0; 
     while (input.hasNext() && i < Drawing.DEF_HEIGHT) { 
      weights[i][j] = input.nextDouble(); 
      j++; 
      if (j > weights[i].length - 1) { 
       i++; 
       j = 0; 
      } 
     } 

     for (int k = 0; k < weights.length; k++) { 
      for (int l = 0; l < weights[k].length; l++) 
       weights[k][l] += IMGNeuron.LEARNING_RATE * error * bin[k][l]; 
     } 

     data = new File(data.getName()); 
     PrintWriter output = null; 
     try { 
      output = new PrintWriter(data); 
     } catch (FileNotFoundException e) { 
      System.out.println("Cannot find data"); 
     } 
     for (int m = 0; m < weights.length; m++) { 
      for (int n = 0; n < weights[m].length - 1; n++) 
       output.print(weights[m][n] + " "); 
      output.print(weights[m][weights[m].length - 1]); 
      output.println(); 
     } 
     output.close(); 
    } 

    public int feedforward(int[][] bin) 
    { 
     double sum = 0; 

     Scanner input = null; 
     try 
     { 
      input = new Scanner(data); 
     } 
     catch(FileNotFoundException e) 
     { 
      System.out.println("Could not locate data"); 
     } 
     double[][] weights = new double[Drawing.DEF_HEIGHT][Drawing.DEF_WIDTH]; 
     int i = 0, j = 0; 
     while (i < Drawing.DEF_HEIGHT && j < Drawing.DEF_WIDTH) 
     { 
      //System.out.println("(" + i + " , " + j + ")"); 
      weights[i][j] = input.nextDouble(); 
      j++; 
      if (j > weights[i].length - 1) 
      { 
       i++; 
       j = 0; 
      } 
     } 

     for (int m = 0; m < weights.length; m++) 
     { 
      for (int n = 0; n < weights[m].length; n++) 
       sum += weights[m][n] * bin[m][n]; 
     } 
     return activate(sum); 
    } 

    public int activate(double n) 
    { 
     double sig = 1.0/(1+Math.exp(-1*n)); 
     int digitized = 0; 

     if (sig < 0.1) 
      digitized = 0; 
     else if (sig >= 0.1 && sig < 0.2) 
      digitized = 1; 
     else if (sig >= 0.2 && sig < 0.3) 
      digitized = 2; 
     else if (sig >= 0.3 && sig < 0.4) 
      digitized = 3; 
     else if (sig >= 0.4 && sig < 0.5) 
      digitized = 4; 
     else if (sig >= 0.5 && sig < 0.6) 
      digitized = 5; 
     else if (sig >= 0.6 && sig < 0.7) 
      digitized = 6; 
     else if (sig >= 0.7 && sig < 0.8) 
      digitized = 7; 
     else if (sig >= 0.8 && sig < 0.9) 
      digitized = 8; 
     else if (sig >= 0.9) 
      digitized = 9; 

     return digitized; 
    } 

    public static void main(String[] args) 
    { 
     Scanner input = new Scanner(System.in); 
     TrainingConsole trainer = new TrainingConsole("training.txt","Training_000"); 

     System.out.println("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------"); 
     System.out.println("Training Console"); 
     System.out.println("--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------"); 

     for (int i = 0; i <= 9; i++) { 
      //System.out.print("Folder with training data for desired = " + i + ", or enter \"skip\" to skip: "); 
      //String folderName = input.nextLine().trim(); 
      String folderName = "Training_00" + i; 
      //System.out.println(folderName); 
      if (!folderName.toLowerCase().equals("skip")) 
      { 
       trainer.changeFolder(folderName); 
//    System.out.print("Press enter to run: "); 
//    String noReason = input.nextLine(); 
       trainer.feedAll(i); 
      } 
      System.out.println("----------------------------------------------------------------------------------------------------ava----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------"); 
     } 
    } 

} 

对于后续的神经元构造,我传递“training.txt”作为权重矩阵。但是,这显然不起作用:enter image description here

请帮忙!我对神经网络和机器学习非常陌生。此时,我不知道自己做错了什么:我是否需要更多培训示例?我是否实施了不良的激活功能?任何意见,将不胜感激。另外,如果需要,随时可以要求额外的代码。

+0

也许我错过了它,但我只在您的网络中看到一个神经元。无论训练的数量如何,这都会表现得很差。你可能想阅读http://neuralnetworksanddeeplearning.com/chap1.html。它有一个与你想要做的非常相似的例子。 – Chill

+0

您也使用sigmoid(int激活函数)是不好的......那不是你怎么做multiclass – user2717954

回答

0

正如评论中指出的那样,有两个主要问题,我会在更详细的叙述它们。

  1. 你的整个模型是单感知,也就是说,你从你的输入空间(像素)创建一个线性模型类(数字)。这根本行不通,它不是现代意义上的神经网络。设计用于图像处理的“现代”NN将由神经元的千个组成,在中连接,其间具有非线性激活,可能以卷积核的形式排列(因为这是用于图像识别的最先进的体系结构)。

  2. 你应该解决多类问题,但你实际上排名。为了让NN分类为K类,你应该有K个输出神经元,每一个都会产生一个信号,解释为属于特定类的“概率”(不是严格的数学意义上的),因此为了分类 - arg max(最高值的神经元数)。

一旦你解决与整个架构,你应该开始得到合理的结果,这两个重要的问题,那么唯一缺失的部分是调整超参数和获得更多的训练数据。