vLLM 中,LLM 推理的 prefill 階段 attention 計(jì)算使用第三方庫(kù) xformers 的優(yōu)化實(shí)現(xiàn),decoding 階段 attention 計(jì)算則使用項(xiàng)目編譯 CUDA 代碼實(shí)現(xiàn)。具體代碼在 vllm 的 csrc/attention/attention_kernels.cu 文件里,開發(fā)者洋洋灑灑寫了八百多行 CUDA 代碼。
Attention 計(jì)算時(shí)使用頁(yè)式(paged)管理 KVCache 用于增加服務(wù)吞吐率,但對(duì)延遲有負(fù)面影響,因此高效的 PA 實(shí)現(xiàn)方法,利用頁(yè)式內(nèi)存管理同時(shí)盡量降低其負(fù)面影響,對(duì)框架的綜合性能表現(xiàn)至關(guān)重要。
本文章將描述 PA CUDA Kernel 的實(shí)現(xiàn)細(xì)節(jié),這些細(xì)節(jié)是公開的論文和博客所不涉及的,但卻對(duì)框架的速度至關(guān)重要。另外,PA 實(shí)現(xiàn)改編自 FasterTransformers 某個(gè)版本的 MHA 實(shí)現(xiàn),NV 原始版本對(duì) GPU 特性的運(yùn)用也是相當(dāng)老道的,值得大家借鑒。
vLLM 中有兩個(gè)版本 PA,使用一個(gè)簡(jiǎn)單的啟發(fā)式方法來(lái)決定是使用 V1 還是 V2 版本。V1 是本文介紹的版本,改編自 FasterTransformers 的 MHA 實(shí)現(xiàn)。V2 是參考 FlashDecoding 方式進(jìn)行實(shí)現(xiàn),對(duì) sequence 維度進(jìn)行切分以增加并行粒度,關(guān)于 FlashDecoding 可以參考本人知乎文章。V1 適合長(zhǎng)度小于 8192 或者 num_seqs * num_heads>512 的情況。
參數(shù)定義和數(shù)據(jù)結(jié)構(gòu)
num_seq:本次推理請(qǐng)求 sequence 數(shù)目。
num_head:Query 的 head 數(shù)目。
num_kv_heads:Key、Value 的 head 數(shù)目,對(duì)于 MHA 和 num_head 相同,如果是 GQA、MQA 則 num_kv_heads 小于 num_head。
head_size hidden dimension,特征的維度。
PA 使用 tensor 的維度信息:
out [num_seqs, num_heads, head_size]
Q [num_seqs, num_heads, head_size]
KCache [num_blocks, num_kv_heads, head_size/x, block_size, x]:x 表示一個(gè)向量化的大小,如 float16 -> 16 / sizeof(float16) = 8。
VCache [num_blocks, num_kv_heads, head_size, block_size]
Paged 內(nèi)存管理相關(guān)的輔助數(shù)據(jù)結(jié)構(gòu):
blk_size:也就是 block_size,是 KVCache page 的最高維,KVCache 是若干個(gè) page 的集合,每個(gè) page 存(blk_size, num_head,head_size)個(gè) K、V 的元素。
head_mapping [num_heads] 用于 MQA, GQA,確定用的 KV_head
block_tables [num_seqs, max_num_blocks_per_seq] block_tables 映射表,表示每個(gè) sequence 映射到哪幾個(gè) block 上
context_lens [num_seqs] 用于變長(zhǎng)
課前問題
如果你能回答以下兩個(gè)問題,那么說(shuō)明你已經(jīng)非常熟練地掌握了 PA 實(shí)現(xiàn),并可以用批判性的眼光審閱本文,找出其中可能存在的錯(cuò)誤。如果你暫時(shí)無(wú)法回答這些問題,請(qǐng)不要擔(dān)憂,閱讀完本文后會(huì)給你答案。
Q1:為什么 K Cache 的 layout 和 V Cache layout 不一樣?
Q2:PA 實(shí)現(xiàn)和 FlashAttention 有什么區(qū)別?
PagedAttention算子計(jì)算流程
首先,按照 CUDA 編程模型對(duì)任務(wù)進(jìn)行并行劃分,grid 大小(num_heads, num_seqs),grid 中每個(gè) CUDA thread block 大小(NUM_THREADS),NUM_THREADS 是常量默認(rèn)為 128,也就說(shuō)每個(gè) thread block 包含 128 個(gè)線程,負(fù)責(zé)完成 output 矩陣一行(包含 head_size 個(gè)元素)結(jié)果的 attention 計(jì)算任務(wù)。thread block 中的線程進(jìn)一步劃分若干個(gè)WARP。
眾所周知,WARP 是 GPU 一個(gè)基本的執(zhí)行單元,由 32 個(gè)線程組成,這些線程以 SMIT 方式在硬件上同時(shí)執(zhí)行相同的指令,在不同的數(shù)據(jù)上進(jìn)行操作。在 PA 中比較特殊的是,warp 內(nèi) 32 個(gè)線程進(jìn)一步劃分為 blk_size 個(gè) thread group,這和 paged KVCache 設(shè)計(jì) x 息息相關(guān)的,馬上會(huì)細(xì)講。
Attention 計(jì)算 softmax(QK^T)V,一圖勝前言,后面流程介紹將圍繞下面這幅圖展開。其中 thread block, warp, thread group, thread 別用不同顏色表示。
▲ 圖1:PagedAttention CUDA計(jì)算流程
在上圖的左側(cè)部分,我們看到了 Q 矩陣,這部分描述了從顯存讀取 Q 數(shù)據(jù)到共享內(nèi)存的過(guò)程。在這個(gè)過(guò)程中,一個(gè) CUDA 線程塊會(huì)讀取圖中 Q 矩陣的一行(包含 head_size個(gè)元素)并將其存入共享內(nèi)存。
這個(gè)過(guò)程是通過(guò)一個(gè)循環(huán)來(lái)實(shí)現(xiàn)的,在每次迭代中,每個(gè) thread group 會(huì)讀取 16 字節(jié)的 Q 數(shù)據(jù)(例如,如果使用 float16,那么就是 8 個(gè)元素)。每個(gè) warp 會(huì)讀取 16*blk_size 字節(jié)的 Q 數(shù)據(jù),這些數(shù)據(jù)對(duì)應(yīng)于一個(gè) sequence 的一個(gè) head,由 CUDA grid 索引指定。當(dāng)循環(huán)訪問結(jié)束后,共享內(nèi)存存儲(chǔ) Q 行的一部分。如下圖所示,綠色部分表示存儲(chǔ)在一個(gè)線程讀入共享內(nèi)存中的數(shù)據(jù)。
圖 1 中上面部分 K 矩陣部分描述了從顯存讀取 K Cache 到寄存器的過(guò)程。每個(gè)序列的 K Cache 包含 cxt_length * num_kv_heads * head_size 個(gè)元素,但由于采用了頁(yè)式內(nèi)存管理,這些元素在內(nèi)存中的存儲(chǔ)并不連續(xù)。每個(gè) thread block 只負(fù)責(zé)計(jì)算一個(gè) sequence 一個(gè) head 的 QK^T,因此只需要 ctx_length * head_size 個(gè) K Cache 元素。
然而,由于 ctx_length 維度的存儲(chǔ)是不連續(xù)的,并且以 blk_size 個(gè) token 為粒度分布在不同的內(nèi)存地址,我們需要根據(jù)query的head_idx和 seq_idx 訪問 block_table 以找到 K Cache的physical_block_num。為了方便后續(xù)的描述,我們可以將 K Cache 視為(:, head_size)的形狀,其中 head_size 個(gè)元素組成一行。
K Cache 的布局為 [num_blocks, num_kv_heads, head_size/x, block_size, x],這是為了優(yōu)化寫入 shared memory 的操作。在 Q 和 K 矩陣的同一行元素被讀入寄存器并進(jìn)行點(diǎn)乘運(yùn)算后,結(jié)果需要被存入 shared memory。
如果一個(gè) warp 中所有線程都計(jì)算 Q、K 同一行數(shù)據(jù),會(huì)導(dǎo)致寫入 shared memory 的同一個(gè)位置,這將造成 warp 內(nèi)不同線程順序地寫入。因此,為了優(yōu)化,warp的線程最好計(jì)算 Q 和 K 的不同行數(shù)據(jù)。因此,在設(shè)計(jì) K Cache 布局時(shí),我們將 block_size 放在比 head_size 更低的維度。
由于 warp size 大于 block_size,我們需要將 head_size 拆分為 head_size/x 和 x 兩個(gè)維度,借 x 到最低維度,以確保每個(gè)線程讀入的數(shù)據(jù)量和計(jì)算量都足夠大。最后,每個(gè)線程組派一個(gè)線程去寫入 shared memory,這樣一個(gè) warp 有 blk_size 個(gè)線程并行寫入 shared memory,從而增加了 shared memory 的訪問帶寬。這種設(shè)計(jì)策略是為了實(shí)現(xiàn)高效的并行計(jì)算和內(nèi)存訪問,以提高整體的計(jì)算性能。
在代碼實(shí)現(xiàn)中,訪問 K 矩陣需要一個(gè)循環(huán),該循環(huán)使得 CUDA 線程塊中的所有 warp 依次訪問 num_block 個(gè)頁(yè)面。在每次循環(huán)迭代中,每個(gè) warp 負(fù)責(zé)訪問連續(xù)的 blk_size個(gè)K Cache 行,這涉及到的數(shù)據(jù)量為 blk_size * head_size 個(gè)元素。同時(shí),每個(gè) thread group 負(fù)責(zé)訪問 K Cache 的一行,將 head_size 個(gè)元素加載到自己的寄存器中。
接著,寄存器中的 Q 和 K 數(shù)據(jù)元素立即進(jìn)行點(diǎn)乘運(yùn)算,運(yùn)算結(jié)果被寫入 shared memory 中。因此,線程塊的 shared memory 存儲(chǔ)了一行 QK^T 的結(jié)果,包含 ctx_length 個(gè)元素。這種實(shí)現(xiàn)方式充分利用了 CUDA 的并行計(jì)算能力,以提高數(shù)據(jù)處理的效率。
然后,thread block 對(duì) shared memory 中元素進(jìn)行 max,sum 方式 reduction,然后計(jì)算得到 softmax 結(jié)果。
圖 1 右邊 V 矩陣部分描述從顯存讀 V Cache 到寄存器過(guò)程。和 K Cache 一樣,CUDA thread block 依次訪問 num_blk 個(gè)物理塊到寄存器,每個(gè) warp 負(fù)責(zé) blk_size 個(gè) token 的 page 內(nèi)存,page 的真實(shí)物理地址同樣需要進(jìn)行索引。
不過(guò)這里不需要以 thread group 為單位訪問 16 字節(jié),而是每個(gè) thread 訪問 16 字節(jié)的元素。訪問完就可以與 shared memory 的 softmax(QK^T) 中間結(jié)果對(duì)應(yīng)位置 16 字節(jié)的數(shù)據(jù)進(jìn)行點(diǎn)乘,得到一個(gè) float 結(jié)果,寫到 output 對(duì)應(yīng)位置中。
為什么V Cache的layout是 [num_blocks, num_kv_heads, head_size, block_size],和 K Cache layout 不一樣?這是因?yàn)?V 要去做點(diǎn)乘的對(duì)象在shared memory,只需要讀,不涉及并行寫的問題。
和 FlashAttention(FA)有什么不同?結(jié)合我的圖和中間 FAv2 的流程圖對(duì)比就一目了然了。FA 用了兩層循環(huán),每次寫一個(gè) Tile 的 output tensor,而 PA 一直只有一層循環(huán),每次寫一行 output tensor。因?yàn)槊看味加姓械?QK^T 中間結(jié)果,不需要 online softmax 這種花哨技巧。
PAv1的問題
以我粗淺的理解指出幾點(diǎn) vLLM PAv1 的問題。一、和 MHA 相比,MQA 和 GAQ 沒有減少對(duì) KV Cache 的讀寫次數(shù)。讀 K、V Cache 時(shí)候只是做了一個(gè) head_idx 的轉(zhuǎn)換,會(huì)重復(fù)從顯存讀相同的 head。二、對(duì)于 seq length 很長(zhǎng)情況沒法適應(yīng),因?yàn)闆]有沿著 ctx_length 或者 batch 維度做切分。這點(diǎn) FlashAttention 和 FlashDecoding 就做了,因此 PAv2 借鑒了 FA 的切分思想。
總結(jié)
vLLM 的 paged attention v1 實(shí)現(xiàn)繼承自 FasterTransformers MHA 實(shí)現(xiàn),它和 FlashAttention 的并行任務(wù)劃分方式不同。其中對(duì) KVCache layout 的設(shè)計(jì)比較巧妙,充分利用了 shared memory 寫帶寬,是一種常用 CUDA 編程技巧。
審核編輯:劉清
-
寄存器
+關(guān)注
關(guān)注
31文章
5369瀏覽量
121275 -
Cache
+關(guān)注
關(guān)注
0文章
129瀏覽量
28448 -
內(nèi)存管理
+關(guān)注
關(guān)注
0文章
168瀏覽量
14194 -
MQA
+關(guān)注
關(guān)注
0文章
3瀏覽量
6058
原文標(biāo)題:vLLM皇冠上的明珠:深入淺出理解PagedAttention CUDA實(shí)現(xiàn)
文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語(yǔ)言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論