张芷铭的个人博客

einops 通过维度名称标注实现张量操作,代码即文档,替代 reshape、transpose、squeeze 等组合调用。

核心优势

优势说明
可读性强rearrange(x, 'b c h w -> b h w c') 即文档
可靠性高自动检查维度一致性
统一接口替代 reshape/transpose/squeeze/stack
多框架支持NumPy、PyTorch、TensorFlow、JAX

安装

1
pip install einops

rearrange:重排维度

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
from einops import rearrange

x = np.random.randn(2, 32, 32, 3)  # NHWC

# 转置:NHWC -> NCHW
y = rearrange(x, 'b h w c -> b c h w')  # (2, 3, 32, 32)

# 展平
y = rearrange(x, 'b h w c -> b (h w) c')  # (2, 1024, 3)

# 分解
y = rearrange(x, 'b (h1 h2) w c -> b h1 h2 w c', h1=2)  # (2, 2, 16, 32, 3)

# 图像补丁(ViT)
y = rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=8, p2=8)
# (2, 16, 192): 16个补丁,每补丁 8x8x3=192维

reduce:聚合维度

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
from einops import reduce

# 全局平均池化
y = reduce(x, 'b h w c -> b c', 'mean')  # (2, 3)

# 2x2 最大池化
y = reduce(x, 'b (h h2) (w w2) c -> b h w c', 'max', h2=2, w2=2)  # (2, 16, 16, 3)

# 批次求和
y = reduce(x, 'b h w c -> h w c', 'sum')  # (32, 32, 3)

repeat:重复张量

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
from einops import repeat

# 新维度重复
y = repeat(x, 'b h w c -> g b h w c', g=3)  # (3, 2, 32, 32, 3)

# 沿现有维度重复
y = repeat(x, 'b h w c -> b h w (c repeat)', repeat=2)  # (2, 32, 32, 6)

# 类似 tile
y = repeat(x, 'b h w c -> b (h rh) (w rw) c', rh=2, rw=3)  # (2, 64, 96, 3)

模式语法

语法说明
空格分隔b c h w 表示四个维度
()组合维度,(h w) 展平为一个维度
分解(h h1 h2) + h1=2 参数
...任意数量维度
1
2
# 只对最后两维转置
y = rearrange(x, '... a b -> ... b a')

实战:ViT 补丁嵌入

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
import torch
from einops import rearrange

x = torch.randn(4, 3, 224, 224)
patch_size = 16

patches = rearrange(x,
    'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
    p1=patch_size, p2=patch_size)

print(patches.shape)  # (4, 196, 768)
# 196 = 14x14 补丁,768 = 16x16x3

函数对照

函数功能替代操作
rearrange重塑、转置、展平reshape + transpose + flatten
reduce聚合mean/sum/max + reshape
repeat重复repeat/tile/expand

Comments