当前位置: 代码迷 >> 综合 >> pytorch结构函数tensor.view()和terson.reshape(),以及contiguous()存在和正解
  详细解决方案

pytorch结构函数tensor.view()和terson.reshape(),以及contiguous()存在和正解

热度:9   发布时间:2024-02-10 09:00:44.0

前要

这两个函数都是用来改变tensor的形状的,但是他们两是有内存共享的区别,因为这个区别。

三个提前说的注意点:

  1. x是个tensorx += 1x= x+1pytorch中是有区别的, 后者会重开内存地址。
  2. 学过面向对象语言都知道,对象和数据的地址是分开的,比如x = [1,2,3],对象是x,数据是[1,2,3]
  3. id()函数只能看对象内存地址,storage()函数可以看数据内存地址。

1. 文档解释

view()

# 使用方法
x_tensor.view(*args) → Tensor
  1. 返回一个有相同数据结构不同的tensor。
  2. 返回的tensor必须有与原tensor相同的数据和相同数目的元素,但可以有不同的大小。
  3. 一个tensor必须是连续的contiguous()才能被view()。

参数:

  • x_tensor: 想改变结构的tensor
  • args: 目标结构

例:

x = torch.randn(4, 4) 
x.view(2,8) 

reshape()

reshape和上面的是一样的使用方法,但是是对于不连续的数据可以改变形状。

2. 区别

view()reshape()有两点以下区别。

  1. view()是内存共享的,reshape()不是共享的,可以理解为reshape获取的是真实副本而不是改变观察视角,然而实际上并不会这样使用它,因为reshape通常获取不到原数据的拷贝(改变了data的内存地址),所以一般用clone()克隆后再view
x = torch.randint(0,20,(2,3))
''' tensor([[ 7., 2., 16.],[17., 19., 15.]]) '''
y = x.view(6)
''' tensor([ 7., 2., 16., 17., 19., 15.]) '''
z = x.view(-1,2)
''' tensor([[ 7., 2.],[16., 17.],[19., 15.]]) '''
##########注意区别##############
# 看看对象的地址是否一样:
print(id(x)==id(y), id(x)==id(z), id(y)==id(z))
print(id(x.data)==id(y.data), id(x.data)==id(z.data),id(y.data)==id(z.data))
''' False False False True True True '''
# 这时候我们改变x
x[0] = 1
''' 都改变了 tensor([[ 1., 1., 1.],[17., 19., 15.]]) tensor([ 1., 1., 1., 17., 19., 15.]) tensor([[ 1., 1.],[ 1., 17.],[19., 15.]]) '''
# 改变z呢?
z[0] = 0
''' # 也都改变了 tensor([[ 0., 0., 1.],[17., 19., 15.]]) tensor([ 0., 0., 1., 17., 19., 15.]) tensor([[ 0., 0.],[ 1., 17.],[19., 15.]]) '''
但是!!!要注意我开头说到的第一点
x  = x + 1
''' tensor([[ 1., 1., 2.],[18., 20., 16.]]) tensor([ 0., 0., 1., 17., 19., 15.]) tensor([[ 0., 0.],[ 1., 17.],[19., 15.]]) '''
会发现只有x改变了,其他没变,
这是因为这样会开辟新的内存空间
所以啊,写论文尽量用x = x + ?这个操作
'你如果把x=x +1改成x+=1会发现y,z输出没有变,只有x变了'
  1. 对于不连续的数据,reshape()可以改变,例子如下:
x = torch.rand(3,4)
x = x.transpose(0,1)
print(x.is_contiguous())
'False'
# 会发现
x.view(3,4)
''' RuntimeError: invalid argument 2: view size is not compatible with input tensor's.... 就是不连续导致的 '''
# 但是这样是可以的。
x = x.contiguous()
x.view(3,4)

我们再看看reshape()

x = torch.rand(3,4)
x = x.permute(1,0) # 等价x = x.transpose(0,1)
x.reshape(3,4)
'''这就不报错了 说明x.reshape(3,4) 这个操作 等于x = x.contiguous().view() 尽管如此,但是我们还是不推荐使用reshape 除非为了获取完全不同但是数据相同的克隆体 '''

3.关于连续contiguous()

目前pytorch只要是transpose()permute()这两个函数用过后,tensor都会变得不再连续,就不可以使用view().

调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。
(这一段看文字你肯定不理解,你也可以不用理解,有空我会画图补上)

你只需要记住了,使用view()之前,只要使用了transpose()permute()这两个函数一定要contiguous()

  相关解决方案