लेख CODA: Transformer ब्लॉक्स को GEMM-Epilogue प्रोग्राम्स के रूप में पुनः लिखना नामक एक नए अनुसंधान का परिचय देता है, जिसका मुख्य लक्ष्य Transformer मॉडल ट्रेनिंग की दक्षता को अनुकूलित करना है, विशेष रूप से उन ऐसे "मेमोरी-भारी" ऑपरेशन्स को हल करना जो अकेले तो छोटे लगते हैं लेकिन जमा होकर बहुत अधिक समय लेते हैं।
लेखक, स्रोत: मशीन ऑफ़ द इंडिया
22 मई को, ट्री डाओ ने सोशल मीडिया पर हान गुओ का एक ट्वीट शेयर किया। उन्होंने लिखा: "कुछ गणितीय पुनर्लेखन के बाद, पता चला कि ट्रांसफॉर्मर का सारा कंटेंट GEMM + epilogue (मैट्रिक्स गुणन और एपिलॉग) की एक श्रृंखला है। कुछ ऑप्टिमाइज़्ड प्राइमिटिव्स के साथ, LLM (और नए) सभी ट्रांसफॉर्मर ऑपरेशन्स के लिए लाइटनिंग-स्पीड कर्नेल लिख सकते हैं!"

Tri Dao, FlashAttention श्रृंखला के प्रमुख लेखकों में से एक हैं, और यह ट्वीट उनके द्वारा उसी दिन प्रकाशित एक पेपर: CODA की ओर इशारा करता है।

- शीर्षक: CODA: ट्रांसफॉर्मर ब्लॉक्स को GEMM-एपिलॉग प्रोग्राम के रूप में पुनः लिखना
- कागजात का पता: https://arxiv.org/abs/2605.19269
- कोड का पता: https://github.com/HanGuo97/coda-kernels
इस नाम को पढ़ने पर 「अंतिम संगीत」 जैसा लगता है, और बोलने पर 「CUDA」 जैसा लगता है। MIT, प्रिंसटन, Together AI और Meta के शोधकर्ता, Transformer प्रशिक्षण में उन दुर्लभ रूप से ध्यान दिए जाने वाले, लेकिन लगातार समय खर्च करने वाले 「टुकड़े-टुकड़े की गणनाओं」 को एक नए प्रोग्रामिंग अमूर्तीकरण के साथ व्यवस्थित रूप से समाप्त करने का प्रयास कर रहे हैं।
बैकग्राउंड: लार्ज मॉडल के प्रशिक्षण का 'लाज़ी टैक्स'
CODA क्या समस्या हल कर रहा है, इसे समझने के लिए, सबसे पहले यह समझना आवश्यक है कि बड़े मॉडल प्रशिक्षण का समय कहाँ जाता है।
एक न्वीडिया H100 पर एक 1B पैरामीटर LLaMA-3 शैली का मॉडल ट्रेन करते समय, अधिकांश लोग यह मानेंगे कि समय मैट्रिक्स गुणन और ध्यान गणना पर खर्च हो रहा है, क्योंकि वे ही 'वास्तविक गणना' हैं। यह सीधी समझ लगभग सही है: मैट्रिक्स गुणन (GEMM) और ध्यान वास्तव में प्रमुख कंप्यूटेशनल भार लेते हैं।

लेकिन अगर आप प्रदर्शन विश्लेषक को ध्यान से देखते हैं, तो आप देखेंगे कि कुछ "छोटे ऑपरेटर" भी चुपचाप समय खर्च कर रहे हैं: नॉर्मलाइज़ेशन (RMSNorm), एक्टिवेशन फ़ंक्शन (SwiGLU, RoPE), रेसिड्यूअल जोड़, क्रॉस-लेयर रिडक्शन... उनकी अकेली गणना मात्रा छोटी है, लेकिन वे बार-बार बड़े मध्यवर्ती टेंसर को वीडियो मेमोरी से बाहर और अंदर ले जाते हैं।

यही «मेमोरी बैंडविड्थ बॉटलनेक» कहलाता है: एक ऐसे शीर्ष रसोइये की तरह जिसे हर व्यंजन बनाने के लिए सामग्री को दूर के गोदाम से लाना पड़ता है और इस्तेमाल के बाद वापस भेजना पड़ता है, न कि इसे अपनी मेज पर रखकर। जितना तेज़ रसोइया हाथ चलाए, ले जाने-लाने का समय वास्तविक बर्बादी है।
और खराब बात यह है कि जैसे-जैसे निविडा के FP8, FP4 जैसे निम्न-सटीकता फॉर्मेट मैट्रिक्स कैलकुलेशन को और तेज़ बना रहे हैं, इन "ट्रांसफर" ऑपरेशन की सापेक्ष लागत बढ़ रही है: मैट्रिक्स गुणन तेज़ हो रहा है, लेकिन टेंसर को अंदर-बाहर ले जाने की लागत समान अनुपात में कम नहीं हो रही है।
एक अध्ययन में एक स्पष्ट डेटा सेट है: H100 पर TorchTitan का उपयोग करके 1B पैरामीटर मॉडल को प्रशिक्षित करते समय, मैट्रिक्स गुणन के अलावा के ऑपरेशन एंड-टू-एंड रनटाइम का काफी हिस्सा लेते हैं, और FP8 प्रिसिजन के शुरू होने से यह अनुपात और अधिक बढ़ जाता है।
मौजूदा प्रोग्रामिंग फ्रेमवर्क इसके लिए लगभग असमर्थ हैं। PyTorch Transformer की गणना को ऑपरेटर सीरीज के रूप में व्यक्त करता है, जिनके बीच स्पष्ट सीमाएँ होती हैं। ये सीमाएँ स्वचालित अवकलन (autograd) के लिए अत्यंत अनुकूल हैं, लेकिन ऑपरेटर के बीच संलयन अनुकूलन को ठीक से रोकती हैं: प्रत्येक ऑपरेटर सीमा, अक्सर एक अनावश्यक GPU मेमोरी राइटबैक होती है।
CODA: "अंत" में खजाना छिपा है
CODA का उद्देश्य एक साधारण अवलोकन से शुरू होता है।
GPU पर, एक उच्च प्रदर्शन वाला मैट्रिक्स गुणन (GEMM) कर्नल संरचनात्मक रूप से दो भागों में विभाजित होता है: मुख्य लूप (mainloop) मैट्रिक्स ब्लॉक गुणन-जोड़ की मूल गणना करता है, और अंतिम भाग (epilogue) परिणाम को VRAM में लिखने से पहले, जैसे कि विस्थापन जोड़ना, प्रकार परिवर्तन, सरल स्केलिंग, जैसे कुछ अंतिम संसाधनों को संभालता है।

अंत का अर्थ यह है कि इस समय मैट्रिक्स गुणन का आउटपुट अभी भी ऑन-चिप रजिस्टर में "जीवित" है, और अभी तक ग्लोबल वीएमएम में नहीं उतरा है। यह एक क्षणिक स्वर्णिम खिड़की है: यदि इस समय अतिरिक्त गणनाएँ की जा सकती हैं, तो एक वीएमएम लिखने और पुनः पढ़ने की पूरी यात्रा को पूरी तरह से बचाया जा सकता है।
CODA का मुख्य अवलोकन है: ट्रांसफॉर्मर में उन स्मृति-घने संचालनों को, जिनमें से कई को बीजगणितीय रूप से पुनः पैरामीटराइज़ किया जा सकता है और इस 「अंतिम」 खिड़की में निष्पादित किया जा सकता है।
इसके लिए कुछ गणितीय कौशल की आवश्यकता होती है। सबसे सामान्य GEMM-RMSNorm-GEMM पैटर्न के साथ उदाहरण लें: एक मैट्रिक्स गुणन का परिणाम, जिसके बाद रेसिडुअल जोड़, RMS नॉर्मलाइज़ेशन, और फिर एक अन्य मैट्रिक्स गुणन होता है। पारंपरिक तरीके में, तीन स्वतंत्र ऑपरेटर क्रमिक रूप से निष्पादित होते हैं, और मध्यवर्ती परिणाम दो बार डिस्प्ले मेमोरी में सहेजे जाते हैं।

CODA टीम ने पाया कि RMS नॉर्मलाइज़ेशन में प्रत्येक पंक्ति के लिए साझा स्केलर r, इसके बाद के मैट्रिक्स गुणन के साथ क्रमविनिमेय है: r के अनुप्रयोग को 'दूसरे GEMM से पहले' से 'दूसरे GEMM के अंत तक' स्थानांतरित किया जा सकता है। इस स्थानांतरण के बाद, पहले GEMM के अंत में केवल स्थानीय 'पार्शियल RMS' की गणना की जाती है, जिसे एक अत्यंत हल्का सहायक रिडक्शन कर्नेल मिलाता है, और पूर्ण RMSNorm गणना समाप्त हो जाती है।
इसी प्रकार का पुनःपैरामीटरीकरण SwiGLU, RoPE (घूर्णन स्थिति कोडिंग), क्रॉस-एंट्रॉपी लॉस आदि ऑपरेशन्स के लिए भी लागू होता है, और यह बैकप्रोपगेशन के लिए भी सत्य है। पेपर में एक प्रमेय साबित किया गया है: जब तक फॉरवर्ड पास का अंत "ब्लॉक-लोकल" है, बैकप्रोपगेशन स्वतः उसी संरचना को विरासत में प्राप्त कर लेता है। विस्तार से जानने के लिए मूल पेपर देखें।
पाँच 「ब्लॉक्स」 और एक 「लेगो भाषा」
CODA एक विशिष्ट फ्यूजन कर्नेल नहीं है, बल्कि एक प्रोग्रामिंग अमूर्ति का सेट है।
यह विशेषज्ञों द्वारा अनुकूलित GEMM मुख्य लूप को स्थिर करता है और अंत में पांच प्रकार के संयोज्य मूल तत्वों को उजागर करता है:
- तत्ववार परिवर्तन (अवशेष जोड़, सक्रियण फ़ंक्शन, RoPE)
- वेक्टर लोड और स्टोर (ब्रॉडकास्ट RMSNorm वजन)
- मैट्रिक्स ब्लॉक लोडिंग और स्टोरिंग (रिवर्स प्रोपेगेशन के लिए मध्यवर्ती एक्टिवेशन सहेजें)
- Block Reduction (Local RMS, Block Log-Sum-Exp)
- स्थिति परिवर्तन (ऑनलाइन नॉर्मलाइजेशन के लिए आवश्यक max और sum-exp सांख्यिकी)
इन पाँच प्रकार के ब्लॉक्स का उपयोग करके, एक मानक Transformer के फॉरवर्ड और बैकवर्ड प्रोपेगेशन में ध्यान के अलावा लगभग सभी ऑपरेशन को कवर किया जा सकता है।
अधिक दिलचस्प बात यह है कि यह अमूर्तीकरण «कौन कोड लिखे» के लिए कितना लचीला है। शोध पत्र में दो कार्यान्वयन पैटर्न का मूल्यांकन किया गया: एक मानव प्रोग्रामर द्वारा लिखा गया, और दूसरा Claude Code द्वारा उत्पन्न — CODA के प्राथमिक विवरण, कुछ उदाहरण और कार्यान्वयन लॉग के साथ, AI द्वारा अधिकांश कोर कोड तैयार किया गया, और मानव द्वारा हल्की निगरानी की गई।
दोनों मोड का प्रदर्शन उच्च स्तर पर पहुंच गया। ट्री डाओ ने ट्वीट में कहा कि "LLM और नए उपयोगकर्ता भी लाइटस्पीड कर्न लिख सकते हैं", जो पेपर के परीक्षण परिणामों का वास्तविक दुनिया में प्रतिबिंब है।
परीक्षण परिणाम
CODA के बेंचमार्क में अपेक्षाकृत कठिन प्रतिद्वंद्वी चुने गए हैं: cuBLAS और torch.compile, और LLM के लिए अनुकूलित Liger Kernel और FlashInfer।
प्रत्येक कोर के लिए दो वास्तुकला मूल्यांकन किए गए: CODA (LLM), जिसे Claude Code द्वारा उत्पन्न किया गया, जहाँ शोधकर्ताओं ने प्राइमिटिव्स के विवरण, कुछ उदाहरण और एक निरंतर अद्यतन किए जाने वाले लागू करने के टिप्स लॉग प्रदान किए, जिससे AI मुख्य कोड पूरा करता है और मानवीय निगरानी हल्की होती है; CODA (Human), जिसे मानव प्रोग्रामर्स द्वारा स्वतंत्र रूप से लिखा गया, जो समान उच्च-स्तरीय पुनर-पैरामीट्राइजेशन दृष्टिकोण का उपयोग करते हैं, लेकिन CODA प्राइमिटिव्स सेट पर निर्भर नहीं होते। दोनों समूहों के परिणाम cuBLAS + torch.compile, Liger Kernel, FlashInfer आदि अनुकूलित पुस्तकालयों के साथ तुलना किए गए।
एकल ऑपरेटर स्तर पर, GEMM-RMSNorm-GEMM जैसे एक आम पैटर्न के संदर्भ में, CODA ने 1B, 7B, 70B तीन मॉडल साइज़ के लिए छिपे हुए आयाम पर cuBLAS + PyTorch बेसलाइन को पार कर लिया है। SwiGLU, RoPE, क्रॉस-एंट्रॉपी जैसे अंतिम संयोजनों में भी समान प्रदर्शन देखा गया है।
LLM द्वारा उत्पन्न कर्नल अधिकांश बेंचमार्क पर मानव द्वारा हस्तलिखित संस्करण के समान प्रदर्शन करते हैं, और कुछ विशिष्ट कॉन्फ़िगरेशन में थोड़ा बेहतर भी हैं। यह GPU कर्नल अनुकूलन जैसे पारंपरिक रूप से अत्यधिक कठिन क्षेत्र में एक अत्यंत दुर्लभ निष्कर्ष है।



रिवर्स प्रोपेगेशन के लिए लाभ विशेष रूप से स्पष्ट हैं: GEMM-Residual-PartialRMS-GEMM के रिवर्स कर्नेल में बेसलाइन की तुलना में 1.6 से 1.8 गुना तक की त्वरितता है, और SwiGLU रिवर्स में लगभग 1.4 से 1.6 गुना की वृद्धि है। इस दिशा में, LLM और मैनुअल अनुकूलन के बीच का अंतर भी बहुत कम है। यह आश्चर्यजनक नहीं है: रिवर्स प्रोपेगेशन स्वाभाविक रूप से अधिक मध्यवर्ती टेंसर्स के एक्सेस को शामिल करता है, जिससे टेल-फ्यूजन का लाभ अधिक होता है; और CODA के प्राइमिटिव्स का डिज़ाइन पर्याप्त स्पष्ट है, जिससे AI मॉडल सही ढंग से संयोजन पूरा कर सकता है।

पूर्ण Transformer स्तर के एंड-टू-एंड बेंचमार्क में, CODA का फॉरवर्ड त्वरण विभिन्न आकारों पर लगभग 5% से 20% है, और बड़े मॉडल आकारों (70B स्केल के छिपे हुए आयाम के संगत) में यह प्रभाव अधिक स्पष्ट होता है।
संख्यात्मक सटीकता के संदर्भ में, CODA का पुनः पैरामीटरीकरण RMSNorm स्केलिंग फैक्टर के अनुप्रयोग के समय को समायोजित करता है, लेकिन प्रयोगों से पता चलता है कि इसकी संख्यात्मक त्रुटि PyTorch संदर्भ कार्यान्वयन के समान है, और कुछ कॉन्फ़िगरेशन में यह त्रुटि और भी कम है — GEMM मुख्य लूप में उच्चतर सटीकता वाले एक्यूमुलेटर के कारण।
CODA क्या कर सकता है: बड़े दृश्य में प्रवेश करने से पहले, CODA की क्षमताओं की सीमा स्पष्ट करते हैं।
- कवरेज: स्टैंडर्ड ट्रांसफॉर्मर (जैसे LLaMA आर्किटेक्चर) के फॉरवर्ड और बैकवर्ड प्रोपेगेशन में, ध्यान और शब्द एम्बेडिंग के अलावा, RMSNorm, रेसिड्यूअल जोड़, SwiGLU एक्टिवेशन, RoPE रोटेशनल पोजिशनल कोडिंग, क्रॉस-एंट्रॉपी लॉस, और उपरोक्त ऑपरेशन्स के बैकवर्ड ग्रेडिएंट कैलकुलेशन शामिल हैं।
- त्वरण प्रभाव: 1B से 70B छिपे हुए आयामों के लिए, cuBLAS + torch.compile बेसलाइन की तुलना में एकल ऑपरेटर स्तर पर विभिन्न स्तरों का त्वरण प्राप्त होता है, जिसमें प्रतिगमन का लाभ सबसे अधिक है (कुछ कर्नेल्स 1.6x से अधिक); पूर्ण Transformer परत का एंड-टू-एंड फॉरवर्ड त्वरण लगभग 5% से 20% है, और बड़े मॉडल आकारों में यह प्रभाव अधिक स्पष्ट होता है।
- CODA, जो CuTeDSL (NVIDIA CUTLASS का Python DSL) पर आधारित है, दोनों मानव प्रोग्रामर और AI मॉडल के लिए कर्नेल लिखने की सुविधा प्रदान करता है, और दोनों तरीके उच्च प्रदर्शन प्राप्त करते हैं।
- वर्तमान सीमाएँ: वर्तमान में केवल एकल GPU स्थिति का समर्थन किया जाता है, वितरित प्रशिक्षण शामिल नहीं है; पुनर्पैरामीटरीकरण मुख्य रूप से मानक Transformer आर्किटेक्चर के लिए है, अन्य आर्किटेक्चर की उपयुक्तता की पुष्टि अभी नहीं की गई है।
अंतिम शब्द
CODA एक अलग कार्य नहीं है। यह एक विचार का वास्तविक कार्यान्वयन है: GPU पर, वास्तविक अनुकूलन का क्षेत्र अक्सर 'क्या गणना करें' में नहीं, बल्कि 'कैसे स्थानांतरित करें' में होता है।
FlashAttention ने ध्यान गणना को चिप पर मेमोरी में "रख दिया", CODA ने सामान्यीकरण और सक्रियण फ़ंक्शन को भी "रखने" की कोशिश की। Triton ने कस्टम कर्न लिखने की बाधा कम की, ThunderKittens, TileLang आदि ने विभिन्न स्तरों पर इस क्षेत्र का अध्ययन किया। ये सभी कार्य एक ही दिशा की ओर इशारा करते हैं: PyTorch ऑपरेटर ग्राफ़ की अभिव्यक्ति की सुविधा और हस्तलिखित CUDA के समीप की निष्पादन दक्षता को एक समान प्रोग्राम करने योग्य ढांचे में वास्तविक रूप से एकीकृत करना।
ट्री डाओ के ट्वीट का अंतिम वाक्य फिर से सोचने लायक है: "LLM और नए उपयोगकर्ता सभी Transformer ऑपरेशन के लिए प्रकाश की गति के कर्नेल लिख सकते हैं।" इसके पीछे एक गहरा तर्क है: जब प्रोग्रामिंग अमूर्तीकरण पर्याप्त रूप से अच्छा होता है, तो AI मॉडल स्वयं अपनी प्रशिक्षण बुनियादी ढांचे के अनुकूलन में भाग ले सकता है। यह चक्र, CODA का सबसे रोचक पहलू है।
इस दृष्टिकोण से, "CODA" नाम का अर्थ शायद अलग है। शास्त्रीय संगीत में, Coda गीत के अंत में समाप्ति करने वाला अंश होता है। यहाँ, यह GEMM कोर का "अंतिम भाग" है—और इस अंतिम भाग को अच्छी तरह से लिखना, Transformer प्रशिक्षण प्रणाली की दक्षता में सुधार का अगला महत्वपूर्ण अध्याय हो सकता है।
