Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch

机器之心报道
机器之心编辑部
JAX是机器学习(ML)领域的新生力量 , 它有望使ML编程更加直观、结构化和简洁 。
在机器学习领域 , 大家可能对TensorFlow和PyTorch已经耳熟能详 , 但除了这两个框架 , 一些新生力量也不容小觑 , 它就是谷歌推出的JAX 。 很多研究者对其寄予厚望 , 希望它可以取代TensorFlow等众多机器学习框架 。
JAX最初由谷歌大脑团队的MattJohnson、RoyFrostig、DougalMaclaurin和ChrisLeary等人发起 。
目前 , JAX在GitHub上已累积13.7K星 。
Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch
文章图片
项目地址:https://GitHub.com/google/jax
迅速发展的JAX
JAX的前身是Autograd , 其借助Autograd的更新版本 , 并且结合了XLA , 可对Python程序与NumPy运算执行自动微分 , 支持循环、分支、递归、闭包函数求导 , 也可以求三阶导数;依赖于XLA , JAX可以在GPU和TPU上编译和运行NumPy程序;通过grad , 可以支持自动模式反向传播和正向传播 , 且二者可以任意组合成任何顺序 。
Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch
文章图片
开发JAX的出发点是什么?说到这 , 就不得不提NumPy 。 NumPy是Python中的一个基础数值运算库 , 被广泛使用 。 但是numpy不支持GPU或其他硬件加速器 , 也没有对反向传播的内置支持 , 此外 , Python本身的速度限制阻碍了NumPy使用 , 所以少有研究者在生产环境下直接用numpy训练或部署深度学习模型 。
在此情况下 , 出现了众多的深度学习框架 , 如PyTorch、TensorFlow等 。 但是numpy具有灵活、调试方便、API稳定等独特的优势 。 而JAX的主要出发点就是将numpy的以上优势与硬件加速结合 。
目前 , 基于JAX已有很多优秀的开源项目 , 如谷歌的神经网络库团队开发了Haiku , 这是一个面向Jax的深度学习代码库 , 通过Haiku , 用户可以在Jax上进行面向对象开发;又比如RLax , 这是一个基于Jax的强化学习库 , 用户使用RLax就能进行Q-learning模型的搭建和训练;此外还包括基于JAX的深度学习库JAXnet , 该库一行代码就能定义计算图、可进行GPU加速 。 可以说 , 在过去几年中 , JAX掀起了深度学习研究的风暴 , 推动了科学研究迅速发展 。
JAX的安装
如何使用JAX呢?首先你需要在Python环境或Googlecolab中安装JAX , 使用pip进行安装:
注意 , 上述安装方式只是支持在CPU上运行 , 如果你想在GPU执行程序 , 首先你需要有CUDA、cuDNN , 然后运行以下命令(确保将jaxlib版本映射到CUDA版本):
现在将JAX与Numpy一起导入:
JAX的一些特性
使用grad()函数自动微分:这对深度学习应用非常有用 , 这样就可以很容易地运行反向传播 , 下面为一个简单的二次函数并在点1.0上求导的示例:
Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch】jit(Justintime):为了利用XLA的强大功能 , 必须将代码编译到XLA内核中 。 这就是jit发挥作用的地方 。 要使用XLA和jit , 用户可以使用jit()函数或@jit注释 。
pmap:自动将计算分配到所有当前设备 , 并处理它们之间的所有通信 。 JAX通过pmap转换支持大规模的数据并行 , 从而将单个处理器无法处理的大数据进行处理 。 要检查可用设备 , 可以运行jax.devices():
vmap:是一种函数转换 , JAX通过vmap变换提供了自动矢量化算法 , 大大简化了这种类型的计算 , 这使得研究人员在处理新算法时无需再去处理批量化的问题 。 示例如下:
TensorFlowvsPyTorchvsJax
在深度学习领域有几家巨头公司 , 他们所提出的框架被广大研究者使用 。 比如谷歌的TensorFlow、Facebook的PyTorch、微软的CNTK、亚马逊AWS的MXnet等 。