vlambda博客
学习文章列表

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接

选自arXiv

编辑:杜伟、小舟

这个新型 Python 框架对库开发者和用户都大有裨益。


近年来,深度学习领域的进展与深度学习框架的开发同步进行。这些框架为自动微分和 GPU 加速提供了高级且高效的 API,从而可以利用相对较少和简单的代码实现极度复杂和强大的深度学习模型。

最初,Theano、Caffe、MXNet、TensorFlow 和 CNTK 等很多流行的深度学习框架使用的是基于图的方法。用户首先需要定义一个静态数据流图(static data flow graph),然后可以对它进行高效地微分、编译并在 GPU 上执行。所以,提前了解整个计算图有助于实现高性能。

但是,这种方法导致难以调试模型以及实现具有变化图(changing graph)的动态模型(如 RNN)。

所以,针对这种方法的局限性,深度学习模型的 Eager Execution 成为了深度学习研究领域的主流方法。用户不再需要提前构建静态数据流图,Eager Execution 框架自身就可以提供 define-by-run 的 API,它可以高速地构建临时的动态图。目前,两大主流深度学习框架 PyTorch 和 TensorFlow 都在使用 eager execution 方法。

而在本文中,来自德国图宾根大学和图宾根伯恩斯坦计算神经科学中心的研究者将 eager execution 进行了扩展,提供了一个新的 Python 框架 EagerPy,它可以编写自动且原生地适配 PyTorch、TensorFlow、Jax 和 Numpy 的代码。EagerPy 对库开发者和用户都有裨益。

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接



EagerPy 能够编写与框架无关(framework-agnostic)的代码,这些代码可以与 PyTorch、TensorFlow、Jax 和 NumPy 实现原生地适配。

这样一来,首先对于新库开发者而言,他们不仅可以选择同时支持上述这几个主流深度学习框架或者为每个框架重新实现库,而且可以对代码重复进行处理。

其次对于这些库的使用者而言,他们也可以更轻松地切换深度学习框架,并且不会被特定的第三方库锁定。

不仅如此,单个框架的使用者也会从 EagerPy 中获益,这是因为 EagerPy 提供了全面的类型注释以及对方法链接到任何框架的一致支持。

接下来我们来看 EagerPy 的具体设计与实现。

EagerPy 的设计与实现

EagerPy 的构建考虑到了 4 个设计目标。两个主要的目标是为需要执行操作的人提供统一的 API,并维护底层框架的原始性能。这两个主要目标定义了 EagerPy 是什么,所以是设计的核心。

与底层框架特定的 API 相比,完全可链接的 API 和全面的类型检查支持这两个附加目标使 EagerPy 更加易于使用,也更安全。

尽管进行了这些更改和改进,但研究者尝试避免不必要的熟悉度(familiarity)损失。只要有意义,EagerPy API 都会遵循 NumPy、PyTorch 和 JAX 设置的标准。

统一的 API

为了实现语法上的一致性,研究者使用适当的方法定义了一个抽象 Tensor 类,并使用一个实例变量来保存原生张量(native tensor),然后为每个支持的框架实现一个特定的子类。对于诸如 sum 或 log 的很多操作,这就像调用底层框架一样简单;而对于其他操作,则工作量会稍大一些。

最困难的部分是统一自动微分 API。PyTorch 使用了一个低级的 autograd API,该 API 允许但也需要对反向传播的精确控制。TensorFlow 使用基于梯度磁带(gradient tapes)的更高级 API。而 JAX 使用基于微分函数的相当高级的 API。

所以,为了统一它们,EagerPy 模仿了 JAX 的高级功能 API,并在 PyTorch 和 TensorFlow 中重新实现。EagerPy 通过 value_and_grad_fn() 函数将其开放。

此外,能够编写自动与所有支持的框架一起运行的代码,不仅需要语法,还需要语义统一。为了确保这一点,EagerPy 附带了一个庞大的测试套件,该套件可以验证不同框架特定子类之间的一致性。它会在所有 pull-request 上自动运行,并且需要通过之后才能合并新代码。

测试套件还可以作为所支持的操作和参数组合的最终参考。这样就可以避免文档和实现之间出现不一致,并在实践中引出测试驱动开发过程。

原始性能

没有 EagerPy,想要与不同深度学习框架进行交互的代码必须经过 NumPy 实现。这需要在 CPU(NumPy)和 GPU(PyTorch、TensorFlow 和 JAX)之间进行高成本的内存复制,反之亦然。

此外,许多计算仅在 CPU 上执行,为了避免这种情况,EagerPy 仅保留对原始框架特定张量的引用(例如 GPU 上的 PyTorch 张量),并将所有的操作委托给相应的框架。这几乎不产生任何的计算开销。

完全可链接的 API

求和或平方之类的许多运算都要采用张量并返回一个张量。通常情况下,这些运算按顺序被调用。例如使用平方、求和和开平方根以计算 L2 范数。

在 EagerPy 中,所有运算都成为了张量对象(tensor object)上可用的方法。这样就可以按照它们的自然顺序(x.square().sum().sqrt())来链接操作。相反,例如,NumPy 需要相反的操作顺序,即 np.sqrt(np.square(x).sum())。

类型检查

在 Python3.5 中,Python 语法的扩展已经实现了对类型注释的支持(van Rossum 等人,2015 年)。即使具有类型注释,Python 仍然是一种动态类型化的编程语言,并且当前在运行时会忽略所有类型注释。但是,我们可以在运行代码之前通过静态代码分析器检查这些类型注释。

EagerPy 带有所有参数和返回值的全面类型注释,并使用 Mypy(Lehtosalo 等人,2016 年)对这些注释进行检查。这有助于我们捕获 EagerPy 中的漏洞,否则这些漏洞将一直不会被发现。

EagerPy 用户可以通过键入自己代码的注释,并根据 EagerPy 的函数签名(function signature)自动检查代码来进一步优化。这一点很关键,因为 TensorFlow、NumPy 和 JAX 当前自身不提供类型注释。

EagerPy 的代码实例解析

如下代码 1 为一个通用 EagerPy 范数函数,它可以通过任何框架中的原生张量被调用,并且返回的范数依然作为同一个框架中的原生张量。

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接

代码 1:框架无关的范数函数。

EagerPy 和原生张量之间的转换

原生张量可以是 PyTorch GPU 或 CPU 张量,如下代码 2 所示:

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接

代码 2:原生 PyTorch 张量。

可以是 TensorFlow 张量,如下代码 3 所示:

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接

代码 3:原生 TensorFlow 张量。


可以是 JAX 数组,如下代码 4 所示:

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接

代码 4:原生 JAX 数组。

可以是 NumPy 数组,如下代码 5 所示:

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接

代码 5:原生 NumPy 数组。

无论是哪种原生张量,通常都可以使用 ep.astensor 将它转换为适当的 EagerPy 张量。在此步骤中,通过使用正确的 EagerPy 张量类来自动封装原生张量。此外,最初的原生张量通常可以利用. raw 属性实现访问。完整示例如下代码 6 所示:

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接

EagerPy 和原生张量之间的转换。

在函数中通常将所有输入转换为 EagerPy 张量。这可以通过单独调用 ep.astensor 完成,但在使用 ep.astensors 时,代码可以更加简洁,如下:

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接


实现框架无关的通用函数

通过上文中的转换函数,我们可以定义一个简单的框架无关函数,如下代码 8 所示:

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接

代码 8:一个简单的框架无关范数函数。

如下代码 9 所示,通过一个 PyTorch 张量来调用范数函数:

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接


 如下代码 10 所示,通过一个 TensorFlow 张量来调用范数函数:

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接


此外,还需要注意一点,如果如上代码 8 所示使用 EagerPy 张量来调用函数,则 ep.astensor 调用只会返回它的输入。但是,最后一行代码中的 result.raw 调用依然会提取底层原生张量。通常而言,实现的通用函数最好可以透明地操控任何原生张量和 EagerPy 张量,也就是说返回类型应该总是与输入类型相匹配。

这在 Foolbox 等库中非常有用,可以使用户同时处理 EagerPy 和原生张量。

为此,EagerPy 提供上述转换函数的两种派生函数,分别是 ep.astensor_和 ep.astensors_,它们可以返回一个能够恢复输入类型的反转函数。

如果 astensor_的输入是一个原生张量,则 restore_type 等同于. raw;而如果原输入是一个 EagerPy 张量,则 restore_type 将不会调用. raw。因此,我们可以编写对任何输入都透明的改进版框架无关通用函数,如下代码 11 所示:

API统一、干净,适配PyTorch、TF,新型EagerPy实现多框架无缝衔接


最后,如下代码 12 所示,使用 ep.astensors_来转换和恢复多个输入:


不久之前, KDD 2020 公布了最佳论文、最佳学生论文等多个奖项。其中,最佳学生论文奖由杜克大学的李昂、杨幻睿、陈怡然和北航段逸骁、杨建磊摘得。


为了帮助读者们更细致的了解这篇论文,9月3日最新一期的机器之心线上论文分享邀请到最佳学生论文一作李昂,为我们介绍该研究。



© THE END 

投稿或寻求报道:[email protected]