Hiroki Naganuma

ここでは、Tr(H) を計算する方法についてまとめる。使うのは Hutchinson Trace Estimation。 詳細はこちらを参照。

Code

以下を hessian.py として保存する。


import torch
import numpy as np

def compute_hessian_trace(model, data_loader, num_samples, batch_size, device):
    """
    Huchinson's Hessian trace estimation

    Args:
        model: PyTorch model
        data_loader: DataLoader
        num_samples: Number of samples to use for the estimation
        batch_size: Batch size
        device: Device to run the computation on

    Returns:
        float: Average Hessian trace
    """

    model.eval()
    hessian_traces = []

    for _ in range(num_samples // batch_size):
        batch = next(iter(data_loader))
        batch = {k: v.to(device) for k, v in batch.items()}

        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss

        # Hessian trace computation
        hessian_trace = _compute_hessian_trace(model, loss)
        hessian_traces.append(hessian_trace)

    return np.mean(hessian_traces)


def _compute_hessian_trace(model, loss):
    """
    Function to compute the Hessian trace of a PyTorch model

    Args:
        model: PyTorch model
        loss: Loss tensor

    Returns:
        float: trace of the Hessian
    """

    filtered_params = []
    # Filter out the parameters that we don't want to compute the Hessian for
    for name, param in model.named_parameters():
        # TODO / Modify for your model
        if not (name in ['bert.pooler.dense_act.weight', 'bert.pooler.dense_act.bias', 'cls.predictions.bias']):
            filtered_params.append(param)

    # Compute the gradient of the loss w.r.t. the parameters = jacobian
    # ∇L
    j_vals = torch.autograd.grad(
        loss, filtered_params, create_graph=True, allow_unused=True
    )

    # Initialize random vectors as same size as j_vals
    # v
    rand_vec_list = []
    for ind, j_val in enumerate(j_vals):
        if j_val is not None:
            rand_vec_list.append(torch.zeros_like(j_val).normal_(0, 1))

    grad_dot = 0
    # Compute the dot product of the jacobian and the random vector
    # ∇L * v
    for ind, (j_val, vec) in enumerate(zip(j_vals, rand_vec_list)):
        if j_val is not None:
            grad_dot += torch.sum(j_val * vec)

    # Compute the Hessian-vector product
    # H * v = ∇(∇L * v)
    hessian_vec_prod_dict = torch.autograd.grad(
        grad_dot, filtered_params, allow_unused=True
    )

    hvp_sum = 0
    # Compute the dot product of the random vector and the Hessian-vector product
    # v * H * v = v * ∇(∇L * v)
    # Then, sum all the values
    # ∑ v * H * v
    for ind, (vec, hv) in enumerate(zip(rand_vec_list, hessian_vec_prod_dict)):
        if hv is not None:
            hvp_sum += torch.sum(hv * vec).item()

    return hvp_sum

使い方は以下の通り。


import hessian

"""
define model, train_loader, device
Then, run the following code
"""

print("Calc Hessian")
hessian_trace = hessian.compute_hessian_trace(copy.deepcopy(model), 
                                              train_loader, 
                                              num_samples=1000, 
                                              batch_size=16, 
                                              device=device)
print(f"Hessian Trace: {hessian_trace}")

除外すべきレイヤー

Hessian トレースの計算では、モデルの全パラメータに対して勾配を計算する必要があるため、計算コストが非常に高くなります。そのため、Hessian トレースの計算から除外しても精度に大きな影響を与えないレイヤーを特定し、計算コストを削減することが重要です。 一般的に、以下のレイヤーは Hessian トレースの計算から除外されることが多いです。

モデル別の除外レイヤーの例

モデル 除外されるレイヤー
BERT bert.pooler, cls.predictions.bias
RoBERTa roberta.pooler, lm_head.bias
GPT-2 lm_head.bias
ResNet Batch Normalization レイヤーの bias パラメータ