旋轉(zhuǎn)位置編碼(Rotary Position Embedding,RoPE)是論文 Roformer: Enhanced Transformer With Rotray Position Embedding 提出的一種能夠?qū)⑾鄬ξ恢?a target="_blank">信息依賴集成到 self-attention 中并提升 transformer 架構(gòu)性能的位置編碼方式。而目前很火的 LLaMA、GLM 模型也是采用該位置編碼方式。
和相對位置編碼相比,RoPE 具有更好的外推性,目前是大模型相對位置編碼中應(yīng)用最廣的方式之一。
備注:什么是大模型外推性?
外推性是指大模型在訓練時和預(yù)測時的輸入長度不一致,導致模型的泛化能力下降的問題。例如,如果一個模型在訓練時只使用了 512 個 token 的文本,那么在預(yù)測時如果輸入超過 512 個 token,模型可能無法正確處理。這就限制了大模型在處理長文本或多輪對話等任務(wù)時的效果。
旋轉(zhuǎn)編碼RoPE
1.1 基本概念
在介紹 RoPE 之前,先給出一些符號定義,以及基本背景。
首先定義一個長度為 的輸入序列為:
1.2 絕對位置編碼
對于位置編碼,常規(guī)的做法是在計算 query,key 和 value 向量之前,會計算一個位置編碼向量 加到詞嵌入 上,位置編碼向量 同樣也是 維向量,然后再乘以對應(yīng)的變換矩陣 :
![46588f4e-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QWAThK8AAA11idTdIU511.png)
而經(jīng)典的位置編碼向量 的計算方式是使用 Sinusoidal 函數(shù):
![46604568-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QWAGeJsAABR7o8iyC0201.png)
其中 表示位置 維度向量 中的第 位置分量也就是偶數(shù)索引位置的計算公式,而 就對應(yīng)第 位置分量也就是奇數(shù)索引位置的計算公式。
1.3 2維旋轉(zhuǎn)位置編碼
論文中提出為了能利用上 token 之間的相對位置信息,假定 query 向量 和 key 向量 之間的內(nèi)積操作可以被一個函數(shù) 表示,該函數(shù) 的輸入是詞嵌入向量 , 和它們之間的相對位置 :
![46651980-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QWAX-eEAAA3_Fx7-O8140.png)
![467a59c6-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QWAdLweAAB1QDC65Yc487.png)
![468cdf4c-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QaALX_FAACQ5oqF5yY950.png)
![469c2614-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QaAQt6WAACGRDN4uaY788.png)
將2維推廣到任意維度,可以表示如下:
![46cf41ac-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QaAZVPGAAAyvGXEcC8314.png)
![46e42edc-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QaAONp0AACCi-wJ1vU537.png)
![46fdb5be-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QaAT5UlAABAlikQicQ045.png)
其中,。
值得指出的是,由于 是一個正交矩陣,它不會改變向量的模長,因此通常來說它不會改變原模型的穩(wěn)定性。 1.5 RoPE 的高效計算由于 的稀疏性,所以直接用矩陣乘法來實現(xiàn)會很浪費算力,推薦通過下述方式來實現(xiàn) RoPE:
1.6 遠程衰減
可以看到,RoPE 形式上和前面公式(6)Sinusoidal 位置編碼有點相似,只不過 Sinusoidal 位置編碼是加性的,而 RoPE 可以視為乘性的。在 的選擇上,RoPE 同樣沿用了 Sinusoidal 位置編碼的方案,即 ,它可以帶來一定的遠程衰減性。
具體證明如下:將 兩兩分組后,它們加上 RoPE 后的內(nèi)積可以用復數(shù)乘法表示為:
![476650ce-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QeACL_4AAA2jGY1FBI521.png)
并約定 ,那么由 Abel 變換(分部求和法)可以得到:
RoPE實驗
我們看一下 RoPE 在預(yù)訓練階段的實驗效果:
![47bcf0d2-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QiAYxFOAABRMKSzyAg941.png)
RoPE代碼實現(xiàn)
Meta 的 LLAMA 和 清華的 ChatGLM 都使用了 RoPE 編碼,下面看一下具體實現(xiàn)。
3.1 在LLAMA中的實現(xiàn)
#生成旋轉(zhuǎn)矩陣
defprecompute_freqs_cis(dim:int,seq_len:int,theta:float=10000.0):
#計算詞向量元素兩兩分組之后,每組元素對應(yīng)的旋轉(zhuǎn)角度 heta_i
freqs=1.0/(theta**(torch.arange(0,dim,2)[:(dim//2)].float()/dim))
#生成token序列索引t=[0,1,...,seq_len-1]
t=torch.arange(seq_len,device=freqs.device)
#freqs.shape=[seq_len,dim//2]
freqs=torch.outer(t,freqs).float()#計算m* heta
#計算結(jié)果是個復數(shù)向量
#假設(shè)freqs=[x,y]
#則freqs_cis=[cos(x)+sin(x)i,cos(y)+sin(y)i]
freqs_cis=torch.polar(torch.ones_like(freqs),freqs)
returnfreqs_cis
#旋轉(zhuǎn)位置編碼計算
defapply_rotary_emb(
xq:torch.Tensor,
xk:torch.Tensor,
freqs_cis:torch.Tensor,
)->Tuple[torch.Tensor,torch.Tensor]:
#xq.shape=[batch_size,seq_len,dim]
#xq_.shape=[batch_size,seq_len,dim//2,2]
xq_=xq.float().reshape(*xq.shape[:-1],-1,2)
xk_=xk.float().reshape(*xk.shape[:-1],-1,2)
#轉(zhuǎn)為復數(shù)域
xq_=torch.view_as_complex(xq_)
xk_=torch.view_as_complex(xk_)
#應(yīng)用旋轉(zhuǎn)操作,然后將結(jié)果轉(zhuǎn)回實數(shù)域
#xq_out.shape=[batch_size,seq_len,dim]
xq_out=torch.view_as_real(xq_*freqs_cis).flatten(2)
xk_out=torch.view_as_real(xk_*freqs_cis).flatten(2)
returnxq_out.type_as(xq),xk_out.type_as(xk)
classAttention(nn.Module):
def__init__(self,args:ModelArgs):
super().__init__()
self.wq=Linear(...)
self.wk=Linear(...)
self.wv=Linear(...)
self.freqs_cis=precompute_freqs_cis(dim,max_seq_len*2)
defforward(self,x:torch.Tensor):
bsz,seqlen,_=x.shape
xq,xk,xv=self.wq(x),self.wk(x),self.wv(x)
xq=xq.view(batch_size,seq_len,dim)
xk=xk.view(batch_size,seq_len,dim)
xv=xv.view(batch_size,seq_len,dim)
#attention操作之前,應(yīng)用旋轉(zhuǎn)位置編碼
xq,xk=apply_rotary_emb(xq,xk,freqs_cis=freqs_cis)
#scores.shape=(bs,seqlen,seqlen)
scores=torch.matmul(xq,xk.transpose(1,2))/math.sqrt(dim)
scores=F.softmax(scores.float(),dim=-1)
output=torch.matmul(scores,xv)#(batch_size,seq_len,dim)
#......
這里舉一個例子,假設(shè) batch_size=10, seq_len=3, d=8,則調(diào)用函數(shù) precompute_freqs_cis(d, seq_len) 后,生成結(jié)果為:
In[239]:freqs_cis
Out[239]:
tensor([[1.0000+0.0000j,1.0000+0.0000j,1.0000+0.0000j,1.0000+0.0000j],
[0.5403+0.8415j,0.9950+0.0998j,0.9999+0.0100j,1.0000+0.0010j],
[-0.4161+0.9093j,0.9801+0.1987j,0.9998+0.0200j,1.0000+0.0020j]])
以結(jié)果中的第二行為例(對應(yīng)的 m = 1),也就是:
![47cc4bea-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QiADQF4AACkQm0xi7c801.png)
In[351]:q_=q.float().reshape(*q.shape[:-1],-1,2)
In[352]:q_[0]
Out[352]:
tensor([[[1.0247,0.4782],
[1.5593,0.2119],
[0.4175,0.5309],
[0.4858,0.1850]],
[[-1.7456,0.6849],
[0.3844,1.1492],
[0.1700,0.2106],
[0.5433,0.2261]],
[[-1.1206,0.6969],
[0.8371,-0.7765],
[-0.3076,0.1704],
[-0.5999,-1.7029]]])
In[353]:xq=torch.view_as_complex(q_)
In[354]:xq[0]
Out[354]:
tensor([[1.0247+0.4782j,1.5593+0.2119j,0.4175+0.5309j,0.4858+0.1850j],
[-1.7456+0.6849j,0.3844+1.1492j,0.1700+0.2106j,0.5433+0.2261j],
[-1.1206+0.6969j,0.8371-0.7765j,-0.3076+0.1704j,-0.5999-1.7029j]])
這里為什么可以這樣計算?
主要是利用了復數(shù)的乘法性質(zhì)。
我們首先來復習一下復數(shù)乘法的性質(zhì):
classRotaryEmbedding(torch.nn.Module):
def__init__(self,dim,base=10000,precision=torch.half,learnable=False):
super().__init__()
#計算 heta_i
inv_freq=1./(base**(torch.arange(0,dim,2).float()/dim))
inv_freq=inv_freq.half()
self.learnable=learnable
iflearnable:
self.inv_freq=torch.nn.Parameter(inv_freq)
self.max_seq_len_cached=None
else:
self.register_buffer('inv_freq',inv_freq)
self.max_seq_len_cached=None
self.cos_cached=None
self.sin_cached=None
self.precision=precision
defforward(self,x,seq_dim=1,seq_len=None):
ifseq_lenisNone:
seq_len=x.shape[seq_dim]
ifself.max_seq_len_cachedisNoneor(seq_len>self.max_seq_len_cached):
self.max_seq_len_cached=Noneifself.learnableelseseq_len
#生成token序列索引t=[0,1,...,seq_len-1]
t=torch.arange(seq_len,device=x.device,dtype=self.inv_freq.dtype)
#對應(yīng)m* heta
freqs=torch.einsum('i,j->ij',t,self.inv_freq)
#將m* heta拼接兩次,對應(yīng)復數(shù)的實部和虛部
emb=torch.cat((freqs,freqs),dim=-1).to(x.device)
ifself.precision==torch.bfloat16:
emb=emb.float()
#[sx,1(b*np),hn]
cos_cached=emb.cos()[:,None,:]#計算得到cos(m* heta)
sin_cached=emb.sin()[:,None,:]#計算得到cos(m* heta)
ifself.precision==torch.bfloat16:
cos_cached=cos_cached.bfloat16()
sin_cached=sin_cached.bfloat16()
ifself.learnable:
returncos_cached,sin_cached
self.cos_cached,self.sin_cached=cos_cached,sin_cached
returnself.cos_cached[:seq_len,...],self.sin_cached[:seq_len,...]
def_apply(self,fn):
ifself.cos_cachedisnotNone:
self.cos_cached=fn(self.cos_cached)
ifself.sin_cachedisnotNone:
self.sin_cached=fn(self.sin_cached)
returnsuper()._apply(fn)
defrotate_half(x):
x1,x2=x[...,:x.shape[-1]//2],x[...,x.shape[-1]//2:]
returntorch.cat((-x2,x1),dim=x1.ndim-1)
RoPE的外推性
我們都知道 RoPE 具有很好的外推性,前面的實驗結(jié)果也證明了這一點。這里解釋下具體原因。 RoPE 可以通過旋轉(zhuǎn)矩陣來實現(xiàn)位置編碼的外推,即可以通過旋轉(zhuǎn)矩陣來生成超過預(yù)期訓練長度的位置編碼。這樣可以提高模型的泛化能力和魯棒性。 我們回顧一下 RoPE 的工作原理:假設(shè)我們有一個 維的絕對位置編碼 ,其中 是位置索引。我們可以將 看成一個 維空間中的一個點。我們可以定義一個 維空間中的一個旋轉(zhuǎn)矩陣 ,它可以將任意一個點沿著某個軸旋轉(zhuǎn)一定的角度。我們可以用 來變換 ,得到一個新的點 。我們可以發(fā)現(xiàn), 和 的距離是相等的,即 。這意味著 和 的相對關(guān)系沒有改變。但是, 和 的距離可能發(fā)生改變,即 。這意味著 和 的相對關(guān)系有所改變。因此,我們可以用 來調(diào)整不同位置之間的相對關(guān)系。 如果我們想要生成超過預(yù)訓練長度的位置編碼,我們只需要用 來重復變換最后一個預(yù)訓練位置編碼 ,得到新的位置編碼
![480fb7ae-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QiAanqSAAAx7lICIkg146.png)
總結(jié)
最近一直聽到旋轉(zhuǎn)編碼這個詞,但是一直沒有仔細看具體原理。今天花時間仔細看了一遍,確實理論寫的比較完備,而且實驗效果也不錯。目前很多的大模型,都選擇了使用了這種編碼方式(LLAMA、GLM 等)。
附錄
這里補充一下前面公式 1.3.2 節(jié)中,公式(8)~(11)是怎么推導出來的。 回到之前的公式(8),編碼之后的 以及內(nèi)積 的形式如下:
![487135f6-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QmAUOdEAAB4t4Gglac805.png)
![48799c82-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QmAbxlpAAArQIqpGeM824.png)
![489594dc-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QmAdYZfAAAu15teVng060.png)
![49122d12-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QqAAhE7AAA_4gGGWNc219.png)
![49a28f74-4bca-11ee-a25d-92fbcf53809c.png](https://file1.elecfans.com/web2/M00/A1/B1/wKgZomT28QuAWdUDAAEbO_qZUiQ862.png)
-
向量
+關(guān)注
關(guān)注
0文章
55瀏覽量
11711 -
旋轉(zhuǎn)編碼
+關(guān)注
關(guān)注
0文章
6瀏覽量
10531 -
大模型
+關(guān)注
關(guān)注
2文章
2603瀏覽量
3215
原文標題:十分鐘讀懂旋轉(zhuǎn)編碼(RoPE)
文章出處:【微信號:zenRRan,微信公眾號:深度學習自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
快充技術(shù)&芯片詳解 十分鐘讓你的手機滿血復活
ModelSim SE 十分鐘入門
全球首發(fā)十分鐘快速充滿電移動電源
采集系統(tǒng)需要隔十分鐘采集10S數(shù)據(jù),怎么實現(xiàn)?
基于STM32F103RB的數(shù)碼管如何去實現(xiàn)十分鐘計時呢
遇到SE5經(jīng)常自動重啟,大約十幾分鐘到二十分鐘左右重啟一次的問題如何解決?
十分鐘學會Xilinx FPGA 設(shè)計
三星改革智能手機充電技術(shù),充滿只需十分鐘
英國搭建太陽能汽車充電網(wǎng)試點項目,電動汽車在三十分鐘內(nèi)完成充電
十分鐘分析穩(wěn)壓三極管工作原理資料下載
![<b class='flag-5'>十分鐘</b>分析穩(wěn)壓三極管工作原理資料下載](https://file.elecfans.com/web1/M00/D9/4E/pIYBAF_1ac2Ac0EEAABDkS1IP1s689.png)
評論