JetBotAI 进行数据集训练脚本

简介: JetBotAI 进行数据集训练脚本

代码如下:

import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
NUM_EPOCHS = 30
BEST_MODEL_PATH = 'best_model.pth'
best_accuracy = 0.0
model = models.alexnet(pretrained=True)
dataset = datasets.ImageFolder(
    'dataset',
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 50, 50])
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4
)
device = torch.device('cuda')
model = model.to(device)
for epoch in range(NUM_EPOCHS):
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
    test_error_count = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))
    test_accuracy = 1.0 - float(test_error_count) / float(len(test_dataset))
    print('%d: %f' % (epoch, test_accuracy))
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy

执行效果:

相关文章
|
3月前
|
机器学习/深度学习 Java TensorFlow
模型推理脚本
模型推理脚本可以使用各种编程语言编写,如Python、C++、Java等。在机器学习和深度学习领域中,Python是最常用的编程语言之一,因为它有许多流行的深度学习框架,如TensorFlow、PyTorch和Keras,这些框架都提供了简单易用的API来加载模型和进行模型推理。
112 5
|
存储 机器学习/深度学习 算法
MMDetection3d对KITT数据集的训练与评估介绍
MMDetection3d对KITT数据集的训练与评估介绍
1727 0
MMDetection3d对KITT数据集的训练与评估介绍
|
10天前
|
计算机视觉
数据集介绍
【8月更文挑战第9天】数据集介绍。
20 1
|
2月前
|
算法 计算机视觉
【YOLOv8训练结果评估】YOLOv8如何使用训练好的模型对验证集进行评估及评估参数详解
【YOLOv8训练结果评估】YOLOv8如何使用训练好的模型对验证集进行评估及评估参数详解
|
XML 数据挖掘 数据格式
|
3月前
|
机器学习/深度学习 缓存 PyTorch
Yolov5如何训练自定义的数据集,以及使用GPU训练,涵盖报错解决
Yolov5如何训练自定义的数据集,以及使用GPU训练,涵盖报错解决
798 0
|
10月前
|
机器学习/深度学习 算法 数据挖掘
【数据科学】Scikit-learn[Scikit-learn、加载数据、训练集与测试集数据、创建模型、模型拟合、拟合数据与模型、评估模型性能、模型调整]
【数据科学】Scikit-learn[Scikit-learn、加载数据、训练集与测试集数据、创建模型、模型拟合、拟合数据与模型、评估模型性能、模型调整]
|
存储 大数据 Linux
基于 YOLOv8 的自定义数据集训练
基于 YOLOv8 的自定义数据集训练
|
数据采集 PyTorch 算法框架/工具
Pytorch训练一个模型的步骤总结
Pytorch训练一个模型的步骤总结
211 0
|
存储 数据可视化 数据挖掘
大五人格测试数据集的探索
大五人格测试数据集的探索
1057 0
大五人格测试数据集的探索