程序与技术分享:Caffe中Solver解析

本文涉及的产品
云解析DNS,个人版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: 程序与技术分享:Caffe中Solver解析

Caffe中Solver解析


1.Solver的初始化


shared_ptr

solver(caffe::SolverRegistry::CreateSolver(solver_param));


caffe.cpp中的train函数中通过上述的代码定义了一个指向Solver的shared_ptr。其中主要是通过调用SolverRegistry这个类的静态成员函数CreateSolver得到一个指向Solver的指针来构造shared_ptr类型的solver。而且由于C++多态的特性,solver是一个指向基类Solver类型的指针,通过solver这个智能指针来调用各个成员函数会调用到各个子类(SGDSolver等)的函数。


具体步骤:


(1)SolverRegistry::CreateSolver(solver_param)。


(2)通过static的gregistry【type】获取type对应的Solver的Creator函数指针。


(3)调用Creator函数。


(4)new SGDSolver(solver_param)创建solver。


SolverRegistry类源码:


class SolverRegistry


{


public:


typedef Solver (Creator)(const SolverParameter&);


typedef std::map CreatorRegistry;


static CreatorRegistry& Registry()


{


static CreatorRegistry gregistry = new CreatorRegistry();


return gregistry;


}


static void AddCreator(const string& type, Creator creator)


{


CreatorRegistry& registry = Registry();


CHECK_EQ(registry.count(type), 0)


[ "Solver type " [ type [ " already registered.";


registry【type】 = creator;


}


static Solver CreateSolver(const SolverParameter& param)


{


const string& type = param.type();


CreatorRegistry& registry = Registry();


CHECK_EQ(registry.count(type), 1) [ "Unknown solver type: " [ type


[ " (known types: " [ SolverTypeListString() [ ")";


return registry【type】(param);


}


static vector SolverTypeList()


{


CreatorRegistry& registry = Registry();


vector solver_types;


for (typename CreatorRegistry::iterator iter = registry.begin();iter != registry.end(); ++iter)


{


solver_types.push_back(iter->first);


}


return solver_types;


}


private:


SolverRegistry() {}


static string SolverTypeListString()


{


vector solver_types = SolverTypeList();


string solver_types_str;


for (vector::iterator iter = solver_types.begin();iter != solver_types.end(); ++iter)


{


if (iter != solver_types.begin())


{


solver_types_str += ", ";


}


solver_types_str += iter;


}


return solver_types_str;


}


};


SolverRegistry类的构造函数是private的,也就是用我们没有办法去构造一个这个类的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。


CreateSolver函数先定义了string类型的变量type,表示Solver的类型,然后定义了一个key类型为string,value类型为Creator的map,变量名为registry,其中Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver。如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver返回。


Registry函数中定义了一个static的变量g_registry,这个变量是一个指向CreatorRegistry这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,而且在其他地方修改这个map里的内容,。事实上各个Solver的register的过程正是向g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。


Register的具体步骤:


//代码效果参考:http://hnjlyzjd.com/xl/wz_24215.html

(1)Registry_Solver_Class(SGD)。

(2)定义Creator函数,Registry_Solver_Creator。


(3)定义SolverRegistry类型的static变量,定义SolverRegistry类型的static变量。


(4)SolverRegistry::AddCreator将定义的Creator函数指针添加到static的变量gregistry(map)中。


SolverRegisterer源码:


template


class SolverRegisterer {


public:


SolverRegisterer(const string& type,


Solver (creator)(const SolverParameter&));


};


#define REGISTER_SOLVER_CREATOR(type, creator) \


static SolverRegisterer[span class="hljs-title">floatfloat

static SolverRegisterer[span class="hljs-title">doubledouble

#define REGISTER_SOLVERCLASS(type) \


template \


Solver[span class="hljs-title">Dtype</span]* Creator##type##Solver( \


const SolverParameter& param) \


{ \


return new type##Solver(param); \


} \


REGISTER_SOLVERCREATOR(type, Creator##type##Solver)


}


#endif


在sgd_solver.cpp文件末尾有REGISTER_SOLVER_CLASS(SGD),使用REGISTER_SOLVER_CLASS宏定义一个名为Creator_SGDSolver的函数,即为Creator类型的指针函数,在Creator_SGDSolver函数中调用了SGDSolver的构造函数,并返回所构造的指针变量。Creator类型的指针函数的作用:构造一个对应类型的Solver对象,将其指针返回,然后在REGISTER_SOLVER_CLASS宏里又调用了REGISTER_SOLVER_CREATOR宏,该宏调用相对应(分别定义了SolverRegisterer类模板的float和double类型的static变量)的构造函数。在SolverRegisterer的构造函数中调用了SolverRegistry类的AddCreator函数,其功能将刚才定义的Creator_SGDSolver函数的指针存到g_registry所指向的map中。类似地,所有的Solver对应的cpp文件的末尾都调用了REGISTER_SOLVER_CLASS宏来完成注册,在所有的Solver都注册之后就可以通过g_registry得到对应的Creator函数的指针,并通过调用这个Creator函数来构造对应的Solver。


2.SIGINT和SIGHUP信号的处理


Caffe在train或者test的过程中都有可能会遇到系统信号(用户按下ctrl+c或者关掉了控制的terminal),可以通过对sigint_effect和sighup_effect来设置遇到系统信号的时候希望进行的处理方式:


caffe train –solver=/path/to/solver.prototxt –sigint_effect=EFFECT –sighup_effect=EFFECT


在caffe.cpp中定义了一个GetRequesedAction函数来将设置的string类型的标志转变为枚举类型的变量:


caffe::SolverAction::Enum GetRequestedAction(const std::string& flag_value)


{


if (flag_value == "stop")


{


return caffe::SolverAction::STOP;


}


if (flag_value == "snapshot")


{


return caffe::SolverAction::SNAPSHOT;


}


if (flag_value == "none")


{


return caffe::SolverAction::NONE;


}


LOG(FATAL) [ "Invalid signal effect \""[ flag_value [ "\" was specified";


}


// SolverAction::Enum的定义


namespace SolverAction


{


enum Enum


{


NONE = 0, // Take no special action.


STOP = 1, // Stop training. snapshot_after_train controls whether a


// snapshot is created.


SNAPSHOT = 2 // Take a snapshot, and keep training.


};


}


其中SolverAction::Enum的定义在solver.hpp中,这是一个定义为枚举类型的数据类型,只有三个可能的值,分别对应了三种处理系统信号的方式:NONE(忽略信号什么都不做)/STOP(停止训练)/SNAPSHOT(保存当前的训练状态,继续训练)。在caffe.cpp中的train函数里Solver设置如何处理系统信号的代码为:


caffe::SignalHandler signal_handler(


GetRequestedAction(FLAGS_sigint_effect),


GetRequestedAction(FLAGS_sighup_effect)


);


solver->SetActionFunction(signal_handler.GetActionFunction());


通过gflags定义和解析的两个Command Line Interface的输入参数,FLAGS_sigint_effect和FLAGS_sighup_effect分别对应遇到sigint和sighup信号的处理方式,如果用户不设定,sigint的默认值为stop,sighup的默认值为snapshot。GetRequestedAction函数会将string类型的FLAGS_xx转为SolverAction::Enum类型,并用来定义一个SignalHandler类型的对象signal_handler。这部分代码都依赖于SignalHandler这个类的接口:


// header file


class SignalHandler


{


public:


// Contructor. Specify what action to take when a signal is received.


SignalHandler(SolverAction::Enum SIGINT_action,


SolverAction::Enum SIGHUP_action);


~SignalHandler();


ActionCallback GetActionFunction();


private:


SolverAction::Enum CheckForSignals() const;


SolverAction::Enum SIGINTaction;


SolverAction::Enum SIGHUPaction;


};


// source file


SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action,


SolverAction::Enum SIGHUP_action):


SIGINTaction(SIGINT_action),SIGHUPaction(SIGHUP_action)


{


HookupHandler();


}


void HookupHandler()


{


if (already_hooked_up)


{


LOG(FATAL) [ "Tried to hookup signal handlers more than once.";


}


already_hooked_up = true;


struct sigaction sa;


sa.sa_handler = &handle_signal;


// ...


}


static volatile sig_atomic_t got_sigint = false;


static volatile sig_atomic_t got_sighup = false;


void handle_signal(int signal)


{


switch (signal)


{


case SIGHUP:


got_sighup = true;


break;


case SIGINT:


got_sigint = true;


break;


}


}


ActionCallback SignalHandler::GetActionFunction()


{


return boost::bind(&SignalHandler::CheckForSignals, this);


}


SolverAction::Enum SignalHandler::CheckForSignals() const


{


if (GotSIGHUP()) { return SIGHUPaction;}


if (GotSIGINT()) { return SIGINTaction;}


return SolverAction::NONE;


}


bool GotSIGINT()


{


bool result = got_sigint;


got_sigint = false;


return result;


}


bool GotSIGHUP()


{


bool result = got_sighup;


got_sighup = false;


return result;


}


// ActionCallback的含义


typedef boost::function ActionCallback;


SignalHandler类有两个数据成员,都是SolverAction::Enum类型的,分别对应sigint和sighup信号,在构造函数中,用解析FLAGS_xx得到的结果分别给两个成员赋值,然后调用了HookupHandler函数,这个函数的主要作用是定义了一个sigaction类型(应该是系统级别的代码)的对象sa,然后通过sa.sa_handler= &handle_signal来设置,当有遇到系统信号时,调用handle_signal函数来处理,即判断一下当前的信号是什么类型,如果是sigint就将全局的static变量got_sigint变为true,sighup的处理类似。


在根据用户设置(或者默认值)的参数定义了signal_handler之后,solver通过SetActionFunction来设置了如何处理系统信号。这个函数的输入为signal_handler的GetActionFunction的返回值,根据上述的代码可以看到,GetActionFunction会返回signal_handler对象的CheckForSignals函数的地址(boost::bind的具体使用请参考boost官方文档)。而在Solver的SetActionFunction函数中只是简单的把Solver的一个成员action_requestfunction赋值为输入参数的值,以当前的例子来说就是,solver对象的action_requestfunction指向了signal_handler对象的CheckForSignals函数的地址。其中的ActionCallback是一个函数指针类型,指向了参数为空,返回值为SolverAction::Enum类型的函数(boost::function具体用法参考官方文档)。


总之,通过定义一个SignalHandler类型的对象,告知系统在遇到系统信号的时候回调handle_signal函数来改变全局变量got_sigint和got_sighup的值,然后通过Solver的接口设置了其遇到系统函数将调用signal_handler的Check函数,实际上就是去判断当前是否遇到了系统信号,如果遇到某个类型的信号,就返回设置的处理方式(SolverAction::Enum类型)。


3.Solver::Solve()具体实现


Solver::Solve源码分析:


void Solver::Solve(const char resume_file)


{


// 检查当前是否是root_solver(多GPU模式下,只有root_solver才运行这一部分的代码)


CHECK(Caffe::rootsolver());


// 输出learning policy(更新学习率的策略)


LOG(INFO) [ "Solving " [ net->name();


LOG(INFO) [ "Learning Rate Policy: " [ param_.lr_policy();


// requested_earlyexit初始值为false,此时不要求在优化结束前退出


requested_earlyexit = false;


// 判断指针resume_file是否NULL,如果不是则从resume_file存储的路径里读取之前训练的状态


if (resume_file)


{


LOG(INFO) [ "Restoring previous solver status from " [ resume_file;


Restore(resumefile);


}


// 调用了Step函数,其执行了实际的逐步的迭代过程


Step(param.maxiter() - iter);


// 迭代结束或者遇到系统信号提前结束后,判断是否需要在训练结束之后snapshot,可在solver.prototxt里设置


if (param_.snapshot_aftertrain() && (!param.snapshot() || iter % param.snapshot() != 0))


{ Snapshot(); }


// 如果在Step函数的迭代过程中遇到了系统信号,且处理方式设置为STOP,requested_earlyexit会被修改为true,迭代提前结束,并输出相关信息


if (requested_earlyexit)


{


LOG(INFO) [ "Optimization stopped early.";


return;


}


// 判断是否需要输出最后的loss


if (param.display() && iter % param.display() == 0)


{


Dtype loss;


net->ForwardPrefilled(&loss);


LOG(INFO) [ "Iteration " [ iter [ ", loss = " [ loss;


}


// 判断是否需要最后Test


if (param.testinterval() && iter % param_.test_interval() == 0) {


TestAll();


}


LOG(INFO) [ "Optimization Done.";


}


Solver::Step函数源码解析:


template [span class="hljs-keyword">typename Dtype>


void Solver::Step(int iters)


{


vector<Blob> bottomvec;


// 设置开始的迭代次数(如果是从之前的snapshot恢复的,那iter等于snapshot时的迭代次数)和结束的迭代次数


const int startiter = iter;


const int stopiter = iter + iters;


// 输出的loss为前average_loss次loss的平均值,在solver.prototxt里设置,默认为1,


// losses存储前average_loss个loss,smoothed_loss为最后要输出的均值


int averageloss = this->param.average_loss();


vector losses;


Dtype smoothedloss = 0;


// 迭代


while (iter < stopiter)


{


// 清空上一次所有参数的梯度


net->ClearParamDiffs();


// 判断是否需要测试


if (param_.testinterval() && iter % param_.testinterval() == 0


&& (iter > 0 || param_.test_initialization()) && Caffe::root_solver())


{


TestAll();


// 判断是否需要提前结束迭代


if (requested_earlyexit)


{ break; }


}


for (int i = 0; i < callbacks.size(); ++i)


{


callbacks【i】->onstart();


}


// 判断当前迭代次数是否需要显示loss等信息


const bool display = param.display() && iter % param.display() == 0;


net_->set_debuginfo(display && param.debug_info());


Dtype loss = 0;


// iter_size在solver.prototxt中设置,实际上的batch_size=iter_size batch_size(网络中定义的),


// 因此每一次迭代的loss是iter_size次迭代的和,再除以iter_size,这个loss是通过调用Net::ForwardBackward函数得到的


// 在GPU的显存不够的时候设置,比如把batch_size设置为128,但是会out_of_memory


// 借助这个方法,可以设置batch_size=32,itersize=4,那实际上每次迭代还是处理了128个数据


for (int i = 0; i < param.itersize(); ++i)


{


loss += net->ForwardBackward(bottomvec);


}


loss /= param.iter_size();


// 计算要输出的smoothed_loss,如果losses里还没有存够average_loss个loss则将当前的loss插入,如果已经存够了,则将之前的替换掉


if (losses.size() < average_loss)


{


losses.push_back(loss);


int size = losses.size();


smoothed_loss = (smoothed_loss (size - 1) + loss) / size;


}


else


{


int idx = (iter_ - start_iter) % average_loss;


smoothed_loss += (loss - losses【idx】) / average_loss;


losses【idx】 = loss;


}


// 输出当前迭代的信息


if (display)


{


LOG_IF(INFO, Caffe::rootsolver()) [ "Iteration " [ iter


[ ", loss = " [ smoothedloss;


const vector& result = net->output_blobs();


int score_index = 0;


for (int j = 0; j < result.size(); ++j)


{


const Dtype result_vec = result【j】->cpu_data();


const string& outputname =


net->blobnames()【net->output_blob_indices()【j】】;


const Dtype lossweight =


net->blob_lossweights()【net->output_blob_indices()【j】】;


for (int k = 0; k count(); ++k)


{


ostringstream loss_msg_stream;


if (loss_weight)


{


loss_msg_stream [ " ( " [ loss_weight


[ " = " [ loss_weight result_vec【k】 [ " loss)";


}


LOG_IF(INFO, Caffe::root_solver()) [ " Train net output #"


[ score_index++ [ ": " [ output_name [ " = "


[ result_vec【k】 [ loss_msgstream.str();


}


}


}


for (int i = 0; i < callbacks.size(); ++i)


{


callbacks_【i】->on_gradientsready();


}


// 执行梯度的更新,其在基类Solver中没有实现,但会调用每个子类的实现


ApplyUpdate();


// 迭代次数加1


++iter;


// 调用GetRequestedAction,实际是通过action_requestfunction函数指针调用已设置好(通过SetRequestedAction)的signalhandler的CheckForSignals函数,这个函数的作用是会根据是否遇到系统信号以及信号的类型和用户设置(或者默认)的方式返回处理的方式


SolverAction::Enum request = GetRequestedAction();


// 判断当前迭代是否需要snapshot,如果request==SNAPSHOT则也需要


if ((param.snapshot() && iter % param.snapshot() == 0


&& Caffe::root_solver()) || (request == SolverAction::SNAPSHOT))


{ Snapshot(); }


// 如果request为STOP则修改requested_earlyexit为true,会提前结束迭代


if (SolverAction::STOP == request)


{


requested_earlyexit = true;


break;


}


}


}


每一组网络中的参数的更新都是在不同类型的Solver实现各自的ApplyUpdate函数中完成的,以最常用的SGD为例子来分析这个函数具体的功能:


SGDSolver::ApplyUpdate源码分析:


template


void SGDSolver::ApplyUpdate()


{


CHECK(Caffe::root_solver());


// GetLearningRate根据设置的lrpolicy来计算当前迭代的learning rate的值


Dtype rate = GetLearningRate();


// 判断是否需要输出当前的learning rate


if (this->param.display() && this->iter % this->param.display() == 0)


{


LOG(INFO) [ "Iteration " [ this->iter_ [ ", lr = " [ rate;


}


// 避免梯度爆炸,如果梯度的二范数超过了某个数值则进行scale操作,将梯度减小


ClipGradients();


// 对所有可更新的网络参数进行操作


for (int param_id = 0; paramid < this->net->learnable_params().size();++param_id)


{


// 将第param_id个参数的梯度除以iter_size,其作用是保证实际的batch_size=iter_size batch_size


Normalize(param_id);


<span class="hljs

相关文章
|
2天前
|
Java 程序员
程序技术好文:解析器组合子
程序技术好文:解析器组合子
|
1天前
|
网络协议 安全 分布式数据库
技术分享:分布式数据库DNS服务器的架构思路
技术分享:分布式数据库DNS服务器的架构思路
7 0
|
2天前
|
Java UED 开发者
JVM逃逸分析原理解析:优化Java程序性能和内存利用效率
JVM逃逸分析原理解析:优化Java程序性能和内存利用效率
|
2天前
|
自然语言处理 C语言 C++
程序与技术分享:C++写一个简单的解析器(分析C语言)
程序与技术分享:C++写一个简单的解析器(分析C语言)
|
2天前
|
Rust 安全 开发者
Rust语言的Hello, World! 程序解析
Rust语言的Hello, World! 程序解析
7 0
|
8天前
|
机器学习/深度学习 缓存 算法
netty源码解解析(4.0)-25 ByteBuf内存池:PoolArena-PoolChunk
netty源码解解析(4.0)-25 ByteBuf内存池:PoolArena-PoolChunk
|
10天前
|
XML Java 数据格式
深度解析 Spring 源码:从 BeanDefinition 源码探索 Bean 的本质
深度解析 Spring 源码:从 BeanDefinition 源码探索 Bean 的本质
23 3
|
3天前
|
Java 数据库连接 Spring
Spring 整合 MyBatis 底层源码解析
Spring 整合 MyBatis 底层源码解析
|
2天前
|
NoSQL Java Redis
【源码解析】自动配置的这些细节都不知道,别说你会 springboot
【源码解析】自动配置的这些细节都不知道,别说你会 springboot
|
9天前
|
存储 NoSQL 算法
Redis(四):del/unlink 命令源码解析
Redis(四):del/unlink 命令源码解析

推荐镜像

更多