0x1. OpenAI Triton介紹閱讀
這里來看官方的介紹:https://openai.com/research/triton ,從官方的介紹中我們可以看到OpenAI Triton的產生動機以及它的目標是什么,還可以看到一些經典算法的實現例子展示。
這里的標題是 Introducing Triton: Open-source GPU programming for neural networks ,翻譯就是《介紹 Triton:用于神經網絡的開源 GPU 編程語言》。然后下面的一句話翻譯過來是:我們發布了 Triton 1.0,這是一種開源的類 Python 編程語言,它使得沒有 CUDA 經驗的研究人員能夠編寫高效的 GPU 代碼——大多數情況下,其效能與專家所能編寫的代碼相當。這里指出了triton的目的,就是讓編寫cuda kernrl變得更簡單。接下來就逐步看一下介紹里的具體內容,為了更加準確這里會截圖對應的原文然后放上我的翻譯或者理解。
這里的意思是Triton可以使得用戶用較少的努力就寫出一個達到硬件峰值性能的kernel,比如使用 Triton 可以編寫 FP16 矩陣乘法的核函數,其性能能夠匹配 cuBLAS,并且這個代碼不超過25行。然后研究者已經用Triton開發了一些高效的實現,和功能相同的Torch實現相比,性能可以達到兩倍提升。后面一段就是強調了使用CUDA來把一些原始的PyTorch實現寫一個算子一般會更加高效,但是這個難度不小,并且目前已有工作也不能很好覆蓋這種情況,所以OpenAI Triton誕生。
這里講的是GPU編程的挑戰,現代 GPU 的架構大致可以分為三個主要部分——DRAM、SRAM 和 ALU。在優化 CUDA 代碼時,必須考慮到這些組件:
從 DRAM 的內存傳輸必須合并成大型事務,以利用現代內存接口的大總線寬度(內存合并訪問)。
數據必須在重復使用前手動存儲到 SRAM 中,并進行管理來最小化bank conflict。
計算必須仔細地進行劃分和調度,不僅是在流式多處理器(SMs)之間,還包括在其內部,以促進指令/線程級并行性,并利用專用的 ALU(例如,Tensor Cores)。
考慮所有這些因素可能對于擁有多年經驗的資深 CUDA 程序員來說都是一個挑戰。Triton 的目的是完全自動化這些優化,以便開發者能夠更好地專注于他們并行代碼的高層邏輯。Triton 旨在廣泛適用,因此不會自動在流式多處理器(SMs)之間調度工作——留下一些重要的算法考慮(例如,tiling,跨 SM 同步)由開發者自行決定。
然后給了一個表格展示cuda的編譯器和triton的區別。
在所有可用的領域特定語言和即時編譯器中,Triton可能和Numba最相似:kernel被定義為一個裝飾過的函數,并以不同的 program_id 并行啟動在所謂的網格實例上。然而,正如下面的代碼片段所示,相似之處僅此而已:Triton 通過對塊上的操作來暴露實例內部的并行性——這些小數組的尺寸是二的冪次方——而不是單指令多線程(SIMT)執行模型。這樣做,Triton 有效地抽象出了所有與 CUDA 線程塊內部并發相關的問題(例如,內存合并、共享內存同步/沖突、Tensor Cores調度)。
注意,Triton 的即時編譯器將 X 和 Y 視為指針而不是張量;我們認為保留對內存訪問的低級控制對于處理更復雜的數據結構(例如,塊稀疏張量)是重要的。重要的是,這種特定的 softmax 實現在整個標準化過程中將 X 的行保留在 SRAM 中,這在適用時最大化了數據重用(約 <32K 列)。這與 PyTorch 的內部 CUDA 代碼不同,后者使用臨時內存使其更具通用性,但顯著更慢(如下所示)。這里的關鍵不是 Triton 本質上更好,而是它簡化了專用kernel的開發,這些內核可能比在通用庫中找到的內核快得多。
Torch(v1.9)JIT編譯器的較低性能凸顯了從高級張量操作序列自動生成 CUDA 代碼的難度。
這里是說Triton大概只需要25行Python代碼就可以實現一個接近峰值的矩陣乘法。(后面有專門的一大節講這個代碼的原理)代碼如下:
@triton.jit defmatmul(A,B,C,M,N,K,stride_am,stride_ak, stride_bk,stride_bn,stride_cm,stride_cn, **META): #extractmetaparameters BLOCK_M,GROUP_M=META['BLOCK_M'],META['GROUP_M'] BLOCK_N=META['BLOCK_N'] BLOCK_K=META['BLOCK_K'] #programsaregroupedtogethertoimproveL2hitrate _pid_m=tl.program_id(0) _pid_n=tl.program_id(1) pid_m=_pid_m//GROUP_M pid_n=(_pid_n*GROUP_M)+(_pid_m%GROUP_M) #rm(resp.rn)denotesarangeofindices #forrows(resp.col)ofC rm=pid_m*BLOCK_M+tl.arange(0,BLOCK_M) rn=pid_n*BLOCK_N+tl.arange(0,BLOCK_N) #rkdenotesarangeofindicesforcolumns #(resp.rows)ofA(resp.B) rk=tl.arange(0,BLOCK_K) #thememoryaddressesofelementsinthefirstblockof #AandBcanbecomputedusingnumpy-stylebroadcasting A=A+(rm[:,None]*stride_am+rk[None,:]*stride_ak) B=B+(rk[:,None]*stride_bk+rn[None,:]*stride_bn) #initializeanditerativelyupdateaccumulator acc=tl.zeros((BLOCK_M,BLOCK_N),dtype=tl.float32) forkinrange(K,0,-BLOCK_K): a=tl.load(A) b=tl.load(B) #blocklevelmatrixmultiplication acc+=tl.dot(a,b) #incrementpointerssothatthenextblocksofAandB #areloadedduringthenextiteration A+=BLOCK_K*stride_ak B+=BLOCK_K*stride_bk #fuseleakyReLUifdesired #acc=tl.where(acc>=0,acc,alpha*acc) #writebackresult C=C+(rm[:,None]*stride_cm+rn[None,:]*stride_cn) mask=(rm[:,None]
手寫矩陣乘法kernel的一個重要優勢是,它們可以根據需要定制,以適應輸入(例如,切片)和輸出(例如,LeakyReLU)的融合轉換。如果沒有像 Triton 這樣的系統,沒有出色的 GPU 編程專長的開發者將無法進行矩陣乘法內核的定制修改。
這里是說Triton 的良好性能源于一個以 Triton-IR 為中心的模塊化系統架構,Triton-IR 是一個基于 LLVM 的中間表示,在這個系統中,多維值塊(這個是MLIR的概念)是一等公民。GPT@triton.jit 裝飾器的工作原理是遍歷提供的 Python 函數的抽象語法樹(AST),以便使用常見的 SSA 構建算法即時生成 Triton-IR。然后,編譯器后端會簡化、優化并自動并行化所產生的 IR 代碼,再將其轉換為高質量的 LLVM-IR —— 最終生成 PTX —— 以在近期的 NVIDIA GPU 上執行。目前不支持 CPU 和 AMD GPU,但我們歡迎社區貢獻,旨在解決這一限制。
我們發現,通過 Triton-IR 使用塊級別程序表示,使我們的編譯器能夠自動執行各種重要的程序優化。例如,可以通過觀察計算密集型塊級操作(例如,tl.dot)的操作數,自動將數據暫存到共享內存中,并使用標準的活性分析技術進行分配和同步。另一方面,如下所示,Triton 程序可以高效且自動地并行化,既可以(1)通過并發執行不同的kernel實例在流式多處理器(SMs)間并行,也可以(2)通過分析每個塊級操作的迭代空間,并在不同的 SIMD 單元間適當分配,從而在 SMs 內部并行。
0x2. 教程1 Vector Addition閱讀
意思是這一節教程會介紹Triton編程模型定義kernel的基本寫法,此外也會介紹一下怎么實現一個良好的benchmark測試。下面來看計算kernel實現,我把注釋改成中文了:
importtorch importtriton importtriton.languageastl @triton.jit defadd_kernel(x_ptr,#*指針*,指向第一個輸入向量。 y_ptr,#*指針*,指向第二個輸入向量。 output_ptr,#*指針*,指向輸出向量。 n_elements,#向量的大小。 BLOCK_SIZE:tl.constexpr,#每個程序應處理的元素數量。 #注意:`constexpr`這樣可以被用作形狀值。 ): #這里有多個“程序”處理不同的數據。我們在這里識別我們是哪一個程序: pid=tl.program_id(axis=0)#我們使用一維啟動網格,所以軸是0。 #該程序將處理從初始數據偏移的輸入。 #例如,如果你有一個長度為256的向量和塊大小為64,那么程序 #將分別訪問元素[0:64,64:128,128:192,192:256]。 #注意偏移量是一個指針列表: block_start=pid*BLOCK_SIZE offsets=block_start+tl.arange(0,BLOCK_SIZE) #創建一個掩碼以防止內存操作越界訪問。 mask=offsets
這里還聲明了一個輔助函數來(1)分配z張量,(2)使用適當的網格/塊大小排隊上面的kernel:
defadd(x:torch.Tensor,y:torch.Tensor): #我們需要預分配輸出。 output=torch.empty_like(x) assertx.is_cudaandy.is_cudaandoutput.is_cuda n_elements=output.numel() #SPMD啟動網格表示并行運行的kernel實例的數量。 #它類似于CUDA啟動網格。它可以是Tuple[int],也可以是Callable(metaparameters)->Tuple[int]。 #在這種情況下,我們使用一個1D網格,其大小是塊的數量: grid=lambdameta:(triton.cdiv(n_elements,meta['BLOCK_SIZE']),) #注意: #-每個torch.tensor對象都隱式地轉換為指向其第一個元素的指針。 #-使用`triton.jit`裝飾的函數可以用一個啟動網格索引來獲得可調用的GPU內核。 #-不要忘記將元參數作為關鍵字參數傳遞。 add_kernel[grid](x,y,output,n_elements,BLOCK_SIZE=1024) #我們返回一個指向z的句柄,但是因為`torch.cuda.synchronize()`還沒有被調用,所以這時kernel仍然 #在異步運行。 returnoutput
我們現在可以使用上面定義的函數來計算兩個torch.tensor對象的逐元素求和,并測試其正確性:
torch.manual_seed(0) size=98432 x=torch.rand(size,device='cuda') y=torch.rand(size,device='cuda') output_torch=x+y output_triton=add(x,y) print(output_torch) print(output_triton) print(f'Themaximumdifferencebetweentorchandtritonis' f'{torch.max(torch.abs(output_torch-output_triton))}')
輸出:
tensor([1.3713,1.3076,0.4940,...,0.6724,1.2141,0.9733],device='cuda:0') tensor([1.3713,1.3076,0.4940,...,0.6724,1.2141,0.9733],device='cuda:0') Themaximumdifferencebetweentorchandtritonis0.0
我們可以對不同大小的向量進行自定義操作的性能基準測試,以了解它相對于PyTorch的表現如何。為了簡化操作,Triton提供了一系列內置工具,使我們能夠簡潔地繪制出自定義操作在不同問題規模下的性能圖表。
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['size'],#用作繪圖x軸的參數名。 x_vals=[2**iforiinrange(12,28,1)],#`x_name`的不同可能值。 x_log=True,#x軸是對數的。 line_arg='provider',#其值對應于圖中不同線條的參數名。 line_vals=['triton','torch'],#`line_arg`的可能值。 line_names=['Triton','Torch'],#線條的標簽名稱。 styles=[('blue','-'),('green','-')],#線條樣式。 ylabel='GB/s',#y軸的標簽名稱。 plot_name='vector-add-performance',#繪圖的名稱。也用作保存繪圖的文件名。 args={},#不在`x_names`和`y_name`中的函數參數的值。 )) defbenchmark(size,provider): x=torch.rand(size,device='cuda',dtype=torch.float32) y=torch.rand(size,device='cuda',dtype=torch.float32) quantiles=[0.5,0.2,0.8] ifprovider=='torch': ms,min_ms,max_ms=triton.testing.do_bench(lambda:x+y,quantiles=quantiles) ifprovider=='triton': ms,min_ms,max_ms=triton.testing.do_bench(lambda:add(x,y),quantiles=quantiles) gbps=lambdams:12*size/ms*1e-6 returngbps(ms),gbps(max_ms),gbps(min_ms)
gbps = lambda ms: 12 * size / ms * 1e-6這里的12表示的是數據讀寫的bit,因為有x和y以及z的存在,所以是3*4=12bit。現在可以運行上面的裝飾函數了。傳遞 print_data=True 參數來查看性能數據,傳遞 show_plots=True 參數來繪制圖表,和/或傳遞 save_path='/path/to/results/' 參數來將它們連同原始CSV數據一起保存到磁盤上:
benchmark.run(print_data=True,show_plots=True)
可以看到,對于elementwise任務,Triton的性能幾乎和PyTorch持平,但是Triton寫起來很簡單。0x3. 教程2 Fused Softmax閱讀
在這個教程中,我們將編寫一個融合的softmax操作,這個操作對于特定類型的矩陣來說比PyTorch的原生操作要快得多:那些行的大小可以放入GPU的SRAM中的矩陣。
通過這樣做,我們將學習到:
kernel融合對于帶寬受限操作的好處。
Triton中的reduce操作符。
動機
自定義GPU kernel用于逐元素加法在教育上是有價值的,但在實際應用中可能作用有限。讓我們考慮一個簡單的(數值穩定的)softmax操作的情況:
importtorch importtriton importtriton.languageastl @torch.jit.script defnaive_softmax(x): """使用原生pytorch計算X的逐行softmax 我們減去最大元素是為了避免溢出。Softmax對這種偏移是不變的。 """ #讀取MN個元素;寫入M個元素 x_max=x.max(dim=1)[0] #讀取MN+M個元素;寫入MN個元素 z=x-x_max[:,None] #讀取MN個元素;寫入MN個元素 numerator=torch.exp(z) #讀取MN個元素;寫入M個元素 denominator=numerator.sum(dim=1) #讀取MN+M個元素;寫入MN個元素 ret=numerator/denominator[:,None] #總計:讀取5MN+2M個元素;寫入3MN+2M個元素 returnret
計算kernel
我們的softmax kernel的工作方式如下:每個程序加載輸入矩陣X的一行,對其進行歸一化處理,然后將結果寫回到輸出Y中。需要注意的是,Triton的一個重要限制是每個塊必須包含2的冪次方個元素,因此如果我們想處理任何可能的輸入形狀,我們需要在內部對每行進行“pad”以及對內存訪問操作進行保護(也就是防止越界):
@triton.jit defsoftmax_kernel(output_ptr,input_ptr,input_row_stride,output_row_stride,n_cols,BLOCK_SIZE:tl.constexpr): #softmax的各行是獨立的,所以我們在這些行上進行并行處理 row_idx=tl.program_id(0) #步長代表我們需要增加多少指針來前進1行 row_start_ptr=input_ptr+row_idx*input_row_stride #塊大小是大于n_cols的下一個2的冪次,因此我們可以將每一行放入單個塊中 col_offsets=tl.arange(0,BLOCK_SIZE) input_ptrs=row_start_ptr+col_offsets #將行加載到SRAM中,使用掩碼因為BLOCK_SIZE可能大于n_cols row=tl.load(input_ptrs,mask=col_offsets
解析來創建一個輔助函數,該函數為任何給定的輸入張量排隊執行kernel并且設置了啟動參數。
defsoftmax(x): n_rows,n_cols=x.shape #塊大小是大于`x`中列數的最小2的冪 BLOCK_SIZE=triton.next_power_of_2(n_cols) #我們可以使用的另一個技巧是要求編譯器通過增加每行分布的warp數(`num_warps`)來使用更多的線程。 #在下一個教程中,你將看到如何以更自然的方式自動調整這個值,這樣你就不必自己想出手動啟發式方法。 num_warps=4 ifBLOCK_SIZE>=2048: num_warps=8 ifBLOCK_SIZE>=4096: num_warps=16 #分配輸出 y=torch.empty_like(x) #排隊執行內核。一維啟動網格很簡單:我們有每行一個內核實例 #輸入矩陣 softmax_kernel[(n_rows,)]( y, x, x.stride(0), y.stride(0), n_cols, num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, ) returny
這里是驗證Triton實現的fuse softmax和PyTorch的naive實現等價,顯然他們是等價的。BenchMark
這里設定矩陣的行數為固定的4096來做benchmark。
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'],#用作繪圖x軸的參數名 x_vals=[128*iforiinrange(2,100)],#`x_name`的不同可能值 line_arg='provider',#其值對應于圖中不同線條的參數名 line_vals=[ 'triton', 'torch-native', 'torch-jit', ],#`line_arg`的可能值 line_names=[ "Triton", "Torch(原生)", "Torch(jit)", ],#線條的標簽名稱 styles=[('blue','-'),('green','-'),('green','--')],#線條樣式 ylabel="GB/s",#y軸的標簽名稱 plot_name="softmax-performance",#繪圖的名稱。也用作保存繪圖的文件名。 args={'M':4096},#不在`x_names`和`y_name`中的函數參數的值 )) defbenchmark(M,N,provider): x=torch.randn(M,N,device='cuda',dtype=torch.float32) quantiles=[0.5,0.2,0.8] ifprovider=='torch-native': ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.softmax(x,axis=-1),quantiles=quantiles) ifprovider=='triton': ms,min_ms,max_ms=triton.testing.do_bench(lambda:softmax(x),quantiles=quantiles) ifprovider=='torch-jit': ms,min_ms,max_ms=triton.testing.do_bench(lambda:naive_softmax(x),quantiles=quantiles) gbps=lambdams:2*x.nelement()*x.element_size()*1e-9/(ms*1e-3) returngbps(ms),gbps(max_ms),gbps(min_ms) benchmark.run(show_plots=True,print_data=True)
這里提到雖然Triton實現的softmax性能更好并且易于理解和維護,但PyTorch的torch.softmax則更加通用。0x4. 教程3 Matrix Multiply閱讀
首先教程指出這里就是要寫一個Block級別的矩陣乘法,然后這里會涉及到多維度的指針操作,程序重排以更好的命中l2 cache以及自動調優。動機
矩陣乘法是大多數現代高性能計算系統的關鍵構建塊。它們眾所周知難以優化,因此它們的實現通常由硬件供應商自己作為所謂的“內核庫”(例如,cuBLAS)的一部分來完成。不幸的是,這些庫通常是專有的,無法輕易地定制以適應現代深度學習工作負載的需求(例如,融合激活函數)。在這個教程中,你將學習如何使用Triton自己實現高效的矩陣乘法,這種方法易于定制和擴展。
大致來說,我們將要編寫的內核將實現以下塊級算法來乘以一個 (M, K) 矩陣和一個 (K, N) 矩陣:
#Doinparallel forminrange(0,M,BLOCK_SIZE_M): #Doinparallel forninrange(0,N,BLOCK_SIZE_N): acc=zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=float32) forkinrange(0,K,BLOCK_SIZE_K): a=A[m:m+BLOCK_SIZE_M,k:k+BLOCK_SIZE_K] b=B[k:k+BLOCK_SIZE_K,n:n+BLOCK_SIZE_N] acc+=dot(a,b) C[m:m+BLOCK_SIZE_M,n:n+BLOCK_SIZE_N]=acc
其中,雙重嵌套的for循環的每次迭代都由一個專用的Triton program實例執行。
計算kernel
上述算法實際上在Triton中相當容易實現。主要的難點來自于在內循環中計算必須讀取A和B塊的內存位置。為此,我們需要多維指針運算。
指針運算
對于一個2D Tensor X,X[i, j]的內存位置為&X[i, j] = X + i*stride_xi + j*stride_xj。因此,對于A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]和B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]的塊指針可以用下面的偽代碼定義:
&A[m:m+BLOCK_SIZE_M,k:k+BLOCK_SIZE_K]=a_ptr+(m:m+BLOCK_SIZE_M)[:,None]*A.stride(0)+(k:k+BLOCK_SIZE_K)[None,:]*A.stride(1); &B[k:k+BLOCK_SIZE_K,n:n+BLOCK_SIZE_N]=b_ptr+(k:k+BLOCK_SIZE_K)[:,None]*B.stride(0)+(n:n+BLOCK_SIZE_N)[None,:]*B.stride(1);
這意味著A和B塊的指針可以在Triton中初始化,比如 k=0 如下代碼所示。另外注意,我們需要一個額外的模運算來處理M不是BLOCK_SIZE_M的倍數或N不是BLOCK_SIZE_N的倍數的情況,在這種情況下,我們可以用一些無用的值填充數據,這些值不會對結果產生影響。對于K維度,我們稍后將使用掩碼加載語義來處理。
offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N offs_k=tl.arange(0,BLOCK_SIZE_K) a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak) b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn)
然后在內循環中按如下方式更新:
a_ptrs+=BLOCK_SIZE_K*stride_ak; b_ptrs+=BLOCK_SIZE_K*stride_bk;
如上所述,每個program實例計算一個 [BLOCK_SIZE_M, BLOCK_SIZE_N] 大小的C矩陣塊。重要的是要記住,這些塊的計算順序是很重要的,因為它會影響我們程序的L2緩存命中率,不幸的是,一個簡單的行優先順序是不夠的。
pid=triton.program_id(0); grid_m=(M+BLOCK_SIZE_M-1)//BLOCK_SIZE_M; grid_n=(N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N; pid_m=pid/grid_n; pid_n=pid%grid_n;
L2 Cache優化
如上所述,每個程序實例計算一個 [BLOCK_SIZE_M, BLOCK_SIZE_N] 大小的C矩陣塊。重要的是要記住,這些塊的計算順序很重要,因為它會影響我們程序的L2緩存命中率,不幸的是,一個簡單的行主序排序是不夠的。
一個可能的解決方案是以一種促進數據重用的順序啟動塊。這可以通過在切換到下一列之前將塊在GROUP_M行的super group中分組來實現:
#程序ID pid=tl.program_id(axis=0) #沿M軸的程序ID數量 num_pid_m=tl.cdiv(M,BLOCK_SIZE_M) #沿N軸的程序ID數量 num_pid_n=tl.cdiv(N,BLOCK_SIZE_N) #組中的程序數量 num_pid_in_group=GROUP_SIZE_M*num_pid_n #該程序所在組的ID group_id=pid//num_pid_in_group #組中第一個程序的行ID first_pid_m=group_id*GROUP_SIZE_M #如果`num_pid_m`不能被`GROUP_SIZE_M`整除,最后一個組更小 group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M) #*在組內*,程序按列主序排列 #程序在*啟動網格*中的行ID pid_m=first_pid_m+(pid%group_size_m) #程序在*啟動網格*中的列ID pid_n=(pid%num_pid_in_group)//group_size_m
例如,在下面的矩陣乘法中,每個矩陣由9個塊乘以9個塊組成,我們可以看到,如果我們按行主序計算輸出,我們需要將90個塊加載到SRAM中以計算前9個輸出塊,但如果我們按grouped ordering進行計算,我們只需要加載54個塊。
在實際應用中,這可以在某些硬件架構上提高我們矩陣乘法內核的性能超過10%(例如,在A100上從220提升到245 TFLOPS)。
L2 Cache優化原理補充講解
上面的group oredering的訪問代碼比較難理解,這里來更詳細的解析一下。
#程序ID pid=tl.program_id(axis=0) #沿M軸的程序ID數量 num_pid_m=tl.cdiv(M,BLOCK_SIZE_M) #沿N軸的程序ID數量 num_pid_n=tl.cdiv(N,BLOCK_SIZE_N)
這里的num_pid_m和num_pid_n就是求分別要在M和N方向循環多少次。
然后上面圖中的黑色數字其實就可以理解為program id,我們可以看到program id增加的方向其實就代表了遍歷的ordering,對于row major來說就是在行方向上順序遍歷,而對于group ordering來說就是按照一個BLOCK_SIZE_M*BLOCK_SIZE_N這么大的一個小組來遍歷。其實這段代碼就是完成group ordering的遍歷:
num_pid_in_group=GROUP_SIZE_M*num_pid_n group_id=pid//num_pid_in_group first_pid_m=group_id*GROUP_SIZE_M group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M) pid_m=first_pid_m+(pid%group_size_m) pid_n=(pid%num_pid_in_group)//group_size_m
以上面圖來看,num_pid_m=3,num_pid_n=3,num_pid_in_group=group_id * GROUP_SIZE_M=9*3=27,也就是下面的紅色框里面的program個數,從名字也可以看出來這個紅色框劃分的區域也是一個group。
group_id 就表示當前的這次 "循環", 是在第幾個紅色框里,以program 0為例,這里為group_id = pid // num_pid_in_group=0//27=0。而first_pid_m 代表當前 group 中的第一個黃色program在全局的M維度上是第幾個program ,這里為first_pid_m = group_id * GROUP_SIZE_M=0,group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)這里是考慮到最后一個group可能占不滿數據(存在padding),所以就做一個截斷處理。
pid_m=first_pid_m+(pid%group_size_m) pid_n=(pid%num_pid_in_group)//group_size_m
這兩行代碼計算當前的program處理的黃色小塊坐標([pid_m, pid_n]),pid_m這行是在行方向上移動,pid_n這行則是保證在上面的紅色框里面一定是一列一列來訪問的。
作為對比,在Row-major的方法中,訪問方式應該是這樣的:
pid_m=pid//num_pid_n pid_n=pid%num_pid_n
計算最后的結果
有了上面的鋪墊,我們就可以計算最終的結果了,下面的代碼展示了完整的Triton 矩陣乘法kernel實現。
#使用`triton.jit`裝飾的函數可以通過`triton.autotune`裝飾器進行自動調優,該裝飾器包括: #-一系列定義不同配置的`triton.Config`對象, #這些配置涉及元參數(例如`BLOCK_SIZE_M`)和編譯選項(例如`num_warps`)的不同設置 #-一個自動調優*關鍵字*,其值的變化將觸發對所有 #提供的配置的評估 @triton.autotune( configs=[ #每個Config定義了一組特定的配置參數和編譯選項 triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':64,'GROUP_SIZE_M':8},num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M':32,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), ], key=['M','N','K'],#自動調優關鍵字 ) @triton.jit defmatmul_kernel( #指向矩陣的指針 a_ptr,b_ptr,c_ptr, #矩陣維度 M,N,K, #步長變量表示在特定維度上移動1個元素時指針增加的量。 #例如`stride_am`是將`a_ptr`增加多少以獲取下一行的元素(A有M行)。 stride_am,stride_ak,#A矩陣的步長 stride_bk,stride_bn,#B矩陣的步長 stride_cm,stride_cn,#C矩陣的步長 #元參數 BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,BLOCK_SIZE_K:tl.constexpr,# GROUP_SIZE_M:tl.constexpr,# ACTIVATION:tl.constexpr#激活函數 ): """用于計算矩陣乘法C=AxB的內核。 A的形狀為(M,K),B的形狀為(K,N),C的形狀為(M,N)。 """ #----------------------------------------------------------- #將程序ID`pid`映射到它應該計算的C矩陣的塊。 #這是以groupedordering完成的,以促進L2數據重用。 #詳細解釋看一節 pid=tl.program_id(axis=0) num_pid_m=tl.cdiv(M,BLOCK_SIZE_M) num_pid_n=tl.cdiv(N,BLOCK_SIZE_N) num_pid_in_group=GROUP_SIZE_M*num_pid_n group_id=pid//num_pid_in_group first_pid_m=group_id*GROUP_SIZE_M group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M) pid_m=first_pid_m+(pid%group_size_m) pid_n=(pid%num_pid_in_group)//group_size_m #---------------------------------------------------------- #為A和B的第一個塊創建指針。 #我們將在K方向移動時推進這個指針并累加 #`a_ptrs`是[BLOCK_SIZE_M,BLOCK_SIZE_K]塊的指針 #`b_ptrs`是[BLOCK_SIZE_K,BLOCK_SIZE_N]塊的指針 #有關詳細信息,請參閱上方“指針算術”部分 offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N offs_k=tl.arange(0,BLOCK_SIZE_K) a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak) b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn) #----------------------------------------------------------- #迭代以計算C矩陣的一個塊。 #我們將累加到一個`[BLOCK_SIZE_M,BLOCK_SIZE_N]`塊 #的fp32值以獲得更高的精度。 #`accumulator`在循環后會轉換回fp16。 accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32) forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)): #LoadthenextblockofAandB,generateamaskbycheckingtheKdimension. #Ifitisoutofbounds,setitto0. a=tl.load(a_ptrs,mask=offs_k[None,:]=0,x,0.01*x)
我們現在可以創建一個方便的封裝函數,它只需要兩個輸入張量,并且會:(1)檢查任何形狀約束;(2)分配輸出;(3)啟動上述kernel。
defmatmul(a,b,activation=""): #Checkconstraints. asserta.shape[1]==b.shape[0],"Incompatibledimensions" asserta.is_contiguous(),"MatrixAmustbecontiguous" assertb.is_contiguous(),"MatrixBmustbecontiguous" M,K=a.shape K,N=b.shape #Allocatesoutput. c=torch.empty((M,N),device=a.device,dtype=a.dtype) #1Dlaunchkernelwhereeachblockgetsitsownprogram. grid=lambdaMETA:(triton.cdiv(M,META['BLOCK_SIZE_M'])*triton.cdiv(N,META['BLOCK_SIZE_N']),) matmul_kernel[grid]( a,b,c,# M,N,K,# a.stride(0),a.stride(1),# b.stride(0),b.stride(1),# c.stride(0),c.stride(1),# ACTIVATION=activation# ) returnc
計算過程的補充說明
上面的《L2 Cache優化原理補充講解》這一節明確了kernel的group ordering的訪問方式以及實現,現在來看對于當前的program實例具體是怎么計算的。現在以計算C中的第一個Block的(0, 0)為例子,它需要從A和B分別加載9個黃色的小塊數據相乘并累加最后得到C中的(0, 0)位置結果。如下圖所示:
下面的代碼先把program實例當前要處理A和B的第一個Block加載上來:
#---------------------------------------------------------- #為A和B的第一個塊創建指針。 #我們將在K方向移動時推進這個指針并累加 #`a_ptrs`是[BLOCK_SIZE_M,BLOCK_SIZE_K]塊的指針 #`b_ptrs`是[BLOCK_SIZE_K,BLOCK_SIZE_N]塊的指針 #有關詳細信息,請參閱上方“指針算術”部分 offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N offs_k=tl.arange(0,BLOCK_SIZE_K) a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak) b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn)
這里的a_ptr 是整個 A 矩陣第一個元素的地址,offs_am和offs_bn表示當前的program id在M維度和K維度的坐標,這個坐標是一個list,用tl.arange(0, BLOCK_SIZE_K)來獲取。
得到 M 維度 和 K 維度的坐標后, 就可以讓它們各自和 M 維度 和 K 維度的 stride 相乘, 然后和 a_ptr 相加, 就可以得到 A 矩陣 9 個 block 中第一個 block 中每個元素的地址了。 b_ptr也是同理。
最后一部分就是累加了,這里會在K維度上進行累加,每次計算輸出的一個塊。
#迭代以計算C矩陣的一個塊。 #我們將累加到一個`[BLOCK_SIZE_M,BLOCK_SIZE_N]`塊 #的fp32值以獲得更高的精度。 #`accumulator`在循環后會轉換回fp16。 accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32) forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)): #LoadthenextblockofAandB,generateamaskbycheckingtheKdimension. #Ifitisoutofbounds,setitto0. a=tl.load(a_ptrs,mask=offs_k[None,:]
這行代碼a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)考慮到 K 可能不能被 BLOCK_SIZE_K 整除, 到每一行最后一個 block 的時候, 實際大小是不足 BLOCK_SIZE_K 的,所以需要把超出的那部分元素mask掉。
最后這部分代碼是把當前的算子和LeakyReLU激活函數進行融合:
#當累加器仍然是FP32時,可以融合任意激活函數 ifACTIVATION=="leaky_relu": accumulator=leaky_relu(accumulator) c=accumulator.to(tl.float16)
單元測試
Benchmark
這里使用一個方陣來對比Triton實現的matmul kernel和cublas的matmul kernel的性能。
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['M','N','K'],#用作圖表x軸的參數名 x_vals=[128*iforiinrange(2,33)],#`x_name`的不同可能值 line_arg='provider',#其值對應于圖表中不同線條的參數名 #`line_arg`的可能值 line_vals=['cublas','triton'], #線條的標簽名稱 line_names=["cuBLAS","Triton"], #線條樣式 styles=[('green','-'),('blue','-')], ylabel="TFLOPS",#y軸的標簽名稱 plot_name="matmul-performance",#圖表的名稱,也用作保存圖表的文件名。 args={},#其他參數 )) defbenchmark(M,N,K,provider): #初始化張量 a=torch.randn((M,K),device='cuda',dtype=torch.float16) b=torch.randn((K,N),device='cuda',dtype=torch.float16) quantiles=[0.5,0.2,0.8]#分位數 #如果提供者是cublas ifprovider=='cublas': ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.matmul(a,b),quantiles=quantiles) #如果提供者是triton ifprovider=='triton': ms,min_ms,max_ms=triton.testing.do_bench(lambda:matmul(a,b),quantiles=quantiles) #性能計算函數 perf=lambdams:2*M*N*K*1e-12/(ms*1e-3) returnperf(ms),perf(max_ms),perf(min_ms) #運行基準測試,展示圖表和打印數據 benchmark.run(show_plots=True,print_data=True)
可以看到基于Triton實現的矩陣乘kernel性能大體可以和高度優化的cuBlas持平。
審核編輯:劉清
-
sram
+關注
關注
6文章
769瀏覽量
114916 -
多處理器
+關注
關注
0文章
22瀏覽量
8988 -
Cache
+關注
關注
0文章
129瀏覽量
28452 -
python
+關注
關注
56文章
4811瀏覽量
85076 -
OpenAI
+關注
關注
9文章
1158瀏覽量
6748
原文標題:【BBuf的CUDA筆記】十三,OpenAI Triton 入門筆記一
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論