TWM: Transformer-based World Models Are Happy With 100k Interactions

年份 2023 ICLR

作者 Jan Robine, Marc Höftmann, Tobias Uelwer, Stefan Harmeling

Department of Computer Science, Technical University of Dortmund, Germany

摘要

深度神经网络在许多强化学习设置中都取得了成功。然而,与人类学习者相比,他们过于依赖数据。为了构建一个样本高效的世界模型,我们以自回归的方式将 transformer 应用于现实世界的事件:不仅紧凑的潜在状态和采取的行动,而且经验或预测的奖励也被输入到 transformer 中,以便它可以灵活地关注不同时间步的所有三种模式。transformer 允许我们的世界模型直接访问以前的状态,而不是通过压缩的递归状态来查看它们。通过利用 Transformer-XL 架构,它能够在保持计算效率的同时学习长期依赖关系。我们基于 transformer 的世界模型 (TWM) 生成有意义的新体验,用于训练策略,该策略在 Atari 100k 基准测试中优于以前的无模型和基于模型的强化学习算法。我们的代码可在 https://github.com/jrobine/twm获取。

主要结论

把Dreamer V2里的RSSM换成了Transformer-XL,也就是世界模型不再使用RSSM来学习了,而是用Transformer学习,突然发现他不是简单的替换,他是使用Transformer来训练一个WorldModel(一个Model Free的Agent),然后在真正使用的时候其实不用Transformer

然后在损失函数等地方做了一定的针对性改进,取得了不错的效果,比如提出了平衡交叉熵损失函数

Balanced Dataset Sampling

是大于0的温度超参数,v表示一个计数器,随着采样的次数增加而增加

学习的内容

1、平衡交叉熵损失函数

2、Transformer擅长长期记忆保存

3、平均采样经验数据集的方法,能够更关注于最新的经验(这跟我之前的一个想法比较类似,但是感觉还能有其他的采集方法)

4、消融实验

疑问

1、但是我看实验中都是和比较一般的模型做的比较,好像都没跟Dreamer V2做实验,这都能发ICLR吗?

2、看到这个我就想起了模拟退火,这个参数能不能随着训练时间的增加则改变

1、采样方法,将数据集平均分成初、中、高三个级别,高级多采样,初级和中级部分采样(以史为鉴)

2、多搞几个buffer

3、既然要从经验数据集中学习世界模型,那么我们保存在buffer中的数据就得是优质的数据,如何界定优秀,可以想一个方法给buffer中的数据打分,根据其特征,甚至可以训练一个模型