报错:RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1
解决:把DP改成DDP解决了
报错的代码:
# DP模式, module_a是一个分类器
module_a = torch.nn.DataParallel(module_a)
改完的代码:
local_rank = 0
# DDP模式
module_a = torch.nn.parallel.DistributedDataParallel(module=module_a, broadcast_buffers=False, device_ids=[local_rank])
module_a.train()