本文代码和实际情况有所出入,写本文主要是通过文章来帮助刚入门的同学快速理解BERT原文中的思路,文章分为三部分:数据集的构建和选取,网络结构和loss计算
数据集的构建和提取
数据集的格式
在BERT原文中采用了Bookcorpus和Wikipedia数据集,并且是document-level的,也就是说他们的sequence选取是在文档中截取的连续token,类似于如下形式
i wish i had a better answer to that question .
然后采用WordPiece的方法对其进行截取,截取完之后得到如下token,一共9个token,并用数字对其编号
i wish had a better answer to that question
4, 5, 6, 7, 8, 9, 10, 11, 12
至于为什么从4开始,这是因为原文中有「CLS」,「MASK」,「Seq」等token,我们对其编码0,1,2等。以此类推,对于整个数据集,语料库就是这样构建的,按照原文,一共30000个token,即编号从0一直到30000
数据集的提取
构建完数据集之后,我们需要对其进行提取,需要编写dataset.py文件,这里展示getitem函数的一些核心内容
- 首先从文档中提取两句话(这里为了方便,我只提取一句话,但是将这句话分为两段),返回t1,t2和is_next_label,其中is_next_label表示这两个sentence是不是连续的,代码如下
1 | t1, t2, is_next_label = self.random_sent(item) |
- 随后提取t1和t2的token编号(就是之前介绍的数字编号)以及进行mask操作, 看最后的返回,t1_random就是返回的编号,比如在这里i这个单词对应的索引编号为12,have的编号为37,t1_label都为0表示这个t1_random没有任何单词被mask过,我们再来看t2_random,其中有一个为4,代表被mask了(因为事先已经将mask编号为了4),这里tell和".\n"被mask了,t2_label里面有两个非0元素,182和5,表示这两个被mask的token被mask实际编号为182和5,我们之后MLM预测时就要预测这两个分类值。
1 | # mask以及mask位置原来的label |
- 这一步是额外加的,因为我这里的sentence太短了,我的toy example只设置了max sentence为20,所以会出现没有mask的情况,也就是上面的t1_label,这样计算MLM的时候会出现Nan问题(但是在训练大模型中不会出现这个问题,因为大模型的token很长,有512或者1024等,也就是说有500+的单词,而且mask的概率也是提前设置的,所以会mask固定的数量,不会存在没被mask的情况)。所以我们接下来check以下t1_label和t2_label是否有没有被mask的,如果没有,将没有被mask的sentence至少mask一个。这里只有t1_label存在非mask的元素,所以check后只有t1变化了
1 | # 因为这里的token很短,会出现没有token被mask的情况 |
- 这里我们加上「CLS」和「SEP」这两个token,他们在我们的设置中编号分别为3和2,然后将t1_label和t2_label也进行相应更改,因为不是mask区域,所以加入0即可
1 | # [CLS] tag = SOS tag, [SEP] tag = EOS tag |
- 最后定义一下segment_label,即句子的标号,由于我们设置了最大sentence的长度,截取一下即可,注意如果整个sentence达不到预先定义的长度,比如本例一共18,我预先设置的max为20,就会进行padding,将18填充到20,padding的编号默认为0,注意箭头后面是我特意标定的补充值
1 | segment_label = ([1 for _ in range(len(t1))] + [2 for _ in range(len(t2))])[:self.seq_len] |
至此数据集的预处理和提取就好了,可以输入到网络进行训练
网络结构
网络结构分为两种
- BERT Base
- BERT Large
Loss计算
网络输出有两个,分别是MLM的输出和NSP的输出,其中二者均是分类问题
其中MLM输出的是mask部分的结果,由于非mask部分的编号为0,所以loss=nn.NLLLoss(ignore_index=0)要忽略掉0类,如果在网络输出没用logsoftmax,这里要用crossentropy损失
NSP输出的是0或者1,loss=nn.NLLLoss()就不需要忽略掉0类了