paper
Motivation: 尝试解决Diffusion Policy等方法对机器人action的训练效果低下的问题(尝试解决loss collapse问题)
loss collapse
损失崩塌: 由于condition差距过小, 导致模型忽略了condition, 不学习条件概率而去学习边缘概率: 由于条件过于相似, 模型忽略了条件, 直接学习了”move”这个action. 均有0.5的概率进行向左或向右移动(忽略了指令的条件)
数学推导: 根据Loss Collapse部分的公式, 标准的Flow Matching的损失函数为:
在这个loss计算中, 需要提供这些内容: 条件, 噪声采样, 原始action . 但是注意, 噪声分布与条件无关.
当模型无法区分两个相似的condition 时, 数学表示为时, loss函数的梯度会有:
其中:
- 上限
- 上限
于是模型对相似condition的loss优化相同, 最终目标loss函数会退化成边缘分布, 学习的向量场与输入的条件c无关:
最终会导致在training的时候看起来模型能力比较好, 但是inference的时候发现模型没有学到任何成功的内容
Rather than adopting a standard Gaussian prior q(z), Cocos anchors the source distribution around the semantics of each condition q(z|c), theoretically preventing training loss collapse and forcing the policy network to remain responsive to condition inputs.
为了解决上面说到的loss collapse的问题, 文章提出了一种 condition-conditioned source distribution(cocos) 的方法: noise不再是一个标准的正态分布, 而是一个锚定在给定condition周围的一个分布
pipeline:
这个是Cocos中创建与条件相关的噪声的方法.
首先将language instructions和images提取embeddings:
注意, 该步骤不参与loss的计算
然后, 设计一个AE, 有编码器()和解码器(). 参考这一部分, AutoEncoder的Encoder和Decoder都是由一个single-layer的Transformer组成.
由于文章中并没有给出Transformer层的具体架构, 因此根据pipeline猜测, 最可能的架构为:
- T5-base和DINOv2提取的features拼接一个nn.Embedding的可学习的query进行Transformer layer(self attention + MLP)的过程
- 把query位置的Encoder输出作为condition(可能会使用MLP投影到相应维度), 用于噪声采样
- 根据其他的embedding(T5和DINOv2的embedding对应的位置)送给Decoder进行Self Attention + MLP, 将输出与原始的embedding vector做余弦相似度的loss
autoencoding objective
对Decoder重建得到的embedding vectors与原始输入的Feature Embeddings进行reconstruction loss, 用于训练Encoder:
However, in practical scenarios, this two-step pipeline may introduce additional inflexibility.
default的settings是, 首先训练AE, 然后再固定AE的权重, 使用AE训练Flow Matching(或者说, Policy Model)
但是使用2-stage的方法有一定缺点:
- 不灵活: 无法针对困难的条件进行特调, 只是学习了一个平均的一个condition(针对输入平均, 没有针对难度进行特调)
- 流程繁琐
因此提出了一个端到端的训练策略. 但是由于在训练的时候, 会更新原始分布, 导致得到的policy model不稳定, 于是提出了一个方法: EMA
Exponential Moving Average
Exponential Moving Average
Exponential Moving Average
指数移动平均 (EMA) 是一种平滑技术,它通过维护一个模型的两个副本来解决训练不稳定的问题:
- 在线网络 (Online Network): 正常接收梯度并快速更新的网络。
- 目标网络 (Target Network): 从不接收梯度,其权重是“在线网络”过去所有权重的一个指数加权平均。
它的工作机制是,在每次更新“在线网络”后,都通过以下公式极其缓慢地更新“目标网络”:
Link to original通过EMA算法更新权重, 使权重的变化尽量保持平滑.