您好,可以参考以下代码进行训练和测试:
训练代码:
import torch
from torch import nn
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
class MyModel(nn.Module):
def init(self):
super(MyModel, self).init()
self.conv1 = nn.Conv2d(1, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 5 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(self.conv1(x))
x = self.pool(self.conv2(x))
x = x.view(-1, 16 * 5 * 5)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(MyModel.parameters(), lr=0.001)
for epoch in range(1, 10):
for data in train_loader:
inputs, labels = data
optimizer.zero_grad()
outputs = MyModel(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 测试模型
correct = 0
total = 0
for data in test_loader:
inputs, labels = data
outputs = MyModel(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy: {}%'.format(100.0 * correct / total))
torch.save(MyModel.state_dict(), 'model.pkl')
测试代码:
import torch
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)
MyModel = torch.load('model.pkl')
correct = 0
total = 0
for data in testloader:
inputs, labels = data
outputs = MyModel(inputs)
, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy: {}%'.format(100.0 * correct / total))
希望以上内容对您有所帮助。欢迎继续关注我们的后续更新。