為了預(yù)訓(xùn)練第 15.8 節(jié)中實(shí)現(xiàn)的 BERT 模型,我們需要以理想的格式生成數(shù)據(jù)集,以促進(jìn)兩項預(yù)訓(xùn)練任務(wù):掩碼語言建模和下一句預(yù)測。一方面,原始的 BERT 模型是在兩個巨大的語料庫 BookCorpus 和英文維基百科(參見第15.8.5 節(jié))的串聯(lián)上進(jìn)行預(yù)訓(xùn)練的,這使得本書的大多數(shù)讀者難以運(yùn)行。另一方面,現(xiàn)成的預(yù)訓(xùn)練 BERT 模型可能不適合醫(yī)學(xué)等特定領(lǐng)域的應(yīng)用。因此,在自定義數(shù)據(jù)集上預(yù)訓(xùn)練 BERT 變得越來越流行。為了便于演示 BERT 預(yù)訓(xùn)練,我們使用較小的語料庫 WikiText-2 ( Merity et al. , 2016 )。
與 15.3節(jié)用于預(yù)訓(xùn)練word2vec的PTB數(shù)據(jù)集相比,WikiText-2(i)保留了原有的標(biāo)點(diǎn)符號,適合下一句預(yù)測;(ii) 保留原始案例和編號;(iii) 大兩倍以上。
在 WikiText-2 數(shù)據(jù)集中,每一行代表一個段落,其中在任何標(biāo)點(diǎn)符號及其前面的標(biāo)記之間插入空格。保留至少兩句話的段落。為了簡單起見,為了拆分句子,我們只使用句點(diǎn)作為分隔符。我們將在本節(jié)末尾的練習(xí)中討論更復(fù)雜的句子拆分技術(shù)。
#@save
d2l.DATA_HUB['wikitext-2'] = (
'https://s3.amazonaws.com/research.metamind.io/wikitext/'
'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
#@save
def _read_wiki(data_dir):
file_name = os.path.join(data_dir, 'wiki.train.tokens')
with open(file_name, 'r') as f:
lines = f.readlines()
# Uppercase letters are converted to lowercase ones
paragraphs = [line.strip().lower().split(' . ')
for line in lines if len(line.split(' . ')) >= 2]
random.shuffle(paragraphs)
return paragraphs
#@save
d2l.DATA_HUB['wikitext-2'] = (
'https://s3.amazonaws.com/research.metamind.io/wikitext/'
'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
#@save
def _read_wiki(data_dir):
file_name = os.path.join(data_dir, 'wiki.train.tokens')
with open(file_name, 'r') as f:
lines = f.readlines()
# Uppercase letters are converted to lowercase ones
paragraphs = [line.strip().lower().split(' . ')
for line in lines if len(line.split(' . ')) >= 2]
random.shuffle(paragraphs)
return paragraphs
15.9.1。為預(yù)訓(xùn)練任務(wù)定義輔助函數(shù)
下面,我們首先為兩個 BERT 預(yù)訓(xùn)練任務(wù)實(shí)現(xiàn)輔助函數(shù):下一句預(yù)測和掩碼語言建模。這些輔助函數(shù)將在稍后將原始文本語料庫轉(zhuǎn)換為理想格式的數(shù)據(jù)集以預(yù)訓(xùn)練 BERT 時調(diào)用。
15.9.1.1。生成下一句預(yù)測任務(wù)
根據(jù)15.8.5.2 節(jié)的描述,該 _get_next_sentence
函數(shù)為二元分類任務(wù)生成一個訓(xùn)練樣例。
#@save
def _get_next_sentence(sentence, next_sentence, paragraphs):
if random.random() < 0.5:
is_next = True
else:
# `paragraphs` is a list of lists of lists
next_sentence = random.choice(random.choice(paragraphs))
is_next = False
return sentence, next_sentence, is_next
以下函數(shù)paragraph
通過調(diào)用該 _get_next_sentence
函數(shù)從輸入生成用于下一句預(yù)測的訓(xùn)練示例。這paragraph
是一個句子列表,其中每個句子都是一個標(biāo)記列表。該參數(shù) max_len
指定預(yù)訓(xùn)練期間 BERT 輸入序列的最大長度。
#@save
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
nsp_data_from_paragraph = []
for i in range(len(paragraph) - 1):
tokens_a, tokens_b, is_next = _get_next_sentence(
paragraph[i], paragraph[i + 1], paragraphs)
# Consider 1 '' token and 2 '' tokens
if len(tokens_a) + len(tokens_b) + 3 > max_len:
continue
tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
nsp_data_from_paragraph.append((tokens, segments, is_next))
return nsp_data_from_paragraph
#@save
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
nsp_data_from_paragraph = []
for i in range(len(paragraph) - 1):
tokens_a, tokens_b, is_next = _get_next_sentence(
paragraph[i], paragraph[i + 1], paragraphs)
# Consider 1 '' token and 2 '' tokens
if len(tokens_a) + len(tokens_b) + 3 > max_len:
continue
tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)
nsp_data_from_paragraph.append((tokens, segments, is_next))
return nsp_data_from_paragraph
15.9.1.2。生成掩碼語言建模任務(wù)
為了從 BERT 輸入序列為掩碼語言建模任務(wù)生成訓(xùn)練示例,我們定義了以下 _replace_mlm_tokens
函數(shù)。在它的輸入中,tokens
是代表BERT輸入序列的token列表,candidate_pred_positions
是BERT輸入序列的token索引列表,不包括特殊token(masked語言建模任務(wù)中不預(yù)測特殊token),num_mlm_preds
表示預(yù)測(召回 15% 的隨機(jī)標(biāo)記來預(yù)測)。遵循第 15.8.5.1 節(jié)中屏蔽語言建模任務(wù)的定義 ,在每個預(yù)測位置,輸入可能被特殊的“”標(biāo)記或隨機(jī)標(biāo)記替換,或者保持不變。最后,該函數(shù)返回可能替換后的輸入標(biāo)記、發(fā)生預(yù)測的標(biāo)記索引以及這些預(yù)測的標(biāo)簽。
#@save
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
vocab):
# For the input of a masked language model, make a new copy of tokens and
# replace some of them by '' or random tokens
mlm_input_tokens = [token for token in tokens]
pred_positions_and_labels = []
# Shuffle for getting 15% random tokens for prediction in the masked
# language modeling task
random.shuffle(candidate_pred_positions)
for mlm_pred_position in candidate_pred_positions:
if len(pred_positions_and_labels) >= num_mlm_preds:
break
masked_token = None
# 80% of the time: replace the word with the '' token
if random.random() < 0.8:
masked_token = ''
else:
# 10% of the time: keep the word unchanged
if random.random() < 0.5:
masked_token = tokens[mlm_pred_position]
# 10% of the time: replace the word with a random word
else:
masked_token = random.choice(vocab.idx_to_token)
mlm_input_tokens[mlm_pred_position] = masked_token
pred_positions_and_labels.append(
(mlm_pred_position, tokens[mlm_pred_position]))
return mlm_input_tokens, pred_positions_and_labels
#@save
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,
vocab):
# For the input of a masked language model, make a new copy of tokens and
# replace some of them by '' or random tokens
mlm_input_tokens = [token for token in tokens]
pred_positions_and_labels = []
# Shuffle for getting 15% random tokens for prediction in the masked
# language modeling task
random.shuffle(candidate_pred_positions)
for mlm_pred_position in candidate_pred_positions:
if len(pred_positions_and_labels) >= num_mlm_preds:
break
masked_token = None
# 80% of the time: replace the word with the '' token
if random.random() < 0.8:
masked_token = ''
else:
# 10% of the time: keep the word unchanged
if random.random() < 0.5:
masked_token = tokens[mlm_pred_position]
# 10% of the time: replace the word with a random word
else:
masked_token = random.choice(vocab.idx_to_token)
mlm_input_tokens[mlm_pred_position] = masked_token
pred_positions_and_labels.append(
(mlm_pred_position, tokens[mlm_pred_position]))
return mlm_input_tokens, pred_positions_and_labels
通過調(diào)用上述_replace_mlm_tokens
函數(shù),以下函數(shù)將 BERT 輸入序列 ( tokens
) 作為輸入并返回輸入標(biāo)記的索引(在可能的標(biāo)記替換之后,如第15.8.5.1 節(jié)所述)、發(fā)生預(yù)測的標(biāo)記索引和標(biāo)簽這些預(yù)測的指標(biāo)。
#@save
def _get_mlm_data_from_tokens(tokens, vocab):
candidate_pred_positions = []
# `tokens` is a list of strings
for i, token in enumerate(tokens):
# Special tokens are not predicted in the masked language modeling
# task
if token in ['', '']:
continue
candidate_pred_positions.append(i)
# 15% of random tokens are predicted in the masked language modeling task
num_mlm_preds = max(1, round(len(tokens) * 0.15))
mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(
tokens, candidate_pred_positions, num_mlm_preds, vocab)
pred_positions_and_labels = sorted(pred_positions_and_labels,
key=lambda x: x[0])
pred_positions <
評論