实现缓冲区协议

简介: 实现缓冲区协议


楔子




了解完缓冲区协议之后,我们就来手动实现一下。而实现方式可以使用原生的 Python/C API,也可以使用 Cython,但前者实现起来会非常麻烦,牵扯的知识也非常多,而 Cython 则简化了这一点。

我们来分别介绍一下这两种方式。


使用 Cython 实现缓冲区协议




Cython 对缓冲区协议也有着强大的支持,我们只需要定义一个魔法方法即可实现缓冲区协议。

from cpython cimport Py_buffer
from cpython.mem cimport PyMem_Malloc, PyMem_Free
cdef class Matrix:
    cdef Py_ssize_t shape[2]  # 数组的形状
    cdef Py_ssize_t strides[2]  # 数组的 stride
    cdef float *array
    def __cinit__(self, row, col):
        self.shape[0] = <Py_ssize_t> row
        self.shape[1] = <Py_ssize_t> col
        self.strides[1] = sizeof(float)
        self.strides[0] = self.strides[1] * self.shape[1]
        self.array = <float *> PyMem_Malloc(
            self.shape[0] * self.shape[1] * sizeof(float))
    def set_item_by_index(self, int index, float value):
        """留一个接口,用来设置元素"""
        if index >= self.shape[0] * self.shape[1] or index < 0:
            raise ValueError("索引无效")
        self.array[index] = value
    def __getbuffer__(self, Py_buffer *buffer, int flags):
        """自定义缓冲区需要实现 __getbuffer__ 方法"""
        cdef int i;
        for i in range(self.shape[0] * self.shape[1]):
            self.array[i] = float(i)
        # 缓冲区,这里是就是 array 本身,但是需要转成 void *
        buffer.buf = <void *> self.array
        # 实现缓冲区协议的对象,显然是 selfself.shape[0] * self.shape[1]
        buffer.obj = self
        # 缓冲区的总大小
        buffer.len = self.shape[0] * self.shape[1] * sizeof(float)
        # 读写权限,这里让缓冲区可读写
        buffer.readonly = 0
        # 缓冲区每个元素的大小
        buffer.itemsize = sizeof(float)
        # 元素类型,"f" 表示 float
        buffer.format = "f"
        # 该对象的维度
        buffer.ndim = 2
        # shape
        buffer.shape = self.shape
        # strides
        buffer.strides = self.strides
        # 直接设置为 NULL 即可
        buffer.suboffsets = NULL
    def dealloc(self):
        if self.array != NULL:
            PyMem_Free(<void *> self.array)

在 Cython 中我们只需要实现一个相应的魔法方法即可,真的是非常方便,当然我们为了验证是否共享内存,专门定义了一个方法。

import pyximport
pyximport.install(language_level=3)
import cython_test
import numpy as np
m = cython_test.Matrix(5, 4)
# 基于 m 创建 Numpy 数组
np_m = np.asarray(m)
# m 和 np_m 是共享内存的
print(m) 
"""
<cython_test.Matrix object at 0x7f96ba55a3f0>
"""
print(np_m)
"""
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]
 [12. 13. 14. 15.]
 [16. 17. 18. 19.]]
"""
# 通过 m 修改元素,然后打印 np_m
m.set_item_by_index(13, 666.666)
print(np_m)
"""
[[  0.      1.      2.      3.   ]
 [  4.      5.      6.      7.   ]
 [  8.      9.     10.     11.   ]
 [ 12.    666.666  14.     15.   ]
 [ 16.     17.     18.     19.   ]]
"""

结果没有任何问题,以上就是 Cython 实现缓冲区协议,其实在日常工作中我们不需要直接面对它,但了解一下总是好的。


使用 Python/C API 实现缓冲区协议




注:通过原生的 Python/C API 实现缓冲区协议,这个过程非常麻烦,因为这需要你熟悉解释器源代码,以及这些 API 本身。我们的重点是 Cython,只要知道 Cython 如何实现缓冲区协议即可。至于原生的 Python/C API,感兴趣的话可以看一看,不感兴趣的话跳过即可。

下面编写 C 源文件,文件名为 py_array.c。

#include <stdio.h>
#include <stdlib.h>
#include <Python.h>
// 定义一个一维的数组
typedef struct {
    int *arr;
    int length;
} Array;
// 初始化函数
void initial_Array(Array *array, int length) {
    array->length = length;
    if (length == 0) {
        array->arr = NULL;
    } else {
        array->arr = (int *) malloc(sizeof(int) * length);
        for (int i = 0; i < length; i++) {
            array->arr[i] = i;
        }
    }
}
// 释放内存
void dealloc_Array(Array *array) {
    if (array->arr != NULL) free(array->arr);
    array->arr = NULL;
}
// Python 的对象在 C 中都嵌套了 PyObject
typedef struct {
    PyObject_HEAD
    Array array;
} PyArray;
// 初始化 __init__ 函数
static int
PyArray_init(PyArray *self, PyObject *args, PyObject *kwargs) {
    if (self->array.arr != NULL) {
        dealloc_Array(&self->array);
    }
    int length = 0;
    static char *kwlist[] = {"length", NULL};
    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", kwlist, &length)) {
        return -1;
    }
    // 因为给 Python 调用,所以这里额外对 length 进行一下检测
    if (length < 0) {
        PyErr_SetString(PyExc_ValueError, "argument 'length' can not be negative");
        return -1;
    }
    initial_Array(&self->array, length);
    return 0;
}
// 析构函数
static void
PyArray_dealloc(PyArray *self) {
    dealloc_Array(&self->array);
    Py_TYPE(self)->tp_free((PyObject *) self);
}
static PyObject *
PyArray_repr(PyArray *self) {
    //转成列表打印
    PyObject *list = PyList_New(self->array.length);
    Py_ssize_t i;
    for (i=0; i<self->array.length; i++){
        PyList_SetItem(list, i, PyLong_FromLong(*(self->array.arr + i)));
    }
    PyObject *ret = PyObject_Str(list);Py_DECREF(list);
    return ret;
}
// 实现缓冲区协议
static int
PyArray_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
    if (view == NULL) {
        PyErr_SetString(PyExc_ValueError,
                        "NULL view in getbuffer");
        return -1;
    }
    PyArray* self = (PyArray *)obj;
    view->obj = (PyObject*)self;
    view->buf = (void*)self->array.arr;
    view->len = self->array.length * sizeof(int);
    view->readonly = 0;
    view->itemsize = sizeof(int);
    view->format = "i";
    view->ndim = 1;
    view->shape = (Py_ssize_t *) &self->array.length;
    view->strides = &view->itemsize;
    view->suboffsets = NULL;
    view->internal = NULL;
    Py_INCREF(self);
    return 0;
}
// 将上面的函数放入到 PyBufferProcs 结构体中
static PyBufferProcs PyArray_as_buffer = {
        (getbufferproc)PyArray_getbuffer,
        (releasebufferproc)0
};
static PyTypeObject PyArrayType = {
    PyVarObject_HEAD_INIT(NULL, 0)
    "py_my_array.PyArray",
    sizeof(PyArray),
    0,
    (destructor) PyArray_dealloc,
    0,
    0,                               /* tp_getattr        */
    0,                               /* tp_setattr        */
    0,                               /* tp_reserved       */
    (reprfunc)PyArray_repr,          /* tp_repr           */
    0,                               /* tp_as_number      */
    0,                               /* tp_as_sequence    */
    0,                               /* tp_as_mapping     */
    0,                               /* tp_hash           */
    0,                               /* tp_call           */
    0,                               /* tp_str            */
    0,                               /* tp_getattro       */
    0,                               /* tp_setattro       */
    // 指定 tp_as_buffer
    &PyArray_as_buffer,              /* tp_as_buffer      */
    Py_TPFLAGS_DEFAULT,              /* tp_flags          */
    "PyArray object",              /* tp_doc            */
    0,                               /* tp_traverse       */
    0,                               /* tp_clear          */
    0,                               /* tp_richcompare    */
    0,                               /* tp_weaklistoffset */
    0,                               /* tp_iter           */
    0,                               /* tp_iternext       */
    0,                               /* tp_methods        */
    0,                               /* tp_members        */
    0,                               /* tp_getset         */
    0,                               /* tp_base           */
    0,                               /* tp_dict           */
    0,                               /* tp_descr_get      */
    0,                               /* tp_descr_set      */
    0,                               /* tp_dictoffset     */
    (initproc) PyArray_init,         /* tp_init           */
};
static PyModuleDef py_array_module = {
    PyModuleDef_HEAD_INIT,
    "py_array",
    "this is a module named py_array",
    -1,
    0,
    NULL,
    NULL,
    NULL,
    NULL
};
PyMODINIT_FUNC
PyInit_py_array(void) {
    PyObject *m;
    PyArrayType.tp_new = PyType_GenericNew;
    if (PyType_Ready(&PyArrayType) < 0) return NULL;
    m = PyModule_Create(&py_array_module);
    if (m == NULL) return NULL;
    Py_XINCREF(&PyArrayType);
    PyModule_AddObject(m, "PyArray", 
                      (PyObject *) &PyArrayType);
    return m;
}

现在相信你一定能体会到为什么要有 Cython 存在,因为写原生的 Python/C API 太痛苦了,而且为了简便我们这里使用的是一维数组,但即便如此,也已经很麻烦了。

我们编译成扩展,首先编写 setup.py:

from distutils.core import setup, Extension
from Cython.Build import cythonize
ext = [Extension("py_array",
                 sources=["py_array.c"])]
setup(ext_modules=cythonize(ext, language_level=3))

执行 python setup.py build 生成扩展模块,然后我们来导入它。

import numpy as np
import py_array
print(py_array)
"""
<module 'py_array' from ..\\py_array.cp38-win_amd64.pyd'>
"""
arr = py_array.PyArray(5)
print(arr)
"""
[0, 1, 2, 3, 4]
"""
np_arr = np.asarray(arr)
print(np_arr)
"""
[0 1 2 3 4]
"""
# 两者也是共享内存
np_arr[0] = 123
print(arr)
print(np_arr)
"""
[123, 1, 2, 3, 4]
[123   1   2   3   4]
"""

显然此时万事大吉了,因为实现了缓冲区协议,Numpy 知道了缓冲区数据,因此会在此基础之上建一个 view,并且 array 和 np_arr 是共享内存的。

因此核心就在于对缓冲区协议的理解,它本质上就是一个结构体,内部的成员描述了缓冲区数据的所有信息。而我们只需要定义一个函数,然后根据数组初始化这些信息即可,最后构建 PyBufferProcs 实例作为 tp_as_buffer 成员的值。


小结




以上我们就介绍了如何通过原生的 Python/C API 和 Cython 实现缓冲区协议,通过两个例子的对比,我们算是体会到了 Cython 的难能可贵,对于想写扩展的人来说,Cython 无疑是一大福音。即使是缓冲区协议,Cython 也有着超一流的支持。

但说实话,缓冲区协议我们在工作中几乎不用手动实现,因为它还是比较原始和底层的,我们知道就好。而 Cython 在缓冲区协议的基础上提供了一种新的数据结构:内存视图,它屏蔽了缓冲区协议的具体细节,可以让我们在不和缓冲区协议直接打交道的情况下,使用缓冲区协议。关于内存视图,下一篇文章来介绍。


E N D



相关文章
|
3天前
|
C语言
数据流与缓冲区
数据流 就C程序而言,从程序移进,移出字节,这种字节流就叫做流。程序与数据的交互是以流的形式进行的。进行C语言文件的读写时,都会先进行“打开文件”操作,这个操作就是在打开数据流,而“关闭文件”操作就是关闭数据流。 缓冲区 在程序执行时,所提供的额外内存,可用来暂时存放准备执行的数据。它的设置是为了提高存取效率,因为内存的存取速度比磁盘驱动器快得多。 当使用标准I/O函数(包含在头文件stdio.h中)时,系统会自动设置缓冲区,并通过数据流来读写文件。当进行文件读取时,是先打开数据流,将磁盘上的文件信息拷贝到缓冲区内,然后程序再从缓冲区中读取所需数据。事实
|
2月前
|
存储 算法 API
解密缓冲区协议
解密缓冲区协议
33 0
|
5月前
|
缓存 移动开发 网络协议
你真的知道TCP协议中的序列号确认、上层协议及记录标识问题吗?
【6月更文挑战第3天】本文探讨了TCP的场景问题,包括序列号确认、上层协议确定及应用程序记录标识。在序列号确认问题中,解释了B主机为何确认号为1000。确定上层协议通过解析IP头的协议字段实现。应用程序可通过特定协议头、固定长度数据块或消息边界标记来标识记录。此外,文章还对比了TCP与UDP在连接性、可靠性、速度及资源占用上的差异。理解这些概念有助于解决面试和实际应用中的问题。
107 6
|
6月前
|
存储 缓存 小程序
详细讲解缓冲区
详细讲解缓冲区
修改udp的缓冲区大小
修改udp的缓冲区大小
189 0
修改udp的缓冲区大小
|
网络架构
为什么udp流设置1316字节
为什么udp流设置1316字节
103 0
|
存储 网络协议 Linux
网络缓冲区
网络缓冲区
70 0
|
C语言
理解缓冲区
理解缓冲区
105 0
|
存储 消息中间件 NoSQL
计网 - 流和缓冲区:缓冲区的 flip 是怎么回事?
计网 - 流和缓冲区:缓冲区的 flip 是怎么回事?
94 0
|
机器学习/深度学习 索引
【Netty】NIO 缓冲区 ( Buffer ) ( 缓冲区读写类型 | 只读缓冲区 | 映射字节缓冲区 )(二)
【Netty】NIO 缓冲区 ( Buffer ) ( 缓冲区读写类型 | 只读缓冲区 | 映射字节缓冲区 )(二)
134 0
【Netty】NIO 缓冲区 ( Buffer ) ( 缓冲区读写类型 | 只读缓冲区 | 映射字节缓冲区 )(二)