2017-08-31 40 views
1

我正在尝试在Unity中实现DeepQ学习模拟蚂蚁。遵循雅阁Animat示例,我设法实现了该算法的要点。捕获状态作为DeepQ中的数组使用Accord.net学习

现在我的代理有5个状态输入 - 其中三个来自探测前方障碍物的传感器(RayCasts在Unity中),剩下的两个是它在地图上的X和Y位置。

我的问题是,qLearning.GetAction(currentState)只接受一个int作为参数。如何使用数组(或张量)为代理当前状态实现我的算法?

这是我的代码:

using System.Collections; 
using System.Collections.Generic; 
using UnityEngine; 
using Accord.MachineLearning; 
using System; 

public class AntManager : MonoBehaviour { 
    float direction = 0.01f; 
    float rotation = 0; 

    // learning settings 
    int learningIterations = 100; 
    private double explorationRate = 0.5; 
    private double learningRate = 0.5; 

    private double moveReward = 0; 
    private double wallReward = -1; 
    private double goalReward = 1; 

    private float lastDistance = 0; 

    private RaycastHit hit; 
    private int hitInteger = 0; 

    // Q-Learning algorithm 
    private QLearning qLearning = null; 


    // Use this for initialization 
    void Start() { 
     qLearning = new QLearning(256, 4, new TabuSearchExploration(4, new EpsilonGreedyExploration(explorationRate))); 
    } 

    // Update is called once per frame 
    void Update() {   

     // curent coordinates of the agent 
     float agentCurrentX = transform.position.x; 
     float agentCurrentY = transform.position.y; 
     // exploration policy 
     TabuSearchExploration tabuPolicy = (TabuSearchExploration)qLearning.ExplorationPolicy; 

     EpsilonGreedyExploration explorationPolicy = (EpsilonGreedyExploration)tabuPolicy.BasePolicy; 

     // set exploration rate for this iteration 
     explorationPolicy.Epsilon = explorationRate - learningIterations * explorationRate; 
     // set learning rate for this iteration 
     qLearning.LearningRate = learningRate - learningIterations * learningRate; 
     // clear tabu list 
     tabuPolicy.ResetTabuList(); 

     // get agent's current state 
     int currentState = ((int)Math.Round(transform.position.x, 0) + (int)Math.Round(transform.position.y, 0) + hitInteger); 
     // get the action for this state 
     int action = qLearning.GetAction(currentState); 
     // update agent's current position and get his reward 
     double reward = UpdateAgentPosition(ref agentCurrentX, ref agentCurrentY, action); 
     // get agent's next state 
     int nextState = currentState; 
     // do learning of the agent - update his Q-function 
     qLearning.UpdateState(currentState, action, reward, nextState); 

     // set tabu action 
     tabuPolicy.SetTabuAction((action + 2) % 4, 1); 


    } 

    // Update agent position and return reward for the move 
    private double UpdateAgentPosition(ref float currentX, ref float currentY, int action) 
    { 
     // default reward is equal to moving reward 
     double reward = moveReward; 
     GameObject food = GameObject.FindGameObjectWithTag("Food"); 

     float distance = Vector3.Distance(transform.position, food.transform.position); 

     if (distance < lastDistance) 
      reward = 0.2f; 

     lastDistance = distance; 

     Debug.Log(distance); 

     switch (action) 
     { 
      case 0:   // go to north (up) 
       rotation += -1f; 
       break; 
      case 1:   // go to east (right) 
       rotation += 1f; 
       break; 
      case 2:   // go to south (down) 
       rotation += 1f; 
       break; 
      case 3:   // go to west (left) 
       rotation += -1f; 
       break; 
     } 

     //transform.eulerAngles = new Vector3(10, rotation, 0); 
     transform.Rotate(0, rotation * Time.deltaTime, 0); 
     transform.Translate(new Vector3(0, 0, 0.01f)); 



     float newX = transform.localRotation.x; 
     float newY = transform.localRotation.y; 

     Ray sensorForward = new Ray(transform.position, transform.forward); 
     Debug.DrawRay(transform.position, transform.forward * 1); 

     if (Physics.Raycast(sensorForward, out hit, 1)) 
     { 
      if (hit.collider.tag != "Terrain") 
      { 
       Debug.Log("Sensor Forward hit!"); 

       reward = wallReward; 
      } 
      if (hit.collider.tag == "Food") 
      { 
       Debug.Log("Sensor Found Food!"); 
       Destroy(food); 
       reward = goalReward; 
       hitInteger = 1; 
      } 
      hitInteger = 0; 
     } 

     return reward; 
    } 
} 

回答

0

documentation提供此作为一个例子:

c1 | (c2 << 1) | (c3 << 2) | (c4 << 3) | (c5 << 4) | (c6 << 5) | (c7 << 6) | (c8 << 7) 

这似乎是位移动两个值的整数进入状态的二进制编码。您的代码可能需要这样的事:

int currentState = ((int)Math.Round(transform.position.x, 0) | ((int)Math.Round(transform.position.y, 0) << 1) | (hitInteger << 2)) 

但是,您首先需要将您的状态映射到二元变量,所以这段代码将只与一个2x2的网格工作。尽管该示例声明了整数,但它们是二进制值:将位移为2或更大的值是没有意义的。

Convert.ToString(1 | (0 << 1) | (1 << 2), 2) 

以可视化的状态,一个有用的方法是直接在二进制找