当前位置: 代码迷 >> 综合 >> Torch load载入数据集的标准流程
  详细解决方案

Torch load载入数据集的标准流程

热度:49   发布时间:2023-12-24 13:39:33.0

不管是图像识别还是自然语言处理任务,在训练时都需要将数据按照batch的方式载入。本文记录了用Torch load数据的流程

  •  第一步当然是先有数据的原始文件,图像处理就是一堆图片,NLP就是很多句子。这个原始文件保存成任何数据格式都没关系,因为之后会先处理成Torch.utils.data.Dataset的一个子类。
  •  转化数据为Torch.utils.data.Dataset,代码如下:
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSamplerclass myDataset(Dataset):  # 这是一个Dataset子类def __init__(self):self.Data = np.asarray([[2, 2], [1, 4], [4, 1], [12, 4], [8, 5]])  self.Label = np.asarray([4, 2, 2, 4, 2]) def __getitem__(self, index):x = torch.from_numpy(self.Data[index])y = (self.Label[index])return x, y  def __len__(self):return len(self.Data)mydata = myDataset()

这个时候可以用下标或者len访问数据集的某个元素或者长度

  • 定义一个sampler,也就是数据被返回的规则。这个不一定需要,如果定义这个,在下一步的shuffle参数需要改为false

rand_sampler = RandomSampler(mydata)
  •  定义Dataloader 
data_loader = DataLoader(mydata,batch_size=2,shuffle=False,sampler=rand_sampler,num_workers=4)
  • 然后就可以用循环访问出一个batch一个batch的数据了。
for i,traindata in enumerate(data_loader):print('i:',i)Data,Label=traindataprint('data:',Data)print('Label:',Label)

  相关解决方案