TWM: Transformer-based World Models Are Happy With 100k Interactions
- RL笔记
- 2025-07-02
- 3热度
- 0评论
年份 2023 ICLR
作者 Jan Robine, Marc Höftmann, Tobias Uelwer, Stefan Harmeling
Department of Computer Science, Technical University of Dortmund, Germany
摘要
深度神经网络在许多强化学习设置中都取得了成功。然而,与人类学习者相比,他们过于依赖数据。为了构建一个样本高效的世界模型,我们以自回归的方式将 transformer 应用于现实世界的事件:不仅紧凑的潜在状态和采取的行动,而且经验或预测的奖励也被输入到 transformer 中,以便它可以灵活地关注不同时间步的所有三种模式。transformer 允许我们的世界模型直接访问以前的状态,而不是通过压缩的递归状态来查看它们。通过利用 Transformer-XL 架构,它能够在保持计算效率的同时学习长期依赖关系。我们基于 transformer 的世界模型 (TWM) 生成有意义的新体验,用于训练策略,该策略在 Atari 100k 基准测试中优于以前的无模型和基于模型的强化学习算法。我们的代码可在 https://github.com/jrobine/twm获取。
主要结论
把Dreamer V2里的RSSM换成了Transformer-XL,也就是世界模型不再使用RSSM来学习了,而是用Transformer学习,突然发现他不是简单的替换,他是使用Transformer来训练一个WorldModel(一个Model Free的Agent),然后在真正使用的时候其实不用Transformer
然后在损失函数等地方做了一定的针对性改进,取得了不错的效果,比如提出了平衡交叉熵损失函数
Balanced Dataset Sampling
是大于0的温度超参数,v表示一个计数器,随着采样的次数增加而增加
学习的内容
1、平衡交叉熵损失函数
2、Transformer擅长长期记忆保存
3、平均采样经验数据集的方法,能够更关注于最新的经验(这跟我之前的一个想法比较类似,但是感觉还能有其他的采集方法)
4、消融实验
疑问
1、但是我看实验中都是和比较一般的模型做的比较,好像都没跟Dreamer V2做实验,这都能发ICLR吗?
2、看到这个我就想起了模拟退火,这个参数能不能随着训练时间的增加则改变
1、采样方法,将数据集平均分成初、中、高三个级别,高级多采样,初级和中级部分采样(以史为鉴)
2、多搞几个buffer
3、既然要从经验数据集中学习世界模型,那么我们保存在buffer中的数据就得是优质的数据,如何界定优秀,可以想一个方法给buffer中的数据打分,根据其特征,甚至可以训练一个模型