【tensorflow】连续输入的神经网络模型训练代码

简介: 【tensorflow】连续输入的神经网络模型训练代码

全部代码 - 复制即用


from sklearn.model_selection import train_test_split
import tensorflow as tf
import numpy as np
from keras import Input, Model, Sequential
from keras.layers import Dense, concatenate, Embedding, LSTM
from sklearn.preprocessing import StandardScaler
from tensorflow import keras
def get_data():
    # 设置随机种子,以确保结果可复现(可选)
    np.random.seed(0)
    # 生成随机数据
    data = np.random.rand(10000, 10)
    # 正则化数据
    scaler = StandardScaler()
    data = scaler.fit_transform(data)
    # 生成随机数据
    target = np.random.rand(10000, 1)
    return train_test_split(data, target, test_size=0.1, random_state=42)
data_train, data_val, target_train, target_val = get_data()
# 迭代轮次
train_epochs = 10
# 学习率
learning_rate = 0.0001
# 批大小
batch_size = 200
model = keras.models.Sequential([
    keras.layers.Dense(64, activation="relu", input_shape=[10]),
    keras.layers.Dense(64, activation="relu"),
    keras.layers.Dense(1)
])
model.summary()
model.compile(loss="mse", optimizer=keras.optimizers.Adam(lr=learning_rate))
history = model.fit(data_train, target_train, epochs=train_epochs, batch_size=batch_size, validation_data=(data_val, target_val))

训练输出


  模型结构如下:




.



代码介绍


 get_data函数用于生成随机的训练和验证数据集。首先使用np.random.rand生成一个形状为(10000, 10)的随机数据集,来模拟10维的连续输入,然后使用StandardScaler对数据进行标准化。再生成一个(10000,1)的target,表示最终拟合的目标分数。最后使用train_test_split函数将数据集划分为训练集和验证集。


 由于target是浮点数,所以我们这个任务就是回归任务了。


 使用keras.models.Sequential构建一个序列模型。模型由一系列层按顺序连接而成。在这个例子中,模型由三个全连接层构成。


 第一个隐藏层(keras.layers.Dense)具有64个神经元,使用ReLU激活函数,并指定输入形状为[10]。输入形状表示输入数据的维度。


 第二个隐藏层也是一个具有64个神经元的全连接层,同样使用ReLU激活函数。


 最后一层是输出层,由一个神经元组成,不使用激活函数。


 模型的结构是输入层(10维)→隐藏层(64个神经元,ReLU激活函数)→隐藏层(64个神经元,ReLU激活函数)→输出层(1个神经元)。


 最后,使用model.compile方法配置模型的损失函数和优化器。在这个例子中,损失函数设置为均方误差(Mean Squared Error,MSE),优化器选择Adam优化算法,并设置学习率为learning_rate。


 使用model.fit方法对模型进行训练。传入训练数据data_train和目标数据target_train,设置训练轮次train_epochs、批处理大小batch_size,以及验证集数据(data_val, target_val)。


 训练过程中,模型会根据给定的训练数据和目标数据进行参数更新,通过反向传播算法优化模型的权重和偏置。每个训练轮次(epoch)都会对整个训练数据集进行一次完整的训练。训练过程还会使用验证集数据对模型进行评估,以监控模型的性能和验证集上的损失。


 训练过程中的损失值和其他指标会被记录在history对象中,可以用于后续的可视化和分析。



相关文章
|
23天前
|
机器学习/深度学习 存储 算法
NoProp:无需反向传播,基于去噪原理的非全局梯度传播神经网络训练,可大幅降低内存消耗
反向传播算法虽是深度学习基石,但面临内存消耗大和并行扩展受限的问题。近期,牛津大学等机构提出NoProp方法,通过扩散模型概念,将训练重塑为分层去噪任务,无需全局前向或反向传播。NoProp包含三种变体(DT、CT、FM),具备低内存占用与高效训练优势,在CIFAR-10等数据集上达到与传统方法相当的性能。其层间解耦特性支持分布式并行训练,为无梯度深度学习提供了新方向。
85 1
NoProp:无需反向传播,基于去噪原理的非全局梯度传播神经网络训练,可大幅降低内存消耗
|
10天前
|
机器学习/深度学习 人工智能 算法
PaperCoder:一种利用大型语言模型自动生成机器学习论文代码的框架
PaperCoder是一种基于多智能体LLM框架的工具,可自动将机器学习研究论文转化为代码库。它通过规划、分析和生成三个阶段,系统性地实现从论文到代码的转化,解决当前研究中代码缺失导致的可复现性问题。实验表明,PaperCoder在自动生成高质量代码方面显著优于基线方法,并获得专家高度认可。这一工具降低了验证研究成果的门槛,推动科研透明与高效。
77 19
PaperCoder:一种利用大型语言模型自动生成机器学习论文代码的框架
|
5月前
|
机器学习/深度学习 人工智能 算法
猫狗宠物识别系统Python+TensorFlow+人工智能+深度学习+卷积网络算法
宠物识别系统使用Python和TensorFlow搭建卷积神经网络,基于37种常见猫狗数据集训练高精度模型,并保存为h5格式。通过Django框架搭建Web平台,用户上传宠物图片即可识别其名称,提供便捷的宠物识别服务。
594 55
|
2月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
害虫识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了12种常见的害虫种类数据集【"蚂蚁(ants)", "蜜蜂(bees)", "甲虫(beetle)", "毛虫(catterpillar)", "蚯蚓(earthworms)", "蜚蠊(earwig)", "蚱蜢(grasshopper)", "飞蛾(moth)", "鼻涕虫(slug)", "蜗牛(snail)", "黄蜂(wasp)", "象鼻虫(weevil)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Djan
177 1
基于Python深度学习的【害虫识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
3月前
|
机器学习/深度学习 人工智能 开发者
DeepSeek安装部署指南,基于阿里云PAI零代码,小白也能轻松搞定!
阿里云PAI平台支持零代码一键部署DeepSeek-V3和DeepSeek-R1大模型,用户可轻松实现从训练到部署再到推理的全流程。通过PAI Model Gallery,开发者只需简单几步即可完成模型部署,享受高效便捷的AI开发体验。具体步骤包括:开通PAI服务、进入控制台选择模型、一键部署并获取调用信息。整个过程简单快捷,极大降低了使用门槛。
1264 43
|
3月前
|
机器学习/深度学习 人工智能 算法
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
蘑菇识别系统,本系统使用Python作为主要开发语言,基于TensorFlow搭建卷积神经网络算法,并收集了9种常见的蘑菇种类数据集【"香菇(Agaricus)", "毒鹅膏菌(Amanita)", "牛肝菌(Boletus)", "网状菌(Cortinarius)", "毒镰孢(Entoloma)", "湿孢菌(Hygrocybe)", "乳菇(Lactarius)", "红菇(Russula)", "松茸(Suillus)"】 再使用通过搭建的算法模型对数据集进行训练得到一个识别精度较高的模型,然后保存为为本地h5格式文件。最后使用Django框架搭建了一个Web网页平台可视化操作界面,
215 11
基于Python深度学习的【蘑菇识别】系统~卷积神经网络+TensorFlow+图像识别+人工智能
|
3月前
|
机器学习/深度学习 文件存储 异构计算
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
330 18
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
|
3月前
|
机器学习/深度学习 数据采集 运维
机器学习在网络流量预测中的应用:运维人员的智慧水晶球?
机器学习在网络流量预测中的应用:运维人员的智慧水晶球?
176 19
|
3月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于机器学习的人脸识别算法matlab仿真,对比GRNN,PNN,DNN以及BP四种网络
本项目展示了人脸识别算法的运行效果(无水印),基于MATLAB2022A开发。核心程序包含详细中文注释及操作视频。理论部分介绍了广义回归神经网络(GRNN)、概率神经网络(PNN)、深度神经网络(DNN)和反向传播(BP)神经网络在人脸识别中的应用,涵盖各算法的结构特点与性能比较。
|
3月前
|
机器学习/深度学习 数据可视化 API
DeepSeek生成对抗网络(GAN)的训练与应用
生成对抗网络(GANs)是深度学习的重要技术,能生成逼真的图像、音频和文本数据。通过生成器和判别器的对抗训练,GANs实现高质量数据生成。DeepSeek提供强大工具和API,简化GAN的训练与应用。本文介绍如何使用DeepSeek构建、训练GAN,并通过代码示例帮助掌握相关技巧,涵盖模型定义、训练过程及图像生成等环节。

热门文章

最新文章