Github Project

Flow Matching

Preliminaries

Continuous Normalizing Flows(CNF):

数据点.

two important objects:

  1. Probability density path:
  2. time-dependent vector field:

一个vector field 可以构筑一个time-dependent diffeomorphic map, 称作, 通过ordinary differential equation(ODE)定义:

可以认为是在时刻的状态, 对应Diffusion过程中在时刻的噪声图.

CNF可以通过Push Forward公式将纯噪声转换为较为复杂的分布(注意, 这里和Diffusion中定义不一样, 在这里是纯噪声, 而Diffusion中是真实样本.):

如果flow 满足上述公式, 那么我们可以认为vector field generate a probability density path

Flow Matching

从一个简单的纯噪声分布开始, 如, 令和目标分布大致相等. Flow matching就是去匹配这样一条路径, 从”流动”向.

Flow matching的目标函数为:

简单而言, Flow matching就是训练一个的回归, 然后就可以使用生成probability path . 然后就可以采样.

由于未知, 因此Loss无法计算.

Constructing , from conditional probability paths and vector fields

给定sample , 使用表示一个条件概率路径. 完成计算后使用积分来得出边缘分布, 从而得出probability path.

需要满足:

  • 时, 即完全噪声的分布, 与无关
  • 时, 是在附近的分布(如, ). 也就是说,

那么积分为:

特别的:

同样的, 可以对求积分:

其中是generate 的向量场.

Conditional Flow Matching

由于积分还是无法计算, 因此提出了一个更加简单的目标函数:

只要能够计算就可以无偏估计. 由于是在单个sample上进行采样, 因此很容易计算.

证明了, 因此优化CFM(Conditional Flow Matching)在期望上等同于优化FM(Flow Matching).

Implementation

# 导入库
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from zuko.utils import odeint
 
# 定义OTFlowMatching类
class OTFlowMatching:
    def __init__(self, sig_min=0.001):
        self.sig_min = sig_min
        self.eps = 1e-5
 
    def psi_t(self, x, x_1, t):
        """条件流函数"""
        return (1 - (1 - self.sig_min) * t) * x + t * x_1
 
    def loss(self, v_t, x_1):
        """计算条件流匹配损失"""
        t = (torch.rand(1, device=x_1.device) + 
             torch.arange(len(x_1), device=x_1.device) / len(x_1)) % (1 - self.eps)
        t = t[:, None].expand(x_1.shape)
        x_0 = torch.randn_like(x_1)
        v_psi = v_t(t[:,0], self.psi_t(x_0, x_1, t))
        d_psi = x_1 - (1 - self.sig_min) * x_0
        return torch.mean((v_psi - d_psi) ** 2)
 
# 定义条件向量场模型
class CondVF(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
 
    def forward(self, t, x):
        return self.net(t, x)
 
    def wrapper(self, t, x):
        t = t * torch.ones(len(x), device=x.device)
        return self(t, x)
 
    def decode(self, x_0):
        """从先验分布解码到数据分布"""
        return odeint(self.wrapper, x_0, 0., 1., self.parameters())
 
# 定义神经网络模型
class Net(nn.Module):
    def __init__(self, in_dim, out_dim, h_dims, n_frequencies):
        super().__init__()
        self.n_frequencies = n_frequencies
        ins = [in_dim + 2 * n_frequencies] + h_dims
        outs = h_dims + [out_dim]
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(in_d, out_d), nn.LeakyReLU()) 
            for in_d, out_d in zip(ins, outs)])
        self.top = nn.Linear(out_dim, out_dim)
 
    def time_encoder(self, t):
        """时间编码"""
        freq = 2 * torch.pi * torch.arange(self.n_frequencies, device=t.device)
        t = freq * t[..., None]
        return torch.cat((t.cos(), t.sin()), dim=-1)
 
    def forward(self, t, x):
        t_enc = self.time_encoder(t)
        x = torch.cat((x, t_enc), dim=-1)
        for layer in self.layers:
            x = layer(x)
        return self.top(x)
 
# 训练流程
def train(model, v_t, dataset):
    # 数据准备
    dataset = make_swiss_roll(n_points)[..., [0,2]]  # 示例数据集
    dataloader = DataLoader(dataset, batch_size=2048)
    
    optimizer = torch.optim.Adam(v_t.parameters(), lr=1e-3)
 
    # 训练循环
    for epoch in range(n_epochs):
        for batch in dataloader:
            x_1 = batch[0].to(device)
            loss = model.loss(v_t, x_1)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
 
# 采样生成样本
def sample(v_t, n_samples=10000):
    with torch.no_grad():
        x_0 = torch.randn(n_samples, 2, device=device)
        x_1_hat = v_t.decode(x_0)
    return x_1_hat.cpu().numpy()

Usage:

# 初始化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = OTFlowMatching()
net = Net(2, 2, [512]*5, 10).to(device)
v_t = CondVF(net)
 
# 训练
train(model, v_t, dataset)
 
# 采样和可视化
samples = sample(v_t)
plt.hist2d(samples[:,0], samples[:,1], bins=128)