einops 通过维度名称标注实现张量操作,代码即文档,替代 reshape、transpose、squeeze 等组合调用。

核心优势

优势说明
可读性强rearrange(x, 'b c h w -> b h w c') 即文档
可靠性高自动检查维度一致性
统一接口替代 reshape/transpose/squeeze/stack
多框架支持NumPy、PyTorch、TensorFlow、JAX

安装

pip install einops

rearrange:重排维度

from einops import rearrange
 
x = np.random.randn(2, 32, 32, 3)  # NHWC
 
# 转置:NHWC -> NCHW
y = rearrange(x, 'b h w c -> b c h w')  # (2, 3, 32, 32)
 
# 展平
y = rearrange(x, 'b h w c -> b (h w) c')  # (2, 1024, 3)
 
# 分解
y = rearrange(x, 'b (h1 h2) w c -> b h1 h2 w c', h1=2)  # (2, 2, 16, 32, 3)
 
# 图像补丁(ViT)
y = rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=8, p2=8)
# (2, 16, 192): 16个补丁,每补丁 8x8x3=192维

reduce:聚合维度

from einops import reduce
 
# 全局平均池化
y = reduce(x, 'b h w c -> b c', 'mean')  # (2, 3)
 
# 2x2 最大池化
y = reduce(x, 'b (h h2) (w w2) c -> b h w c', 'max', h2=2, w2=2)  # (2, 16, 16, 3)
 
# 批次求和
y = reduce(x, 'b h w c -> h w c', 'sum')  # (32, 32, 3)

repeat:重复张量

from einops import repeat
 
# 新维度重复
y = repeat(x, 'b h w c -> g b h w c', g=3)  # (3, 2, 32, 32, 3)
 
# 沿现有维度重复
y = repeat(x, 'b h w c -> b h w (c repeat)', repeat=2)  # (2, 32, 32, 6)
 
# 类似 tile
y = repeat(x, 'b h w c -> b (h rh) (w rw) c', rh=2, rw=3)  # (2, 64, 96, 3)

模式语法

语法说明
空格分隔b c h w 表示四个维度
()组合维度,(h w) 展平为一个维度
分解(h h1 h2) + h1=2 参数
...任意数量维度
# 只对最后两维转置
y = rearrange(x, '... a b -> ... b a')

实战:ViT 补丁嵌入

import torch
from einops import rearrange
 
x = torch.randn(4, 3, 224, 224)
patch_size = 16
 
patches = rearrange(x,
    'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
    p1=patch_size, p2=patch_size)
 
print(patches.shape)  # (4, 196, 768)
# 196 = 14x14 补丁,768 = 16x16x3

函数对照

函数功能替代操作
rearrange重塑、转置、展平reshape + transpose + flatten
reduce聚合mean/sum/max + reshape
repeat重复repeat/tile/expand