前言 這是一篇來(lái)自于 ACL 2022 的關(guān)于跨語(yǔ)言的 NER 蒸餾模型。主要的過(guò)程還是兩大塊:1)Teacher Model 的訓(xùn)練;2)從 Teacher Model 蒸餾到 Student Model。采用了類(lèi)似傳統(tǒng)的 Soft 蒸餾方式,其中利用了多任務(wù)的方式對(duì) Teacher Model 進(jìn)行訓(xùn)練,一個(gè)任務(wù)是 NER 訓(xùn)練的任務(wù),另一個(gè)是計(jì)算句對(duì)的相似性任務(wù)。整體思路還是采用了序列標(biāo)注的方法,也是一個(gè)不錯(cuò)的 IDEA。
論文標(biāo)題:
An Unsupervised Multiple-Task and Multiple-Teacher Model for Cross-lingual Named Entity Recognition
論文鏈接:
https://aclanthology.org/2022.acl-long.14.pdf
模型架構(gòu)
2.1 Teacher Model
以上就是 Teacher Model 的第一個(gè)任務(wù),直接對(duì)標(biāo)注序列進(jìn)行 NER,并且采用交叉熵?fù)p失函數(shù)作為 loss_function,計(jì)算如下:
2.2 Student Model Distilled
獲得兩個(gè)序列的hidden_state后進(jìn)行一個(gè)線(xiàn)性計(jì)算,然后利用softmax進(jìn)行歸一化,得到每個(gè)Token預(yù)測(cè)的標(biāo)簽,計(jì)算如下:
這里也類(lèi)似 Teacher Model 的計(jì)算方式,計(jì)算 target 序列間的Token相似度,計(jì)算如下所示:
當(dāng)然,這里做的是蒸餾模型,所以對(duì)于輸入到 Student Model 的序列對(duì),也是Teacher Model Inference 預(yù)測(cè)模型的輸入,通過(guò) Teacher Model 的預(yù)測(cè)計(jì)算得到一個(gè) teacher_ner_logits 和 teacher_similar_logits,將 teacher_ner_logits 分別與 和 通過(guò) CrossEntropyLoss 來(lái)計(jì)算 TS_ _Loss 和 TS_ _Loss,teacher_similar_logits 與 通過(guò) 計(jì)算 Similar_Loss,最終將幾個(gè) loss 進(jìn)行相加作為 DistilldeLoss。
這里作者還對(duì)每個(gè) TS_ _Loss,TS_ _Loss 分別賦予了權(quán)重 ,對(duì) Similar_Loss 賦予了權(quán)重 ,對(duì)最終的 DistilldeLoss 賦予權(quán)重 ,這樣的權(quán)重賦予能夠使得 Student Model 從 Teacher Model 學(xué)習(xí)到的噪聲減少。最終的 Loss 計(jì)算如下所示:
這里的權(quán)重 筆者認(rèn)為是用來(lái)控制 Student Model 學(xué)習(xí)傾向的參數(shù),首先對(duì)于 來(lái)說(shuō),由于 Student Model 輸入的是 Unlabeled 數(shù)據(jù),所以在進(jìn)行蒸餾學(xué)習(xí)時(shí),需要盡可能使得 Student Model 的輸出的 student_ner_logits 來(lái)對(duì)齊 Teacher Model 預(yù)測(cè)輸出的 teacher_ner_logits,由于不知道輸入的無(wú)標(biāo)簽數(shù)據(jù)的數(shù)據(jù)分布,所以設(shè)置一個(gè)權(quán)重參數(shù)來(lái)對(duì)整個(gè) Teacher Model 的預(yù)測(cè)標(biāo)簽進(jìn)行加權(quán),將各個(gè)無(wú)標(biāo)簽的輸入序列看作一個(gè)數(shù)據(jù)量較少的類(lèi)別。這里可以參考 在進(jìn)行數(shù)據(jù)標(biāo)簽不平衡時(shí)使用權(quán)重系數(shù)對(duì)各個(gè)標(biāo)簽進(jìn)行加權(quán)的操作。而且作者也分析了, 參數(shù)是一個(gè)隨著 Teacher Model 輸出而遞增的一個(gè)參數(shù)。如下圖所示:
實(shí)驗(yàn)結(jié)果
作者分別在 CoNLL 和 WiKiAnn 數(shù)據(jù)集上進(jìn)行了實(shí)驗(yàn),數(shù)據(jù)使用量如下圖所示:
簡(jiǎn)單代碼實(shí)現(xiàn)
#!/usr/bin/envpython
#-*-coding:utf-8-*-
#@Time:2022/5/3013:59
#@Author:SinGaln
"""
AnUnsupervisedMultiple-TaskandMultiple-TeacherModelforCross-lingualNamedEntityRecognition
"""
importtorch
importtorch.nnasnn
importtorch.nn.functionalasF
fromtransformersimportBertModel,BertPreTrainedModel,logging
logging.set_verbosity_error()
classTeacherNER(BertPreTrainedModel):
def__init__(self,config,num_labels):
"""
teacher模型是在標(biāo)簽數(shù)據(jù)上訓(xùn)練得到的,
主要分為三個(gè)encoder.
:paramconfig:
:paramnum_labels:
"""
super(TeacherNER,self).__init__(config)
self.config=config
self.num_labels=num_labels
self.mbert=BertModel(config=config)
self.fc=nn.Linear(config.hidden_size,num_labels)
defforward(self,batch_token_input_ids,batch_attention_mask,batch_token_type_ids,batch_labels,training=True,
batch_pair_input_ids=None,batch_pair_attention_mask=None,batch_pair_token_type_ids=None,
batch_t=None):
"""
:parambatch_token_input_ids:單句子token序列
:parambatch_attention_mask:單句子attention_mask
:parambatch_token_type_ids:單句子token_type_ids
:parambatch_pair_input_ids:句對(duì)token序列
:parambatch_pair_attention_mask:句對(duì)attention_mask
:parambatch_pair_token_type_ids:句對(duì)token_type_ids
"""
#RecognizerTeacher
single_output=self.mbert(input_ids=batch_token_input_ids,attention_mask=batch_attention_mask,
token_type_ids=batch_token_type_ids).last_hidden_state
single_output=F.softmax(self.fc(single_output),dim=-1)
#EvaluatorTeacher(類(lèi)似雙塔模型)
pair_output1=self.mbert(input_ids=batch_pair_input_ids[0],attention_mask=batch_pair_attention_mask[0],
token_type_ids=batch_pair_token_type_ids[0]).last_hidden_state
pair_output2=self.mbert(input_ids=batch_pair_input_ids[1],attention_mask=batch_pair_attention_mask[1],
token_type_ids=batch_pair_token_type_ids[1]).last_hidden_state
pair_output=torch.sigmoid(torch.cosine_similarity(pair_output1,pair_output2,dim=-1))#計(jì)算兩個(gè)輸出的cosine相似度
iftraining:
#計(jì)算loss,訓(xùn)練時(shí)采用平均loss作為模型最終的loss
loss1=F.cross_entropy(single_output.view(-1,self.num_labels),batch_labels.view(-1))
loss2=F.binary_cross_entropy(pair_output,batch_t.type(torch.float))
loss=loss1+loss2
returnsingle_output,loss
else:
returnsingle_output,pair_output
classStudentNER(BertPreTrainedModel):
def__init__(self,config,num_labels):
"""
student模型采用的也是一個(gè)雙塔結(jié)構(gòu)
:paramconfig:mBert的配置文件
:paramnum_labels:標(biāo)簽數(shù)量
"""
super(StudentNER,self).__init__(config)
self.config=config
self.num_labels=num_labels
self.mbert=BertModel(config=config)
self.fc1=nn.Linear(config.hidden_size,num_labels)
self.fc2=nn.Linear(config.hidden_size,num_labels)
defforward(self,batch_pair_input_ids,batch_pair_attention_mask,batch_pair_token_type_ids,batch_pair_labels,
teacher_logits,teacher_similar):
"""
:parambatch_pair_input_ids:句對(duì)token序列
:parambatch_pair_attention_mask:句對(duì)attention_mask
:parambatch_pair_token_type_ids:句對(duì)token_type_ids
"""
output1=self.mbert(input_ids=batch_pair_input_ids[0],attention_mask=batch_pair_attention_mask[0],
token_type_ids=batch_pair_token_type_ids[0]).last_hidden_state
output2=self.mbert(input_ids=batch_pair_input_ids[1],attention_mask=batch_pair_attention_mask[1],
token_type_ids=batch_pair_token_type_ids[1]).last_hidden_state
soft_output1,soft_output2=self.fc1(output1),self.fc2(output2)
soft_logits1,soft_logits2=F.softmax(soft_output1,dim=-1),F.softmax(soft_output2,dim=-1)
alpha1,alpha2=torch.square(torch.max(input=soft_logits1,dim=-1)[0]).mean(),torch.square(
torch.max(soft_logits2,dim=-1)[0]).mean()
output_similar=torch.sigmoid(torch.cosine_similarity(soft_output1,soft_output2,dim=-1))
soft_similar=torch.sigmoid(torch.cosine_similarity(soft_logits1,soft_logits2,dim=-1))
beta=torch.square(2*output_similar-1).mean()
gamma=1-torch.abs(soft_similar-output_similar).mean()
#計(jì)算蒸餾的loss
#teacherlogits與studentlogits1的loss
loss1=alpha1*(F.cross_entropy(soft_logits1,teacher_logits))
#teachersimilar與studentsimilar的loss
loss2=beta*(F.binary_cross_entropy(soft_similar,teacher_similar))
#teacherlogits與studentlogits2的loss
loss3=alpha2*(F.cross_entropy(soft_logits2,teacher_logits))
#finalloss
loss=gamma*(loss1+loss2+loss3).mean()
returnloss
if__name__=="__main__":
fromtransformersimportBertConfig
pretarin_path="./pytorch_mbert_model"
batch_pair1_input_ids=torch.randint(1,100,(2,128))
batch_pair1_attention_mask=torch.ones_like(batch_pair1_input_ids)
batch_pair1_token_type_ids=torch.zeros_like(batch_pair1_input_ids)
batch_labels1=torch.randint(1,10,(2,128))
batch_labels2=torch.randint(1,10,(2,128))
#t(對(duì)比兩個(gè)序列標(biāo)簽,相同為1,不同為0)
batch_t=torch.as_tensor(batch_labels1.numpy()==batch_labels2.numpy()).float()
batch_pair2_input_ids=torch.randint(1,100,(2,128))
batch_pair2_attention_mask=torch.ones_like(batch_pair2_input_ids)
batch_pair2_token_type_ids=torch.zeros_like(batch_pair2_input_ids)
batch_all_input_ids,batch_all_attention_mask,batch_all_token_type_ids,batch_all_labels=[],[],[],[]
batch_all_labels.append(batch_labels1)
batch_all_labels.append(batch_labels2)
batch_all_input_ids.append(batch_pair1_input_ids)
batch_all_input_ids.append(batch_pair2_input_ids)
batch_all_attention_mask.append(batch_pair1_attention_mask)
batch_all_attention_mask.append(batch_pair2_attention_mask)
batch_all_token_type_ids.append(batch_pair1_token_type_ids)
batch_all_token_type_ids.append(batch_pair2_token_type_ids)
config=BertConfig.from_pretrained(pretarin_path)
#teacher模型訓(xùn)練
teacher_model=TeacherNER.from_pretrained(pretarin_path,config=config,num_labels=10)
outputs,loss=teacher_model(batch_token_input_ids=batch_pair1_input_ids,
batch_attention_mask=batch_pair1_attention_mask,
batch_token_type_ids=batch_pair1_token_type_ids,batch_labels=batch_labels1,
batch_pair_input_ids=batch_all_input_ids,
batch_pair_attention_mask=batch_all_attention_mask,
batch_pair_token_type_ids=batch_all_token_type_ids,
training=True,batch_t=batch_t)
#student模型蒸餾
teacher_logits,teacher_similar=teacher_model(batch_token_input_ids=batch_pair1_input_ids,
batch_attention_mask=batch_pair1_attention_mask,
batch_token_type_ids=batch_pair1_token_type_ids,
batch_labels=batch_labels1,
batch_pair_input_ids=batch_all_input_ids,
batch_pair_attention_mask=batch_all_attention_mask,
batch_pair_token_type_ids=batch_all_token_type_ids,
training=False)
student_model=StudentNER.from_pretrained(pretarin_path,config=config,num_labels=10)
loss_all=student_model(batch_pair_input_ids=batch_all_input_ids,
batch_pair_attention_mask=batch_all_attention_mask,
batch_pair_token_type_ids=batch_all_token_type_ids,
batch_pair_labels=batch_all_labels,teacher_logits=teacher_logits,
teacher_similar=teacher_similar)
print(loss_all)
筆者自己實(shí)現(xiàn)的一部分代碼,可能不是原論文作者想表達(dá)的意思,讀者有疑問(wèn)的話(huà)可以一起討論一下^~^。
審核編輯 :李倩
-
編碼器
+關(guān)注
關(guān)注
45文章
3776瀏覽量
137201 -
模型
+關(guān)注
關(guān)注
1文章
3488瀏覽量
50021 -
標(biāo)簽
+關(guān)注
關(guān)注
0文章
145瀏覽量
18188
原文標(biāo)題:ACL2022 | 跨語(yǔ)言命名實(shí)體識(shí)別:無(wú)監(jiān)督多任務(wù)多教師蒸餾模型
文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語(yǔ)言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
自然語(yǔ)言基礎(chǔ)技術(shù)之命名實(shí)體識(shí)別相對(duì)全面的介紹

HanLP分詞命名實(shí)體提取詳解
基于結(jié)構(gòu)化感知機(jī)的詞性標(biāo)注與命名實(shí)體識(shí)別框架
HanLP-命名實(shí)體識(shí)別總結(jié)
基于神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)在命名實(shí)體識(shí)別中應(yīng)用的分析與總結(jié)

深度學(xué)習(xí):四種利用少量標(biāo)注數(shù)據(jù)進(jìn)行命名實(shí)體識(shí)別的方法

思必馳中文命名實(shí)體識(shí)別任務(wù)助力AI落地應(yīng)用
新型中文旅游文本命名實(shí)體識(shí)別設(shè)計(jì)方案

知識(shí)圖譜與訓(xùn)練模型相結(jié)合和命名實(shí)體識(shí)別的研究工作

命名實(shí)體識(shí)別的遷移學(xué)習(xí)相關(guān)研究分析

基于字語(yǔ)言模型的中文命名實(shí)體識(shí)別系統(tǒng)

評(píng)論