当前位置: 代码迷 >> 综合 >> einops库的rearrange、repeat、reduce 表达式怎么写
  详细解决方案

einops库的rearrange、repeat、reduce 表达式怎么写

热度:126   发布时间:2023-10-11 12:44:43.0

einops库的rearrange、repeat、reduce 表达式怎么写

最近发现一个B格很高的库。einops库。这个库主要是用来对张量进行形态操作的。网上查了一下,说是灵感来源于爱因斯坦求和约定。然后去查了一下求和约定,也没看出来有啥关联。网上查了很多资料,还是看不太懂这个表达式是怎么写的。弄了好久弄明白了,现在来写一下这玩意怎么写的。

这玩意主要有三个操作,rearrange,repeat和reduce。分别对应重新排列形状(reshape),复制和缩减这三种功能。这些函数不仅仅可以用来对numpy进行处理,还可以直接再深度学习中作为torch的一个模块来使用(具体可以百度einops.layers.torch),可以说是非常的方便。下面给个torch中使用的例子。

self.to_patch_embedding = nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.Linear(16*16*16, dim),)

这句话是直接在ViT网络中的第一层里面取出来的,可以看到它确实非常的方便。
那么这些函数里面的表达式是怎么写的呢。我就以上面那句话为例。

Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width)

作为transformer输入的第一层,它并没有任何训练参数,目的只是为了实现输入数据的变形。表达式”->“的左边是原始的输入大小的维度。b标识batchsize,c表示图像的channel数(这里是3,因为是RGB3通道)。括号括起来的表示一个维度的分解表达。比如说输入是8×3×256×512的张量(为了方便阅读,这里选择了256和512不同的长宽)。其中b就是表示8,c表示3,(h p1)表示256,(w p2)表示512。h 和p1相乘等于256,w和p2相乘等于512.

但是这里两个数相乘等于256的情况有无数种,那函数怎么知道h和p1分别取多少呢?所有在函数的后面传入了p1,p2的值。他们分别是patch_height(32)和patch_width(32).这样函数就可以根据p1和p2的值,自动的去顶h和w。就是说即使不告诉函数h和w等于多少,他也能自己算出来。

->后面的表示式表示想要输出的形状。上面说了输入是8×3×256×512的张量时b c (h p1) (w p2) 换算成数字的话就是 8 3 (32 8) (32 16) ,而输出是b (h w) (p1 p2 c) 换算成数字就是 8 (8 16) (32 32 3) 。之前说了,括号里面的所有值相乘,等于一个维度。那么进一步的缩减,就是输入是8 3 256 512输出是8 8*16 32*32*3 总的来说就是把输入的3通道图像,切割成816格3232*3的一维向量。因为transformer就是以一维向量输入的。

可以发现这个效率高的惊人,用一句话就能实现这么复杂的变换!

既然看懂了这句话,那么其他变形也很容易看懂了。举几个例子。

repeat(image, 'h w c-> (repeat h) w c', repeat=2)

这里定义了repeat=2,意思就是把图像沿着高方向扩展两倍,也就是下面这两格图像的变形:
einops库的rearrange、repeat、reduce 表达式怎么写
那这里就有人问了,为什么(repeat h)是复制两份,而不是直接把图像扩大成两倍呢。这里是因为repeat在h的前面,就是说有2个h,而不是把h扩大成两倍。那么只要换一下repeat和h的位置,就可以实现扩大两倍的效果。如下:

repeat(image, 'h w c-> (h repeat) w c', repeat=2)

einops库的rearrange、repeat、reduce 表达式怎么写
可以看到,这里的高从930变成了1860.这也是我们在使用的时候需要注意的点,拆分和重组的时候要注意字母的顺序。

具体还有啥用法大家可以去百度,上面有很多例子,这里就不一一讲解了。说白了这个库就是实现高效的维度变化的,学会了逼格贼高,而且超级方便!

  相关解决方案