2017-07-07 146 views
1

我使用了chainerRL并尝试了Breakout v0。为什么我的DQN在Breakout v0的经纪人不学习?

我运行此代码。 它确实有效,但我的经纪人无法获得奖励(奖励总是低于5分)。

蟒蛇2.7 的Ubuntu 14.04

教我,为什么我不能。

也是我无法理解为什么继承人数量972> L5 = L.Linear(972,512)

import chainer 
import chainer.functions as F 
import chainer.links as L 
import chainerrl 
import gym 
import numpy as np 

from chainer import cuda 

import datetime 
from skimage.color import rgb2gray 
from skimage.transform import resize 

env = gym.make('Breakout-v0') 
obs = env.reset() 

print("observation space : {}".format(env.observation_space)) 
print("action space  : {}".format(env.action_space)) 

action = env.action_space.sample() 
obs, r, done, info = env.step(action) 
class QFunction(chainer.Chain): 
def __init__(self,obs_size, n_action): 
    super(QFunction, self).__init__(
     l1=L.Convolution2D(obs_size, 4, ksize=2,pad=1),#210x160 
     bn1=L.BatchNormalization(4), 
     l2=L.Convolution2D(4, 4, ksize=2,pad=1),#105x80 
     bn2=L.BatchNormalization(4), 
     #l3=L.Convolution2D(64, 64, ksize=2, pad=1),#100x100 
     #bn3=L.BatchNormalization(64), 
     #l4=L.Convolution2D(64, 3, ksize=2,pad=1),#50x50 
     # bn4=L.BatchNormalization(3), 

     l5=L.Linear(972, 512), 
     out=L.Linear(512, n_action, initialW=np.zeros((n_action, 512), dtype=np.float32)) 
    ) 

def __call__(self, x, test=False): 

    h1=F.relu(self.bn1(self.l1(x))) 
    h2=F.max_pooling_2d(F.relu(self.bn2(self.l2(h1))),2) 
    #h3=F.relu(self.bn3(self.l3(h2))) 
    #h4=F.max_pooling_2d(F.relu(self.bn4(self.l4(h3))),2) 
    #print h4.shape 

    return chainerrl.action_value.DiscreteActionValue(self.out(self.l5(h2))) 

n_action = env.action_space.n 
obs_size = env.observation_space.shape[0] #(210,160,3) 
q_func = QFunction(obs_size, n_action) 

optimizer = chainer.optimizers.Adam(eps=1e-2) 
optimizer.setup(q_func) 

gamma = 0.99 

explorer = chainerrl.explorers.ConstantEpsilonGreedy(
epsilon=0.2, random_action_func=env.action_space.sample) 

replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6) 

phi = lambda x: x.astype(np.float32, copy=False) 
agent = chainerrl.agents.DoubleDQN(
q_func, optimizer, replay_buffer, gamma, explorer, 
minibatch_size=4, replay_start_size=100, update_interval=10, 
target_update_interval=10, phi=phi) 

last_time = datetime.datetime.now() 
n_episodes = 10000 
for i in range(1, n_episodes + 1): 
obs = env.reset() 

reward = 0 
done = False 
R = 0 

while not done: 
    env.render() 
    action = agent.act_and_train(obs, reward) 
    obs, reward, done, _ = env.step(action) 


    if reward != 0: 
     R += reward 

elapsed_time = datetime.datetime.now() - last_time 
print('episode:', i, 
     'reward:', R, 
    ) 
last_time = datetime.datetime.now() 

if i % 100 == 0: 
    filename = 'agent_Breakout' + str(i) 
    agent.save(filename) 

agent.stop_episode_and_train(obs, reward, done) 
print('Finished.') 

回答

0

由于ChainerRL的作者,如果你想解决雅达利的环境中,我建议你从开始examples/ale/train_*.py并逐步进行自定义。深度强化学习对超参数和网络体系结构的变化非常敏感,如果您一次引入很多变化,则很难判断哪种变化会导致培训失败。

我试着在通过agent.get_statistics()打印统计信息时运行脚本,发现Q值越来越高,这表明培训进展不顺利。

$ python yourscript.py 
[2017-07-10 18:14:45,309] Making new env: Breakout-v0 
observation space : Box(210, 160, 3) 
action space  : Discrete(6) 
episode: 1 reward: 0 
[('average_q', 0.0), ('average_loss', 0.0)] 
episode: 2 reward: 1.0 
[('average_q', 0.0), ('average_loss', 0.0)] 
episode: 3 reward: 0 
[('average_q', 0.0), ('average_loss', 0.0)] 
episode: 4 reward: 0 
[('average_q', 0.0), ('average_loss', 0.0)] 
episode: 5 reward: 2.0 
[('average_q', 0.0), ('average_loss', 0.0)] 
episode: 6 reward: 0 
[('average_q', 0.0), ('average_loss', 0.0)] 
episode: 7 reward: 1.0 
[('average_q', 0.0), ('average_loss', 0.0)] 
episode: 8 reward: 2.0 
[('average_q', 0.0), ('average_loss', 0.0)] 
episode: 9 reward: 1.0 
[('average_q', 0.0), ('average_loss', 0.0)] 
episode: 10 reward: 2.0 
[('average_q', 0.05082079044988309), ('average_loss', 0.0028927958279822935)] 
episode: 11 reward: 4.0 
[('average_q', 7.09331367665307), ('average_loss', 0.0706595716528489)] 
episode: 12 reward: 0 
[('average_q', 17.418094266218915), ('average_loss', 0.251431955409951)] 
episode: 13 reward: 1.0 
[('average_q', 40.903169833428954), ('average_loss', 1.0959175910071859)] 
episode: 14 reward: 2.0 
[('average_q', 115.25579476118122), ('average_loss', 2.513677824600575)] 
episode: 15 reward: 2.0 
[('average_q', 258.7392539556941), ('average_loss', 6.20968827451279)] 
episode: 16 reward: 1.0 
[('average_q', 569.6735852049942), ('average_loss', 19.295426012437833)] 
episode: 17 reward: 4.0 
[('average_q', 1403.8461185742353), ('average_loss', 32.6092646561004)] 
episode: 18 reward: 1.0 
[('average_q', 2138.438909199657), ('average_loss', 44.90832410172697)] 
episode: 19 reward: 1.0 
[('average_q', 3112.752923036582), ('average_loss', 88.50687458947431)] 
episode: 20 reward: 1.0 
[('average_q', 4138.601621651058), ('average_loss', 106.09160137599618)] 
+0

好的,谢谢。我试图使用你的例子,但icant运行。 – KEN

+0

请告诉我如何使用你的例子。我已经安装了ALE – KEN

+0

你能描述你运行的命令和输出吗? – muupan