Caffe::Snapshot的运行过程

简介: Snapshot的存储概述Snapshot的存储格式有两种,分别是BINARYPROTO格式和hdf5格式。BINARYPROTO是一种二进制文件,并且可以通过修改shapshot_format来设置存储类型。

Snapshot的存储

概述

Snapshot的存储格式有两种,分别是BINARYPROTO格式和hdf5格式。BINARYPROTO是一种二进制文件,并且可以通过修改shapshot_format来设置存储类型。该项的默认是BINARYPROTO不管哪种格式,运行的过程是类似的,都是从Solver<Dtype>::Snapshot()函数进入,首先调用Net网络的方法,再操作网络中的每一层,最后再操作每一层中blob,最后调用write函数写入输出。源码入口:

 1 void Solver<Dtype>::Snapshot() {
 2   CHECK(Caffe::root_solver());
 3   string model_filename;
 4   switch (param_.snapshot_format()) {
 5   case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
 6     model_filename = SnapshotToBinaryProto();
 7     break;
 8   case caffe::SolverParameter_SnapshotFormat_HDF5:
 9     model_filename = SnapshotToHDF5();
10     break;
11   default:
12     LOG(FATAL) << "Unsupported snapshot format.";
13   }

 

BINARYPROTO格式

如果是BINARYPROTO的存储格式,就执行如下代码:

1 string Solver<Dtype>::SnapshotToBinaryProto() {
2   string model_filename = SnapshotFilename(".caffemodel");
3   LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
4   NetParameter net_param;
5   net_->ToProto(&net_param, param_.snapshot_diff());
6   WriteProtoToBinaryFile(net_param, model_filename);
7   return model_filename;
8 }   

 

首先会执行SnapshotFilename(“.caffemodel”)函数,识别出sovler.prototxt文件中snapshot_prefix的内容,作用该snapshot文件的文件名前缀。然后调用net_->ToProto(),具体的代码如下:

 1 void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
 2   param->Clear();
 3   param->set_name(name_);
 4   for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
 5     param->add_input(blob_names_[net_input_blob_indices_[i]]);
 6   }
 7   for (int i = 0; i < layers_.size(); ++i) {
 8     LayerParameter* layer_param = param->add_layer();
 9     layers_[i]->ToProto(layer_param, write_diff);
10   }
11 }  

 

获取到网络中的每层的名字等参数后,调用layers_[i]->ToProto()每一层的ToProto方法,接下来

1 void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {
2   param->Clear();
3   param->CopyFrom(layer_param_);
4   param->clear_blobs();
5   for (int i = 0; i < blobs_.size(); ++i) {
6     blobs_[i]->ToProto(param->add_blobs(), write_diff);
7   }
8 } 

然后调用当前层下的所有blobToProto方法,即:

 1 void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const {
 2   proto->clear_shape();
 3   for (int i = 0; i < shape_.size(); ++i) {
 4     proto->mutable_shape()->add_dim(shape_[i]);
 5   }
 6   proto->clear_double_data();
 7   proto->clear_double_diff();
 8   const double* data_vec = cpu_data();
 9   for (int i = 0; i < count_; ++i) {
10     proto->add_double_data(data_vec[i]);
11   }
12   if (write_diff) {
13     const double* diff_vec = cpu_diff();
14     for (int i = 0; i < count_; ++i) {
15       proto->add_double_diff(diff_vec[i]);
16     }
17   }

 

在每一个blob中,会调用add_double_data()函数,把data添加到snapshot文件中,同时会判断是否当前blob参与diff的计算,如果需要当前blob需要diff参数,就调用add_double_diff()添加到snapshot文件中。

调用完所有的blobToProto()方法后,会执行WriteProtoToBinaryFile()把该文件写出即可。

1 void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
2   fstream output(filename, ios::out | ios::trunc | ios::binary);
3   CHECK(proto.SerializeToOstream(&output));
4 }

在该方法里调用FStreamoutput方法进行输出。

Hdf5格式

Hdf5格式的运行过程和BINARYPROTO格式的过程类似,首先会调用SnapshotToHDF5()函数,即:

1 string Solver<Dtype>::SnapshotToHDF5() {
2   string model_filename = SnapshotFilename(".caffemodel.h5");
3   LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
4   net_->ToHDF5(model_filename, param_.snapshot_diff());
5   return model_filename;
6 }

首先会执行SnapshotFilename(“.caffemodel.h5”)函数,识别出sovler.prototxt文件中snapshot_prefix的内容,作用该snapshot文件的文件名前缀。然后调用net_->ToHDF5(),即:

 1 void Net<Dtype>::ToHDF5(const string& filename, bool write_diff) const {
 2   hid_t file_hid = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
 3       H5P_DEFAULT);
 4   hid_t data_hid = H5Gcreate2(file_hid, "data", H5P_DEFAULT, H5P_DEFAULT,
 5       H5P_DEFAULT);
 6     hid_t diff_hid = -1;
 7   if (write_diff) {
 8     diff_hid = H5Gcreate2(file_hid, "diff", H5P_DEFAULT, H5P_DEFAULT,
 9         H5P_DEFAULT);
10    }
11   for (int layer_id = 0; layer_id < layers_.size(); ++layer_id) {
12     const LayerParameter& layer_param = layers_[layer_id]->layer_param();
13     string layer_name = layer_param.name();
14     hid_t layer_data_hid = H5Gcreate2(data_hid, layer_name.c_str(),
15     hid_t layer_diff_hid = -1;
16     if (write_diff) {
17       layer_diff_hid = H5Gcreate2(diff_hid, layer_name.c_str(),
18           H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);  
19  }
20     int num_params = layers_[layer_id]->blobs().size();
21     for (int param_id = 0; param_id < num_params; ++param_id) {
22       ostringstream dataset_name;
23       dataset_name << param_id;
24       const int net_param_id = param_id_vecs_[layer_id][param_id];
25       if (param_owners_[net_param_id] == -1) {
26         hdf5_save_nd_dataset<Dtype>(layer_data_hid, dataset_name.str(),
27             *params_[net_param_id]);
28       }
29       if (write_diff) {
30         hdf5_save_nd_dataset<Dtype>(layer_diff_hid, dataset_name.str(),
31             *params_[net_param_id], true);
32       }
33 ...............
34 H5Fclose(file_hid);
35 }

该函数首先调用H5Fcreate()创建一个file文件,然后循环调用每一层,通过调用每一层的H5Gcreate2函数记录出该层的data_hid或者diff_hid(如果该层需要参与计算),然后进入每一层内部的blob,然后在当前blob内调用hdf5_save_nd_dataset()hdf5_save_nd_dataset()(如果当前blob需要参与计算diff),将data添加到hdf5格式的文件中,最后调用H5Fclose(file_hid)函数,输出该文件。

 

Snapshot的恢复

概述

想在已经训练好的网络上继续训练,那么需要调用Restore()方法从snapshot的文件中恢复成网络,从而缩短了训练时间。方法的入口是Solver<Dtype>::Restore(const char* state_file)函数,即:

1 void Solver<Dtype>::Restore(const char* state_file) {
2   CHECK(Caffe::root_solver());
3   string state_filename(state_file);
4   if (state_filename.size() >= 3 &&
5       state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
6     RestoreSolverStateFromHDF5(state_filename);
7   } else {
8     RestoreSolverStateFromBinaryProto(state_filename);
9   }

该函数会解析snapshot文件是BINARYPROTO格式还是Hdf5格式,如果是BINARYPROTO格式的话就调用RestoreSolverStateFromBinaryProto()函数,如果格式Hdf5的格式,就执行RestoreSolverStateFromHDF5()

BINARYPROOTO格式

如果是BINARYPROTO格式,则执行下列代码:

 

 1 void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
 2     const string& state_file) {
 3   SolverState state;
 4   ReadProtoFromBinaryFile(state_file, &state);
 5   this->iter_ = state.iter();
 6   if (state.has_learned_net()) {
 7     NetParameter net_param;
 8     ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
 9     this->net_->CopyTrainedLayersFrom(net_param);
10   }
11   this->current_step_ = state.current_step();
12   CHECK_EQ(state.history_size(), history_.size())
13       << "Incorrect length of history blobs.";
14   for (int i = 0; i < history_.size(); ++i) {
15     history_[i]->FromProto(state.history(i));
16   }
17 }

 

该函数会大量调用googleprotobuf包内的函数,首先会通过ReadProtoFromBinaryFile()函数读取BINARYPROTO格式的文件来返回是否可以成功读取。然后判断该snapshot是否有曾经训练过的网络,如果有,则调用函数ReadNetParamsFromBinaryFileOrDie()读取出该Net网络,然后调用函数CopyTrainedLayersFrom(net_param)具体恢复该网络的每一层以及当前层内的所有blob,具体数据恢复的工作就是CopyTrainedLayersFrom()函数内部变量调用FromProto()函数来实现blob复制的。然后会通过函数current_step()来判断上次训练的位置(迭代到多少次),然后通过循环把训练过的data数据通过FromProto()完成数据的复制。

Hdf5格式

 1 void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
 2   hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
 3   CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
 4   this->iter_ = hdf5_load_int(file_hid, "iter");
 5   if (H5LTfind_dataset(file_hid, "learned_net")) {
 6     string learned_net = hdf5_load_string(file_hid, "learned_net");
 7     this->net_->CopyTrainedLayersFrom(learned_net);
 8   }
 9   this->current_step_ = hdf5_load_int(file_hid, "current_step");
10   hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
11   CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
12   int state_history_size = hdf5_get_num_links(history_hid);
13   CHECK_EQ(state_history_size, history_.size())
14       << "Incorrect length of history blobs.";
15   for (int i = 0; i < history_.size(); ++i) {
16     ostringstream oss;
17     oss << i;
18     hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
19                                 kMaxBlobAxes, history_[i].get());
20   }
21   H5Gclose(history_hid);
22   H5Fclose(file_hid);
23 }

该函数会识别hdf5格式存储的snapshot文件的file_hid编号,会判断是否存在之前训练过的网络,如果存在则执行CopyTrainedLayersFrom()函数,完成网络的每层以及每层内的blob的数据的恢复复制,然后或取上一次的训练位置(进行的迭代),并且调用函数hdf5_load_nd_dataset()具体把每次迭代的数据恢复复制,最后再调用H5Fclose()关闭。

 

 

 

 

 

当神已无能为力,那便是魔渡众生
目录
相关文章
|
9月前
|
自然语言处理 监控 安全
SmolLM2:多阶段训练策略优化和高质量数据集,小型语言模型同样可以实现卓越的性能表现
SmolLM2 通过创新的多阶段训练策略、高质量数据集的构建与优化,以及精细的模型后训练调优,在 1.7B 参数规模下实现了卓越的性能表现,并在多个基准测试中超越了同等规模甚至更大规模的语言模型。
342 73
SmolLM2:多阶段训练策略优化和高质量数据集,小型语言模型同样可以实现卓越的性能表现
|
9月前
|
JavaScript 编译器 开发工具
【02】鸿蒙实战应用开发-华为鸿蒙纯血操作系统Harmony OS NEXT-项目开发实战-准备工具安装-编译器DevEco Studio安装-arkts编程语言认识-编译器devco-鸿蒙SDK安装-模拟器环境调试-hyper虚拟化开启-全过程实战项目分享-从零开发到上线-优雅草卓伊凡
【02】鸿蒙实战应用开发-华为鸿蒙纯血操作系统Harmony OS NEXT-项目开发实战-准备工具安装-编译器DevEco Studio安装-arkts编程语言认识-编译器devco-鸿蒙SDK安装-模拟器环境调试-hyper虚拟化开启-全过程实战项目分享-从零开发到上线-优雅草卓伊凡
487 2
【02】鸿蒙实战应用开发-华为鸿蒙纯血操作系统Harmony OS NEXT-项目开发实战-准备工具安装-编译器DevEco Studio安装-arkts编程语言认识-编译器devco-鸿蒙SDK安装-模拟器环境调试-hyper虚拟化开启-全过程实战项目分享-从零开发到上线-优雅草卓伊凡
|
存储 弹性计算 固态存储
阿里云服务器租用价格参考,2核8G、4核16G、8核32G最新收费标准
阿里云服务器2核8G、4核16G、8核32G配置租用价格参考,2024年阿里云产品再一次降价,降价之后2核8G配置按量收费最低收费标准为0.3375元/小时,按月租用标准收费标准为136.0元/1个月。4核16G配置的阿里云服务器按量收费标准最低为0.675元/小时,按月租用标准收费标准为272.0元/1个月。8核32G配置的阿里云服务器按量收费标准最低为1.35元/小时,按月租用标准收费标准为544.0元/1个月。云服务器实例规格的地域和实例规格不同,收费标准不一样,下面是2024年阿里云服务器2核8G、4核16G、8核32G配置的最新租用收费标准。
阿里云服务器租用价格参考,2核8G、4核16G、8核32G最新收费标准
|
11月前
|
算法 Java 数据库
理解CAS算法原理
CAS(Compare and Swap,比较并交换)是一种无锁算法,用于实现多线程环境下的原子操作。它通过比较内存中的值与预期值是否相同来决定是否进行更新。JDK 5引入了基于CAS的乐观锁机制,替代了传统的synchronized独占锁,提升了并发性能。然而,CAS存在ABA问题、循环时间长开销大和只能保证单个共享变量原子性等缺点。为解决这些问题,可以使用版本号机制、合并多个变量或引入pause指令优化CPU执行效率。CAS广泛应用于JDK的原子类中,如AtomicInteger.incrementAndGet(),利用底层Unsafe库实现高效的无锁自增操作。
464 0
理解CAS算法原理
|
Java 程序员 图形学
程序员教你用代码制作飞翔的小鸟--Java小游戏,正好拿去和给女神一起玩
《飞扬的小鸟》Java实现摘要:使用IntelliJ IDEA和JDK 16开发,包含小鸟类`Bird`,处理小鸟的位置、速度和碰撞检测。代码示例展示小鸟图像的加载、绘制与旋转。同时有`Music`类用于循环播放背景音乐。游戏运行时检查小鸟是否撞到地面、柱子或星星,并实现翅膀煽动效果。简单易懂,可直接复制使用。
415 0
|
Java Linux Maven
设置 Maven 环境变量
配置Maven环境变量涉及Windows、Linux和Mac。在Windows上,需新建系统变量`MAVEN_HOME`,值为Maven安装路径,编辑`Path`添加`%MAVEN_HOME%\bin`。在Linux中,下载解压Maven后移动到`/usr/local/`,编辑`/etc/profile`添加`MAVEN_HOME`和`PATH`。在Mac上,类似Linux操作,下载解压后移动到`/usr/local/`,编辑`/etc/profile`。最后,通过`mvn -v`检查是否安装成功。
|
搜索推荐 C#
一个适用于定制个性化界面的WPF UI组件库
一个适用于定制个性化界面的WPF UI组件库
280 0
|
网络协议 Linux 文件存储
Linux系统使用Docker搭建Traefik结合内网穿透实现公网访问管理界面
Linux系统使用Docker搭建Traefik结合内网穿透实现公网访问管理界面
274 2
|
前端开发 UED
产品入门第四讲:Axure动态面板
产品入门第四讲:Axure动态面板
323 0