Last updated
Last updated
以下示例均基于 einops 库进行演示;
einops: Deep learning operations reinvented (for pytorch, tensorflow, jax and others)
更多示例见 docstring
import torch
import einops
# 示例1:reorder
x = torch.randn(2, 3, 4, 5)
# [B H W C] -> [B W H C] -> [B W C H] -> [W B C H]
o1 = x.transpose(1, 2).transpose(2, 3).transpose(0, 1)
o2 = einops.rearrange(x, 'B H W C -> W B C H')
o3 = torch.einsum('BHWC->WBCH', x) # torch 提供的方法
assert torch.allclose(o1, o2) and torch.allclose(o2, o3)
# 示例2:flatten
x = torch.randn(2, 3, 4, 5)
# [B H W C] -> [B W H C] -> [B*W H C]
o1 = x.transpose(1, 2).reshape(8, 3, 5)
o2 = einops.rearrange(x, 'B H W C -> (B W) H C')
assert torch.allclose(o1, o2)
# 示例3:split
x = torch.randn(6, 4, 5)
# [B*W H C] -> [B W H C] -> [B H W C]
o1 = x.reshape(2, 3, 4, 5).transpose(1, 2)
o2 = einops.rearrange(x, '(B W) H C -> B H W C', B=2) # [B*W H C] -> [B H W C]
assert torch.allclose(o1, o2)
# 示例4:unsqueeze and squeeze
x = torch.randn(2, 4, 3)
# [B L C] -> [B 1 L C 1]
o1 = x.unsqueeze(-1).unsqueeze(1)
o2 = einops.rearrange(x, 'B L C -> B 1 L C 1')
o3 = einops.rearrange(o2, 'B 1 L C 1 -> B L C')
assert torch.allclose(o1, o2) and torch.allclose(o3, x)
# 示例5:concat
x = torch.randn(2, 4, 3)
xs = [x, x, x]
# N * [B L C] -> [B*N L C]
o1 = torch.concat(xs)
o2 = einops.rearrange(xs, 'N B L C -> (N B) L C')
assert torch.allclose(o1, o2)
# 示例6:stack
x = torch.randn(4, 4, 3)
xs = [x, x, x]
# B * [H W C] -> [B H W C]
o1 = torch.stack(xs)
o2 = einops.rearrange(xs, 'B H W C -> B H W C')
assert torch.allclose(o1, o2)
x = torch.randn(2, 3, 8, 8) # [B I H W] -> [B N P]
o = einops.rearrange(x, 'B I (H P1) (W P2) -> B (H W) (P1 P2 I)', P1=2, P2=2)
print(o.shape) # [2, 16, 12]
更多示例见 docstring
# 示例1:复制
x = torch.randn(2, 3, 4)
# [B H W] -> [B H W 1] -> [B H W 3]
o1 = x.unsqueeze(-1).tile((1, 1, 1, 3))
o2 = einops.repeat(x, 'B H W -> B H W C', C=3)
assert torch.allclose(o1, o2)
# 示例2:上采样(upsampling)
x = torch.randn(2, 3)
o1 = x.tile((2, 3))
o2 = einops.repeat(x, 'B C -> (2 B) (3 C)')
assert torch.allclose(o1, o2)
更多示例见 docstring
支持的归约方法:('min', 'max', 'sum', 'mean', 'prod')
import torch
import einops
# 示例1:一维池化(下采样)
x = torch.randn(2, 3, 4, 5) # [B H W C]
o1 = x.mean(-1)
o2 = einops.reduce(x, 'B H W C -> B H W', reduction='mean')
assert torch.allclose(o1, o2)
o1, _ = torch.max(x, -1, keepdim=True)
o2 = einops.reduce(x, 'B H W C -> B H W 1', reduction='max')
o3 = einops.reduce(x, 'B H W C -> B H W ()', reduction='max') # 等价写法
assert torch.allclose(o1, o2) and torch.allclose(o1, o3)
# 示例2:二维池化
x = torch.randn(20, 30, 40, 40)
o1 = x.mean((-2, -1), keepdim=True)
o2 = einops.reduce(x, 'B C H W -> B C 1 1', reduction='mean')
assert torch.allclose(o1, o2)
o1 = nn.MaxPool2d(kernel_size=2, stride=2)(x)
o2 = einops.reduce(x, 'B C (H S1) (W S2) -> B C H W', reduction='max', S1=2, S2=2)
assert torch.allclose(o1, o2)
# MaxPool2d 还支持 kernel_size != stride 的情况,这一点 einops.reduce 做不到