分布式训练教程:从通信原语到大模型并行训练

开篇总结

大模型训练的核心张力是:显存希望每个 rank 少存一点,吞吐希望每个 rank 多算一点,数学等价性又要求该同步的数据必须同步。分布式训练的难点正是在这三者之间做切分和通信的交换。

  • 通信原语不是背景知识,而是并行策略的“接口”:DDP 的梯度同步落在 All-Reduce,FSDP/ZeRO-3 的参数重建落在 All-Gather,梯度分片落在 Reduce-Scatter,MoE token 路由落在 All-to-All。
  • 并行策略的区别不在名字,而在切分对象:DP 切 batch,TP 切单层矩阵,PP 切层,SP/CP 切序列或上下文,EP 切 expert,ZeRO/FSDP 切模型状态。
  • 显存估算必须逐项拆开:parameter、gradient、optimizer state、activation、communication buffer、临时 full parameter 的峰值不能混在一句“省 N 倍”里。
  • 3D/4D/5D 并行不是固定标准术语;更可追问的表达是说明每个 process group 负责哪种切分,以及 forward/backward/optimizer step 中触发什么通信。
  • 没有多 GPU 环境时,CPU-only 仿真仍然有价值:它不能替代真实性能 profiling,但能把“每个 rank 通信前后持有什么”固定成可打印的 tensor。

主线可以压缩成一条链:

单卡训练为什么不够
-> 哪些 tensor 或状态可以被切分
-> 切分后需要哪些 collective 通信
-> 通信量和显存如何变化
-> 多种并行策略如何组合成 3D/4D/5D 训练
-> PyTorch FSDP、DeepSpeed ZeRO、Megatron-LM 等框架分别落在哪些抽象层

相比普通通识介绍,本文更关注三个抓手:

  • 数据流:一次 step 中 batch、activation、parameter、gradient、optimizer state 到底在哪里产生、在哪里通信、在哪里释放。
  • 通信成本:All-Gather、Reduce-Scatter、All-Reduce 不只看输入输出,还要看每个 rank 发送多少字节、需要多少通信轮次。
  • 组合视角:DP/DDP、TP、PP、SP/CP、EP、ZeRO/FSDP 不是互斥选项,而是分别切不同维度,最终组成大模型训练中的多维并行。

配套代码:

  • examples/fsdp_zero3_sim.py:CPU-only FSDP / ZeRO-3 数值实验,用两个虚拟 rank 展示参数分片、all-gather、reduce-scatter 和 shard optimizer update。
  • examples/collectives_cost_sim.py:CPU 上模拟常见 collective 的输入输出语义,并用简单 alpha-beta 模型估算通信开销。
  • examples/collective_cost_model.py:共享的 alpha-beta 与 ring collective 成本模型,供 collective 语义脚本和策略估算脚本复用。
  • examples/memory_comm_estimator.py:把 DDP、ZeRO-1/2/3、FSDP 的模型状态显存和通信量放在同一张估算表里。
  • examples/tp_linear_sim.py:手写 column-parallel linear 和 row-parallel linear,并对照完整 Linear 验证 forward/backward 等价。

目录

1. 学习边界和心智模型

本项目面向 AI infra 大模型算法岗位,不以 CUDA 算子开发或 NCCL kernel 实现为主线。更重要的是能解释清楚:

  • 大模型训练的瓶颈来自哪里:显存、算力、通信、训练稳定性。
  • 哪些对象可以切:数据、参数、层、张量维度、序列维度、专家、优化器状态。
  • 切完以后如何恢复数学等价性:通过 All-Reduce、All-Gather、Reduce-Scatter、All-to-All 等 collective。
  • 每种并行策略牺牲什么换来什么:显存、吞吐、通信量、bubble、负载均衡和实现复杂度。

可以用四层心智模型组织所有内容:

资源层:多 GPU / 多机提供显存、算力和网络带宽
通信层:collective 定义跨 rank tensor 如何交换
并行层:DP、TP、PP、SP/CP、EP、ZeRO/FSDP 决定切分方式
训练层:一次 step 中 parameter、activation、gradient、optimizer state 的生命周期

不要把“模型并行”理解成一个单一方法。更准确的说法是:不同并行策略切的是不同对象。

策略 主要切分对象 主要解决问题 典型通信
DP/DDP batch/data 提升吞吐 gradient All-Reduce
ZeRO/FSDP parameter/gradient/optimizer state 降低模型状态显存 All-Gather、Reduce-Scatter
TP 单层矩阵维度 单层太大或计算太重 All-Reduce、All-Gather、Reduce-Scatter
PP layer/stage 整模型层数太多 activation send/recv
SP/CP sequence/context 长上下文 activation/attention 显存 All-Gather、Reduce-Scatter、ring exchange
EP expert/token routing MoE 大参数规模 All-to-All / All-to-AllV

2. 从单卡训练 step 出发

单卡训练的抽象非常简单:

pred = model(x)
loss = criterion(pred, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()

分布式训练的所有复杂性,本质上都来自这几行代码里的对象被拆开了:

  • x 被数据并行切成不同 rank 的 batch shard。
  • model 的参数可能被 DDP 复制,也可能被 FSDP/ZeRO-3 分片。
  • activation 可能跨 PP stage 传递,也可能按 sequence 维度分片。
  • loss.backward() 产生的梯度可能需要 All-Reduce,也可能只保留 Reduce-Scatter 后的 shard。
  • optimizer.step() 可能在每个 rank 更新完整参数,也可能只更新本 rank 的参数分片。

纯 DDP 的 step

global batch
-> DistributedSampler 按 DP rank 切分 batch
-> rank i 拿到 local batch shard
-> 每个 rank 用完整模型副本 forward
-> 每个 rank 计算 local loss
-> backward 产生本地梯度
-> DDP bucket 触发梯度 All-Reduce
-> 每个 rank 用同步后的梯度执行 optimizer step
-> 所有 rank 继续保持相同参数

关键点:

  • DDP 梯度同步通常发生在每个 training step 的 backward 过程中,而不是每个 epoch 结束。
  • PyTorch DDP 会按 bucket 组织梯度,某个 bucket 中的梯度就绪后就可以启动通信,从而和剩余 backward 计算重叠。
  • DDP 的数学目标是让每个 rank 都拿到等价的全局平均梯度。

FSDP 的 step

persistent:
rank i 常驻 parameter_shard_i 和 optimizer_state_shard_i

forward:
for each wrapped module:
    all-gather 当前 module 的 parameter shards
    用完整参数执行本地 batch shard 的 forward
    reshard/free 完整参数

backward:
for each wrapped module in reverse:
    all-gather 当前 module 的 parameter shards
    计算本地完整梯度
    reduce-scatter 梯度,得到 rank i 的 gradient shard
    free 完整参数和完整梯度

optimizer:
rank i 只用自己的 gradient shard 和 optimizer state shard 更新 parameter shard

一句话记忆:

FSDP 参数需要时临时 all-gather,梯度产生后 reduce-scatter 回 shard,optimizer 始终只更新 shard。

3. 通信原语:语义、出现位置和通信量

通信原语要同时记三件事:

  • 语义:输入输出 tensor 如何变化。
  • 位置:在训练 step 的哪里出现。
  • 成本:每个 rank 大概发送多少数据,需要多少通信轮次。

下面假设有 N 个 rank,完整 tensor 大小为 D bytes。为了做 chunk-based collective,完整 tensor 通常被切成 N 块,每块大小约 D/N

实际通信时间可以粗略写成:

time ~= alpha * num_steps + beta * bytes_per_rank

其中 alpha 表示每轮通信延迟,beta 表示带宽倒数。大 tensor 更受带宽影响,小 tensor 更受延迟影响。

All-Gather

语义:每个 rank 贡献一个 shard,所有 rank 最终都拿到完整拼接结果。

before:
rank0 = [a0]
rank1 = [a1]
rank2 = [a2]
rank3 = [a3]

after:
rank0 = [a0, a1, a2, a3]
rank1 = [a0, a1, a2, a3]
rank2 = [a0, a1, a2, a3]
rank3 = [a0, a1, a2, a3]

特点:

  • 不做 sum/max 等规约,只收集并拼接。
  • 常用于 FSDP/ZeRO-3 中从参数 shard 临时重建完整参数。
  • Ring all-gather 中,每个 rank 发送约 (N-1)/N * D bytes,通信轮次约 N-1
  • 不应简单说“通信量与节点数无关”。更准确是:当 N 增大时,单 rank 发送量趋近于 D,但通信轮次仍随 N 增长。

Reduce-Scatter

语义:先对所有 rank 的完整 tensor 做规约,再把规约结果按 chunk 分散给不同 rank。

before:
rank0 = [x_00, x_01, x_02, x_03]
rank1 = [x_10, x_11, x_12, x_13]
rank2 = [x_20, x_21, x_22, x_23]
rank3 = [x_30, x_31, x_32, x_33]

after with SUM:
rank0 = [x_00 + x_10 + x_20 + x_30]
rank1 = [x_01 + x_11 + x_21 + x_31]
rank2 = [x_02 + x_12 + x_22 + x_32]
rank3 = [x_03 + x_13 + x_23 + x_33]

特点:

  • 每个 rank 只得到规约结果的 1/N
  • 常用于 ZeRO-2 和 FSDP/ZeRO-3 的梯度同步。
  • Ring reduce-scatter 中,每个 rank 发送约 (N-1)/N * D bytes,通信轮次约 N-1
  • 它比 All-Reduce 更适合“后续只需要梯度分片”的场景,因为没有必要把完整规约梯度再复制给每个 rank。

All-Reduce

语义:所有 rank 的 tensor 做规约,然后每个 rank 都得到完整规约结果。

All-Reduce = Reduce-Scatter + All-Gather

DDP 中典型用法是对梯度 bucket 做 All-Reduce,使所有 rank 拿到相同的全局平均梯度。实现上经常先 sum,再除以 world size。

Ring all-reduce 成本:

bytes_per_rank ~= 2 * (N - 1) / N * D
steps ~= 2 * (N - 1)

它的带宽利用率通常较好,但小 tensor 或 rank 数很多时,延迟项会变明显。

Broadcast

语义:root rank 把一份完整数据复制给其他 rank。

常见位置:

  • DDP 初始化时同步参数。
  • 某些全局元数据或 checkpoint state 分发。

All-to-All

语义:每个 rank 都把不同 chunk 发给不同目标 rank,每个 rank 也从所有 rank 接收属于自己的 chunk。

常见位置:

  • MoE expert parallelism 中 token dispatch 和 combine。
  • 某些 sequence/context parallel 的 token 或 block 重排。

All-to-AllV 表示不同 rank 到不同目标 rank 的数据量可以不同。MoE 的 token routing 往往动态且不均匀,所以工程上经常需要处理 All-to-AllV 式的不规则通信。

4. 数据并行:DP 与 DDP

数据并行切的是 batch,不切模型。每个 rank 都有完整模型副本,处理不同数据子集。

DP/DDP 的核心逻辑

rank0: model(theta), batch_0 -> grad_0
rank1: model(theta), batch_1 -> grad_1
rank2: model(theta), batch_2 -> grad_2
rank3: model(theta), batch_3 -> grad_3

All-Reduce:
grad = average(grad_0, grad_1, grad_2, grad_3)

每个 rank:
theta <- optimizer(theta, grad)

优点:

  • 最容易理解和落地。
  • 对模型代码侵入小。
  • 适合通过增大 global batch 提升吞吐。

限制:

  • 每个 rank 都复制完整 parameters、gradients、optimizer states。
  • 模型状态显存随模型规模线性增长,不会因为 DP rank 增加而下降。
  • 跨 rank 同步完整梯度,通信量随模型参数量增长。

PyTorch DDP 的实现抓手

面试中可以这样描述 DDP:

DDP 初始化时同步参数。
它给参数注册 autograd hook。
backward 时梯度逐步产生,DDP Reducer 把梯度放入 bucket。
bucket ready 后启动 All-Reduce。
All-Reduce 完成后,每个 rank 的梯度一致。
optimizer 在每个 rank 本地执行相同更新。

关键词:

  • DistributedSampler
  • gradient bucket
  • autograd hook
  • overlap communication with backward
  • no_sync() gradient accumulation

5. 状态分片:ZeRO 与 FSDP

普通 DDP 最大的问题是模型状态冗余。以 Adam mixed precision 训练为例,单个参数可能对应:

  • BF16/FP16 parameter:2 bytes
  • BF16/FP16 gradient:2 bytes
  • FP32 master parameter:4 bytes
  • FP32 Adam first moment m:4 bytes
  • FP32 Adam second moment v:4 bytes

粗略就是 16 bytes * 参数量,还没算 activation、通信 buffer 和临时 workspace。

ZeRO 的核心思想是:数据并行 rank 之间不必重复保存所有模型状态。按分片对象不同,分成三个阶段。

ZeRO-1:切 optimizer states

常驻状态:

  • parameters:完整复制。
  • gradients:完整复制。
  • optimizer states:按 DP rank 分片。

流程:

  1. 每个 rank 用完整参数 forward/backward。
  2. 梯度通过 All-Reduce 得到全局平均梯度。
  3. 每个 rank 只更新自己负责的 parameter shard,因为只有这部分 optimizer states。
  4. 更新后的 parameter shard 再 All-Gather 成完整参数。

ZeRO-2:继续切 gradients

常驻状态:

  • parameters:完整复制。
  • gradients:分片。
  • optimizer states:分片。

流程:

  1. forward/backward 仍使用完整参数。
  2. 梯度产生后通过 Reduce-Scatter 做规约并分片。
  3. 每个 rank 用自己的 gradient shard 和 optimizer state shard 更新 parameter shard。
  4. parameter shard 再 All-Gather 成完整参数。

ZeRO-3 / FSDP:继续切 parameters

常驻状态:

  • parameters:分片。
  • gradients:分片。
  • optimizer states:分片。

流程:

  1. 进入某个 module 前 All-Gather 当前 module 的参数 shard。
  2. 使用临时完整参数计算 forward。
  3. forward 后释放完整参数,只保留 shard。
  4. backward 到该 module 时再次 All-Gather 参数。
  5. 计算本地梯度后 Reduce-Scatter,得到 gradient shard。
  6. optimizer 只更新本 rank 的 parameter shard。

FSDP 和 ZeRO-3 的关系

可以说 FSDP 思想上接近 ZeRO-3,但不要把二者说成完全同一个东西。

  • ZeRO 是 DeepSpeed 提出的减少数据并行冗余状态的一组方法。
  • FSDP 是 PyTorch 中围绕 module wrapping、FlatParameter、reshard、prefetch、mixed precision 等机制实现的 fully sharded data parallel。
  • FSDP 的工程核心是“按 wrapped module 管理参数 all-gather 和 reshard”,而不是一次性 all-gather 整个模型。

显存估算要逐项算

不要脱离假设说“ZeRO-3 一定降低 64 倍”。更专业的说法是:

如果 world size = N,被分片的状态项理论上常驻显存近似降低到 1/N。
但实际峰值还要加上当前 FSDP unit 的完整参数、prefetch buffer、通信 buffer、activation、padding 和 allocator 碎片。

因此分析 FSDP 显存时,应至少拆成:

  • persistent parameter shards
  • persistent optimizer state shards
  • gradient shards
  • gathered full parameter for current module
  • activation
  • communication buffer

以 mixed precision Adam 为例,如果每个参数对应 BF16/FP16 parameter、BF16/FP16 gradient、FP32 master parameter、Adam FP32 m/v,模型状态可以粗略估成:

策略 每 rank 常驻模型状态 额外峰值
DDP param + grad + master + m + v 无参数 all-gather 峰值
ZeRO-1 param + grad + (master + m + v) / N 更新参数 shard 后可能需要 all-gather
ZeRO-2 param + (grad + master + m + v) / N 更新参数 shard 后可能需要 all-gather
ZeRO-3 / FSDP (param + grad + master + m + v) / N 当前 FSDP unit 的完整参数、通信 buffer、prefetch buffer

通信量也要和 collective 对上:

DDP gradient sync      ~= All-Reduce(full_grad)
ZeRO-1 optimizer step  ~= All-Reduce(full_grad) + All-Gather(updated_param)
ZeRO-2 gradient sync   ~= Reduce-Scatter(full_grad) + All-Gather(updated_param)
FSDP/ZeRO-3 one step   ~= All-Gather(param) in forward
                         + All-Gather(param) in backward
                         + Reduce-Scatter(full_grad)

examples/collective_cost_model.py 保存这套共享 ring 公式;examples/memory_comm_estimator.py 用它估算策略级通信项,examples/collectives_cost_sim.py 用它标注单个 collective 的语义输出。两个可运行脚本配合看,可以把“公式里的 D bytes”和“rank 上实际拿到什么 tensor”对齐起来。

6. 张量并行:把单层矩阵乘法切开

张量并行切的是单层内部的矩阵计算。以 PyTorch 线性层为例:

Y = X @ W.T + b
W shape = [out_features, in_features]

Column Parallel Linear

按 output dimension 切 W

W = concat([W_0, W_1, ..., W_k], dim=0)
Y_i = X @ W_i.T
Y = concat([Y_0, Y_1, ..., Y_k], dim=-1)

特点:

  • 输入 X 通常复制到 TP group 内所有 rank。
  • 每个 rank 计算一部分 output features。
  • 如果下一层可以直接消费分片输出,就可以延迟 All-Gather。

反向:

  • 每个 rank 计算自己的 dW_i
  • 因为每个 rank 的输出都依赖同一个 XdX 需要把各 rank 的局部贡献相加,常见通信是 All-Reduce。

Row Parallel Linear

按 input dimension 切 W,同时切 X

X = concat([X_0, X_1, ..., X_k], dim=-1)
W = concat([W_0, W_1, ..., W_k], dim=1)
Y_i = X_i @ W_i.T
Y = sum_i(Y_i)

特点:

  • 每个 rank 只持有一段 input features 和对应权重。
  • forward 的 partial output 需要求和,常见通信是 All-Reduce 或 Reduce-Scatter 变体。
  • dX_i 可以留在本 rank,完整 dX 是各分片拼接。

Transformer 中的典型组合

Megatron-style TP 常把 MLP 写成:

X
-> ColumnParallelLinear: 得到分片 hidden
-> GeLU/SwiGLU: 本地计算
-> RowParallelLinear: 得到 partial output
-> All-Reduce partial output
-> residual add

这个设计的直觉是:尽量让中间大 hidden tensor 保持分片,只在必要位置通信。

examples/tp_linear_sim.py 把这里的两个 Linear 拆法都写成了 CPU 数值实验:

  • Column parallel:每个 rank 持有 output-feature shard,forward 输出需要 concat/all-gather,backward 的 dX 需要 sum/all-reduce。
  • Row parallel:每个 rank 持有 input-feature shard,forward partial output 需要 sum/all-reduce,backward 的 dX 是各 input shard 的 concat。

脚本会对照完整 Linear 层打印 YdXdWdb 的最大误差。误差为 0 或浮点舍入级别,说明 TP 只是重排计算和通信位置,不改变线性层的数学结果。

7. 流水线并行:把层按 stage 切开

流水线并行切的是层。它解决的是“整模型太深,或者完整模型无法放进单个 rank/group”的问题。

核心概念:

  • stage:一组连续层。
  • micro-batch:把一个 global batch 切成多个小批次,用于填满流水线。
  • bubble:stage 等待输入或等待反向梯度时的空闲。

GPipe

GPipe 的基本调度是:

先跑完所有 micro-batch forward
再跑所有 micro-batch backward

优点是调度简单,bubble 相对容易分析。缺点是 activation 存活时间较长,因为较早 micro-batch 的 forward activation 要等到 backward 才能释放。

1F1B

1F1B 的基本思想是 warmup 后交替执行:

one forward, one backward

它可以缩短 activation 存活时间,显存更友好,是实际大模型流水线训练中更常见的调度思想。

Interleaved 1F1B

一个 rank 持有多个 virtual stage,减少 pipeline bubble,但代价是调度和通信更复杂。

PP 的通信主要是相邻 stage 间发送 activation 和 activation gradient,不是 All-Reduce。把 PP 和 TP/DP 混合时,同一 step 中会同时出现 stage send/recv、TP collective 和 DP/FSDP 梯度同步。

8. 序列并行和上下文并行

序列相关并行容易混淆,建议拆成两个概念。

Megatron-style Sequence Parallelism

它通常和 TP 配套,把 LayerNorm、Dropout 等 activation 沿 sequence 维度切分。核心收益是减少 activation 显存,并把部分 TP 中的 All-Reduce 拆成:

Reduce-Scatter + All-Gather

这样通信量可以接近等价,但 activation 不再完整复制在每个 TP rank 上。

Context Parallelism

CP 面向长上下文 attention。它把 sequence/context 维度切给多个 rank,使每个 rank 不必保存完整长序列的 attention 中间状态。

难点是 attention 的 softmax 需要看到足够的 K/V 信息。常见实现会通过:

  • all-gather K/V
  • ring attention
  • block-wise exchange
  • reduce-scatter output

来避免单 rank 上构造完整 seq_len x seq_len 的注意力矩阵。

9. 专家并行和 MoE 通信

专家并行用于 MoE。它切的是 expert,而不是普通 dense layer 的矩阵维度。

一次 MoE layer 的数据流:

hidden states
-> router 为每个 token 选择 top-k expert
-> 按目标 expert/rank 打包 token
-> All-to-All dispatch
-> local experts compute
-> All-to-All combine
-> 按原 token 顺序恢复输出

EP 的主要难点:

  • 负载不均衡:某些 expert 收到过多 token。
  • capacity factor:限制每个 expert 的最大 token 数。
  • token drop 或 padding:处理超过容量和通信 shape 对齐。
  • All-to-AllV:不同 expert 的 token 数不均匀,通信量动态变化。

MoE 的独特性在于:参数量可以很大,但每个 token 只激活少量 expert。因此它常常提升“总参数规模”,而不是等比例提升“每 token 计算量”。

10. 从 3D 并行到 5D 并行

所谓 3D/4D/5D 并行不是固定标准术语,而是一种训练配置的组织方式。关键是说清楚每一维切什么。

3D 并行

常见指:

DP + TP + PP
  • DP 切 batch。
  • TP 切单层矩阵。
  • PP 切层/stage。

3D 并行适合 dense Transformer 大模型,是理解 Megatron-LM 类训练栈的基础。

4D 并行

在 3D 基础上再加入一种维度,常见有两种语境:

DP + TP + PP + ZeRO/FSDP

或:

DP + TP + PP + SP/CP

前者强调模型状态分片,后者强调长序列 activation/attention 分片。回答时不要只说“4D”,要明确第四维到底是什么。

5D 并行

大模型 MoE 或长上下文训练中,常见可以组织成:

DP/FSDP + TP + PP + SP/CP + EP

每一维的职责:

  • DP/FSDP:不同数据 shard 和模型状态分片。
  • TP:单层 dense 计算切分。
  • PP:层间 stage 切分。
  • SP/CP:序列或上下文维度切分。
  • EP:MoE expert 和 token routing 切分。

面试中更专业的表达是:

我不会把 3D/5D 当成死记术语,而会先说明每个 process group 对应哪种切分维度,以及这些 group 在 forward/backward/optimizer step 中分别产生哪些通信。

11. 数值精度如何配合分布式训练

数值精度不是分布式训练的附属概念。它直接影响显存、通信量、吞吐和训练稳定性。

常见格式

格式 大致特点 常见用途
FP32 精度和动态范围较好,成本高 optimizer state、master weights、部分 reduction
TF32 NVIDIA Ampere 后常见的 matmul 加速格式 FP32 输入的矩阵乘加加速
FP16 显存低、吞吐高,但动态范围小 mixed precision training
BF16 与 FP32 类似的 exponent,动态范围更好 大模型训练常用
FP8 更低精度,常见 E4M3/E5M2 需要配套 scaling 和框架支持

如果看到“BF8”这个说法,需要谨慎。更常见、标准的术语是 BF16 和 FP8。FP8 又常细分为 E4M3、E5M2 等格式。

为什么 BF16 常用于大模型

FP16 的 mantissa 和 exponent 都较小,容易 overflow/underflow,因此经常需要 loss scaling。BF16 的 mantissa 更短,但 exponent 和 FP32 接近,动态范围更大,对大模型训练更稳。

精度和分布式通信的关系

精度会影响:

  • 参数和 activation 显存。
  • 通信 tensor 的字节数,例如 BF16 梯度通信约为 FP32 的一半。
  • reduction 的数值稳定性。有些框架会用低精度通信,但在 FP32 中累积或维护 optimizer state。
  • FSDP/ZeRO 的状态估算。optimizer state 常用 FP32,即使 forward/backward 用 BF16。

一个更专业的表述是:

mixed precision 不是简单把所有 tensor 改成 FP16/BF16。
通常 forward/backward 用低精度提升吞吐和降低显存,optimizer state 或 master weight 保留高精度保证更新稳定。
分布式场景下,通信 tensor 的 dtype 还会直接影响通信量。

12. 主流训练框架的关系

这些框架可以按抽象层理解,而不是互相替代地背名字。

PyTorch DDP / FSDP

  • DDP:数据并行,完整模型副本,gradient All-Reduce。
  • FSDP:PyTorch 原生 fully sharded data parallel,参数、梯度、optimizer state 分片。

PyTorch 的 collective API 通过 c10d ProcessGroup 调到底层 backend。GPU 训练通常使用 NCCL,CPU 或测试场景可能使用 Gloo。

DeepSpeed ZeRO

DeepSpeed 的 ZeRO 系列专注减少数据并行冗余状态:

  • ZeRO-1:optimizer state sharding。
  • ZeRO-2:optimizer state + gradient sharding。
  • ZeRO-3:optimizer state + gradient + parameter sharding。

DeepSpeed 还常和 offload、activation checkpointing、pipeline parallelism 等组合使用。

Megatron-LM

Megatron-LM 的代表性价值在于大模型并行策略组合,尤其是:

  • tensor parallelism
  • pipeline parallelism
  • sequence parallelism
  • context parallelism
  • 和 data parallel / distributed optimizer 的组合

理解 Megatron-LM 时,重点不是“它有几个并行开关”,而是它如何构造不同 process group,并让每个 group 承担不同维度的通信。

Accelerate、NeMo、Colossal-AI 等

这些工具或框架往往提供更高层的训练配置、launcher、策略封装和生态集成。面试中可以把它们放在“编排层”理解:它们可能调用 PyTorch DDP/FSDP、DeepSpeed 或 Megatron 风格并行策略,而不是每个都从零实现所有 collective。

13. 关键时间线和出处

这些条目用于给主教程中的判断建立出处锚点,不是完整论文综述。

时间 工作或文档 和本文的关系
2018 GPipe 把 micro-batch pipeline 作为训练大模型的一种系统化调度方式。
2019 Megatron-LM 系统展示 Transformer tensor model parallelism,并奠定 TP + PP + DP 组合讨论的基础。
2019/2020 ZeRO 提出按 optimizer state、gradient、parameter 逐步消除数据并行冗余,是理解 ZeRO-1/2/3 和 FSDP 的核心来源。
PyTorch 官方文档 DDP notes / FSDP docs DDP bucket、autograd hook、FSDP wrapping、sharding、reshard 等工程语义以官方文档为准。
NVIDIA Transformer Engine 文档 FP8 primer FP8 E4M3/E5M2、scaling 与低精度训练相关概念的工程参考。

参考来源:

14. 如何使用本项目的代码实验

这些实验不是为了替代真实多 GPU profiling,而是在没有多 GPU 环境时,把分布式训练中最容易抽象化的对象固定成可打印、可检查的 CPU tensor。阅读时应关注每个虚拟 rank 在通信前后持有什么,而不是把脚本当作高性能实现。

FSDP 数值实验

运行:

python examples/fsdp_zero3_sim.py

这个脚本用两个虚拟 rank 模拟:

  • 参数 flatten 后分片。
  • forward 前 all-gather 当前 layer 的完整参数。
  • forward 后释放完整参数。
  • backward 再次 all-gather 参数。
  • 梯度 reduce-scatter 回 shard。
  • Adam 的 m/v optimizer state 始终保持分片。

验证点是:FSDP shard update 后重新 gather 成完整模型,应该与单进程完整模型训练一步的结果一致,差异只来自浮点误差。

Collective 和通信开销实验

运行:

python examples/collectives_cost_sim.py

这个脚本不是模拟真实网络,而是做两件事:

  • 用 tensor 展示 All-Gather、Reduce-Scatter、All-Reduce、All-to-All 的输入输出。
  • 按 ring 或 pairwise 直觉统计每个 rank 发送 bytes 和通信 step,并用 alpha-beta 模型估算通信时间。

它和本文第 3 节对应:先理解 collective 语义,再把通信量和训练中的使用位置对上。

通信成本公式来自 examples/collective_cost_model.py。该文件只保存共享的 alpha-beta 和 ring collective 记账逻辑,不单独作为实验入口。

显存与通信量估算实验

运行:

python examples/memory_comm_estimator.py

这个脚本和 collective 实验共享 examples/collective_cost_model.py 中的 ring 成本公式,但视角从“单个 collective 怎么变换 tensor”切到“一种并行策略在一个 training step 中触发哪些 collective”。它会打印:

  • DDP、ZeRO-1、ZeRO-2、ZeRO-3/FSDP 的 per-rank persistent model-state memory。
  • FSDP 当前 unit all-gather 带来的额外峰值。
  • DDP/ZeRO/FSDP 每 step 的主要模型状态通信量。

Tensor Parallel Linear 数值实验

运行:

python examples/tp_linear_sim.py

这个脚本和本文第 6 节对应:完整 Linear 作为 reference,column-parallel 和 row-parallel 分别手写 forward/backward,再比较 YdXdWdb。它适合用来检查自己是否真的理解“列并行为什么要 all-reduce dX,行并行为什么要 all-reduce output”。

15. 面试表达线索

回答分布式训练问题时,可以固定按这条链路组织:

  1. 训练瓶颈是什么:显存、计算、通信、长上下文、MoE 负载。
  2. 切分对象是什么:data、parameter、gradient、optimizer state、layer、tensor、sequence、expert。
  3. 引入什么通信:All-Reduce、All-Gather、Reduce-Scatter、All-to-All、send/recv。
  4. 显存和通信怎么变:哪些状态从完整复制变成 1/N,哪些地方增加临时 buffer 或额外通信。
  5. 数据流是否仍数学等价:梯度如何聚合,参数如何保持一致或按 shard 更新。
  6. 工程代价是什么:bucket、prefetch、reshard、bubble、load balance、activation checkpointing、mixed precision。

可以用一句完整的话收束:

分布式训练的核心不是某个并行名词,而是为每类 tensor 选择合适的切分维度,并用相应 collective 在需要的位置恢复数学等价性,同时控制显存峰值和通信开销。



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Google Gemini updates: Flash 1.5, Gemma 2 and Project Astra
  • Displaying External Posts on Your al-folio Blog
  • nano-vllm Pipeline Tutorial:从 Prompt 到 Output 的推理闭环
  • Attention 原理、实现与演进教程
  • 基于流的生成模型:从 Normalizing Flow 到 Stochastic Interpolants 与 MeanFlow