当前位置: 代码迷 >> 综合 >> pytorch转onnx模型多输入问题(如:Bert)
  详细解决方案

pytorch转onnx模型多输入问题(如:Bert)

热度:10   发布时间:2023-12-03 11:42:49.0

举个例子:
Bert模型有三个输入,因此就要创建三个dummy_input,然后利用一个tuple,传入函数中。

dummy_input0 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
dummy_input1 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
dummy_input2 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
torch.onnx.export(model. (dummy_input0, dummy_input1, dummy_input2), filepath)