- PyTorch框架:
PyTorch是一个开源的机器学习库,由Facebook的人工智能研究团队开发,适用于Python程序,支持强大的GPU加速。PyTorch提供了两个主要功能:一是张量计算(如NumPy),二是基于这些张量的自动求导机制。这使得PyTorch非常适合进行深度学习研究和应用。PyTorch的动态计算图(也称为即时执行)是其与其他深度学习框架(如TensorFlow)的主要区别,这使得它在研究和调试方面更为灵活。 MNIST数据集:
MNIST(Modified National Institute of Standards and Technology database)是一个广泛使用的手写数字识别数据集。它包含60,000个训练样本和10,000个测试样本。每个样本都是一张28x28像素的灰度图像,代表一个0到9之间的数字。MNIST数据集通常被用作入门级的机器学习和深度学习项目,因为它既不太简单也不太复杂,非常适合初步学习和实践。
下面是使用PyTorch框架来训练一个简单的神经网络进行MNIST手写数字识别的代码示例:import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import datasets, transforms # 定义超参数 BATCH_SIZE = 64 EPOCHS = 5 LEARNING_RATE = 0.01 # 数据预处理:将数据转换为torch.FloatTensor,并标准化 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # 下载并加载训练数据 train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform) train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True) # 下载并加载测试数据 test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform) test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False) # 定义一个简单的神经网络模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(28*28, 500) self.fc2 = nn.Linear(500, 10) def forward(self, x): x = x.view(-1, 28*28) x = F.relu(self.fc1(x)) x = self.fc2(x) return x # 实例化模型 model = Net() # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE) # 训练模型 for epoch in range(EPOCHS): for images, labels in train_loader: # 前向传播 outputs = model(images) loss = criterion(outputs, labels) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}') # 测试模型 model.eval() # 将模型设置为评估模式 with torch.no_grad(): # 在这个with下,所有计算得出的tensor都不会计算梯度,也就是不会进行反向传播 correct = 0 total = 0 for images, labels in test_loader: outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')