import jax.numpy as jnp from jax import grad def tanh(x): y = jnp.exp(-2.0 * x) return (1.0 - y) / (1.0 + y) grad_tanh = grad(tanh) print(grad_tanh(1.0)) # 0.4199743 print(grad(grad(grad(tanh)))(1.0)) # 0.62162673
vmap() や pmap()
vmap()
pmap()
JAX入門~高速なNumPyとして使いこなすためのチュートリアル~
JAX学習記録③ーAutomatic Vectorization and Differentiation
JAX/Flaxを使ってMNISTを学習させてみる