ConRFT

在fine-tuning一个VLM使其执行robotic manipulation的时候, 可能会由于 有限且不一致 的demonstrations(特别是在contact-rich的环境中), 导致无法得到robust performance

问题:

  1. 在fine-tuning中严重依赖于数据集的质量与数量
  2. VLA需要有安全性和成本限制

提出reinforced fine-tuning:

  1. offline阶段, 使用监督学习(BehaviorClone)+Q-learning结合
  2. online阶段, 通过consistency policy的方式进行RL训练

Problem Setup and Preliminaries

定义:

  1. 是pretrained VLA model, 可以编码visual input(如, RGB image)以及language instructions
  2. : 是任务的trajectory
  3. : negative log-likelihood 或者Mean-Squared Error

在SFT(Supervise Fine-Tuning)任务中, 目的是用一个小的labeled demonstrations集合作为训练数据.

VLA目的是, 即最小化loss(NLL或者MSE)

定义MDP:

其中是state, 是action. 定义是environment transition probability. 是初始状态分布. 是reward, 使用作为discount factor. 作为policy, 需要maximize reward

Method

Stage 1: Offline Fine-tuning with Cal-ConRFT

pretrained VLA对zero-shot的novel robotic configurations缺乏泛化性, 因此在online之前, 使用小数据集的demonstrations(20-30 trajectory)

为了解决这个问题, 引入BehaviorClone loss(BC loss)来让model模仿演示中的行为, 提供了额外的supervisory signals(监督信号)

将BC loss和Cal-QL结合在consistency-based objective中, 提出了Cal-ConRFT的方法. 这个方法使用consistency policy作为action head来fine-tuning VLA, 解决两个主要的问题:

  1. pre-collected dataset中的inconsistency和sub-optimal的演示示例
  2. Diffusion policy相比, 这个方法更加轻量级

对于diffusion horizon(diffusion的时间范围) , 将其离散化为个子区间, 其边界为. 这种情况下的consistency policy为:

其中:

  • 参数化的consistency policy model, 从步噪声生成action
  • : diffusion noise step
  • : 经过步加噪声之后的action
  • : encoded state, 由参数化的pretrained VLA生成

那么consistency-based objective为:

其中:

    • 表示欧几里得距离
  • : 超参数

Stage 2: Online Fine-tuning with HIL-ConRFT

offline stage提供了初始化的policy, 但是由于其从small dataset中学习, 因此可能有limited performance

提出HIL-ConRFT, 通过与真实世界互动, 使用consistency policy来fine-tuning VLA

训练过程中, offline的数据集仍然保存, 同时使用replay buffer 来保存online数据, 使用symmetric sampling对每一个batch采样(每一个batch在中等量采样), 目的是减少offline的distribution-shift问题

因此直接的loss为:

使用consistency policy的loss为:

其中:

注意到这个和stage 1的consistency policy loss非常接近, 这可以快速的进行训练(代码修改少)

在online阶段, 降低增加:

  1. 确保policy能够不遗忘演示数据(continues to align with demonstration data)
  2. 降低BC loss以防止突然崩溃, 导致安全问题