Hiroki Naganuma

「Modular Manifolds」(Jeremy Bernstein, Thinking Machines, 2025)

https://thinkingmachines.ai/blog/modular-manifolds/

Introduction

Normalization

Necessity of Weight Normalization

重み行列の正規化は、活性化や勾配に比べると一般的ではありませんが、無視できない効果があります。 たとえば:

Benefits of Weight Normalization

重み行列を正規化・制約することには以下のような利点があります:

Purpose of This Article

本記事では、ニューラルネットワークの各層において 重み行列を多様体 (manifold) 上に制約するというアプローチを解説します。 これは単なる制約ではなく、最適化アルゴリズムと多様体構造を共同設計(co-design)するための枠組みです。 その一例として、著者らは Stiefel 多様体(単位条件数をもつ行列の集合)上で動作する Muon Optimizer の多様体版 を提案します。 これを “Manifold Muon” と呼びます。

The Shape of a Manifold Optimizer(多様体最適化器の形状)

この章では、多様体上での学習の最も単純な例として、
$\mathbb{R}^d$ 空間内のベクトルパラメータが 超球面(hypersphere)上に制約されている場合 を考えます。

このベクトルパラメータは、全空間 $\mathbb{R}^d$ 上で定義された損失関数を最小化するよう訓練されます。
この設定は、例えば Transformer モデルにおける個々の埋め込みベクトルの学習 などに応用できます。
この章は、次章「Manifold Muon(行列パラメータ版)」の理解のウォームアップになります。

多様体とは何か?

ここでは形式的な定義に立ち入らず、多様体とは「局所的には平らに見える曲面」 であると理解すれば十分です。
多様体上のある点において、その局所的な平面近似を 接空間 (tangent space) と呼びます。

図1:3次元球面(または高次元の超球面)は多様体であり、
その局所的な平坦近似(赤い平面)は、接空間として可視化できる。

超球面と接空間

$d$ 次元空間における単位超球面は、次のように定義されます:

\[S^{d-1} = \{\, \mathbf{w} \in \mathbb{R}^d \mid \| \mathbf{w} \|_2 = 1 \,\}\]

超球面上の点 $\mathbf{w}$ における接空間は、$\mathbf{w}$ に直交するすべてのベクトルの集合です:

\[T_{\mathbf{w}}S^{d-1} = \{\, \mathbf{a} \in \mathbb{R}^d \mid \mathbf{a}^\top \mathbf{w} = 0 \,\}\]

接空間内での更新と射影

重みを多様体上に保つ最も単純な方法は、通常のオプティマイザを使い、
更新後に重みを再び多様体上へ射影 (projection) することです。

しかし、ここではより洗練された方法を採ります。
すなわち、接空間内で直接ステップを取るように最適化を設計します。

その理由は、学習率 (learning rate) をステップの実際の長さと一致させたいからです。
更新方向が多様体から外れすぎて投影によって戻される場合、この関係が崩れてしまいます。

接空間での距離の定義

多様体をリーマン多様体とするためには、距離が内積によって誘導される (ベクトルの『内積』という操作さえ定義すれば、そこから自動的に『ベクトルの長さ(ノルム)』と『2つのベクトル間の距離』を自然な形で計算できる) 必要があります。
最も一般的な距離尺度はユークリッド距離ですが、他の距離も選択できます。

図2:接空間における単位球の形状。
$\ell_2$ ノルムでは円、$\ell_1$ ノルムでは菱形になります。

距離尺度の選択は、最適化ステップの方向を大きく変えます。
もし非ユークリッド距離を採用すれば、同じステップ長でも勾配方向とは異なる方向に進むほうが
より損失を下げられることがあります。

図3:幾何的形状が更新方向に与える影響。
ピンクの矢印が勾配、黄の菱形が $\ell_1$ 単位球、緑の矢印が最適な更新方向です。

数式での定式化

次に、ユークリッド距離を仮定した超球面上での最適更新方向を導出します。
勾配を $\mathbf{g}$、現在の重みを $\mathbf{w}$、更新方向を $\mathbf{a}$、学習率を $\eta$ とします。

求めるべき問題は次のように書けます:

\[\begin{aligned} \min_{\mathbf{a} \in \mathbb{R}^d} \quad & \mathbf{a}^\top \mathbf{g} \\ \text{s.t.} \quad & \| \mathbf{a} \|_2 = \eta \quad \text{(大きさ制約)} \\ & \mathbf{a}^\top \mathbf{w} = 0 \quad \text{(接空間制約)} \end{aligned} \tag{★}\]

ラグランジュ法による解法

ラグランジュ関数を

\[\mathcal{L}(\mathbf{a}, \lambda, \mu) = \mathbf{a}^\top \mathbf{g} + \frac{\lambda}{2} (\mathbf{a}^\top \mathbf{a} - \eta) + \mu (\mathbf{a}^\top \mathbf{w})\]

とおきます。
これを $\mathbf{a}$ で微分し、制約を適用して解くと、最適解は次の形になります:

\[\mathbf{a}_{\text{opt}} = -\eta \frac{ \mathbf{g} - \mathbf{w}(\mathbf{w}^\top \mathbf{g}) }{ \| \mathbf{g} - \mathbf{w}(\mathbf{w}^\top \mathbf{g}) \|_2 }\]

すなわち、「勾配のうち $\mathbf{w}$ に沿う成分を除去し、正規化して学習率を掛ける」という更新です。

射影(リトラクションマップ)

更新後に多様体上へ戻すための操作を リトラクションマップ (retraction map) と呼びます。
ユークリッドノルムの単位球面では、ピタゴラスの定理から次の式が導かれます:

\[\mathbf{w} \leftarrow \frac{1}{\sqrt{1+\eta^2}} \left[ \mathbf{w} - \eta \frac{ \mathbf{g} - \mathbf{w}(\mathbf{w}^\top \mathbf{g}) } { \| \mathbf{g} - \mathbf{w}(\mathbf{w}^\top \mathbf{g}) \|_2 } \right]\]

多様体最適化アルゴリズム(要約)

  1. 勾配方向で最も遠くへ進む単位長の接ベクトルを求める
  2. それに学習率を掛けて重みから減算する
  3. リトラクションマップで多様体上に戻す

他の最適化手法との対応関係

多様体 ノルム オプティマイザ
ユークリッド空間 $\mathbb{R}^n$ ユークリッドノルム 標準的な勾配降下法
ユークリッド空間 $\mathbb{R}^n$ 無限ノルム sign勾配降下法
超球面 $S^n$ ユークリッドノルム 球面勾配降下法
行列空間 $\mathbb{R}^{m\times n}$ スペクトルノルム Muon
Stiefel多様体 $\subset \mathbb{R}^{m\times n}$ スペクトルノルム Manifold Muon

Manifold Muon(多様体版 Muon Optimizer)

Transformer における典型的な重み行列 $\mathbf{W}$ は、ベクトル変換器 (vector-multiplier) として機能します。
すなわち、入力ベクトル $\mathbf{x}$ を受け取り、出力ベクトル $\mathbf{y} = \mathbf{W}\mathbf{x}$ を生成します。

このとき、行列 $\mathbf{W}$ が入力に対して「過剰に引き伸ばす」または「過剰に縮める」ことのないように、
良い振る舞いを保証する多様体制約と距離関数を設計することが目的です。
また、更新によって出力ベクトルが極端に変化しないようにしたいという動機もあります。

行列の伸縮を理解する:特異値分解 (SVD)

行列がどのようにベクトルを変換するかを理解するための自然な方法は、特異値分解 (Singular Value Decomposition) です。

\[\mathbf{M} \in \mathbb{R}^{m \times n}, \quad \text{rank}(\mathbf{M}) = k\] \[\mathbf{M} = \mathbf{U} \mathbf{\Sigma} \mathbf{V}^\top\]

ここで:

図5:SVD は行列の伸縮作用を、軸ごとのスケーリングとして可視化する。

Stiefel 多様体(単位特異値制約)

理想的な行列は、ベクトルをちょうど「1倍」に伸縮する(=単位特異値を持つ)ものです。
そのような行列の集合が Stiefel 多様体 と呼ばれます。

\[\text{Stiefel}(m, n) := \{\, \mathbf{W} \in \mathbb{R}^{m \times n} \mid \mathbf{W}^\top \mathbf{W} = \mathbf{I}_n \,\}\]

ここでは $m \ge n$(縦長行列)を仮定します。
この制約は、前章の超球面制約 $\mathbf{w}^\top \mathbf{w} = 1$ の自然な一般化です。

接空間の条件

行列 $\mathbf{A} \in \mathbb{R}^{m \times n}$ が $\mathbf{W}$ における Stiefel 多様体の接空間に属するのは、次の条件を満たすときです:

\[\mathbf{A}^\top \mathbf{W} + \mathbf{W}^\top \mathbf{A} = 0\]

これは、$\mathbf{A}$ が「多様体上で許される方向」であることを意味します。

スペクトルノルムによる距離の定義

行列更新の際に、入力ベクトルを過剰に伸縮させないようにするため、
距離関数として スペクトルノルム (spectral norm) を採用します。
これは行列の最大特異値を測るノルムです:

\[\|\mathbf{A}\|_{\text{spectral}} = \sigma_{\max}(\mathbf{A})\]

この制約により、「最大限どれだけ出力を変化させうるか」を抑えつつ、
更新が十分な大きさを持つように調整できます。

Manifold Muon の最適化問題

超球面版(前章の式★)を一般化して、Manifold Muon の最適化問題は次のように書けます:

\[\begin{aligned} \min_{\mathbf{A} \in \mathbb{R}^{m \times n}} \quad & \text{Tr}(\mathbf{G}^\top \mathbf{A}) \\ \text{s.t.} \quad & \|\mathbf{A}\|_{\text{spectral}} \le \eta \quad \text{(大きさ制約)} \\ & \mathbf{A}^\top \mathbf{W} + \mathbf{W}^\top \mathbf{A} = 0 \quad \text{(接空間制約)} \end{aligned} \tag{†}\]

ここで $\mathbf{G}$ は勾配です。

Jianlin Su と Franz Cesista による洞察

この問題 (†) は凸最適化問題であり、標準的な手法である 双対上昇法 (dual ascent) によって解くことができます。
著者はこの方法を、Jianlin Su と Franz Louis Cesista の先行研究を基に発展させました。

双対変数としてラグランジュ乗数 $\mathbf{\Lambda} \in \mathbb{R}^{n \times n}$ を導入すると、
次のように変形できます:

\[\max_{\mathbf{\Lambda}} \; -\eta \, \|\, \mathbf{G} + 2\mathbf{W}(\mathbf{\Lambda} + \mathbf{\Lambda}^\top) \,\|_{\text{nuclear}} \tag{1}\]

ここで $|\cdot|_{\text{nuclear}}$ は 核ノルム (nuclear norm)、すなわち特異値の総和です。

双対関数の勾配

双対関数の勾配は次式で表されます:

\[\mathbf{H}(\mathbf{\Lambda}) = -\eta \Big[ \mathbf{W}^\top \operatorname{msign}(\mathbf{G} + 2\mathbf{W}(\mathbf{\Lambda} + \mathbf{\Lambda}^\top)) + \operatorname{msign}(\mathbf{G} + 2\mathbf{W}(\mathbf{\Lambda} + \mathbf{\Lambda}^\top))^\top \mathbf{W} \Big]\]

ここで $\operatorname{msign}$ は 行列符号関数 (matrix sign function) であり、
行列の特異値を ±1 に「スナップ」させる演算です。
(Newton–Schulz 反復法や Polar Express アルゴリズムで効率的にGPU実装可能。)

Manifold Muon アルゴリズム

最終的な Manifold Muon Optimizer は以下のステップで構成されます:

  1. 双対変数更新(dual ascent) \(\mathbf{\Lambda} \leftarrow \mathbf{\Lambda} + \alpha \, \mathbf{H}(\mathbf{\Lambda})\) (ここで $\alpha$ はステップサイズ)

  2. 最適更新の計算 \(\mathbf{A}_{\text{opt}} = -\eta \, \operatorname{msign}\!\left( \mathbf{G} + 2\mathbf{W}(\mathbf{\Lambda}_{\text{opt}} + \mathbf{\Lambda}_{\text{opt}}^\top) \right)\)

  3. 重み更新 \(\mathbf{W} \leftarrow \mathbf{W} + \mathbf{A}_{\text{opt}}\)

  4. リトラクション(多様体への射影) \(\mathbf{W} \leftarrow \operatorname{msign}(\mathbf{W})\)

実験結果(Sanity Check)

著者は小規模な MLP を CIFAR-10 データセット上で3エポック訓練し、
Manifold MuonAdamW を比較しました。

Manifold Muon の1ステップ当たりの時間は AdamW より長いものの、
双対上昇のステップ数を減らすか、モメンタムを導入すれば改善可能です。
システム的なボトルネックがなければ、このオーバーヘッドは実用上問題にならないと考えられます。

Modular Manifolds(モジュラー多様体)

ここまでの議論では、個々のパラメータテンソル(例:重み行列)に対して
多様体制約を課し、それに適合する最適化手法を設計してきました。

では、複数の層を組み合わせてネットワーク全体を構築する場合はどうなるでしょうか?
層ごとに独立して考えられるのでしょうか?
それとも、層間の相互作用を考慮して最適化ロジックを調整する必要があるでしょうか?

この章では、そうした問題を解決する枠組みとして
モジュラー多様体(Modular Manifold)理論を紹介します。

基本理念

モジュラー多様体の目的は、
ネットワーク全体における「層ごとの学習率の予算配分 (learning rate budgeting)」を
数学的に一貫して扱う抽象化を構築することです。

層内の最適化ロジック自体は、これまで導出したものと同じですが、
層の位置に応じて学習率がスケールされるように調整します。

この抽象化は、筆者らの論文「The Modular Norm」で導入された概念に基づいており、
重み変化に対する出力感度(Lipschitz感度)の理解を軸としています。
多様体制約を導入することで、この感度の解析がより厳密に行えるようになります。

ニューラルネットワークモジュールの形式化

任意のニューラルネットワークモジュール(例:層〜トランスフォーマーブロック)は、
次の3つの属性をもつ数学的対象として定式化できます:

  1. 順伝播関数 (forward function)
    \(f : \mathcal{W} \times \mathcal{X} \to \mathcal{Y}\) ここで $\mathcal{W} = \mathbb{R}^d$ はパラメータ空間、
    $\mathcal{X}$ は入力空間、$\mathcal{Y}$ は出力空間。

  2. 多様体制約 (manifold constraint)
    \(\mathcal{M} \subseteq \mathcal{W}\)

  3. ノルム (norm)
    \(\|\cdot\| : \mathcal{W} \to \mathbb{R}\)

例:Stiefel Linear モジュール

たとえば、スペクトルノルムを持ち、Stiefel 多様体に制約された線形層は次のように表せます:

\[\text{StiefelLinear} = \Big\{ \begin{array}{ll} (\mathbf{W}, \mathbf{x}) \mapsto \mathbf{W}\mathbf{x}, & \text{(forward function)} \\ \text{Stiefel}(m,n), & \text{(manifold)} \\ \|\cdot\|_{\text{spectral}} & \text{(norm)} \end{array} \Big\}\]

入力 $\mathbf{x}$ の $\ell_2$ ノルムが 1 のとき、
このモジュールは重み $\mathbf{W}$ に関してリプシッツ定数 1 を持ちます:

\[\|(\mathbf{W} + \Delta\mathbf{W})\mathbf{x} - \mathbf{W}\mathbf{x}\|_2 \le \|\Delta\mathbf{W}\|_{\text{spectral}} \cdot \|\mathbf{x}\|_2 = \|\Delta\mathbf{W}\|_{\text{spectral}}\]

したがって、重み更新のスケールが出力変化をどの程度引き起こすかを
定量的に把握することができます。

モジュールの合成(Composition)

では、2つのモジュールを組み合わせたとき、
新しいモジュールに対しても同様のリプシッツ境界を自動的に導けるでしょうか?

答えは Yes です。
ただし、いくつかの特別なルールに従って新しいモジュールを構築する必要があります。

1. 新しい順伝播関数

\[f_3((\mathbf{w}_1, \mathbf{w}_2), \mathbf{x}) = f_2(\mathbf{w}_2, f_1(\mathbf{w}_1, \mathbf{x}))\]

すなわち、既存のモジュール $f_1, f_2$ の順伝播を関数合成するだけです。


2. 新しい多様体制約

\[\mathcal{M}_3 = \mathcal{M}_1 \times \mathcal{M}_2\]

これは単なる 直積多様体 (Cartesian product) です。


3. 新しいノルム関数

新しいノルム $|\cdot|_3$ は、各モジュールのノルムをスカラー係数 $s_1, s_2$ で重みづけた最大値で定義します:

\[\|(\mathbf{w}_1, \mathbf{w}_2)\|_3 = \max(s_1 \|\mathbf{w}_1\|_1, \; s_2 \|\mathbf{w}_2\|_2)\]

モジュラー最適化の結果

この複合ノルムに基づいて最適化アルゴリズムを導出すると、
各層ごとに独立したオプティマイザが得られます。
ただし、係数 $s_i$ が層ごとの学習率の配分比率(budgeting factor)となります。

各層の学習率は、ネットワーク内での位置やリプシッツ感度に応じて調整される。

関連研究

この構成のより詳細な理論は以下で扱われています:

直積多様体の可視化

図7:多様体の直積。
線(1次元多様体)と円盤(2次元多様体)の直積は円柱。
各点に円盤が1枚ずつ貼り付けられた形をしている。