我对this SARSA FA有疑问。

在输入单元格142中,我看到了此修改后的更新

w += alpha * (reward - discount * q_hat_next) * q_hat_grad


其中,q_hat_nextQ(S', a')q_hat_gradQ(S, a)的派生(假定S, a, R, S' a'序列)。

我的问题是更新不应该这样吗?

w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad


修改后的更新背后的直觉是什么?

最佳答案

我认为你是对的。我也希望更新包含TD错误项,应为reward + discount * q_hat_next - q_hat

供参考,这是实现:

if done: # (terminal state reached)
   w += alpha*(reward - q_hat) * q_hat_grad
   break
else:
   next_action = policy(env, w, next_state, epsilon)
   q_hat_next = approx(w, next_state, next_action)
   w += alpha*(reward - discount*q_hat_next)*q_hat_grad
   state = next_state


这是来自Reinforcement Learning:An Introduction (by Sutton & Barto)的伪代码(第171页):

machine-learning - 车杆的SARSA值近似值-LMLPHP

由于实现为TD(0),因此n为1。然后可以简化伪代码中的更新:

w <- w + a[G - v(S_t,w)] * dv(S_t,w)


变为(通过替换G == reward + discount*v(S_t+1,w))

w <- w + a[reward + discount*v(S_t+1,w) - v(S_t,w)] * dv(S_t,w)


或在原始代码示例中使用变量名称:

w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad


我最终得到了与您相同的更新公式。看起来像是非终端状态更新中的错误。

只有末尾的情况(如果done为true)才是正确的,因为根据定义,q_hat_next始终为0,因为情节结束了,无法获得更多奖励。

08-04 09:21