实现缓冲区协议

简介: 实现缓冲区协议


楔子




了解完缓冲区协议之后,我们就来手动实现一下。而实现方式可以使用原生的 Python/C API,也可以使用 Cython,但前者实现起来会非常麻烦,牵扯的知识也非常多,而 Cython 则简化了这一点。

我们来分别介绍一下这两种方式。


使用 Cython 实现缓冲区协议




Cython 对缓冲区协议也有着强大的支持,我们只需要定义一个魔法方法即可实现缓冲区协议。

from cpython cimport Py_buffer
from cpython.mem cimport PyMem_Malloc, PyMem_Free
cdef class Matrix:
    cdef Py_ssize_t shape[2]  # 数组的形状
    cdef Py_ssize_t strides[2]  # 数组的 stride
    cdef float *array
    def __cinit__(self, row, col):
        self.shape[0] = <Py_ssize_t> row
        self.shape[1] = <Py_ssize_t> col
        self.strides[1] = sizeof(float)
        self.strides[0] = self.strides[1] * self.shape[1]
        self.array = <float *> PyMem_Malloc(
            self.shape[0] * self.shape[1] * sizeof(float))
    def set_item_by_index(self, int index, float value):
        """留一个接口,用来设置元素"""
        if index >= self.shape[0] * self.shape[1] or index < 0:
            raise ValueError("索引无效")
        self.array[index] = value
    def __getbuffer__(self, Py_buffer *buffer, int flags):
        """自定义缓冲区需要实现 __getbuffer__ 方法"""
        cdef int i;
        for i in range(self.shape[0] * self.shape[1]):
            self.array[i] = float(i)
        # 缓冲区,这里是就是 array 本身,但是需要转成 void *
        buffer.buf = <void *> self.array
        # 实现缓冲区协议的对象,显然是 selfself.shape[0] * self.shape[1]
        buffer.obj = self
        # 缓冲区的总大小
        buffer.len = self.shape[0] * self.shape[1] * sizeof(float)
        # 读写权限,这里让缓冲区可读写
        buffer.readonly = 0
        # 缓冲区每个元素的大小
        buffer.itemsize = sizeof(float)
        # 元素类型,"f" 表示 float
        buffer.format = "f"
        # 该对象的维度
        buffer.ndim = 2
        # shape
        buffer.shape = self.shape
        # strides
        buffer.strides = self.strides
        # 直接设置为 NULL 即可
        buffer.suboffsets = NULL
    def dealloc(self):
        if self.array != NULL:
            PyMem_Free(<void *> self.array)

在 Cython 中我们只需要实现一个相应的魔法方法即可,真的是非常方便,当然我们为了验证是否共享内存,专门定义了一个方法。

import pyximport
pyximport.install(language_level=3)
import cython_test
import numpy as np
m = cython_test.Matrix(5, 4)
# 基于 m 创建 Numpy 数组
np_m = np.asarray(m)
# m 和 np_m 是共享内存的
print(m) 
"""
<cython_test.Matrix object at 0x7f96ba55a3f0>
"""
print(np_m)
"""
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]
 [12. 13. 14. 15.]
 [16. 17. 18. 19.]]
"""
# 通过 m 修改元素,然后打印 np_m
m.set_item_by_index(13, 666.666)
print(np_m)
"""
[[  0.      1.      2.      3.   ]
 [  4.      5.      6.      7.   ]
 [  8.      9.     10.     11.   ]
 [ 12.    666.666  14.     15.   ]
 [ 16.     17.     18.     19.   ]]
"""

结果没有任何问题,以上就是 Cython 实现缓冲区协议,其实在日常工作中我们不需要直接面对它,但了解一下总是好的。


使用 Python/C API 实现缓冲区协议




注:通过原生的 Python/C API 实现缓冲区协议,这个过程非常麻烦,因为这需要你熟悉解释器源代码,以及这些 API 本身。我们的重点是 Cython,只要知道 Cython 如何实现缓冲区协议即可。至于原生的 Python/C API,感兴趣的话可以看一看,不感兴趣的话跳过即可。

下面编写 C 源文件,文件名为 py_array.c。

#include <stdio.h>
#include <stdlib.h>
#include <Python.h>
// 定义一个一维的数组
typedef struct {
    int *arr;
    int length;
} Array;
// 初始化函数
void initial_Array(Array *array, int length) {
    array->length = length;
    if (length == 0) {
        array->arr = NULL;
    } else {
        array->arr = (int *) malloc(sizeof(int) * length);
        for (int i = 0; i < length; i++) {
            array->arr[i] = i;
        }
    }
}
// 释放内存
void dealloc_Array(Array *array) {
    if (array->arr != NULL) free(array->arr);
    array->arr = NULL;
}
// Python 的对象在 C 中都嵌套了 PyObject
typedef struct {
    PyObject_HEAD
    Array array;
} PyArray;
// 初始化 __init__ 函数
static int
PyArray_init(PyArray *self, PyObject *args, PyObject *kwargs) {
    if (self->array.arr != NULL) {
        dealloc_Array(&self->array);
    }
    int length = 0;
    static char *kwlist[] = {"length", NULL};
    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", kwlist, &length)) {
        return -1;
    }
    // 因为给 Python 调用,所以这里额外对 length 进行一下检测
    if (length < 0) {
        PyErr_SetString(PyExc_ValueError, "argument 'length' can not be negative");
        return -1;
    }
    initial_Array(&self->array, length);
    return 0;
}
// 析构函数
static void
PyArray_dealloc(PyArray *self) {
    dealloc_Array(&self->array);
    Py_TYPE(self)->tp_free((PyObject *) self);
}
static PyObject *
PyArray_repr(PyArray *self) {
    //转成列表打印
    PyObject *list = PyList_New(self->array.length);
    Py_ssize_t i;
    for (i=0; i<self->array.length; i++){
        PyList_SetItem(list, i, PyLong_FromLong(*(self->array.arr + i)));
    }
    PyObject *ret = PyObject_Str(list);Py_DECREF(list);
    return ret;
}
// 实现缓冲区协议
static int
PyArray_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
    if (view == NULL) {
        PyErr_SetString(PyExc_ValueError,
                        "NULL view in getbuffer");
        return -1;
    }
    PyArray* self = (PyArray *)obj;
    view->obj = (PyObject*)self;
    view->buf = (void*)self->array.arr;
    view->len = self->array.length * sizeof(int);
    view->readonly = 0;
    view->itemsize = sizeof(int);
    view->format = "i";
    view->ndim = 1;
    view->shape = (Py_ssize_t *) &self->array.length;
    view->strides = &view->itemsize;
    view->suboffsets = NULL;
    view->internal = NULL;
    Py_INCREF(self);
    return 0;
}
// 将上面的函数放入到 PyBufferProcs 结构体中
static PyBufferProcs PyArray_as_buffer = {
        (getbufferproc)PyArray_getbuffer,
        (releasebufferproc)0
};
static PyTypeObject PyArrayType = {
    PyVarObject_HEAD_INIT(NULL, 0)
    "py_my_array.PyArray",
    sizeof(PyArray),
    0,
    (destructor) PyArray_dealloc,
    0,
    0,                               /* tp_getattr        */
    0,                               /* tp_setattr        */
    0,                               /* tp_reserved       */
    (reprfunc)PyArray_repr,          /* tp_repr           */
    0,                               /* tp_as_number      */
    0,                               /* tp_as_sequence    */
    0,                               /* tp_as_mapping     */
    0,                               /* tp_hash           */
    0,                               /* tp_call           */
    0,                               /* tp_str            */
    0,                               /* tp_getattro       */
    0,                               /* tp_setattro       */
    // 指定 tp_as_buffer
    &PyArray_as_buffer,              /* tp_as_buffer      */
    Py_TPFLAGS_DEFAULT,              /* tp_flags          */
    "PyArray object",              /* tp_doc            */
    0,                               /* tp_traverse       */
    0,                               /* tp_clear          */
    0,                               /* tp_richcompare    */
    0,                               /* tp_weaklistoffset */
    0,                               /* tp_iter           */
    0,                               /* tp_iternext       */
    0,                               /* tp_methods        */
    0,                               /* tp_members        */
    0,                               /* tp_getset         */
    0,                               /* tp_base           */
    0,                               /* tp_dict           */
    0,                               /* tp_descr_get      */
    0,                               /* tp_descr_set      */
    0,                               /* tp_dictoffset     */
    (initproc) PyArray_init,         /* tp_init           */
};
static PyModuleDef py_array_module = {
    PyModuleDef_HEAD_INIT,
    "py_array",
    "this is a module named py_array",
    -1,
    0,
    NULL,
    NULL,
    NULL,
    NULL
};
PyMODINIT_FUNC
PyInit_py_array(void) {
    PyObject *m;
    PyArrayType.tp_new = PyType_GenericNew;
    if (PyType_Ready(&PyArrayType) < 0) return NULL;
    m = PyModule_Create(&py_array_module);
    if (m == NULL) return NULL;
    Py_XINCREF(&PyArrayType);
    PyModule_AddObject(m, "PyArray", 
                      (PyObject *) &PyArrayType);
    return m;
}

现在相信你一定能体会到为什么要有 Cython 存在,因为写原生的 Python/C API 太痛苦了,而且为了简便我们这里使用的是一维数组,但即便如此,也已经很麻烦了。

我们编译成扩展,首先编写 setup.py:

from distutils.core import setup, Extension
from Cython.Build import cythonize
ext = [Extension("py_array",
                 sources=["py_array.c"])]
setup(ext_modules=cythonize(ext, language_level=3))

执行 python setup.py build 生成扩展模块,然后我们来导入它。

import numpy as np
import py_array
print(py_array)
"""
<module 'py_array' from ..\\py_array.cp38-win_amd64.pyd'>
"""
arr = py_array.PyArray(5)
print(arr)
"""
[0, 1, 2, 3, 4]
"""
np_arr = np.asarray(arr)
print(np_arr)
"""
[0 1 2 3 4]
"""
# 两者也是共享内存
np_arr[0] = 123
print(arr)
print(np_arr)
"""
[123, 1, 2, 3, 4]
[123   1   2   3   4]
"""

显然此时万事大吉了,因为实现了缓冲区协议,Numpy 知道了缓冲区数据,因此会在此基础之上建一个 view,并且 array 和 np_arr 是共享内存的。

因此核心就在于对缓冲区协议的理解,它本质上就是一个结构体,内部的成员描述了缓冲区数据的所有信息。而我们只需要定义一个函数,然后根据数组初始化这些信息即可,最后构建 PyBufferProcs 实例作为 tp_as_buffer 成员的值。


小结




以上我们就介绍了如何通过原生的 Python/C API 和 Cython 实现缓冲区协议,通过两个例子的对比,我们算是体会到了 Cython 的难能可贵,对于想写扩展的人来说,Cython 无疑是一大福音。即使是缓冲区协议,Cython 也有着超一流的支持。

但说实话,缓冲区协议我们在工作中几乎不用手动实现,因为它还是比较原始和底层的,我们知道就好。而 Cython 在缓冲区协议的基础上提供了一种新的数据结构:内存视图,它屏蔽了缓冲区协议的具体细节,可以让我们在不和缓冲区协议直接打交道的情况下,使用缓冲区协议。关于内存视图,下一篇文章来介绍。


E N D



相关文章
|
2天前
|
存储 算法 API
解密缓冲区协议
解密缓冲区协议
8 0
修改udp的缓冲区大小
修改udp的缓冲区大小
167 0
修改udp的缓冲区大小
|
网络架构
为什么udp流设置1316字节
为什么udp流设置1316字节
88 0
|
12月前
|
存储
14.3 Socket 字符串分块传输
首先为什么要实行分块传输字符串,一般而言`Socket`套接字最长发送的字节数为`8192`字节,如果发送的字节超出了此范围则后续部分会被自动截断,此时将字符串进行分块传输将显得格外重要,分块传输的关键在于封装实现一个字符串切割函数,将特定缓冲区内的字串动态切割成一个个小的子块,当切割结束后会得到该数据块的个数,此时通过套接字将个数发送至服务端此时服务端在依次循环接收数据包直到接收完所有数据包之后在组合并显示即可。
57 0
14.3 Socket 字符串分块传输
|
11月前
|
存储 网络协议 Linux
网络缓冲区
网络缓冲区
59 0
|
机器学习/深度学习 索引
【Netty】NIO 缓冲区 ( Buffer ) ( 缓冲区读写类型 | 只读缓冲区 | 映射字节缓冲区 )(二)
【Netty】NIO 缓冲区 ( Buffer ) ( 缓冲区读写类型 | 只读缓冲区 | 映射字节缓冲区 )(二)
131 0
【Netty】NIO 缓冲区 ( Buffer ) ( 缓冲区读写类型 | 只读缓冲区 | 映射字节缓冲区 )(二)
|
存储
【Netty】NIO 缓冲区 ( Buffer ) ( 缓冲区读写类型 | 只读缓冲区 | 映射字节缓冲区 )(一)
【Netty】NIO 缓冲区 ( Buffer ) ( 缓冲区读写类型 | 只读缓冲区 | 映射字节缓冲区 )(一)
193 0
|
SQL 网络协议 关系型数据库
TCP性能和发送接收Buffer的关系
# 前言 本文希望解析清楚,当我们在代码中写下 socket.setSendBufferSize 和 sysctl 看到的rmem/wmem系统参数以及最终我们在TCP常常谈到的接收发送窗口的关系,以及他们怎样影响TCP传输的性能。 先明确一下:**文章标题中所说的Buffer指的是sysctl中的 rmem或者wmem,如果是代码中指定的话对应着SO_SNDBUF或者SO_RCVB
5808 0
|
数据处理 C# 索引
网络数据处理缓冲区和缓冲池实现
在编写网络应用的时候数据缓冲区是应该比较常用的方式,主要用构建一个内存区用于存储发送的数据和接收的数据;为了更好的利用已有数据缓冲区所以构造一个缓冲池来存放相关数据方便不同连接更好地利用缓冲区,节省不停的构造新的缓冲区所带的损耗问题。
868 0
|
存储 缓存 网络协议
Linux网络协议栈(二)——套接字缓存(socket buffer)
Linux网络核心数据结构是套接字缓存(socket buffer),简称skb。它代表一个要发送或处理的报文,并贯穿于整个协议栈。1、    套接字缓存skb由两部分组成:(1)    报文数据:它保存了实际在网络中传输的数据;(2)    管理数据:供内核处理报文的额外数据,这些数据构成了协议之间交换的控制信息。
1436 0