Transformer模型:Decoder的self-attention mask实现
创始人
2025-01-08 16:06:28
0

前言

        这是对Transformer模型Word Embedding、Postion Embedding、Encoder self-attention mask、intra-attention mask内容的续篇。

视频链接:20、Transformer模型Decoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili

文章链接:Transformer模型:WordEmbedding实现-CSDN博客

                  Transformer模型:Postion Embedding实现-CSDN博客

                  Transformer模型:Encoder的self-attention mask实现-CSDN博客

                  Transformer模型:intra-attention mask实现-CSDN博客


 正文

        首先介绍一下Deoder的self-attention mask,它与前面的两个mask不一样地方在于Decoder是生成一个单词之后,将改单词作为输入给到Decoder中继续生成下一个,也就是相当于下三角矩阵,一次多一个,直到完成整个预测。

        先生成一个下三角矩阵:

tri_matrix = [torch.tril(torch.ones(L, L)) for L in tgt_len]

         这里生成的两个下三角矩阵的维度是不一样的,首先要统一维度:

valid_decoder_tri_matrix = [F.pad(torch.tril(torch.ones(L, L)), (0, max_tgt_seg_len-L, 0, max_tgt_seg_len-L)) for L in tgt_len]

        然后就是将它转为1个3维的张量形式,过程跟先前类似,这里就不一步步拆解了:

valid_decoder_tri_matrix = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones(L, L)), (0, max_tgt_seg_len-L, 0, max_tgt_seg_len-L)),0) for L in tgt_len]) 

        后续掩码过程还是跟前两篇一样,这里也不多解释了:

invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix mask_decoder_self_attention = invalid_decoder_tri_matrix.to(torch.bool) score2 = torch.randn(batch_size, max_tgt_seg_len, max_tgt_seg_len) mask_score3 = score2.masked_fill(mask_decoder_self_attention, -1e9) prob3 = F.softmax(mask_score3, -1)

 代码

import torch import numpy as np import torch.nn as nn import torch.nn.functional as F  # 句子数 batch_size = 2  # 单词表大小 max_num_src_words = 10 max_num_tgt_words = 10  # 序列的最大长度 max_src_seg_len = 12 max_tgt_seg_len = 12 max_position_len = 12  # 模型的维度 model_dim = 8  # 生成固定长度的序列 src_len = torch.Tensor([11, 9]).to(torch.int32) tgt_len = torch.Tensor([10, 11]).to(torch.int32)  # 单词索引构成的句子 src_seq = torch.cat(     [torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max_src_seg_len - L)), 0) for L in src_len]) tgt_seq = torch.cat(     [torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)), (0, max_tgt_seg_len - L)), 0) for L in tgt_len])  # Part1:构造Word Embedding src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim) tgt_embedding_table = nn.Embedding(max_num_tgt_words + 1, model_dim) src_embedding = src_embedding_table(src_seq) tgt_embedding = tgt_embedding_table(tgt_seq)  # 构造Pos序列跟i序列 pos_mat = torch.arange(max_position_len).reshape((-1, 1)) i_mat = torch.pow(10000, torch.arange(0, 8, 2) / model_dim)  # Part2:构造Position Embedding pe_embedding_table = torch.zeros(max_position_len, model_dim) pe_embedding_table[:, 0::2] = torch.sin(pos_mat / i_mat) pe_embedding_table[:, 1::2] = torch.cos(pos_mat / i_mat)  pe_embedding = nn.Embedding(max_position_len, model_dim) pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)  # 构建位置索引 src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len), 0) for _ in src_len]).to(torch.int32) tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len), 0) for _ in tgt_len]).to(torch.int32)  src_pe_embedding = pe_embedding(src_pos) tgt_pe_embedding = pe_embedding(tgt_pos)  # Part3:构造encoder self-attention mask valid_encoder_pos = torch.unsqueeze(     torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_src_seg_len - L)), 0) for L in src_len]), 2) valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2)) invalid_encoder_pos_matrix = 1 - torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2)) mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool) score = torch.randn(batch_size, max_src_seg_len, max_src_seg_len) mask_score1 = score.masked_fill(mask_encoder_self_attention, -1e9) prob1 = F.softmax(mask_score1, -1)  # Part4:构造intra-attention mask valid_encoder_pos = torch.unsqueeze(     torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_src_seg_len - L)), 0) for L in src_len]), 2) valid_decoder_pos = torch.unsqueeze(     torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_tgt_seg_len - L)), 0) for L in tgt_len]), 2)  valid_cross_pos_matrix = torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1, 2)) invalid_cross_pos_matrix = 1 - valid_cross_pos_matrix mask_cross_attention = invalid_cross_pos_matrix.to(torch.bool) mask_score2 = score.masked_fill(mask_cross_attention, -1e9) prob2 = F.softmax(mask_score2, -1)  # Part5:构造Decoder self-attention mask valid_decoder_tri_matrix = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones(L, L)), (0, max_tgt_seg_len-L, 0, max_tgt_seg_len-L)),0) for L in tgt_len]) invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix mask_decoder_self_attention = invalid_decoder_tri_matrix.to(torch.bool) score2 = torch.randn(batch_size, max_tgt_seg_len, max_tgt_seg_len) mask_score3 = score2.masked_fill(mask_decoder_self_attention, -1e9) prob3 = F.softmax(mask_score3, -1)

相关内容

热门资讯

绝活儿辅助!广东雀神智能插件是... 绝活儿辅助!广东雀神智能插件是真的(辅助挂)其实是有辅助软件(存在有挂)1、广东雀神智能插件是真的公...
绝活辅助!天天爱消除自动消除辅... 绝活辅助!天天爱消除自动消除辅助(辅助挂)一贯是有辅助工具(有挂透明挂);运天天爱消除自动消除辅助辅...
模块辅助!凑一桌关春天怎么才能... 模块辅助!凑一桌关春天怎么才能开挂(辅助挂)果然真的有辅助挂(有挂技术)1、凑一桌关春天怎么才能开挂...
模块辅助!聚友联盟辅助器(辅助... 模块辅助!聚友联盟辅助器(辅助挂)一直真的是有辅助器(证实有挂)1、起透看视 聚友联盟辅助器辅助软件...
指引辅助!途游小程序辅助器(辅... 指引辅助!途游小程序辅助器(辅助挂)果然确实有辅助神器(新版有挂)1、在途游小程序辅助器插件功能辅助...
阶段辅助!手机卡五星辅助软件(... 阶段辅助!手机卡五星辅助软件(辅助挂)确实是真的有辅助方法(确实有挂)1、手机卡五星辅助软件免费辅助...
手段辅助!芒果辅助器安卓版(辅... 手段辅助!芒果辅助器安卓版(辅助挂)原来真的有辅助脚本(有挂解惑)1、这是跨平台的芒果辅助器安卓版轻...
诀窍辅助!免费宝宝浙江游戏安装... 诀窍辅助!免费宝宝浙江游戏安装(辅助挂)竟然真的是有辅助攻略(有挂实锤)1、免费宝宝浙江游戏安装脚本...
办法辅助!透视盒子(辅助挂)一... 办法辅助!透视盒子(辅助挂)一贯是真的有辅助教程(有挂总结)透视盒子辅助器是一种具有地方特色的麻将游...
策略辅助!凑一桌游戏辅助(辅助... 策略辅助!凑一桌游戏辅助(辅助挂)本来真的有辅助插件(有挂技巧)1、点击下载安装,凑一桌游戏辅助脚本...