TensorFlow,危!抛弃者正是谷歌自己

萧箫丰色发自凹非寺
量子位|公众号QbitAI收获接近16.6万个Star、见证深度学习崛起的TensorFlow , 地位已岌岌可危 。
并且这次 , 冲击不是来自老对手PyTorch , 而是自家新秀JAX 。
最新一波AI圈热议中 , 连fast.ai创始人JeremyHoward都下场表示:
JAX正逐渐取代TensorFlow这件事 , 早已广为人知了 。 现在它就在发生(至少在谷歌内部是这样) 。
TensorFlow,危!抛弃者正是谷歌自己
文章图片
LeCun更是认为 , 深度学习框架之间的激烈竞争 , 已经进入了一个新的阶段 。
TensorFlow,危!抛弃者正是谷歌自己
文章图片
LeCun表示 , 当初谷歌的TensorFlow确实比Torch更火 。 然而Meta的PyTorch出现之后 , 现在其受欢迎程度已经超过TensorFlow了 。
现在 , 包括GoogleBrain、DeepMind以及不少外部项目 , 都已经开始用上JAX 。
典型例子就是最近爆火的DALL·EMini , 为了充分利用TPU , 作者采用了JAX进行编程 。 有人用过后感叹:
这可比PyTorch快多了 。
TensorFlow,危!抛弃者正是谷歌自己
文章图片
据《商业内幕》透露 , 预计在未来几年内 , JAX将覆盖谷歌所有采用机器学习技术的产品 。
这样看来 , 如今大力在内部推广JAX , 更像是谷歌在框架上发起的一场“自救” 。
JAX从何而来?关于JAX , 谷歌其实是有备而来 。
早在2018年的时候 , 它就由谷歌大脑的一个三人小团队给搭出来了 。
研究成果发表在了题为Compilingmachinelearningprogramsviahigh-leveltracing的论文中:
TensorFlow,危!抛弃者正是谷歌自己
文章图片
Jax是一个用于高性能数值计算的Python库 , 而深度学习只是其中的功能之一 。
TensorFlow,危!抛弃者正是谷歌自己
文章图片
自诞生以来 , 它受欢迎的程度就一直在上升 。
最大的特点就是快 。
一个例子感受一下 。
比如求矩阵的前三次幂的和 , 用NumPy实现 , 计算需要约478毫秒 。
TensorFlow,危!抛弃者正是谷歌自己
文章图片
用JAX就只需要5.54毫秒 , 比NumPy快86倍 。
TensorFlow,危!抛弃者正是谷歌自己
文章图片
为什么这么快?原因有很多 , 包括:
1、NumPy加速器 。 NumPy的重要性不用多说 , 用Python搞科学计算和机器学习 , 没人离得开它 , 但它原生一直不支持GPU等硬件加速 。
JAX的计算函数API则全部基于NumPy , 可以让模型很轻松在GPU和TPU上运行 。 这一点就拿捏住了很多人 。
2、XLA 。 XLA(AcceleratedLinearAlgebra)就是加速线性代数 , 一个优化编译器 。 JAX建立在XLA之上 , 大幅提高了JAX计算速度的上限 。
3、JIT 。 研究人员可使用XLA将自己的函数转换为实时编译(JIT)版本 , 相当于通过向计算函数添加一个简单的函数修饰符 , 就可以将计算速度提高几个数量级 。
除此之外 , JAX与Autograd完全兼容 , 支持自动差分 , 通过grad、hessian、jacfwd和jacrev等函数转换 , 支持反向模式和正向模式微分 , 并且两者可以任意顺序组成 。
当然 , JAX也是有一些缺点在身上的 。
比如:
1、虽然JAX以加速器著称 , 但它并没有针对CPU计算中的每个操作进行充分优化 。
2、JAX还太新 , 没有形成像TensorFlow那样完整的基础生态 。 因此它还没有被谷歌以成型产品的形式推出 。
3、debug需要的时间和成本不确定 , “副作用”也不完全明确 。