CODA:新研究透過 GEMM-Epilogue 編程優化 Transformer 訓練

iconMetaEra
分享
Share IconShare IconShare IconShare IconShare IconShare IconCopy
AI summary icon精華摘要

expand icon
一篇名為 "CODA: 將 Transformer 模塊重寫為 GEMM-Epilogue 程式" 的新研究論文,提出了一種提升 Transformer 訓練效率的方法。此研究由麻省理工學院、普林斯頓大學、Together AI 和 Meta 共同完成,將記憶體密集型運算重構為 GEMM epilogues,以減少記憶體轉帳。這不僅能加快執行速度,還讓開發者和 LLM 能夠撰寫優化的 CUDA 核心。鏈上新聞強調了 AI 基礎設施在效能提升方面的日益關注,此方法可能影響未來與 AI 進展相關的新代幣上線發展。
本文介紹了一項名為 CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs 的新研究,核心目標是優化 Transformer 模型訓練的效率,特別是解決那些看似零散但累積起來耗時嚴重的「記憶體密集型」運算。

文章作者、來源:機器之心

5 月 22 日,Tri Dao 在社交媒體上轉發了 Han Guo 的一條推文。他還寫道:「經過一些數學重寫,發現 Transformer 的所有內容都是一系列 GEMM + epilogue(矩陣乘法加尾聲)。給定一些優化的原語,LLM(以及新手)就可以為所有 Transformer 操作編寫光速核心!」

Tri Dao 是 FlashAttention 系列的核心作者之一,而這條推文指向了他們當天發布的一篇論文:CODA。

  • 論文標題:CODA:將 Transformer 模塊重寫為 GEMM-Epilogue 程序
  • 論文地址:https://arxiv.org/abs/2605.19269
  • 代碼地址:https://github.com/HanGuo97/coda-kernels

這個名字,讀起來像「終曲」,念起來像「CUDA」。來自 MIT、普林斯頓、Together AI 和 Meta 的研究者,試圖用一套新的程式設計抽象,把 Transformer 訓練裡那些鮮少被人關注、卻持續消耗時間的「散碎計算」,系統性地消化掉。

訓練大模型的「偷懶稅」

要理解 CODA 在解決什麼問題,先要明白大模型訓練的時間都去哪了。

在一塊英偉達 H100 上訓練一個 LLaMA-3 風格的 1B 參數模型,大多數人會直覺地認為:時間都花在矩陣乘法和注意力計算上,畢竟那才是「真正的計算」。這個直覺大體上沒錯:矩陣乘法(GEMM)和注意力確實佔據了主要算力。

但如果你打開性能分析器仔細查看,會發現還有一批「小算子」在安靜地消耗時間:歸一化(RMSNorm)、激活函數(SwiGLU、RoPE)、殘差加法、跨層規約…… 它們單個計算量不大,卻頻繁地將大型中間張量從顯存中搬進搬出。

這就是所謂的「記憶體頻寬瓶頸」:好比一位廚藝絕頂的廚師,但每做一道菜都要從遠處的倉庫搬來食材,用完再送回去,而不是放在手邊的台面上。無論廚師的手速多快,等待搬運的時間都是真實的浪費。

更糟糕的是,隨著英偉達的 FP8、FP4 等低精度格式讓矩陣計算越來越快,這些「搬運」操作的相對成本反而在上升:矩陣乘法加速了,但張量搬進搬出的成本並沒有同比縮短。

論文中有一組數據非常直觀:在 H100 上使用 TorchTitan 訓練 1B 參數模型時,非矩陣乘法操作佔據了相當大比例的端到端運行時間,且隨著 FP8 精度的引入,這一比例還會進一步凸顯。

現有的編程框架對此幾乎無能為力。PyTorch 將 Transformer 的計算表達為一串算子序列,算子之間有清晰的邊界。這種邊界對於自動微分(autograd)非常友好,卻恰好阻止了跨算子的融合優化:每一個算子邊界,往往就是一次不必要的顯存寫回。

CODA:「尾聲」中藏有寶藏

Coda 的出發點是一個樸素的觀察。

在 GPU 上,一個高性能的矩陣乘法(GEMM)核心在結構上分為兩個部分:主迴圈(mainloop)負責核心的矩陣分塊乘加計算,尾聲(epilogue)負責在結果寫回顯存之前做一些收尾處理,比如加偏置、類型轉換、簡單縮放。

尾聲存在的意義,在於此時矩陣乘法的輸出還「活在」片上寄存器裡,還沒有落地到全局顯存。這是一個短暫的黃金窗口:如果能在這個時刻多做一些計算,就可以完全省掉一次顯存寫入再讀出的往返。

Coda 的核心洞察是:Transformer 中那些記憶體密集型操作,其實很多可以被代數地重新參數化,塞進這個「尾聲」窗口裡執行。

這需要一點數學技巧。以最常見的 GEMM-RMSNorm-GEMM 模式為例:一個矩陣乘法的結果,經過殘差加法、RMS 歸一化,再進行另一個矩陣乘法。傳統做法是三個獨立運算子串行執行,中間結果兩次寫入顯存。

CODA 團隊發現,RMS 正規化中的行縮放因子 r,由於是每行共享的標量,因此與後續的矩陣乘法滿足交換律:可將 r 的應用從「第二個 GEMM 之前」推遲至「第二個 GEMM 的尾聲」。推遲後,第一個 GEMM 的尾聲只需計算局部的「分塊均方根」(partial RMS),並由一個極輕量的輔助歸約內核合併,而完整的 RMSNorm 計算則消失。

類似的重新參數化同樣適用於 SwiGLU、RoPE(旋轉位置編碼)、交叉熵損失等操作,甚至對反向傳播也成立。論文中有一個定理證明:只要前向過程的尾端是「分塊局部」的,反向傳播就會自動繼承相同的結構。詳情請參閱原論文。

五種「積木」和一套「樂高語言」

CODA 不是一個具體的融合內核,而是一套程式設計抽象。

它固定住經過專家優化的 GEMM 主迴圈,然後在尾聲位置暴露五類可組合的基本原語:

  • 逐元素變換(residual 加法、激活函數、RoPE)
  • 向量加載與存儲(廣播 RMSNorm 權重)
  • 矩陣分塊加載與存儲(保存中間激活以供反向傳播使用)
  • Block Reduction (Local RMS, Block Log-Sum-Exp)
  • 有狀態變換(在線歸一化所需的 max 和 sum-exp 統計)

使用這五類積木,幾乎可以覆蓋標準 Transformer 在前向和反向傳播中除注意力之外的所有操作。

更有意思的是,這套抽象對於「誰來寫代碼」的包容度。論文在實驗中評估了兩種實現模式:一種是由人工程式員撰寫,另一種則是使用 Claude Code 生成——根據 CODA 的原語說明、若干示例和實現日誌,由 AI 完成大部分核心代碼,人工進行輕度監督。

兩種模式的性能表現均達到了較高水平。Tri Dao 在推文中表示「LLM 以及新手就可以編寫光速內核」,這正是論文實驗結果在現實層面的映射。

實驗結果

CODA 的基準測試選擇了較為苛刻的對手:cuBLAS 搭配 torch.compile,以及專為 LLM 優化的 Liger Kernel 和 FlashInfer。

論文對每個核心評估了兩種實現:CODA (LLM) 由 Claude Code 生成,研究者提供原語說明、若干示例和一份持續更新的實現技巧日誌,AI 完成主體代碼,人工進行輕度監督;CODA (Human) 由人工程式設計師獨立編寫,使用相同的高層重參數化思路,但不依賴 CODA 原語集本身。兩組結果均與 cuBLAS + torch.compile、Liger Kernel、FlashInfer 等優化庫進行對比。

在單運算子層面,以 GEMM-RMSNorm-GEMM 這一典型模式為例,CODA 在對應 1B、7B、70B 三個模型規模的隱藏維度下,均實現了對 cuBLAS + PyTorch 基線的超越。SwiGLU、RoPE、交叉熵等尾端組合亦有類似表現。

LLM 生成的內核在大多數基準上與人工手寫版本不相上下,個別配置下甚至略有超越。這在 GPU 內核優化這個歷來門檻極高的領域,是一個頗為罕見的結論。

反向傳播的收益尤為突出:GEMM-Residual-PartialRMS-GEMM 的反向核心相比基線加速幅度可達 1.6 至 1.8 倍,SwiGLU 反向也有約 1.4 至 1.6 倍的提升。在這個方向上,LLM 與人工實現的差距同樣微小。這並不奇怪:反向傳播天然涉及更多中間張量的存取,尾端融合的收益就更大;而 CODA 的原語設計足夠清晰,使得 AI 模型能夠正確地完成組合。

在完整 Transformer 層的端到端基準中,CODA 的前向加速在不同規模下約為 5% 至 20%,在較大模型尺寸(對應 70B 規模的隱藏維度)下效果更為顯著。

在數值精度方面,CODA 的重參數化調整了 RMSNorm 縮放因子的應用時機,但實驗表明其數值誤差與 PyTorch 參考實現相當,在某些配置下誤差甚至更小——得益於 GEMM 主迴圈本身具有更高精度的累加器。

CODA 能做什麼:在進入更宏觀的視角之前,先釐清 CODA 的能力範圍。

  • 覆蓋範圍:在標準 Transformer(如 LLaMA 架構)的前向和反向傳播中,除注意力和詞嵌入之外的幾乎全部計算,包括 RMSNorm、殘差加法、SwiGLU 激活、RoPE 旋轉位置編碼、交叉熵損失,以及上述操作的反向梯度計算。
  • 加速效果:在對應 1B 至 70B 規模的隱藏維度下,單算子層面相比 cuBLAS + torch.compile 基線有不同程度的提升,其中反向傳播收益最為顯著(部分核心可達 1.6 倍以上);完整 Transformer 層的端到端前向加速約為 5% 至 20%,在較大模型尺寸下效果更突出。
  • 誰能使用:CODA 基於 CuTeDSL(NVIDIA CUTLASS 的 Python DSL)實現,支援人工程式設計師和 AI 模型兩種核心編寫方式,且兩種方式均能達到高性能。
  • 目前限制:目前僅支援單 GPU 場景,不涉及分散式訓練;重參數化主要針對標準 Transformer 架構,其他架構的適用性仍有待驗證。

結語

CODA 不是孤立的工作。它是這類思想的具體實現:在 GPU 上,真正的優化空間往往不在「算什麼」,而在「怎麼搬」。

FlashAttention 讓注意力計算「住進」了片上記憶體,CODA 試圖讓歸一化和激活函數也「住進去」。Triton 降低了撰寫自定義核心的門檻,ThunderKittens、TileLang 等進一步在不同層次上探索這一空間。這些工作共同指向同一個方向:將 PyTorch 算子圖的表達便利性,與接近手寫 CUDA 的執行效率,真正統一在一套可程式化的框架裡。

Tri Dao 推文的最後一句話值得再細品:「LLM 以及新手就可以為所有 Transformer 操作編寫光速內核。」這背後有一個更深的邏輯:當程式設計抽象設計得足夠好,AI 模型本身就可以參與到自身訓練基礎設施的優化中。這個循環,才是 CODA 最耐人尋味的地方。

從這個角度看,「CODA」這個名字或許另有深意。在古典音樂中,Coda 是樂曲末尾收束全篇的段落。在這裡,它是 GEMM 核心的「尾聲」—— 而寫好這段尾聲,或許正是 Transformer 訓練系統效率提升的下一個重要章節。

免責聲明:本頁面資訊可能來自第三方,不一定反映KuCoin的觀點或意見。本內容僅供一般參考之用,不構成任何形式的陳述或保證,也不應被解釋為財務或投資建議。 KuCoin 對任何錯誤或遺漏,或因使用該資訊而導致的任何結果不承擔任何責任。 虛擬資產投資可能存在風險。請您根據自身的財務狀況仔細評估產品的風險以及您的風險承受能力。如需了解更多信息,請參閱我們的使用條款風險披露