- 代码分析:FedML库中client的工作
client端的大概工作流程为:向server注册,向server请求数据集,加载数据,创建模型,创建ModelTrainer,创建FedAVGTrainer,创建ClientManager,开始训练并与server交互
(1)向server注册
register()::向server post自己的id,通知server该client已上线,并从server获取训练的参数,如训练所使用的数据集、模型、学习率等
(2)向server请求数据集
download_and_unzip():server有预处理好的小数据集,向server请求下载数据,下载完成后解压
(3)加载数据
load_data_by_device():根据自己的id,加载对应的数据集,如client 0加载MNIST_mobile/0,client 1加载MNIST_mobile/1
(4)创建模型
示例中使用了包含一个隐藏层(784, 10)的神经网络
优化器选择了SGD
(5)创建ModelTrainer
ModelTrainer的功能有:
get_model_params():获取模型参数
set_model_params():设置模型参数
train():训练
test():测试
(6)创建FedAVGTrainer
对ModelTrainer的封装,功能有:
update_model():调用ModelTrainer的set_model_params()设置模型参数
update_dataset():根据client id选择训练数据
train():使用ModelTrainer的train()进行训练,并获取模型参数,返回模型参数和训练样本数
test():使用ModelTrainer的test()进行测试
(7)创建ClientManager
ClientManager用于client和server之间的通信,主要功能有:
register_message_receive_handlers():注册处理不同类型信息的函数,如处理初始化信息使用handle_message_init(),处理server发来的模型使用handle_message_receive_model_from_server()
handle_message_init():处理初始化信息,但该函数实际上没有被调用过
handle_message_model_receive_from_server():接收server发送的更新后的模型参数,使用FedAVGTrainer的update_model()进行模型参数的更新,然后开始下一轮训练;所有训练epoch完成后通知server
send_model_to_server():将自己的id、模型的参数和训练样本数发送给server
__train():调用FedAVGTrainer.train()进行训练,获取训练后的模型参数和训练的样本数,再调用send_model_to_server()上传模型参数和训练样本数给server
- 实验代码更新
跑实验代码进行模型训练。