記事は、CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programsという新研究を紹介しており、その核心的な目標は、Transformerモデルのトレーニング効率を最適化し、特に散在しているが蓄積すると深刻な時間消費を引き起こす「メモリ集約型」演算を解決することである。
記事執筆者、出典:機械の心
5月22日、Tri Daoはソーシャルメディア上でHan Guoのツイートを共有し、次のように記した。「いくつかの数学的再記述により、Transformerのすべての内容が一連のGEMM + epilogue(行列乗算とエピローグ)であることが明らかになった。最適化されたプリミティブを用いれば、LLM(および初心者)もすべてのTransformer操作用に光速カーネルを記述できる!」

Tri DaoはFlashAttentionシリーズの主要な著者の一人であり、このツイートはその日発表された論文「CODA」を指しています。

- 論文タイトル:CODA: TransformerブロックをGEMMエピローグプログラムとして書き直す
- 論文のアドレス:https://arxiv.org/abs/2605.19269
- コードのアドレス:https://github.com/HanGuo97/coda-kernels
この名前は「終曲」と読め、かつ「CUDA」と発音される。MIT、プリンストン大学、Together AI、Metaの研究者たちは、Transformerの訓練においてあまり注目されず、継続的に時間を消費する「散在する計算」を、新しいプログラミング抽象化によって体系的に解消しようとしている。
大モデルを訓練するための「怠け税」
CODAが解決しようとしている問題を理解するには、大規模モデルのトレーニングに時間がどこに使われているかを理解する必要があります。
NVIDIA H1001枚でLLaMA-3スタイルの1Bパラメータモデルをトレーニングする際、ほとんどの人は直感的に、行列乗算と注意メカニズムの計算に時間がかかっていると考えるだろう。なぜなら、それらが「本質的な計算」だからだ。この直感は基本的に正しい:行列乗算(GEMM)と注意メカニズムが主要な計算リソースを占めている。

しかし、パフォーマンスアナライザを開いて詳細を確認すると、正規化(RMSNorm)、活性化関数(SwiGLU、RoPE)、残差加算、層間集約などの「小さな演算子」が静かに時間を消費していることがわかります。これらは個々の計算量は大きくありませんが、大型の中間テンソルをGPUメモリから頻繁に読み出し、書き込みしています。

これがいわゆる「メモリ帯域幅のボトルネック」です。まるで最高のシェフが、毎回料理をするたびに材料を遠くの倉庫から運び出し、使い終わったらまた戻さなければならないようなものです。手元の作業台に材料を置いておけばいいのに。シェフの手の速さがどれほど速くても、運搬を待つ時間は確かに無駄です。
さらに悪いことに、NVIDIAのFP8やFP4などの低精度フォーマットにより行列計算がますます高速化される中で、これらの「移動」操作の相対的コストは逆に上昇している:行列乗算は加速したが、テンソルの入出力コストはそれに比例して短縮されていない。
論文には、H100でTorchTitanを用いて1Bパラメータモデルを訓練する際、行列乗算以外の操作がエンドツーエンドの実行時間の相当部分を占めており、FP8精度の導入によりこの割合がさらに顕著になるという直観的なデータがあります。
従来のプログラミングフレームワークでは、これに対してほとんど対応できません。PyTorchはTransformerの計算を一連のオペレーターの列として表現し、オペレーター間には明確な境界があります。この境界は自動微分(autograd)には非常に適していますが、逆にオペレーター間の融合最適化を妨げます。各オペレーターの境界は、しばしば不要なメモリ書き戻しを意味します。
CODA:「尾声」に隠された宝物
CODAの出発点は、単純な観察である。
GPU上で、高性能な行列乗算(GEMM)カーネルは、主ループ(mainloop)とエピローグ(epilogue)の二つの部分で構成されています。主ループは、行列のブロック乗算と加算の核心処理を担当し、エピローグは、結果をVRAMに書き戻す前にバイアスの加算、型変換、簡単なスケーリングなどの後処理を行います。

フィナーレの意義は、この時点での行列乗算の出力がまだオンチップレジスタに「存在」し、グローバルVRAMに書き込まれていないことにあります。これは短い黄金の窓口です:この瞬間にさらに計算を追加すれば、VRAMへの書き込みと読み込みの往復を完全に省略できます。
CODAの核心的な洞察は、Transformerにおけるメモリ集約型の操作の多くを、代数的に再パラメータ化してこの「尾声」ウィンドウ内で実行できるということである。
これは少し数学的なテクニックを要します。最も一般的なGEMM-RMSNorm-GEMMパターンを例に挙げると:1つの行列乗算の結果が残差加算、RMS正規化を経て、さらに別の行列乗算が行われます。従来の方法では、3つの独立したオペレーターが直列で実行され、中間結果が2回GPUメモリに書き込まれます。

CODAチームは、RMS正規化における行スケーリング因子rが、各行で共有されるスカラーであるため、後続の行列乗算と交換法則を満たすことを発見しました。つまり、rの適用を「2番目のGEMMの前」から「2番目のGEMMの終了時」に遅延させることができます。この遅延により、1番目のGEMMの終了時には、軽量な補助還元カーネルによって統合される局所的な「ブロック単位の二乗平均平方根」(partial RMS)を計算するだけで済み、完全なRMSNorm計算は不要になります。
同様の再パラメトリゼーションは、SwiGLU、RoPE(回転位置エンコーディング)、交差エントロピー損失などの操作にも適用でき、逆伝播にも成り立ちます。論文には、前方計算の最終段階が「ブロック局所的」である限り、逆伝播が自動的に同じ構造を継承することを示す定理が記載されています。詳細については、オリジナル論文をご覧ください。
五つの「ブロック」と一組の「レゴ言語」
CODAは具体的な融合カーネルではなく、一連のプログラミング抽象です。
それは専門家によって最適化されたGEMMのメインループを固定し、末尾に5種類の組み合わせ可能な基本プリミティブを公開します:
- 要素ごとの変換(残差加算、活性化関数、RoPE)
- ベクトルの読み込みと保存(ブロードキャスト RMSNorm 重み)
- 行列のブロック読み込みと保存(逆伝播用の中間活性化を保存)
- ブロック還元(局所的平均平方根、ブロック log-sum-exp)
- ステートフル変換(オンライン正規化に必要な最大値およびsum-exp統計)
この5種類のブロックを使用することで、アテンションを除く標準的なTransformerのフォワードおよびバックワードプロパゲーションのほぼすべての操作をカバーできます。
より興味深いのは、この抽象が「誰がコードを書くか」に対して持つ寛容性である。論文では、二つの実装モデルを評価した:一つは人間のプログラマーが作成するもの、もう一つはClaude Codeが生成するものである。CODAの原語説明、いくつかの例、および実装ログをもとに、AIが大部分のカーネルコードを生成し、人間は軽度の監督を行う。
両方のモードのパフォーマンスはいずれも高いレベルに達しました。Tri Daoはツイートで「LLMと初心者でも光速カーネルを記述できる」と述べており、これは論文の実験結果が現実で実現されたことを示しています。
実験結果
CODAのベンチマークでは、cuBLASとtorch.compile、およびLLMに最適化されたLiger KernelとFlashInferという厳しい対手を選択しています。
論文では、各カーネルに対して2つの実装を評価した。CODA (LLM) はClaude Codeによって生成され、研究者が原語の説明、いくつかの例、および継続的に更新される実装テクニックログを提供し、AIがメインコードを生成し、人間が軽度の監督を行う。CODA (Human) は、人間のプログラマーが独立して作成したもので、同じ高レベルのリパラメトライゼーションのアイデアを使用するが、CODA原語セット自体には依存しない。両グループの結果は、cuBLAS + torch.compile、Liger Kernel、FlashInferなどの最適化ライブラリと比較された。
単一オペレータレベルで、GEMM-RMSNorm-GEMMという典型的なパターンを例に挙げると、CODAは1B、7B、70Bの3つのモデル規模における隠れ次元に対して、cuBLAS + PyTorchのベースラインを上回りました。SwiGLU、RoPE、交差エントロピーなどの後続コンポーネントでも同様の性能を示しました。
LLMが生成したカーネルは、ほとんどのベンチマークで手作業で記述されたバージョンと同等の性能を発揮し、一部の設定ではやや上回る結果を示しています。これは、従来非常に高いハードルがあったGPUカーネル最適化の分野において、非常に珍しい結論です。



逆伝播の利点が特に顕著である:GEMM-Residual-PartialRMS-GEMMの逆伝播カーネルはベースラインに対して1.6〜1.8倍の高速化を達成し、SwiGLUの逆伝播も約1.4〜1.6倍の向上を示した。この分野では、LLMと手動実装との差も非常に小さい。これは驚くべきことではない:逆伝播は自然に多くの中間テンソルのアクセスを伴い、テール融合の利点がより大きくなるためである。また、CODAのプリミティブ設計は十分に明確であり、AIモデルが正しく組み合わせを実行できるようにしている。

完全なTransformer層のエンドツーエンドベンチマークにおいて、CODAのフォワード加速はスケールごとに約5%〜20%であり、より大きなモデルサイズ(70Bスケールの隠れ次元に対応)で効果がより顕著です。
数値精度に関して、CODAの再パラメータ化はRMSNormのスケーリングファクターの適用タイミングを調整しているが、実験によりその数値誤差はPyTorchの参照実装と同等であり、一部の設定ではさらに小さいことが示された——これはGEMMのメインループ自体がより高精度の累算器を備えているためである。
CODAの機能:より広い視野に入る前に、CODAの機能範囲を明確にしましょう。
- 対象範囲:標準的なTransformer(LLaMAアーキテクチャなど)における、アテンションと単語埋め込みを除く、ほぼすべての計算。RMSNorm、残差加算、SwiGLU活性化、RoPE回転位置エンコーディング、交差エントロピー損失、およびこれらの操作の逆伝播勾配計算を含む。
- 加速効果:1B~70B規模の隠れ次元において、単一オペレーター単位でcuBLAS + torch.compileベースラインと比較して、不同程度の性能向上が見られます。特にバックプロパゲーションの効果が顕著で、一部のカーネルでは1.6倍以上を達成。完全なTransformer層のエンドツーエンドフロワード加速は約5%~20%で、モデルサイズが大きくなるほど効果が顕著になります。
- CODAは、CuTeDSL(NVIDIA CUTLASSのPython DSL)を基に実装されており、人間のプログラマとAIモデルの両方のカーネル記述方式をサポートし、どちらの方式でも高性能を実現します。
- 現在の制限:現在は単一GPU環境のみをサポートしており、分散学習は対象外です。再パラメータ化は標準的なTransformerアーキテクチャに主に適用され、他のアーキテクチャへの適用性は検証中です。
まとめ
CODAは孤立した作業ではありません。これは、GPU上で真の最適化の余地が「何を計算するか」ではなく「どのようにデータを移動するか」にあるという思想の具体的な実装です。
FlashAttention はアテンション計算をオンチップメモリに「収容」し、CODA は正規化と活性化関数も「収容」しようと試みている。Triton はカスタムカーネルの作成障壁を下げ、ThunderKittens、TileLang などのツールは異なるレベルでこの領域をさらに探求している。これらの取り組みはすべて、PyTorch のオペレータグラフの表現のしやすさと、手書き CUDA に近い実行効率を、一つのプログラマブルなフレームワーク内で真正に統一することを指向している。
Tri Daoのツイートの最後の文は再考に値する:「LLMと初心者でも、すべてのTransformer操作に光速カーネルを書ける。」その背後にはより深いロジックがある:プログラミングの抽象化が十分に優れていれば、AIモデル自体が自身のトレーニングインフラの最適化に参加できるということだ。この循環こそ、CODAの最も興味深い点である。
この観点から見ると、「CODA」という名前には別の深い意味があるのかもしれない。古典音楽において、Codaは楽曲の終わりに全体を締めくくる部分である。ここでは、それはGEMMカーネルの「尾声」であり、この尾声をしっかりと書き上げることが、Transformerトレーニングシステムの効率向上における次の重要な章となるかもしれない。
