Keras-3-实例2-多分类问题

简介: Keras-3-实例2-多分类问题

在这个例子中,我们将使用Keras处理多分类问题。我们将使用Iris数据集,该数据集包含三个不同种类的鸢尾花,每个种类包含50个样本。每个数据点都有四个特征:萼片长度、萼片宽度、花瓣长度和花瓣宽度。

我们将使用一个简单的神经网络来训练模型,并使用softmax激活函数输出一个概率分布,即每个样本属于三个类别中的哪一个。我们还将在训练过程中使用交叉熵损失函数和随机梯度下降优化器。

代码示例:

载入数据集

from sklearn.datasets import load_iris
iris = load_iris()

将数据集拆分为训练集和测试集

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)

将标签转换为one-hot编码

from keras.utils import to_categorical
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

创建模型

from keras.models import Sequential
from keras.layers import Dense

model = Sequential()
model.add(Dense(units=10, input_dim=4, activation='relu'))
model.add(Dense(units=3, activation='softmax'))

编译模型

model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

训练模型

model.fit(X_train, y_train, epochs=50, batch_size=10)

评估模型

loss_and_metrics = model.evaluate(X_test, y_test, batch_size=10)
print(loss_and_metrics)
python

我们首先将数据集拆分为训练集和测试集,然后使用to_categorical函数将标签转换为one-hot编码。

接下来,我们创建一个Sequential模型,并为其添加两个Dense层。第一层包含10个神经元和relu激活函数,第二层包含3个神经元和softmax激活函数,以输出每个样本属于三个类别中的哪一个的概率分布。我们使用交叉熵损失函数和随机梯度下降优化器来编译模型,并指定metrics为accuracy。

我们使用fit函数训练模型,并在测试集上使用evaluate函数评估模型的性能。

相关文章
|
域名解析 网络协议 数据安全/隐私保护
DNS解析问题之授权RAM子账号管理指定域名如何解决
DNS解析是指将人类可读的域名转换成机器可读的IP地址的过程,它是互联网访问中不可或缺的一环;本合集将介绍DNS解析的机制、类型和相关问题的解决策略,以确保域名解析的准确性和高效性。
733 1
|
开发工具 对象存储 Android开发
对象存储oss使用问题之C++使用OSS SDK时遍历OSS上的文件时崩溃如何解决
《对象存储OSS操作报错合集》精选了用户在使用阿里云对象存储服务(OSS)过程中出现的各种常见及疑难报错情况,包括但不限于权限问题、上传下载异常、Bucket配置错误、网络连接问题、跨域资源共享(CORS)设定错误、数据一致性问题以及API调用失败等场景。为用户降低故障排查时间,确保OSS服务的稳定运行与高效利用。
621 0
|
机器学习/深度学习 数据采集 移动开发
【学习笔记】使用Keras构建CNN网络完成猫狗分类(适合初学者,简单易上手)
【学习笔记】使用Keras构建CNN网络完成猫狗分类(适合初学者,简单易上手)
【学习笔记】使用Keras构建CNN网络完成猫狗分类(适合初学者,简单易上手)
|
Kubernetes 容器 Perl
【Agones系列】Game Server的扩缩容
本文介绍了Agones中GameServer是如何进行扩缩容的
|
2天前
|
存储 弹性计算 人工智能
【2025云栖精华内容】 打造持续领先,全球覆盖的澎湃算力底座——通用计算产品发布与行业实践专场回顾
2025年9月24日,阿里云弹性计算团队多位产品、技术专家及服务器团队技术专家共同在【2025云栖大会】现场带来了《通用计算产品发布与行业实践》的专场论坛,本论坛聚焦弹性计算多款通用算力产品发布。同时,ECS云服务器安全能力、资源售卖模式、计算AI助手等用户体验关键环节也宣布升级,让用云更简单、更智能。海尔三翼鸟云服务负责人刘建锋先生作为特邀嘉宾,莅临现场分享了关于阿里云ECS g9i推动AIoT平台的场景落地实践。
【2025云栖精华内容】 打造持续领先,全球覆盖的澎湃算力底座——通用计算产品发布与行业实践专场回顾
|
4天前
|
云安全 数据采集 人工智能
古茗联名引爆全网,阿里云三层防护助力对抗黑产
阿里云三层校验+风险识别,为古茗每一杯奶茶保驾护航!
古茗联名引爆全网,阿里云三层防护助力对抗黑产
|
4天前
|
存储 机器学习/深度学习 人工智能
大模型微调技术:LoRA原理与实践
本文深入解析大语言模型微调中的关键技术——低秩自适应(LoRA)。通过分析全参数微调的计算瓶颈,详细阐述LoRA的数学原理、实现机制和优势特点。文章包含完整的PyTorch实现代码、性能对比实验以及实际应用场景,为开发者提供高效微调大模型的实践指南。
533 1
kde
|
4天前
|
人工智能 关系型数据库 PostgreSQL
n8n Docker 部署手册
n8n是一款开源工作流自动化平台,支持低代码与可编程模式,集成400+服务节点,原生支持AI与API连接,可自托管部署,助力团队构建安全高效的自动化流程。
kde
362 3