Bài viết giới thiệu một nghiên cứu mới có tên CODA: Rewriting Transformer Blocks as GEMM-Epilogue Programs, với mục tiêu chính là tối ưu hóa hiệu suất huấn luyện mô hình Transformer, đặc biệt là giải quyết những phép toán “tốn bộ nhớ” tưởng chừng rời rạc nhưng tích lũy lại gây tiêu tốn nhiều thời gian.
Tác giả bài viết, nguồn: Machine Heart
Ngày 22 tháng 5, Tri Dao đã chia sẻ một bài đăng trên mạng xã hội của Han Guo. Anh ấy còn viết: “Sau một số phép biến đổi toán học, có thể thấy rằng mọi thứ trong Transformer đều là một chuỗi GEMM + epilogue (phép nhân ma trận cộng phần kết thúc). Với một số nguyên tố đã được tối ưu, LLM (và người mới bắt đầu) có thể viết các kernel tốc độ ánh sáng cho mọi thao tác Transformer!”

Tri Dao là một trong những tác giả cốt lõi của chuỗi FlashAttention, và tweet này dẫn đến bài báo mà họ công bố cùng ngày: CODA.

- Tiêu đề bài báo: CODA: Viết lại các khối Transformer dưới dạng chương trình GEMM-Epilogue
- Địa chỉ bài báo: https://arxiv.org/abs/2605.19269
- Địa chỉ mã nguồn: https://github.com/HanGuo97/coda-kernels
Tên này, khi đọc giống như “chung khúc”, khi phát âm lại giống như “CUDA”. Các nhà nghiên cứu đến từ MIT, Princeton, Together AI và Meta đã cố gắng sử dụng một hệ thống trừu tượng lập trình mới để hệ thống hóa xử lý những phép tính rải rác, ít được chú ý nhưng liên tục tốn thời gian trong quá trình huấn luyện Transformer.
Chi phí lười biếng khi huấn luyện mô hình lớn
Để hiểu CODA đang giải quyết vấn đề gì, trước tiên cần hiểu thời gian trong quá trình huấn luyện mô hình lớn đã được sử dụng vào đâu.
Khi huấn luyện một mô hình 1B tham số phong cách LLaMA-3 trên một GPU NVIDIA H100, phần lớn mọi người sẽ trực觉 cho rằng: thời gian chủ yếu dành cho các phép nhân ma trận và tính toán chú ý, vì đó mới là “tính toán thực sự”. Trực giác này nhìn chung là đúng: phép nhân ma trận (GEMM) và cơ chế chú ý chiếm phần lớn năng lực tính toán.

Nhưng nếu bạn mở công cụ phân tích hiệu năng và xem kỹ, bạn sẽ thấy một loạt các “toán tử nhỏ” đang âm thầm tiêu tốn thời gian: chuẩn hóa (RMSNorm), hàm kích hoạt (SwiGLU, RoPE), phép cộng dư, giảm toàn lớp... Chúng từng cái có khối lượng tính toán nhỏ, nhưng thường xuyên di chuyển các tensor trung gian lớn vào ra bộ nhớ GPU.

Đây chính là cái được gọi là “nút thắt băng thông bộ nhớ”: giống như một đầu bếp tài ba, nhưng mỗi khi nấu một món ăn đều phải mang nguyên liệu từ kho ở xa đến, dùng xong lại phải trả lại, thay vì để ngay trên quầy cạnh bên. Dù tay đầu bếp nhanh đến đâu, thời gian chờ đợi vận chuyển vẫn là sự lãng phí thực sự.
Tệ hơn nữa, khi các định dạng độ chính xác thấp như FP8 và FP4 của NVIDIA làm cho các phép tính ma trận ngày càng nhanh hơn, chi phí tương đối của các thao tác “di chuyển” này lại đang tăng lên: phép nhân ma trận đã được tăng tốc, nhưng chi phí di chuyển tensor vào và ra không giảm tương ứng.
Một nhóm dữ liệu trong bài báo rất trực quan: khi huấn luyện mô hình 1 tỷ tham số trên H100 bằng TorchTitan, các thao tác không phải nhân ma trận chiếm một tỷ lệ đáng kể thời gian chạy end-to-end, và tỷ lệ này còn trở nên nổi bật hơn khi sử dụng độ chính xác FP8.
Các khung lập trình hiện có gần như bất lực trước vấn đề này. PyTorch biểu diễn tính toán Transformer dưới dạng một chuỗi các toán tử, giữa các toán tử có ranh giới rõ ràng. Ranh giới này rất thân thiện với tự động vi phân (autograd), nhưng lại chính xác ngăn cản việc tối ưu hóa kết hợp giữa các toán tử: mỗi ranh giới toán tử thường là một lần ghi lại bộ nhớ hiển thị không cần thiết.
CODA: Bí mật ẩn chứa trong «Phần kết»
Điểm xuất phát của CODA là một quan sát đơn giản.
Trên GPU, một nhân nhân ma trận hiệu suất cao (GEMM) được chia thành hai phần cấu trúc: vòng lặp chính (mainloop) chịu trách nhiệm thực hiện các phép tính nhân-chia ma trận theo khối, và phần kết thúc (epilogue) thực hiện các xử lý cuối cùng trước khi ghi kết quả trở lại bộ nhớ video, chẳng hạn như cộng độ lệch, chuyển đổi kiểu dữ liệu và thu nhỏ đơn giản.

Ý nghĩa của giai đoạn cuối nằm ở chỗ, đầu ra của phép nhân ma trận vẫn còn “sống” trong thanh ghi trên chip, chưa được ghi xuống bộ nhớ hiển thị toàn cục. Đây là một cửa sổ vàng ngắn ngủi: nếu có thể thực hiện thêm một số phép tính vào thời điểm này, ta có thể hoàn toàn loại bỏ một lần ghi và đọc lại bộ nhớ hiển thị.
Sự hiểu biết cốt lõi của CODA là: nhiều thao tác tốn bộ nhớ trong Transformer thực ra có thể được tái tham số hóa đại số và thực hiện trong cửa sổ "kết thúc" này.
Điều này đòi hỏi một chút kỹ năng toán học. Lấy mô hình GEMM-RMSNorm-GEMM phổ biến nhất làm ví dụ: kết quả của một phép nhân ma trận, sau đó thực hiện phép cộng dư, chuẩn hóa RMS, rồi tiếp tục thực hiện một phép nhân ma trận khác. Cách làm truyền thống là thực hiện ba toán tử độc lập theo chuỗi, với kết quả trung gian được ghi xuống bộ nhớ GPU hai lần.

Đội ngũ CODA phát hiện rằng, hệ số tỷ lệ hàng r trong chuẩn hóa RMS, do là một đại lượng vô hướng chung cho từng hàng, nên có tính giao hoán với phép nhân ma trận sau đó: có thể dời việc áp dụng r từ “trước GEMM thứ hai” sang “cuối GEMM thứ hai”. Sau khi dời, ở cuối GEMM thứ nhất, chỉ cần tính “căn trung bình bình phương cục bộ” (partial RMS), được tổng hợp bởi một nhân quy約 nhẹ nhàng, và phép tính RMSNorm đầy đủ biến mất.
Sự tái tham số hóa tương tự cũng áp dụng cho các thao tác như SwiGLU, RoPE (mã hóa vị trí xoay), hàm mất mát cross-entropy, và thậm chí cả lan truyền ngược. Một định lý trong bài báo chứng minh rằng: miễn là phần cuối của tiến trình thuận là “phân khối cục bộ”, thì lan truyền ngược sẽ tự động kế thừa cùng cấu trúc đó. Vui lòng truy cập bài báo gốc để xem chi tiết.
Năm loại "khối xây dựng" và một bộ "ngôn ngữ LEGO"
CODA không phải là một kernel tích hợp cụ thể, mà là một bộ trừu tượng lập trình.
Nó cố định vòng lặp chính GEMM đã được chuyên gia tối ưu, sau đó phơi bày năm loại nguyên tố cơ bản có thể kết hợp tại vị trí kết thúc:
- Biến đổi từng phần tử (phép cộng residual, hàm kích hoạt, RoPE)
- Tải và lưu vector (phát sóng trọng số RMSNorm)
- Nạp và lưu trữ khối ma trận (lưu hoạt hóa trung gian để sử dụng trong lan truyền ngược)
- Chunk reduction (local RMS, chunk log-sum-exp)
- Có trạng thái biến đổi (thống kê max và sum-exp cần thiết cho chuẩn hóa trực tuyến)
Với năm loại khối xây dựng này, hầu như tất cả các thao tác trong quá trình lan truyền thuận và lan truyền ngược của một Transformer chuẩn, ngoại trừ sự chú ý, đều có thể được bao phủ.
Điều thú vị hơn là mức độ linh hoạt của hệ thống trừu tượng này đối với việc ai sẽ viết mã. Trong thí nghiệm, bài báo đánh giá hai mô hình triển khai: một là do lập trình viên con người viết, và một là do Claude Code tạo ra — với các nguyên tố cơ bản của CODA, một số ví dụ và nhật ký triển khai, AI sẽ thực hiện phần lớn mã lõi, trong khi con người chỉ giám sát nhẹ.
Hiệu suất của cả hai chế độ đều đạt mức cao. Tri Dao viết trên Twitter: “LLM và người mới bắt đầu có thể viết kernel với tốc độ ánh sáng”, điều này chính là sự phản ánh thực tế của kết quả thí nghiệm trong bài báo.
Kết quả thí nghiệm
Các bài kiểm tra hiệu năng của CODA chọn những đối thủ khá khắt khe: cuBLAS kết hợp với torch.compile, cùng với Liger Kernel và FlashInfer được tối ưu hóa riêng cho LLM.
Bài báo đánh giá hai cách triển khai cho mỗi kernel: CODA (LLM) được tạo bởi Claude Code, với các hướng dẫn nguyên tố, một số ví dụ và nhật ký kỹ thuật triển khai được cập nhật liên tục do nhà nghiên cứu cung cấp, AI thực hiện phần lớn mã nguồn và con người giám sát nhẹ; CODA (Human) được lập trình viên con người tự viết độc lập, sử dụng cùng tư tưởng tái tham số hóa cấp cao, nhưng không phụ thuộc vào bộ nguyên tố CODA. Kết quả của cả hai nhóm đều được so sánh với các thư viện tối ưu hóa như cuBLAS + torch.compile, Liger Kernel và FlashInfer.
Ở cấp độ đơn bộ xử lý, lấy mô hình điển hình GEMM-RMSNorm-GEMM làm ví dụ, CODA đều vượt trội so với cơ sở cuBLAS + PyTorch trên các kích thước ẩn tương ứng với ba mô hình 1B, 7B và 70B. Các tổ hợp cuối như SwiGLU, RoPE và cross-entropy cũng cho kết quả tương tự.
Các kernel do LLM tạo ra có hiệu năng tương đương với phiên bản do con người viết tay trên hầu hết các benchmark, và trong một số cấu hình cụ thể còn vượt nhẹ. Đây là một kết luận khá hiếm hoi trong lĩnh vực tối ưu hóa kernel GPU, vốn luôn có rào cản cực kỳ cao.



Lợi ích từ việc lan truyền ngược đặc biệt nổi bật: nhân tố lan truyền ngược của GEMM-Residual-PartialRMS-GEMM có thể tăng tốc từ 1,6 đến 1,8 lần so với cơ sở, trong khi SwiGLU lan truyền ngược cũng cải thiện khoảng 1,4 đến 1,6 lần. Trên hướng này, khoảng cách giữa LLM và các triển khai thủ công cũng rất nhỏ. Điều này không gây ngạc nhiên: lan truyền ngược tự nhiên liên quan đến nhiều truy cập tensor trung gian hơn, do đó lợi ích từ việc hợp nhất cuối cùng lớn hơn; và thiết kế nguyên tử của CODA đủ rõ ràng để mô hình AI có thể thực hiện chính xác việc kết hợp.

Trong benchmark end-to-end của lớp Transformer đầy đủ, tốc độ tăng tiến của CODA dao động từ khoảng 5% đến 20% ở các quy mô khác nhau, với hiệu quả nổi bật hơn ở các kích thước mô hình lớn (tương ứng với kích thước ẩn 70B).
Về độ chính xác số, việc tái tham số hóa CODA điều chỉnh thời điểm áp dụng hệ số tỷ lệ RMSNorm, nhưng các thí nghiệm cho thấy sai số số học của nó tương đương với phiên bản tham chiếu PyTorch, và trong một số cấu hình, sai số thậm chí còn nhỏ hơn — nhờ vào bộ tích lũy có độ chính xác cao hơn trong vòng lặp chính GEMM.
CODA có thể làm gì: Một danh sách tra cứu nhanh để làm rõ phạm vi khả năng của CODA trước khi nhìn vào bức tranh lớn hơn.
- Phạm vi bao gồm: gần như toàn bộ các phép tính trong quá trình lan truyền thuận và lan truyền ngược của Transformer chuẩn (như kiến trúc LLaMA), ngoại trừ sự chú ý và nhúng từ, bao gồm RMSNorm, phép cộng残差, kích hoạt SwiGLU, mã hóa vị trí xoay RoPE, tổn thất entropy chéo, cùng với các phép tính gradient ngược của các thao tác trên.
- Hiệu ứng tăng tốc: Trong các chiều ẩn từ quy mô 1B đến 70B, mức độ cải thiện ở cấp độ toán tử so với cơ sở cuBLAS + torch.compile là khác nhau, trong đó lợi ích trong quá trình lan truyền ngược nổi bật nhất (một số nhân có thể đạt hơn 1,6 lần); tốc độ tăng tiến trực tiếp toàn bộ lớp Transformer dao động khoảng 5% đến 20%, hiệu quả càng rõ rệt hơn với các mô hình lớn hơn.
- Ai cũng có thể sử dụng: CODA được triển khai dựa trên CuTeDSL (Python DSL của NVIDIA CUTLASS), hỗ trợ cả hai phương pháp viết kernel: do lập trình viên thủ công và do mô hình AI, cả hai đều đạt hiệu năng cao.
- Hạn chế hiện tại: Hiện chỉ hỗ trợ môi trường một GPU, không bao gồm đào tạo phân tán; tái tham số hóa chủ yếu áp dụng cho kiến trúc Transformer tiêu chuẩn, tính phù hợp với các kiến trúc khác vẫn cần được xác minh.
Kết luận
CODA không phải là một công việc cô lập. Đó là sự hiện thực hóa cụ thể của một lớp ý tưởng: trên GPU, không gian tối ưu hóa thực sự thường không nằm ở "tính gì", mà nằm ở "di chuyển thế nào".
FlashAttention giúp các phép tính chú ý được “di chuyển” vào bộ nhớ trên chip, CODA cố gắng làm tương tự với các hàm chuẩn hóa và kích hoạt. Triton làm giảm rào cản khi viết các kernel tùy chỉnh, trong khi ThunderKittens, TileLang và các công cụ khác tiếp tục khám phá không gian này ở các cấp độ khác nhau. Những nỗ lực này cùng hướng tới một mục tiêu chung: thống nhất sự tiện lợi trong biểu diễn đồ thị toán tử PyTorch với hiệu suất thực thi gần như CUDA viết tay, trong một khung lập trình duy nhất.
Câu cuối cùng trong tweet của Tri Dao đáng để suy ngẫm thêm: “LLM và người mới bắt đầu có thể viết các nhân tốc độ ánh sáng cho mọi thao tác Transformer.” Đằng sau điều này là một logic sâu sắc hơn: khi thiết kế trừu tượng lập trình được thực hiện đủ tốt, các mô hình AI chính chúng có thể tham gia vào việc tối ưu hóa cơ sở hạ tầng đào tạo của chính chúng. Vòng lặp này mới chính là điểm khiến CODA trở nên đáng chú ý nhất.
Từ góc độ này, cái tên “CODA” có lẽ mang một ý nghĩa sâu sắc hơn. Trong âm nhạc cổ điển, Coda là đoạn kết thúc toàn bộ tác phẩm. Ở đây, nó là “kết thúc” của lõi GEMM — và việc viết nên đoạn kết thúc này có lẽ chính là chương quan trọng tiếp theo trong việc nâng cao hiệu suất hệ thống huấn luyện Transformer.
