1. summary
本文以Omnigen项目(https://github.com/VectorSpaceLab/OmniGen)为例,对LLM训练过程中涉及与存储交互的部分在代码逻辑上做了梳理。整体分为模型文件加载侧以及训练数据加载侧两部分,训练数据除包含常规结构化的文本数据之外,又包含了图像相关数据的读写逻辑的梳理。整体包含了Python\Cpython\Rust\Cpp语言的safetensors\torch storage\torch dataset\torch dataloader\huggingface datasets\PIL\cpython open等实现逻辑。
2. 模型文件加载侧
2.1. 简介
当前项目涉及多个模态,与加载模型文件相关的部分共涉及如下4部分与存储交互
- 主模型
- vae编码器
- tokenizer
- lora
根据模型文件类型的不同,走了不同的加载方法和库
model.safetensors -> huggingface satetensor -> rust
model.pt -> torch -> python内置函数open -> cpython(torch.serialization._open_file.__init__ 本质上还是走了python内置方法open,我们将在builtins.open中做展开)
2.2. rust safetensors
https://github.com/huggingface/safetensors
safetensors._safetensors_rust.safe_open.__init__ 这个本质上走了huggingface safetensor的rust实现
bindings python function:
https://github.com/huggingface/safetensors/blob/7d5af853631628137a79341ddc5611d18a17f3fe/bindings/python/src/lib.rs#L709
这里可以看到该rust实现的逻辑,new这个关联函数会带着filename\framework(pt)\device(cpu)(后面两个都是enum对象,可以在源码中找到对应string的定义)去new一个Open实例,模型层的key以及对应的tensor均是Open实例(通过self.inner引用)的方法
lib.rs中的Open是实际处理打开文件以及内容获取:https://github.com/huggingface/safetensors/blob/7d5af853631628137a79341ddc5611d18a17f3fe/bindings/python/src/lib.rs#L399
Open结构体定义如下:
其中比较重要的几个逻辑
- std::fs::File::open打开filename的fd
- memmap2::{Mmap, MmapOptions}进行文件的虚拟内存映射,并设置为只读模式
- 本例中,frameworks为pytorch,这里storage会根据pytorch的版本是否大于1.11走torch.storage.UntypedStorage所继承基类torch.storage._StorageBase的from_file方法或者mmap(使用逻辑2中的buffer,这里不展开,基本上torch的版本目前已经2.3+ ~ 2.5),并在rust中通过std::sync::Arc封装确保线程安全。
get_tensor
- 获取metadata中指定key的tensor info,dtypes等,主要是offsets
- 根据data.offsets创建slice切片
- 通过torch的storage的__getitem__方法结合切片来获取tensor并使用asarray创建
- 进行dtype转换和reshape
这里torch框架的storage其实完成了mmap的实现
https://pytorch.org/docs/stable/storage.html#torch.UntypedStorage.from_file
所以这里基本就是处理一些异常,做兼容,剩下都是在调pytorch的python库来读weight到tensor里
2.3. torch.storage
通过代码结构我们可以看出,torch.UntypedStorage.from_file方法实际上是torch._C.StorageBase基类中的方法
https://github.com/pytorch/pytorch/blob/ea906a47c009df49971a6707b167ec20037a055a/torch/storage.py#L205
在Pytorch源码的csrc下我们可以找到Storage中from_file方法对应的cpp实现是THPStorage_fromFile,这里包含了解析入参,创建storage对象,以及包装返回python对象PyObject*
这里主要是at::MapAllocator::makeDataPtr这个方法对filename入参及其他判断的情况下做了mmap的存储映射
这里与存储相关比较核心的就是这个context对象,这个对象是通过初始化MapAllocator的构造方法得到的
MapAllocator::MapAllocator在代码中有重载的两个方法,最开始是带着filename调用的如下方法:
MapAllocator::MapAllocator(c10::string_view filename, int flags, size_t size)
该方法内部做了重载方法的调用
MapAllocator::MapAllocator(WithxFd, c10::string_view filename, int fd, int flags, size_t size)
在该方法的内部根据操作系统的种类(_WIN32)以及是否支持MMAP等做了一些校验和容错,核心还是通过mmap做的
3. 训练数据加载侧
3.1. torch.utils.data.dataloader.DataLoader
https://github.com/pytorch/pytorch/blob/main/torch/utils/data/dataloader.py
for data in dataloader: model.forward(data)
这个过程调用了
torch.utils.data.dataloader.DataLoader.__iter__生成迭代器。生成的对象的基类为:torch.utils.data.dataloader._BaseDataLoaderIter
torch.utils.data.dataloader.DataLoader._get_iterator。根据num_workers不同配置生成并初始化继承类,后续以_SingleProcessDataLoaderIter
为例,init方法的入参是dataloader对象
torch.utils.data.dataloader._SingleProcessDataLoaderIter.__init__中对基类进行了初始化,比较重要的是torch.utils.data.dataloader._DatasetKind.create_fetcher用于后面的数据获取。
create_fetcher有两个细节,会判断kind来决定Fecher类
- _DatasetKind.Map -> torch.utils.data._utils.fetch._MapDatasetFetcher
- _DatasetKind.Iterable -> torch.utils.data._utils.fetch._IterableDatasetFetcher
kind来自于DataLoader的attribute:_dataset_kind,主要判断依据如下图所示是看dataset的类型,本例中dataset由OmniGen.train_helper.data.DatasetFromJson实例化而来,继承自torch.utils.data.Dataset所以实例化的是_MapDatasetFetcher
Fetcher的fetch方法中的逻辑依赖_auto_collation做判断,这个属性来自DataLoader如下property的装饰方法
@property def _auto_collation(self): return self.batch_sampler is not None
从这里看应该是不为空的,Map的datasetkind会根据配置了的shuffle,初始化RandomSampler,配置batch_size就会被Batchsampler包装,间接导致_auto_collation为True。
通过DataLoader传入的dataset的__getitems__方法来获取数据,后面for的时候会来这里调用。
for循环遍历过程通过torch.utils.data.dataloader._BaseDataLoaderIter.__next__来获取数据,__next__方法中调用了_next_data方法获取数据,这个方法需要子类单独实现。
torch.utils.data.dataloader._SingleProcessDataLoaderIter._next_data中有比较重要的几部分
- 通过sampler获取index(torch.utils.data.dataloader._BaseDataLoaderIter._next_index)
- 通过index配合fecher的fetch方法,这里就进到dataset的逻辑了,我们下一章说
- 获取到数据
- pin memory的处理加速到GPU的数据传输(pin memory参考)
这里看下index获取的逻辑,
torch.utils.data.dataloader._BaseDataLoaderIter._next_index有如下调用逻辑:
def _next_index(self): return next(self._sampler_iter) # may raise StopIteration self._sampler_iter = iter(self._index_sampler) self._index_sampler = loader._index_sampler @property def _index_sampler(self): if self._auto_collation: return self.batch_sampler else: return self.sampler
这部分在_auto_collation的时候提到过这部分的类的实例化
self.batch_sampler 即 torch.utils.data.sampler.BatchSampler 封装了 torch.utils.data.sampler.SequentialSampler 或 torch.utils.data.sampler.RandomSampler
这三个类的实例化比较简单,通过next + iter调用得到batch的index,这里根据是否放弃最后不足一个batch的数据分成了两个逻辑
如果放弃最后一个batch的数据,这里就是暴力的next个batch_size个index,如果报错就跳出循环;
如果保留所有的数据,则会根据batch_size初始化特定长度的列表容器,通过for循环的方式不断的往容器里塞index->到达batch大小->重新初始化列表容器\idx_in_batch归零去yield batch index,最后一个不足一个batch的切片后返回。
SequentialSampler的实现:比较简单,直接就是range(n length dataset)
RandomSampler的实现:这里根据是否配置了有放回抽样,分别使用了torch.randint和torch.randperm来实现抽样
3.2. torch.utils.data.Dataset
https://github.com/pytorch/pytorch/blob/main/torch/utils/data/dataset.py
OmniGen.train_helper.data.DatasetFromJson继承自torch.utils.data.Dataset,初始化方法:
dataset = DatasetFromJson(json_file=args.json_file, image_path=args.image_path, processer=processor, image_transform=image_transform, max_input_length_limit=args.max_input_length_limit, condition_dropout_prob=args.condition_dropout_prob, keep_raw_resolution=args.keep_raw_resolution )
通过上面DataLoader代码的分析,我们知道了对于dataset总共就进行了两件事
- len(dataset) 获取数据集长度
- __getitem__获取某个index的数据
DatasetFromJson在Dataset基类的基础上实现了上述两个方法,这里主要看__getitem__的方法,这个方法调用了get_example方法获取数据
这里通过给定的index索引在data中取数据,除此之外还涉及了针对指令进行条件dropout成<cfg>的逻辑、图像加载与处理、多模态数据处理等逻辑
这里与存储交互涉及两部分
- 通过索引在hg的datasets库所加载的json文件中取数据,我们在下章展开
- 通过获取到的图像文件通过PIL读取图像文件,我们在下下章展开
3.3. datasets.load.load_dataset
https://github.com/huggingface/datasets
3.3.1. 上下游调用解析
OmniGen.train_helper.data.DatasetFromJson类中的self.data,这里使用的load_dataset方法是huggingface的官方库datasets内置(datasets.load.load_dataset),这里使用path(json) + data_files的入参方式来读取本地文件并封装,后续被torch的DataLoader通过
OmniGen.train_helper.data.DatasetFromJson.__getitem__ ->
OmniGen.train_helper.data.DatasetFromJson.get_example ->
self.data[index]来调用,实际上调用的是
datasets.arrow_dataset.Dataset类对象的__getitem__方法。
datasets是huggingface社区的开源的轻量级但功能强大的数据集库,社区500+ contributer,支持下载公共数据集、加载本地多格式的文件,支持对接torch\tf等主流深度学习框架,支持音视频文本等多模态数据,支持apache arrow的memory map读取海量数据集\高效的数据预处理及缓存机制。load_dataset方法同样功能强大,本文仅对访问存储部分做展开。
3.3.2. 关键源码分析
3.3.3. datasets.load.load_dataset
该方法中有如上三个方法比较重要
- builder_instance = load_dataset_builder(...) -> 初始化datasets.builder.DatasetBuilder
- builder_instance.download_and_prepare(...) 这里的builder对象就已经有数据集的大小等信息在info里了,读本地json的数据集的话相当于读了两遍
- return ds = builder_instance.as_dataset(...) -> datasets.arrow_dataset.Dataset mmap映射\读取数据、batch封装
3.3.3.1. load_dataset_builder
datasets.load.load_dataset_builder -> datasets.builder.DatasetBuilder
load_dataset的入参默认值均为None或者False,本地调用仅path='json',data_files为待读取数据。
这个方法中有如下三个比较重要的方法
- dataset_module_factory
- get_dataset_builder_class
- builder_cls()
3.3.3.1.1. dataset_module_factory
-- datasets.load.dataset_module_factory
默认download_config为空,初始化DownloadConfig,对download_config的一些参数做赋值,不过咱们这里是读本地json文件,后续也没用到
如果path(这里是json)在_PACKAGED_DATASETS_MODULES中的话,_PACKAGED_DATASETS_MODULES是一些基础的存储格式,像csv,json,pandas,arrow等,见上图,则去初始化datasets.load.PackagedDatasetModuleFactory 类,__init__方法中主要就是赋值和调用increase_load_count方法来做数据集下载的次数上报,网络问题会捕获exception直接pass。这里比较重要的是
datasets.load.PackagedDatasetModuleFactory.get_module方法
base_path获取当前工作路径
patterns 会返回一个字典 key为split(这里还有test和validation,入参是string的话默认只给一个train),value是对应的数据的路径
data_files这里递归调用DataFilesDict.from_patterns来解析所有的数据文件确保列表patterns中都是DataFilesList对象
这里比较主要的就是module_path,这里是通过查hash表,找到“json”对应的类的module_path:"json": (json.__name__, _hash_python_lines(inspect.getsource(json).splitlines())),是datasets.packaged_modules.json.json。
最后将上述信息用DatasetModule这个dataclass包装。
3.3.3.1.2. get_dataset_builder_class
datasets.load.get_dataset_builder_class -> {type} <class 'datasets.packaged_modules.json.json.Json'>
这里主要是这个方法比较重要
datasets.load.import_main_class,这里主要是通过datasets.packaged_modules.json.json这个module_path进行加载,并找到对应的DatasetBuilder的子类把对象吐出去,准备后面的init
3.3.3.1.3. builder_cls
datasets.packaged_modules.json.json.Json其中init方法在这里datasets.builder.DatasetBuilder.__init__
这里最开始没有local的cache dir,所以这里的逻辑会较为简单,只是初始化了cache_dir,fs等相关信息并创建好实例
3.3.3.2. download_and_prepare
datasets.builder.DatasetBuilder.download_and_prepare
方法的入参:
有如下几个逻辑,前面的一些比较简单
- 通过url_to_fs获取到文件系统的实例以及output dir
- download_config以及dl_manager的初始化
- 如果是本地文件系统,初始化output dir的父目录,并且对output dir创建锁,避免并行的本地操作
- 这里有个datasets.utils._filelock.FileLock 锁方法,会创建self._output_dir + "_builder.lock"锁文件并进行相关的操作
- DATASET_INFO_FILENAME = "dataset_info.json" 判断info文件是否存在
- has_sufficient_disk_space 判断底盘空间
- 确认 _dest的写出目录空间
- _check_manual_download 检查是否多模态数据下载
incomplete_dir方法有上下文管理的装饰器 @contextlib.contextmanager,在with语法块的加持下可以做到在is_local=True的环境下,创建名为「原文件名+.incomplete」的文件夹用来做后续的xxx,并通过yield返回回去,在再次获取到执行权限的时候删除老目录,重命名incomplete目录为老目录,最后还会尝试删除临时目录
temporary_assignment方法与incomplete_dir类似,只不过这个作用在当前对象,通过这个方法可以在with语法块中临时替换_output_dir
3.3.3.2.1. _download_and_prepare
datasets.builder.DatasetBuilder._download_and_prepare
_split_generators
该方法走的json继承类的方法(datasets.packaged_modules.json.json.Json._split_generators)
用datasets.splits.SplitGenerator封装下面方法:datasets.download.download_manager.DownloadManager.download_and_extract的返回
download_and_extract这个方法调用链路如下,主要逻辑是根据输入的data_files进行下载和提取本地原始文件或者缓存的path路径,内部包含了嵌套数据结构的递归处理、单\多线程下载以及计时、checksum校验,主要调用链路如下:
datasets.download.download_manager.DownloadManager.download_and_extract ->
datasets.download.download_manager.DownloadManager.extract ->
datasets.download.download_manager.DownloadManager.download ->
datasets.download.download_manager.DownloadManager._download_batched \ _download_single
这里会对 _split_generators 生成的对象做for loop的循环调用,这里单个item对应的就是train的数据集,即我们在最上层方法中配置的 /Users/adamsun/pai/OmniGen/OmniGen/toy_data/sunyf.jsonl
datasets.builder.ArrowBasedBuilder._prepare_split
主要的逻辑:
- max_shard_size 每个shard最大的大小,默认500MB
- split_info 切分信息获取,这里直接用了split_generator.split_info
- 根据默认的前缀SUFFIX,初始化fname和fpath,封装进_prepare_split_args
- 后续根据num_proc的不同取值区分了单独处理和并发处理(multiprocess.Pool)两种,最后执行的都是_prepare_split_single 这个方法
datasets.builder.ArrowBasedBuilder._prepare_split_single的逻辑:
- 这里重写了一个list的子类,看着是用来追踪当前iter的是什么,不知道再整个项目里是啥用处,先mark下,tracked_list
- _generate_tables 返回的是个迭代器对象,这里没有进到具体的实现方法,下面for loop的时候具体看,这里应该是涉及具体的读取数据->封装pyarrow的部分
- writer_class被设定为datasets.arrow_writer.ArrowWriter
- 这里embed_local_files有点意思,判断的是format是不是parquet,暂时没多看
- datasets.arrow_writer.ArrowWriter的初始化,这里主要replace了shardid和jobid,类的初始化里判断没有schema和features,初始化空字符串的hashkey,并通过指定filesystem的实现类(这里是本地)的open方法wb二进制打开了文件流准备写入
- 初始化计时,并进入循环,for _, table in generatoras_dataset
- 进到_generate_tables的方法,实现类是datasets.packaged_modules.json.json.Json,返回的是文件索引以及pyarrow.lib.Table.Table
- 这里会根据json数据的配置判断走哪个部分,如果config.field进行了配置的话会默认数据是一个大json里面有一个字段是存储数据的,这时候会通过pandas先读取,拿到这个字段的数据,再dump序列化,通过pandas+io stream的方式读回到内存,转换成pyarrow的Table对象,这部分的逻辑在每行一个json数据读取异常降级的逻辑一致
- 每行一个json数据的形式会通过python buildin open对象打开文件,每次读取一个trunk size,并通过io.stream的readline或者自定义的readline方法保证最后一行数据完整性,根据数据编码decode并encode成utf8,通过pyarrow.json直接读取bytesio,并在异常的时候回退到pandas处理
- 这里_generate_tables方法的yield会返回pa.Tables对象,并通过ArrowWriter写入到cache目录下,写入的过程中还会涉及一个按shard的切分逻辑writer._num_bytes > max_shard_size,方法就是判断超过了max_shard_size后重新开个新文件的writer
- 最后对上述读过的信息做汇总加和,返回total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths等相关信息,包含了数据总数,bytes数,字段数,shard数,每个shard的长度(数据量)
整体_download_and_prepare在读取完数据更新到pyarrow格式的文件后,update下info的信息就返回了,info的信息除了在_download_and_prepare的内部方法中有update之外,还在download_and_prepare的方法中通过self._save_info()方法调用datasets.info.DatasetInfo.write_to_directory写了dataset_info.json以及LICENSE两个元数据文件到本地
其实这里有些疑问,目前看好像读取本地数据都会在本地的cache目录下再写一份arrow格式的数据,这里以本机为例是~/.cache/huggingface/xxx/xxx的目录下,虽然所占空间有压缩,如果数据集足够大,理论上本地盘是肯定不够的,这里怎么处理?
这里做了相关的测试,oss json数据集190G,当前pod本地盘100G左右,因为huggingface datasets会读一遍然后写到~/.cache下变成arrow格式存储,这个格式占用空间会小
如果数据集足够大的情况下会出现osError,同时contextmanager的逻辑会将临时数据删除。
目前看这个部分的逻辑适用于小数据量的sft后训练,大数据量的预训练是有磁盘打满的风险的。
3.3.3.3. as_dataset
datasets.builder.DatasetBuilder.as_dataset -> datasets.arrow_dataset.Dataset
这里的外层逻辑比较简单,map_nested解嵌套,dpartial绑定部分参数到atasets.builder.DatasetBuilder._build_single_dataset方法mmap读取数据
datasets.builder.DatasetBuilder._as_dataset
用途:读取预处理的r数据集文件并且生成一个Dataset对象
初始化datasets.arrow_reader.ArrowReader.__init__对象,并用read方法读取数据,这里cache_dir\dataset_name都在DatasetBuilder的对象中有,前面处理过相关的info
read方法通过get_file_instructions获取到上面arrow的文件路径+文件名,通过datasets.arrow_reader.BaseReader.read_files方法读取数据
read_files通过调用datasets.arrow_reader.BaseReader._read_files读取数据,封装为pyarrow Table,同时将info和split的相关信息整合返回
datasets.arrow_reader.BaseReader._read_files使用thread_map方法(concurrent.futures.ThreadPoolExecutor的实现)来进行并发读取,对应的fn是datasets.arrow_reader.BaseReader._get_table_from_filename
datasets.arrow_reader.BaseReader._get_table_from_filename有两个实现,当前我们在读取的arrow类型的文件,所以实现类是datasets.arrow_reader.ArrowReader._get_table_from_filename
这里主要是调用datasets.arrow_reader.ArrowReader.read_table方法对数据进行mmap读取
in_memory是False的情况下,这里走的是datasets.table.MemoryMappedTable类的类方法from_file
调用链路如下
datasets.table._memory_mapped_arrow_table_from_file ->
datasets.table._memory_mapped_record_batch_reader_from_file
pyarrow.lib.memory_map
pyarrow.lib.RecordBatchReader.RecordBatchReader.read_all
通过datasets.arrow_dataset.Dataset封装后返回
3.4. PIL.Image.open
https://github.com/python-pillow/Pillow/
图片这里直接通过buitins的open方法打开文件,移到文件头并读取16个bytes用于后面的是否是指定文件类型校验(_accept方法)。
同时preinit()方法导入了5中图片格式,并在导入的过程中通过register_open等方法注册不同图片类型的解析类以及识别方法accept到OPEN对象(下面的方法会用到)中
如JPEG对应的工厂方法和_accept方法如下所示
'JPEG': (<function PIL.JpegImagePlugin.jpeg_factory(fp: 'IO[bytes]', filename: 'str | bytes | None' = None) -> 'JpegImageFile | MpoImageFile'>, <function PIL.JpegImagePlugin._accept(prefix: 'bytes') -> 'bool'>)
PIL.JpegImagePlugin.jpeg_factory 主要逻辑在于带着fp初始化JpegImageFile,JpegImageFile继承自PIL.ImageFile.ImageFile,初始化会调用_open方法读取数据,里面图片相关技术点涉及过多,这里了解不多不展开讲,从与存储交互的角度来讲,都是对open后的BufferedReader对象进行read。这里都是builtins function open的逻辑,这部分是CPython进行的实现,我们讲在下一章做下学习。
3.5. cpython builtins.open\io.open
python的open方法在_io.py中是没有实现的,只写了用法
通过type我们可以看出open是个builtin function,这部分要看逻辑需要看cpython的代码了
以cpython 3.10这个branch的代码为例
(base) adamsun@B-30DUQ05P-0039 cpython % git status On branch 3.10 Your branch is up to date with 'origin/3.10'. nothing to commit, working tree clean
简单说下cpython项目源码的结构
通过线程抓取我们可以看到如下的读文件调用栈(采样频率导致的可能不全)
- io.open
- _io.TextIOWrapper.read
- _io.TextIOWrapper.__exit__
带着如上的宏观了解和基础知识我们到cpython的源码中看一下相关的部分
Lib._pyio.open是具体的实现方法,在
Python/pylifecycle.c:2277(init_set_builtins_open)方法中被从io module绑定到builtins module中
io.open初始化:
Lib._pyio.FileIO # 这里有os.open(),在没有自定义实现opener的情况下走os.open()方法获取文件描述符fd
Lib._pyio.BufferedReader.__init__
Lib._pyio.TextIOWrapper.__init__
返回TextIOWrapper对象。
读取数据,对应python代码中 open对象的read()方法,读取的时候这边没有传参,默认size是None后改为-1,这里会带着decoder直接调用buffer的read方法
Lib._pyio.BufferedReader.read 会在内置线程锁的实现下,调用Lib._pyio.BufferedReader._read_unlocked,如果raw有readall方法则直接调用,如果没有的话则会死循环用read方法读完拼接成Bytes返回
Lib._pyio.FileIO.readall。这里我们可以看出同时也封装了read方法,在未指定size的时候使用readall读取数据,readall中也是通过buffer缓冲,以os.read()分chunk读取,拼接到result中返回数据。
3.6. cpython os.open
通过python的builtins.open的实现逻辑我们可以看出,底层最后调用的是os库的open以及read方法。
同样在cpython的Modules中我们可以找到io.open和io.read的实现
Modules/posixmodule.c:9166 os_open_impl
这个方法主要是在不同平台和不同的env的条件下调用操作系统的_wopen(windows)\openat\open来打开文件返回文件描述符fd
Modules/posixmodule.c:9492 os_read_impl
这里主要逻辑就是初始化buffer以及通过_Py_read方法读数据并return buffer
Python/fileutils.c:1722 _Py_read
这里是封装方法,实际还是调用操作系统的read方法并在释放GIL的情况下高效读取数据