paper
Motivation: 将NLP领域的遮罩预测方法(以BERT为例子)应用到视觉领域
- 使用非对称的Encoder-Decoder框架
- 使用高遮盖率, 迫使模型学习全局的语义信息, 从而克服图像的空间冗余
- 高可迁移能力, 高可拓展性
Related Work
Convolutions typically operate on regular grids and it is not straightforward to integrate ‘indicators’ such as mask tokens or positional embeddings into convolutional networks.
CNN网络架构导致了很难将位置编码或者掩码引入, 因此在架构上很难做到学习被移除的内容
但是现在由于ViT模型架构的产生, 能够引入Transformer模型来支持在图像中嵌入Mask, 因此可以实现在视觉任务中的Masked Auto-Encoder
In vision, the decoder reconstructs pixels, hence its output is of a lower semantic level than common recognition tasks. This is in contrast to language, where the decoder predicts missing words that contain rich semantic information.
语言是高语义级别的decoder, 因为每一个文字都含有丰富的语义信息. decoder重建的时候, 从latent space的隐变量转换成输出的过程比较简单(都含有丰富语义)
但是对于图片不同. 图片的latent space有丰富的语义信息, 但是最终的输出是pixel, 语义信息非常稀少. 如果使用架构简单的decoder(如, MLP), 那么可能迫使Encoder学习非常local的信息(如, 纹理, 光影等pixel level的信息, 而忽略了全局信息以及语义信息).
优点:
- 由于Mask掉了大量(75%)的输入, 因此Encoder的训练更加轻量级(输入少了), 因此很容易Scaling
- 泛化能力强. Scaling并扩大训练数据可以增强模型能力, 与LLM相似
Pipeline:
Approach:
Masking
使用类似ViT的方法, 对图片进行patch, 然后随机mask
MAE encoder
是一个ViT, 但是仅针对没有被mask掉的部分.
- Embedding Image Patch + Positional Embedding
- Transformer blocks
MAE decoder
将image embeddings和一个特殊的embedding vector(mask embedding)按照顺序进行拼接, 得到最终的
input_embeds, 然后送给Transformer进行Self Attention, 最终得到一个完整的image的embeddings.随后, 使用MLP或其他方法将hidden states重建成为image
Simple implementation.
这一部分是比较巧妙的地方, 使用了一些trick, 能够加速mask的replace.
首先, 使用randomly shuffle随机打乱patch embeds:
input_patches = patch(image) # (bs, seq_len, image_dim) shuffle_indices = torch.rand(bs, seq_len).argsort(dim=1) unshuffle_indices = shuffle_indices.argsort(dim=1) shuffle_patches = torch.gather(input_patches, 1, shuffled_indices.unsqueeze(-1).expand(-1, -1, image_dim)) num_visible = int(0.25 * seq_len) # Maksed 75% patches visible_patches = shuffle_patches[:, :num_visible, :] visible_embeds = transformer(visible_patches) # Concat visible patches and mask embeds # masked_token = nn.Parameter(...) full_shuffled_embeds = torch.cat([visible_embeds, masked_token.expand(bs, seq_len - num_visible, -1)]) full_unshuffled_embeds = torch.gather(full_shuffled_embeds, 1, unshuffled_indices.unsqueeze(-1).expand(-1, -1, hidden_dim)) # decode by Transformer Blocks return decoder(full_unshuffled_embeds)上面的这个是全部的伪代码. 这个tricky的点在于使用了
rand+argsort配合torch.gather快速的记录了mask的顺序, 并快速进行了采样shuffled作用机理
步骤 1: 生成随机打乱索引 (
shuffled_indices)这个索引的作用是:告诉我如何将原始序列打乱。
生成随机数: 我们为4个位置各自生成一个随机数。假设生成的随机数是:
- 位置0:
0.7- 位置1:
0.1- 位置2:
0.9- 位置3:
0.4对随机数进行排序 (
argsort):argsort操作会返回原始数组中元素按从小到大排序后的索引。
- 最小的数是
0.1,它在位置1。- 第二小的数是
0.4,它在位置3。- 第三小的数是
0.7,它在位置0。- 最大的数是
0.9,它在位置2。因此,
argsort的结果就是[1, 3, 0, 2]。得到
shuffled_indices:shuffled_indices = [1, 3, 0, 2]这个索引的含义是: “新的序列中,第0个位置应该放原始序列的第1个元素;第1个位置应该放原始序列的第3个元素;第2个位置应该放原始序列的第0个元素;第3个位置应该放原始序列的第2个元素。”
应用
shuffled_indices如果我们用这个索引来打乱原始序列
[P0, P1, P2, P3],我们会得到一个新的序列shuffled_sequence:
shuffled_sequence[0] = original_sequence[1] = P1shuffled_sequence[1] = original_sequence[3] = P3shuffled_sequence[2] = original_sequence[0] = P0shuffled_sequence[3] = original_sequence[2] = P2所以,打乱后的序列是:
[P1, P3, P0, P2]
步骤 2: 生成恢复顺序索引 (
unshuffle_indices)这个索引的作用是:告诉我如何将打乱后的序列恢复到原始顺序。
它的生成方法非常巧妙:对
shuffled_indices本身再做一次argsort操作。
我们的
shuffled_indices是:[1, 3, 0, 2]对它进行
argsort:
shuffled_indices中最小的数是0,它在位置2。- 第二小的数是
1,它在位置0。- 第三小的数是
2,它在位置3。- 最大的数是
3,它在位置1。因此,
argsort的结果是[2, 0, 3, 1]。得到
unshuffle_indices:unshuffle_indices = [2, 0, 3, 1]这个索引的含义是: “恢复后的序列中,第0个位置应该放打乱序列的第2个元素;第1个位置应该放打乱序列的第0个元素;第2个位置应该放打乱序列的第3个元素;第3个位置应该放打乱序列的第1个元素。”
应用
unshuffle_indices现在,我们用这个索引来恢复打乱后的序列
shuffled_sequence = [P1, P3, P0, P2]:
unshuffled_sequence[0] = shuffled_sequence[2] = P0unshuffled_sequence[1] = shuffled_sequence[0] = P1unshuffled_sequence[2] = shuffled_sequence[3] = P2unshuffled_sequence[3] = shuffled_sequence[1] = P3所以,恢复后的序列是:
[P0, P1, P2, P3]。 我们成功地恢复了原始顺序!