当前位置: 代码迷 >> 综合 >> pytorch分布式 RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]
  详细解决方案

pytorch分布式 RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]

热度:45   发布时间:2023-12-15 16:04:41.0

报错:RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1

解决:把DP改成DDP解决了
报错的代码:

# DP模式, module_a是一个分类器
module_a = torch.nn.DataParallel(module_a)

改完的代码:

local_rank = 0
# DDP模式
module_a = torch.nn.parallel.DistributedDataParallel(module=module_a, broadcast_buffers=False, device_ids=[local_rank])
module_a.train()
  相关解决方案