楔子
了解完缓冲区协议之后,我们就来手动实现一下。而实现方式可以使用原生的 Python/C API,也可以使用 Cython,但前者实现起来会非常麻烦,牵扯的知识也非常多,而 Cython 则简化了这一点。
我们来分别介绍一下这两种方式。
使用 Cython 实现缓冲区协议
Cython 对缓冲区协议也有着强大的支持,我们只需要定义一个魔法方法即可实现缓冲区协议。
from cpython cimport Py_buffer from cpython.mem cimport PyMem_Malloc, PyMem_Free cdef class Matrix: cdef Py_ssize_t shape[2] # 数组的形状 cdef Py_ssize_t strides[2] # 数组的 stride cdef float *array def __cinit__(self, row, col): self.shape[0] = <Py_ssize_t> row self.shape[1] = <Py_ssize_t> col self.strides[1] = sizeof(float) self.strides[0] = self.strides[1] * self.shape[1] self.array = <float *> PyMem_Malloc( self.shape[0] * self.shape[1] * sizeof(float)) def set_item_by_index(self, int index, float value): """留一个接口,用来设置元素""" if index >= self.shape[0] * self.shape[1] or index < 0: raise ValueError("索引无效") self.array[index] = value def __getbuffer__(self, Py_buffer *buffer, int flags): """自定义缓冲区需要实现 __getbuffer__ 方法""" cdef int i; for i in range(self.shape[0] * self.shape[1]): self.array[i] = float(i) # 缓冲区,这里是就是 array 本身,但是需要转成 void * buffer.buf = <void *> self.array # 实现缓冲区协议的对象,显然是 selfself.shape[0] * self.shape[1] buffer.obj = self # 缓冲区的总大小 buffer.len = self.shape[0] * self.shape[1] * sizeof(float) # 读写权限,这里让缓冲区可读写 buffer.readonly = 0 # 缓冲区每个元素的大小 buffer.itemsize = sizeof(float) # 元素类型,"f" 表示 float buffer.format = "f" # 该对象的维度 buffer.ndim = 2 # shape buffer.shape = self.shape # strides buffer.strides = self.strides # 直接设置为 NULL 即可 buffer.suboffsets = NULL def dealloc(self): if self.array != NULL: PyMem_Free(<void *> self.array)
在 Cython 中我们只需要实现一个相应的魔法方法即可,真的是非常方便,当然我们为了验证是否共享内存,专门定义了一个方法。
import pyximport pyximport.install(language_level=3) import cython_test import numpy as np m = cython_test.Matrix(5, 4) # 基于 m 创建 Numpy 数组 np_m = np.asarray(m) # m 和 np_m 是共享内存的 print(m) """ <cython_test.Matrix object at 0x7f96ba55a3f0> """ print(np_m) """ [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.] [12. 13. 14. 15.] [16. 17. 18. 19.]] """ # 通过 m 修改元素,然后打印 np_m m.set_item_by_index(13, 666.666) print(np_m) """ [[ 0. 1. 2. 3. ] [ 4. 5. 6. 7. ] [ 8. 9. 10. 11. ] [ 12. 666.666 14. 15. ] [ 16. 17. 18. 19. ]] """
结果没有任何问题,以上就是 Cython 实现缓冲区协议,其实在日常工作中我们不需要直接面对它,但了解一下总是好的。
使用 Python/C API 实现缓冲区协议
注:通过原生的 Python/C API 实现缓冲区协议,这个过程非常麻烦,因为这需要你熟悉解释器源代码,以及这些 API 本身。我们的重点是 Cython,只要知道 Cython 如何实现缓冲区协议即可。至于原生的 Python/C API,感兴趣的话可以看一看,不感兴趣的话跳过即可。
下面编写 C 源文件,文件名为 py_array.c。
#include <stdio.h> #include <stdlib.h> #include <Python.h> // 定义一个一维的数组 typedef struct { int *arr; int length; } Array; // 初始化函数 void initial_Array(Array *array, int length) { array->length = length; if (length == 0) { array->arr = NULL; } else { array->arr = (int *) malloc(sizeof(int) * length); for (int i = 0; i < length; i++) { array->arr[i] = i; } } } // 释放内存 void dealloc_Array(Array *array) { if (array->arr != NULL) free(array->arr); array->arr = NULL; } // Python 的对象在 C 中都嵌套了 PyObject typedef struct { PyObject_HEAD Array array; } PyArray; // 初始化 __init__ 函数 static int PyArray_init(PyArray *self, PyObject *args, PyObject *kwargs) { if (self->array.arr != NULL) { dealloc_Array(&self->array); } int length = 0; static char *kwlist[] = {"length", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", kwlist, &length)) { return -1; } // 因为给 Python 调用,所以这里额外对 length 进行一下检测 if (length < 0) { PyErr_SetString(PyExc_ValueError, "argument 'length' can not be negative"); return -1; } initial_Array(&self->array, length); return 0; } // 析构函数 static void PyArray_dealloc(PyArray *self) { dealloc_Array(&self->array); Py_TYPE(self)->tp_free((PyObject *) self); } static PyObject * PyArray_repr(PyArray *self) { //转成列表打印 PyObject *list = PyList_New(self->array.length); Py_ssize_t i; for (i=0; i<self->array.length; i++){ PyList_SetItem(list, i, PyLong_FromLong(*(self->array.arr + i))); } PyObject *ret = PyObject_Str(list);Py_DECREF(list); return ret; } // 实现缓冲区协议 static int PyArray_getbuffer(PyObject *obj, Py_buffer *view, int flags) { if (view == NULL) { PyErr_SetString(PyExc_ValueError, "NULL view in getbuffer"); return -1; } PyArray* self = (PyArray *)obj; view->obj = (PyObject*)self; view->buf = (void*)self->array.arr; view->len = self->array.length * sizeof(int); view->readonly = 0; view->itemsize = sizeof(int); view->format = "i"; view->ndim = 1; view->shape = (Py_ssize_t *) &self->array.length; view->strides = &view->itemsize; view->suboffsets = NULL; view->internal = NULL; Py_INCREF(self); return 0; } // 将上面的函数放入到 PyBufferProcs 结构体中 static PyBufferProcs PyArray_as_buffer = { (getbufferproc)PyArray_getbuffer, (releasebufferproc)0 }; static PyTypeObject PyArrayType = { PyVarObject_HEAD_INIT(NULL, 0) "py_my_array.PyArray", sizeof(PyArray), 0, (destructor) PyArray_dealloc, 0, 0, /* tp_getattr */ 0, /* tp_setattr */ 0, /* tp_reserved */ (reprfunc)PyArray_repr, /* tp_repr */ 0, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ 0, /* tp_call */ 0, /* tp_str */ 0, /* tp_getattro */ 0, /* tp_setattro */ // 指定 tp_as_buffer &PyArray_as_buffer, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT, /* tp_flags */ "PyArray object", /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ 0, /* tp_methods */ 0, /* tp_members */ 0, /* tp_getset */ 0, /* tp_base */ 0, /* tp_dict */ 0, /* tp_descr_get */ 0, /* tp_descr_set */ 0, /* tp_dictoffset */ (initproc) PyArray_init, /* tp_init */ }; static PyModuleDef py_array_module = { PyModuleDef_HEAD_INIT, "py_array", "this is a module named py_array", -1, 0, NULL, NULL, NULL, NULL }; PyMODINIT_FUNC PyInit_py_array(void) { PyObject *m; PyArrayType.tp_new = PyType_GenericNew; if (PyType_Ready(&PyArrayType) < 0) return NULL; m = PyModule_Create(&py_array_module); if (m == NULL) return NULL; Py_XINCREF(&PyArrayType); PyModule_AddObject(m, "PyArray", (PyObject *) &PyArrayType); return m; }
现在相信你一定能体会到为什么要有 Cython 存在,因为写原生的 Python/C API 太痛苦了,而且为了简便我们这里使用的是一维数组,但即便如此,也已经很麻烦了。
我们编译成扩展,首先编写 setup.py:
from distutils.core import setup, Extension from Cython.Build import cythonize ext = [Extension("py_array", sources=["py_array.c"])] setup(ext_modules=cythonize(ext, language_level=3))
执行 python setup.py build 生成扩展模块,然后我们来导入它。
import numpy as np import py_array print(py_array) """ <module 'py_array' from ..\\py_array.cp38-win_amd64.pyd'> """ arr = py_array.PyArray(5) print(arr) """ [0, 1, 2, 3, 4] """ np_arr = np.asarray(arr) print(np_arr) """ [0 1 2 3 4] """ # 两者也是共享内存 np_arr[0] = 123 print(arr) print(np_arr) """ [123, 1, 2, 3, 4] [123 1 2 3 4] """
显然此时万事大吉了,因为实现了缓冲区协议,Numpy 知道了缓冲区数据,因此会在此基础之上建一个 view,并且 array 和 np_arr 是共享内存的。
因此核心就在于对缓冲区协议的理解,它本质上就是一个结构体,内部的成员描述了缓冲区数据的所有信息。而我们只需要定义一个函数,然后根据数组初始化这些信息即可,最后构建 PyBufferProcs 实例作为 tp_as_buffer 成员的值。
小结
以上我们就介绍了如何通过原生的 Python/C API 和 Cython 实现缓冲区协议,通过两个例子的对比,我们算是体会到了 Cython 的难能可贵,对于想写扩展的人来说,Cython 无疑是一大福音。即使是缓冲区协议,Cython 也有着超一流的支持。
但说实话,缓冲区协议我们在工作中几乎不用手动实现,因为它还是比较原始和底层的,我们知道就好。而 Cython 在缓冲区协议的基础上提供了一种新的数据结构:内存视图,它屏蔽了缓冲区协议的具体细节,可以让我们在不和缓冲区协议直接打交道的情况下,使用缓冲区协议。关于内存视图,下一篇文章来介绍。
E N D