深度学习入门:使用 PyTorch 构建和训练你的第一个神经网络

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时计算 Flink 版,1000CU*H 3个月
简介: 【8月更文第29天】深度学习是机器学习的一个分支,它利用多层非线性处理单元(即神经网络)来解决复杂的模式识别问题。PyTorch 是一个强大的深度学习框架,它提供了灵活的 API 和动态计算图,非常适合初学者和研究者使用。

深度学习入门:使用 PyTorch 构建和训练你的第一个神经网络

引言

深度学习是机器学习的一个分支,它利用多层非线性处理单元(即神经网络)来解决复杂的模式识别问题。PyTorch 是一个强大的深度学习框架,它提供了灵活的 API 和动态计算图,非常适合初学者和研究者使用。

安装 PyTorch

确保安装了 Python 和 pip。然后通过以下命令安装 PyTorch:

pip install torch torchvision

导入库

我们需要导入一些必要的库:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

准备数据

首先,我们需要下载并加载 MNIST 数据集:

# 设置数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换成 Tensor
    transforms.Normalize((0.5,), (0.5,))  # 归一化
])

# 下载训练数据集
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)

# 下载测试数据集
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=64, shuffle=True)

定义模型

接下来定义我们的神经网络模型。这里我们使用一个简单的全连接网络:

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  # 输入层到隐藏层
        self.fc2 = nn.Linear(128, 64)       # 隐藏层到隐藏层
        self.fc3 = nn.Linear(64, 10)        # 最后一层到输出层

    def forward(self, x):
        x = x.view(-1, 28 * 28)             # 展平输入
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = SimpleNet()

定义损失函数和优化器

选择适当的损失函数和优化器:

criterion = nn.CrossEntropyLoss()  # 多分类任务常用的损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 使用随机梯度下降优化器

训练模型

现在我们可以开始训练模型:

num_epochs = 10

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader, 0):
        # 前向传播 + 反向传播 + 优化
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

测试模型

最后,我们需要评估模型在测试集上的性能:

correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

以上就是使用 PyTorch 构建和训练简单神经网络的完整过程。这个例子展示了如何从零开始建立一个基本的深度学习模型,并通过实际数据集对其进行训练和测试。随着你对 PyTorch 的了解加深,你可以尝试更复杂的模型结构和更高级的技术。

目录
相关文章
|
2月前
|
前端开发 JavaScript 开发者
JavaScript:构建动态网络的引擎
JavaScript:构建动态网络的引擎
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
PyTorch深度学习 ? 带你从入门到精通!!!
🌟 蒋星熠Jaxonic,深度学习探索者。三年深耕PyTorch,从基础到部署,分享模型构建、GPU加速、TorchScript优化及PyTorch 2.0新特性,助力AI开发者高效进阶。
PyTorch深度学习 ? 带你从入门到精通!!!
|
4月前
|
机器学习/深度学习 算法 量子技术
GQNN框架:让Python开发者轻松构建量子神经网络
为降低量子神经网络的研发门槛并提升其实用性,本文介绍一个名为GQNN(Generalized Quantum Neural Network)的Python开发框架。
110 4
GQNN框架:让Python开发者轻松构建量子神经网络
|
2月前
|
人工智能 监控 数据可视化
如何破解AI推理延迟难题:构建敏捷多云算力网络
本文探讨了AI企业在突破算力瓶颈后,如何构建高效、稳定的网络架构以支撑AI产品化落地。文章分析了典型AI IT架构的四个层次——流量接入层、调度决策层、推理服务层和训练算力层,并深入解析了AI架构对网络提出的三大核心挑战:跨云互联、逻辑隔离与业务识别、网络可视化与QoS控制。最终提出了一站式网络解决方案,助力AI企业实现多云调度、业务融合承载与精细化流量管理,推动AI服务高效、稳定交付。
|
1月前
|
机器学习/深度学习 分布式计算 Java
Java与图神经网络:构建企业级知识图谱与智能推理系统
图神经网络(GNN)作为处理非欧几里得数据的前沿技术,正成为企业知识管理和智能推理的核心引擎。本文深入探讨如何在Java生态中构建基于GNN的知识图谱系统,涵盖从图数据建模、GNN模型集成、分布式图计算到实时推理的全流程。通过具体的代码实现和架构设计,展示如何将先进的图神经网络技术融入传统Java企业应用,为构建下一代智能决策系统提供完整解决方案。
281 0
|
2月前
|
机器学习/深度学习 算法 搜索推荐
从零开始构建图注意力网络:GAT算法原理与数值实现详解
本文详细解析了图注意力网络(GAT)的算法原理和实现过程。GAT通过引入注意力机制解决了图卷积网络(GCN)中所有邻居节点贡献相等的局限性,让模型能够自动学习不同邻居的重要性权重。
439 0
从零开始构建图注意力网络:GAT算法原理与数值实现详解
|
4月前
|
监控 安全 Go
使用Go语言构建网络IP层安全防护
在Go语言中构建网络IP层安全防护是一项需求明确的任务,考虑到高性能、并发和跨平台的优势,Go是构建此类安全系统的合适选择。通过紧密遵循上述步骤并结合最佳实践,可以构建一个强大的网络防护系统,以保障数字环境的安全完整。
128 12
|
11月前
|
SQL 安全 网络安全
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
262 17
|
11月前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将从网络安全漏洞、加密技术和安全意识三个方面进行探讨,旨在提高读者对网络安全的认识和防范能力。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,帮助读者更好地保护自己的网络信息安全。
220 10
|
11月前
|
存储 SQL 安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将介绍网络安全的重要性,分析常见的网络安全漏洞及其危害,探讨加密技术在保障网络安全中的作用,并强调提高安全意识的必要性。通过本文的学习,读者将了解网络安全的基本概念和应对策略,提升个人和组织的网络安全防护能力。

推荐镜像

更多