JetBotAI 进行数据集训练脚本

简介: JetBotAI 进行数据集训练脚本

代码如下:

import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
NUM_EPOCHS = 30
BEST_MODEL_PATH = 'best_model.pth'
best_accuracy = 0.0
model = models.alexnet(pretrained=True)
dataset = datasets.ImageFolder(
    'dataset',
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 50, 50])
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4
)
device = torch.device('cuda')
model = model.to(device)
for epoch in range(NUM_EPOCHS):
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
    test_error_count = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))
    test_accuracy = 1.0 - float(test_error_count) / float(len(test_dataset))
    print('%d: %f' % (epoch, test_accuracy))
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy

执行效果:

目录
打赏
0
1
1
0
4
分享
相关文章
【yolo训练数据集】标注好的垃圾分类数据集共享
【yolo训练数据集】标注好的垃圾分类数据集共享
2699 161
【yolo训练数据集】标注好的垃圾分类数据集共享
训练数据集(一):真实场景下采集的煤矸石目标检测数据集,可直接用于YOLOv5/v6/v7/v8训练
本文介绍了一个用于煤炭与矸石分类的煤矸石目标检测数据集,包含891张训练图片和404张验证图片,分为煤炭、矸石和混合物三类。数据集已标注并划分为训练和验证集,适用于YOLOv5/v6/v7/v8训练。数据集可通过提供的链接下载。
157 1
训练数据集(一):真实场景下采集的煤矸石目标检测数据集,可直接用于YOLOv5/v6/v7/v8训练
YOLOv8训练自己的数据集+常用传参说明
YOLOv8训练自己的数据集+常用传参说明
9161 0
数据集介绍
【8月更文挑战第9天】数据集介绍。
153 1
数据集
【7月更文挑战第10天】数据集
583 1
【YOLOv8训练结果评估】YOLOv8如何使用训练好的模型对验证集进行评估及评估参数详解
【YOLOv8训练结果评估】YOLOv8如何使用训练好的模型对验证集进行评估及评估参数详解
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等