PyTorch 提供丰富的张量维度操作函数,涵盖形状调整、维度变换、拼接分割等场景。
维度调整
| 函数 | 功能 | 示例 |
|---|
squeeze() | 去除尺寸为 1 的维度 | (1,3,1,5) → (3,5) |
unsqueeze(dim) | 插入尺寸为 1 的维度 | (3,5) → (1,3,5) |
view(*shape) | 重塑形状 | (2,3,4) → (3,8) |
reshape(*shape) | 重塑形状(返回新张量) | 同 view |
flatten() | 展平 | (2,3,4) → (24) |
维度变换
| 函数 | 功能 | 示例 |
|---|
permute(*dims) | 重排维度顺序 | (2,3,5) → (5,2,3) |
transpose(dim0, dim1) | 交换两个维度 | (2,3) → (3,2) |
拼接与堆叠
# 沿现有维度拼接
torch.cat([t1, t2], dim=0) # (2,3) + (2,3) → (4,3)
# 沿新维度堆叠
torch.stack([t1, t2], dim=0) # (3,4) + (3,4) → (2,3,4)
分割与选择
| 函数 | 功能 |
|---|
split(tensor, size, dim) | 按块大小分割 |
chunk(tensor, chunks, dim) | 按块数量分割 |
index_select(tensor, dim, index) | 按索引选择 |
示例
import torch
x = torch.randn(2, 3, 4)
# 展平
x.flatten() # (24,)
x.flatten(start_dim=1) # (2, 12)
# 转置
x.permute(2, 0, 1) # (4, 2, 3)
x.transpose(0, 1) # (3, 2, 4)
# view vs reshape
x.view(3, 8) # 共享内存
x.reshape(3, 8) # 可能返回副本
# 广播
a = torch.randn(3, 1)
b = torch.randn(3, 4)
torch.broadcast_tensors(a, b) # a 广播为 (3, 4)
view vs reshape
view:要求张量连续,共享内存
reshape:自动处理不连续情况,可能返回副本