谷歌NLP模型的官方TensorFlow實(shí)現(xiàn)很強(qiáng),現(xiàn)在,它的PyTorch版本來(lái)了!只需簡(jiǎn)單運(yùn)行一次轉(zhuǎn)換腳本,就可得到一個(gè)PyTorch模型,且結(jié)果與原始版本相近,甚至更好。
上周,谷歌最強(qiáng)NLP模型BERT開源了官方TensorFlow代碼和預(yù)訓(xùn)練模型,引起大量關(guān)注。
現(xiàn)在,PyTorch用戶的福利來(lái)了:一個(gè)名為Hugging Face的團(tuán)隊(duì)近日公開了BERT模型的谷歌官方TensorFlow庫(kù)的op-for-op PyTorch重新實(shí)現(xiàn):
https://github.com/huggingface/pytorch-pretrained-BERT
這個(gè)實(shí)現(xiàn)可以為BERT加載任何預(yù)訓(xùn)練的TensorFlow checkpoint(特別是谷歌的官方預(yù)訓(xùn)練模型),并提供一個(gè)轉(zhuǎn)換腳本。
BERT-base和BERT-large模型的參數(shù)數(shù)量分別為110M和340M,為了獲得良好的性能,很難使用推薦的batch size在單個(gè)GPU上對(duì)其進(jìn)行微調(diào)。為了幫助微調(diào)模型,這個(gè)repo還提供了3種可以在微調(diào)腳本中激活技術(shù):梯度累積(gradient-accumulation)、multi-GPU和分布式訓(xùn)練。
其結(jié)果如下:
在序列級(jí)MRPC分類任務(wù)上,該實(shí)現(xiàn)使用小型BERT-base模型再現(xiàn)了原始實(shí)現(xiàn)的84%-88%的準(zhǔn)確率。
在token級(jí)的SQuAD 任務(wù)上,該個(gè)實(shí)現(xiàn)使用小型BERT-base模型再現(xiàn)了原始實(shí)現(xiàn)的88.52 F1的結(jié)果。
作者表示,正致力于在其他任務(wù)以及更大的BERT模型上重現(xiàn)結(jié)果。
BERT模型的PyTorch實(shí)現(xiàn)
這個(gè)存儲(chǔ)庫(kù)包含了谷歌BERT模型的官方TensorFlow存儲(chǔ)庫(kù)的op-for-op PyTorch重新實(shí)現(xiàn)。谷歌的官方存儲(chǔ)庫(kù)是與BERT論文一起發(fā)布的:BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding,作者是Jacob Devlin、Ming-Wei Chang、Kenton Lee和Kristina Toutanova。
這個(gè)實(shí)現(xiàn)可以為BERT加載任何預(yù)訓(xùn)練的TensorFlow checkpoint(特別是谷歌的預(yù)訓(xùn)練模型),并提供了一個(gè)轉(zhuǎn)換腳本(見下文)。
此外,我們將在本周晚些時(shí)候添加多語(yǔ)言版本和中文版本的模型代碼。
腳本:加載任何TensorFlow檢查點(diǎn)
使用convert_tf_checkpoint_to_pytorch.py腳本,你可以在PyTorch保存文件中轉(zhuǎn)換BERT的任何TensorFlow檢查點(diǎn)(尤其是谷歌發(fā)布的官方預(yù)訓(xùn)練模型)。
這個(gè)腳本將TensorFlow checkpoint(以bert_model.ckpt開頭的三個(gè)文件)和相關(guān)的配置文件(bert_config.json)作為輸入,并為此配置創(chuàng)建PyTorch模型,從PyTorch模型的TensorFlow checkpoint加載權(quán)重并保存生成的模型在一個(gè)標(biāo)準(zhǔn)PyTorch保存文件中,可以使用 torch.load()導(dǎo)入(請(qǐng)參閱extract_features.py,run_classifier.py和run_squad.py中的示例)。
只需要運(yùn)行一次這個(gè)轉(zhuǎn)換腳本,就可以得到一個(gè)PyTorch模型。然后,你可以忽略TensorFlow checkpoint(以bert_model.ckpt開頭的三個(gè)文件),但是一定要保留配置文件(bert_config.json)和詞匯表文件(vocab.txt),因?yàn)镻yTorch模型也需要這些文件。
要運(yùn)行這個(gè)特定的轉(zhuǎn)換腳本,你需要安裝TensorFlow和PyTorch。該庫(kù)的其余部分只需要PyTorch。
下面是一個(gè)預(yù)訓(xùn)練的BERT-Base Uncased 模型的轉(zhuǎn)換過(guò)程示例:
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 python convert_tf_checkpoint_to_pytorch.py --tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt --bert_config_file $BERT_BASE_DIR/bert_config.json --pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin
你可以在這里下載Google的預(yù)訓(xùn)練轉(zhuǎn)換模型:
https://github.com/google-research/bert#pre-trained-models
BERT的PyTorch模型
在這個(gè)庫(kù)里,我們提供了三個(gè)PyTorch模型,你可以在modeling.py中找到:
BertModel- 基本的BERT Transformer 模型
BertForSequenceClassification- 頂部帶有sequence classification head的BERT模型
BertForQuestionAnswering- 頂部帶有token classification head 的BERT模型,
以下是每類模型的一些細(xì)節(jié)。
1 . BertModel
BertModel是一個(gè)基本的BERT Transformer模型,包含一個(gè)summed token、位置和序列嵌入層,然后是一系列相同的self-attention blocks(BERT-base是12個(gè)blocks, BERT-large是24個(gè)blocks)。
輸入和輸出與TensorFlow 模型的輸入和輸出相同。
具體來(lái)說(shuō),該模型的輸入是:
input_ids:一個(gè)形狀為[batch_size, sequence_length]的torch.LongTensor,在詞匯表中包含單詞的token索引
token_type_ids:形狀[batch_size, sequence_length]的可選torch.LongTensor,在[0,1]中選擇token類型索引。類型0對(duì)應(yīng)于句子A,類型1對(duì)應(yīng)于句子B。
attention_mask:一個(gè)可選的torch.LongTensor,形狀為[batch_size, sequence_length],索引在[0,1]中選擇。
模型的輸出是由以下內(nèi)容組成的一個(gè)元組:
all_encoder_layers:一個(gè)大小為[batch_size, sequence_length,hidden_size]的torch.FloatTensor列表,它是每個(gè)注意塊末端隱藏狀態(tài)的完整序列列表(即BERT-base的12個(gè)完整序列,BERT-large的24個(gè)完整序列)
pooled_output:一個(gè)大小為[batch_size, hidden_size]的torch.FloatTensor,它是在與輸入(CLF)的第一個(gè)字符相關(guān)聯(lián)的隱藏狀態(tài)之上預(yù)訓(xùn)練的分類器的輸出,用于訓(xùn)練Next-Sentence任務(wù)(參見BERT的論文)。
extract_features.py腳本提供了有關(guān)如何使用這類模型的示例,該腳本可用于為給定輸入提取模型的隱藏狀態(tài)。
2 . BertForSequenceClassification
BertForSequenceClassification是一個(gè)fine-tuning 模型,包括BertModel,以及BertModel頂部的一個(gè)序列級(jí)分類器(sequence-level classifier)。
序列級(jí)分類器是一個(gè)線性層,它將輸入序列中第一個(gè)字符的最后隱藏狀態(tài)作為輸入(參見BERT論文中的圖3a和3b)。
run_classifier.py腳本提供了關(guān)于如何使用此類模型的示例,該腳本可用于使用BERT微調(diào)單個(gè)序列(或序列對(duì))分類器,例如用于MRPC任務(wù)。
3. BertForQuestionAnswering
BertForQuestionAnswering是一個(gè)fine-tuning 模型,包括BertModel,它在最后隱藏狀態(tài)的完整序列之上具有token級(jí)分類器(token-level classifiers)。
token-level 分類器將最后隱藏狀態(tài)的完整序列作為輸入,并為每個(gè)token計(jì)算得分,(參見BERT論文的圖3c和3d)。
run_squad.py腳本提供了有關(guān)如何使用此類模型的示例,該腳本可用于使用BERT微調(diào)token分類器,例如用于SQuAD任務(wù)。
安裝、要求、測(cè)試
這段代碼在Python 3.5+上進(jìn)行了測(cè)試。必備條件是:
PyTorch (> = 0.4.1)
tqdm
安裝dependencies:
pip install -r ./requirements.txt
測(cè)試文件夾中包含一系列測(cè)試,可以使用pytest運(yùn)行(如果需要,請(qǐng)安裝pytest: pip install pytest)。
你可以使用以下命令運(yùn)行測(cè)試:
python -m pytest -sv tests/大批量訓(xùn)練:梯度積累、多GPU、分布式訓(xùn)練
BERT-base和BERT-large的模型參數(shù)分別是110M和340M,為了獲得良好的性能(大多數(shù)情況下批大小是32),很難在單個(gè)GPU上對(duì)它們進(jìn)行微調(diào)。
為了幫助對(duì)這些模型進(jìn)行微調(diào),我們介紹了在微調(diào)腳本run_classifier.py和run_squad中可以激活的四種技術(shù):優(yōu)化CPU、梯度積累、multi-gpu和分布式訓(xùn)練。
有關(guān)如何使用這些技術(shù)的更多細(xì)節(jié),你可以閱讀這篇關(guān)于PyTorch批量訓(xùn)練技巧的文章:
https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255
BERT的微調(diào):運(yùn)行示例
我們展示了與原始實(shí)現(xiàn)相同的示例:在MRPC分類語(yǔ)料庫(kù)上微調(diào)sequence級(jí)分類器和在問題回答數(shù)據(jù)集SQuAD上微調(diào)token級(jí)分類器。
在運(yùn)行這些示例之前,應(yīng)該先下載GLUE數(shù)據(jù),并將其解壓縮到某個(gè)目錄$GLUE_DIR。還需下載BERT-Base checkpoint,將其解壓縮到某個(gè)目錄$BERT_BASE_DIR,并將其轉(zhuǎn)換為上一節(jié)所述的PyTorch版本。
這個(gè)示例代碼基于微軟研究意譯語(yǔ)料庫(kù)(MRPC)調(diào)優(yōu)了BERT-Base,在單個(gè)K-80上運(yùn)行只需不到10分鐘。
export GLUE_DIR=/path/to/glue python run_classifier.py --task_name MRPC --do_train --do_eval --do_lower_case --data_dir $GLUE_DIR/MRPC/ --vocab_file $BERT_BASE_DIR/vocab.txt --bert_config_file $BERT_BASE_DIR/bert_config.json --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin --max_seq_length 128 --train_batch_size 32 --learning_rate 2e-5 --num_train_epochs 3.0 --output_dir /tmp/mrpc_output/
基于原始實(shí)現(xiàn)的超參數(shù)進(jìn)行測(cè)試,評(píng)估結(jié)果達(dá)到84%到88%。
第二個(gè)示例是基于SQuAD問題回答任務(wù)微調(diào)BERT-Base。
export SQUAD_DIR=/path/to/SQUAD python run_squad.py --vocab_file $BERT_BASE_DIR/vocab.txt --bert_config_file $BERT_BASE_DIR/bert_config.json --init_checkpoint $BERT_PYTORCH_DIR/pytorch_model.bin --do_train --do_predict --do_lower_case --train_file $SQUAD_DIR/train-v1.1.json --predict_file $SQUAD_DIR/dev-v1.1.json --train_batch_size 12 --learning_rate 3e-5 --num_train_epochs 2.0 --max_seq_length 384 --doc_stride 128 --output_dir ../debug_squad/
使用之前的超參數(shù)進(jìn)行訓(xùn)練,得到如下結(jié)果:
{"f1": 88.52381567990474, "exact_match": 81.22043519394512}在GPU上微調(diào)BERT-large
上面列出的選項(xiàng)允許在GPU上很容易地對(duì)BERT-large進(jìn)行微調(diào),而不是像原始實(shí)現(xiàn)那樣使用TPU。
例如,針對(duì)SQuAD任務(wù)微調(diào)BERT-large模型,可以在服務(wù)器上用4個(gè)k-80在18個(gè)小時(shí)內(nèi)完成。我們的結(jié)果與TensorFlow的實(shí)現(xiàn)結(jié)果相似(實(shí)際上是略高):
{"exact_match": 84.56953642384106, "f1": 91.04028647786927}
為了得到這些結(jié)果,我們使用了以下組合:
多GPU訓(xùn)練(在多GPU服務(wù)器上自動(dòng)激活),
梯度累積
在CPU上執(zhí)行優(yōu)化步驟,將Adam的平均值存儲(chǔ)在RAM中。
以下是我們?cè)诖舜芜\(yùn)行中使用的超參數(shù)的完整列表:
python ./run_squad.py --vocab_file $BERT_LARGE_DIR/vocab.txt --bert_config_file $BERT_LARGE_DIR/bert_config.json --init_checkpoint $BERT_LARGE_DIR/pytorch_model.bin --do_lower_case --do_train --do_predict --train_file $SQUAD_TRAIN --predict_file $SQUAD_EVAL --learning_rate 3e-5 --num_train_epochs 2 --max_seq_length 384 --doc_stride 128 --output_dir $OUTPUT_DIR/bert_large_bsz_24 --train_batch_size 24 --gradient_accumulation_steps 2 --optimize_on_cpu
-
谷歌
+關(guān)注
關(guān)注
27文章
6224瀏覽量
107589 -
tensorflow
+關(guān)注
關(guān)注
13文章
330瀏覽量
61040 -
pytorch
+關(guān)注
關(guān)注
2文章
809瀏覽量
13777
原文標(biāo)題:橫掃各項(xiàng)NLP任務(wù)的BERT模型有了PyTorch實(shí)現(xiàn)!提供轉(zhuǎn)換腳本
文章出處:【微信號(hào):AI_era,微信公眾號(hào):新智元】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
Pytorch模型訓(xùn)練實(shí)用PDF教程【中文】
怎樣去解決pytorch模型一直無(wú)法加載的問題呢
將pytorch模型轉(zhuǎn)化為onxx模型的步驟有哪些
怎樣使用PyTorch Hub去加載YOLOv5模型
通過(guò)Cortex來(lái)非常方便的部署PyTorch模型
將Pytorch模型轉(zhuǎn)換為DeepViewRT模型時(shí)出錯(cuò)怎么解決?
pytorch模型轉(zhuǎn)換需要注意的事項(xiàng)有哪些?
圖解BERT預(yù)訓(xùn)練模型!
如何使用BERT模型進(jìn)行抽取式摘要

PyTorch教程15.9之預(yù)訓(xùn)練BERT的數(shù)據(jù)集

PyTorch教程15.10之預(yù)訓(xùn)練BERT

PyTorch教程16.6之針對(duì)序列級(jí)和令牌級(jí)應(yīng)用程序微調(diào)BERT

PyTorch教程16.7之自然語(yǔ)言推理:微調(diào)BERT

評(píng)論