张芷铭的个人博客

PyTorch 提供丰富的张量维度操作函数,涵盖形状调整、维度变换、拼接分割等场景。

维度调整

函数功能示例
squeeze()去除尺寸为 1 的维度(1,3,1,5) → (3,5)
unsqueeze(dim)插入尺寸为 1 的维度(3,5) → (1,3,5)
view(*shape)重塑形状(2,3,4) → (3,8)
reshape(*shape)重塑形状(返回新张量)同 view
flatten()展平(2,3,4) → (24)

维度变换

函数功能示例
permute(*dims)重排维度顺序(2,3,5) → (5,2,3)
transpose(dim0, dim1)交换两个维度(2,3) → (3,2)

拼接与堆叠

1
2
3
4
5
# 沿现有维度拼接
torch.cat([t1, t2], dim=0)  # (2,3) + (2,3) → (4,3)

# 沿新维度堆叠
torch.stack([t1, t2], dim=0)  # (3,4) + (3,4) → (2,3,4)

分割与选择

函数功能
split(tensor, size, dim)按块大小分割
chunk(tensor, chunks, dim)按块数量分割
index_select(tensor, dim, index)按索引选择

示例

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import torch

x = torch.randn(2, 3, 4)

# 展平
x.flatten()                    # (24,)
x.flatten(start_dim=1)         # (2, 12)

# 转置
x.permute(2, 0, 1)             # (4, 2, 3)
x.transpose(0, 1)              # (3, 2, 4)

# view vs reshape
x.view(3, 8)                   # 共享内存
x.reshape(3, 8)                # 可能返回副本

# 广播
a = torch.randn(3, 1)
b = torch.randn(3, 4)
torch.broadcast_tensors(a, b)  # a 广播为 (3, 4)

view vs reshape

  • view:要求张量连续,共享内存
  • reshape:自动处理不连续情况,可能返回副本

Comments