BERT finetune

BERT finetune

本文代码部分参考github项目:
https://github.com/BSlience/search-engine-zerotohero/tree/main/public/bert_wwm_pretrain

项目概述

本文的主要内容是基于huggingface transformer的chinese-bert-wwm模型,在自己的语料集上进行finetune的整体步骤和代码实现。

关于chinese-bert-wwm:

https://huggingface.co/hfl/chinese-bert-wwm

https://github.com/ymcui/Chinese-BERT-wwm

主要步骤包括:预处理和训练两个部分

预处理(pre-processing)

  1. 下载chinese-bert-wwm模型的预训练词表(vocab.txt)、config.json和pytorch_model.bin;
  2. 读取自己的原始数据集(比如大量的文章、文本),做句子分割,然后保存成语料集;
  3. 根据自己的语料集进行分词(BERT是分割成单个字),并将自己语料集中相比原始的词表多的字(或者词)添加到原始词表中(就是一个扩充操作),然后就生成了自己的词表;

步骤1的下载地址:https://huggingface.co/hfl/chinese-bert-wwm/tree/main

如果生成的语料集比较大,为了后续加载方便可以存储至内存型的数据库中(比如Redis)

训练(train)

  1. 加载BertTokenizer,读取语料集,生成数据;
  2. 针对1中得到的数据,进行填充、截断和mask等操作,借助data collator类;
  3. 加载chinese-bert-wwm模型的预训练权重文件,基于当前数据开始训练(微调);
  4. 保存模型,测试效果。

关于data collator的实现可以查看我的上一篇文章whole word mask;

步骤3加载的预训练权重文件就是上述预处理步骤1下载的,不过要注意把config.json和pytorch_model.bin放在同一个目录下,然后加载这个目录即可。

预处理

这里原始数据(大量的知乎文章)存储在mongodb中,我们读取出来,然后执行预处理步骤2、3

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def clean_html_tag(data: str):
"""
清除给定文本中包含的HTML标签
:param data:
:return:
"""
if isinstance(data, float):
print(data)
return ""
result = bs(data).get_text()
return result


def read_data_from_mongodb(mongodb_url, db, collection):
"""
从mongodb中读取数据
:param mongodb_url: 连接mongodb的地址
:param db: 数据库名称
:param collection: 集合名称
:return:
"""
client = MongoClient(mongodb_url)
collection = client[db][collection]
data = defaultdict(list) # 可以将字典的value设置list类型
for each in tqdm(collection.find(batch_size=10)): # tqdm用于显示进度
try:
if each["title"]:
title = clean_html_tag(each["title"])
data['title'].append(title)
if each["excerpt"]:
summary = clean_html_tag(each['excerpt'])
data['summary'].append(summary)
if each['content']:
clean_content = clean_html_tag(each['content'])
data['content'].append(clean_content)
except Exception as _e:
print(_e)
print(each)
return data


def get_split_sentences(data):
split_sentence = SplitSentence()
data_new = defaultdict(list)
for key, value in data.items():
if key == 'title':
data_new[key].extend(data[key])
else:
sentences_list = []
for sentences in data[key]:
for each in split_sentence.split_sentence(sentences):
sentences_list.append(each)
data_new[key].extend(sentences_list)
return data_new


def pre_processing(redis_host, redis_port, redis_pwd, mongodb_host, db, collection):
"""
预处理操作,包括从mongodb读取数据,分割、存放至redis
:param redis_host:
:param redis_port:
:param redis_pwd
:param mongodb_host:
:param db:
:param collection:
:return:
"""
print('start reading data....')
data = read_data_from_mongodb(mongodb_host, db, collection)
print('read data from mongodb finished')
print('*' * 50)
print('start saving data')
data = get_split_sentences(data)
res = redis.StrictRedis(host=redis_host, port=redis_port, db=0, password=redis_pwd)
res.set('sentences', json.dumps(data))
print('save data to redis finished')


def main():
# 预训练词表存储的位置
# 下载地址:https://huggingface.co/hfl/chinese-bert-wwm/tree/main
original_vocab_file_path = 'data/chinese_bert_wwm/vocab.txt'
# 语料存储位置
corpus_file_path = 'data/pretrain_corpus.txt'
# 自己的词表(训练效果一般取决于,词表中的词在语料中出现的次数,如果重要的词在语料中只出现了一次效果就不好)
vocab_file_path = 'data/vocab.txt'

# mongo和redis配置
mongodb_config = {'host': 'mongodb://127.0.0.1:27017', 'db': 'zhihu_new', 'collection': 'articles'}
redis_config = {'host': '127.0.0.1', 'port': '6379', 'db_index': '0', 'pwd': 'xxx'}
# 预处理,生成预料(存至Redis,方便后续使用)
pre_processing(
redis_host=redis_config['host'], redis_port=redis_config['port'], redis_pwd=redis_config['pwd'],
mongodb_host=mongodb_config['host'], db=mongodb_config['db'], collection=mongodb_config['collection']
)
res = redis.StrictRedis(
host=redis_config['host'],
port=redis_config['port'],
db=0,
password=redis_config['pwd']
)
data = json.loads(res.get('sentences'))
data_list = []
for each in data.keys():
data_list.extend(data[each])
print('start producing corpus')
# 保存语料到文件
save_corpus(data_list, corpus_file_path)
print('start producing vocab')
# 生成词表(其实就是单个的字符)
generate_vocab(data_list, vocab_file_path, original_vocab_file_path)


if __name__ == "__main__":
main()

这里句子分割单独封装了一个类SplitSentence,具体实现如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def replace_with_separator(text, separator, regexs):
replacement = r"\1" + separator + r"\2"
result = text
for regex in regexs:
result = regex.sub(replacement, result)
return result


class SplitSentence:
"""
这个分割的方法需要根据你的数据集调整,比如针对一些没有分开的句子,添加分割符到这里
"""
def __init__(self):
self.separator = r'@'
self.re_sentence = re.compile(r'(\S.+?[.!?])(?=\s+|$)|(\S.+?)(?=[\n]|$)', re.UNICODE)
self.ab_senior = re.compile(r'([A-Z][a-z]{1,2}\.)\s(\w)', re.UNICODE)
self.ab_acronym = re.compile(r'(\.[a-zA-Z]\.)\s(\w)', re.UNICODE)
self.undo_ab_senior = re.compile(r'([A-Z][a-z]{1,2}\.)' + self.separator + r'(\w)', re.UNICODE)
self.undo_ab_acronym = re.compile(r'(\.[a-zA-Z]\.)' + self.separator + r'(\w)', re.UNICODE)

def split_sentence(self, text, best=True):
# 句子分割,主要是通过标点符号,如果分割结果发现有些句子分割效果不好,再增加相应的分割符号
text = re.sub('([。!??])([^”’])', r"\1\n\2", text)
text = re.sub('(\.{6})([^”’])', r"\1\n\2", text)
text = re.sub('(…{2})([^”’])', r"\1\n\2", text)
text = re.sub('([。!??][”’])([^,。!??])', r'\1\n\2', text)
for chunk in text.split("\n"):
chunk = chunk.strip()
if not chunk:
continue
if not best:
yield chunk
continue
processed = replace_with_separator(chunk, self.separator, [self.ab_senior, self.ab_acronym])
for sentence in self.re_sentence.finditer(processed):
sentence = replace_with_separator(sentence.group(), r" ",
[self.undo_ab_senior, self.undo_ab_acronym])
yield sentence

上面是pre_processing()方法的实现,接下来还有保存语料到文件和生成词表的操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def save_corpus(data, corpus_file_path):
# 语料其实就是一段段的文本(分割后的一句话)
with open(corpus_file_path, 'w', encoding='utf-8') as f:
for row in tqdm(data, total=len(data)):
f.write(row + '\n')


def generate_vocab(total_data, vocab_file_path, original_vocab_file_path):
# 以单个的字作为词表(BERT用的是字,也有其他方法是用词的)
total_tokens = [token for sent in total_data for token in sent]
counter = Counter(total_tokens)
vocab = [token for token, freq in counter.items()]
# 更新下载的预训练词表,也就是把自己词表添加到原始词表中
# 如果只使用自己的词表,则无法fine tune成功,一定是扩充原始词表
original_vocab = []
with open(original_vocab_file_path, 'r', encoding='utf-8') as f:
for line in f.readlines():
line = line.strip('\n')
original_vocab.append(line)
need_add_token = [each for each in vocab if each not in original_vocab]

original_vocab.extend(need_add_token)
with open(vocab_file_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(original_vocab))

上面的代码是把所有的语料都放到列表中,然后全部操作完再一条条的写到文件中,很容易在执行的时候超出内存卡死,所以我优化了下,读取一条、处理一条、保存一条,运行瞬间丝滑。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
根据原项目自己优化的版本
原项目会把所有的句子都放在内存中,很容易电脑内存超出了
这里改为一句一句的读,分割,然后存到文件,并且不使用redis
"""

from bs4 import BeautifulSoup as bs
from pymongo import MongoClient
from tqdm import tqdm

from processing import SplitSentence


def clean_html_tag(data: str):
"""
清除给定文本中包含的HTML标签
:param data:
:return:
"""
if isinstance(data, float):
print(data)
return ""
result = bs(data).get_text()
return result


def read_data_from_mongodb(mongodb_url, db, collection):
"""
从mongodb中读取数据
:param mongodb_url: 连接mongodb的地址
:param db: 数据库名称
:param collection: 集合名称
:return:
"""
client = MongoClient(mongodb_url)
collection = client[db][collection]
for each in tqdm(collection.find(batch_size=10)): # tqdm用于显示进度
data = {}
try:
for key in ['title', 'excerpt', 'content']:
if each[key]:
value = clean_html_tag(each[key])
data[key] = value
except Exception as _e:
print(_e)
print(each)
else:
yield data


def generate_vocab(corpus_file_path, vocab_file_path, original_vocab_file_path):
# 以单个的字作为词表(BERT用的是字,也有其他方法是用词的)
vocab = set() # 词表,不重复
with open(corpus_file_path, 'r', encoding='utf-8') as fr:
for line in fr.readlines():
for word in line:
vocab.add(word)
# 更新下载的预训练词表,也就是把自己词表添加到原始词表中
# 如果只使用自己的词表,则无法fine tune成功,一定是扩充原始词表
original_vocab = []
with open(original_vocab_file_path, 'r', encoding='utf-8') as f:
for line in f.readlines(): # 原始词表中,每行就是一个字符
line = line.strip('\n')
original_vocab.append(line)
need_add_token = vocab.difference(set(original_vocab))
# new_vocab = vocab.union(original_vocab) # 并集(这样会打乱原始此表的顺序)
with open(vocab_file_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(original_vocab))
f.write('\n')
f.write('\n'.join(need_add_token))


def main():
# 预训练词表存储的位置
# 下载地址:https://huggingface.co/hfl/chinese-bert-wwm/tree/main
original_vocab_file_path = 'data/chinese_bert_wwm/vocab.txt'
# 语料存储位置
corpus_file_path = 'data/pretrain_corpus.txt'
# 自己的词表(训练效果一般取决于,词表中的词在语料中出现的次数,如果重要的词在语料中只出现了一次效果就不好)
vocab_file_path = 'data/vocab.txt'

# mongodb配置
mongodb_config = {'host': 'mongodb://127.0.0.1:27017', 'db': 'zhihu_new', 'collection': 'articles'}

# 读取数据预处理,分割句子,生成语料
data_list = read_data_from_mongodb(
mongodb_url=mongodb_config['host'],
db=mongodb_config['db'],
collection=mongodb_config['collection']
)
# 语料文件,每行一句文本
print('start producing corpus')
fw = open(corpus_file_path, 'w', encoding='utf-8')

for data in data_list: # data_list是一个生成器
split_sentence = SplitSentence()
for key, value in data.items():
if key == 'title':
fw.write(value+"\n")
else:
for each in split_sentence.split_sentence(value):
fw.write(each+"\n")
fw.close()
print('saving corpus finished')

print('start producing vocab')
# 生成词表(其实就是单个的字符)
generate_vocab(corpus_file_path, vocab_file_path, original_vocab_file_path)


if __name__ == "__main__":
main()

不过这里需要注意的是,扩充原词表时,不要改变原始词表的顺序,保持原词表顺序不变,在后面添加新词,新添加的顺序无所谓。

训练

先配置一个整体的配置文件,方便后面管理使用

1
2
3
4
5
6
7
8
9
10
11
CONFIG = {
'corpus_file_path': 'data/pretrain_corpus.txt', # 训练样本(语料)
'vocab_file_path': 'data/vocab.txt', # 这个词表是根据自己的数据集扩充过的
'redis_url': '127.0.0.1',
'redis_port': 6379,
'max_seq_len': 102,
'batch_size': 32,
'output_dir': 'data/whole_word_mask_bert_output',
'bert_model_dir': 'data/chinese_bert_wwm', # 存放config.json和pytorch_model.bin的路径
'debug': False # 调试用的
}

训练代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def seed_everyone(seed_):
torch.manual_seed(seed_)
torch.cuda.manual_seed_all(seed_)
np.random.seed(seed_)
random.seed(seed_)
return seed_


def check_dir(path):
if not os.path.exists(path):
os.makedirs(path)


class SearchDataSet(Dataset):
def __init__(self, data_dict: dict):
self.data_dict = data_dict

# 重写魔术方法,可以通过索引来访问对象
def __getitem__(self, index: int) -> tuple:
data = (self.data_dict['input_ids'][index],
self.data_dict['token_type_ids'][index],
self.data_dict['attention_mask'][index])
return data

def __len__(self) -> int:
return len(self.data_dict['input_ids'])


def read_data(train_file_path, tokenizer: BertTokenizer, debug=False) -> dict:
train_data = open(train_file_path, 'r', encoding='utf-8').readlines()
if debug:
train_data = train_data[:2000]
inputs = defaultdict(list)
for row in tqdm(train_data, desc='Preprocessing train data', total=len(train_data)):
sentence = row.strip()
# encode
inputs_dict = tokenizer.encode_plus(sentence, add_special_tokens=True,
return_token_type_ids=True, return_attention_mask=True)
inputs['input_ids'].append(inputs_dict['input_ids'])
inputs['token_type_ids'].append(inputs_dict['token_type_ids'])
inputs['attention_mask'].append(inputs_dict['attention_mask'])
return inputs


def main():
seed_everyone(20220531) # 统一设置随机数种子
train_file_path = CONFIG['corpus_file_path']
# 加载预训练的分词器
tokenizer = BertTokenizer.from_pretrained(CONFIG['vocab_file_path'], local_file_only=True)
# 使用分词器读取数据(语料)
data = read_data(train_file_path, tokenizer, CONFIG['debug'])

train_dataset = SearchDataSet(data)
# huggingface transformer中特有的概念,data collator 数据修补(截断和填充)
# https://huggingface.co/docs/transformers/main/en/main_classes/data_collator#transformers.DataCollatorForLanguageModeling
# 上面的train_dataset中只有单独的句子,还没有标签(BERT中就是被MASK的值,随机遮蔽一些值,然后预测)
data_collator = SearchCollator(max_seq_len=CONFIG['max_seq_len'], tokenizer=tokenizer, mlm_probability=0.15)
# 测试
data_collator(list(train_dataset))

output_dir = CONFIG['output_dir']
model = BertForMaskedLM.from_pretrained(CONFIG['bert_model_dir'])

model_save_dir = (os.path.join(output_dir, 'best_model_dir'))
tokenizer_and_config = os.path.join(output_dir, 'tokenizer_and_config')
check_dir(model_save_dir)
check_dir(tokenizer_and_config)

training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=20,
fp16_backend='auto',
per_device_train_batch_size=128,
save_steps=500,
logging_steps=500,
save_total_limit=5,
prediction_loss_only=True,
# report_to='comet_ml',
logging_first_step=True,
dataloader_num_workers=4,
disable_tqdm=False,
seed=202203
)

trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)

trainer.train()
trainer.save_model(model_save_dir)
tokenizer.save_pretrained(tokenizer_and_config)


if __name__ == "__main__":
main()

训练的时候可能会出现这个问题

Failed to create a directory: data/whole_word_mask_bert_output\runs\Jun21_16-47-59_user; No such file or directory

我在whole_word_mask_bert_output文件夹下又创建了一个runs文件夹解决了。

另外,如果你的debug配置忘记改为False,那么传入trainer.train()的数据只有2000条,是会报错的,IndexError: index out of range in self

报这个错误是embedding层的张量输入超过了合法范围,embedding层的合法张量输入数值范围应该在[0, num_embeddings-1]的范围内,过大过小都会报错。

关于tokenizer.encode_plus

1
2
3
4
5
6
7
8
9
10
11
tokenizer = BertTokenizer.from_pretrained(CONFIG['vocab_file_path'], local_file_only=True)

inputs_dict = tokenizer.encode_plus(
sentence,
add_special_tokens=True,
return_token_type_ids=True,
return_attention_mask=True
)
inputs['input_ids'].append(inputs_dict['input_ids'])
inputs['token_type_ids'].append(inputs_dict['token_type_ids'])
inputs['attention_mask'].append(inputs_dict['attention_mask'])

[CLS] 标志放在第一个句子的首位,经过 BERT 得到的的表征向量 C 可以用于后续的分类任务。
[SEP] 标志用于分开两个输入句子,例如输入句子 A 和 B,要在句子 A,B 后面增加 [SEP] 标志。
[UNK]标志指的是未知字符
[MASK] 标志用于遮盖句子中的一些单词,将单词用 [MASK] 遮盖之后,再利用 BERT 输出的 [MASK] 向量预测单词是什么。

特征抽取

训练(finetune)完成后,就可以使用训练得到的模型来抽取文本的特征,其实这里所说的抽取文本的特征实际就是把自然语言文本转为向量,我们直接使用原始的模型也是可以进行特征抽取的,只不过在自己的数据集上finetune之后效果会更好,而具体效果要在下游的实际任务中才能评估,仅通过finetune后的模型来将文本转为特征向量无法评估效果的好坏。

特征抽取代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
使用BERT抽取自然语言特征
过程分为两步:1、在自己的数据集上进行finetune;2、利用finetune后的模型进行抽取
这里是特征抽取过程
"""
import torch
from transformers import BertModel, BertTokenizer


class TextVector:
def __init__(self, device_id):
self.tokenizer = BertTokenizer.from_pretrained(
# 文件夹下包括config.json、pytorch_model.bin、vocab.txt
'./data/best_model_ckpt'
)
self.model = BertModel.from_pretrained(
'./data/best_model_ckpt'
)
if torch.cuda.is_available():
self.device = "cuda:" + str(device_id)
self.model.to(self.device)
print(f"bert model 加载到了cuda:{self.device}.")
else:
self.device = 'cpu'

def run(self, data):
# bert长度限制一般为512,超过长度截断(长度过长会导致参数量太大)
if len(data) > 510:
data = data[:510]
inputs = self.tokenizer(data, return_tensors='pt')
print(inputs)
# 将数据放入cpu或者gpu
inputs = {key: value.to(self.device) for key, value in inputs.items()}
outputs = self.model(**inputs)
print(outputs.pooler_output.detach().size()) # tensor类型, torch.Size([1, 768])
print(outputs.pooler_output.detach().to("cpu").numpy().shape) # 转为numpy, shape(1, 768)
# 这里取0号元素后,得到的就是一个列表,reshape(1, -1)转为1行,列自动计算,又变成了(1, 768)
data_vector = outputs.pooler_output.detach().to("cpu").numpy()[0].reshape(1, -1)
return data_vector[0].tolist()


text_vector = TextVector(device_id=0)
res = text_vector.run("我喜欢学习")
print(len(res)) # 768
print(res)