这是我输入的用于分类由鸟类,狗和猫组成的类的代码。它与二进制分类的代码相同,但是当我添加另一个类并将编译方法的损失函数更改为使用categorical_Crossentropy时,它给了我以下错误(=在代码结尾处)。任何人都可以解释这里有什么问题或我犯的错误吗?Keras CNN用于多类别分类交叉熵损失函数
# Importing Keras and Tensorflow modules
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Conv2D
from keras.layers import MaxPooling2D
from keras.layers import Flatten
from keras.layers import Dense
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
from keras.utils.np_utils import to_categorical
import os.path
# Initilize the CNN
classifier = Sequential()
# Step 1 - Convolution
classifier.add(Conv2D(32, (3, 3), input_shape = (64, 64, 3), activation = 'relu'))
# Step 2 - Pooling
classifier.add(MaxPooling2D(pool_size = (2, 2)))
# Step 2(b) - Add 2nd Convolution Layer making it Deep followed by a Pooling Layer
classifier.add(Conv2D(32, (3, 3), activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))
# Step 3 - Flattening
classifier.add(Flatten())
# Step 4 - Fully Connected Neural Network
# Hidden Layer - Activation Function RELU
classifier.add(Dense(units = 128, activation = 'relu'))
# Output Layer - Activation Function Softmax(to clasify multiple classes)
classifier.add(Dense(units = 1, activation = 'softmax'))
# Compile the CNN
# Categorical Crossentropy - to classify between multiple classes of images
classifier.compile(optimizer = 'adam', loss = 'categorical_crossentropy',
metrics = ['accuracy'])
# Image Augmentation and Training Section
# Image Augmentation to prevent Overfitting (Applying random transformation on
images to train set.ie.
# scalling, rotating and streching)
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
training_set = train_datagen.flow_from_directory(
'dataset/training_set',
target_size=(64, 64),
batch_size=8,
class_mode='categorical')
test_set = test_datagen.flow_from_directory(
'dataset/test_set',
target_size=(64, 64),
batch_size=8,
class_mode='categorical')
#Fit the clasifier on the CNN data
if(os.path.isfile('my_model.h5') == False):
classifier.fit_generator(
training_set,
steps_per_epoch=8000,
epochs=2,
validation_data=test_set,
validation_steps=2000
)
# Save the generated model to my_model.h5
classifier.save('my_model.h5')
else:
classifier = load_model('my_model.h5')
谢谢!它的作品我不知道我是如何错过的,但我是深度学习的新手!感谢帮助! – BasuruK