我对this SARSA FA有疑问。
在输入单元格142中,我看到了此修改后的更新
w += alpha * (reward - discount * q_hat_next) * q_hat_grad
其中,
q_hat_next
是Q(S', a')
,q_hat_grad
是Q(S, a)
的派生(假定S, a, R, S' a'
序列)。我的问题是更新不应该这样吗?
w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad
修改后的更新背后的直觉是什么?
最佳答案
我认为你是对的。我也希望更新包含TD错误项,应为reward + discount * q_hat_next - q_hat
。
供参考,这是实现:
if done: # (terminal state reached)
w += alpha*(reward - q_hat) * q_hat_grad
break
else:
next_action = policy(env, w, next_state, epsilon)
q_hat_next = approx(w, next_state, next_action)
w += alpha*(reward - discount*q_hat_next)*q_hat_grad
state = next_state
这是来自Reinforcement Learning:An Introduction (by Sutton & Barto)的伪代码(第171页):
由于实现为TD(0),因此
n
为1。然后可以简化伪代码中的更新:w <- w + a[G - v(S_t,w)] * dv(S_t,w)
变为(通过替换
G == reward + discount*v(S_t+1,w))
)w <- w + a[reward + discount*v(S_t+1,w) - v(S_t,w)] * dv(S_t,w)
或在原始代码示例中使用变量名称:
w += alpha * (reward + discount * q_hat_next - q_hat) * q_hat_grad
我最终得到了与您相同的更新公式。看起来像是非终端状态更新中的错误。
只有末尾的情况(如果
done
为true)才是正确的,因为根据定义,q_hat_next
始终为0,因为情节结束了,无法获得更多奖励。