pytorch transformers BertModel 中的tgt_key_padding_mask

tgt_key_padding_mask in pytorch transformers BertModel

提问人:carpet119 提问时间:11/8/2023 最后编辑:carpet119 更新时间:11/9/2023 访问量:29

问:

在浏览 PyTorch 中的转换器文档时,我看到形状(batch_size、tgt_seq_len)的tgt_key_padding_mask用于表示由于填充而使 tgt 的某些部分无关紧要。当我从 transformers 库查看 Pytorch 的 BertModel 实现时,我在 forward 函数中没有看到此类掩码的选项。使用 BertModel 时如何提供tgt_key_padding_mask?

注意:这里的首要答案解释了什么是tgt_key_padding_mask。

BertModel 可以选择提供一个head_mask,该乘以 attention_probs(在 softmax 之后,但在与值相乘之前)。我还没有看到任何关于head_mask的预期形状的提及/文档。似乎它将被广播为乘以 attention_probs 所需的任何形状(如果形状为 (batch_size, num_heads, query_seq_len=tgt_seq_len, key_seq_len=src_seq_len)),否则将触发形状不匹配错误。我在想,我可以通过这个传递我的tgt_key_padding_mask来简单地使用/滥用这个head_mask。但是,在使用 BertModel 时,有没有另一种更合适的方法来指定tgt_key_padding_mask呢?

pytorch 填充 掩码 bert-language-model transformer-model

评论

0赞 Karl 11/8/2023
您链接的 Bert 实现来自 Huggingface,它与 Pytorch 是不同的实体。Huggingface团队很可能没有实现该功能。tgt_key_padding_mask
0赞 carpet119 11/8/2023
好的,对实现 tgt_key_padding 的 BERT 的可靠 PyTorch 实现的任何建议?只是问

答:

1赞 carpet119 11/9/2023 #1

在 BertModel 中,如果 attention_mask 作为 2D 张量传递,则假定它是用于指示序列的哪些部分被填充(因此需要忽略)的掩码,其中 I (和 nn.Transformer 文档)简称tgt_key_padding_mask。在这种情况下,因果掩码(tgt_mask,根据 nn.Transformers documentation)根据序列长度自动计算,并与传递的attention_mask相结合。我希望这在某个地方被记录下来。