Skip to the content.

Natural Gradient

Original Paper

FIM

Two Approximation Methods

K-FAC (Kronecker-factored Approximate Curvature) :

One of approximation methods for Natural Gradient.

Original Paper

Screenshot 2022-03-17 at 3 06 25 PMScreenshot 2022-03-17 at 3 14 56 PM

ML で使われる Fisher は x の条件付き分布における Fisher 情報行列 (他分野では Empirical Fisher と言われるけど、 K-FAC とか機械学習の最適化コミュニティでは True Fisher として扱われてる、唐木田さんの論文や理論の人は忠実な Fisher のみ True Fisher としてる) 機械学習の最適化コミュニティでは、Covariance のことを Empirical Fisher という

実装

疑問 -> Answer は最後に PDF を添付

    for batch_idx, (inputs, targets) in prog_bar:
        inputs, targets = inputs.to(args.device), targets.to(args.device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        if optim_name in ['kfac', 'ekfac'] and optimizer.steps % optimizer.TCov == 0:
            # compute true fisher
            optimizer.acc_stats = True
            with torch.no_grad():
                # 一番確率の高い sample を持ってくる / 厳密には pred の分布をもとに バッチ内の output 1つにつき、1だけ smaple を持ってくる
                sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().cuda() 
            loss_sample = criterion(outputs, sampled_y)
            loss_sample.backward(retain_graph=True)
            optimizer.acc_stats = False
            optimizer.zero_grad()  # clear the gradient for computing true-fisher.
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) 重みのテンソルを作成する
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 4) # ERROR!
RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False,
not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])

I would like to thank Chaoqi Wang for his help on clarifing my understanding.