Block Flow Matching 详解
本文档详细讲解 DiffRhythm2 中的 Block Flow Matching 机制,包括数学原理、训练流程、推理流程和代码实现。
目录
- 概述
- 背景知识
- Flow Matching 数学原理
- Block-wise 生成策略
- 训练机制详解
- 推理机制详解
- KV Cache 机制
- Classifier-Free Guidance
- 代码实现解析
- 与其他方法对比
1. 概述
1.1 什么是 Block Flow Matching?
Block Flow Matching 是一种半自回归生成架构,专为长序列音频生成设计。它结合了两种生成范式的优点:
| 生成范式 | 特点 | 优势 | 劣势 |
|---|---|---|---|
| 自回归 (AR) | 逐 token 生成 | 可处理任意长度 | 速度慢,误差累积 |
| 非自回归 (NAR) | 一次性并行生成 | 速度快,质量高 | 难以处理长序列 |
| Block Flow Matching | 逐 block 生成,block 内 NAR | 兼顾两者优势 | 需要精心设计 |
1.2 核心思想
┌─────────────────────────────────────────────────────────────┐
│ Block Flow Matching │
├─────────────────────────────────────────────────────────────┤
│ │
│ 歌词条件 ──────────────────────────────────────────────► │
│ ↓ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Block 0 │ → │ Block 1 │ → │ Block 2 │ → ... │
│ │ (NAR) │ │ (NAR) │ │ (NAR) │ │
│ └─────────┘ └─────────┘ └─────────┘ │
│ ↓ ↓ ↓ │
│ KV Cache ────► KV Cache ────► KV Cache │
│ │
│ 每个 Block: 10帧 × 64维 = 2秒音频 │
└─────────────────────────────────────────────────────────────┘
关键设计:
- Block 间:自回归(顺序生成,利用历史信息)
- Block 内:非自回归(通过 Flow Matching 并行去噪)
2. 背景知识
2.1 扩散模型回顾
传统扩散模型(如 DDPM)通过离散去噪步骤生成数据:
噪声 x_T → x_{T-1} → x_{T-2} → ... → x_1 → x_0 (干净数据)
问题:需要大量步骤(50-1000步),速度慢。
2.2 Flow Matching 的改进
Flow Matching 将离散步骤替换为连续的 ODE 轨迹:
噪声 x_0 ──────────────────────────────► x_1 (干净数据)
t=0 t=1
连续变换路径(ODE)
优势:
- 采样步数大幅减少(16步 vs 1000步)
- 可以使用高效的 ODE 求解器
- 支持精确的似然计算
2.3 为什么需要 Block-wise?
音乐生成的挑战:
- 一首歌 3-4 分钟 = 180-240 秒
- Latent 帧率 5Hz → 900-1200 帧
- 直接生成整首歌:显存爆炸 + 质量下降
解决方案:分 Block 生成,每个 Block 约 2 秒。
3. Flow Matching 数学原理
3.1 概率路径定义
Flow Matching 定义了从噪声分布 p_0 到数据分布 p_1 的线性插值路径:
x_t = (1 - t) · x_0 + t · x_1
其中:
x_0 ~ N(0, I):标准高斯噪声x_1:真实数据(mel-spectrogram latent)t ∈ [0, 1]:时间参数
3.2 速度场(Velocity Field)
对路径求导,得到速度场:
v(x_t, t) = dx_t/dt = x_1 - x_0
这是一个常数速度场,表示从噪声到数据的直线运动。
3.3 训练目标
模型 v_θ 学习预测速度场:
L_FM = E_{t, x_0, x_1} [ || v_θ(x_t, t) - (x_1 - x_0) ||² ]
直观理解:
- 给定带噪声的
x_t和时间t - 模型预测"应该往哪个方向走"
- 目标是走向干净数据
x_1
3.4 采样过程(ODE 求解)
推理时,从噪声出发,沿着学到的速度场积分:
dx/dt = v_θ(x, t)
x(0) = x_0 ~ N(0, I)
x(1) = 生成的干净数据
使用 ODE 求解器(如 Euler 方法):
# 伪代码
x = noise
for t in [0, dt, 2*dt, ..., 1]:
v = model(x, t) # 预测速度
x = x + v * dt # 沿速度方向移动
return x
3.5 图解
t=0 t=0.5 t=1
│ │ │
▼ ▼ ▼
┌─────────┐ ┌─────────┐ ┌─────────┐
│ ░░░░░░░ │ │ ▒▒▒▒▒▒▒ │ │ ███████ │
│ ░ 噪声 ░ │ ────────► │ ▒ 中间 ▒ │ ────────► │ █ 干净 █ │
│ ░░░░░░░ │ v_θ │ ▒▒▒▒▒▒▒ │ v_θ │ ███████ │
└─────────┘ └─────────┘ └─────────┘
│ │ │
└───────────────────────┴──────────────────────┘
ODE 轨迹
4. Block-wise 生成策略
4.1 序列布局
DiffRhythm2 将整个序列组织为三部分:
┌──────────────────────────────────────────────────────────────────┐
│ 完整输入序列 │
├────────────┬─────────────────────────┬──────────────────────────┤
│ Lyrics │ Clean Blocks │ Noisy Blocks │
│ (歌词) │ (干净音频 latent) │ (带噪音频 latent) │
├────────────┼─────────────────────────┼──────────────────────────┤
│ L tokens │ N × block_size frames │ N × block_size frames │
│ (可变) │ (历史已生成) │ (当前待生成) │
└────────────┴─────────────────────────┴──────────────────────────┘
4.2 Block 参数
| 参数 | 默认值 | 说明 |
|---|---|---|
block_size |
10 | 每个 block 的帧数 |
latent_rate |
5 Hz | Latent 帧率 |
block_duration |
2 秒 | 每个 block 对应的音频时长 |
num_history_block |
可配置 | 保留的历史 block 数量 |
4.3 生成流程图
时间线:
Block 0 Block 1 Block 2 Block 3
│ │ │ │
▼ ▼ ▼ ▼
┌──────┐ ┌──────┐ ┌──────┐ ┌──────┐
│噪声→ │ │噪声→ │ │噪声→ │ │噪声→ │
│干净 │ ──► │干净 │ ──► │干净 │ ──► │干净 │
└──────┘ └──────┘ └──────┘ └──────┘
│ │ │ │
│ ┌──────────┘ │ │
│ │ ┌─────────────────────┘ │
│ │ │ ┌────────────────────────────────┘
▼ ▼ ▼ ▼
┌─────────────────────────────────────────────────┐
│ 最终输出序列 │
│ [Block 0] [Block 1] [Block 2] [Block 3] ... │
└─────────────────────────────────────────────────┘
4.4 条件信息流
每个 Block 生成时可以访问:
- 歌词条件:完整的歌词 token 序列
- 风格条件:MuLan 编码的风格嵌入
- 历史上下文:之前生成的 Clean Blocks(通过 KV Cache)
5. 训练机制详解
5.1 训练时的序列布局
训练时,所有 Block 同时处理(并行训练):
[Lyrics] [Clean₁] [Clean₂] ... [Clean_N] [Noisy₁] [Noisy₂] ... [Noisy_N]
↑ ↑ ↑ ↑ ↑ ↑ ↑
t=-1 t=1 t=1 t=1 t~U[0,1] t~U[0,1] t~U[0,1]
5.2 时间步分配
不同位置使用不同的时间步:
| 序列部分 | 时间步 t | 含义 |
|---|---|---|
| Lyrics | t = -1 | 条件信息,不参与去噪 |
| Clean Blocks | t = 1 | 完全干净的数据 |
| Noisy Blocks | t ~ U[0,1] | 随机噪声级别 |
代码位置:training/utils/attention_mask.py:104-170
# 时间步分配逻辑
time_ids[:, :lyrics_len] = -1 # 歌词: t=-1
time_ids[:, lyrics_len:clean_end] = 1 # Clean: t=1
# Noisy blocks: 每个 block 独立采样
for i in range(num_blocks):
t_i = torch.rand(1) # 随机采样 t ∈ [0,1]
time_ids[:, noisy_start:noisy_end] = t_i
5.3 Attention Mask 设计
训练时的注意力掩码是 Block Flow Matching 的核心:
K (被关注的位置)
↓
┌─────────────────────────────────────────────┐
│ Lyrics │ Clean₁ │ Clean₂ │ Noisy₁ │ Noisy₂ │
─────┼────────┼────────┼────────┼────────┼────────┤
Q Lyrics│ ✓ │ ✗ │ ✗ │ ✗ │ ✗ │
(查 ─────┼────────┼────────┼────────┼────────┼────────┤
询 Clean₁│ ✓ │ ✓ │ ✗ │ ✗ │ ✗ │
的 ─────┼────────┼────────┼────────┼────────┼────────┤
位 Clean₂│ ✓ │ ✓ │ ✓ │ ✗ │ ✗ │
置) ─────┼────────┼────────┼────────┼────────┼────────┤
Noisy₁│ ✓ │ ✗ │ ✗ │ ✓ │ ✗ │
─────┼────────┼────────┼────────┼────────┼────────┤
Noisy₂│ ✓ │ ✓ │ ✗ │ ✗ │ ✓ │
└─────────────────────────────────────────────┘
✓ = 可以关注 (mask=True)
✗ = 不能关注 (mask=False)
掩码规则:
- Lyrics:只能看到自己(歌词内部互相关注)
- Clean block i:可以看到 Lyrics + Clean blocks 1~i
- Noisy block i:可以看到 Lyrics + Clean blocks 1~(i-1) + 自己
5.4 为什么 Noisy 不能看到对应的 Clean?
这是关键设计!如果 Noisy block i 能看到 Clean block i:
- 模型会直接"抄答案"
- 不会学习真正的去噪能力
通过遮蔽,模型必须:
- 从歌词理解要生成什么内容
- 从历史 Clean blocks 理解上下文
- 真正学会去噪
5.5 训练损失计算
# 伪代码
def compute_loss(model, batch):
lyrics = batch["lyrics"]
clean_latent = batch["latent_z"] # 真实的干净 latent
# 1. 为每个 block 采样时间步 t
t = torch.rand(num_blocks) # t ∈ [0, 1]
# 2. 采样噪声
noise = torch.randn_like(clean_latent)
# 3. 构造 noisy latent: x_t = (1-t)*noise + t*clean
noisy_latent = (1 - t) * noise + t * clean_latent
# 4. 模型预测速度场
pred_velocity = model(noisy_latent, t, lyrics, ...)
# 5. 计算 Flow Matching Loss
target_velocity = clean_latent - noise # 真实速度
loss = MSE(pred_velocity, target_velocity)
return loss
6. 推理机制详解
6.1 推理流程概览
推理时采用逐 Block 生成,每个 Block 通过 ODE 求解去噪:
┌─────────────────────────────────────────────────────────────┐
│ 推理流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ Step 1: 预计算歌词 KV Cache │
│ ↓ │
│ Step 2: for block_id in range(num_blocks): │
│ │ │
│ ├─► 采样噪声 x_0 ~ N(0,I) │
│ │ │
│ ├─► ODE 求解: x_0 → x_1 (16步) │
│ │ │
│ ├─► 更新 KV Cache │
│ │ │
│ └─► 检测 EOS │
│ │
│ Step 3: 拼接所有 blocks → 完整 latent │
│ │
└─────────────────────────────────────────────────────────────┘
6.2 ODE 求解过程
每个 Block 的去噪通过 ODE 积分完成:
# cfm.py:175-181 - 核心采样代码
noisy_emb = torch.randn(batch, block_size, channels) # 采样噪声
t_set = torch.linspace(0, 1, steps) # 时间步: [0, 0.0625, ..., 1]
# ODE 求解
outputs = odeint(fn, noisy_emb, t_set, method="euler")
sampled = outputs[-1] # 取最终结果
Euler 方法图解:
t=0.0 t=0.25 t=0.5 t=0.75 t=1.0
│ │ │ │ │
▼ ▼ ▼ ▼ ▼
┌────┐ ┌────┐ ┌────┐ ┌────┐ ┌────┐
│ x₀ │→ │ x₁ │→ │ x₂ │→ │ x₃ │→ │ x₄ │
│噪声│ │ │ │ │ │ │ │干净│
└────┘ └────┘ └────┘ └────┘ └────┘
│ │ │ │
└───────┴───────┴───────┘
x_{i+1} = x_i + v_θ(x_i, t_i) × Δt
6.3 EOS 检测机制
模型学习输出全1向量来表示结束:
# cfm.py:209-225 - EOS 检测
curr_frame = clean_emb_stream[:, -1, :] # 最后一帧
eos = torch.ones_like(curr_frame) # 全1向量
mse = F.mse_loss(curr_frame, eos)
if mse <= 0.05: # 接近全1,检测到 EOS
# 回溯找到真正的结束位置
while mse <= 0.05:
pos -= 1
...
break # 提前结束生成
这允许可变长度生成,无需预先指定精确时长。
7. KV Cache 机制
7.1 为什么需要 KV Cache?
每个 Block 生成时需要 attend 到:
- 歌词嵌入(固定不变)
- 之前所有 Block(逐渐增长)
如果不缓存,计算量会随序列长度平方增长。
7.2 双缓存设计
DiffRhythm2 维护两套独立的 KV Cache:
┌──────────────────────────────────────────┐
│ KV Cache 结构 │
├──────────────────────────────────────────┤
│ kv_cache ← 条件路径(有 style) │
│ cfg_kv_cache ← 无条件路径(无 style) │
├──────────────────────────────────────────┤
│ 每个 cache 包含: │
│ ├─ text_key_cache (固定,只计算一次) │
│ ├─ text_value_cache (固定,只计算一次) │
│ ├─ context_key_cache (随 block 增长) │
│ └─ context_value_cache (随 block 增长) │
└──────────────────────────────────────────┘
7.3 缓存更新流程
# cfm.py:108-128 - 预计算歌词缓存
with kv_cache.cache_text():
self.transformer(x=text_emb, time=-1, ...)
# cfm.py:183-204 - 每个 block 后更新上下文缓存
with kv_cache.cache_context():
self.transformer(x=sampled_block, time=1, ...)
7.4 历史限制
为防止内存无限增长,可设置 num_history_block:
# 只保留最近 N 个 block 的 KV Cache
if num_history_block is not None:
context_cache = context_cache[-num_history_block * block_size:]
8. Classifier-Free Guidance
8.1 CFG 原理
CFG 通过对比有条件和无条件预测来增强条件控制:
output = pred + cfg_strength × (pred - null_pred)
pred:有条件预测(使用 style prompt)null_pred:无条件预测(style = 0)cfg_strength:控制强度(默认 1.0)
8.2 代码实现
# cfm.py:143-172 - ODE 函数中的 CFG
def fn(t, x):
# 条件预测
pred = self.transformer(x, style_prompt=style_prompt, ...)
# 无条件预测
null_pred = self.transformer(x, style_prompt=zeros, ...)
# CFG 组合
return pred + (pred - null_pred) * cfg_strength
8.3 CFG 强度效果
| cfg_strength | 效果 |
|---|---|
| 0.0 | 纯条件生成,风格影响弱 |
| 1.0 | 标准 CFG,平衡质量和多样性 |
| 2.0+ | 强条件控制,可能过拟合 |
9. 代码实现解析
9.1 关键文件
| 文件 | 功能 |
|---|---|
diffrhythm2/cfm.py |
推理主逻辑,Block 循环生成 |
diffrhythm2/backbones/dit.py |
DiT Transformer 模型 |
diffrhythm2/cache_utils.py |
KV Cache 管理 |
training/utils/attention_mask.py |
训练时的 Mask 生成 |
training/train_dit.py |
训练主循环 |
9.2 推理核心代码
# cfm.py:72-231 - sample_block_cache 方法
# 1. 初始化
num_blocks = duration // block_size
kv_cache = BlockFlowMatchingCache(...)
# 2. 预计算歌词 KV Cache
with kv_cache.cache_text():
self.transformer(x=text_emb, time=-1, ...)
# 3. Block 循环
for bid in range(num_blocks):
# 3.1 采样噪声
noisy_emb = torch.randn(batch, block_size, channels)
# 3.2 ODE 求解
outputs = odeint(fn, noisy_emb, t_set)
sampled = outputs[-1]
# 3.3 更新 KV Cache
with kv_cache.cache_context():
self.transformer(x=sampled, time=1, ...)
# 3.4 拼接结果
clean_emb_stream = torch.cat([clean_emb_stream, sampled], dim=1)
# 3.5 EOS 检测
if detect_eos(sampled):
break
return clean_emb_stream
10. 与其他方法对比
10.1 生成范式对比
| 方法 | 代表模型 | 优势 | 劣势 |
|---|---|---|---|
| 纯自回归 | AudioLM | 长序列,流式 | 慢,误差累积 |
| 纯非自回归 | FastSpeech | 快,并行 | 长度受限 |
| DDPM | DiffWave | 高质量 | 步数多,慢 |
| Block Flow Matching | DiffRhythm2 | 快+长+高质量 | 设计复杂 |
10.2 采样效率对比
| 方法 | 采样步数 | 生成 3 分钟音乐 |
|---|---|---|
| DDPM | 1000 | ~10 分钟 |
| DDIM | 50 | ~30 秒 |
| Flow Matching | 16 | ~5 秒 |
10.3 Block Flow Matching 的创新点
- 半自回归架构:Block 间 AR + Block 内 NAR
- 高效 KV Cache:避免重复计算历史
- 精心设计的 Mask:训练时模拟推理行为
- EOS 检测:支持可变长度生成
- 双路 CFG:增强条件控制
参考资料
- 论文:DiffRhythm 2: High-Fidelity Lyrics-to-Song Generation
- 代码:
diffrhythm2/cfm.py,training/utils/attention_mask.py

Comments NOTHING