DeepSeekV3的attention模塊采用了MLA(Multi-head Latent Attention,多頭潛注意力)結(jié)構(gòu),通過對attention過程中的Key和Value進行低秩聯(lián)合壓縮,降低推理過程中需要的KV cache,提升推理效率。MLA對attention過程中的Query也進行了低秩壓縮,可以減少訓(xùn)練過程中激活的內(nèi)存。
大模型的推理分為兩階段,處理所有輸入prompt并產(chǎn)生首個token的過程稱為prefill,此后至產(chǎn)生所有token結(jié)束推理的過程稱為decode,本文的MLA算子融合及優(yōu)化特指decode過程。
MLA的計算過程比較復(fù)雜,包括下投影、上投影、attention和輸出投影,為了減少數(shù)據(jù)搬運和任務(wù)調(diào)度帶來的時間開銷,提升芯片效率,我們在SC11上,將上投影和attention過程融合成MLA大算子,如圖1所示。DeepSeekV3提供了兩種計算模式:na?ve和absorb,我們采用計算量更少的absorb方式實現(xiàn)MLA decode過程,步驟如下:

圖1-SC11 MLA decode融合算子示意圖
常用的attention并行部署方案有兩種,TP(Tensor Parallel,張量并行)和DP(Data parallel,數(shù)據(jù)并行)。TP將權(quán)重切分到多顆芯片,每顆芯片會重復(fù)加載KV cache。DP將數(shù)據(jù)按batch分配到多顆芯片,每顆芯片處理不同batch的數(shù)據(jù),但會重復(fù)加載權(quán)重。實際應(yīng)用過程中,可以根據(jù)權(quán)重和緩存的大小選擇并行部署方案,權(quán)重和緩存大小如表1所示。
表1 權(quán)重與緩存數(shù)據(jù)大小
#seqlen指所有batch數(shù)據(jù)序列長度總和。
在SC11部署DeepSeekV3模型時,由于應(yīng)用場景中的權(quán)重數(shù)據(jù)多于KV cache數(shù)據(jù),所以MLA階段采用TP方案進行部署,即將Query、Key和Value的上投影權(quán)重矩陣按head切分,部署到四張SC11。DeepSeekV3的參數(shù)中,上投影權(quán)重有128頭,因此每張板卡處理32頭。每顆芯片有多個核,上投影權(quán)重會繼續(xù)按head切分到多核。由于低秩的KV cache不包含head維度,無法對KV cache進行TP,為了充分利用多核優(yōu)勢,我們對MLA的實現(xiàn)方式進行了探索,優(yōu)化了不同batch數(shù)目和序列長度下的實現(xiàn)方案,如表2所示。
表2 MLA decode多核實現(xiàn)方案
除了算子融合與動態(tài)調(diào)用優(yōu)化后的實現(xiàn)方案,MLA的實現(xiàn)過程也采用了業(yè)界常用的Flash Attention和Page Attention等優(yōu)化方法,進一步減少數(shù)據(jù)搬運和內(nèi)存占用。在Page Attention過程中,我們采用兩塊buffer優(yōu)化KV cache搬運,使得數(shù)據(jù)搬運和MLA計算同步進行,優(yōu)化過程如圖2所示。圖中SDMA代表負責(zé)DDR和L2 SRAM之間或內(nèi)部的數(shù)據(jù)搬運模塊,GDMA代表負責(zé)任意內(nèi)存之間數(shù)據(jù)搬運的模塊,BDC代表負責(zé)數(shù)據(jù)計算的單元。
在時刻T0同時進行兩個操作:
SDMA將batch 0以page方式存儲的KV cache從DDR搬到L2 SRAM中的Buffer0,形成連續(xù)存儲的緩存數(shù)據(jù);
GDMA將上投影權(quán)重從DDR搬到芯片的片上內(nèi)存(local memory)。
在時刻T1同時進行三個操作:
SDMA將batch 1以page方式存儲的KV cache從DDR搬到L2 SRAM中的Buffer1,形成連續(xù)存儲的緩存數(shù)據(jù);
GDMA將Buffer0中連續(xù)存儲的batch 0的KV cache數(shù)據(jù)從L2 SRAM搬到localmemory;
BDC對batch 0進行MLA計算。
時刻T2和T3的操作可依此類推。測試數(shù)據(jù)表明,在128 batch 512序列的decode過程,使用雙buffer優(yōu)化page attention實現(xiàn)過程后,可以節(jié)省30%的推理時間。
圖2-雙buffer優(yōu)化Page Attention實現(xiàn)過程
經(jīng)過融合與優(yōu)化后的MLA,助力了DeepSeekV3全流程的性能,當(dāng)模型處理128 batch數(shù)據(jù),每batch輸入序列長度為128,輸出序列長度為1024時,DeepSeekV3全流程在4卡SC11上能達到532 token/s。
作者:周文婧,陳學(xué)儒,溫舉發(fā)
-
AI
+關(guān)注
關(guān)注
88文章
35153瀏覽量
279845 -
人工智能
+關(guān)注
關(guān)注
1806文章
49021瀏覽量
249487 -
大模型
+關(guān)注
關(guān)注
2文章
3144瀏覽量
4067
發(fā)布評論請先 登錄
進迭時空同構(gòu)融合RISC-V AI CPU的Triton算子編譯器實踐

鴻蒙應(yīng)用px,vp,fp概念詳解

摩爾線程GPU原生FP8計算助力AI訓(xùn)練

sc跳線是什么口
FP8在大模型訓(xùn)練中的應(yīng)用

EE-401:ADSP-SC5xx/215xx SHARC處理器系統(tǒng)優(yōu)化技術(shù)

FP7127/FP7128 降壓雙路調(diào)光調(diào)色方案 輸入48V,輸出36V,12W功率

光纖口是sc-sc什么樣
sc光纖是什么意思
基于 DSP5509 進行數(shù)字圖像處理中 Sobel 算子邊緣檢測的硬件連接電路圖
FP8模型訓(xùn)練中Debug優(yōu)化思路

TMP300 采用 SC70 封裝的 1.8V 電阻器可編程溫度開關(guān)和模擬輸出溫度傳感器數(shù)據(jù)表

評論