Transformer——多头注意力机制(Pytorch)
创始人
2025-01-11 08:34:26
0

1. 原理图

2. 代码

import torch import torch.nn as nn   class Multi_Head_Self_Attention(nn.Module):     def __init__(self, embed_size, heads):         super(Multi_Head_Self_Attention, self).__init__()         self.embed_size = embed_size         self.heads = heads         self.head_dim = embed_size // heads          self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)         self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)         self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)         self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)      def forward(self,queries, keys, values, mask):         N = queries.shape[0]  # batch_size         query_len = queries.shape[1]  # sequence_length         key_len = keys.shape[1]  # sequence_length          value_len = values.shape[1]  # sequence_length          queries = self.queries(queries)         keys = self.keys(keys)         values = self.values(values)          # Split the embedding into self.heads pieces         # batch_size, sequence_length, embed_size(512) -->          # batch_size, sequence_length, heads(8), head_dim(64)         queries = queries.reshape(N, query_len, self.heads, self.head_dim)         keys = keys.reshape(N, key_len, self.heads, self.head_dim)         values = values.reshape(N, value_len, self.heads, self.head_dim)          # batch_size, sequence_length, heads(8), head_dim(64) -->          # batch_size, heads(8), sequence_length, head_dim(64)         queries = queries.transpose(1, 2)         keys = keys.transpose(1, 2)         values = values.transpose(1, 2)          # Scaled dot-product attention         score = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2))          if mask is not None:             score = score.masked_fill(mask == 0, float("-inf"))         # batch_size, heads(8), sequence_length, sequence_length         attention = torch.softmax(score, dim=-1)          out = torch.matmul(attention, values)         # batch_size, heads(8), sequence_length, head_dim(64) -->         # batch_size, sequence_length, heads(8), head_dim(64) -->         # batch_size, sequence_length, embed_size(512)         # 为了方便送入后面的网络         out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size)         out = self.fc_out(out)          return out       batch_size = 64 sequence_length = 10 embed_size = 512 heads = 8 mask = None  Q = torch.randn(batch_size, sequence_length, embed_size)   K = torch.randn(batch_size, sequence_length, embed_size)   V = torch.randn(batch_size, sequence_length, embed_size)    model = Multi_Head_Self_Attention(embed_size, heads) output = model(Q, K, V, mask) print(output.shape)

 

相关内容

热门资讯

2分钟细说!微信新九游辅助,山... >>您好:山西扣点点脚本辅助器确实是有挂的,很多玩家在这款山西扣点点脚本辅助器游戏中打牌都会发现很多...
透视有挂!wpk辅助,poke... 透视有挂!wpk辅助,pokemmo辅助器脚本下载(黑科技教程辅助开挂下载);打开点击测试直接进入微...
揭秘一下!黑科技辅助器,蘑菇云... 揭秘一下!黑科技辅助器,蘑菇云辅助怎么使用,科技辅助开挂器(有挂细节);打开点击测试直接进入微信(1...
每日必看!广西老友辅助,多乐跑... 每日必看!广西老友辅助,多乐跑辅助,起初有开挂辅助安装(有挂教程);无需打开直接搜索薇:136704...
透视安卓版!新518互游插件下... 透视安卓版!新518互游插件下载,微乐自建房插件免费软件(新2026版开挂辅助挂);亲,微乐自建房插...
重大通报!福建微乐小程序修改器... 重大通报!福建微乐小程序修改器,悠闲卡五星辅助,通报辅助挂(有挂辅助);无需打开直接搜索薇:1367...
发现一款!悠闲卡五星辅助,佛手... 您好:这款佛手在线大菠萝辅助游戏是可以开挂的,确实是有挂的,很多玩家在这款佛手在线大菠萝辅助游戏中打...
透视黑科技!wepoker游戏... 透视黑科技!wepoker游戏下载,广东雀神智能插件可测试(教你攻略开挂辅助安装) 了解更多开挂安装...
玩家必看攻略!河洛杠次辅助,微... 您好:河洛杠次辅助这款游戏可以开挂的,确实是有挂的,很多玩家在这款游戏中打牌都会发现很多用户的牌特别...
盘点一款!科乐填大坑攻略,波特... 科乐填大坑攻略开挂教程视频分享装挂详细步骤在当今的网络游戏中,科乐填大坑攻略作为一种经典的娱乐方式,...