目录

简述

开始

转换

自动分级

使用 jit 进行编译

使用 vmap 自动矢量化

使用 pmap 进行 SPMD 编程


简述

JAX是谷歌开发的一个机器学习库,专门用于高性能的数值计算和自动微分。它建立在NumPy、SciPy和Cython的基础上,并结合了XLA(Accelerated Linear Algebra)编译器,可以实现高效的计算。

JAX的一个核心特性是支持自动微分。它可以自动计算函数的梯度和高阶导数,这对于训练神经网络等优化问题非常有用。JAX还支持将Python函数转换为可以在GPU或TPU上运行的高性能代码,并且通过JIT(Just-In-Time Compilation)技术可以实现动态编译,提高计算效率。

JAX的设计理念是简洁、灵活和可扩展的。它提供了一组简单而强大的API,使得用户可以方便地定义计算图和模型,并且可以通过组合和转换这些API来构建复杂的模型。JAX还提供了一些高级功能,如并行计算、分布式训练和模型评估等,以满足不同应用场景的需求。

JAX是一个开源项目,广泛应用于谷歌的机器学习项目中,如DeepMind的AlphaGo和OpenAI的GPT-3等。它的出现大大简化了机器学习领域的开发和研究工作,提高了计算效率和模型性能。

项目地址为:https://github.com/google/jax

文档地址:JAX: High-Performance Array Computing — JAX documentation


政安晨:【示例演绎】【Python】【Google/JAX】(一)—— 专为高性能与大规模机器学习设计-LMLPHP

开始

JAX 的定位:

一个 Python 库,用于面向加速器的数组计算和程序转换,专为高性能数值计算和大规模机器学习而设计。

凭借其升级版 Autograd,JAX 可以自动区分本地 Python 和 NumPy 函数。它可以通过循环、分支、递归和闭包进行微分,还可以求导数的导数。它支持通过 grad 进行反向模式微分(又称反向传播),也支持正向模式微分,二者可以任意顺序组合。

新功能是,JAX 使用 XLA 在 GPU 和 TPU 上编译和运行 NumPy 程序。编译默认在引擎盖下进行,库调用会得到及时编译和执行。编译和自动微分可任意组合,因此您可以表达复杂的算法,并在不离开 Python 的情况下获得最高性能。您甚至可以使用 pmap 同时为多个 GPU 或 TPU 内核编程,并在整个过程中进行微分。

转换

JAX 的核心是一个用于转换数值函数的可扩展系统。下面是我们最感兴趣的四种变换:grad、jit、vmap 和 pmap。

自动分级

JAX 的 API 与 Autograd 大致相同。最常用的函数是用于反向模式梯度的 grad:

from jax import grad
import jax.numpy as jnp

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))   # Evaluate it at x = 1.0
# prints 0.4199743

您可以通过 grad 区分任何顺序。

print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673

对于更高级的自动衍射,可以使用 jax.vjp 进行反向模式的向量-雅各布乘积,使用 jax.jvp 进行正向模式的雅各布-向量乘积。两者可以任意相互组合,也可以与其他 JAX 变换组合。

下面是一种组合方法,可以制作一个高效计算全 Hessian 矩阵的函数:

from jax import jit, jacfwd, jacrev

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

与 Autograd 一样,您可以自由使用 Python 控制结构的差异化:

def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0))   # prints 1.0
print(abs_val_grad(-1.0))  # prints -1.0 (abs_val is re-evaluated)

使用 jit 进行编译

您可以使用 XLA 将您的函数与 jit 进行端到端编译,将其用作 @jit 装饰器或高阶函数。

import jax.numpy as jnp
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)  # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x)  # ~ 14.5 ms / loop (also on GPU via JAX)

您可以随意混合使用 jit 和 grad 以及任何其他 JAX 转换。

使用 jit 会对函数可以使用的 Python 控制流类型造成限制。

使用 vmap 自动矢量化

vmap 是矢量化映射。它具有我们熟悉的沿着数组轴映射函数的语义,但它并没有将循环保留在外部,而是将循环向下推入函数的原始操作中,以获得更好的性能。

使用 vmap 可以让您不必在代码中携带批量维数。

例如,请看这个简单的无批处理神经网络预测函数:

def predict(params, input_vec):
  assert input_vec.ndim == 1
  activations = input_vec
  for W, b in params:
    outputs = jnp.dot(W, activations) + b  # `activations` on the right-hand side!
    activations = jnp.tanh(outputs)        # inputs to the next layer
  return outputs                           # no activation on last layer

我们通常会改写 jnp.dot(activations,W),以便在激活左侧考虑批量维度,但我们编写的这个特定预测函数只适用于单个输入向量。如果我们想同时对一批输入应用此函数,从语义上讲,我们只需写下:

from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))

但一次通过网络推送一个例子会很慢!最好的办法是将计算矢量化,这样我们在每一层做的都是矩阵-矩阵乘法,而不是矩阵-向量乘法。

vmap 函数会为我们完成这种转换。也就是说,如果我们写:

from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))

那么 vmap 函数将把外循环推入函数内部,我们的机器最终将执行矩阵-矩阵乘法,就像我们手工完成批处理一样。

有了 vmap,这个问题就好办了:

per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)

当然,vmap 可以与 jit、grad 和其他任何 JAX 变换任意组合!我们在 jax.jacfwd、jax.jacrev 和 jax.hessian 中使用 vmap 与正向和反向模式自动微分,以快速计算雅各布矩阵和哈希值矩阵。

使用 pmap 进行 SPMD 编程

要对多个加速器(如多个 GPU)进行并行编程,请使用 pmap。使用 pmap,您可以编写单程序多数据(SPMD)程序,包括快速并行集体通信操作。使用 pmap 意味着您编写的函数将由 XLA 进行编译(类似于 jit),然后在各设备间复制并并行执行。

下面是一个 8 GPU 机器上的示例:

from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]

除了表达纯粹的地图外,您还可以在设备之间使用快速的集体通信操作:

from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')

print(normalize(jnp.arange(4.)))
# prints [0.         0.16666667 0.33333334 0.5       ]

您甚至可以嵌套 pmap 函数,以实现更复杂的通信模式。
所有这些都是组合在一起的,所以你可以通过并行计算自由地进行区分:

from jax import grad

@pmap
def f(x):
  y = jnp.sin(x)
  @pmap
  def g(z):
    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
  return grad(lambda w: jnp.sum(g(w)))(x)

print(f(x))
# [[ 0.        , -0.7170853 ],
#  [-3.1085174 , -0.4824318 ],
#  [10.366636  , 13.135289  ],
#  [ 0.22163185, -0.52112055]]

print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726,  -1.6356447],
#  [  4.7572474,  11.606951 ],
#  [-98.524414 ,  42.76499  ],
#  [ -1.6007166,  -1.2568436]]

在对 pmap 函数进行反向模式微分(例如使用 grad)时,计算的反向过程会像正向过程一样并行化。


让我们时刻关注JAX的动态,也许会给我们惊喜。

04-03 17:14