神经网络案例实战

简介: 使用PyTorch解决手机价格分类问题:收集包含RAM、存储等特征的手机销售数据,将价格分为4个等级的分类任务。步骤包括数据预处理、特征工程、选择神经网络模型、训练、评估和预测。模型使用Sigmoid激活的三层网络,训练时采用交叉熵损失和SGD优化器。通过调整模型结构、优化器和学习率以优化性能。

🔎我们通过一个案例详细使用PyTorch实战 ,案例背景:你创办了一家手机公司,不知道如何估算手机产品的价格。为了解决这个问题,收集了多家公司的手机销售数据:这些数据维度可以包括RAM、存储容量、屏幕尺寸、摄像头像素等。


在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。


🔎思路:


  1. 数据预处理:对收集到的数据进行清洗和预处理,确保数据的质量和一致性。这包括处理缺失值、异常值和重复数据等。


  1. 特征工程:从原始数据中提取有用的特征,以便用于建模。这可以包括对连续型特征进行归一化或标准化,对分类特征进行编码等。


  1. 模型选择:选择一个适合的机器学习算法来建立模型,这里我们使用神经网络模型。


  1. 模型训练:将收集到的数据划分为训练集和测试集。使用训练集来训练模型,通过调整模型的参数来最小化预测误差。


  1. 模型评估:使用测试集来评估模型的性能。常用的评估指标包括均方误差(MSE)、均方根误差(RMSE)和平均绝对误差(MAE)等。


  1. 价格预测:使用训练好的模型来预测新手机的价格。输入新手机的功能特征,模型将输出预测的价格。


构建数据集


💬数据共有 2000 条, 其中 1600 条数据作为训练集, 400 条数据用作测试集。 我们使用 sklearn 的数据集划分工作来完成。并使用 PyTorch 的 TensorDataset 来将数据集构建为 Dataset 对象,方便后期构造数据集加载对象。


def create_dataset():
 
    data = pd.read_csv('predict.csv')
 
    # 特征值和目标值
    x, y = data.iloc[:, :-1], data.iloc[:, -1]
    x = x.astype(np.float32)
    y = y.astype(np.int64)
 
    # 数据集划分
    x_train, x_valid, y_train, y_valid = train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)
 
    # 构建数据集
    train_dataset = TensorDataset(torch.tensor(x_train.values), torch.tensor(y_train.values))
    valid_dataset = TensorDataset(torch.tensor(x_valid.values), torch.tensor(y_valid.values))
 
    return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))
 
 
train_dataset, valid_dataset, input_dim, class_num = create_dataset()


💬其中 train_test_split 方法中的stratify=y参数的作用是在划分训练集和验证集时,保持类别的比例相同。这样可以确保在训练集和验证集中各类别的比例与原始数据集中的比例相同,有助于提高模型的泛化能力,防止出现一份中某个类别只有几个。


构建分类网络模型


# 构建网络模型
class PhonePriceModel(nn.Module):
 
    def __init__(self, input_dim, output_dim):
        super(PhonePriceModel, self).__init__()
 
        self.linear1 = nn.Linear(input_dim, 128) # 输入层到隐藏层的线性变换
        self.linear2 = nn.Linear(128, 256)  # 隐藏层到输出层的线性变换
        self.linear3 = nn.Linear(256, output_dim)  # 输出层到最终输出的线性变换
 
    def _activation(self, x):
        return torch.sigmoid(x)
 
    def forward(self, x):
 
        x = self._activation(self.linear1(x))
        # 通过第一层线性变换和激活函数
        x = self._activation(self.linear2(x))
        # 通过第二层线性变换和激活函数
        output = self.linear3(x)
        # 通过第三层线性变换得到最终输出
 
        return output


  • self.linear1和self.linear2之间的线性变换将输入维度从input_dim映射到128个神经元,然后再将128个神经元映射到256个神经元。
  • 我们通过self._activation方法对输入数据进行激活函数处理。具体来说,它使用Sigmoid激活函数对输入数据进行非线性变换,将线性变换后的输出映射到0和1之间。
  • 第三层: 输入为维度为 256, 输出维度为: 4


编写训练函数


def train():
 
    # 固定随机数种子
    torch.manual_seed(66)
 
    
    model = PhonePriceModel(input_dim, class_num)
    # 损失函数
    criterion = nn.CrossEntropyLoss()
    # 优化方法
    optimizer = optim.SGD(model.parameters(), lr=1e-3)
    
    num_epoch = 50
 
    for epoch_idx in range(num_epoch):
 
        # 初始化数据加载器
        dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)
        # 训练时间
        start = time.time()
        # 计算损失
        total_loss = 0.0
        total_num = 1
        # 准确率
        correct = 0
 
        for x, y in dataloader:
 
            output = model(x)
            # 计算损失
            loss = criterion(output, y)
            # 梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 参数更新
            optimizer.step()
 
            total_num += len(y)
            total_loss += loss.item() * len(y)
 
        print('epoch: %4s loss: %.2f, time: %.2fs' %
              (epoch_idx + 1, total_loss / total_num, time.time() - start))
 
    # 模型保存
    torch.save(model.state_dict(), 'price-model.bin')


  • 💬要在PyTorch中查看随机数种子,可以使用torch.random.initial_seed()函数。这个函数会返回当前设置的随机数种子。


评估函数


def test():
 
    # 加载模型
    model = PhonePriceModel(input_dim, class_num)
    model.load_state_dict(torch.load('price-model.bin'))
 
    # 构建加载器
    dataloader = DataLoader(valid_dataset, batch_size=8, shuffle=False)
 
    # 评估测试集
    correct = 0
    for x, y in dataloader:
 
        output = model(x)
        y_pred = torch.argmax(output, dim=1)
        correct += (y_pred == y).sum()
 
    print('Acc: %.5f' % (correct.item() / len(valid_dataset)))


网络性能调优


  1. 对输入数据进行标准化
  2. 调整优化方法
  3. 调整学习率
  4. 增加批量归一化层
  5. 增加网络层数、神经元个数
  6. 增加训练轮数


🔎我们可以自行将优化方法由 SGD 调整为 Adam , 学习率由 1e-3 调整为 1e-4 ,对数据数据进行标准化 ,增加网络深度 。



相关文章
|
8天前
|
安全 算法 网络安全
网络安全与信息安全:构建数字世界的坚固防线在数字化浪潮席卷全球的今天,网络安全与信息安全已成为维系社会秩序、保障个人隐私和企业机密的关键防线。本文旨在深入探讨网络安全漏洞的本质、加密技术的前沿进展以及提升公众安全意识的重要性,通过一系列生动的案例和实用的建议,为读者揭示如何在日益复杂的网络环境中保护自己的数字资产。
本文聚焦于网络安全与信息安全领域的核心议题,包括网络安全漏洞的识别与防御、加密技术的应用与发展,以及公众安全意识的培养策略。通过分析近年来典型的网络安全事件,文章揭示了漏洞产生的深层原因,阐述了加密技术如何作为守护数据安全的利器,并强调了提高全社会网络安全素养的紧迫性。旨在为读者提供一套全面而实用的网络安全知识体系,助力构建更加安全的数字生活环境。
|
2月前
|
存储 SQL 安全
网络安全的盾牌:漏洞防护与加密技术的实战应用
【8月更文挑战第27天】在数字化浪潮中,信息安全成为保护个人隐私和企业资产的关键。本文深入探讨了网络安全的两大支柱——安全漏洞管理和数据加密技术,以及如何通过提升安全意识来构建坚固的防御体系。我们将从基础概念出发,逐步揭示网络攻击者如何利用安全漏洞进行入侵,介绍最新的加密算法和协议,并分享实用的安全实践技巧。最终,旨在为读者提供一套全面的网络安全解决方案,以应对日益复杂的网络威胁。
|
2天前
|
SQL 安全 算法
网络安全的盾牌与剑:漏洞防御与加密技术的实战应用
【9月更文挑战第30天】在数字时代的浪潮中,网络安全成为守护信息资产的关键防线。本文深入浅出地探讨了网络安全中的两大核心议题——安全漏洞与加密技术,并辅以实例和代码演示,旨在提升公众的安全意识和技术防护能力。
|
12天前
|
网络协议 Python
告别网络编程迷雾!Python Socket编程基础与实战,让你秒变网络达人!
在网络编程的世界里,Socket编程是连接数据与服务的关键桥梁。对于初学者,这往往是最棘手的部分。本文将用Python带你轻松入门Socket编程,从创建TCP服务器与客户端的基础搭建,到处理并发连接的实战技巧,逐步揭开网络编程的神秘面纱。通过具体的代码示例,我们将掌握Socket的基本概念与操作,让你成为网络编程的高手。无论是简单的数据传输还是复杂的并发处理,Python都能助你一臂之力。希望这篇文章成为你网络编程旅程的良好开端。
35 3
|
13天前
|
机器学习/深度学习 JSON API
HTTP协议实战演练场:Python requests库助你成为网络数据抓取大师
在数据驱动的时代,网络数据抓取对于数据分析、机器学习等至关重要。HTTP协议作为互联网通信的基石,其重要性不言而喻。Python的`requests`库凭借简洁的API和强大的功能,成为网络数据抓取的利器。本文将通过实战演练展示如何使用`requests`库进行数据抓取,包括发送GET/POST请求、处理JSON响应及添加自定义请求头等。首先,请确保已安装`requests`库,可通过`pip install requests`进行安装。接下来,我们将逐一介绍如何利用`requests`库探索网络世界,助你成为数据抓取大师。在实践过程中,务必遵守相关法律法规和网站使用条款,做到技术与道德并重。
28 2
|
20天前
|
数据采集 网络协议 API
HTTP协议大揭秘!Python requests库实战,让网络请求变得简单高效
【9月更文挑战第13天】在数字化时代,互联网成为信息传输的核心平台,HTTP协议作为基石,定义了客户端与服务器间的数据传输规则。直接处理HTTP请求复杂繁琐,但Python的`requests`库提供了一个简洁强大的接口,简化了这一过程。HTTP协议采用请求与响应模式,无状态且结构化设计,使其能灵活处理各种数据交换。
47 8
|
15天前
|
数据采集 API 开发者
🚀告别网络爬虫小白!urllib与requests联手,Python网络请求实战全攻略
在网络的广阔世界里,Python凭借其简洁的语法和强大的库支持,成为开发网络爬虫的首选语言。本文将通过实战案例,带你探索urllib和requests两大神器的魅力。urllib作为Python内置库,虽API稍显繁琐,但有助于理解HTTP请求本质;requests则简化了请求流程,使开发者更专注于业务逻辑。从基本的网页内容抓取到处理Cookies与Session,我们将逐一剖析,助你从爬虫新手成长为高手。
39 1
|
2月前
|
运维 安全 应用服务中间件
自动化运维的利器:Ansible入门与实战网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
【8月更文挑战第30天】在当今快速发展的IT时代,自动化运维已成为提升效率、减少错误的关键。本文将介绍Ansible,一种流行的自动化运维工具,通过简单易懂的语言和实际案例,带领读者从零开始掌握Ansible的使用。我们将一起探索如何利用Ansible简化日常的运维任务,实现快速部署和管理服务器,以及如何处理常见问题。无论你是运维新手还是希望提高工作效率的资深人士,这篇文章都将为你开启自动化运维的新篇章。
|
2月前
|
Java
【实战演练】JAVA网络编程高手养成记:URL与URLConnection的实战技巧,一学就会!
【实战演练】JAVA网络编程高手养成记:URL与URLConnection的实战技巧,一学就会!
30 3
|
2月前
|
Java API UED
【实战秘籍】Spring Boot开发者的福音:掌握网络防抖动,告别无效请求,提升用户体验!
【8月更文挑战第29天】网络防抖动技术能有效处理频繁触发的事件或请求,避免资源浪费,提升系统响应速度与用户体验。本文介绍如何在Spring Boot中实现防抖动,并提供代码示例。通过使用ScheduledExecutorService,可轻松实现延迟执行功能,确保仅在用户停止输入后才触发操作,大幅减少服务器负载。此外,还可利用`@Async`注解简化异步处理逻辑。防抖动是优化应用性能的关键策略,有助于打造高效稳定的软件系统。
41 2
下一篇
无影云桌面