LLM并行训练2-张量并行

切分方式

前置知识

矩阵乘法求导

\[Y=f(AB)=f(C) \]

\[\frac{\partial Y}{\partial A} = \frac{\partial Y}{\partial C} \cdot B^{T} \]

\[\frac{\partial Y}{\partial B} = A^{T} \cdot \frac{\partial Y}{\partial C} \]

以下定义X的dim为(M,K), W的dim为(K, N), 平均切分z次

行式切分

forward

\[Y= X_1W_1 + X_2W_2 \]

\[X= concat(X_1, X_2, axis=1) \]

\[W = concat(W_1, W_2, axis= 0) \]

先把X按列切分每个子块的dim都是 (M, K/z), W1的dim(K/z, N), 这里利用了分块矩阵乘法的性质, 把切分好的Xi scatter到对应W的卡上, 计算完成后相加结果矩阵即可拿到Y的前向结果

backward:

\[\frac{\partial L}{\partial W_i} = \frac{\partial L}{\partial Y}\cdot \frac{\partial Y}{\partial Y_i}\cdot \frac{\partial Y_i}{\partial W_i} \\ \]

Y对Yi的偏导因为 Y= Y1 + Y2求导偏导是1, 可以直接省略. 只需要把L对Y的偏导广播到W1, W2各自的卡上, 他们就能各自计算对应的梯度来更新W. L对X的偏导也是两张卡各自计算后(L对Y的偏导 * Wi的转置), 最后按列concat到一起就能得到最终X的偏导

列式切分

forward:

\[Y= concat(X_1W_1, X_2W_2, axis=1) \\ \]

因为按列切分没有改变矩阵乘法的中间dim, 前向只需要concat起来两个切分后的乘法结果

backward:

\[\frac{\partial L}{\partial W_i} = \frac{\partial L}{\partial Y}\cdot \frac{\partial Y_i}{\partial W_i} \]

\[\frac{\partial L}{\partial X} = \frac{\partial L}{\partial X_1} + \frac{\partial Y_i}{\partial X_2} \\ \]

这里是需要先把L对Y的导数切分后再传给各张卡, L对W的偏导计算方法和行切分一样, L对X的偏导因为对于损失L,X既参与了XW1的计算,也参与了XW2的计算, 所以需要把两张卡上对X1,X2的偏导求和. 得到最终的结果

MLP并行

以Y = GELU(X * A) * B 为例

forward: 把参数A进行列切分, B进行行切分. 先把X广播到每张卡上, 每张卡直接算完从A->B的所有流程后, AllReduce计算结果就能得到Y

Backward: 把Grad(y)广播到各张卡上独立反向, 然后allreduce所有的grad(xi), 就能得到grad(x)

这个设计真挺巧妙的. 如果我们只用行切分或者列切分, 在两个矩阵计算的中间必然会进行一次集合通信的同步. 列切分是AllGather, 行切分是AllReduce. 然而先行后列, 中间除了节省掉集合通信的成本, 连第二次列切分的时候需要先对X做分块操作的步骤都节省了. 牛啊

MultiHeadAttention并行

如果有两个头两张卡, 把V,Q,K权重矩阵进行列切分后. 算出来的Q1,Q2 通过concat就能得到Q, 完美的切分了数据和算力..真的感觉天然适配张量并行, 只要我们保证head数能整除卡数就能完全利用起来所有的卡.

总结

张量并行结合了分块矩阵运算的性质, 通过合理的切分输入和参数, 再加上行列切分的合理配置. 就能节省掉很多过程中的不必要通信和冗余计算. 而且对效果无损, 看的过程中感觉好神奇.

参考

https://zhuanlan.zhihu.com/p/622212228

千百度
© 版权声明
THE END
喜欢就支持一下吧
点赞8 分享
评论 抢沙发
头像
欢迎您留下宝贵的见解!
提交
头像

昵称

取消
昵称表情代码图片

    暂无评论内容