当前位置: 代码迷 >> 综合 >> 关于torch.max(a,dim)中维度的选取
  详细解决方案

关于torch.max(a,dim)中维度的选取

热度:47   发布时间:2024-02-29 12:22:05.0

Pytorch中,经常会使用torch.max(a,dim)对tensor进行处理,特别是针对多维的tensor,就感觉对dim的选取似懂非懂。

一、针对1维的数据

这个比较好理解,就是针对1维的数据取最大值,返回一个tensor类型的数值,和该数值对应的下标,合起来就是一个tuple类型。

import torcha = torch.randn(3) #随机生成数组
max=torch.max(a,dim=0) #默认dim=0
print("a:\n", a)
print("************************************************")
print("max(a):", max) #输出最大值,以及对应的索引,tuple类型
print("max(a)_value:", max[0]) #只返回tensor数值
print("max(a)_index:", max[1]) #只返回对应的索引<<
a:tensor([ 1.5691, -0.7801, -1.2262])
************************************************
max(a): (tensor(1.5691), tensor(0))
max(a)_value: tensor(1.5691)
max(a)_index: tensor(0)

二、针对2维数据

此时的tensor(行,列),可以理解为一张特征图,dim=0就是行间进行比较,dim=1就是列间进行比较

import torcha = torch.randn(2,3) #随机生成数组
max_0=torch.max(a,dim=0) #针对第1个元素“2”,对应的是行
max_1=torch.max(a,dim=1) #针对第2个元素“3”,对应的是列
print("a:\n", a)
print("************************************************")
print("max(a)_0:", max_0)  #dim=0,行与行之间进行比较,所以返回每一列的最大值
print("max(a)_1:", max_1)  #dim=1,列与列之间进行比较,所以返回每一行的最大值<<
a:tensor([[ 0.1734, -0.7264,  0.6981],[ 0.0859,  1.2663, -0.0851]])
************************************************
max(a)_0: (tensor([ 0.1734,  1.2663,  0.6981]), tensor([ 0,  1,  0]))
max(a)_1: (tensor([ 0.6981,  1.2663]), tensor([ 2,  1]))

三、针对3维数据

此时的tensor(通道,行,列),可以理解为很多张特征图叠加在一起,dim=0就是通道间进行比较,dim=1就是行间进行比较,dim=2就是列间进行比较。

import torcha = torch.randn(2,3,4) #随机生成数组
max_0=torch.max(a,dim=0) #针对第1个元素“2”,对应的是通道
max_1=torch.max(a,dim=1) #针对第2个元素“3”,对应的是行
max_2=torch.max(a,dim=2) #针对第2个元素“4”,对应的是列
print("a:\n", a)
print("************************************************")
print("max(a)_0:", max_0)  #dim=0,通道间进行比较,所以返回每一张特征图,同一像素位置上的最大值
print("max(a)_1:", max_1)  #dim=1,行与行之间进行比较,所以返回每一张特征图,每一列的最大值
print("max(a)_2:", max_1)  #dim=1,列与列之间进行比较,所以返回每一张特征图,每一行的最大值<<
a:tensor([[[ 0.5323,  1.5229, -0.6122,  0.6054],[ 1.2424, -1.6005,  0.0779,  0.9227],[-0.6340, -0.5770, -0.1672,  0.3598]],[[-0.3770, -0.4992,  1.8444, -1.1040],[ 1.2238,  0.7283, -1.6462,  0.0325],[-0.3555, -0.2599,  1.5741,  1.0683]]])
************************************************
max(a)_0: (tensor([[ 0.5323,  1.5229,  1.8444,  0.6054],[ 1.2424,  0.7283,  0.0779,  0.9227],[-0.3555, -0.2599,  1.5741,  1.0683]]), tensor([[ 0,  0,  1,  0],[ 0,  1,  0,  0],[ 1,  1,  1,  1]]))
max(a)_1: (tensor([[ 1.2424,  1.5229,  0.0779,  0.9227],[ 1.2238,  0.7283,  1.8444,  1.0683]]), tensor([[ 1,  0,  1,  1],[ 1,  1,  0,  2]]))
max(a)_2: (tensor([[ 1.2424,  1.5229,  0.0779,  0.9227],[ 1.2238,  0.7283,  1.8444,  1.0683]]), tensor([[ 1,  0,  1,  1],[ 1,  1,  0,  2]]))

四、针对4维数据

此时的tensor(batch_size,channel, 行,列),可以理解为一个批次的训练数据的集合,dim=0,是批次间的比较;dim=1,是每个批次,自己通道间的比较;dim=2对应的行比较;dim=3对应的是列比较

import torcha = torch.randn(1,2,3,4) #随机生成数组
max_0=torch.max(a,dim=0) #针对第1个元素“1”,对应的是batch_size
max_1=torch.max(a,dim=1) #针对第2个元素“2”,对应的是通道
max_2=torch.max(a,dim=2) #针对第2个元素“3”,对应的是行
max_3=torch.max(a,dim=3) #针对第2个元素“4”,对应的是列print("a:\n", a)
print("************************************************")print("max(a)_0:", max_0)  #dim=0,多个输入之间进行比较,返回的是每次输入时,每一张特征图,同一像素位置上的最大值#一般为1,一张张读的
print("max(a)_1:", max_1)  #dim=1,通道间进行比较,所以返回每一张特征图,同一像素位置上的最大值
print("max(a)_2:", max_2)  #dim=2,行与行之间进行比较,所以返回每一张特征图,每一列的最大值
print("max(a)_3:", max_3)  #dim=3,列与列之间进行比较,所以返回每一张特征图,每一行的最大值<<a:tensor([[[[ 0.6404,  0.5116, -0.5562,  2.2283],[-0.6507,  0.4440,  0.8723, -0.6538],[ 0.0352,  1.0738,  0.2382,  0.7763]],[[-0.5208,  0.4854, -0.0950,  1.3100],[ 0.0433, -0.6561,  0.1956, -0.3584],[-1.0031, -1.7104,  0.6768, -0.1648]]]])
************************************************
max(a)_0: (tensor([[[ 0.6404,  0.5116, -0.5562,  2.2283],[-0.6507,  0.4440,  0.8723, -0.6538],[ 0.0352,  1.0738,  0.2382,  0.7763]],[[-0.5208,  0.4854, -0.0950,  1.3100],[ 0.0433, -0.6561,  0.1956, -0.3584],[-1.0031, -1.7104,  0.6768, -0.1648]]]), tensor([[[ 0,  0,  0,  0],[ 0,  0,  0,  0],[ 0,  0,  0,  0]],[[ 0,  0,  0,  0],[ 0,  0,  0,  0],[ 0,  0,  0,  0]]]))
max(a)_1: (tensor([[[ 0.6404,  0.5116, -0.0950,  2.2283],[ 0.0433,  0.4440,  0.8723, -0.3584],[ 0.0352,  1.0738,  0.6768,  0.7763]]]), tensor([[[ 0,  0,  1,  0],[ 1,  0,  0,  1],[ 0,  0,  1,  0]]]))
max(a)_2: (tensor([[[ 0.6404,  1.0738,  0.8723,  2.2283],[ 0.0433,  0.4854,  0.6768,  1.3100]]]), tensor([[[ 0,  2,  1,  0],[ 1,  0,  2,  0]]]))
max(a)_3: (tensor([[[ 2.2283,  0.8723,  1.0738],[ 1.3100,  0.1956,  0.6768]]]), tensor([[[ 3,  2,  1],[ 3,  2,  2]]]))