实现缓冲区协议

简介: 实现缓冲区协议


楔子




了解完缓冲区协议之后,我们就来手动实现一下。而实现方式可以使用原生的 Python/C API,也可以使用 Cython,但前者实现起来会非常麻烦,牵扯的知识也非常多,而 Cython 则简化了这一点。

我们来分别介绍一下这两种方式。


使用 Cython 实现缓冲区协议




Cython 对缓冲区协议也有着强大的支持,我们只需要定义一个魔法方法即可实现缓冲区协议。

from cpython cimport Py_buffer
from cpython.mem cimport PyMem_Malloc, PyMem_Free
cdef class Matrix:
    cdef Py_ssize_t shape[2]  # 数组的形状
    cdef Py_ssize_t strides[2]  # 数组的 stride
    cdef float *array
    def __cinit__(self, row, col):
        self.shape[0] = <Py_ssize_t> row
        self.shape[1] = <Py_ssize_t> col
        self.strides[1] = sizeof(float)
        self.strides[0] = self.strides[1] * self.shape[1]
        self.array = <float *> PyMem_Malloc(
            self.shape[0] * self.shape[1] * sizeof(float))
    def set_item_by_index(self, int index, float value):
        """留一个接口,用来设置元素"""
        if index >= self.shape[0] * self.shape[1] or index < 0:
            raise ValueError("索引无效")
        self.array[index] = value
    def __getbuffer__(self, Py_buffer *buffer, int flags):
        """自定义缓冲区需要实现 __getbuffer__ 方法"""
        cdef int i;
        for i in range(self.shape[0] * self.shape[1]):
            self.array[i] = float(i)
        # 缓冲区,这里是就是 array 本身,但是需要转成 void *
        buffer.buf = <void *> self.array
        # 实现缓冲区协议的对象,显然是 selfself.shape[0] * self.shape[1]
        buffer.obj = self
        # 缓冲区的总大小
        buffer.len = self.shape[0] * self.shape[1] * sizeof(float)
        # 读写权限,这里让缓冲区可读写
        buffer.readonly = 0
        # 缓冲区每个元素的大小
        buffer.itemsize = sizeof(float)
        # 元素类型,"f" 表示 float
        buffer.format = "f"
        # 该对象的维度
        buffer.ndim = 2
        # shape
        buffer.shape = self.shape
        # strides
        buffer.strides = self.strides
        # 直接设置为 NULL 即可
        buffer.suboffsets = NULL
    def dealloc(self):
        if self.array != NULL:
            PyMem_Free(<void *> self.array)

在 Cython 中我们只需要实现一个相应的魔法方法即可,真的是非常方便,当然我们为了验证是否共享内存,专门定义了一个方法。

import pyximport
pyximport.install(language_level=3)
import cython_test
import numpy as np
m = cython_test.Matrix(5, 4)
# 基于 m 创建 Numpy 数组
np_m = np.asarray(m)
# m 和 np_m 是共享内存的
print(m) 
"""
<cython_test.Matrix object at 0x7f96ba55a3f0>
"""
print(np_m)
"""
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]
 [12. 13. 14. 15.]
 [16. 17. 18. 19.]]
"""
# 通过 m 修改元素,然后打印 np_m
m.set_item_by_index(13, 666.666)
print(np_m)
"""
[[  0.      1.      2.      3.   ]
 [  4.      5.      6.      7.   ]
 [  8.      9.     10.     11.   ]
 [ 12.    666.666  14.     15.   ]
 [ 16.     17.     18.     19.   ]]
"""

结果没有任何问题,以上就是 Cython 实现缓冲区协议,其实在日常工作中我们不需要直接面对它,但了解一下总是好的。


使用 Python/C API 实现缓冲区协议




注:通过原生的 Python/C API 实现缓冲区协议,这个过程非常麻烦,因为这需要你熟悉解释器源代码,以及这些 API 本身。我们的重点是 Cython,只要知道 Cython 如何实现缓冲区协议即可。至于原生的 Python/C API,感兴趣的话可以看一看,不感兴趣的话跳过即可。

下面编写 C 源文件,文件名为 py_array.c。

#include <stdio.h>
#include <stdlib.h>
#include <Python.h>
// 定义一个一维的数组
typedef struct {
    int *arr;
    int length;
} Array;
// 初始化函数
void initial_Array(Array *array, int length) {
    array->length = length;
    if (length == 0) {
        array->arr = NULL;
    } else {
        array->arr = (int *) malloc(sizeof(int) * length);
        for (int i = 0; i < length; i++) {
            array->arr[i] = i;
        }
    }
}
// 释放内存
void dealloc_Array(Array *array) {
    if (array->arr != NULL) free(array->arr);
    array->arr = NULL;
}
// Python 的对象在 C 中都嵌套了 PyObject
typedef struct {
    PyObject_HEAD
    Array array;
} PyArray;
// 初始化 __init__ 函数
static int
PyArray_init(PyArray *self, PyObject *args, PyObject *kwargs) {
    if (self->array.arr != NULL) {
        dealloc_Array(&self->array);
    }
    int length = 0;
    static char *kwlist[] = {"length", NULL};
    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", kwlist, &length)) {
        return -1;
    }
    // 因为给 Python 调用,所以这里额外对 length 进行一下检测
    if (length < 0) {
        PyErr_SetString(PyExc_ValueError, "argument 'length' can not be negative");
        return -1;
    }
    initial_Array(&self->array, length);
    return 0;
}
// 析构函数
static void
PyArray_dealloc(PyArray *self) {
    dealloc_Array(&self->array);
    Py_TYPE(self)->tp_free((PyObject *) self);
}
static PyObject *
PyArray_repr(PyArray *self) {
    //转成列表打印
    PyObject *list = PyList_New(self->array.length);
    Py_ssize_t i;
    for (i=0; i<self->array.length; i++){
        PyList_SetItem(list, i, PyLong_FromLong(*(self->array.arr + i)));
    }
    PyObject *ret = PyObject_Str(list);Py_DECREF(list);
    return ret;
}
// 实现缓冲区协议
static int
PyArray_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
    if (view == NULL) {
        PyErr_SetString(PyExc_ValueError,
                        "NULL view in getbuffer");
        return -1;
    }
    PyArray* self = (PyArray *)obj;
    view->obj = (PyObject*)self;
    view->buf = (void*)self->array.arr;
    view->len = self->array.length * sizeof(int);
    view->readonly = 0;
    view->itemsize = sizeof(int);
    view->format = "i";
    view->ndim = 1;
    view->shape = (Py_ssize_t *) &self->array.length;
    view->strides = &view->itemsize;
    view->suboffsets = NULL;
    view->internal = NULL;
    Py_INCREF(self);
    return 0;
}
// 将上面的函数放入到 PyBufferProcs 结构体中
static PyBufferProcs PyArray_as_buffer = {
        (getbufferproc)PyArray_getbuffer,
        (releasebufferproc)0
};
static PyTypeObject PyArrayType = {
    PyVarObject_HEAD_INIT(NULL, 0)
    "py_my_array.PyArray",
    sizeof(PyArray),
    0,
    (destructor) PyArray_dealloc,
    0,
    0,                               /* tp_getattr        */
    0,                               /* tp_setattr        */
    0,                               /* tp_reserved       */
    (reprfunc)PyArray_repr,          /* tp_repr           */
    0,                               /* tp_as_number      */
    0,                               /* tp_as_sequence    */
    0,                               /* tp_as_mapping     */
    0,                               /* tp_hash           */
    0,                               /* tp_call           */
    0,                               /* tp_str            */
    0,                               /* tp_getattro       */
    0,                               /* tp_setattro       */
    // 指定 tp_as_buffer
    &PyArray_as_buffer,              /* tp_as_buffer      */
    Py_TPFLAGS_DEFAULT,              /* tp_flags          */
    "PyArray object",              /* tp_doc            */
    0,                               /* tp_traverse       */
    0,                               /* tp_clear          */
    0,                               /* tp_richcompare    */
    0,                               /* tp_weaklistoffset */
    0,                               /* tp_iter           */
    0,                               /* tp_iternext       */
    0,                               /* tp_methods        */
    0,                               /* tp_members        */
    0,                               /* tp_getset         */
    0,                               /* tp_base           */
    0,                               /* tp_dict           */
    0,                               /* tp_descr_get      */
    0,                               /* tp_descr_set      */
    0,                               /* tp_dictoffset     */
    (initproc) PyArray_init,         /* tp_init           */
};
static PyModuleDef py_array_module = {
    PyModuleDef_HEAD_INIT,
    "py_array",
    "this is a module named py_array",
    -1,
    0,
    NULL,
    NULL,
    NULL,
    NULL
};
PyMODINIT_FUNC
PyInit_py_array(void) {
    PyObject *m;
    PyArrayType.tp_new = PyType_GenericNew;
    if (PyType_Ready(&PyArrayType) < 0) return NULL;
    m = PyModule_Create(&py_array_module);
    if (m == NULL) return NULL;
    Py_XINCREF(&PyArrayType);
    PyModule_AddObject(m, "PyArray", 
                      (PyObject *) &PyArrayType);
    return m;
}

现在相信你一定能体会到为什么要有 Cython 存在,因为写原生的 Python/C API 太痛苦了,而且为了简便我们这里使用的是一维数组,但即便如此,也已经很麻烦了。

我们编译成扩展,首先编写 setup.py:

from distutils.core import setup, Extension
from Cython.Build import cythonize
ext = [Extension("py_array",
                 sources=["py_array.c"])]
setup(ext_modules=cythonize(ext, language_level=3))

执行 python setup.py build 生成扩展模块,然后我们来导入它。

import numpy as np
import py_array
print(py_array)
"""
<module 'py_array' from ..\\py_array.cp38-win_amd64.pyd'>
"""
arr = py_array.PyArray(5)
print(arr)
"""
[0, 1, 2, 3, 4]
"""
np_arr = np.asarray(arr)
print(np_arr)
"""
[0 1 2 3 4]
"""
# 两者也是共享内存
np_arr[0] = 123
print(arr)
print(np_arr)
"""
[123, 1, 2, 3, 4]
[123   1   2   3   4]
"""

显然此时万事大吉了,因为实现了缓冲区协议,Numpy 知道了缓冲区数据,因此会在此基础之上建一个 view,并且 array 和 np_arr 是共享内存的。

因此核心就在于对缓冲区协议的理解,它本质上就是一个结构体,内部的成员描述了缓冲区数据的所有信息。而我们只需要定义一个函数,然后根据数组初始化这些信息即可,最后构建 PyBufferProcs 实例作为 tp_as_buffer 成员的值。


小结




以上我们就介绍了如何通过原生的 Python/C API 和 Cython 实现缓冲区协议,通过两个例子的对比,我们算是体会到了 Cython 的难能可贵,对于想写扩展的人来说,Cython 无疑是一大福音。即使是缓冲区协议,Cython 也有着超一流的支持。

但说实话,缓冲区协议我们在工作中几乎不用手动实现,因为它还是比较原始和底层的,我们知道就好。而 Cython 在缓冲区协议的基础上提供了一种新的数据结构:内存视图,它屏蔽了缓冲区协议的具体细节,可以让我们在不和缓冲区协议直接打交道的情况下,使用缓冲区协议。关于内存视图,下一篇文章来介绍。


E N D



相关文章
|
SQL Java 数据库连接
MyBatis 优秀的持久层框架(一)
MyBatis 优秀的持久层框架
375 0
|
SQL Java 数据库连接
挺详细的spring+springmvc+mybatis配置整合|含源代码
挺详细的spring+springmvc+mybatis配置整合|含源代码
|
NoSQL 索引
MongoDB查询优化:从 10s 到 10ms
本文是我前同事付秋雷最近遇到到一个关于MongoDB执行计划选择的问题,非常有意思,在探索源码之后,他将整个问题搞明白并整理分享出来。付秋雷(他的博客)曾是Tair(阿里内部用得非常官方的KV存储系统)的核心开发,目前就职于蘑菇街。
|
11月前
|
缓存 JSON 数据处理
Python进阶:深入理解import机制与importlib的妙用
本文深入解析了Python的`import`机制及其背后的原理,涵盖基本用法、模块缓存、导入搜索路径和导入钩子等内容。通过理解这些机制,开发者可以优化模块加载速度并确保代码的一致性。文章还介绍了`importlib`的强大功能,如动态模块导入、实现插件系统及重新加载模块,展示了如何利用这些特性编写更加灵活和高效的代码。掌握这些知识有助于提升编程技能,充分利用Python的强大功能。
589 4
|
运维 负载均衡 网络协议
OSPF的主要特点与优势
OSPF的主要特点与优势
1121 0
|
前端开发 Docker 容器
主机host服务器和Docker容器之间的文件互传方法汇总
Docker 成为前端工具,可实现跨设备兼容。本文介绍主机与 Docker 容器/镜像间文件传输的三种方法:1. 构建镜像时使用 `COPY` 或 `ADD` 指令;2. 启动容器时使用 `-v` 挂载卷;3. 运行时使用 `docker cp` 命令。每种方法适用于不同场景,如静态文件打包、开发时文件同步及临时文件传输。注意权限问题、容器停止后的文件传输及性能影响。
3334 0
|
数据采集 监控 数据可视化
用Python构建动态折线图:实时展示爬取数据的指南
本文介绍了如何利用Python的爬虫技术从“财富吧”获取中国股市的实时数据,并使用动态折线图展示股价变化。文章详细讲解了如何通过设置代理IP和请求头来绕过反爬机制,确保数据稳定获取。通过示例代码展示了如何使用`requests`和`matplotlib`库实现这一过程,最终生成每秒自动更新的动态股价图。这种方法不仅适用于股市分析,还可广泛应用于其他需要实时监控的数据源,帮助用户快速做出决策。
583 0
|
网络协议 Serverless API
现代化 Web 应用构建问题之验证各个服务是否已成功部署如何解决
现代化 Web 应用构建问题之验证各个服务是否已成功部署如何解决
197 1
uniapp 打包成 apk(原生APP-云打包)免费
uniapp 打包成 apk(原生APP-云打包)免费
1645 1
|
监控 数据挖掘 数据安全/隐私保护
ERP系统中的质量管理与控制
【7月更文挑战第25天】 ERP系统中的质量管理与控制
948 2