JetBotAI 进行数据集训练脚本

简介: JetBotAI 进行数据集训练脚本

代码如下:

import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
NUM_EPOCHS = 30
BEST_MODEL_PATH = 'best_model.pth'
best_accuracy = 0.0
model = models.alexnet(pretrained=True)
dataset = datasets.ImageFolder(
    'dataset',
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 50, 50])
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4
)
device = torch.device('cuda')
model = model.to(device)
for epoch in range(NUM_EPOCHS):
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
    test_error_count = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))
    test_accuracy = 1.0 - float(test_error_count) / float(len(test_dataset))
    print('%d: %f' % (epoch, test_accuracy))
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy

执行效果:

相关文章
|
10月前
|
机器学习/深度学习 JSON 算法
如何在自定义数据集上训练 YOLOv8 实例分割模型
在本文中,我们将介绍微调 YOLOv8-seg 预训练模型的过程,以提高其在特定目标类别上的准确性。Ikomia API简化了计算机视觉工作流的开发过程,允许轻松尝试不同的参数以达到最佳结果。
|
存储 机器学习/深度学习 算法
MMDetection3d对KITT数据集的训练与评估介绍
MMDetection3d对KITT数据集的训练与评估介绍
2257 0
MMDetection3d对KITT数据集的训练与评估介绍
|
7月前
|
数据采集 人工智能 小程序
如何制作数据集并基于yolov5训练成模型并部署
这篇文章介绍了如何为YOLOv5制作数据集、训练模型、进行模型部署的整个流程,包括搜集和标注图片、创建数据集文件夹结构、编写配置文件、训练和评估模型,以及将训练好的模型部署到不同平台如ROS机器人、微信小程序和移动应用等。
如何制作数据集并基于yolov5训练成模型并部署
|
10月前
|
XML 机器学习/深度学习 数据格式
YOLOv8训练自己的数据集+常用传参说明
YOLOv8训练自己的数据集+常用传参说明
9925 0
|
XML 数据挖掘 数据格式
|
9月前
|
算法 计算机视觉
【YOLOv8训练结果评估】YOLOv8如何使用训练好的模型对验证集进行评估及评估参数详解
【YOLOv8训练结果评估】YOLOv8如何使用训练好的模型对验证集进行评估及评估参数详解
|
9月前
|
计算机视觉
【YOLOv10训练教程】如何使用YOLOv10训练自己的数据集并且推理使用
【YOLOv10训练教程】如何使用YOLOv10训练自己的数据集并且推理使用
|
10月前
|
机器学习/深度学习 缓存 PyTorch
Yolov5如何训练自定义的数据集,以及使用GPU训练,涵盖报错解决
Yolov5如何训练自定义的数据集,以及使用GPU训练,涵盖报错解决
1294 0
|
存储 大数据 Linux
基于 YOLOv8 的自定义数据集训练
基于 YOLOv8 的自定义数据集训练
|
数据采集 PyTorch 算法框架/工具
Pytorch训练一个模型的步骤总结
Pytorch训练一个模型的步骤总结
263 0