1. 原理图
2. 代码
import torch import torch.nn as nn class Multi_Head_Self_Attention(nn.Module): def __init__(self, embed_size, heads): super(Multi_Head_Self_Attention, self).__init__() self.embed_size = embed_size self.heads = heads self.head_dim = embed_size // heads self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False) self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False) self.values = nn.Linear(self.embed_size, self.embed_size, bias=False) self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False) def forward(self,queries, keys, values, mask): N = queries.shape[0] # batch_size query_len = queries.shape[1] # sequence_length key_len = keys.shape[1] # sequence_length value_len = values.shape[1] # sequence_length queries = self.queries(queries) keys = self.keys(keys) values = self.values(values) # Split the embedding into self.heads pieces # batch_size, sequence_length, embed_size(512) --> # batch_size, sequence_length, heads(8), head_dim(64) queries = queries.reshape(N, query_len, self.heads, self.head_dim) keys = keys.reshape(N, key_len, self.heads, self.head_dim) values = values.reshape(N, value_len, self.heads, self.head_dim) # batch_size, sequence_length, heads(8), head_dim(64) --> # batch_size, heads(8), sequence_length, head_dim(64) queries = queries.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) # Scaled dot-product attention score = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2)) if mask is not None: score = score.masked_fill(mask == 0, float("-inf")) # batch_size, heads(8), sequence_length, sequence_length attention = torch.softmax(score, dim=-1) out = torch.matmul(attention, values) # batch_size, heads(8), sequence_length, head_dim(64) --> # batch_size, sequence_length, heads(8), head_dim(64) --> # batch_size, sequence_length, embed_size(512) # 为了方便送入后面的网络 out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size) out = self.fc_out(out) return out batch_size = 64 sequence_length = 10 embed_size = 512 heads = 8 mask = None Q = torch.randn(batch_size, sequence_length, embed_size) K = torch.randn(batch_size, sequence_length, embed_size) V = torch.randn(batch_size, sequence_length, embed_size) model = Multi_Head_Self_Attention(embed_size, heads) output = model(Q, K, V, mask) print(output.shape)