Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch
JAX 是机器学习 (ML) 领域的新生力量,它有望使 ML 编程更加直观、结构化和简洁。
$ pip install --upgrade jax jaxlib
$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
import jax
import jax.numpy as jnp
import numpy as np
from jax import grad
def f(x):
return 3*x**2 + 2*x + 5
def f_prime(x):
return 6*x +2
grad(f)(1.0)
# DeviceArray(8., dtype=float32)
f_prime(1.0)
# 8.0
from jax import jit
x = np.random.rand(1000,1000)
y = jnp.array(x)
def f(x):
for _ in range(10):
x = 0.5*x + 0.1* jnp.sin(x)
return x
g = jit(f)
%timeit -n 5 -r 5 f(y).block_until_ready()
# 5 loops, best of 5: 10.8 ms per loop
%timeit -n 5 -r 5 g(y).block_until_ready()
# 5 loops, best of 5: 341 µs per loop
from jax import pmap
def f(x):
return jnp.sin(x) + x**2
f(np.arange(4))
#DeviceArray([0. , 1.841471 , 4.9092975, 9.14112 ], dtype=float32)
pmap(f)(np.arange(4))
#ShardedDeviceArray([0. , 1.841471 , 4.9092975, 9.14112 ], dtype=float32)
from jax import vmap
def f(x):
return jnp.square(x)
f(jnp.arange(10))
#DeviceArray([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
vmap(f)(jnp.arange(10))
#DeviceArray([ 0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
它们是开源的。这意味着如果库中存在错误,使用者可以在 GitHub 中发布问题(并修复),此外你也可以在库中添加自己的功能;
由于全局解释器锁,Python 在内部运行缓慢。所以这些框架使用 C/C++ 作为后端来处理所有的计算和并行过程。
这是一个非常友好的框架,高级 API-Keras 的可用性使得模型层定义、损失函数和模型创建变得非常容易;
TensorFlow2.0 带有 Eager Execution(动态图机制),这使得该库更加用户友好,并且是对以前版本的重大升级;
Keras 这种高级接口有一定的缺点,由于 TensorFlow 抽象了许多底层机制(只是为了方便最终用户),这让研究人员在处理模型方面的自由度更小;
Tensorflow 提供了 TensorBoard,它实际上是 Tensorflow 可视化工具包。它允许研究者可视化损失函数、模型图、模型分析等。
与 TensorFlow 不同,PyTorch 使用动态类型图,这意味着执行图是在运行中创建的。它允许我们随时修改和检查图的内部结构;
除了用户友好的高级 API 之外,PyTorch 还包括精心构建的低级 API,允许对机器学习模型进行越来越多的控制。我们可以在训练期间对模型的前向和后向传递进行检查和修改输出。这被证明对于梯度裁剪和神经风格迁移非常有效;
PyTorch 允许用户扩展代码,可以轻松添加新的损失函数和用户定义的层。PyTorch 的 Autograd 模块实现了深度学习算法中的反向传播求导数,在 Tensor 类上的所有操作, Autograd 都能自动提供微分,简化了手动计算导数的复杂过程;
PyTorch 对数据并行和 GPU 的使用具有广泛的支持;
PyTorch 比 TensorFlow 更 Python 化。PyTorch 非常适合 Python 生态系统,它允许使用 Python 类调试器工具来调试 PyTorch 代码。
正如官方网站所描述的那样,JAX 能够执行 Python+NumPy 程序的可组合转换:向量化、JIT 到 GPU/TPU 等等;
与 PyTorch 相比,JAX 最重要的方面是如何计算梯度。在 Torch 中,图是在前向传递期间创建的,梯度在后向传递期间计算, 另一方面,在 JAX 中,计算表示为函数。在函数上使用 grad() 返回一个梯度函数,该函数直接计算给定输入的函数梯度;
JAX 是一个 autograd 工具,不建议单独使用。有各种基于 JAX 的机器学习库,其中值得注意的是 ObJax、Flax 和 Elegy。由于它们都使用相同的核心并且接口只是 JAX 库的 wrapper,因此可以将它们放在同一个 bracket 下;
Flax 最初是在 PyTorch 生态系统下开发的,更注重使用的灵活性。另一方面,Elegy 受 Keras 启发。ObJAX 主要是为以研究为导向的目的而设计的,它更注重简单性和可理解性。
NVIDIA对话式AI开发工具NeMo的应用
8月12日开始,英伟达专家将带来三期直播分享,通过理论解读和实战演示,展示如何使用 NeMo 快速完成文本分类任务、快速构建智能问答系统、构建智能对话机器人。
直播链接:https://jmq.h5.xeknow.com/s/how4w(点击阅读原文直达)
报名方式:进入直播间——移动端点击底部「观看直播」、PC端点击「立即学习」——填写报名表单后即可进入直播间观看。
交流答疑群:直播间详情页扫码即可加入。
© THE END
投稿或寻求报道:[email protected]