创建一个训练函数

简介: 【7月更文挑战第22天】创建一个训练函数。

创建一个训练函数
def train_tf(train_data):

   # 1.获取数据
   trainx = [train_d[0] for train_d in train_data]  # list
   trainy = [train_d[1] for train_d in train_data]

   # 2.构造预测的线性回归函数:y= W * x + b
   W = tf.Variable(tf.random_uniform([1]))  # 从均匀分布中返回随机值,即[0,1)
   b = tf.Variable(tf.zeros([1]))  # 在一维数组里放一个值
   y = W * trainx + b

   # 3.判断假设的函数的好坏
   cost = tf.reduce_mean(tf.square(y - trainy))

   # 4.优化函数
   optimizer = tf.train.AdamOptimizer(0.05)
   train = optimizer.minimize(cost)

   # 5.开始训练
   with tf.Session() as sess:
         # 初始化所有变量值
         sess.run(tf.global_variables_initializer())
         # 将画图模式改为交互模式
         plt.ion()
         for k in range(1000):
               sess.run(train)
               # 构造图形结构
               # 实时地输出训练好的W和b
               if k % 50 == 0:
                     print("第", k ,"步:","cost=", sess.run(cost), "W=", 

sess.run(W), "b=", sess.run(b))
plt.cla() # 清除原有图像
plt.plot(trainx, trainy, 'co', label='train data') # 显示数据
plt.plot(trainx, sess.run(y), 'y', label='train result') # 显示拟合
plt.pause(0.01)
plt.ioff() # 关闭交互模式
plt.close() # 关闭当前窗口
print("训练完成!")

         # 输出训练好的W和b
         print("finally_cost=", sess.run(cost), "finally_W=", sess.run(W), 

"finally_b=", sess.run(b))
return sess.run(W)[0], sess.run(b)[0]
tf.reduce_mean()函数用于计算张量Tensor沿着指定数轴(Tensor的某一维度)的平均值,主要用于降维或计算结果的平均值。第4步中,用梯度下降算法找最优解,通过梯度下降法为最小化损失函数增加了相关的优化操作。在训练过程中,先实例化一个优化函数,并基于一定的学习率进行梯度优化训练,如tf.train.AdamOptimizer(),该优化函数是一个寻找全局最优点的优化算法,引入了二次方梯度校正;使用minimize()操作,不仅可以优化及更新训练的模型参数,也可以为全局步骤(Global Step)计数,函数的参数传入损失值节点cost,再启动一个外层的循环,优化器就会按照循环的次数沿着cost最小值的方向优化参数。第5步开始训练,先初始化所有变量值和操作,打开plt的交互模式,开始训练并实时显示拟合的效果。

相关文章
|
9月前
|
人工智能 自然语言处理 物联网
RoSA: 一种新的大模型参数高效微调方法
随着语言模型不断扩展到前所未有的规模,对下游任务的所有参数进行微调变得非常昂贵,PEFT方法已成为自然语言处理领域的研究热点。PEFT方法将微调限制在一小部分参数中,以很小的计算成本实现自然语言理解任务的最先进性能。
226 1
|
9月前
|
机器学习/深度学习 JSON 算法
如何在自定义数据集上训练 YOLOv8 实例分割模型
在本文中,我们将介绍微调 YOLOv8-seg 预训练模型的过程,以提高其在特定目标类别上的准确性。Ikomia API简化了计算机视觉工作流的开发过程,允许轻松尝试不同的参数以达到最佳结果。
|
4月前
|
数据可视化 Linux 网络安全
如何使用服务器训练模型
本文介绍了如何使用服务器训练模型,包括获取服务器、访问服务器、上传文件、配置环境、训练模型和下载模型等步骤。适合没有GPU或不熟悉Linux服务器的用户。通过MobaXterm工具连接服务器,使用Conda管理环境,确保训练过程顺利进行。
267 0
如何使用服务器训练模型
|
4月前
|
存储 并行计算 PyTorch
探索PyTorch:模型的定义和保存方法
探索PyTorch:模型的定义和保存方法
|
9月前
|
Python
创建模型
创建模型。
41 1
|
9月前
|
机器学习/深度学习 TensorFlow 调度
优化TensorFlow模型:超参数调整与训练技巧
【4月更文挑战第17天】本文探讨了如何优化TensorFlow模型的性能,重点介绍了超参数调整和训练技巧。超参数如学习率、批量大小和层数对模型性能至关重要。文章提到了三种超参数调整策略:网格搜索、随机搜索和贝叶斯优化。此外,还分享了训练技巧,包括学习率调度、早停、数据增强和正则化,这些都有助于防止过拟合并提高模型泛化能力。结合这些方法,可构建更高效、健壮的深度学习模型。
|
9月前
|
关系型数据库 MySQL 数据库
定义模型和模型配置
定义模型和模型配置。
53 1
|
机器学习/深度学习 数据可视化 数据挖掘
使用PyTorch构建神经网络(详细步骤讲解+注释版) 02-数据读取与训练
熟悉基础数据分析的同学应该更习惯使用Pandas库对数据进行处理,此处为了加深对PyTorch的理解,我们尝试使用PyTorch读取数据。这里面用到的包是torch.utils.data.Dataset。 在下面的代码中,分别定义了len方法与getitem方法。这两个方法都是python的内置方法,但是对类并不适用。这里通过重写方法使类也可以调用,并且自定义了getitem方法的输出
使用PyTorch构建神经网络(详细步骤讲解+注释版) 02-数据读取与训练
|
机器学习/深度学习 API 算法框架/工具
Keras 高级教程:模型微调和自定义训练循环
我们在前两篇文章中介绍了如何使用 Keras 构建和训练深度学习模型的基础和中级知识。在本篇文章中,我们将探讨一些更高级的主题,包括模型微调和自定义训练循环。
|
数据采集 PyTorch 算法框架/工具
Pytorch训练一个模型的步骤总结
Pytorch训练一个模型的步骤总结
253 0

热门文章

最新文章