Multi-head Self Attention
def forward(x, mask, H, D):
q = k = v = x # [B, L, N]
B, L, N = x.shape
# linear
q = W_q(q).reshape([B, L, H, D]).transpose(1, 2) # [B, H, T, D]
k = W_k(k).reshape([B, L, H, D]).transpose(1, 2) # [B, H, T, D]
v = W_v(v).reshape([B, L, H, D]).transpose(1, 2) # [B, H, T, D]
# attention
logits = matmul(q, k.transpose(-2, -1)) / sqrt(D) + mask
a = softmax(logits)
# output
o = matmul(a, v)
o = W_o(o).reshape([B, L, N])
return o