目录
简述
JAX是谷歌开发的一个机器学习库,专门用于高性能的数值计算和自动微分。它建立在NumPy、SciPy和Cython的基础上,并结合了XLA(Accelerated Linear Algebra)编译器,可以实现高效的计算。
JAX的一个核心特性是支持自动微分。它可以自动计算函数的梯度和高阶导数,这对于训练神经网络等优化问题非常有用。JAX还支持将Python函数转换为可以在GPU或TPU上运行的高性能代码,并且通过JIT(Just-In-Time Compilation)技术可以实现动态编译,提高计算效率。
JAX的设计理念是简洁、灵活和可扩展的。它提供了一组简单而强大的API,使得用户可以方便地定义计算图和模型,并且可以通过组合和转换这些API来构建复杂的模型。JAX还提供了一些高级功能,如并行计算、分布式训练和模型评估等,以满足不同应用场景的需求。
JAX是一个开源项目,广泛应用于谷歌的机器学习项目中,如DeepMind的AlphaGo和OpenAI的GPT-3等。它的出现大大简化了机器学习领域的开发和研究工作,提高了计算效率和模型性能。
项目地址为:https://github.com/google/jax
文档地址:JAX: High-Performance Array Computing — JAX documentation
开始
JAX 的定位:
一个 Python 库,用于面向加速器的数组计算和程序转换,专为高性能数值计算和大规模机器学习而设计。
凭借其升级版 Autograd,JAX 可以自动区分本地 Python 和 NumPy 函数。它可以通过循环、分支、递归和闭包进行微分,还可以求导数的导数。它支持通过 grad 进行反向模式微分(又称反向传播),也支持正向模式微分,二者可以任意顺序组合。
新功能是,JAX 使用 XLA 在 GPU 和 TPU 上编译和运行 NumPy 程序。编译默认在引擎盖下进行,库调用会得到及时编译和执行。编译和自动微分可任意组合,因此您可以表达复杂的算法,并在不离开 Python 的情况下获得最高性能。您甚至可以使用 pmap 同时为多个 GPU 或 TPU 内核编程,并在整个过程中进行微分。
转换
JAX 的核心是一个用于转换数值函数的可扩展系统。下面是我们最感兴趣的四种变换:grad、jit、vmap 和 pmap。
自动分级
JAX 的 API 与 Autograd 大致相同。最常用的函数是用于反向模式梯度的 grad:
from jax import grad
import jax.numpy as jnp
def tanh(x): # Define a function
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # Obtain its gradient function
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
# prints 0.4199743
您可以通过 grad 区分任何顺序。
print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673
对于更高级的自动衍射,可以使用 jax.vjp 进行反向模式的向量-雅各布乘积,使用 jax.jvp 进行正向模式的雅各布-向量乘积。两者可以任意相互组合,也可以与其他 JAX 变换组合。
下面是一种组合方法,可以制作一个高效计算全 Hessian 矩阵的函数:
from jax import jit, jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
与 Autograd 一样,您可以自由使用 Python 控制结构的差异化:
def abs_val(x):
if x > 0:
return x
else:
return -x
abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0)) # prints 1.0
print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
使用 jit 进行编译
您可以使用 XLA 将您的函数与 jit 进行端到端编译,将其用作 @jit 装饰器或高阶函数。
import jax.numpy as jnp
from jax import jit
def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
您可以随意混合使用 jit 和 grad 以及任何其他 JAX 转换。
使用 jit 会对函数可以使用的 Python 控制流类型造成限制。
使用 vmap 自动矢量化
vmap 是矢量化映射。它具有我们熟悉的沿着数组轴映射函数的语义,但它并没有将循环保留在外部,而是将循环向下推入函数的原始操作中,以获得更好的性能。
使用 vmap 可以让您不必在代码中携带批量维数。
例如,请看这个简单的无批处理神经网络预测函数:
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side!
activations = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
我们通常会改写 jnp.dot(activations,W),以便在激活左侧考虑批量维度,但我们编写的这个特定预测函数只适用于单个输入向量。如果我们想同时对一批输入应用此函数,从语义上讲,我们只需写下:
from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
但一次通过网络推送一个例子会很慢!最好的办法是将计算矢量化,这样我们在每一层做的都是矩阵-矩阵乘法,而不是矩阵-向量乘法。
vmap 函数会为我们完成这种转换。也就是说,如果我们写:
from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
那么 vmap 函数将把外循环推入函数内部,我们的机器最终将执行矩阵-矩阵乘法,就像我们手工完成批处理一样。
有了 vmap,这个问题就好办了:
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
当然,vmap 可以与 jit、grad 和其他任何 JAX 变换任意组合!我们在 jax.jacfwd、jax.jacrev 和 jax.hessian 中使用 vmap 与正向和反向模式自动微分,以快速计算雅各布矩阵和哈希值矩阵。
使用 pmap 进行 SPMD 编程
要对多个加速器(如多个 GPU)进行并行编程,请使用 pmap。使用 pmap,您可以编写单程序多数据(SPMD)程序,包括快速并行集体通信操作。使用 pmap 意味着您编写的函数将由 XLA 进行编译(类似于 jit),然后在各设备间复制并并行执行。
下面是一个 8 GPU 机器上的示例:
from jax import random, pmap
import jax.numpy as jnp
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
除了表达纯粹的地图外,您还可以在设备之间使用快速的集体通信操作:
from functools import partial
from jax import lax
@partial(pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')
print(normalize(jnp.arange(4.)))
# prints [0. 0.16666667 0.33333334 0.5 ]
您甚至可以嵌套 pmap 函数,以实现更复杂的通信模式。
所有这些都是组合在一起的,所以你可以通过并行计算自由地进行区分:
from jax import grad
@pmap
def f(x):
y = jnp.sin(x)
@pmap
def g(z):
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)
print(f(x))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
在对 pmap 函数进行反向模式微分(例如使用 grad)时,计算的反向过程会像正向过程一样并行化。
让我们时刻关注JAX的动态,也许会给我们惊喜。