Data Parallel v.s. Distributed Data Parallel
本文由 Gui-Ru Li 撰寫後編修,且持續更新中
參考資料
DataParallel
single process, multi-thread
將輸入一個 batch 的數據均分成多份,分別送到對應的 GPU 進行計算,各個 GPU 得到的梯度累加。
與 Module 相關的所有數據也都會以淺複製的方式複製多份。
每個 GPU 將針對各自的輸入數據獨立進行 forward 計算,在 backward 時,每個卡上的梯度會匯總到原始的 module 上,再用反向傳播更新單個 GPU 上的模型參數,再將更新後的模型參數複製到剩餘指定的 GPU 中,以此來實現並行。
由於 GPU 0 作為master來進行梯度的匯總和模型的更新,再將計算任務下發給其他GPU,所以他的記憶體和使用率會比其他的高。
全程只維護一個 optimizer,對各 GPU 上梯度進行求和,而在主 GPU 進行參數更新,之後再將模型參數 broadcast 到其他 GPU。
DistributedDataParallel
multi process
在每次迭代中,每個process具有自己的 optimizer ,並獨立完成所有的優化步驟,進程內與一般的訓練無異。
在各process梯度計算完成之後需要將梯度進行匯總平均,然後再由 rank=0 的進程,將其 broadcast 到所有進程,接著各進程用該梯度來獨立的更新參數。
由於各進程中的模型,初始參數一致 (初始時刻進行一次 broadcast),而每次用於更新參數的梯度也一致,因此各進程的模型參數始終保持一致。
相較於 DataParallel,torch.distributed 傳輸的數據量更少,因此速度更快,效率更高。