发起 PyTorch DDP 分布式训练
PyTorch DistributedDataParallel (DDP) 是一种高效的数据并行分布式训练方法。本文详细描述了如何使用 AIStudio 「任务」功能配置和启动 PyTorch DDP 分布式训练任务。
基本流程
在 AIStudio 任务中,通过以下步骤启动 PyTorch DDP 分布式训练任务:
- 创建训练任务:配置任务 Worker 规格、Worker 数量,设置「分布式框架」为 Pytorch DDP。
- 初始化:平台创建对应的 pod,并注入必要环境变量,供用户代码获取环境信息。
- 容错检查:训练开始前,平台发起自检,确保所有 Worker 之间网络通畅、GPU、存储工作正常。
- 训练执行:
- 若任意 pod 失败(退出码非 0),训练任务失败。
- 若所有 pod 成功完成(退出码为 0),训练任务成功。
环境变量
我们平台在分布式训练任务中采用了与 PyTorch 原生定义不同的环境变量设置。以下是平台提供的内置环境变量:
MASTER_ADDR
:DDP 分布式通信的主机 IP 地址或名称(通常为 worker0 的 pod 名称),由系统自动解析。MASTER_PORT
:主节点上开启的端口号,默认值为 29500。WORLD_SIZE
:启动的 pod 数量(即节点数量),注意:不同于 PyTorch 原生的WORLD_SIZE
(总进程数)。RANK
:每个 pod 的编号(从 0 到WORLD_SIZE - 1
),用于设置torchrun
的--node_rank
参数。
NOTE
平台将 WORLD_SIZE
定义为 pod(节点)的数量,而非 PyTorch 中的总进程数。这一选择与 Kubeflow 的 pytorch-operator
保持一致,方便从 Kubeflow 迁移的用户无缝过渡。
Pytorch DDP 关键参数
以下参数用于配置 torchrun
命令:
--master_addr
:主节点地址,对应平台注入的环境变量MASTER_ADDR
。--master_port
:主节点端口,对应平台注入的环境变量MASTER_PORT
。--nnodes
:节点数量,对应平台注入的环境变量WORLD_SIZE
。--node_rank
:节点编号,对应平台注入的环境变量RANK
。--nproc_per_node
:每个 pod 的进程数,通常等于 GPU 数量。例如:shellGPUS_PER_NODE=$(nvidia-smi --query-gpu=index --format=csv,noheader | wc -l)
NOTE
您可以硬编码参数(如 --nnodes=2
或 --nproc_per_node=4
),但注意这会降低灵活性,硬编码值在 Worker 数量或 GPU 数量变化时可能会失效。
启动命令
为实现统一配置管理、和更好的复用性,通常使用 Shell 脚本封装复杂的 PyTorch DDP torchrun
bash 命令。
通过网页提交训练任务时,可以在界面中填写 Bash 启动命令(entrypoint),用于触发封装的 PyTorch DDP Shell 脚本。以下是详细说明:
Bash 启动命令:用户在网页表单中输入的 Bash 命令是任务的启动指令,通常用于调用封装 DDP 训练的 Shell 脚本。例如,输入
./launch_ddp.sh train.py
指示平台执行名为launch_ddp.sh
的脚本,可按需传递训练脚本参数。shell# 推荐设置 set -o pipefail,如果管道命令中的任何一个命令失败,则整个管道命令将被视为失败,返回非零。 set -o pipefail # 可选:覆盖 OUTPUT_DIR 或其他关键变量 export MODEL_NAME="meta-llama/Llama-3.2-3B-Instruct" export OUTPUT_DIR=${OUTPUT:-"/mnt/zhaoyinghao/train_lora_llama3"} # 用户自定义输出目录,可被网页环境变量覆盖 # 调用训练脚本 # 由于 tee 的命令总是会执行成功(退出码为 0),上方设置 set -o pipefail 后可保证捕获管道中异常 bash /mnt/zhaoyinghao/train_lora_llama3/launch_ddp.sh | tee ${OUTPUT_DIR}/launch_ddp.log # 获取上一条命令的返回值,便于排查问题 # 若为0,表示正常退出,否则表示出现异常,而可以进行异常检测。 ret=$? if [[ ${ret} -ne 0 ]]; then echo "[$(date +"%Y-%m-%d %H:%M:%S")] $HOSTNAME 训练失败!现场保留时间为 10000s" sleep 10000 fi exit $?
Shell 训练脚本:Shell 脚本常用于封装 PyTorch DDP 训练的 bash 命令,该脚本(例如
launch_ddp.sh
)封装了启动 PyTorch DDP 训练所需的torchrun
命令及环境变量配置(如MASTER_ADDR
、WORLD_SIZE
)。您可以在脚本中检测 GPU 数量、设置环境参数、写日志等复杂逻辑,确保训练任务正确运行。NOTE
请确保 Shell 训练脚本已上传至任务挂载的共享存储目录,并具有可执行权限(
chmod +x launch_ddp.sh
)。Bash 启动命令应正确指向脚本路径。假设通过
torchrun
启动 DDP 分布式训练,以下是通用启动脚本,适用于单节点和多节点场景:bash# 获取每个 Worker 的 GPU 数量 GPUS_PER_NODE=$(nvidia-smi --query-gpu=index --format=csv,noheader | wc -l) # 设置分布式训练参数 MASTER_ADDR=${MASTER_ADDR:-'127.0.0.1'} MASTER_PORT=${MASTER_PORT:-'29500'} NNODES=${WORLD_SIZE:-'1'} # 平台提供的 WORLD_SIZE 是节点数 NODE_RANK=${RANK:-'0'} # 平台提供的 RANK 是 pod 编号,用于 --node_rank # 计算 PyTorch 所需的 WORLD_SIZE(总进程数) WORLD_SIZE=$(($GPUS_PER_NODE * $NNODES)) # 设置 torchrun 参数 DISTRIBUTED_ARGS=" --nproc_per_node $GPUS_PER_NODE \ --nnodes $NNODES \ --node_rank $NODE_RANK \ --master_addr $MASTER_ADDR \ --master_port $MASTER_PORT" # 运行 torchrun torchrun $DISTRIBUTED_ARGS your_script.py
对于 2 个 Worker,每个 Worker 4 个 GPU 的设置:
- 平台设置:
WORLD_SIZE=2
,RANK=0
(第一个 pod),RANK=1
(第二个 pod)。 - 脚本计算:
NNODES=2
,GPUS_PER_NODE=4
,WORLD_SIZE=8
(总进程数)。 torchrun
启动 8 个进程(每个 pod 4 个),全局编号为 0-3(第一个 pod)和 4-7(第二个 pod)。
与 PyTorch 环境变量命名区别
如果您熟悉 PyTorch 原生的分布式训练(如使用 torchrun
),可能会对平台的环境变量设置感到困惑,以下是详细的区别说明。
平台与 PyTorch 的区别
平台:
WORLD_SIZE
:表示 pod(节点)的数量。RANK
:表示每个 pod 的编号(从 0 到WORLD_SIZE - 1
)。
PyTorch(torchrun):
WORLD_SIZE
:表示分布式训练中的总进程数(通常为节点数 × 每节点 GPU 数
)。RANK
:表示每个进程的全局编号(从 0 到WORLD_SIZE - 1
)。
由于平台将 WORLD_SIZE
定义为节点数,而 torchrun
期望的是总进程数,因此需要在脚本中进行转换。
在平台上运行 torchrun
请参考 Pytorch DDP 关键参数,查看如何使用平台环境变量配置 torchrun
。
IMPORTANT
- 在多节点训练中,
torchrun
命令需要--node-rank
参数来指定节点的排名。该参数的值可以直接从平台提供的RANK
环境变量中获取,因为RANK
表示 pod(节点)的编号。 - 平台提供的
WORLD_SIZE
和RANK
不能直接用于torchrun
,需要通过脚本进行转换。
与 Kubeflow 兼容
针对 Kubeflow 和 PyTorch Operator 背景的用户,平台的环境变量设置与 Kubeflow 的 pytorch-operator
一致,使得从 Kubeflow 迁移到平台更加简单,可以复用相似的配置和脚本。
在 Kubeflow 的 pytorch-operator
中:
WORLD_SIZE
:表示 worker pod 的数量(即节点数)。RANK
:表示每个 worker pod 的编号。
平台的环境变量定义与 Kubeflow 相同:
WORLD_SIZE
:pod(节点)数量。RANK
:pod 编号。
故障排除
- 在任务运行过程中,可从网页端登录任务 Worker,访问 Web Terminal。在任务详情页底部可看到登录按钮。
- 任务功能提供 atlctl 命令行调试工具,您可以从 Web Terminal 登录任意的任务 Worker,执行停止任务、统一下发测试命令等调试工作。