這兩天,FlashAttention 團隊推出了新作:
一種給 Transformer 架構大模型推理加速的新方法,最高可提速 8 倍。
該方法尤其造福于長上下文LLM,在64k 長度的 CodeLlama-34B上通過了驗證。
甚至得到了PyTorch 官方認可:
如果你之前有所關注,就會記得用給大模型加速效果真的很驚豔。
不過它僅限于訓練階段。
因此,這一新成果一出,就有網友表示:
等推理加速等了好久,終于來了。
據介紹,這個新方法也是在 FlashAttention 的基礎之上衍生而出,主要思想也不複雜:
用并行操作盡快加載 Key 和 Value 緩存,然後分别重新縮放再合并結果,最終獲得推理速度上的大幅提升。
提速 8 倍的長上下文推理方法來了
該方法被命名爲Flash-Decoding。
背景與動機
根據作者介紹:
LLM 的推理(即 " 解碼 ")過程是叠代的,即一次生成一個 token,組成一個完整句子需要 n 個 token 以及 n 次前向傳遞。
不過,由于我們可以緩存之前計算出來的 token,所以單個生成步驟并不總是依賴于上下文長度。
但有一個操作例外:注意力 ( attention ) ,它不能随着上下文長度靈活擴展。
鑒于長上下文已成趨勢,比如目前最大的開源 LLM 已達 100k(CodeLlama),我們不得不注意到 attention 在大模型推理過程中浪費了太多時間,時間就是金錢。
更别提 attention 在 batch size 上進行擴展時,即使模型上下文相對較短,它也可能成爲性能瓶頸(因爲模型要讀取的内存量與 batch size 成比例,而它僅取決于模型其餘部分的大小)。
怎麽破解?
不可複用的 FlashAttention 優化
模型在推理也就是解碼過程中,爲了計算 softmax ( queries @keys.transpose ) @values 這兩個值,生成的每個新 token 都需要關注先前的所有 token。
團隊先前的工作 FlashAttention,已經在訓練階段對此操作進行了優化。
當時,FlashAttention 解決的主要瓶頸是讀寫中間結果的内存帶寬(例如,Q @ K^T)。
然而,在推理階段,我們要面對的瓶頸變了,導緻 FlashAttention 所做的優化并不能直接拿過來應用。
具體而言:
在階段階段,FlashAttention 在batch size 和查詢長度維度上進行并行化。
在推理階段,查詢長度通常爲 1,這意味着如果 batch size 小于 GPU 上的流式多處理器數量(例如,A100 爲 108),該操作将僅使用 GPU 的一小部分。
這對于長上下文情況尤甚,因爲長上下文需要較小的 batch size 才能适應 GPU 内存。
所以,結果就是,當 batch size 爲 1 時,FlashAttention 将隻占用不足 1% 的 GPU,非常不劃算。
當然,你可能會說,不用 FlashAttention 也行,用矩陣乘法原語來完注意力操作。
不過,作者指出,這種情況又會完全占用 GPU,并啓動非常多的寫入和讀取中間結果的内核,也不是最佳辦法。
Flash-Decoding 誕生
最終,基于以上考量,作者在 FlashAttention 的基礎上,添加了一個新的并行化緯度:key 和 value 序列長度。
這個方法(即 Flash-Decoding)結合上述兩種方法的優點:
與 FlashAttention 一樣,它在全局内存中存儲的額外數據非常少,但隻要上下文長度足夠大,即使 batch size 很小,它也可以充分利用 GPU。
詳細來看,Flash-Decoding 一共分爲三個步驟:
1、先将 key 和 value 值分成更小的塊。
2、用 FlashAttention 并行計算每塊分割的查詢注意力。并爲每行和每塊分割寫入一個額外标量:注意力值的 log-sum-exp。
3、最後,通過減少所有分割來計算實際輸出,使用 log-sum-exp 來 scale 每塊分割的貢獻。
作者指出,由于 attention/softmax 可以叠代計算,以上所有操作均可行。
并且在 Flash-Decoding 中,ttention/softmax 既可以在分割塊内,也可以跨分割塊來執行最終的縮減,隻不過後者可縮減的步驟很少。
而在實際操作中,步驟 1 不涉及任何 GPU 操作,因爲 key 和 value 塊是完整的張量視圖。然後由 2 個獨立的内核分别執行步驟 2 和 3。
最高提速 8 倍
驗證環節,作者在 CodeLLaMa-34b(架構與 Llama 2 相同)上對其解碼吞吐量進行了基準測試。
具體以 tok/s 爲單位,測量了 512 到 64k 序列長度下的解碼速度(上限爲從内存中讀取整個模型以及 KV 緩存所需的時間),并和多種計算注意力的方法進行對比,包括:
Pytorch,使用純 PyTorch 原語運行注意力
FlashAttention v2
FasterTransformer:使用 FasterTransformer 注意力内核
最終,Flash-Decoding 最高可将長序列解碼速度提升 8 倍,并比其他方法具 有更好的擴展性(受長度影響較小)
此外,作者還在 A100 上對各種序列長度和 batch size 的縮放多頭注意力進行了微基準測試。
結果顯示,當序列長度擴展到 64k 時,Flash-Decoding 實現了幾乎恒定的運行時間。
如何使用?
以下是 Flash-Decoding 的獲取途徑,戳文末官方博客即可找到地址:
FlashAttention 包,2.2 版本及以上
xFormers 包,0.0.22 版本及以上
調度程序将根據問題的大小自動使用 Flash-Decoding 或 FlashAttention 方法。
團隊介紹
目前 Flash-Decoding 還沒出論文,但作者團隊已透露,這次不再是Tri Dao" 單打獨鬥 ",不過一作仍然是他。
Tri Dao 今年博士畢業于斯坦福,7 月份加盟大模型創業公司 Together AI 擔任首席科學家。
明年 9 月将上任普林斯頓大學助理教授,他是 FlashAttention v1 和 v2 的主要作者。
剩下三位作者分别是:
Daniel Haziza,Facebook AI Research 研究工程師,主要負責 xformers(用于訓練加速的開源框架);
Francisco Massa,同 Facebook AI Research 研究工程師, 主要從事 PyTorch 相關工作;
Grigory Sizov,Meta 機器學習工程師,主要工作是優化 GPU 上的 LLM 推理和其他 AI 工作負載,爲 PyTorch 生态做出過貢獻。
官方博客:
https://princeton-nlp.github.io/flash-decoding/
參考鏈接:
https://twitter.com/tri_dao/status/1712904220519944411?s=20