0x0. 前言
本文解析一下mlc-llm(https://github.com/mlc-ai/mlc-llm)對(duì)大模型推理的流程以及使用的圖優(yōu)化,算子優(yōu)化策略。mlc-llm的模型部署流程可以查看官方文檔:https://mlc.ai/mlc-llm/docs/ ,也可以參考我前段時(shí)間寫(xiě)的這篇MLC-LLM 部署RWKV World系列模型實(shí)戰(zhàn)(3B模型Mac M2解碼可達(dá)26tokens/s) 。
此外,閱讀mlc-llm的代碼還需要理解一些TVM Unify的一些基礎(chǔ)概念,可以參考TVM 學(xué)習(xí)指南(個(gè)人版) ,Relax: TVM 的下一代圖層級(jí) IR,新一代深度學(xué)習(xí)編譯技術(shù)變革和展望等等。從 https://github.com/BBuf/tvm_mlir_learn 這里可以查看更多相關(guān)博客和資料。
在 MLC-LLM 部署RWKV World系列模型實(shí)戰(zhàn)(3B模型Mac M2解碼可達(dá)26tokens/s) 中提到要使用mlc-llm部署模型首先需要一個(gè)編譯過(guò)程,將原始的基于Realx搭建的模型比如RWKV和給定的device信息一起編譯為T(mén)VM中的runtime.Module(在linux上編譯的產(chǎn)物就是.so文件)提供mlc-llm的c++推理接口調(diào)用 。我們就從這里看起:
由于mlc-llm上游更新很快,為了準(zhǔn)確標(biāo)定代碼位置我fork了一份2023年9月17號(hào)的mlc-llm代碼 :https://github.com/BBuf/mlc-llm-code-analysis ,本文的注釋以及指出的代碼位置均以這個(gè)fork倉(cāng)庫(kù)為準(zhǔn)。
0x1. 編譯流程解析
編譯的入口在:https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/build.py 。
這個(gè)腳本構(gòu)建了一個(gè)模型build的入口,可以通過(guò)傳入不同的參數(shù)來(lái)構(gòu)建不同配置的模型。參數(shù)解析和模型編譯都在 https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py 中實(shí)現(xiàn),模型編譯準(zhǔn)備(mod_transform_before_build函數(shù))和編譯(build函數(shù))兩個(gè)階段。在模型編譯準(zhǔn)備階段,包含準(zhǔn)備需要優(yōu)化的算子,執(zhí)行一些基礎(chǔ)的圖變換,針對(duì)cuda做進(jìn)一步優(yōu)化,做算子fuse等優(yōu)化,詳細(xì)的解釋清閱讀這里的注釋?zhuān)篽ttps://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py#L378 。
在這之后會(huì)執(zhí)行編譯過(guò)程:https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py#L378 。從這里我們可以看到,對(duì)于GPU來(lái)說(shuō)使用的是默認(rèn)的schedule模板,并沒(méi)有使用AutoTVM/Ansor等等調(diào)優(yōu)工具,這一點(diǎn)是很友好的,個(gè)人猜測(cè)也是因?yàn)門(mén)ransformer架構(gòu)的模型是很固定的,然后優(yōu)化方法也比較統(tǒng)一。
上面的編譯前準(zhǔn)備和編譯都是針對(duì)IRModule來(lái)說(shuō)的,那么這個(gè)IRModule是怎么來(lái)的呢?以及量化是在哪里做的?這兩個(gè)問(wèn)題都是在 build_model_from_args 函數(shù): https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py#L627 處理的,發(fā)生在 mod_transform_before_build 函數(shù)調(diào)用之前。以 RWKV 模型為例,通過(guò)這行 mod, param_manager, params, model_config = rwkv.get_model(args, config) 代碼完成了從原始的 huggingface 模型到初始的 IRModule 的轉(zhuǎn)換,在這個(gè)過(guò)程中也包含了量化。
0x2. 模型搭建解析
0x2.1 模型組件搭建
首先在 https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/relax_model/modules.py 這里基于Relax的內(nèi)部接口(relax.Expr,relax.testing.nn.Module,relax.op.xxx等等)定義了搭建LLM模型需要的一些組件比如 ModuleList,Linear,Embedding,LayerNorm,RotaryEmbedding等等。這個(gè)地方我添加了一些解釋?zhuān)?qǐng)點(diǎn)上面的源碼鏈接查看。然后這個(gè)地方需要注意2個(gè)特殊的op,第一個(gè)是來(lái)自 https://github.com/mlc-ai/relax/blob/ceaf7b0156524d30537a3de5fa30764eaff4edb8/python/tvm/relax/op/index.py#L28 的:
def?take(x:?Expr,?indices:?Expr,?axis:?Optional[int]?=?None)?->?Expr: ????return?_ffi_api.take(x,?indices,?axis)??#?type:?ignore
這個(gè)函數(shù),實(shí)現(xiàn)了take的核心功能,與numpy和pytorch的take語(yǔ)義類(lèi)似,都可以通過(guò)指定indices來(lái)從輸入張量中抽取值。主要調(diào)用了_ffi_api.take進(jìn)行取值操作, 這個(gè)_ffi_api是relax底層實(shí)現(xiàn), take操作的實(shí)際計(jì)算會(huì)在這里進(jìn)行。這個(gè)函數(shù)被用于Embedding組件的搭建中。
另外nn.emit這個(gè)接口的作用是將一個(gè)relax.Expr表達(dá)式轉(zhuǎn)化為relax.Var變量,并保存該變量。
最后我們注意到這里搭建的Relax模塊風(fēng)格和PyTorch的模塊風(fēng)格基本一致,也可以看出Relax前端是不斷靠近動(dòng)態(tài)圖風(fēng)格,追求更佳的易用性。
0x2.2 模型搭建
首先看一些準(zhǔn)備工作:
#?@dataclass:這個(gè)裝飾器用于指示RWKVConfig類(lèi)是一個(gè)數(shù)據(jù)類(lèi)。用于存儲(chǔ)RWKVModel的配置信息。 @dataclass class?RWKVConfig: ????"""The?configuration?class?to?store?the?configuration?of?a?`RWKVModel`.""" ????num_hidden_layers:?int?#?類(lèi)中的一個(gè)屬性,用于存儲(chǔ)隱藏層的數(shù)量,類(lèi)型為整數(shù)。 ????vocab_size:?int?#?類(lèi)中的一個(gè)屬性,用于存儲(chǔ)詞匯表的大小,類(lèi)型為整數(shù)。 ????hidden_size:?int?#?類(lèi)中的一個(gè)屬性,用于存儲(chǔ)隱藏層的大小,類(lèi)型為整數(shù)。 ????intermediate_size:?int?#?類(lèi)中的一個(gè)屬性,用于存儲(chǔ)中間層的大小,類(lèi)型為整數(shù)。 ????rescale_every:?int?=?0?#?類(lèi)中的一個(gè)屬性,默認(rèn)值為0,用于存儲(chǔ)重新縮放的頻率,類(lèi)型為整數(shù)。 ????layer_norm_epsilon:?float?=?1e-5?#?類(lèi)中的一個(gè)屬性,默認(rèn)值為1e-5,用于存儲(chǔ)層歸一化的epsilon值,類(lèi)型為浮點(diǎn)數(shù)。 ????max_sequence_length:?int?=?1024?#?類(lèi)中的一個(gè)屬性,默認(rèn)值為1024,用于存儲(chǔ)最大序列長(zhǎng)度,類(lèi)型為整數(shù)。 ????dtype:?str?=?"float32"?#?類(lèi)中的一個(gè)屬性,默認(rèn)值為"float32",用于存儲(chǔ)數(shù)據(jù)類(lèi)型,類(lèi)型為字符串。 ????def?__init__( ????????self, ????????num_hidden_layers:?int, ????????vocab_size:?int, ????????hidden_size:?int, ????????intermediate_size:?int, ????????rescale_every:?int?=?0, ????????layer_norm_epsilon:?float?=?1e-5, ????????context_length:?int?=?1024, ????????dtype:?str?=?"float32", ????????**kwargs, ????)?->?None: ????????self.num_hidden_layers?=?num_hidden_layers ????????self.vocab_size?=?vocab_size ????????self.hidden_size?=?hidden_size ????????self.intermediate_size?=?intermediate_size ????????self.rescale_every?=?rescale_every ????????self.layer_norm_epsilon?=?layer_norm_epsilon ????????self.max_sequence_length?=?context_length ????????self.dtype?=?dtype ????????self.kwargs?=?kwargs #?用來(lái)索引RWKV的Attention和FFN部分存儲(chǔ)的狀態(tài)或者叫Cache。 #?python代碼可以參考:?https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L858-L867 class?State: ????ATT_X?=?0 ????ATT_A?=?1 ????ATT_B?=?2 ????ATT_P?=?3 ????FFN_X?=?4
這里的State是用來(lái)索引RWKV的Attention和FFN部分存儲(chǔ)的狀態(tài)或者叫Cache,每一個(gè)Layer有5個(gè)不同的State,并且每個(gè)State的shape都是[1, hidden_size],這里的1代表的應(yīng)該是batch緯度。
#?義了一個(gè)名為_(kāi)load_state的函數(shù),它接受一個(gè)名為state的參數(shù),類(lèi)型為Expr,一個(gè)名為hidden_size的參數(shù),類(lèi)型為整數(shù), #?一個(gè)名為dtype的參數(shù),類(lèi)型為字符串。函數(shù)的返回類(lèi)型為Expr。 def?_load_state(state:?Expr,?hidden_size:?int,?dtype:?str)?->?Expr: ????#?Reuse?`attention_kv_cache_view` ????#?將外部函數(shù)vm.builtin.attention_kv_cache_view賦值給變量f_load_cache。relax.extern是一個(gè)外部函數(shù)調(diào)用的語(yǔ)法, ????#?它指示編譯器在編譯時(shí)將該函數(shù)調(diào)用轉(zhuǎn)換為相應(yīng)的外部函數(shù)調(diào)用。 ????f_load_cache?=?relax.extern("vm.builtin.attention_kv_cache_view") ????#?使用nn.emit方法生成一個(gè)表達(dá)式對(duì)象,該表達(dá)式表示對(duì)外部函數(shù)f_load_cache的調(diào)用。 ????#?調(diào)用的參數(shù)是一個(gè)列表,包含state和R.shape([1,?hidden_size]),以及sinfo_args參數(shù)指定的一個(gè)R.Tensor對(duì)象。 ????cache?=?nn.emit( ????????relax.Call( ????????????f_load_cache, ????????????[state,?R.shape([1,?hidden_size])], ????????????sinfo_args=[R.Tensor((1,?hidden_size),?dtype)], ????????) ????) ????return?cache #?定義了一個(gè)名為_(kāi)store_state的函數(shù),它接受一個(gè)名為state的參數(shù),類(lèi)型為Expr,一個(gè)名為value的參數(shù),類(lèi)型為Expr。 def?_store_state(state:?Expr,?value:?Expr): ????#?Reuse?`attention_kv_cache_update` ????#?將外部函數(shù)vm.builtin.attention_kv_cache_update賦值給變量f_store_cache。 ????#?relax.extern是一個(gè)外部函數(shù)調(diào)用的語(yǔ)法,它指示編譯器在編譯時(shí)將該函數(shù)調(diào)用轉(zhuǎn)換為相應(yīng)的外部函數(shù)調(diào)用。 ????f_store_cache?=?relax.extern("vm.builtin.attention_kv_cache_update") ????#?使用nn.emit方法生成一個(gè)表達(dá)式對(duì)象,該表達(dá)式表示對(duì)外部函數(shù)f_store_cache的調(diào)用。 ????#?調(diào)用的參數(shù)是一個(gè)列表,包含state和value,以及sinfo_args參數(shù)指定的一個(gè)R.Object()對(duì)象。 ????return?nn.emit( ????????relax.Call( ????????????f_store_cache, ????????????[state,?value], ????????????sinfo_args=[R.Object()], ????????) ????)
這兩個(gè)函數(shù)用來(lái)加載和存儲(chǔ)RWKV模型的State。接下來(lái)看一下對(duì)應(yīng) https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L741 這里的torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp) 的Relax實(shí)現(xiàn),為了方便對(duì)照先貼一下原始的wkv forward cuda kernel:
?
?
template?__global__?void?kernel_wkv_forward(const?int?B,?const?int?T,?const?int?C, ???????????????????????????????const?float?*__restrict__?const?_w,?const?float?*__restrict__?const?_u,?const?F?*__restrict__?const?_k,?const?F?*__restrict__?const?_v, ???????????????????????????????F?*__restrict__?const?_y,?float?*__restrict__?const?_aa,?float?*__restrict__?const?_bb,?float?*__restrict__?const?_pp)?{ ????const?int?idx?=?blockIdx.x?*?blockDim.x?+?threadIdx.x; ????const?int?_b?=?idx?/?C; ????const?int?_c?=?idx?%?C; ????const?int?_offset?=?_b?*?T?*?C?+?_c; ????const?int?_state_offset?=?_b?*?C?+?_c; ????float?u?=?_u[_c]; ????float?w?=?_w[_c]; ????const?F?*__restrict__?const?k?=?_k?+?_offset; ????const?F?*__restrict__?const?v?=?_v?+?_offset; ????F?*__restrict__?const?y?=?_y?+?_offset; ????float?aa?=?_aa[_state_offset]; ????float?bb?=?_bb[_state_offset]; ????float?pp?=?_pp[_state_offset]; ????for?(int?i?=?0;?i? void?cuda_wkv_forward(int?B,?int?T,?int?C,?float?*w,?float?*u,?F?*k,?F?*v,?F?*y,?float?*aa,?float?*bb,?float?*pp)?{ ????dim3?threadsPerBlock(?min(C,?32)?); ????assert(B?*?C?%?threadsPerBlock.x?==?0); ????dim3?numBlocks(B?*?C?/?threadsPerBlock.x); ????kernel_wkv_forward<< >>(B,?T,?C,?w,?u,?k,?v,?y,?aa,?bb,?pp); }
這個(gè)cuda kernel里面,B表示batch_size,在mlc-llm的實(shí)現(xiàn)默認(rèn)為1。然后T表示序列長(zhǎng)度,C表示隱藏層緯度。然后我們就可以對(duì)應(yīng)來(lái)看mlc-llm的wkv實(shí)現(xiàn)了。
#?定義了一個(gè)名為create_wkv_func的函數(shù),它接受一個(gè)名為hidden_size的參數(shù), #?類(lèi)型為整數(shù),一個(gè)名為dtype的參數(shù),類(lèi)型為字符串,一個(gè)名為out_dtype的參數(shù),類(lèi)型為字符串。 def?create_wkv_func(hidden_size:?int,?dtype:?str,?out_dtype:?str): ????@T.prim_func ????def?wkv_func( ????????k:?T.handle, ????????v:?T.handle, ????????time_decay:?T.handle, ????????time_first:?T.handle, ????????saved_a:?T.handle, ????????saved_b:?T.handle, ????????saved_p:?T.handle, ????????wkv:?T.handle, ????????out_a:?T.handle, ????????out_b:?T.handle, ????????out_p:?T.handle, ????): ????????#?設(shè)置TIR函數(shù)的屬性。這里設(shè)置了三個(gè)屬性,包括op_pattern、tir.noalias和tir.is_scheduled。 ????????T.func_attr({"op_pattern":?8,?"tir.noalias":?True,?"tir.is_scheduled":?1}) ????????#?聲明一個(gè)名為context_length的變量,類(lèi)型為T(mén).int64(),用于存儲(chǔ)上下文長(zhǎng)度。 ????????context_length?=?T.int64() ????????#?創(chuàng)建一個(gè)名為K的匹配緩沖區(qū),通過(guò)T.match_buffer方法匹配參數(shù)k的形狀和數(shù)據(jù)類(lèi)型。 ????????#?K的形狀在原始的ChatRWKV中為B,T,C,只不過(guò)這里B=1 ????????#?這里的k就是上面cuda?kernel的_k ????????K?=?T.match_buffer(k,?(context_length,?hidden_size),?dtype=dtype) ????????#?創(chuàng)建一個(gè)名為V的匹配緩沖區(qū),通過(guò)T.match_buffer方法匹配參數(shù)v的形狀和數(shù)據(jù)類(lèi)型。 ????????#?這里的v就是上面cuda?kernel的_v ????????V?=?T.match_buffer(v,?(context_length,?hidden_size),?dtype=dtype) ????????#?創(chuàng)建一個(gè)名為T(mén)imeDecay的匹配緩沖區(qū),通過(guò)T.match_buffer方法匹配參數(shù)time_decay的形狀和數(shù)據(jù)類(lèi)型。 ????????#?這里的TimeDecay就是上面的w ????????TimeDecay?=?T.match_buffer(time_decay,?(hidden_size,),?dtype=dtype) ????????#?創(chuàng)建一個(gè)名為T(mén)imeFirst的匹配緩沖區(qū),通過(guò)T.match_buffer方法匹配參數(shù)time_first的形狀和數(shù)據(jù)類(lèi)型。 ????????#?這里的TimeFirst對(duì)應(yīng)上面的u ????????TimeFirst?=?T.match_buffer(time_first,?(hidden_size,),?dtype=dtype) ????????#?對(duì)應(yīng)kernel里面的_aa的上一個(gè)token的狀態(tài) ????????SavedA?=?T.match_buffer(saved_a,?(1,?hidden_size),?dtype=dtype) ????????#?對(duì)應(yīng)kernel里面的_bb的上一個(gè)token的狀態(tài) ????????SavedB?=?T.match_buffer(saved_b,?(1,?hidden_size),?dtype=dtype) ????????#?對(duì)應(yīng)kernel里面的_pp的上一個(gè)token的狀態(tài) ????????SavedP?=?T.match_buffer(saved_p,?(1,?hidden_size),?dtype=dtype) ????????#?對(duì)應(yīng)_aa的當(dāng)前token狀態(tài) ????????OutA?=?T.match_buffer(out_a,?(1,?hidden_size),?dtype=dtype) ????????#?對(duì)應(yīng)_bb的當(dāng)前token狀態(tài) ????????OutB?=?T.match_buffer(out_b,?(1,?hidden_size),?dtype=dtype) ????????#?對(duì)應(yīng)_pp的當(dāng)前token狀態(tài) ????????OutP?=?T.match_buffer(out_p,?(1,?hidden_size),?dtype=dtype) ????????#?對(duì)應(yīng)kernel里面的p ????????P?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?對(duì)應(yīng)kernel里面的e1 ????????E1?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?對(duì)應(yīng)kernel里面的e2 ????????E2?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?對(duì)應(yīng)kernel里面的aa ????????A_local?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?對(duì)應(yīng)kernel里面的bb ????????B_local?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?對(duì)應(yīng)kernel里面的cc ????????P_local?=?T.alloc_buffer((hidden_size,),?dtype=dtype,?scope="local") ????????#?迭代hidden_size?//?32次,使用T.thread_binding方法進(jìn)行線(xiàn)程綁定,其中hidden_size?//?32是塊索引的范圍。 ????????#?這里的線(xiàn)程塊劃分和rwkv?kernel里面保持一致:即每個(gè)block?32個(gè)線(xiàn)程,一共((B=1)*C)/32個(gè)blcok ????????for?bx?in?T.thread_binding(hidden_size?//?32,?thread="blockIdx.x"): ????????????#?迭代32次,使用T.thread_binding方法進(jìn)行線(xiàn)程綁定,其中32是線(xiàn)程索引的范圍。 ????????????for?tx?in?T.thread_binding(32,?thread="threadIdx.x"): ????????????????#?創(chuàng)建一個(gè)名為"init"的塊,用于初始化局部變量。 ????????????????with?T.block("init"): ????????????????????#?對(duì)應(yīng)?const?int?_state_offset?=?_b?*?C?+?_c; ????????????????????vi?=?T.axis.S(hidden_size,?bx?*?32?+?tx) ????????????????????#?對(duì)應(yīng)?float?aa?=?_aa[_state_offset]; ????????????????????A_local[vi]?=?SavedA[0,?vi] ????????????????????#?對(duì)應(yīng)?float?bb?=?_bb[_state_offset]; ????????????????????B_local[vi]?=?SavedB[0,?vi] ????????????????????#?對(duì)應(yīng)?float?pp?=?_pp[_state_offset]; ????????????????????P_local[vi]?=?SavedP[0,?vi] ????????????????for?j?in?range(context_length):?#?對(duì)應(yīng)?for?(int?i?=?0;?i?我們可以看到mlc-llm里面的wkv forward實(shí)現(xiàn)基本就是用基于Relax的api將cuda函數(shù)翻譯成了TIR。注釋里面給了一些下標(biāo)的推導(dǎo)以及每一行Relax的代碼是如何對(duì)應(yīng)到原始的cuda kernel。
#?定義了一個(gè)名為_(kāi)te_concat_saved_x的函數(shù),它接受兩個(gè)參數(shù)saved_x和x,都是te.Tensor類(lèi)型的張量。 #?使用TVM的te.compute函數(shù)計(jì)算一個(gè)新的張量,該張量的形狀與x相同,元素根據(jù)條件判斷進(jìn)行選擇。如果i等于0, #?則選擇saved_x[0,?j]作為元素值,否則選擇x[i?-?1,?j]作為元素值。其中i和j是迭代變量。 def?_te_concat_saved_x(saved_x:?te.Tensor,?x:?te.Tensor): ????return?te.compute( ????????x.shape, ????????lambda?i,?j:?tir.if_then_else(i?==?0,?saved_x[0,?j],?x[i?-?1,?j]), ????) #?定義了一個(gè)名為_(kāi)te_get_last_x的函數(shù),它接受一個(gè)參數(shù)x,是一個(gè)te.Tensor類(lèi)型的張量。 #?a.?seq_len,?hidden_size?=?x.shape:獲取x張量的形狀,其中seq_len表示序列長(zhǎng)度,hidden_size表示隱藏大小。 #?b.?return?te.compute(...):使用TVM的te.compute函數(shù)計(jì)算一個(gè)新的張量,該張量的形狀為(1,?hidden_size), #?元素值為x[seq_len?-?1,?j],其中j是迭代變量。 def?_te_get_last_x(x:?te.Tensor): ????seq_len,?hidden_size?=?x.shape ????return?te.compute((1,?hidden_size),?lambda?_,?j:?x[seq_len?-?1,?j])這兩個(gè)函數(shù)應(yīng)該對(duì)應(yīng)了 https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L455 這里代碼里面的sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))和xx[-1, :]:
@MyFunction ????def?ffn_seq(self,?x,?sx,?ln_w,?ln_b,?k_mix,?r_mix,?kw,?vw,?rw,?kmx,?krx,?kmy,?kry,?vmx,?vrx,?vmy,?vry,?rmx,?rrx,?rmy,?rry): ????????xx?=?F.layer_norm(x,?(x.shape[-1],),?weight=ln_w,?bias=ln_b) ????????sx?=?torch.cat((sx.unsqueeze(0),?xx[:-1,:])) ????????kx?=?xx?*?k_mix?+?sx?*?(1?-?k_mix) ????????rx?=?xx?*?r_mix?+?sx?*?(1?-?r_mix) ????????r?=?torch.sigmoid(gemm(rx,?rw)) ????????vx?=?torch.square(torch.relu(gemm(kx,?kw))) ????????out?=?r?*?gemm(vx,?vw) ????????return?x?+?out,?xx[-1,:]接著對(duì)Embedding函數(shù)進(jìn)行解析:
#?定義了一個(gè)名為RWKV_Embedding的PyTorch模塊。 class?RWKV_Embedding(nn.Module): ????#?定義了RWKV_Embedding類(lèi)的構(gòu)造函數(shù),接受三個(gè)參數(shù)num_embeddings、embedding_dim和dtype。 ????def?__init__(self,?num_embeddings,?embedding_dim,?dtype): ????????self.num_embeddings?=?num_embeddings?#?將num_embeddings賦值給類(lèi)成員變量self.num_embeddings。 ????????self.embedding_dim?=?embedding_dim?#?將embedding_dim賦值給類(lèi)成員變量self.embedding_dim。 ????????#?創(chuàng)建一個(gè)名為weight的Parameter,形狀為(num_embeddings,?embedding_dim), ????????#?數(shù)據(jù)類(lèi)型為dtype,并將其賦值給類(lèi)成員變量self.weight。 ????????self.weight?=?nn.Parameter( ????????????(num_embeddings,?embedding_dim),?dtype=dtype,?name="weight" ????????) ????def?forward(self,?x:?relax.Expr)?->?relax.Var: ????????#?調(diào)用op.reshape函數(shù)將輸入張量x進(jìn)行reshape,將其展平為一維張量,并將結(jié)果重新賦值給x。 ????????#?nn.emit是將一個(gè)relax.Expr表達(dá)式轉(zhuǎn)化為relax.Var變量,并保存該變量。 ????????x?=?nn.emit(op.reshape(x,?shape=[-1])) ????????#?使用op.take操作從self.weight中按照索引x提取對(duì)應(yīng)的嵌入向量,并返回結(jié)果。這里的axis=0表示在第一個(gè)維度上進(jìn)行索引操作。 ????????return?nn.emit(op.take(self.weight,?x,?axis=0))以及LayerNorm:
#?這段代碼定義了一個(gè)名為RWKV_LayerNorm的PyTorch模塊,它實(shí)現(xiàn)了一個(gè)Layer?Normalization層。 class?RWKV_LayerNorm(nn.Module): ????#?定義了RWKV_LayerNorm類(lèi)的構(gòu)造函數(shù),接受四個(gè)參數(shù)intermediate_size、dtype、eps和name_prefix。 ????def?__init__(self,?intermediate_size,?dtype,?eps=1e-5,?name_prefix=""): ????????super().__init__() ????????self.eps?=?eps ????????self.weight?=?nn.Parameter( ????????????(intermediate_size,),?dtype=dtype,?name=f"{name_prefix}_ln_weight" ????????) ????????self.bias?=?nn.Parameter( ????????????(intermediate_size,),?dtype=dtype,?name=f"{name_prefix}_ln_bias" ????????) ????def?forward(self,?x:?relax.Expr)?->?relax.Var: ????????#?使用op.nn.layer_norm操作對(duì)輸入張量x進(jìn)行Layer?Normalization,其中使用Parameter?self.weight作為縮放參數(shù)(gamma), ????????#?使用可學(xué)習(xí)參數(shù)self.bias作為偏移參數(shù)(beta),在最后一個(gè)維度(axes=-1)上進(jìn)行標(biāo)準(zhǔn)化操作, ????????#?并設(shè)置小數(shù)值修正項(xiàng)為self.eps。將標(biāo)準(zhǔn)化后的結(jié)果重新賦值給x。 ????????x?=?nn.emit( ????????????op.nn.layer_norm( ????????????????x, ????????????????gamma=self.weight, ????????????????beta=self.bias, ????????????????axes=-1, ????????????????epsilon=self.eps, ????????????) ????????) ????????return?x接著對(duì)FFN層做一個(gè)詳細(xì)的解析:
#?這段代碼定義了一個(gè)名為RWKV_FFN的PyTorch模塊,它實(shí)現(xiàn)了Feed-Forward?Network(FFN)。 class?RWKV_FFN(nn.Module): ????#?定義了RWKV_FFN類(lèi)的構(gòu)造函數(shù),接受兩個(gè)參數(shù)RWKVConfig和index。 ????def?__init__(self,?config:?RWKVConfig,?index:?int)?->?None: ????????super().__init__() ????????#?將config.hidden_size賦值給類(lèi)成員變量self.hidden_size,表示隱藏大小。 ????????self.hidden_size?=?config.hidden_size ????????#?將config.dtype賦值給類(lèi)成員變量self.dtype,表示數(shù)據(jù)類(lèi)型。 ????????self.dtype?=?config.dtype ????????#?將index賦值給類(lèi)成員變 ????????self.index?=?index ????????#?建一個(gè)名為time_mix_key的可學(xué)習(xí)參數(shù),形狀為(self.hidden_size,), ????????#?數(shù)據(jù)類(lèi)型為config.dtype,命名為"ffn_{index}_time_mix_k",并將其賦值給類(lèi)成員變量self.time_mix_key。 ????????self.time_mix_key?=?nn.Parameter( ????????????(self.hidden_size,),?dtype=config.dtype,?name=f"ffn_{index}_time_mix_k" ????????) ????????#?創(chuàng)建一個(gè)名為time_mix_receptance的可學(xué)習(xí)參數(shù),形狀為(self.hidden_size,),數(shù)據(jù)類(lèi)型為config.dtype, ????????#?命名為"ffn_{index}_time_mix_r",并將其賦值給類(lèi)成員變量self.time_mix_receptance。 ????????self.time_mix_receptance?=?nn.Parameter( ????????????(self.hidden_size,),?dtype=config.dtype,?name=f"ffn_{index}_time_mix_r" ????????) ????????#?創(chuàng)建一個(gè)線(xiàn)性層,輸入大小為self.hidden_size,輸出大小為config.intermediate_size, ????????#?數(shù)據(jù)類(lèi)型為config.dtype,沒(méi)有偏置項(xiàng),并將其賦值給類(lèi)成員變量self.key。 ????????self.key?=?Linear( ????????????self.hidden_size,?config.intermediate_size,?dtype=config.dtype,?bias=False ????????) ????????#?創(chuàng)建一個(gè)線(xiàn)性層,輸入大小為self.hidden_size,輸出大小為self.hidden_size,數(shù)據(jù)類(lèi)型為config.dtype, ????????#?沒(méi)有偏置項(xiàng),并將其賦值給類(lèi)成員變量self.receptance。 ????????self.receptance?=?Linear( ????????????self.hidden_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????????self.value?=?Linear( ????????????config.intermediate_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????def?forward(self,?x:?Expr,?state:?Expr)?->?Expr: ????????#?計(jì)算偏移量,用于在state中獲取對(duì)應(yīng)的保存狀態(tài)。 ????????offset?=?self.index?*?5?+?State.FFN_X ????????#?獲取x的shape[0]表示上下文長(zhǎng)度。 ????????context_length?=?x.struct_info.shape[0] ????????#?獲取隱藏層大小。 ????????hidden_size?=?self.hidden_size ????????#?調(diào)用_load_state函數(shù)從state中加載保存的狀態(tài)state[offset],并將結(jié)果賦值給saved_x。 ????????saved_x?=?_load_state(state[offset],?hidden_size,?self.dtype) ????????#?如果上下文長(zhǎng)度不為1,則執(zhí)行下面的操作。 ????????if?not?is_one(context_length): ????????????#?調(diào)用nn.emit_te函數(shù),將saved_x和x作為參數(shù)傳遞給 ????????????#?_te_concat_saved_x函數(shù)進(jìn)行計(jì)算,并將結(jié)果重新賦值給saved_x。 ????????????#?類(lèi)似于transformer?里面的KV?Cache的,但是這里的concat是緯度不變的 ????????????#?對(duì)應(yīng)?sx?=?torch.cat((sx.unsqueeze(0),?xx[:-1,:]))?這行代碼 ????????????saved_x?=?nn.emit_te(_te_concat_saved_x,?saved_x,?x) ????????#?創(chuàng)建一個(gè)全為1的張量,形狀為(hidden_size,),數(shù)據(jù)類(lèi)型為self.dtype,并將其賦值給ones。 ????????ones?=?nn.emit(relax.op.ones((hidden_size,),?self.dtype)) ????????#?計(jì)算xk,根據(jù)時(shí)間混合參數(shù)self.time_mix_key和保存的狀態(tài)saved_x,使用加權(quán)求和的方式得到。 ????????#?其中,x和saved_x分別乘以self.time_mix_key和(ones?-?self.time_mix_key),然后相加。將計(jì)算結(jié)果賦值給xk。 ????????#?對(duì)應(yīng)?kx?=?xx?*?k_mix?+?sx?*?(1?-?k_mix)?這行代碼 ????????xk?=?nn.emit(x?*?self.time_mix_key?+?saved_x?*?(ones?-?self.time_mix_key)) ????????#?計(jì)算xr,根據(jù)時(shí)間混合參數(shù)self.time_mix_receptance和保存的狀態(tài)saved_x,使用加權(quán)求和的方式得到。 ????????#?其中,x和saved_x分別乘以self.time_mix_receptance和(ones?-?self.time_mix_receptance),然后相加。 ????????#?將計(jì)算結(jié)果賦值給xr。 ????????#?對(duì)應(yīng)?rx?=?xx?*?r_mix?+?sx?*?(1?-?r_mix) ????????xr?=?nn.emit( ????????????x?*?self.time_mix_receptance?+?saved_x?*?(ones?-?self.time_mix_receptance) ????????) ????????#?#?如果上下文長(zhǎng)度不為1,則執(zhí)行下面的操作。 ????????if?not?is_one(context_length): ????????????#?調(diào)用nn.emit_te函數(shù),使用_te_get_last_x函數(shù)從x中獲取最后一個(gè)token對(duì)應(yīng)的tensor,并將結(jié)果重新賦值給x。 ????????????#?對(duì)應(yīng)?xx[-1,:] ????????????x?=?nn.emit_te(_te_get_last_x,?x) ????????#?斷言x的結(jié)構(gòu)信息(shape)的第一個(gè)維度為1。 ????????assert?is_one(x.struct_info.shape[0]) ????????#?調(diào)用_store_state函數(shù),將x保存到state[offset]中,并將結(jié)果重新賦值給saved_x。 ????????#?對(duì)應(yīng):https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L921 ????????saved_x?=?_store_state(state[offset],?x) ????????#?將xr作為輸入,經(jīng)過(guò)sigmoid激活函數(shù)計(jì)算得到r。對(duì)應(yīng):r?=?torch.sigmoid(gemm(rx,?rw)) ????????r?=?nn.emit(op.sigmoid(self.receptance(xr))) ????????#?對(duì)應(yīng)?vx?=?torch.square(torch.relu(gemm(kx,?kw))) ????????xv?=?nn.emit(op.square(op.nn.relu(self.key(xk)))) ????????return?nn.emit(r?*?self.value(xv)),?[saved_x]接下來(lái)對(duì)Attention部分的實(shí)現(xiàn)進(jìn)行解析,注意這部分對(duì)應(yīng)的代碼在 https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L728-L747 。貼一下python代碼防止看錯(cuò)位置產(chǎn)生疑問(wèn):
if?os.environ["RWKV_CUDA_ON"]?==?'1': ????????@MyFunction ????????def?cuda_att_seq(self,?x,?sx,?aa,?bb,?pp,?ln_w,?ln_b,?k_mix,?v_mix,?r_mix,?t_decay,?t_first,?kw,?vw,?rw,?ow,?kmx,?krx,?kmy,?kry,?vmx,?vrx,?vmy,?vry,?rmx,?rrx,?rmy,?rry,?omx,?orx,?omy,?ory): ????????????T,?C?=?x.shape ????????????xx?=?F.layer_norm(x,?(C,),?weight=ln_w,?bias=ln_b) ????????????sx?=?torch.cat((sx.unsqueeze(0),?xx[:-1,:])) ????????????kx?=?xx?*?k_mix?+?sx?*?(1?-?k_mix) ????????????vx?=?xx?*?v_mix?+?sx?*?(1?-?v_mix) ????????????rx?=?xx?*?r_mix?+?sx?*?(1?-?r_mix) ????????????r?=?torch.sigmoid(gemm(rx,?rw)) ????????????k?=?gemm(kx,?kw,?output_dtype=torch.float32) ????????????v?=?gemm(vx,?vw,?output_dtype=torch.float32) ????????????y,?aa,?bb,?pp?=?cuda_wkv(T,?aa.shape[0],?t_decay,?t_first,?k,?v,?aa,?bb,?pp) ???????????? ????????????out?=?gemm(r?*?y.to(x.dtype),?ow) ????????????return?x?+?out,?xx[-1,:],?aa,?bb,?pp對(duì)應(yīng)mlc-llm RWKV Attention的代碼解析為:
#?實(shí)現(xiàn)RWKV?Attention,對(duì)應(yīng)?https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L479 class?RWKV_Attention(nn.Module): ????#?初始化函數(shù),接受一個(gè)config對(duì)象和一個(gè)整數(shù)index作為參數(shù)。其中config是一個(gè)RWKVConfig類(lèi)型的對(duì)象,index表示當(dāng)前層的索引。 ????def?__init__(self,?config:?RWKVConfig,?index:?int)?->?None: ????????super().__init__() ????????self.index?=?index ????????self.dtype?=?config.dtype ????????self.hidden_size?=?config.hidden_size ????????#?創(chuàng)建一些可學(xué)習(xí)的參數(shù),如time_decay、time_first、time_mix_key等,這些參數(shù)會(huì)在模型的前向傳播中使用。 ????????self.time_decay?=?nn.Parameter( ????????????(self.hidden_size,),?dtype="float32",?name=f"att_{index}_time_decay" ????????) ????????self.time_first?=?nn.Parameter( ????????????(self.hidden_size,),?dtype="float32",?name=f"att_{index}_time_first" ????????) ????????self.time_mix_key?=?nn.Parameter( ????????????(self.hidden_size,),?dtype=config.dtype,?name=f"att_{index}_time_mix_k" ????????) ????????self.time_mix_value?=?nn.Parameter( ????????????(self.hidden_size,),?dtype=config.dtype,?name=f"att_{index}_time_mix_v" ????????) ????????self.time_mix_receptance?=?nn.Parameter( ????????????(self.hidden_size,),?dtype=config.dtype,?name=f"att_{index}_time_mix_r" ????????) ????????#?前向傳播用到的線(xiàn)性層 ????????self.key?=?Linear( ????????????self.hidden_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????????self.value?=?Linear( ????????????self.hidden_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????????self.receptance?=?Linear( ????????????self.hidden_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????????self.output?=?Linear( ????????????self.hidden_size,?self.hidden_size,?dtype=config.dtype,?bias=False ????????) ????#?前向傳播函數(shù),接受輸入張量x和狀態(tài)張量state作為參數(shù),并返回輸出張量 ????def?forward(self,?x:?Expr,?state:?Expr)?->?Expr: ????????#?Load?current?state ????????#?定義了一些局部變量,如ones、index、hidden_size、context_length等。 ????????ones?=?nn.emit(relax.op.ones((self.hidden_size,),?self.dtype)) ????????index?=?self.index ????????hidden_size?=?self.hidden_size ????????context_length?=?x.struct_info.shape[0] ????????bb?=?relax.BlockBuilder.current() ????????#?_load_state函數(shù)從state中加載保存的狀態(tài),賦值給saved_a、saved_b、saved_p和saved_x。 ????????saved_a?=?_load_state(state[index?*?5?+?State.ATT_A],?hidden_size,?"float32") ????????saved_b?=?_load_state(state[index?*?5?+?State.ATT_B],?hidden_size,?"float32") ????????saved_p?=?_load_state(state[index?*?5?+?State.ATT_P],?hidden_size,?"float32") ????????saved_x?=?_load_state(state[index?*?5?+?State.ATT_X],?hidden_size,?self.dtype) ???????? ????????#?調(diào)用nn.emit_te函數(shù),將saved_x和x作為參數(shù)傳遞給 ????????#?_te_concat_saved_x函數(shù)進(jìn)行計(jì)算,并將結(jié)果重新賦值給saved_x。 ????????#?對(duì)應(yīng)?sx?=?torch.cat((sx.unsqueeze(0),?xx[:-1,:])) ????????if?not?is_one(context_length): ????????????saved_x?=?nn.emit_te(_te_concat_saved_x,?saved_x,?x) ????????#?對(duì)應(yīng)?kx?=?xx?*?k_mix?+?sx?*?(1?-?k_mix) ????????xk?=?nn.emit(x?*?self.time_mix_key?+?saved_x?*?(ones?-?self.time_mix_key)) ????????#?對(duì)應(yīng)?vx?=?xx?*?v_mix?+?sx?*?(1?-?v_mix) ????????xv?=?nn.emit(x?*?self.time_mix_value?+?saved_x?*?(ones?-?self.time_mix_value)) ????????#?對(duì)應(yīng)?rx?=?xx?*?r_mix?+?sx?*?(1?-?r_mix) ????????xr?=?nn.emit( ????????????x?*?self.time_mix_receptance?+?saved_x?*?(ones?-?self.time_mix_receptance) ????????) ????????#?對(duì)應(yīng)?r?=?torch.sigmoid(gemm(rx,?rw)) ????????r?=?nn.emit(op.sigmoid(self.receptance(xr))) ????????#?對(duì)應(yīng)?k?=?gemm(kx,?kw,?output_dtype=torch.float32) ????????k?=?nn.emit(op.astype(self.key(xk),?"float32")) ????????#?對(duì)應(yīng)?v?=?gemm(vx,?vw,?output_dtype=torch.float32) ????????v?=?nn.emit(op.astype(self.value(xv),?"float32")) ????????#?這部分對(duì)應(yīng)?y,?aa,?bb,?pp?=?cuda_wkv(T,?aa.shape[0],?t_decay,?t_first,?k,?v,?aa,?bb,?pp) ????????#?這里的?create_wkv_func?在上面已經(jīng)解析了 ????????gv?=?bb.add_func(create_wkv_func(hidden_size,?"float32",?self.dtype),?"wkv") ????????ret?=?nn.emit( ????????????relax.call_tir( ????????????????gv, ????????????????[k,?v,?self.time_decay,?self.time_first,?saved_a,?saved_b,?saved_p], ????????????????[ ????????????????????R.Tensor((context_length,?hidden_size),?self.dtype),?#?對(duì)應(yīng)wkv ????????????????????R.Tensor((1,?hidden_size),?"float32"),?#?對(duì)應(yīng)out_a ????????????????????R.Tensor((1,?hidden_size),?"float32"),?#?對(duì)應(yīng)out_b ????????????????????R.Tensor((1,?hidden_size),?"float32"),?#?對(duì)應(yīng)out_p ????????????????], ????????????) ????????) ????????if?not?is_one(context_length): ????????????#?對(duì)應(yīng)?xx[-1,:] ????????????x?=?nn.emit_te(_te_get_last_x,?x) ????????assert?is_one(x.struct_info.shape[0]) ????????saved_x?=?_store_state(state[self.index?*?5?+?State.ATT_X],?x) ????????saved_a?=?_store_state(state[self.index?*?5?+?State.ATT_A],?ret[1]) ????????saved_b?=?_store_state(state[self.index?*?5?+?State.ATT_B],?ret[2]) ????????saved_p?=?_store_state(state[self.index?*?5?+?State.ATT_P],?ret[3]) ????????#?需要注意一下,python代碼里面的?return?x?+?out,?xx[-1,:],?aa,?bb,?pp ????????#?這里的?x?+?out被放在attention外面做了,因?yàn)檫@里的x已經(jīng)是被修改之后好的結(jié)果而不是原始的x ????????return?nn.emit(self.output(r?*?ret[0])),?[ ????????????saved_x, ????????????saved_a, ????????????saved_b, ????????????saved_p, ????????]接著解析一下RWKVLayer的實(shí)現(xiàn),請(qǐng)注意下面的最后一行代碼的解釋?zhuān)?/span>
class?RWKVLayer(nn.Module): ????#?初始化函數(shù),接受一個(gè)config對(duì)象和一個(gè)整數(shù)index作為參數(shù)。其中config是一個(gè)RWKVConfig類(lèi)型的對(duì)象,index表示層的索引。 ????def?__init__(self,?config:?RWKVConfig,?index:?int)?->?None: ????????super().__init__() ????????#?如果index為0,創(chuàng)建一個(gè)RWKV_LayerNorm對(duì)象pre_ln,用于對(duì)輸入進(jìn)行Layer?Normalization操作。 ????????if?index?==?0: ????????????self.pre_ln?=?RWKV_LayerNorm( ????????????????config.hidden_size, ????????????????config.dtype, ????????????????eps=config.layer_norm_epsilon, ????????????????name_prefix="pre_ln", ????????????) ????????#?創(chuàng)建兩個(gè)RWKV_LayerNorm對(duì)象,分別命名為ln1和ln2, ????????#?用于對(duì)注意力機(jī)制和前饋神經(jīng)網(wǎng)絡(luò)的輸出進(jìn)行Layer?Normalization操作。 ????????self.ln1?=?RWKV_LayerNorm( ????????????config.hidden_size, ????????????config.dtype, ????????????eps=config.layer_norm_epsilon, ????????????name_prefix=f"att_{index}", ????????) ????????self.ln2?=?RWKV_LayerNorm( ????????????config.hidden_size, ????????????config.dtype, ????????????eps=config.layer_norm_epsilon, ????????????name_prefix=f"ffn_{index}", ????????) ????????#?創(chuàng)建一個(gè)RWKV_Attention對(duì)象attention,用于實(shí)現(xiàn)注意力機(jī)制。 ????????self.attention?=?RWKV_Attention(config,?index) ????????#?創(chuàng)建一個(gè)RWKV_FFN對(duì)象feed_forward,用于實(shí)現(xiàn)前饋神經(jīng)網(wǎng)絡(luò)。 ????????self.feed_forward?=?RWKV_FFN(config,?index) ????????self.rescale_every?=?config.rescale_every ????????self.dtype?=?config.dtype ????????self.index?=?index ????#?前向傳播函數(shù),接受輸入張量x和狀態(tài)張量state作為參數(shù),并返回輸出張量和更新后的狀態(tài)列表。 ????def?forward(self,?x:?Expr,?state:?Expr)?->?Tuple[Expr,?List[Expr]]: ????????#?如果index為0,則將輸入張量x傳入pre_ln進(jìn)行Layer?Normalization操作。 ????????if?self.index?==?0: ????????????x?=?self.pre_ln(x) ????????#?將經(jīng)過(guò)ln1的輸入張量x和狀態(tài)張量state傳入attention進(jìn)行計(jì)算,得到注意力機(jī)制的輸出att和更新后的狀態(tài)列表att_state。 ????????att,?att_state?=?self.attention(self.ln1(x),?state) ????????#?將輸入張量x和注意力機(jī)制的輸出att相加,并將結(jié)果賦值給x。 ????????x?=?nn.emit(x?+?att) ????????#?將經(jīng)過(guò)ln2的輸入張量x和狀態(tài)張量state傳入feed_forward進(jìn)行計(jì)算,得到前饋神經(jīng)網(wǎng)絡(luò)的輸出ffn和更新后的狀態(tài)列表ffn_state。 ????????ffn,?ffn_state?=?self.feed_forward(self.ln2(x),?state) ????????#?將輸入張量x和前饋神經(jīng)網(wǎng)絡(luò)的輸出ffn相加,并將結(jié)果賦值給x。 ????????x?=?nn.emit(x?+?ffn) ????????#?如果滿(mǎn)足self.rescale_every?>?0且(self.index?+?1)?%?self.rescale_every?==?0,則對(duì)輸入張量x進(jìn)行縮放操作。 ????????if?self.rescale_every?>?0?and?(self.index?+?1)?%?self.rescale_every?==?0: ????????????x?=?nn.emit(x?/?relax.const(2,?dtype=self.dtype)) ????????#?返回輸出張量x和注意力機(jī)制和前饋神經(jīng)網(wǎng)絡(luò)的更新后的狀態(tài)列表的拼接。 ????????return?x,?att_state?+?ffn_state注意這里的attn_state是[saved_x,saved_a,saved_b,saved_p,] ,然后ffn_state是[saved_x],注意這兩個(gè)x是不一樣的,這5個(gè)狀態(tài)也和本節(jié)開(kāi)頭的class State的成員定義一致。
接下來(lái)對(duì)RWKV模型定義進(jìn)行了解析:
#?該代碼是一個(gè)自定義的PyTorch模型類(lèi)RWKVModel,繼承自nn.Module class?RWKVModel(nn.Module): ????#?初始化函數(shù),接受一個(gè)config對(duì)象作為參數(shù)。其中config是一個(gè)RWKVConfig類(lèi)型的對(duì)象。 ????def?__init__(self,?config:?RWKVConfig)?->?None: ????????super().__init__() ????????#?創(chuàng)建一個(gè)RWKV_Embedding對(duì)象embeddings,用于實(shí)現(xiàn)輸入的嵌入操作。 ????????self.embeddings?=?RWKV_Embedding( ????????????num_embeddings=config.vocab_size, ????????????embedding_dim=config.hidden_size, ????????????dtype=config.dtype, ????????) ????????#?創(chuàng)建一個(gè)ModuleList對(duì)象blocks,其中包含了config.num_hidden_layers個(gè)RWKVLayer對(duì)象, ????????#?每個(gè)對(duì)象的索引從0到config.num_hidden_layers-1。 ????????self.blocks?=?ModuleList( ????????????[RWKVLayer(config,?i)?for?i?in?range(config.num_hidden_layers)] ????????) ????????#?創(chuàng)建一個(gè)RWKV_LayerNorm對(duì)象ln_out,用于對(duì)輸出進(jìn)行Layer?Normalization操作。 ????????self.ln_out?=?RWKV_LayerNorm( ????????????config.hidden_size, ????????????config.dtype, ????????????eps=config.layer_norm_epsilon, ????????????name_prefix="out_ln", ????????) ????????self.hidden_size?=?config.hidden_size ????????self.dtype?=?config.dtype ????#?前向傳播函數(shù),接受輸入張量input_ids和狀態(tài)張量state作為參數(shù),并返回輸出張量和更新后的狀態(tài)列表。 ????def?forward(self,?input_ids:?Expr,?state:?Expr)?->?Tuple[Expr,?List[Expr]]: ????????#?將輸入張量input_ids傳入embeddings進(jìn)行嵌入操作,得到隱藏狀態(tài)張量hidden_states。 ????????hidden_states?=?self.embeddings(input_ids) ????????#?創(chuàng)建一個(gè)空列表states,用于存儲(chǔ)每個(gè)RWKVLayer對(duì)象的更新后的狀態(tài)列表。 ????????states?=?[] ????????#?遍歷blocks中的每個(gè)RWKVLayer對(duì)象,將隱藏狀態(tài)張量hidden_states和狀態(tài)張量state傳入 ????????#?每個(gè)RWKVLayer對(duì)象的前向傳播函數(shù)進(jìn)行計(jì)算,得到更新后的隱藏狀態(tài)張量和更新后的狀態(tài)列表, ????????#?并將更新后的狀態(tài)列表添加到states中。 ????????for?_,?layer?in?enumerate(self.blocks): ????????????hidden_states,?layer_states?=?layer(hidden_states,?state) ????????????states?+=?layer_states ????????#?獲取隱藏狀態(tài)張量的上下文長(zhǎng)度context_length。 ????????context_length?=?hidden_states.struct_info.shape[0] ????????#?如果context_length不為1,則調(diào)用_te_get_last_x函數(shù)獲取最后一個(gè)token對(duì)應(yīng)的張量。 ????????if?not?is_one(context_length): ????????????hidden_states?=?nn.emit_te(_te_get_last_x,?hidden_states) ????????#?將隱藏狀態(tài)張量傳入ln_out進(jìn)行Layer?Normalization操作。 ????????hidden_states?=?self.ln_out(hidden_states) ????????#?返回輸出隱藏狀態(tài)張量和所有RWKVLayer對(duì)象的更新后的狀態(tài)列表。 ????????return?hidden_states,?states #?該代碼是一個(gè)自定義的PyTorch模型類(lèi)RWKVForCausalLM,繼承自nn.Module。 class?RWKVForCausalLM(nn.Module): ????#?初始化函數(shù),接受一個(gè)config對(duì)象作為參數(shù)。其中config是一個(gè)RWKVConfig類(lèi)型的對(duì)象。 ????def?__init__(self,?config:?RWKVConfig): ????????#?創(chuàng)建一個(gè)RWKVModel對(duì)象rwkv,用于實(shí)現(xiàn)序列模型的計(jì)算。 ????????self.rwkv?=?RWKVModel(config) ????????#?創(chuàng)建一個(gè)Linear對(duì)象head,用于將隱藏狀態(tài)映射到詞匯表大小的輸出空間。 ????????self.head?=?Linear( ????????????config.hidden_size,?config.vocab_size,?dtype=config.dtype,?bias=False ????????) ????????self.vocab_size?=?config.vocab_size ????????############?End?############ ????#?前向傳播函數(shù),接受輸入張量input_ids和狀態(tài)張量state作為參數(shù),并返回預(yù)測(cè)的logits和更新后的kv?cache。 ????def?forward( ????????self, ????????input_ids:?relax.Expr, ????????state:?relax.Expr, ????): ????????#?將輸入張量input_ids和狀態(tài)張量state傳入rwkv對(duì)象的前向傳播函數(shù)進(jìn)行計(jì)算, ????????#?得到更新后的隱藏狀態(tài)張量hidden_states和key-value緩存key_value_cache。 ????????hidden_states,?key_value_cache?=?self.rwkv(input_ids,?state) ????????#?將隱藏狀態(tài)張量hidden_states傳入head進(jìn)行線(xiàn)性映射操作,得到logits。 ????????logits?=?nn.emit(self.head(hidden_states)) ????????#?對(duì)logits進(jìn)行形狀重塑,將其reshape為形狀為(1,?1,?self.vocab_size)的張量。 ????????logits?=?nn.emit(op.reshape(logits,?(1,?1,?self.vocab_size))) ????????#?如果logits的數(shù)據(jù)類(lèi)型不是float32,則將其轉(zhuǎn)換為float32類(lèi)型。 ????????if?logits.struct_info.dtype?!=?"float32": ????????????logits?=?nn.emit(relax.op.astype(logits,?"float32")) ????????return?logits,?key_value_cache解下是一個(gè)根據(jù)參數(shù)的名字確定量化參數(shù)類(lèi)型的函數(shù):
#?該代碼定義了一個(gè)函數(shù)get_param_quant_kind,用于根據(jù)參數(shù)名稱(chēng)和參數(shù)信息確定參數(shù)的量化類(lèi)型。 def?get_param_quant_kind( ????name:?str,?param_info:?relax.TensorStructInfo )?->?ParamQuantKind: ????#?如果參數(shù)名稱(chēng)以"embeddings.weight"結(jié)尾,返回ParamQuantKind.embedding_table表示該參數(shù)是嵌入表的權(quán)重。 ????if?name.endswith("embeddings.weight"): ????????return?ParamQuantKind.embedding_table ????#?如果參數(shù)名稱(chēng)為"head.weight",返回ParamQuantKind.final_fc_weight表示該參數(shù)是最后一個(gè)全連接層的權(quán)重。 ????elif?name?==?"head.weight": ????????return?ParamQuantKind.final_fc_weight ????#?如果參數(shù)的維度為2且名稱(chēng)以".weight"結(jié)尾,返回ParamQuantKind.linear_weight表示該參數(shù)是線(xiàn)性層的權(quán)重。 ????elif?param_info.ndim?==?2?and?name.endswith(".weight"): ????????return?ParamQuantKind.linear_weight ????else: ????????return?ParamQuantKind.others上面已經(jīng)完成了RWKV模型的定義,接下來(lái)是定義幾個(gè)相關(guān)的TIR函數(shù)并定義一個(gè)最終的TIR模型獲取函數(shù)。這里對(duì)創(chuàng)建prefill和decode的create_func函數(shù)以及最終的TIR模型獲取函數(shù)get_model進(jìn)行解析:
由于字?jǐn)?shù)被公眾號(hào)限制了,請(qǐng)?jiān)谥跷恼虏榭催@部分,https://zhuanlan.zhihu.com/p/658354795
自此,我們基本就有了搭建RWKV模型的全部流程,說(shuō)白了就是用TVM的Relax語(yǔ)言手動(dòng)一對(duì)一的把PyTorch實(shí)現(xiàn)翻譯過(guò)去。
0x3. Transform舉例
在mlc-llm有一些圖層的優(yōu)化,在 https://github.com/BBuf/mlc-llm-code-analysis/tree/main/mlc_llm/transform 這個(gè)文件里面,我們對(duì)其中的一些優(yōu)化Pass做一下解析。
0x3.1 rewrite attention
代碼如下:
#?導(dǎo)入了TVM的relax模塊中的一些函數(shù)和類(lèi),以及TVM的script模塊中的relax別名。 from?tvm.relax.dpl?import?PatternContext,?is_const,?is_op,?rewrite_call,?wildcard from?tvm.script?import?relax?as?R #?定義了一個(gè)名為rewrite_attention的函數(shù),接收一個(gè)參數(shù)f。 def?rewrite_attention(f): ????#?使用wildcard()創(chuàng)建了三個(gè)通配符,分別賦值給Q、K和V。 ????Q?=?wildcard() ????K?=?wildcard() ????V?=?wildcard() ????#?使用is_op()函數(shù)創(chuàng)建了三個(gè)操作模式,分別對(duì)應(yīng)Q、K和V的維度重排操作,并將結(jié)果分別賦值給Q_BNSH、K_BNSH和V_BNSH。 ????Q_BNSH?=?is_op("relax.permute_dims")(Q) ????K_BNSH?=?is_op("relax.permute_dims")(K) ????V_BNSH?=?is_op("relax.permute_dims")(V) ????#?使用is_op()函數(shù)創(chuàng)建了一個(gè)操作模式,對(duì)應(yīng)K_BNSH的維度重排操作,并將結(jié)果賦值給K_BNSH_T。 ????K_BNSH_T?=?is_op("relax.permute_dims")(K_BNSH) ????#?使用is_op()函數(shù)創(chuàng)建了一系列操作模式,對(duì)應(yīng)矩陣乘法、除法、最大值、最小值、softmax以及另一個(gè)矩陣乘法操作。 ????#?這些操作模式(Attention)根據(jù)之前定義的通配符和常數(shù)匹配不同的計(jì)算圖節(jié)點(diǎn)。 ????matmul1?=?is_op("relax.matmul")(Q_BNSH,?K_BNSH_T) ????divide?=?is_op("relax.divide")(matmul1,?is_const()) ????max?=?is_op("relax.maximum")(divide,?is_const()) ????min?=?is_op("relax.minimum")(max,?wildcard()) ????softmax?=?is_op("relax.nn.softmax")(is_op("relax.astype")(min)) ????matmul2?=?is_op("relax.matmul")(is_op("relax.astype")(softmax),?V_BNSH) ????#?使用is_op()函數(shù)創(chuàng)建了一個(gè)操作模式,對(duì)應(yīng)matmul2的維度重排操作,并將結(jié)果賦值給pattern。 ????pattern?=?is_op("relax.permute_dims")(matmul2) ????#?定義了一個(gè)名為callback的回調(diào)函數(shù),接收兩個(gè)參數(shù)_和matchings。 ????#?該回調(diào)函數(shù)使用R.nn.attention函數(shù)構(gòu)建一個(gè)新的計(jì)算圖節(jié)點(diǎn),并使用matchings字典中的匹配結(jié)果來(lái)填充該節(jié)點(diǎn)的參數(shù)。 ????def?callback(_,?matchings): ????????return?R.nn.attention( ????????????matchings[Q],?matchings[K],?matchings[V],?causal_mask="BottomRight" ????????) ????#?使用rewrite_call函數(shù)將pattern、callback和輸入的計(jì)算圖f傳遞給它,以便在計(jì)算圖中應(yīng)用模式匹配和重寫(xiě)。 ????#?最后,將重寫(xiě)后的計(jì)算圖返回。 ????return?rewrite_call(pattern,?callback,?f)雖然沒(méi)有完全看懂這里的操作比如max和min的含義,但是從后面的callback_可以猜測(cè)出這里的Pass就是把打散的Self Attention模塊融合為一個(gè)relax.nn.attention操作。在cuda后端,如果支持了cutlass,那么relax.nn.attention操作就對(duì)應(yīng)了Flash Attention。
0x3.2 Transpose MatMul
代碼實(shí)現(xiàn)解析如下:
#?這段代碼定義了一個(gè)名為T(mén)ransposeMatmulCodeGenerator的類(lèi),該類(lèi)繼承自relax.PyExprMutator。 #?通過(guò)@relax.expr_functor.mutator裝飾器將該類(lèi)聲明為一個(gè)表達(dá)式重寫(xiě)器。 @relax.expr_functor.mutator class?TransposeMatmulCodeGenerator(relax.PyExprMutator): ????def?__init__(self,?mod): ????????super().__init__(mod) ????@staticmethod ????def?pattern(): ????????#?定義了靜態(tài)方法pattern(),該方法返回一個(gè)描述模式的元組。 ????????#?通過(guò)使用通配符(wildcard())和操作模式(is_op())來(lái)匹配計(jì)算圖中的特定模式。 ????????#?在這個(gè)例子中,模式匹配了一個(gè)矩陣乘法操作中矩陣w的維度重排操作,并將匹配的結(jié)果保存在字典annotations中。 ????????w?=?wildcard() ????????x?=?wildcard() ????????wT?=?is_op("relax.permute_dims")(w) ????????o?=?is_op("relax.matmul")(x,?wT) ????????annotations?=?{"o":?o,?"w":?w,?"x":?x,?"wT":?wT} ????????#?定義了內(nèi)部函數(shù)_check(),用于檢查模式匹配的結(jié)果是否滿(mǎn)足特定的條件。 ????????#?在這個(gè)例子中,檢查了維度重排操作的維度數(shù)和軸的順序是否正確。 ????????def?_check(context:?relax.transform.PatternCheckContext)?->?bool: ????????????transpose_call?=?context.annotated_expr["wT"] ????????????ndim?=?transpose_call.args[0].struct_info.ndim ????????????if?ndim?==?-1: ????????????????return?False ????????????if?ndim?==?2?and?transpose_call.attrs.axes?is?None: ????????????????return?True ????????????axes?=?list(range(ndim)) ????????????axes[-1],?axes[-2]?=?axes[-2],?axes[-1] ????????????return?list(transpose_call.attrs.axes)?==?axes ????????#?將匹配的計(jì)算圖節(jié)點(diǎn)、注解和檢查函數(shù)作為元組返回。 ????????return?o,?annotations,?_check ????#?重寫(xiě)了父類(lèi)的visit_call_()方法,用于處理特定類(lèi)型的計(jì)算圖節(jié)點(diǎn)。 ????def?visit_call_(self,?call:?relax.Call)?->?relax.Expr: ????????#?定義了一個(gè)變量out_dtype,用于保存輸出的數(shù)據(jù)類(lèi)型。 ????????out_dtype?=?None ????????#?定義了一個(gè)內(nèi)部函數(shù)te_transposed_matmul(),該函數(shù)實(shí)現(xiàn)了矩陣乘法的計(jì)算邏輯。 ????????def?te_transposed_matmul(a:?te.Tensor,?b:?te.Tensor)?->?te.Tensor: ????????????nonlocal?out_dtype ????????????#?將輸入張量?a?和?b?的形狀轉(zhuǎn)換為列表形式,分別保存在變量?a_shape?和?b_shape?中。 ????????????a_shape?=?list(a.shape) ????????????b_shape?=?list(b.shape) ????????????#?定義了兩個(gè)布爾變量?a_prepended?和?b_appended,用于標(biāo)記是否在相應(yīng)的形狀的前面或后面添加了維度。 ????????????a_prepended?=?False ????????????b_appended?=?False ????????????#?如果輸入張量?a?的形狀為一維,則在其前面添加一個(gè)維度,將其形狀修改為?(1,?original_shape)。 ????????????#?同樣地,如果輸入張量?b?的形狀為一維,則在其后面添加一個(gè)維度,將其形狀修改為?(original_shape,?1)。 ????????????if?len(a_shape)?==?1: ????????????????a_prepended?=?True ????????????????a_shape.insert(0,?1) ????????????if?len(b_shape)?==?1: ????????????????b_appended?=?True ????????????????b_shape.append(1) ????????????#?比較?a_shape?和?b_shape?的長(zhǎng)度,將結(jié)果保存在布爾變量?is_a_larger?中。 ????????????#?offset?表示兩個(gè)形狀長(zhǎng)度之差,用于后續(xù)處理。 ????????????is_a_larger?=?len(a_shape)?>?len(b_shape) ????????????offset?=?( ????????????????len(a_shape)?-?len(b_shape) ????????????????if?is_a_larger ????????????????else?len(b_shape)?-?len(a_shape) ????????????) ????????????#?創(chuàng)建兩個(gè)?relax.Var?對(duì)象?a_relax?和?bT_relax,用于表示張量?a?和轉(zhuǎn)置后的張量?bT?的結(jié)構(gòu)信息。 ????????????#?a_relax?的形狀和?a?的形狀相同,bT_relax?的形狀是?b?的形狀經(jīng)過(guò)維度互換后的結(jié)果。 ????????????a_relax?=?relax.Var("a",?relax.TensorStructInfo(a.shape)) ????????????bT_shape?=?list(b.shape) ????????????bT_shape[-1],?bT_shape[-2]?=?bT_shape[-2],?bT_shape[-1] ????????????bT_relax?=?relax.Var("b",?relax.TensorStructInfo(bT_shape)) ????????????#?使用?relax.op.matmul()?方法對(duì)?a_relax?和?bT_relax?進(jìn)行矩陣乘法運(yùn)算。 ????????????#?然后,通過(guò)?self.builder_.normalize()?方法對(duì)結(jié)果進(jìn)行歸一化處理,并獲取最終的輸出形狀。 ????????????output_shape?=?self.builder_.normalize( ????????????????relax.op.matmul(a_relax,?bT_relax) ????????????).struct_info.shape ????????????#?該函數(shù)接受可變數(shù)量的空間索引參數(shù)?idx_spatial, ????????????def?matmul_compute(*idx_spatial): ????????????????#?并定義了一個(gè)名為?k?的規(guī)約軸(reduce?axis),其范圍為?0?到?a_shape[-1]。 ????????????????k?=?te.reduce_axis((0,?a_shape[-1]),?name="k") ????????????????#?定義了一個(gè)名為?multiply_compute?的內(nèi)部函數(shù),用于計(jì)算乘法操作時(shí)的索引。 ????????????????def?multiply_compute(idx_reduce): ????????????????????a_indices?=?[] ????????????????????b_indices?=?[] ????????????????????#?根據(jù)?is_a_larger?的值,將?idx_spatial?中的索引分配給?a_indices?或?b_indices,用于處理形狀長(zhǎng)度差異的維度。 ????????????????????for?i?in?range(offset): ????????????????????????if?is_a_larger: ????????????????????????????a_indices.append(idx_spatial[i]) ????????????????????????else: ????????????????????????????b_indices.append(idx_spatial[i]) ????????????????????for?i?in?range( ????????????????????????offset,?len(output_shape)?-?(2?-?a_prepended?-?b_appended) ????????????????????): ????????????????????????#?根據(jù)維度的相等性,將適當(dāng)?shù)乃饕砑拥?a_indices?和?b_indices?中。 ????????????????????????#?如果維度不相等或無(wú)法確定是否相等,則將索引設(shè)為?0?或保持不變。 ????????????????????????a_dim?=?a_shape[i?if?is_a_larger?else?i?-?offset] ????????????????????????b_dim?=?b_shape[i?if?not?is_a_larger?else?i?-?offset] ????????????????????????dim_equal?=?a_dim?==?b_dim ????????????????????????if?not?isinstance(dim_equal,?tir.IntImm)?or?dim_equal?==?0: ????????????????????????????a_dim_is_one?=?isinstance(a_dim,?tir.IntImm)?and?a_dim?==?1 ????????????????????????????b_dim_is_one?=?isinstance(b_dim,?tir.IntImm)?and?b_dim?==?1 ????????????????????????????a_indices.append(0?if?a_dim_is_one?else?idx_spatial[i]) ????????????????????????????b_indices.append(0?if?b_dim_is_one?else?idx_spatial[i]) ????????????????????????else: ????????????????????????????a_indices.append(idx_spatial[i]) ????????????????????????????b_indices.append(idx_spatial[i]) ????????????????????#?在乘法操作的索引中添加規(guī)約軸?idx_reduce,并根據(jù)?a_prepended?和?b_appended?的值, ????????????????????#?將適當(dāng)?shù)乃饕砑拥?a_indices?和?b_indices?中。 ????????????????????if?not?a_prepended: ????????????????????????a_indices.append(idx_spatial[-2?+?b_appended]) ????????????????????a_indices.append(idx_reduce) ????????????????????if?not?b_appended: ????????????????????????b_indices.append(idx_spatial[-1]) ????????????????????b_indices.append(idx_reduce) ????????????????????#?根據(jù)?out_dtype?的值,選擇是否進(jìn)行數(shù)據(jù)類(lèi)型轉(zhuǎn)換,并返回乘法操作的結(jié)果。 ????????????????????dtype?=?out_dtype ????????????????????if?dtype?!=?"": ????????????????????????return?a(*a_indices).astype(dtype)?*?b(*b_indices).astype(dtype) ????????????????????return?a(*a_indices)?*?b(*b_indices) ????????????????#?在縮減軸?k?上對(duì)?multiply_compute?的結(jié)果進(jìn)行求和操作。 ????????????????return?te.sum(multiply_compute(k),?axis=k) ????????????#?使用?te.compute()?函數(shù)計(jì)算最終的輸出,其中使用一個(gè)?lambda?函數(shù)將輸入索引傳遞給?matmul_compute?函數(shù), ????????????#?并將結(jié)果命名為?"NT_matmul"。整個(gè)計(jì)算過(guò)程將根據(jù)?output_shape?進(jìn)行執(zhí)行。 ????????????return?te.compute( ????????????????output_shape, ????????????????lambda?*idx:?matmul_compute(*idx),??#?pylint:?disable=unnecessary-lambda ????????????????name="NT_matmul", ????????????) ????????#?首先,檢查函數(shù)調(diào)用的操作符?call.op?是否是?relax.GlobalVar?類(lèi)型。如果是,獲取與該操作符對(duì)應(yīng)的函數(shù)對(duì)象, ????????#?并檢查函數(shù)的屬性中是否包含鍵?"Composite",且其值為?"transpose_matmul_fuse"。 ????????if?isinstance(call.op,?relax.GlobalVar): ????????????function?=?self.builder_.get()[call.op] ????????????if?( ????????????????"Composite"?in?function.attrs ????????????????and?function.attrs["Composite"]?==?"transpose_matmul_fuse" ????????????): ????????????????#?將函數(shù)的返回類(lèi)型?function.ret_struct_info.dtype?賦值給變量?out_dtype ????????????????out_dtype?=?function.ret_struct_info.dtype ????????????????#?然后調(diào)用?self.builder_.call_te()?方法,傳遞?te_transposed_matmul?函數(shù)作為參數(shù), ????????????????#?以及調(diào)用的參數(shù)?call.args[1]?和?call.args[0],并指定?primfunc_name_hint?為?"NT_matmul"。 ????????????????return?self.builder_.call_te( ????????????????????te_transposed_matmul, ????????????????????call.args[1], ????????????????????call.args[0], ????????????????????primfunc_name_hint="NT_matmul", ????????????????) ????????return?super().visit_call_(call) #?使用?@tvm.transform.module_pass?裝飾器定義了一個(gè)名為?FuseTransposeMatmul?的類(lèi), #?并指定了優(yōu)化級(jí)別?opt_level=0?和?pass?的名稱(chēng)為?"FuseTransposeMatmul"。 @tvm.transform.module_pass(opt_level=0,?name="FuseTransposeMatmul") class?FuseTransposeMatmul: ????#?定義了?transform_module?方法,接受一個(gè)名為?mod?的?IRModule?對(duì)象和 ????#?tvm.transform.PassContext?對(duì)象作為參數(shù),并返回一個(gè)?IRModule?對(duì)象。 ????def?transform_module( ????????self,?mod:?IRModule,?ctx:?tvm.transform.PassContext ????)?->?IRModule: ????????#?通過(guò)調(diào)用?relax.transform.FuseOpsByPattern?并傳遞一個(gè)包含單個(gè)模式元組的列表, ????????#?對(duì)模塊?mod?進(jìn)行融合的轉(zhuǎn)置矩陣乘法操作。 ????????mod?=?relax.transform.FuseOpsByPattern( ????????????[("transpose_matmul_fuse",?*TransposeMatmulCodeGenerator.pattern())] ????????)(mod) ????????#?創(chuàng)建一個(gè)名為?transpose_matmul_codegen?的?TransposeMatmulCodeGenerator?對(duì)象, ????????#?并對(duì)模塊中的每個(gè)函數(shù)進(jìn)行遍歷。如果函數(shù)是?relax.Function?類(lèi)型,則調(diào)用?transpose_matmul_codegen.visit_expr? ????????#?方法對(duì)函數(shù)進(jìn)行轉(zhuǎn)置矩陣乘法代碼生成,并通過(guò)?transpose_matmul_codegen.builder_.update_func?方法更新函數(shù)。 ????????transpose_matmul_codegen?=?TransposeMatmulCodeGenerator(mod) ????????for?gv?in?mod.functions: ????????????func?=?mod[gv] ????????????if?not?isinstance(func,?relax.Function): ????????????????continue ????????????func?=?transpose_matmul_codegen.visit_expr(func) ????????????transpose_matmul_codegen.builder_.update_func(gv,?func) ????????#?返回轉(zhuǎn)置矩陣乘法代碼生成器的?builder?對(duì)象中的模塊。 ????????return?transpose_matmul_codegen.builder_.get()?
?
這個(gè)Pass將Transpose算子和一個(gè)MatMul算子替換為一個(gè)TE表達(dá)式的實(shí)現(xiàn)來(lái)達(dá)到融合算子的目的。
除了上面2種Pass,MLC-LLM還有不少的圖變換Pass,這篇文章就不一一去解析了,大多數(shù)優(yōu)化的目的都是匹配某種Pattern然后用更優(yōu)秀的算子去完成計(jì)算。
量化策略這一塊就不在這篇文章解析了。
0x4. MLC-LLM優(yōu)缺點(diǎn)個(gè)人評(píng)價(jià)和期待
0x4.1 優(yōu)點(diǎn)
Tune Free。mlc-llm不需要用TVM的AutoTVM/Ansor等等程序去執(zhí)行算子搜索過(guò)程,對(duì)跨平臺(tái)部署是比原始的TVM搭建的模型更清真的。
TIR的語(yǔ)法很大程度靠近了PyTorch的API,使得用戶(hù)在模型搭建部分不會(huì)很困難。
文檔寫(xiě)得不錯(cuò),跟隨教程基本可以完成大多數(shù)平臺(tái)的模型部署,并且單Batch下的吞吐和延遲表現(xiàn)都是不錯(cuò)的。
0x4.2 缺點(diǎn)
不支持從onnx或者h(yuǎn)uggingface模型直接轉(zhuǎn)換出TIR,手工實(shí)現(xiàn)模型的時(shí)候需要相當(dāng)多的先驗(yàn)知識(shí),比如在上面的RWKV模型中如果有自定義的cuda kernel,那么這個(gè)模型的實(shí)現(xiàn)可能只能全權(quán)委托給mlc-ai社區(qū)的核心開(kāi)發(fā)人員了。
KV Cache開(kāi)的是max_sequence_length這么長(zhǎng),顯然會(huì)有顯存的浪費(fèi),Serving的時(shí)候極限情況下可以服務(wù)的用戶(hù)數(shù)量應(yīng)該比VLLM/TGI等要小?
CUDA后端Decoding的Attention我看起來(lái)好像還是會(huì)用Flash Attention?也許是我看錯(cuò)了,這條暫時(shí)存疑。
在RWKV模型實(shí)現(xiàn)里,看到Batch維度寫(xiě)死為1了,應(yīng)該不支持動(dòng)態(tài)Batch?這樣對(duì)于啟真實(shí)服務(wù)來(lái)說(shuō)會(huì)有一些限制。
0x4.3 期待
如果短期內(nèi)能讓一個(gè)對(duì)TVM只有輕度依賴(lài)的社區(qū)開(kāi)發(fā)者新增一個(gè)新的模型。
如果模型存在自定義CUDA Kernel,需要一個(gè)詳細(xì)的教程來(lái)指引。
模型逐層打印來(lái)debug精度缺一個(gè)教程。
Paged Attention類(lèi)似策略的引入。
動(dòng)態(tài)Batch的支持。
暫時(shí)就想到這些,歡迎斧正。
編輯:黃飛
?
評(píng)論