Caffe:深入分析(怎么训练)

简介: main()   首先入口函数caffe.cpp 1 int main(int argc, char** argv) { 2 ...... 3 if (argc == 2) { 4 #ifdef WITH_PYTHON_LAYER 5 try { 6 #endif ...

main() 

  首先入口函数caffe.cpp

 1 int main(int argc, char** argv) {
 2   ......
 3   if (argc == 2) {
 4 #ifdef WITH_PYTHON_LAYER
 5     try {
 6 #endif
 7       return GetBrewFunction(caffe::string(argv[1]))(); //根据输入参数确定是train还是test,采用string到函数指针的映射实现,非常巧妙
 8 #ifdef WITH_PYTHON_LAYER
 9     } catch (bp::error_already_set) {
10       PyErr_Print();
11       return 1;
12     }
13 #endif
14   } else {
15     gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
16   }
17 }

  在main函数中GetBrewFunction函数调用了通过工厂模式生成的由string到函数指针的map

1 typedef int (*BrewFunction)();
2 typedef std::map<caffe::string, BrewFunction> BrewMap;
3 BrewMap g_brew_map;

  在train、test、device_query、time函数后面都可以看到对这些函数的register,相当于这些函数指针已经在map中存在了

1 RegisterBrewFunction(train);
2 RegisterBrewFunction(test);
3 RegisterBrewFunction(device_query);
4 RegisterBrewFunction(time);

train()

  接着是train过程

 1 // Train / Finetune a model.
 2 int train() {
 3   ......
 4   caffe::SolverParameter solver_param;
 5   caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//从-solver参数读取solver_param
 6   ......
 7   shared_ptr<caffe::Solver<float> >
 8       solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));//从参数创建solver,同样采用string到函数指针的映射实现,用到了工厂模式
 9 
10   if (FLAGS_snapshot.size()) {//迭代snapshot次后保存模型一次
11     LOG(INFO) << "Resuming from " << FLAGS_snapshot;
12     solver->Restore(FLAGS_snapshot.c_str());
13   } else if (FLAGS_weights.size()) {//若采用finetuning,则拷贝weight到指定模型
14     CopyLayers(solver.get(), FLAGS_weights);
15   }
16 
17   if (gpus.size() > 1) {
18     caffe::P2PSync<float> sync(solver, NULL, solver->param());
19     sync.Run(gpus);
20   } else {
21     LOG(INFO) << "Starting Optimization";
22     solver->Solve();//开始训练网络
23   }
24   LOG(INFO) << "Optimization Done.";
25   return 0;
26 }

Solver()

  看CreateSolver函数是如何构建solver和net的,CreateSolver定义在solver_factory.hpp中,首先需要知道的是solver是一个基类,继承自它的类有SGD等,下面的实现就可以根据param的type构造一个指向特定solver的指针,比如SGD。

1 static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
2     const string& type = param.type();
3     CreatorRegistry& registry = Registry();
4     CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type
5         << " (known types: " << SolverTypeListString() << ")";
6     return registry[type](param);
7   }

  关键之处在于上面代码最后一行语句,它的作用是根据配置文件创建对应的Solver对象(默认为SGDSolver子类对象)。此处工厂模式和一个关键的宏REGISTER_SOLVER_CLASS(SGD)发挥了重要作用。

1 #define REGISTER_SOLVER_CLASS(type)                                              
2   template <typename Dtype>                                                      
3   Solver<Dtype>* Creator_##type##Solver(                                         
4       const SolverParameter& param)                                              
5   {                                                                              
6     return new type##Solver<Dtype>(param);                                       
7   }                                                                              
8   REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)    
9 }   

  这样一个SGDSolver对象就调用其构造函数被构造出来了。

1 explicit SGDSolver(const SolverParameter& param)
2       : Solver<Dtype>(param) { PreSolve(); }

  同时,Solver这个基类也被构造出来了,在solver.hpp里

1 explicit Solver(const SolverParameter& param,
2       const Solver* root_solver = NULL);

  Solver构造函数又会调用Init进行训练网络和测试网络的初始化,Init函数没有被声明为虚函数,不能被覆写,也就是说所有的solver都调用这个函数进行初始化。

 1 template <typename Dtype>
 2 void Solver<Dtype>::Init(const SolverParameter& param) {
 3   ......
 4   // Scaffolding code
 5   InitTrainNet();//初始化训练网络
 6   if (Caffe::root_solver()) {
 7     InitTestNets();//初始化测试网络
 8     LOG(INFO) << "Solver scaffolding done.";
 9   }
10   iter_ = 0;//迭代次数设为0
11   current_step_ = 0;
12 }

InitTrainNet()

  接下来看训练网络初始化函数InitTrainNet,具体的内容见Net的网络层的构建(源码分析)

  caffe是如何来solve的:在成员函数Solve()内部,

 1 template <typename Dtype>
 2 void Solver<Dtype>::Solve(const char* resume_file) {
 3   ......
 4   // For a network that is trained by the solver, no bottom or top vecs
 5   // should be given, and we will just provide dummy vecs.
 6   int start_iter = iter_;
 7   //开始迭代
 8   Step(param_.max_iter() - iter_);
 9   ......
10 }

Step()

  下面我们看一下Solver::Step()函数内部实现情况,具体的一次迭代过程。见Caffe参数交换源码分析

  这就是整个网络的训练过程。 

 

当神已无能为力,那便是魔渡众生
目录
相关文章
|
6月前
|
机器学习/深度学习 算法 PyTorch
挑战Transformer的新架构Mamba解析以及Pytorch复现
今天我们来详细研究这篇论文“Mamba:具有选择性状态空间的线性时间序列建模”
1371 1
|
5天前
|
机器学习/深度学习 自然语言处理 并行计算
DeepSpeed分布式训练框架深度学习指南
【11月更文挑战第6天】随着深度学习模型规模的日益增大,训练这些模型所需的计算资源和时间成本也随之增加。传统的单机训练方式已难以应对大规模模型的训练需求。
28 3
|
9天前
|
机器学习/深度学习 并行计算 Java
谈谈分布式训练框架DeepSpeed与Megatron
【11月更文挑战第3天】随着深度学习技术的不断发展,大规模模型的训练需求日益增长。为了应对这种需求,分布式训练框架应运而生,其中DeepSpeed和Megatron是两个备受瞩目的框架。本文将深入探讨这两个框架的背景、业务场景、优缺点、主要功能及底层实现逻辑,并提供一个基于Java语言的简单demo例子,帮助读者更好地理解这些技术。
28 2
|
2月前
|
机器学习/深度学习 人工智能 监控
一文读懂deepSpeed:深度学习训练的并行化
DeepSpeed 是由微软开发的开源深度学习优化库,旨在提高大规模模型训练的效率和可扩展性。通过创新的并行化策略、内存优化技术(如 ZeRO)及混合精度训练,DeepSpeed 显著提升了训练速度并降低了资源需求。它支持多种并行方法,包括数据并行、模型并行和流水线并行,同时与 PyTorch 等主流框架无缝集成,提供了易用的 API 和丰富的文档支持。DeepSpeed 不仅大幅减少了内存占用,还通过自动混合精度训练提高了计算效率,降低了能耗。其开源特性促进了 AI 行业的整体进步,使得更多研究者和开发者能够利用先进优化技术,推动了 AI 在各个领域的广泛应用。
|
3月前
|
机器学习/深度学习 人工智能 关系型数据库
【机器学习】Qwen2大模型原理、训练及推理部署实战
【机器学习】Qwen2大模型原理、训练及推理部署实战
593 0
【机器学习】Qwen2大模型原理、训练及推理部署实战
|
3月前
|
机器学习/深度学习 数据采集 物联网
【机器学习】Google开源大模型Gemma2:原理、微调训练及推理部署实战
【机器学习】Google开源大模型Gemma2:原理、微调训练及推理部署实战
117 0
|
4月前
|
机器学习/深度学习 并行计算 TensorFlow
使用Python实现深度学习模型:分布式训练与模型并行化
【7月更文挑战第9天】 使用Python实现深度学习模型:分布式训练与模型并行化
70 1
|
5月前
|
机器学习/深度学习 人工智能 自然语言处理
【机器学习】Transformer模型大小与性能探究
【机器学习】Transformer模型大小与性能探究
362 5
|
6月前
|
机器学习/深度学习 Python
超参数优化:提升机器学习模型性能
【5月更文挑战第31天】超参数优化对提升机器学习模型性能至关重要。网格搜索和随机搜索是常见方法,Python示例展示了如何使用GridSearchCV进行网格搜索。其他高级技术包括基于梯度的优化和贝叶斯优化。优化时注意选择合适评估指标、划分训练验证集,并进行迭代调整。自动化工具可简化这一过程。超参数优化是一个持续演进的领域,对于构建高性能模型具有关键作用。
93 0
|
6月前
|
机器学习/深度学习 PyTorch 测试技术
PyTorch实战:图像分类任务的实现与优化
【4月更文挑战第17天】本文介绍了使用PyTorch实现图像分类任务的步骤,包括数据集准备(如使用CIFAR-10数据集)、构建简单的CNN模型、训练与优化模型以及测试模型性能。在训练过程中,使用了交叉熵损失和SGD优化器。此外,文章还讨论了提升模型性能的策略,如调整模型结构、数据增强、正则化和利用预训练模型。通过本文,读者可掌握基础的PyTorch图像分类实践。