使用爱因斯坦标记法操作张量

last modify

核心操作

以下示例均基于 einops 库进行演示;

einops: Deep learning operations reinvented (for pytorch, tensorflow, jax and others)

维度变换 einops.rearrange

更多示例见 docstring

import torch
import einops

# 示例1:reorder
x = torch.randn(2, 3, 4, 5)
# [B H W C] -> [B W H C] -> [B W C H] -> [W B C H]
o1 = x.transpose(1, 2).transpose(2, 3).transpose(0, 1)
o2 = einops.rearrange(x, 'B H W C -> W B C H')
o3 = torch.einsum('BHWC->WBCH', x)  # torch 提供的方法
assert torch.allclose(o1, o2) and torch.allclose(o2, o3)

# 示例2:flatten
x = torch.randn(2, 3, 4, 5)
# [B H W C] -> [B W H C] -> [B*W H C]
o1 = x.transpose(1, 2).reshape(8, 3, 5)
o2 = einops.rearrange(x, 'B H W C -> (B W) H C')
assert torch.allclose(o1, o2)

# 示例3:split
x = torch.randn(6, 4, 5)
# [B*W H C] -> [B W H C] -> [B H W C]
o1 = x.reshape(2, 3, 4, 5).transpose(1, 2)
o2 = einops.rearrange(x, '(B W) H C -> B H W C', B=2)  # [B*W H C] -> [B H W C]
assert torch.allclose(o1, o2)

# 示例4:unsqueeze and squeeze
x = torch.randn(2, 4, 3)
# [B L C] -> [B 1 L C 1]
o1 = x.unsqueeze(-1).unsqueeze(1)
o2 = einops.rearrange(x, 'B L C -> B 1 L C 1')
o3 = einops.rearrange(o2, 'B 1 L C 1 -> B L C')
assert torch.allclose(o1, o2) and torch.allclose(o3, x)

# 示例5:concat
x = torch.randn(2, 4, 3)
xs = [x, x, x]
# N * [B L C] -> [B*N L C]
o1 = torch.concat(xs)
o2 = einops.rearrange(xs, 'N B L C -> (N B) L C')
assert torch.allclose(o1, o2)

# 示例6:stack
x = torch.randn(4, 4, 3)
xs = [x, x, x]
# B * [H W C] -> [B H W C]
o1 = torch.stack(xs)
o2 = einops.rearrange(xs, 'B H W C -> B H W C')
assert torch.allclose(o1, o2)

综合示例:image2patch

x = torch.randn(2, 3, 8, 8)  # [B I H W] -> [B N P]
o = einops.rearrange(x, 'B I (H P1) (W P2) -> B (H W) (P1 P2 I)', P1=2, P2=2)
print(o.shape)  # [2, 16, 12]

复制 eioops.repeat

更多示例见 docstring

# 示例1:复制
x = torch.randn(2, 3, 4)
# [B H W] -> [B H W 1] -> [B H W 3]
o1 = x.unsqueeze(-1).tile((1, 1, 1, 3))
o2 = einops.repeat(x, 'B H W -> B H W C', C=3)
assert torch.allclose(o1, o2)

# 示例2:上采样(upsampling)
x = torch.randn(2, 3)
o1 = x.tile((2, 3))
o2 = einops.repeat(x, 'B C -> (2 B) (3 C)')
assert torch.allclose(o1, o2)

归约 einops.reduce

更多示例见 docstring

  • 支持的归约方法:('min', 'max', 'sum', 'mean', 'prod')

import torch
import einops

# 示例1:一维池化(下采样)
x = torch.randn(2, 3, 4, 5)  # [B H W C]
o1 = x.mean(-1)
o2 = einops.reduce(x, 'B H W C -> B H W', reduction='mean')
assert torch.allclose(o1, o2)
o1, _ = torch.max(x, -1, keepdim=True)
o2 = einops.reduce(x, 'B H W C -> B H W 1', reduction='max')
o3 = einops.reduce(x, 'B H W C -> B H W ()', reduction='max')  # 等价写法
assert torch.allclose(o1, o2) and torch.allclose(o1, o3)

# 示例2:二维池化
x = torch.randn(20, 30, 40, 40)
o1 = x.mean((-2, -1), keepdim=True)
o2 = einops.reduce(x, 'B C H W -> B C 1 1', reduction='mean')
assert torch.allclose(o1, o2)
o1 = nn.MaxPool2d(kernel_size=2, stride=2)(x)
o2 = einops.reduce(x, 'B C (H S1) (W S2) -> B C H W', reduction='max', S1=2, S2=2)
assert torch.allclose(o1, o2)
# MaxPool2d 还支持 kernel_size != stride 的情况,这一点 einops.reduce 做不到

Last updated