vlambda博客
学习文章列表

深入理解TVM:Python/C++互调(下)

一、基本原理

TVM使用python的ctypes模块来调用c++代码提供的API,ctypes是python内建的可以用于调用C/C++动态链接库函数的功能模块,ctypes官方文档(https://docs.python.org/3/library/ctypes.html)是这样介绍的:

ctypes is a foreign function library for Python.It provides C compatible data types, and allows calling functions in DLLs or shared libraries. It can be used to wrap these libraries in pure Python.

对于动态链接库提供的API,需要使用符合c语言编译和链接约定的API,因为python的ctype只和c兼容,而c++编译器会对函数和变量名进行name mangling,所以需要使用__cplusplus宏和extern "C"来得到符合c语言编译和链接约定的API,以TVM给python提供的接口为例:

// TVM给python提供的接口主要都在这个文件:// include/tvm/runtime/c_runtime_api.h,// 下面主要展示了__cplusplus和extern "C"的用法,// 以及几个关键的API。#ifdef __cplusplusextern "C" {#endif
int TVMFuncListGlobalNames(...);int TVMFuncGetGlobal(...);int TVMFuncCall(...);    #ifdef __cplusplus} // TVM_EXTERN_C#endif

二、加载TVM动态库

TVM的python代码从python/tvm/__init__.py中开始真正执行,即:
from ._ffi.base import TVMError, __version__

这句简单的import代码,会执行python/tvm/_ffi/__init__.py:

from .base import register_errorfrom .registry import register_funcfrom .registry import _init_api, get_global_func

上面的第一句,会导致python/tvm/_ffi/base.py中的下面代码被执行:

def _load_lib(): lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL) return lib, os.path.basename(lib_path[0])
_LIB, _LIB_NAME = _load_lib()

上面的lib_path[0]是TVM动态链接库的全路径名称,我是在linux系统做的试验,链接库的名称是/xxx/libtvm.so(不同的系统动态库的名字会有所不同,windows系统是.dll,苹果系统是.dylib,linux系统是.so),在_load_lib函数执行完成后,_LIB和_LIB_NAME都完成了初始化,其中_LIB是一个ctypes.CDLL类型的变量,可以认为它是能够操作TVM动态链接库的export symbols的一个全局句柄,_LIB_NAME是libtvm.so这个字符串。这样后续在python中,我们就能通过_LIB这个桥梁不断的和c++的部分进行交互。

三、python怎么关联c++的PackedFunc

在这个系列的中,已经对c++中的PackedFunc做了详细的剖析,这里主要来理清楚python的代码中是怎么使用这个核心组件的,还是通过代码,一步步来看。

python中来获取c++API的底层函数是_get_global_func:

# python/tvm/_ffi/_ctypes/packed_func.pydef _get_global_func(func_name): handle = ctypes.c_void_p()    _LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle))    return _make_packed_func(handle, False)

这里面handle是一个相当于void类型的指针变量,因为从ctypes的官方文档中可以查到,c_void_p对应的primitive C compatible data type是:

ctype type
c type python type
c_void_p
void * int or None

_get_global_func中调用了TVMFuncGetGlobal这个API,看下这个API的实现就可以发现,handle最终保存了一个c++代码在堆中new出来的PackedFunc对象指针:

// src/runtime/registry.ccint TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) { const tvm::runtime::PackedFunc* fp       = tvm::runtime::Registry::Get(name);  *out = new tvm::runtime::PackedFunc(*fp);}

和c++PackedFunc的关联工作这时候才完成一半,在_get_global_func的最后调用了_make_packed_func这个函数:

# python/tvm/_ffi/_ctypes/packed_func.pydef _make_packed_func(handle, is_global): obj = PackedFunc.__new__(PackedFuncBase) obj.is_global = is_global obj.handle = handle return obj

可以看到_make_packed_func函数中创建了一个定义在python/tvm/runtime/packed_func.py中的python PackedFunc对象,PackedFunc其实是一个空实现,它继承自PackedFuncBase类,PackedFuncBase类中定义了一个__call__函数:

# python/tvm/_ffi/_ctypes/packed_func.pyclass PackedFuncBase(object):  def __call__(self, *args):    values, tcodes, num_args = _make_tvm_args(args, temp_args)    ret_val = TVMValue()    ret_tcode = ctypes.c_int()    _LIB.TVMFuncCall( self.handle, values, tcodes, ctypes.c_int(num_args), ctypes.byref(ret_val), ctypes.byref(ret_tcode),    )    return ret_val
// src/runtime/c_runtime_api.ccint TVMFuncCall(TVMFunctionHandle handle, TVMValue* args, ...)  (*static_cast<const PackedFunc*>(handle))      .CallPacked(TVMArgs(args, arg_type_codes, num_args), &rv);}

这样就完成了把c++中的PackedFunc映射到了python中的PackedFunc,在python代码中只需要调用python中创建好的PackedFunc对象,就会通过上面分析的过程来一步步调到c++的代码中。

四、把注册的函数关联到python各个模块

注册的函数既包括c++中注册的函数,也包括python中注册的函数,其中主要是c++中注册的函数,通过list_global_func_names函数(实际上调用的TVMFuncListGlobalNames这个c++API)可以得到c++中注册的所有函数,目前有1500多个,截图了最开始的十个作为示例给大家看一下:

先看_init_api这个函数,这个函数是把注册函数关联到各个模块的关键:

# python/tvm/_ffi/registry.pydef _init_api(prefix, module_name): target_module = sys.modules[module_name]
for name in list_global_func_names(): if not name.startswith(prefix):            continue        fname = name[len(prefix) + 1 :] f = get_global_func(name) ff = _get_api(f) ff.__name__ = fname ff.__doc__ = "TVM PackedFunc %s. " % fname setattr(target_module, ff.__name__, ff)

这里面有三个最主要的点:

  • line3:sys.modules是一个全局字典,每当程序员导入新的模块,sys.modules将自动记录该模块。 当第二次再导入该模块时,python会直接到字典中查找,从而加快了程序运行的速度。

  • line13:把前面代码构造的python端PackedFunc对象作为属性设置到相应的模块上

然后各个模块中对_init_api来全局调用一次,就完成了关联,我在代码中找了几个作为示例,如下所示:

# python/tvm/runtime/_ffi_api.pytvm._ffi._init_api("runtime", __name__)
# python/tvm/relay/op/op.pytvm._ffi._init_api("relay.op", __name__)
# python/tvm/relay/backend/_backend.pytvm._ffi._init_api("relay.backend", __name__)

五、举一个例子

以TVM中求绝对值的函数abs为例,这个函数实现在tir模块,函数的功能很简单,不会造成额外的理解负担,我们只关注从python调用是怎么映射到c++中的,先看在c++中abs函数的定义和注册:

// src/tir/op/op.cc// 函数定义PrimExpr abs(PrimExpr x, Span span) { ... }
// 函数注册TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);

再看python端的调用:

# python/tvm/tir/_ffi_api.py# 把c++ tir中注册的函数以python PackedFunc# 对象的形式关联到了_ffi_api这个模块tvm._ffi._init_api("tir", __name__)
# python/tvm/tir/op.py# 定义了abs的python函数,其实内部调用了前面# 关联到_ffi_api这个模块的python PackedFunc对象def abs(x, span=None): return _ffi_api.abs(x, span)

最后用户可以这样来使用这个函数:

import tvmfrom tvm import tir
rlt = tir.abs(-100)print("abs(-100) = %d" % (rlt)

六、最后

现在为止,python/c++互调这个系列就讲完了,后续还会继续写TVM为主题的文章,自己的理解有限,这里面也许有说的不对的地方,欢迎大家留言指出,最后附上前两篇的链接: