types-paddle: 为Paddle增加Tensor类型注释特性

简介: Paddle中没有Tensor类,导致在IDE中写代码时没有办法进行智能提示,我提供了一个解决方案。

Paddle中没有Tensor类,导致在IDE中写代码时没有办法进行智能提示,我提供了一个解决方案。

types-paddle-mini.gif

思路

  • 修改已安装Paddle包的paddle/tensor/tensor.py文件,添加未实现的Tensor类。
  • 添加tensor.pyi文件到paddle包当中,从而实现代码智能提示。

类型注释的三种解决方案

在python当中有三种方式给代码提供类型注释,从而让IDE能够实现智能提示:

  1. 直接在代码中写上类型注释
def add(x: int, y: int) -> int:
    return x + y

此方法也是在python3.7+中最为推荐的方式。

  1. 原代码中并没有类型注释,便在包中添加pyi文件
什么是 pyi文件?可理解为 python interface文件,为某个python module提供接口定义信息。
# foo.py
def add(x, y):
    return x + y
# foo.pyi
def add(x: int, y: int) -> int: ...

此情况下,pyi文件名必须和py文件名一致,这样浏览器在加载原始文件类型信息时,直接从pyi文件中加载。其中PyGithub就是使用这种方式来提供类型注释。

  1. 作者不想把pyi stub 文件添加到包中

如果每个文件都要添加一个pyi文件,则代码文件数量直接增加一倍,这将会增加维护的难度,于是可将pyi文件通过第三方包发布。 详细原理可阅读PEP 561 – Distributing and Packaging Type Information

为什么选择这个解决方案?

面临的问题

我最初的做法也是使用第三种方法,可是发现如果用第三方包发布的话,paddle的所有类型提示将会从我的包走:也就是说我要给paddle所有的module都添加上pyi stub 文件。 这工作量很大,而且很多模块都在发生变动,我没有办法及时获取到最新的变动,很容易导致版本接口不兼容的问题。

解决方案

于是,我返回使用第二种方法,只是使用types-paddle来修改paddle包从而支持类型注释。

pyi 是如何生成的?

其中最核心的模块属于pyi文件是如何生成的?毕竟这个代表着Tensor所有类的所有属性:根据runtime Tensor来生成,伪代码如下所示:

import paddle
tensor = paddle.randn([3,4])
members = get_members(tensor)
gen_stub_file_by_tensor_member(members)

详细可见:gen_tensor_stub.py

这个包的未来

我相信,未来paddle肯定是会支持类型注释的,毕竟原始paddle/tensor/tensor.py文件已经写上了TODO。只是该特性还没有完成的时候,这个工具可以提升大家的编码体验。

希望这个工具能够让大家写paddle越来越爽。

参考资料

相关文章
|
3月前
|
TensorFlow 算法框架/工具
【Tensorflow+Keras】tf.keras.backend.image_data_format()的解析与举例使用
介绍了TensorFlow和Keras中tf.keras.backend.image_data_format()函数的用法。
46 5
|
12月前
|
Docker 容器
求助: 运行模型时报错module 'megatron_util.mpu' has no attribute 'get_model_parallel_rank'
运行ZhipuAI/Multilingual-GLM-Summarization-zh的官方代码范例时,报错AttributeError: MGLMTextSummarizationPipeline: module 'megatron_util.mpu' has no attribute 'get_model_parallel_rank' 环境是基于ModelScope官方docker镜像,尝试了各个版本结果都是一样的。
398 5
|
11月前
问题出在`megatron_util.mpu`模块中没有找到`get_model_parallel_rank`属性
问题出在`megatron_util.mpu`模块中没有找到`get_model_parallel_rank`属性
111 1
torch.argmax(dim=1)用法
)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;
630 0
|
API 数据格式
TensorFlow2._:model.summary() Output Shape为multiple解决方法
TensorFlow2._:model.summary() Output Shape为multiple解决方法
274 0
TensorFlow2._:model.summary() Output Shape为multiple解决方法
Paddle 1.8 与 Paddle 2.0 API 映射表
本文档基于 Paddle 1.8 梳理了常用 API 与 Paddle 2.0 对应关系。你可以根据对应关系,快速熟悉 Paddle 2.0 的接口使用。
|
机器学习/深度学习 人工智能 PyTorch
PyTorch与Paddle映射表
PyTorch与Paddle映射表
172 0
|
存储 测试技术
测试模型时,为什么要with torch.no_grad(),为什么要model.eval(),如何使用with torch.no_grad(),model.eval(),同时使用还是只用其中之一
在测试模型时,我们通常使用with torch.no_grad()和model.eval()这两个方法来确保模型在评估过程中的正确性和效率。
944 0
|
并行计算 编译器 TensorFlow
[不用回退keras版本的解决方法]AttributeError: module 'tensorflow.python.ops.nn' has no attribute 'leaky_relu'
[不用回退keras版本的解决方法]AttributeError: module 'tensorflow.python.ops.nn' has no attribute 'leaky_relu'
269 0
[不用回退keras版本的解决方法]AttributeError: module 'tensorflow.python.ops.nn' has no attribute 'leaky_relu'
|
机器学习/深度学习 PyTorch 算法框架/工具
pytorch中nn.Parameter()使用方法
pytorch中nn.Parameter()使用方法
1347 1