Whole Word Masking (wwm)

Whole Word Masking (wwm)

本文代码部分参考github项目:
https://github.com/BSlience/search-engine-zerotohero/tree/main/public/bert_wwm_pretrain

Whole Word Masking (wwm),暂翻译为全词Mask或整词Mask,是谷歌在2019年5月31日发布的一项BERT的升级版本,主要更改了原预训练阶段的训练样本生成策略。我们先看下BERT原文的遮蔽语言模型。

BERT–遮蔽语言模型

在BERT之前,标准的条件语言模型只能从左到右或从右到左进行训练,因为双向条件作用将允许每个单词在多层上下文中间接地看到自己,为了训练深度双向表示,BERT采用了一种简单的方法,即随机遮蔽一定比例的输入标记,然后仅预测那些被遮蔽的标记,这一过程被称为遮蔽语言模型(MLM, masked language model),尽管在文献中它通常被称为完型填词任务。

在这种情况下,就像在标准语言模型中一样,与遮蔽标记相对应的最终隐藏向量被输入到与词汇表对应的输出 softmax 中(也就是要把被遮蔽的标记对应为词汇表中的一个词语)。在所有的实验中,BERT在每个序列中随机遮蔽 15% 的标记

虽然这确实允许我们获得一个双向预训练模型,但这种方法有两个缺点。第一个缺点是,我们在预训练和微调之间造成了不匹配,因为 [MASK] 标记在微调期间从未出现过。为了缓和这种情况,我们并不总是用真的用 [MASK] 标记替换被选择的单词。而是,训练数据生成器随机选择 15% 的标记,例如,在my dog is hairy 这句话中,它选择 hairy。然后执行以下步骤:

  • 80% 的情况下:用 [MASK] 替换被选择的单词,例如,my dog is hairy → my dog is [MASK]
  • 10% 的情况下:用一个随机单词替换被选择的单词,例如,my dog is hairy → my dog is apple
  • 10% 的情况下:保持被选择的单词不变,例如,my dog is hairy → my dog is hairy。这样做的目
    的是使表示偏向于实际观察到的词。

Transformer 编码器不知道它将被要求预测哪些单词,或者哪些单词已经被随机单词替换,因此它被迫保持每个输入标记的分布的上下文表示。另外,因为随机替换只发生在 1.5% 的标记(即,15% 的10%)这似乎不会损害模型的语言理解能力。

第二个缺点是,使用 Transformer 的每批次数据中只有 15% 的标记被预测,这意味着模型可能需要更多的预训练步骤来收敛。在 5.3 节中,我们证明了 Transformer 确实比从左到右的模型(预测每个标记)稍微慢一点,但是 Transformer 模型的实验效果远远超过了它增加的预训练模型的成本。

WordPiece

BERT原文中的遮蔽语言模型是基于wordPiece拆词后的子词进行MASK,所谓的wordPiece其实是把word再进一步的拆分,拆分为piece,得到更细粒度。

比如“loved”,”loving”,”loves”这三个单词。其实本身的语义都是“爱”的意思,但是如果我们以单词为单位,那它们就算作是不一样的词,在英语中不同后缀的词非常的多,就会使得词表变的很大,训练速度变慢,训练的效果也不是太好。

WordPiece与BPE(Byte-Pair Encoding)双字节编码算法比较相似,它们是两种不同的子词切分算法,主要区别在于如何选择两个子词进行合并。

例如WordPiece(或BPE)通过训练,能够把上面的”loved”,”loving”,”loves”3个单词拆分成”lov”,”ed”,”ing”,”es”几部分,这样可以把词的本身的意思和时态分开,有效的减少了词表的数量。

Whole Word Masking策略

在BERT中,原有基于WordPiece的分词方式会把一个完整的词切分成若干个子词,在生成训练样本时,这些被分开的子词会随机被mask。 在全词Mask中,如果一个完整的词的部分WordPiece子词被mask,则同属该词的其他部分也会被mask,即全词Mask。

需要注意的是,这里的mask指的是广义的mask(替换成[MASK];保持原词汇;随机替换成另外一个词),并非只局限于单词替换成[MASK]标签的情况。

由于谷歌官方发布的BERT-base, Chinese中,中文是以字为粒度进行切分,没有考虑到传统NLP中的中文分词(CWS, chinese word segment),所以全词Mask可以用在中文预训练中。

数据示例(方便理解)

  • 原始文本: 使用语言模型来预测下一个词的probability。
  • 分词文本: 使用 语言 模型 来 预测 下 一个 词 的 probability 。
  • 原始Mask输入(mlm): 使 用 语 言 [MASK] 型 来 [MASK] 测 下 一 个 词 的 pro [MASK] ##lity 。
  • 全词Mask输入(wwm): 使 用 语 言 [MASK] [MASK] 来 [MASK] [MASK] 下 一 个 词 的 [MASK] [MASK] [MASK] 。

代码实现

因为后面我会针对huggingface transformer中的chinese_bert wwm模型进行fine tune,该模型使用的是wwm(也就是全词MASK方法),所以这里记录whole Word Masking的一种实现方式。

huggingface transformer中有一个data collator的概念,数据整理器(data collator)是通过使用数据集元素列表作为输入来形成批次的对象。这些元素与train_dataset或eval_dataset的元素类型相同。

为了能够构建批处理,数据整理器可能会应用一些处理(如填充、截断)。其中一些(如DataCollatorForLanguageModeling)还对所形成的批处理应用了一些随机数据扩充(如随机屏蔽)。

huggingface transformer中关于data collator的文档

当然MASK操作也属于数据整理器的功能之一,整个data collator的步骤如下:

  1. 先获得这个批次数据的最大长度max_seq_len;
  2. 对句子进行补齐和截断;
  3. 对于每个样本的input_ids,随机选择20%字(token),认为其和前面一个词可能组成词;
  4. 在对应的token前添加特殊符号##比如 4 -> ##4
  5. 将带特征符号##的token传入mask方法(这里是self._whole_word_mask),随机选择15%的字认为是需要mask的,如果选到的字是带##标记的,那么就把它前面的字一起mask,返回mask_label;
  6. 根据mask_label和input_ids进行mask(80%进行mask掉,10%进行随机替换,10%选择保持不变)

注意:步骤3中选择的20%,是认为可能组成词的字(并不是需要mask的字),因为是随机选的,所以可能根本不是词,因为参考的这个项目就是这么实现的,所以在我看来是一个不完整的实现方案,如果有能力、有兴趣的小伙伴可以完整实现,也就是找到真正的词,可以借助一些分词工具。

下面是实现代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class DataCollator:
def __init__(self, max_seq_len: int, tokenizer: BertTokenizer, mlm_probability=0.15):
# max_seq_len 用于截断的最大长度
self.max_seq_len = max_seq_len
self.tokenizer = tokenizer
self.mlm_probability = mlm_probability # 遮词概率

# 截断和填充
def truncate_and_pad(self, input_ids_list, token_type_ids_list, attention_mask_list, max_seq_len):
# 初始化一个样本数量 * max_seq_len 的二维tensor
input_ids = torch.zeros((len(input_ids_list), max_seq_len), dtype=torch.long)
token_type_ids = torch.zeros_like(input_ids)
attention_mask = torch.zeros_like(input_ids)
for i in range(len(input_ids_list)):
seq_len = len(input_ids_list[i]) # 当前句子的长度
# 如果长度小于最大长度
if seq_len <= max_seq_len:
# 把input_ids_list中的值赋值给input_ids
input_ids[i, :seq_len] = torch.tensor(input_ids_list[i][:seq_len], dtype=torch.long)
else: # self.tokenizer.sep_token_id = 102
# 度超过最大长度的句子,input_ids最后一个值设置为102即分割词
# input_ids[i, :seq_len] = torch.tensor(input_ids_list[i][:seq_len - 1] +
# [self.tokenizer.sep_token_id], dtype=torch.long)
input_ids[i, :seq_len] = torch.tensor(input_ids_list[i][:max_seq_len - 1] +
[self.tokenizer.sep_token_id], dtype=torch.long)
print(input_ids[i])
seq_len = min(len(input_ids_list[i]), max_seq_len)
token_type_ids[i, :seq_len] = torch.tensor(token_type_ids_list[i][:seq_len], dtype=torch.long)
attention_mask[i, :seq_len] = torch.tensor(attention_mask_list[i][:seq_len], dtype=torch.long)
# print('截断和填充之前' + '*' * 30)
# print(input_ids_list) # 每个句子向量长度不一
# print('截断和填充之后' + '*' * 30)
# print(input_ids) # 每个句子向量长度统一
return input_ids, token_type_ids, attention_mask

def _whole_word_mask(self, input_ids_list: List[str], max_seq_len: int, max_predictions=512):
cand_indexes = []
for (i, token) in enumerate(input_ids_list):
# 跳过开头与结尾
if (token == str(self.tokenizer.cls_token_id) # 101
or token == str(self.tokenizer.sep_token_id)): # 102
continue

if len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])

random.shuffle(cand_indexes) # 打乱
# 根据句子长度*遮词概率算出要预测的个数,最大预测不超过512,不足1的按1
# round()四舍五入,但是偶数.5会舍去,不过这是细节问题,影响不是很大
num_to_predict = min(max_predictions, max(1, int(round(len(input_ids_list) * self.mlm_probability))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_lms.append(index)

assert len(covered_indexes) == len(masked_lms)
# mask 掉的 token 使用 1 来进行标记,否则使用 0 来标记
mask_labels = [1 if i in covered_indexes else 0 for i in range(min(len(input_ids_list), max_seq_len))]
mask_labels += [0] * (max_seq_len - len(mask_labels))
return torch.tensor(mask_labels)

def whole_word_mask(self, input_ids_list: List[list], max_seq_len: int) -> torch.Tensor:
mask_labels = []
for input_ids in input_ids_list:
# 随机选取20%的字,认为其和前一个字可以组成词(实际不一定)
# choices是一个有放回抽样,会重复,可能实际会少于20%,细节问题影响不大
wwm_id = random.choices(range(len(input_ids)), k=int(len(input_ids)*0.2))
# 给挑选出来的位置添加 "##"标记
input_id_str = [f'##{id_}' if i in wwm_id else str(id_) for i, id_ in enumerate(input_ids)]
mask_label = self._whole_word_mask(input_id_str, max_seq_len)
mask_labels.append(mask_label)
return torch.stack(mask_labels, dim=0)

def mask_tokens(self, inputs: torch.Tensor, mask_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
labels = inputs.clone()

probability_matrix = mask_labels

special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer.pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)

masked_indices = probability_matrix.bool()
labels[~masked_indices] = -100

indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]

return inputs, labels

# 重写魔术方法,可以把类的对象当做函数去调用
def __call__(self, examples: list) -> dict:
# pad前的(句子不一样长,需要填充)
input_ids_list, token_type_ids_list, attention_mask_list = list(zip(*examples))
# 动态识别batch中最大长度,用于padding操作
cur_max_seq_len = max(len(input_id) for input_id in input_ids_list)
# 如果这一批中,所有句子都比设定的最大长度还小,那直接使用该批次的最大长度,
# 可以减少运算数据量,加快速度
# 如果这一批中有句子比设定的最大长度还长,后续就会被截断
max_seq_len = min(cur_max_seq_len, self.max_seq_len)

# pad后的
input_ids, token_type_ids, attention_mask = self.truncate_and_pad(
input_ids_list, token_type_ids_list, attention_mask_list, max_seq_len
)

# 遮蔽单词,whole word mask策略
batch_mask = self.whole_word_mask(input_ids_list, max_seq_len)
# 针对得到的需要mask的词,进行实际mask操作(80%进行mask掉,10%进行随机替换,10%选择保持不变)
input_ids, mlm_labels = self.mask_tokens(input_ids, batch_mask)
data_dict = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'token_type_ids': token_type_ids,
'labels': mlm_labels
}
return data_dict

测试数据样例

输入data collator的数据一般是BertTokenizer encode_plus得到的输出,即

  • input_ids:输入句子中每个词的编号(在词表中的序号),101代表[cls],102代表[sep];
  • token_type_ids:单词属于哪个句子,第一个句子为0,第二句子为1;
  • attention_mask:需要对哪些单词做self_attention。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
input_ids = [
[101, 4078, 3828, 7029, 4344, 2768, 2642, 8024, 1220, 4289, 924, 2844, 5442, 1316, 2456, 21128, 7344, 4344, 7270, 1814, 21129, 2828, 3315, 1759, 4289, 4905, 1750, 1075, 2768, 4635, 4590, 102],
...
]
token_type_ids = [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
...
]
attention_mask = [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
...
]

data = (input_ids, token_type_ids, attention_mask)
data_collator = DataCollator()
data_collator(data) # 直接调用__call__方法