ここでは、Tr(H) を計算する方法についてまとめる。使うのは Hutchinson Trace Estimation。 詳細はこちらを参照。
以下を hessian.py として保存する。
import torch
import numpy as np
def compute_hessian_trace(model, data_loader, num_samples, batch_size, device):
"""
Huchinson's Hessian trace estimation
Args:
model: PyTorch model
data_loader: DataLoader
num_samples: Number of samples to use for the estimation
batch_size: Batch size
device: Device to run the computation on
Returns:
float: Average Hessian trace
"""
model.eval()
hessian_traces = []
for _ in range(num_samples // batch_size):
batch = next(iter(data_loader))
batch = {k: v.to(device) for k, v in batch.items()}
# Forward pass
outputs = model(**batch)
loss = outputs.loss
# Hessian trace computation
hessian_trace = _compute_hessian_trace(model, loss)
hessian_traces.append(hessian_trace)
return np.mean(hessian_traces)
def _compute_hessian_trace(model, loss):
"""
Function to compute the Hessian trace of a PyTorch model
Args:
model: PyTorch model
loss: Loss tensor
Returns:
float: trace of the Hessian
"""
filtered_params = []
# Filter out the parameters that we don't want to compute the Hessian for
for name, param in model.named_parameters():
# TODO / Modify for your model
if not (name in ['bert.pooler.dense_act.weight', 'bert.pooler.dense_act.bias', 'cls.predictions.bias']):
filtered_params.append(param)
# Compute the gradient of the loss w.r.t. the parameters = jacobian
# ∇L
j_vals = torch.autograd.grad(
loss, filtered_params, create_graph=True, allow_unused=True
)
# Initialize random vectors as same size as j_vals
# v
rand_vec_list = []
for ind, j_val in enumerate(j_vals):
if j_val is not None:
rand_vec_list.append(torch.zeros_like(j_val).normal_(0, 1))
grad_dot = 0
# Compute the dot product of the jacobian and the random vector
# ∇L * v
for ind, (j_val, vec) in enumerate(zip(j_vals, rand_vec_list)):
if j_val is not None:
grad_dot += torch.sum(j_val * vec)
# Compute the Hessian-vector product
# H * v = ∇(∇L * v)
hessian_vec_prod_dict = torch.autograd.grad(
grad_dot, filtered_params, allow_unused=True
)
hvp_sum = 0
# Compute the dot product of the random vector and the Hessian-vector product
# v * H * v = v * ∇(∇L * v)
# Then, sum all the values
# ∑ v * H * v
for ind, (vec, hv) in enumerate(zip(rand_vec_list, hessian_vec_prod_dict)):
if hv is not None:
hvp_sum += torch.sum(hv * vec).item()
return hvp_sum
使い方は以下の通り。
import hessian
"""
define model, train_loader, device
Then, run the following code
"""
print("Calc Hessian")
hessian_trace = hessian.compute_hessian_trace(copy.deepcopy(model),
train_loader,
num_samples=1000,
batch_size=16,
device=device)
print(f"Hessian Trace: {hessian_trace}")
Hessian トレースの計算では、モデルの全パラメータに対して勾配を計算する必要があるため、計算コストが非常に高くなります。そのため、Hessian トレースの計算から除外しても精度に大きな影響を与えないレイヤーを特定し、計算コストを削減することが重要です。 一般的に、以下のレイヤーは Hessian トレースの計算から除外されることが多いです。
モデル | 除外されるレイヤー |
---|---|
BERT | bert.pooler, cls.predictions.bias |
RoBERTa | roberta.pooler, lm_head.bias |
GPT-2 | lm_head.bias |
ResNet | Batch Normalization レイヤーの bias パラメータ |