Skip to content

时序差分学习

Temporal-Difference

TD 是什么

它最核心的思想可以先记成一句话:

不必等整条轨迹结束,只要看到一步转移,就用“当前观察到的一步奖励 + 对下一状态的当前估计”,来修正当前状态的价值估计。

如果再说得更口语一点:

TD 不等“标准答案”全部公布,它会边走边改。

这里最重要的两个关键词是:

  • temporal difference:不是拿“预测和最终真实结果”的差学习,而是拿“相邻时刻预测之间的差”学习;
  • bootstrapping:不是完全依赖真实完整回报,而是用“下一步的当前估计”来帮助更新当前估计。

回顾

  • MDP 里,一步交互可以写成 StAtSt+1,Rt+1
  • 贝尔曼公式 里,固定策略 π 下的状态价值函数记为 vπ(s)
  • 蒙特卡洛方法 里,我们用采样得到的完整回报 Gt 来估计价值。

TD 可以理解成:

不再等完整的 Gt 出来,而是用一步样本 Rt+1 加上下一状态的当前估计,马上更新当前状态。

约定:

  • vπ(s):固定策略 π 下真实的状态价值函数,是我们想估计的目标;
  • V(s):算法当前维护的估计值,用来近似 vπ(s)
  • Gt:从时刻 t 开始的真实折扣回报;

TD 在学什么

TD 最经典的任务是 策略评估:给定一个固定策略 π,估计每个状态的价值 vπ(s)

状态值函数定义为:

vπ(s)=Eπ[l=0γlRt+1+l|St=s]

这里:

  • St 是时刻 t 的状态;
  • Rt+1 是从 St 走到下一步后得到的奖励;
  • γ[0,1) 是折扣因子;
  • π 是固定策略。

这个定义的意思是:

一个状态值多少钱,取决于从这里出发将来能拿到多少折扣后的累计奖励。

从 return 到 Bellman 递推

先定义从时刻 t 开始的 return:

Gt=Rt+1+γRt+2+γ2Rt+3+

把它改写成递推形式,就是:

Gt=Rt+1+γGt+1

这条式子非常重要,因为它说明:

当前总回报 = 一步奖励 + 下一时刻总回报的折扣。

对上式取条件期望,就得到值函数满足的 Bellman 关系:

vπ(s)=Eπ[Rt+1+γvπ(St+1)|St=s]

TD 学习正是把这条“期望上的递推关系”,变成“样本上的在线更新规则”。

TD target 是怎么来的

如果你真的知道上面那个条件期望,那么就不需要强化学习了。现实里我们通常只能看到一条条经验样本。

假设某一步真实发生了这样一次转移:

StAtπ(St)St+1,reward=Rt+1

那么 Bellman 关系里的期望项

Eπ[Rt+1+γvπ(St+1)St=s]

就可以用一次真实样本近似。由于真实的 vπ(St+1) 也不知道,所以 TD 用当前估计 V(St+1) 代替它:

Rt+1+γV(St+1)

于是定义 TD target 为:

ytTD=Rt+1+γV(St+1)

这就是 TD 的核心目标值。

它的意思是:

当前状态不再等整局结束才打分,而是先看眼前这一步奖励,再加上对下一状态未来价值的当前估计。

如果 St+1 是终止状态,通常令 V(St+1)=0

TD(0) 的更新公式

有了 TD target,最经典的一步 TD,也就是 TD(0),更新规则就是:

V(St)V(St)+α(Rt+1+γV(St+1)V(St))

其中 α 是学习率。

这个式子也可以改写成:

V(St)(1α)V(St)+α(Rt+1+γV(St+1))

所以 TD(0) 本质上就是:

让当前估计朝着 TD target 靠近一点。

什么是 TD error

括号里的那一项:

δt=Rt+1+γV(St+1)V(St)

就叫 TD error,也叫 TD residual

它特别值得单独记住,因为几乎整个 TD 家族都围绕它转。

它的直觉很简单:

  • 如果 δt>0,说明当前状态值被低估了,应该往上调;
  • 如果 δt<0,说明当前状态值被高估了,应该往下调;
  • 如果 δt=0,说明当前估计已经和“一步奖励 + 下一状态估计”一致了。

于是 TD(0) 还可以写成:

V(St)V(St)+αδt

所以 TD 学习也可以看成一句话:

TD 学习可以看成不断根据 TD error 调整 V 的过程。

一个具体小例子

假设某一步里:

  • 当前状态估计值:V(St)=10
  • 奖励:Rt+1=2
  • 下一状态估计值:V(St+1)=12
  • 折扣因子:γ=0.9
  • 学习率:α=0.1

那么 TD target 是:

ytTD=2+0.9×12=12.8

TD error 是:

δt=12.810=2.8

于是更新后:

V(St)10+0.1×2.8=10.28

也就是说,这一步以后,你会把当前状态的价值估计从 10 稍微往上修到 10.28。

TD 和 Monte Carlo 的区别

TD 最容易和 Monte Carlo(MC)混在一起,但它们并不一样。

Monte Carlo 的想法

MC 会一直等到 episode 结束,再拿到真实完整回报:

Gt=Rt+1+γRt+2+γ2Rt+3+

然后更新:

V(St)V(St)+α(GtV(St))

所以 MC 的特点是:

  • 目标是完整真实回报;
  • 不做 bootstrapping;
  • 必须等 episode 结束后才能更新。

TD 的想法

TD 不等结束,只要走一步就更新:

V(St)V(St)+α(Rt+1+γV(St+1)V(St))

所以 TD 的特点是:

  • 不需要等完整轨迹结束;
  • 更新更及时;
  • 但目标里用了当前估计,因此会 bootstrapping。

一句话对比

可以把它们记成:

  • MC:用真实整段回报学;
  • TD:用一步奖励 + 下一步当前估计学。

也可以记成一句更经典的话:

TD 像 MC 一样采样,又像 DP 一样自举。

什么叫 bootstrapping

bootstrapping 的意思是:

用一个估计去更新另一个估计。

在 TD 里,这一点表现得非常明显:

V(St)V(St)+α(Rt+1+γV(St+1)V(St))

右边的 V(St+1) 并不是真正的未来回报真值,而只是我们当前手里的一个估计。也就是说,TD 不是拿“完整标准答案”来学,而是拿“下一步的当前估计”来帮助更新当前估计。

这就是 TD 的力量所在,也是它和 MC 最根本的区别之一。

n-step TD:不只看一步

TD(0) 只看一步。但你也可以不是只看一步,而是往前看 n 步。

这时定义 n-step return

Gt(n)=Rt+1+γRt+2++γn1Rt+n+γnV(St+n)

它的意思是:

  • 前面 n 步用真实奖励;
  • n 步之后,再接一个 bootstrap 的值函数估计。

于是更新可以写成:

V(St)V(St)+α(Gt(n)V(St))

这样看就很清楚了:

  • n=1 时,就是 TD(0);
  • n 一直大到 episode 结束时,就越来越接近 Monte Carlo。

TIP

如果在第 n 步之前 episode 已经终止,那么后面的 bootstrap 项 γnV(St+n) 就不再出现。

TD(λ):在短视和长视之间折中

既然一步 TD 和长回报各有优缺点,一个自然想法就是:

能不能把不同步长的目标揉在一起?

这就是 TD(λ) 的核心思想。

它引入一个参数 λ[0,1],用它来混合不同的 n-step return。对应的 λ-return 可以写成:

Gtλ=(1λ)n=1λn1Gt(n)

这个式子的意思是:

  • 同时使用很多个不同步长的 return;
  • 但离得越远,权重越小;
  • λ 决定你更偏向短视还是长视。

TIP

在 episodic 场景里,更严格的 forward-view 写法通常会在终止时刻截断,并把最后的完整回报单独补进去。

Gtλ=(1λ)n=1Tt1λn1Gt(n)+λTt1Gt

两个极端特别值得记:

λ=0

只剩下 n=1 的项,所以:

Gtλ=Gt(1)

也就是一步 TD。

λ1

会越来越接近更长回报;在 episodic 场景下,它趋向于更接近完整回报。

所以一句话记忆就是:

TD(λ) 是从 TD(0) 平滑过渡到更长回报目标的一座桥。

为什么 TD 很重要

TD 重要,不是因为它只有一个简单公式,而是因为它抓住了强化学习里一个非常实用的更新模式:

  • 只要有一步经验,就能立即学习;
  • 不需要等整局结束;
  • 不需要环境模型;
  • 还能利用已经学到的估计继续学习。

这使得 TD 成为很多后续方法的基础。无论是值函数学习、actor-critic,还是很多更复杂的 RL 算法,背后都会反复出现 TD target 和 TD error 的影子。

总结

TD 学习的核心,是用一步真实反馈和对下一状态的当前估计,来在线修正当前状态价值:

V(St)V(St)+α(Rt+1+γV(St+1)V(St))

与 Monte Carlo 相比,它不需要等整条轨迹结束;与动态规划相比,它又不需要环境模型。因此,TD 可以理解成一种“边交互、边自举、边更新”的价值学习方法。

TD target:

ytTD=Rt+1+γV(St+1)

TD error:

δt=Rt+1+γV(St+1)V(St)

TD(0) 更新:

V(St)V(St)+αδt