提问人:carpet119 提问时间:11/8/2023 最后编辑:carpet119 更新时间:11/9/2023 访问量:29
pytorch transformers BertModel 中的tgt_key_padding_mask
tgt_key_padding_mask in pytorch transformers BertModel
问:
在浏览 PyTorch 中的转换器文档时,我看到形状(batch_size、tgt_seq_len)的tgt_key_padding_mask用于表示由于填充而使 tgt 的某些部分无关紧要。当我从 transformers 库查看 Pytorch 的 BertModel 实现时,我在 forward 函数中没有看到此类掩码的选项。使用 BertModel 时如何提供tgt_key_padding_mask?
注意:这里的首要答案解释了什么是tgt_key_padding_mask。
BertModel 可以选择提供一个head_mask,该乘以 attention_probs(在 softmax 之后,但在与值相乘之前)。我还没有看到任何关于head_mask的预期形状的提及/文档。似乎它将被广播为乘以 attention_probs 所需的任何形状(如果形状为 (batch_size, num_heads, query_seq_len=tgt_seq_len, key_seq_len=src_seq_len)),否则将触发形状不匹配错误。我在想,我可以通过这个传递我的tgt_key_padding_mask来简单地使用/滥用这个head_mask。但是,在使用 BertModel 时,有没有另一种更合适的方法来指定tgt_key_padding_mask呢?
答:
上一个:Flutter 填充和边距问题
评论
tgt_key_padding_mask