Hiroki Naganuma

Overview

import jax.numpy as jnp
from jax import grad

def tanh(x):
    y = jnp.exp(-2.0 * x)
    return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)
print(grad_tanh(1.0))  # 0.4199743
print(grad(grad(grad(tanh)))(1.0))  # 0.62162673

Misc