程序与技术分享:Caffe中Solver解析

本文涉及的产品
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
云解析 DNS,旗舰版 1个月
简介: 程序与技术分享:Caffe中Solver解析

Caffe中Solver解析


1.Solver的初始化


shared_ptr

solver(caffe::SolverRegistry::CreateSolver(solver_param));


caffe.cpp中的train函数中通过上述的代码定义了一个指向Solver的shared_ptr。其中主要是通过调用SolverRegistry这个类的静态成员函数CreateSolver得到一个指向Solver的指针来构造shared_ptr类型的solver。而且由于C++多态的特性,solver是一个指向基类Solver类型的指针,通过solver这个智能指针来调用各个成员函数会调用到各个子类(SGDSolver等)的函数。


具体步骤:


(1)SolverRegistry::CreateSolver(solver_param)。


(2)通过static的gregistry【type】获取type对应的Solver的Creator函数指针。


(3)调用Creator函数。


(4)new SGDSolver(solver_param)创建solver。


SolverRegistry类源码:


class SolverRegistry


{


public:


typedef Solver (Creator)(const SolverParameter&);


typedef std::map CreatorRegistry;


static CreatorRegistry& Registry()


{


static CreatorRegistry gregistry = new CreatorRegistry();


return gregistry;


}


static void AddCreator(const string& type, Creator creator)


{


CreatorRegistry& registry = Registry();


CHECK_EQ(registry.count(type), 0)


[ "Solver type " [ type [ " already registered.";


registry【type】 = creator;


}


static Solver CreateSolver(const SolverParameter& param)


{


const string& type = param.type();


CreatorRegistry& registry = Registry();


CHECK_EQ(registry.count(type), 1) [ "Unknown solver type: " [ type


[ " (known types: " [ SolverTypeListString() [ ")";


return registry【type】(param);


}


static vector SolverTypeList()


{


CreatorRegistry& registry = Registry();


vector solver_types;


for (typename CreatorRegistry::iterator iter = registry.begin();iter != registry.end(); ++iter)


{


solver_types.push_back(iter->first);


}


return solver_types;


}


private:


SolverRegistry() {}


static string SolverTypeListString()


{


vector solver_types = SolverTypeList();


string solver_types_str;


for (vector::iterator iter = solver_types.begin();iter != solver_types.end(); ++iter)


{


if (iter != solver_types.begin())


{


solver_types_str += ", ";


}


solver_types_str += iter;


}


return solver_types_str;


}


};


SolverRegistry类的构造函数是private的,也就是用我们没有办法去构造一个这个类的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。


CreateSolver函数先定义了string类型的变量type,表示Solver的类型,然后定义了一个key类型为string,value类型为Creator的map,变量名为registry,其中Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver。如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver返回。


Registry函数中定义了一个static的变量g_registry,这个变量是一个指向CreatorRegistry这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,而且在其他地方修改这个map里的内容,。事实上各个Solver的register的过程正是向g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。


Register的具体步骤:


//代码效果参考:http://hnjlyzjd.com/xl/wz_24215.html

(1)Registry_Solver_Class(SGD)。

(2)定义Creator函数,Registry_Solver_Creator。


(3)定义SolverRegistry类型的static变量,定义SolverRegistry类型的static变量。


(4)SolverRegistry::AddCreator将定义的Creator函数指针添加到static的变量gregistry(map)中。


SolverRegisterer源码:


template


class SolverRegisterer {


public:


SolverRegisterer(const string& type,


Solver (creator)(const SolverParameter&));


};


#define REGISTER_SOLVER_CREATOR(type, creator) \


static SolverRegisterer[span class="hljs-title">floatfloat

static SolverRegisterer[span class="hljs-title">doubledouble

#define REGISTER_SOLVERCLASS(type) \


template \


Solver[span class="hljs-title">Dtype</span]* Creator##type##Solver( \


const SolverParameter& param) \


{ \


return new type##Solver(param); \


} \


REGISTER_SOLVERCREATOR(type, Creator##type##Solver)


}


#endif


在sgd_solver.cpp文件末尾有REGISTER_SOLVER_CLASS(SGD),使用REGISTER_SOLVER_CLASS宏定义一个名为Creator_SGDSolver的函数,即为Creator类型的指针函数,在Creator_SGDSolver函数中调用了SGDSolver的构造函数,并返回所构造的指针变量。Creator类型的指针函数的作用:构造一个对应类型的Solver对象,将其指针返回,然后在REGISTER_SOLVER_CLASS宏里又调用了REGISTER_SOLVER_CREATOR宏,该宏调用相对应(分别定义了SolverRegisterer类模板的float和double类型的static变量)的构造函数。在SolverRegisterer的构造函数中调用了SolverRegistry类的AddCreator函数,其功能将刚才定义的Creator_SGDSolver函数的指针存到g_registry所指向的map中。类似地,所有的Solver对应的cpp文件的末尾都调用了REGISTER_SOLVER_CLASS宏来完成注册,在所有的Solver都注册之后就可以通过g_registry得到对应的Creator函数的指针,并通过调用这个Creator函数来构造对应的Solver。


2.SIGINT和SIGHUP信号的处理


Caffe在train或者test的过程中都有可能会遇到系统信号(用户按下ctrl+c或者关掉了控制的terminal),可以通过对sigint_effect和sighup_effect来设置遇到系统信号的时候希望进行的处理方式:


caffe train –solver=/path/to/solver.prototxt –sigint_effect=EFFECT –sighup_effect=EFFECT


在caffe.cpp中定义了一个GetRequesedAction函数来将设置的string类型的标志转变为枚举类型的变量:


caffe::SolverAction::Enum GetRequestedAction(const std::string& flag_value)


{


if (flag_value == "stop")


{


return caffe::SolverAction::STOP;


}


if (flag_value == "snapshot")


{


return caffe::SolverAction::SNAPSHOT;


}


if (flag_value == "none")


{


return caffe::SolverAction::NONE;


}


LOG(FATAL) [ "Invalid signal effect \""[ flag_value [ "\" was specified";


}


// SolverAction::Enum的定义


namespace SolverAction


{


enum Enum


{


NONE = 0, // Take no special action.


STOP = 1, // Stop training. snapshot_after_train controls whether a


// snapshot is created.


SNAPSHOT = 2 // Take a snapshot, and keep training.


};


}


其中SolverAction::Enum的定义在solver.hpp中,这是一个定义为枚举类型的数据类型,只有三个可能的值,分别对应了三种处理系统信号的方式:NONE(忽略信号什么都不做)/STOP(停止训练)/SNAPSHOT(保存当前的训练状态,继续训练)。在caffe.cpp中的train函数里Solver设置如何处理系统信号的代码为:


caffe::SignalHandler signal_handler(


GetRequestedAction(FLAGS_sigint_effect),


GetRequestedAction(FLAGS_sighup_effect)


);


solver->SetActionFunction(signal_handler.GetActionFunction());


通过gflags定义和解析的两个Command Line Interface的输入参数,FLAGS_sigint_effect和FLAGS_sighup_effect分别对应遇到sigint和sighup信号的处理方式,如果用户不设定,sigint的默认值为stop,sighup的默认值为snapshot。GetRequestedAction函数会将string类型的FLAGS_xx转为SolverAction::Enum类型,并用来定义一个SignalHandler类型的对象signal_handler。这部分代码都依赖于SignalHandler这个类的接口:


// header file


class SignalHandler


{


public:


// Contructor. Specify what action to take when a signal is received.


SignalHandler(SolverAction::Enum SIGINT_action,


SolverAction::Enum SIGHUP_action);


~SignalHandler();


ActionCallback GetActionFunction();


private:


SolverAction::Enum CheckForSignals() const;


SolverAction::Enum SIGINTaction;


SolverAction::Enum SIGHUPaction;


};


// source file


SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action,


SolverAction::Enum SIGHUP_action):


SIGINTaction(SIGINT_action),SIGHUPaction(SIGHUP_action)


{


HookupHandler();


}


void HookupHandler()


{


if (already_hooked_up)


{


LOG(FATAL) [ "Tried to hookup signal handlers more than once.";


}


already_hooked_up = true;


struct sigaction sa;


sa.sa_handler = &handle_signal;


// ...


}


static volatile sig_atomic_t got_sigint = false;


static volatile sig_atomic_t got_sighup = false;


void handle_signal(int signal)


{


switch (signal)


{


case SIGHUP:


got_sighup = true;


break;


case SIGINT:


got_sigint = true;


break;


}


}


ActionCallback SignalHandler::GetActionFunction()


{


return boost::bind(&SignalHandler::CheckForSignals, this);


}


SolverAction::Enum SignalHandler::CheckForSignals() const


{


if (GotSIGHUP()) { return SIGHUPaction;}


if (GotSIGINT()) { return SIGINTaction;}


return SolverAction::NONE;


}


bool GotSIGINT()


{


bool result = got_sigint;


got_sigint = false;


return result;


}


bool GotSIGHUP()


{


bool result = got_sighup;


got_sighup = false;


return result;


}


// ActionCallback的含义


typedef boost::function ActionCallback;


SignalHandler类有两个数据成员,都是SolverAction::Enum类型的,分别对应sigint和sighup信号,在构造函数中,用解析FLAGS_xx得到的结果分别给两个成员赋值,然后调用了HookupHandler函数,这个函数的主要作用是定义了一个sigaction类型(应该是系统级别的代码)的对象sa,然后通过sa.sa_handler= &handle_signal来设置,当有遇到系统信号时,调用handle_signal函数来处理,即判断一下当前的信号是什么类型,如果是sigint就将全局的static变量got_sigint变为true,sighup的处理类似。


在根据用户设置(或者默认值)的参数定义了signal_handler之后,solver通过SetActionFunction来设置了如何处理系统信号。这个函数的输入为signal_handler的GetActionFunction的返回值,根据上述的代码可以看到,GetActionFunction会返回signal_handler对象的CheckForSignals函数的地址(boost::bind的具体使用请参考boost官方文档)。而在Solver的SetActionFunction函数中只是简单的把Solver的一个成员action_requestfunction赋值为输入参数的值,以当前的例子来说就是,solver对象的action_requestfunction指向了signal_handler对象的CheckForSignals函数的地址。其中的ActionCallback是一个函数指针类型,指向了参数为空,返回值为SolverAction::Enum类型的函数(boost::function具体用法参考官方文档)。


总之,通过定义一个SignalHandler类型的对象,告知系统在遇到系统信号的时候回调handle_signal函数来改变全局变量got_sigint和got_sighup的值,然后通过Solver的接口设置了其遇到系统函数将调用signal_handler的Check函数,实际上就是去判断当前是否遇到了系统信号,如果遇到某个类型的信号,就返回设置的处理方式(SolverAction::Enum类型)。


3.Solver::Solve()具体实现


Solver::Solve源码分析:


void Solver::Solve(const char resume_file)


{


// 检查当前是否是root_solver(多GPU模式下,只有root_solver才运行这一部分的代码)


CHECK(Caffe::rootsolver());


// 输出learning policy(更新学习率的策略)


LOG(INFO) [ "Solving " [ net->name();


LOG(INFO) [ "Learning Rate Policy: " [ param_.lr_policy();


// requested_earlyexit初始值为false,此时不要求在优化结束前退出


requested_earlyexit = false;


// 判断指针resume_file是否NULL,如果不是则从resume_file存储的路径里读取之前训练的状态


if (resume_file)


{


LOG(INFO) [ "Restoring previous solver status from " [ resume_file;


Restore(resumefile);


}


// 调用了Step函数,其执行了实际的逐步的迭代过程


Step(param.maxiter() - iter);


// 迭代结束或者遇到系统信号提前结束后,判断是否需要在训练结束之后snapshot,可在solver.prototxt里设置


if (param_.snapshot_aftertrain() && (!param.snapshot() || iter % param.snapshot() != 0))


{ Snapshot(); }


// 如果在Step函数的迭代过程中遇到了系统信号,且处理方式设置为STOP,requested_earlyexit会被修改为true,迭代提前结束,并输出相关信息


if (requested_earlyexit)


{


LOG(INFO) [ "Optimization stopped early.";


return;


}


// 判断是否需要输出最后的loss


if (param.display() && iter % param.display() == 0)


{


Dtype loss;


net->ForwardPrefilled(&loss);


LOG(INFO) [ "Iteration " [ iter [ ", loss = " [ loss;


}


// 判断是否需要最后Test


if (param.testinterval() && iter % param_.test_interval() == 0) {


TestAll();


}


LOG(INFO) [ "Optimization Done.";


}


Solver::Step函数源码解析:


template [span class="hljs-keyword">typename Dtype>


void Solver::Step(int iters)


{


vector<Blob> bottomvec;


// 设置开始的迭代次数(如果是从之前的snapshot恢复的,那iter等于snapshot时的迭代次数)和结束的迭代次数


const int startiter = iter;


const int stopiter = iter + iters;


// 输出的loss为前average_loss次loss的平均值,在solver.prototxt里设置,默认为1,


// losses存储前average_loss个loss,smoothed_loss为最后要输出的均值


int averageloss = this->param.average_loss();


vector losses;


Dtype smoothedloss = 0;


// 迭代


while (iter < stopiter)


{


// 清空上一次所有参数的梯度


net->ClearParamDiffs();


// 判断是否需要测试


if (param_.testinterval() && iter % param_.testinterval() == 0


&& (iter > 0 || param_.test_initialization()) && Caffe::root_solver())


{


TestAll();


// 判断是否需要提前结束迭代


if (requested_earlyexit)


{ break; }


}


for (int i = 0; i < callbacks.size(); ++i)


{


callbacks【i】->onstart();


}


// 判断当前迭代次数是否需要显示loss等信息


const bool display = param.display() && iter % param.display() == 0;


net_->set_debuginfo(display && param.debug_info());


Dtype loss = 0;


// iter_size在solver.prototxt中设置,实际上的batch_size=iter_size batch_size(网络中定义的),


// 因此每一次迭代的loss是iter_size次迭代的和,再除以iter_size,这个loss是通过调用Net::ForwardBackward函数得到的


// 在GPU的显存不够的时候设置,比如把batch_size设置为128,但是会out_of_memory


// 借助这个方法,可以设置batch_size=32,itersize=4,那实际上每次迭代还是处理了128个数据


for (int i = 0; i < param.itersize(); ++i)


{


loss += net->ForwardBackward(bottomvec);


}


loss /= param.iter_size();


// 计算要输出的smoothed_loss,如果losses里还没有存够average_loss个loss则将当前的loss插入,如果已经存够了,则将之前的替换掉


if (losses.size() < average_loss)


{


losses.push_back(loss);


int size = losses.size();


smoothed_loss = (smoothed_loss (size - 1) + loss) / size;


}


else


{


int idx = (iter_ - start_iter) % average_loss;


smoothed_loss += (loss - losses【idx】) / average_loss;


losses【idx】 = loss;


}


// 输出当前迭代的信息


if (display)


{


LOG_IF(INFO, Caffe::rootsolver()) [ "Iteration " [ iter


[ ", loss = " [ smoothedloss;


const vector& result = net->output_blobs();


int score_index = 0;


for (int j = 0; j < result.size(); ++j)


{


const Dtype result_vec = result【j】->cpu_data();


const string& outputname =


net->blobnames()【net->output_blob_indices()【j】】;


const Dtype lossweight =


net->blob_lossweights()【net->output_blob_indices()【j】】;


for (int k = 0; k count(); ++k)


{


ostringstream loss_msg_stream;


if (loss_weight)


{


loss_msg_stream [ " ( " [ loss_weight


[ " = " [ loss_weight result_vec【k】 [ " loss)";


}


LOG_IF(INFO, Caffe::root_solver()) [ " Train net output #"


[ score_index++ [ ": " [ output_name [ " = "


[ result_vec【k】 [ loss_msgstream.str();


}


}


}


for (int i = 0; i < callbacks.size(); ++i)


{


callbacks_【i】->on_gradientsready();


}


// 执行梯度的更新,其在基类Solver中没有实现,但会调用每个子类的实现


ApplyUpdate();


// 迭代次数加1


++iter;


// 调用GetRequestedAction,实际是通过action_requestfunction函数指针调用已设置好(通过SetRequestedAction)的signalhandler的CheckForSignals函数,这个函数的作用是会根据是否遇到系统信号以及信号的类型和用户设置(或者默认)的方式返回处理的方式


SolverAction::Enum request = GetRequestedAction();


// 判断当前迭代是否需要snapshot,如果request==SNAPSHOT则也需要


if ((param.snapshot() && iter % param.snapshot() == 0


&& Caffe::root_solver()) || (request == SolverAction::SNAPSHOT))


{ Snapshot(); }


// 如果request为STOP则修改requested_earlyexit为true,会提前结束迭代


if (SolverAction::STOP == request)


{


requested_earlyexit = true;


break;


}


}


}


每一组网络中的参数的更新都是在不同类型的Solver实现各自的ApplyUpdate函数中完成的,以最常用的SGD为例子来分析这个函数具体的功能:


SGDSolver::ApplyUpdate源码分析:


template


void SGDSolver::ApplyUpdate()


{


CHECK(Caffe::root_solver());


// GetLearningRate根据设置的lrpolicy来计算当前迭代的learning rate的值


Dtype rate = GetLearningRate();


// 判断是否需要输出当前的learning rate


if (this->param.display() && this->iter % this->param.display() == 0)


{


LOG(INFO) [ "Iteration " [ this->iter_ [ ", lr = " [ rate;


}


// 避免梯度爆炸,如果梯度的二范数超过了某个数值则进行scale操作,将梯度减小


ClipGradients();


// 对所有可更新的网络参数进行操作


for (int param_id = 0; paramid < this->net->learnable_params().size();++param_id)


{


// 将第param_id个参数的梯度除以iter_size,其作用是保证实际的batch_size=iter_size batch_size


Normalize(param_id);


<span class="hljs

相关文章
|
3天前
|
物联网 调度 vr&ar
鸿蒙HarmonyOS应用开发 |鸿蒙技术分享HarmonyOS Next 深度解析:分布式能力与跨设备协作实战
鸿蒙技术分享:HarmonyOS Next 深度解析 随着万物互联时代的到来,华为发布的 HarmonyOS Next 在技术架构和生态体验上实现了重大升级。本文从技术架构、生态优势和开发实践三方面深入探讨其特点,并通过跨设备笔记应用实战案例,展示其强大的分布式能力和多设备协作功能。核心亮点包括新一代微内核架构、统一开发语言 ArkTS 和多模态交互支持。开发者可借助 DevEco Studio 4.0 快速上手,体验高效、灵活的开发过程。 239个字符
144 13
鸿蒙HarmonyOS应用开发 |鸿蒙技术分享HarmonyOS Next 深度解析:分布式能力与跨设备协作实战
|
2月前
|
SQL 安全 Windows
SQL安装程序规则错误解析与解决方案
在安装SQL Server时,用户可能会遇到安装程序规则错误的问题,这些错误通常与系统配置、权限设置、依赖项缺失或版本不兼容等因素有关
|
2月前
|
XML Java 数据格式
手动开发-简单的Spring基于注解配置的程序--源码解析
手动开发-简单的Spring基于注解配置的程序--源码解析
53 0
|
2月前
|
XML Java 数据格式
手动开发-简单的Spring基于XML配置的程序--源码解析
手动开发-简单的Spring基于XML配置的程序--源码解析
87 0
|
3月前
|
设计模式 存储 算法
PHP中的设计模式:策略模式的深入解析与应用在软件开发的浩瀚海洋中,PHP以其独特的魅力和强大的功能吸引了无数开发者。作为一门历史悠久且广泛应用的编程语言,PHP不仅拥有丰富的内置函数和扩展库,还支持面向对象编程(OOP),为开发者提供了灵活而强大的工具集。在PHP的众多特性中,设计模式的应用尤为引人注目,它们如同精雕细琢的宝石,镶嵌在代码的肌理之中,让程序更加优雅、高效且易于维护。今天,我们就来深入探讨PHP中使用频率颇高的一种设计模式——策略模式。
本文旨在深入探讨PHP中的策略模式,从定义到实现,再到应用场景,全面剖析其在PHP编程中的应用价值。策略模式作为一种行为型设计模式,允许在运行时根据不同情况选择不同的算法或行为,极大地提高了代码的灵活性和可维护性。通过实例分析,本文将展示如何在PHP项目中有效利用策略模式来解决实际问题,并提升代码质量。
|
4月前
|
存储 并行计算 API
ViperGPT解析:结合视觉输入与文本查询生成和执行程序
ViperGPT是一个创新的混合视觉和语言处理模型,通过生成和执行代码来解决视觉查询问题,具有高度模块化、灵活性和优秀的外部知识查询能力。
90 0
|
5月前
|
数据可视化 持续交付 开发工具
RAD技术解析:快速开发应用程序的秘诀
**快速应用开发(RAD)**是一种始于90年代的敏捷方法,旨在通过迭代原型和反馈加速高质量软件交付。由James Martin提出,它包括需求规划、界面设计、快速构建和持续优化四阶段,以提高质量、降低风险、增强灵活性、降低成本和提升客户满意度。工具如ZohoCreator支持RAD,通过可视化工具和低代码平台促进高效开发,实现快速迭代和市场适应,降低项目失败风险,提高用户满意度。
106 9
|
1月前
|
监控 Java 应用服务中间件
高级java面试---spring.factories文件的解析源码API机制
【11月更文挑战第20天】Spring Boot是一个用于快速构建基于Spring框架的应用程序的开源框架。它通过自动配置、起步依赖和内嵌服务器等特性,极大地简化了Spring应用的开发和部署过程。本文将深入探讨Spring Boot的背景历史、业务场景、功能点以及底层原理,并通过Java代码手写模拟Spring Boot的启动过程,特别是spring.factories文件的解析源码API机制。
76 2
|
2月前
|
缓存 Java 程序员
Map - LinkedHashSet&Map源码解析
Map - LinkedHashSet&Map源码解析
78 0
|
2月前
|
算法 Java 容器
Map - HashSet & HashMap 源码解析
Map - HashSet & HashMap 源码解析
63 0

推荐镜像

更多