内存占用小,訓練表現也要好……大模型訓練成功實現二者兼得。
來自北理、北大和港中文 MMLab 的研究團隊提出了一種滿足低秩約束的大模型全秩訓練框架——Fira,成功打破了傳統低秩方法中内存占用與訓練表現的 " 非此即彼 " 僵局。
展開來說——
爲了突破内存瓶頸,許多低秩訓練方法應運而生,如 LoRA(分解參數矩陣)和 GaLore(分解梯度矩陣)。
△圖 1:從宏觀層面分析三種内存高效低秩訓練方法
然而,如上圖所示,LoRA 将訓練局限于參數的低秩子空間,降低了模型的表征能力,難以實現預訓練;GaLore 将訓練局限于梯度的低秩子空間,造成了子空間外梯度的信息損失。
相較于全秩訓練,這兩種方法由于施加了低秩約束,會導緻訓練表現有所下降。
但是,若提高秩值,則會相應地增加内存占用。
因此,在實際應用中,它們需要在确保訓練表現與降低内存消耗之間找到一個恰當的平衡點。
這引發了一個核心問題:
能否在維持低秩約束以确保内存高效的同時,實現全秩參數、全秩梯度的訓練以提升表現?
Fira 即爲最新答案,它有三大亮點:
即插即用:Fira 簡單易用,其核心實現僅涉及兩行關鍵公式,現已封裝進 Python 庫,可直接融入現有的大模型訓練流程中,替換原有優化器。代碼示例如下:
from fira import FiraAdamW, divide_paramsparam_groups = divide_params ( model, target_modules_list = [ "Linear" ] , rank=8 ) optimizer = FiraAdamW ( param_groups, lr=learning_rate )
雙赢解決方案:在維持低秩約束的前提下,Fira 實現了大模型的全秩訓練,打破了内存占用與訓練表現的取舍難題。與此同時,區别于系統方法(如梯度檢查點),Fira 不以時間換内存;
實驗驗證:Fira 在多種規模的模型(60M 至 7B 參數)以及預訓練和微調任務中均展現出卓越性能,優于現有的 LoRA 和 GaLore,甚至能達到或超越全秩訓練的效果。
打造 Fira 訓練框架
Fira 訓練框架由兩部分組成:
1 ) 基于梯度模長的縮放策略:利用了團隊在大模型低秩和全秩訓練中發現的共通點——自适應優化器對原始梯度的修正效應,實現了低秩約束下的全秩訓練。
2 ) 梯度模長限制器,通過限制梯度模長的相對增長比例,解決了大模型訓練中常出現的損失尖峰問題。
背景動機
大模型訓練常常面臨顯著的内存瓶頸,尤其是其中的優化器狀态。
舉例來說,使用 Adam 優化器從頭預訓練一個 LLaMA 7B 模型(batchsize 爲 1,精度爲 BF16)可能需要至少 58GB 内存。
其中 14GB 用于加載參數,14GB 用于儲存梯度,28GB 用于儲存優化器狀态,剩下 2GB 用于儲存激活值。
在這之中,優化器狀态所占内存甚至要大于參數本身。
因此,使用低秩方法來減少這一部分内存,實現大模型的内存高效訓練十分重要。
而在現有的低秩方法中,LoRA 通過分解參數矩陣,使用低秩适配器來減少内存占用;Galore 通過分解梯度矩陣,在自适應優化器中儲存低秩梯度來減少内存占用。
鑒于使用 LoRA 低秩适配器方法來實現全參數訓練的困難性,團隊選擇拓展 Galore 的梯度投影方法來實現全秩訓練。
在 Galore 中,全秩梯度 G ∊ ℝ mxn,會被投影矩陣 P ∊ ℝ mxr 分解成兩項低秩梯度 PR 和(G — PR),其中。
爲減少像 Adam 這樣的自适應優化器在内存中對應的狀态占用,Galore 僅在優化器核心中保留低秩梯度 R,而非全秩梯度 G。
而另一項梯度(G — PR),則會因爲缺少對應的優化器狀态,被 Galore 直接丢棄,從而造成嚴重的信息損失。
這也解釋了,爲什麽 Galore 的性能會在 rank 值減小時,顯著衰減。
△圖 2:Fira 與 Galore 及其變體的訓練損失對比
爲了彌補上述信息損失,最直觀的方法是直接加上這一部分梯度(G — PR):
其中,W 是參數矩陣, 是學習率。
然而,如圖所示,使用這種方法(Galore-add)不僅未能帶來性能提升,反而可能導緻訓練過程更加不穩定,且結果更差。
分析原因可歸結于這一部分的梯度缺乏優化器狀态,直接使用會退化爲單純的 SGD 算法,并且可能與前面使用的 Adam 優化器的梯度不匹配,導緻效果不佳。
基于梯度模長的縮放策略
爲了解決上述挑戰,團隊提出了scaling factor 概念,來描述 Adam 這樣的自适應優化器對原始梯度的修正效應,并揭示了它在大模型的低秩訓練和全秩訓練之間的相似性。
其中, 就是 scaling factor,代表經過優化器修正過的梯度與原始梯度的模長比例。
如下圖,如果根據 scaling factor 的平均值對參數矩陣進行排序,可以發現低秩和全秩之間的排序非常相似。
△圖 3:scaling factor 在大模型低秩和全秩訓練間的相似性
基于這個觀察,團隊就嘗試在矩陣層面用低秩梯度 R 的 scaling factor,作爲全秩梯度 G 的 scaling factor 的替代,從而近似地修正(G — PR),彌補其缺少的優化器狀态:
這樣團隊就在低秩約束下成功實現了全秩訓練。
進一步來說,剛才是從矩陣層面來考慮 scaling factor。
順理成章地,團隊可以從更細粒度的角度——列的層面,來考慮 scaling factor,實現更加精細地修正。
其中 R,:, 是低秩梯度 R 的第 i 列,
是 scaling factor 的第 i 項。
梯度模長限制器
在訓練過程中,梯度常常會突然增大,導緻損失函數出現尖峰,從而影響訓練的表現。
經過分析,可能原因是 Galore 在切換投影矩陣時存在不穩定性,以及維持(G — PR)這種原始梯度的方向的方式,無法像 Adam 這樣的自适應算法,有效應對大模型訓練中存在的陡峭損失景觀。
△圖 4:3 種 Fira 變體的訓練損失與梯度模長
然而,常見的梯度裁剪方法(如圖中的 Fira-gradient-clipping)由于采用絕對裁剪,難以适應不同參數矩陣間梯度的較大差異,從而可能導緻次優的訓練結果。
爲此,團隊提出了一種新的梯度模長限制器,它通過限制梯度模長的相對增長比例,來更好地适應不同梯度的變化:
其中是比例增長的上限,S=(R ) ( G — PR)是原始梯度(G — PR)修正後的結果。
通過提出的控制梯度相對增長比例的方法,能夠将梯度的驟然增大轉化爲平緩的上升,從而有效穩定訓練過程。
如圖 2 和圖 3 所示,團隊的限制器成功避免了損失函數的尖峰情況,并顯著提升了訓練表現。
實驗結果
如下表所示,在預訓練任務中,Fira 在保持内存高效的前提下,驗證集困惑度(↓)顯著超過各類基線方法,甚至超越全秩方法。
具體來說,在預訓練 LLaMA 1B 模型時,Fira 節約了61.1%優化器狀态所占内存,并且取得了比全秩訓練更加好的結果。
△使用 C4 數據集預訓練不同大小的 LLaMA 模型驗證集困惑度(↓)對比
在預訓練LLaMA 7B模型時,Fira 在使用了比 Galore 小 8 倍的秩 rank 的情況下,訓練表現遠超 Galore。
這展現了 Fira 在大規模大模型上的有效性,以及相較 Galore 更高的内存減少能力。
△使用 C4 數據集預訓練 LLaMA 7B 的驗證集困惑度(↓)對比
在八個常識推理數據集微調 LLaMA 7B 的任務中,相較其他基線方法,Fira 在一半的數據集下表現最好,平均準确率最高的同時實現了内存高效。
△在八個常識推理數據集微調 LLaMA 7B 準确率對比
另外,消融實驗也顯示了:
Fira-w.o.-scaling 說明了 Fira 使用基于梯度模長的縮放策略的有效性;
Fira-matrix 說明了從更細粒度的列級别,而不是矩陣級别,考慮 scaling factor 的有效性;
Fira-w.o.-limiter 說明了 Fira 中梯度模長限制器的有效性;
Fira-gradient-clipping 說明了梯度裁剪可能無法完全解決損失尖峰問題,導緻結果次優。
△消融實驗
與 GaLore 相比,Fira 的表現幾乎不受秩 rank 值減少的影響。
在低秩的情況下(rank=16, rank=4),Fira 仍然能與全秩訓練相當,相較 Galore 更加内存高效。
△不同 rank 下的預訓練驗證集困惑度(↓)
最後,團隊在不同模型大小,以及低秩和全秩條件下,訓練 10,000 步,并對得到的矩陣和列級别上 Scaling factor 做平均。
接着,使用了斯皮爾曼(Spearman)和肯德爾(Kendall)相關系數分析了 Scaling factor 在矩陣和列級别上大小順序的相關性。
其中,Coefficient 中 1 代表完全正相關,-1 代表完全負相關,而 P-value 越小越好(通常小于 0.05 爲顯著)。
在所有規模的 LLaMA 模型中,Scaling factor 在矩陣和列的級别上都表現出很強的正相關關系,并且所有的 P-value 小于 0.05,非常顯著,爲 Fira 中基于梯度模長的縮放策略提供了堅實的實驗基礎。
△矩陣和列級别上的 Scaling factor 低秩與全秩相似性分析
更多細節歡迎查閱原論文。
論文鏈接:https://arxiv.org/abs/2410.01623
代碼倉庫:https://github.com/xichen-fy/Fira
— 完 —
投稿請發郵件到:
标題注明【投稿】,告訴我們:
你是誰,從哪來,投稿内容
附上論文 / 項目主頁鏈接,以及聯系方式哦
我們會(盡量)及時回複你
點這裏關注我,記得标星哦~
一鍵三連「分享」、「點贊」和「在看」
科技前沿進展日日相見 ~
>