Transformer——多头注意力机制(Pytorch)
创始人
2025-01-11 08:34:26
0

1. 原理图

2. 代码

import torch import torch.nn as nn   class Multi_Head_Self_Attention(nn.Module):     def __init__(self, embed_size, heads):         super(Multi_Head_Self_Attention, self).__init__()         self.embed_size = embed_size         self.heads = heads         self.head_dim = embed_size // heads          self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)         self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)         self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)         self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)      def forward(self,queries, keys, values, mask):         N = queries.shape[0]  # batch_size         query_len = queries.shape[1]  # sequence_length         key_len = keys.shape[1]  # sequence_length          value_len = values.shape[1]  # sequence_length          queries = self.queries(queries)         keys = self.keys(keys)         values = self.values(values)          # Split the embedding into self.heads pieces         # batch_size, sequence_length, embed_size(512) -->          # batch_size, sequence_length, heads(8), head_dim(64)         queries = queries.reshape(N, query_len, self.heads, self.head_dim)         keys = keys.reshape(N, key_len, self.heads, self.head_dim)         values = values.reshape(N, value_len, self.heads, self.head_dim)          # batch_size, sequence_length, heads(8), head_dim(64) -->          # batch_size, heads(8), sequence_length, head_dim(64)         queries = queries.transpose(1, 2)         keys = keys.transpose(1, 2)         values = values.transpose(1, 2)          # Scaled dot-product attention         score = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2))          if mask is not None:             score = score.masked_fill(mask == 0, float("-inf"))         # batch_size, heads(8), sequence_length, sequence_length         attention = torch.softmax(score, dim=-1)          out = torch.matmul(attention, values)         # batch_size, heads(8), sequence_length, head_dim(64) -->         # batch_size, sequence_length, heads(8), head_dim(64) -->         # batch_size, sequence_length, embed_size(512)         # 为了方便送入后面的网络         out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size)         out = self.fc_out(out)          return out       batch_size = 64 sequence_length = 10 embed_size = 512 heads = 8 mask = None  Q = torch.randn(batch_size, sequence_length, embed_size)   K = torch.randn(batch_size, sequence_length, embed_size)   V = torch.randn(batch_size, sequence_length, embed_size)    model = Multi_Head_Self_Attention(embed_size, heads) output = model(Q, K, V, mask) print(output.shape)

 

相关内容

热门资讯

第九分钟辅助!天天爱柳州辅助器... 天天爱柳州辅助器是一款可以让一直输的玩家,快速成为一个“必胜”的ai辅助神器,有需要的用户可以加我微...
热门推荐!微信小程序蜀山辅助器... 热门推荐!微信小程序蜀山辅助器免费下载(辅助)外挂辅助开挂插件(有挂秘诀)-哔哩哔哩;最新版2026...
第4分钟普及!奇迹陕西辅助器(... 第4分钟普及!奇迹陕西辅助器(透视)原来真的是有挂脚本(推荐开挂软件);奇迹陕西辅助器免费下载原版,...
第八分钟讲解!越乡游金花辅助,... 第八分钟讲解!越乡游金花辅助,九酷众游辅助,细节方法(有挂实锤)-哔哩哔哩1、下载安装好越乡游金花辅...
透视软件!浙江宝宝游戏辅助工具... 透视软件!浙江宝宝游戏辅助工具(辅助)外挂开挂辅助插件(了解有挂)-哔哩哔哩;是一款可以让一直输的玩...
第5分钟发现!桃乐甘肃麻将下载... 《第5分钟发现!桃乐甘肃麻将下载辅助器(辅助)原来有挂软件(必看开挂工具)》 桃乐甘肃麻将下载辅助器...
第8分钟了解!河洛刚次辅助,温... 第8分钟了解!河洛刚次辅助,温州茶苑辅助软件,详细教程(有挂攻略)-哔哩哔哩;无需打开直接搜索加薇1...
热点讨论!新二号辅助软件下载(... 热点讨论!新二号辅助软件下载(辅助)外挂辅助开挂软件(有挂教程)-哔哩哔哩;是一款可以让一直输的玩家...
8分钟解说!网易亲友圈辅助(开... 8分钟解说!网易亲友圈辅助(开挂)原来有挂平台(必看开挂安装);小薇(透视辅助)致您一封信;亲爱网易...
第5分钟开挂!如何在哈灵上辅助... 第5分钟开挂!如何在哈灵上辅助,决战卡五星辅助看牌器,解说技巧(有挂方略)-哔哩哔哩 了解更多开挂安...