Caffe:深入分析(怎么训练)

简介: main()   首先入口函数caffe.cpp 1 int main(int argc, char** argv) { 2 ...... 3 if (argc == 2) { 4 #ifdef WITH_PYTHON_LAYER 5 try { 6 #endif ...

main() 

  首先入口函数caffe.cpp

 1 int main(int argc, char** argv) {
 2   ......
 3   if (argc == 2) {
 4 #ifdef WITH_PYTHON_LAYER
 5     try {
 6 #endif
 7       return GetBrewFunction(caffe::string(argv[1]))(); //根据输入参数确定是train还是test,采用string到函数指针的映射实现,非常巧妙
 8 #ifdef WITH_PYTHON_LAYER
 9     } catch (bp::error_already_set) {
10       PyErr_Print();
11       return 1;
12     }
13 #endif
14   } else {
15     gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
16   }
17 }

  在main函数中GetBrewFunction函数调用了通过工厂模式生成的由string到函数指针的map

1 typedef int (*BrewFunction)();
2 typedef std::map<caffe::string, BrewFunction> BrewMap;
3 BrewMap g_brew_map;

  在train、test、device_query、time函数后面都可以看到对这些函数的register,相当于这些函数指针已经在map中存在了

1 RegisterBrewFunction(train);
2 RegisterBrewFunction(test);
3 RegisterBrewFunction(device_query);
4 RegisterBrewFunction(time);

train()

  接着是train过程

 1 // Train / Finetune a model.
 2 int train() {
 3   ......
 4   caffe::SolverParameter solver_param;
 5   caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//从-solver参数读取solver_param
 6   ......
 7   shared_ptr<caffe::Solver<float> >
 8       solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));//从参数创建solver,同样采用string到函数指针的映射实现,用到了工厂模式
 9 
10   if (FLAGS_snapshot.size()) {//迭代snapshot次后保存模型一次
11     LOG(INFO) << "Resuming from " << FLAGS_snapshot;
12     solver->Restore(FLAGS_snapshot.c_str());
13   } else if (FLAGS_weights.size()) {//若采用finetuning,则拷贝weight到指定模型
14     CopyLayers(solver.get(), FLAGS_weights);
15   }
16 
17   if (gpus.size() > 1) {
18     caffe::P2PSync<float> sync(solver, NULL, solver->param());
19     sync.Run(gpus);
20   } else {
21     LOG(INFO) << "Starting Optimization";
22     solver->Solve();//开始训练网络
23   }
24   LOG(INFO) << "Optimization Done.";
25   return 0;
26 }

Solver()

  看CreateSolver函数是如何构建solver和net的,CreateSolver定义在solver_factory.hpp中,首先需要知道的是solver是一个基类,继承自它的类有SGD等,下面的实现就可以根据param的type构造一个指向特定solver的指针,比如SGD。

1 static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
2     const string& type = param.type();
3     CreatorRegistry& registry = Registry();
4     CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
5         << " (known types: " << SolverTypeListString() << ")";
6     return registry[type](param);
7   }

  关键之处在于上面代码最后一行语句,它的作用是根据配置文件创建对应的Solver对象(默认为SGDSolver子类对象)。此处工厂模式和一个关键的宏REGISTER_SOLVER_CLASS(SGD)发挥了重要作用。

1 #define REGISTER_SOLVER_CLASS(type)                                              
2   template <typename Dtype>                                                      
3   Solver<Dtype>* Creator_##type##Solver(                                         
4       const SolverParameter& param)                                              
5   {                                                                              
6     return new type##Solver<Dtype>(param);                                       
7   }                                                                              
8   REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)    
9 }   

  这样一个SGDSolver对象就调用其构造函数被构造出来了。

1 explicit SGDSolver(const SolverParameter& param)
2       : Solver<Dtype>(param) { PreSolve(); }

  同时,Solver这个基类也被构造出来了,在solver.hpp里

1 explicit Solver(const SolverParameter& param,
2       const Solver* root_solver = NULL);

  Solver构造函数又会调用Init进行训练网络和测试网络的初始化,Init函数没有被声明为虚函数,不能被覆写,也就是说所有的solver都调用这个函数进行初始化。

 1 template <typename Dtype>
 2 void Solver<Dtype>::Init(const SolverParameter& param) {
 3   ......
 4   // Scaffolding code
 5   InitTrainNet();//初始化训练网络
 6   if (Caffe::root_solver()) {
 7     InitTestNets();//初始化测试网络
 8     LOG(INFO) << "Solver scaffolding done.";
 9   }
10   iter_ = 0;//迭代次数设为0
11   current_step_ = 0;
12 }

InitTrainNet()

  接下来看训练网络初始化函数InitTrainNet,具体的内容见Net的网络层的构建(源码分析)

  caffe是如何来solve的:在成员函数Solve()内部,

 1 template <typename Dtype>
 2 void Solver<Dtype>::Solve(const char* resume_file) {
 3   ......
 4   // For a network that is trained by the solver, no bottom or top vecs
 5   // should be given, and we will just provide dummy vecs.
 6   int start_iter = iter_;
 7   //开始迭代
 8   Step(param_.max_iter() - iter_);
 9   ......
10 }

Step()

  下面我们看一下Solver::Step()函数内部实现情况,具体的一次迭代过程。见Caffe参数交换源码分析

  这就是整个网络的训练过程。 

 

当神已无能为力,那便是魔渡众生
目录
相关文章
|
10月前
|
编解码
数字式电秒表、电子毫秒表,智能毫秒计,通用电脑式毫秒计,智能毫秒表
‌数字式电秒表的主要用途和作用包括在各种需要精确计时的场景中应用,如消防施工质量控制、技术检测、维护管理以及消防产品现场检查等‌。具体应用场景包括火灾自动报警系统的响应时间、水流指示器的延迟时间、电梯的迫降时间、灯具的应急工作时间等‌1。
|
Android开发 Swift iOS开发
iOS和安卓作为主流操作系统,开发者需了解两者差异以提高效率并确保优质用户体验。
【10月更文挑战第1天】随着移动互联网的发展,智能手机成为生活必需品,iOS和安卓作为主流操作系统,各有庞大的用户群。开发者需了解两者差异以提高效率并确保优质用户体验。iOS使用Swift或Objective-C开发,强调简洁直观的设计;安卓则采用Java或Kotlin,注重层次与动画。Swift和Kotlin均有现代编程特性。此外,iOS设备更易优化,而安卓需考虑更多兼容性问题。iOS应用仅能通过App Store发布,审核严格;安卓除Google Play外还可通过第三方市场发布,审核较宽松。开发者应根据需求选择合适平台,提供最佳应用体验。
357 3
|
11月前
|
编解码 人工智能 缓存
自学记录鸿蒙API 13:实现多目标识别Object Detection
多目标识别技术广泛应用于动物识别、智能相册分类和工业检测等领域。本文通过学习HarmonyOS的Object Detection API(API 13),详细介绍了如何实现一个多目标识别应用,涵盖从项目初始化、核心功能实现到用户界面设计的全过程。重点探讨了目标类别识别、边界框生成、高精度置信度等关键功能,并分享了性能优化与功能扩展的经验。最后,作者总结了学习心得,并展望了未来结合语音助手等创新应用的可能性。如果你对多目标识别感兴趣,不妨从基础功能开始,逐步实现自己的创意。
342 60
|
11月前
|
存储 JSON 缓存
【网络原理】——HTTP请求头中的属性
HTTP请求头,HOST、Content-Agent、Content-Type、User-Agent、Referer、Cookie。
|
人工智能 智能设计
阿里云logo设计入口(在线一键生成)
阿里云logo设计入口(在线一键生成)
9738 1
阿里云logo设计入口(在线一键生成)
|
SQL 存储 数据库
Hive简介、什么是Hive、为什么使用Hive、Hive的特点、Hive架构图、Hive基本组成、Hive与Hadoop的关系、Hive与传统数据库对比、Hive数据存储(来自学习资料)
1.1 Hive简介 1.1.1   什么是Hive Hive是基于Hadoop的一个数据仓库工具,可以将结构化的数据文件映射为一张数据库表,并提供类SQL查询功能。 1.1.2   为什么使用Hive Ø  直接使用hadoop所面临的问题 人员学习成本太高 项目周期要求太短 MapReduce实现复杂查询逻辑开发难度太大   Ø  为什么要使用Hive 操作接口采用类SQ
28099 0
|
机器学习/深度学习 Web App开发 算法
强化学习(Reinforcement Learning)
强化学习(Reinforcement Learning)是机器学习的一个分支,旨在让智能体(agent)通过与环境的交互学习如何做出决策以最大化累积奖励。在强化学习中,智能体通过试错的方式与环境进行交互,并根据环境的反馈(奖励或惩罚)调整自己的行为。
380 2
|
前端开发 JavaScript 测试技术
Ant Design 开源项目经验分享,你想知道的都在这儿了
如何成功的运作一个开源项目?来自Ant Design灵魂人物偏右的全干货分享。
Ant Design 开源项目经验分享,你想知道的都在这儿了
|
自然语言处理 开发者
天猫精灵技能测评实践
天猫精灵技能测评实践
11470 1
天猫精灵技能测评实践
|
存储 Android开发 iOS开发
三分钟了解Studio One6最新版二十项功能介绍及下载
Studio One是一款音乐编曲软件,是音乐工作者必不可少的创作工具,用于创建、录制、混合和掌握音乐和其他音频。无论你是第一次接触数字音乐工作站(DAW),还是第一次尝试制作属于自己的音乐,Studio One 6都能给你非凡的体验!Studio One 6新功能包括智能模板、乐谱支持歌词,全局视频轨,还有全新的声码器插件。万众期待的2022新版 Studio One 终于来了!在广受好评的5系列基础上,Studio One 6 又将给喜欢创作音乐的爱好者,带来哪些惊喜功能呢?请跟随 Studio One 中文来一探究竟!抢先体验20项全新功能吧!
2352 0