Attention
1. Input
There’re 3 inputs Q(query), K(key), V(value) for attention mechanism. If Q=K=V, we call it ‘self-attention’. Also, there’re several rules to calculate it.
$$Attention(Q,K,V) = Softmax(Linear([Q,K])) \cdot V$$
$$Attention(Q,K,V) = Softmax(sum(\tanh(Linear([Q,K])))) \cdot V$$
$$Attention(Q,K,V) = Softmax(\frac{Q \cdot K^T}{\sqrt{d_k}}) \cdot V$$
The ‘bmm’ is a special tensor multiply operation, batch matrices multiplication.
$$(b, n, m)*(b, m, p) \rightarrow (b, n, p)$$
Attention is usually used in seq2seq task.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attn(nn.Module):
def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
# query_size, key_size代表最后一个维度
# V尺寸为[1, value_size1, value_size2]
super(Attn, self).__init__()
self.query_size = query_size
self.key_size = key_size
self.value_size1 = value_size1
self.value_size2 = value_size2
self.output_size = output_size
self.attn = nn.Linear(self.query_size + self.key_size, self.value_size1)
self.attn_combine = nn.Linear(self.query_size + self.value_size2, self.outpu_size)
def forward(self, Q, K, V):
attn_weights = F.softmax( self.attn(torch.cat((Q[0], k[0]), 1)) , dim=1)
attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)
output = torch.cat((Q[0], attn_applied[0]), 1) # 降维
output = self.attn_combine(output).unsqueeze(0)
return output, attn_weights
query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1, 1, 32)
K = torch.randn(1, 1, 32)
V = torch.randn(1, 1, 64)
output = attn(Q, K, V)
print(output[0])
print(output[1])