1. 簡介
Prompt通過將輸入文本填入預設prompt模板的方式,將下游NLP任務形式與語言模型預訓練任務統一起來,來更好地利用預訓練階段學習到的知識,使模型更容易適應于下游任務,在一系列NLP任務上取得了很好的效果[1]。Soft prompt方法使用可學習的參數來替代prompt模板中固定的token,盡管在少標注文本分類任務上性能優異[2],但是其表現隨模型初始化參數不同會出現很大的波動[1, 3]。人工選擇soft prompt模型參數需要對語言模型內部工作機理的深入理解和大量試錯,并且在遇到不同少標注任務時難以復用。
圖1 MetaPrompting幫助模型找到一個更優參數初始化點,以更快、更好地適應于新的少標注任務
為了解決上述問題,本文將目光從任務專用的soft prompt模型設計轉移到任務通用的模型參數初始化點搜索,以幫助模型快速適應到不同的少標注任務上。本文采用近年提出的基于優化的元學習方法,例如MAML[4]、Reptile[5]等,來搜索更優的soft prompt模型參數初始化點,以解決模型對初始化點過于敏感的問題。
本文在四個常用的少標注文本分類數據集上進行了充分的實驗,結果表明MetaPrompting相比其他基于元學習和prompt方法的強基線模型取得了更好的效果,達到了新的SOTA。
2. 方法
2.1 Soft prompt方法
Prompt方法通過將下游任務轉化成語言模型預訓練目標的形式,幫助模型更好地在下游任務上發揮性能。如圖2所示,對于一個新聞文本分類任務,可以通過將輸入文本填入prompt模板的方式,將該文本分類任務轉化為MLM任務形式。之后將模型在[MASK]位置填入各個詞語的概率映射到不同標簽上,即可完成文本分類任務的處理。
Soft prompt模型中,部分prompt tokens以可訓練embedding的形式給出,并可以和預訓練模型的參數一起進行優化,在保留離散token中語義信息的同時,給予模型更多的靈活性。
圖2 Soft prompt方法
2.2 將基于優化的元學習方法應用于soft prompt模型
少標注任務構建
本文使用元階段(episode)風格的少標注學習范式。具體而言,每一個少標注任務包含支持集和查詢集兩個部分,支持集中每個類別所對應標注樣本數量極少,本文通過將模型在支持集上進行適配,在查詢集上進行測試的方法,衡量模型的少標注學習性能。本文將不同標簽對應的樣本分別劃分成用于訓練、驗證和測試的少標注任務,以衡量模型從源領域學習通用元知識來處理目標領域少標注任務的能力。
基于元學習的soft prompt模型優化過程
MetaPrompting的整體優化過程如圖3所示。元訓練階段,模型在少標注任務的支持集上進行試探性參數更新,并在查詢集上進行梯度回傳。元測試階段,模型在未見過的少標注任務上進行適配和預測。令和分別表示預訓練模型和soft prompt的參數,在元訓練階段,模型在一個少標注任務支持集上進行適配的過程如下式所示:
其中α是適配過程的學習率,表示模型進行適配學習的步數。令模型在少標注任務上適配學習之后的參數為和,可將模型在該少標注任務上的優化目標描述為:
該優化目標模擬了模型在少標注場景下進行試探性參數更新,并根據試探性更新之后的情況優化模型參數的策略。這種策略更多關注了模型在一步或多步更新之后的情況,因而可以幫助模型找到一個能快速適應于新的少標注任務的參數初始化點。
圖3 MetaPrompting模型參數更新過程
實驗中,本文還使用了MAML++[6]中的多步梯度回傳技巧,來使得優化過程更加穩定,達到更好的效果。
3. 實驗
本文分別采用5way 1shot和5way 5shot的少標注學習設定來測試模型性能。實驗選擇了HuffPost、Amazon、Reuters和20newsgroup四個廣泛使用的文本分類數據集,結果以分類準確率%給出。
實驗結果如表1所示,表中20newsgroup數據集性能由于數據構造問題與原文略有出入,現為勘誤后結果,勘誤不影響實驗結論。由實驗結果可見,MetaPrompting性能優于當前的SOTA模型ContrastNet[7]和其他基于元學習和提示學習的方法,取得了明顯的性能提升。相比于不使用元學習優化目標的Ours (Pretrain Init),引入元學習搜索模型參數初始化點的Ours (Meta Init)也得到了更好的性能,說明了元學習方法在soft prompt模型參數優化中的有效性。
表1 MetaPrompting主實驗結果
主實驗中,為了與其他基線模型進行公平的對比,將soft prompt參數和預訓練模型參數一起進行了優化。為了更好地說明MetaPrompting針對soft prompt參數初始化的作用,本文還參數進行了固定預訓練模型的實驗。實驗結果如表2所示,相比于參數隨機初始化的soft prompt模型,MetaPrompting取得了明顯的性能提升。
表2 MetaPrompting在固定預訓練模型參數時的性能
現實應用場景中,往往難以得到內容、形式十分相近的源領域數據。因此本文還對MetaPrompting在分布外數據上的性能進行了測試。實驗結果如表3所示,即使源領域的數據內容、形式上有較大的差異,MetaPrompting仍然可以學習到任務通用的元知識,來輔助在目標領域少標注任務上的學習。
表3 MetaPrompting在不同內容、形式的源領域數據上進行元學習的性能
本文還對MetaPrompting對于不同prompt模板的魯棒性進行了測試。如表4所示,相比于隨機初始化的soft prompt模型,MetaPrompting尋找到的參數初始化點在不同prompt模板下性能方差更小,魯棒性更強。
表4 MetaPrompting在不同prompt模板下性能的方差
4. 總結
本文提出了MetaPrompting,將基于優化的元學習方法推廣到soft prompt模型中,來處理少標注文本任務。MetaPrompting利用源領域數據進行元學習,搜索能夠更快、更好地適應于新的少標注人物的模型參數初始化點。在4個少標注文本分類數據集上的實驗結果表明,MetaPrompting相比于樸素的soft prompt模型以及其他基于元學習的基線模型取得了更好的效果,達到了新的SOTA性能。
審核編輯 :李倩
-
模型
+關注
關注
1文章
3341瀏覽量
49270 -
數據集
+關注
關注
4文章
1209瀏覽量
24848 -
nlp
+關注
關注
1文章
489瀏覽量
22116
原文標題:參考文獻
文章出處:【微信號:zenRRan,微信公眾號:深度學習自然語言處理】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
帶通濾波器的設計步驟與優化方法
如何快速學習硬件電路
![如何快速<b class='flag-5'>學習</b>硬件電路](https://file1.elecfans.com/web3/M00/06/96/wKgZPGeNv1iAat6ZAAAZVw_Jvtk110.jpg)
焊接技術流程優化方法
傳統機器學習方法和應用指導
![傳統機器<b class='flag-5'>學習方法</b>和應用指導](https://file1.elecfans.com/web3/M00/04/33/wKgZPGdx9NKAcZdAAABMVybzcFI029.png)
什么是機器學習?通過機器學習方法能解決哪些問題?
![什么是機器<b class='flag-5'>學習</b>?通過機器<b class='flag-5'>學習方法</b>能解決哪些問題?](https://file.elecfans.com/web2/M00/4E/DC/poYBAGLCjeiALm_WAAAYmfR7Qec474.png)
評論