PyTorch 和 Albumentations 实现图像分类(猫狗大战)

简介: PyTorch 和 Albumentations 实现图像分类(猫狗大战)

目录


摘要


导入所需的库


下载数据集并解压缩


设置下载数据集的目录


下载数据集并解压


切分训练集、验证集和测试集


定义一个可视化图像及其标签的函数


定义一个PyTorch数据集类


使用Albumentations定义训练和验证数据集的转换函数


定义训练辅助方法


定义训练参数


训练和验证


训练模型


预测图像标签并可视化这些预测


完整代码


摘要

本示例说明如何使用Albumentations 对图像进行分类。 我们将使用``猫与狗''数据集。 任务是检测图像是否包含猫或狗。


导入所需的库

from collections import defaultdict

import copy

import random

import os

import shutil

from urllib.request import urlretrieve

import albumentations as A

from albumentations.pytorch import ToTensorV2

import cv2

import matplotlib.pyplot as plt

from tqdm import tqdm

import torch

import torch.backends.cudnn as cudnn

import torch.nn as nn

import torch.optim

from torch.utils.data import Dataset, DataLoader

import torchvision.models as models

cudnn.benchmark = True

下载数据集并解压缩

class TqdmUpTo(tqdm):

   def update_to(self, b=1, bsize=1, tsize=None):

       if tsize is not None:

           self.total = tsize

       self.update(b * bsize - self.n)

def download_url(url, filepath):

   directory = os.path.dirname(os.path.abspath(filepath))

   os.makedirs(directory, exist_ok=True)

   if os.path.exists(filepath):

       print("Filepath already exists. Skipping download.")

       return

   with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=os.path.basename(filepath)) as t:

       urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None)

       t.total = t.n

def extract_archive(filepath):

   extract_dir = os.path.dirname(os.path.abspath(filepath))

   shutil.unpack_archive(filepath, extract_dir)

设置下载数据集的目录

dataset_directory = "datasets/cats-vs-dogs"

下载数据集并解压

filepath = os.path.join(dataset_directory, "kagglecatsanddogs_3367a.zip")

download_url(

   url="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip",

   filepath=filepath,

)

extract_archive(filepath)

切分训练集、验证集和测试集

数据集中的某些文件已损坏,因此我们将仅使用OpenCV可以正确加载的那些图像文件。 我们将使用20000张图像进行训练,使用4936张图像进行验证,并使用10张图像进行测试。


root_directory = os.path.join(dataset_directory, "PetImages")

cat_directory = os.path.join(root_directory, "Cat")

dog_directory = os.path.join(root_directory, "Dog")

cat_images_filepaths = sorted([os.path.join(cat_directory, f) for f in os.listdir(cat_directory)])

dog_images_filepaths = sorted([os.path.join(dog_directory, f) for f in os.listdir(dog_directory)])

images_filepaths = [*cat_images_filepaths, *dog_images_filepaths]

correct_images_filepaths = [i for i in images_filepaths if cv2.imread(i) is not None]

random.seed(42)

random.shuffle(correct_images_filepaths)

train_images_filepaths = correct_images_filepaths[:20000]

val_images_filepaths = correct_images_filepaths[20000:-10]

test_images_filepaths = correct_images_filepaths[-10:]

print(len(train_images_filepaths), len(val_images_filepaths), len(test_images_filepaths))

20000 4936 10

定义一个可视化图像及其标签的函数

让我们定义一个函数,该函数将获取图像文件路径及其标签的列表,并在网格中将其可视化。 正确的标签为绿色,错误预测的标签为红色。


def display_image_grid(images_filepaths, predicted_labels=(), cols=5):

   rows = len(images_filepaths) // cols

   figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))

   for i, image_filepath in enumerate(images_filepaths):

       image = cv2.imread(image_filepath)

       image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

       true_label = os.path.normpath(image_filepath).split(os.sep)[-2]

       predicted_label = predicted_labels[i] if predicted_labels else true_label

       color = "green" if true_label == predicted_label else "red"

       ax.ravel()[i].imshow(image)

       ax.ravel()[i].set_title(predicted_label, color=color)

       ax.ravel()[i].set_axis_off()

   plt.tight_layout()

   plt.show()

display_image_grid(test_images_filepaths)

tt.png



定义一个PyTorch数据集类

接下来,我们定义一个PyTorch数据集。 如果您不熟悉PyTorch数据集,请参阅本教程-https://pytorch.org/tutorials/beginner/data_loading_tutorial.html。 输出任务是二进制分类-模型需要预测图像包含猫还是狗。 我们的标签将标记图像包含猫的可能性。 因此,带有猫的图像的正确标签将为1.0,带有狗的图像的正确标签将为0.0。 __init__将收到一个可选的转换参数。 它是“白化”增强管道的转换功能。 然后在__getitem__中,Dataset类将使用该函数来扩大图像并返回正确的标签。


class CatsVsDogsDataset(Dataset):

   def __init__(self, images_filepaths, transform=None):

       self.images_filepaths = images_filepaths

       self.transform = transform

   def __len__(self):

       return len(self.images_filepaths)

   def __getitem__(self, idx):

       image_filepath = self.images_filepaths[idx]

       image = cv2.imread(image_filepath)

       image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

       if os.path.normpath(image_filepath).split(os.sep)[-2] == "Cat":

           label = 1.0

       else:

           label = 0.0

       if self.transform is not None:

           image = self.transform(image=image)["image"]

       return image, label

使用Albumentations定义训练和验证数据集的转换函数

我们使用Albumentation定义用于训练和验证数据集的扩充管道。在这两个管道中,我们首先调整输入图像的大小,因此其最小尺寸为160px,然后进行128px x 128px的裁剪。对于训练数据集,我们还对该作物应用更多的增强。接下来,我们将对图像进行归一化。我们首先将图像的所有像素值除以255,因此每个像素的值将在[0.0,1.0]范围内。然后,我们将减去平均像素值,然后将其除以标准偏差。增强流水线的均值和标准差取自ImageNet数据集。尽管如此,它们仍然可以很好地传输到``猫与狗''数据集。之后,我们将应用ToTensorV2将Tombs数组转换为PyTorch张量,该张量将用作神经网络的输入。 请注意,在验证管道中,我们将使用A.CenterCrop而不是A.RandomCrop,因为我们希望验证结果具有确定性(这样就不会依赖于作物的随机位置)。


train_transform = A.Compose(

   [

       A.SmallestMaxSize(max_size=160),

       A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),

       A.RandomCrop(height=128, width=128),

       A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),

       A.RandomBrightnessContrast(p=0.5),

       A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

       ToTensorV2(),

   ]

)

train_dataset = CatsVsDogsDataset(images_filepaths=train_images_filepaths, transform=train_transform)

val_transform = A.Compose(

   [

       A.SmallestMaxSize(max_size=160),

       A.CenterCrop(height=128, width=128),

       A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

       ToTensorV2(),

   ]

)

val_dataset = CatsVsDogsDataset(images_filepaths=val_images_filepaths, transform=val_transform)

还让我们定义一个函数,该函数采用数据集并可视化应用于同一图像的不同增强。


def visualize_augmentations(dataset, idx=0, samples=10, cols=5):

   dataset = copy.deepcopy(dataset)

   dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])

   rows = samples // cols

   figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))

   for i in range(samples):

       image, _ = dataset[idx]

       ax.ravel()[i].imshow(image)

       ax.ravel()[i].set_axis_off()

   plt.tight_layout()

   plt.show()    

random.seed(42)

visualize_augmentations(train_dataset)



定义训练辅助方法

我们定义了训练的辅助方法。 compute_accuracy接受模型预测和真实标签,并将返回这些预测的准确性。 MetricMonitor有助于跟踪训练和验证过程中的准确性或损失等指标


def calculate_accuracy(output, target):

   output = torch.sigmoid(output) >= 0.5

   target = target == 1.0

   return torch.true_divide((target == output).sum(dim=0), output.size(0)).item()

class MetricMonitor:

   def __init__(self, float_precision=3):

       self.float_precision = float_precision

       self.reset()

   def reset(self):

       self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})

   def update(self, metric_name, val):

       metric = self.metrics[metric_name]

       metric["val"] += val

       metric["count"] += 1

       metric["avg"] = metric["val"] / metric["count"]

   def __str__(self):

       return " | ".join(

           [

               "{metric_name}: {avg:.{float_precision}f}".format(

                   metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision

               )

               for (metric_name, metric) in self.metrics.items()

           ]

       )

tt.png

定义训练参数

在这里,我们定义了一些训练参数,例如模型架构,学习率,batch_size,epochs等


params = {

   "model": "resnet50",

   "device": "cuda",

   "lr": 0.001,

   "batch_size": 64,

   "num_workers": 4,

   "epochs": 10,

}

训练和验证

model = getattr(models, params["model"])(pretrained=False, num_classes=1,)

model = model.to(params["device"])

criterion = nn.BCEWithLogitsLoss().to(params["device"])

optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])

train_loader = DataLoader(

   train_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=params["num_workers"], pin_memory=True,

)

val_loader = DataLoader(

   val_dataset, batch_size=params["batch_size"], shuffle=False, num_workers=params["num_workers"], pin_memory=True,

)

def train(train_loader, model, criterion, optimizer, epoch, params):

   metric_monitor = MetricMonitor()

   model.train()

   stream = tqdm(train_loader)

   for i, (images, target) in enumerate(stream, start=1):

       images = images.to(params["device"], non_blocking=True)

       target = target.to(params["device"], non_blocking=True).float().view(-1, 1)

       output = model(images)

       loss = criterion(output, target)

       accuracy = calculate_accuracy(output, target)

       metric_monitor.update("Loss", loss.item())

       metric_monitor.update("Accuracy", accuracy)

       optimizer.zero_grad()

       loss.backward()

       optimizer.step()

       stream.set_description(

           "Epoch: {epoch}. Train.      {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)

       )

def validate(val_loader, model, criterion, epoch, params):

   metric_monitor = MetricMonitor()

   model.eval()

   stream = tqdm(val_loader)

   with torch.no_grad():

       for i, (images, target) in enumerate(stream, start=1):

           images = images.to(params["device"], non_blocking=True)

           target = target.to(params["device"], non_blocking=True).float().view(-1, 1)

           output = model(images)

           loss = criterion(output, target)

           accuracy = calculate_accuracy(output, target)

           metric_monitor.update("Loss", loss.item())

           metric_monitor.update("Accuracy", accuracy)

           stream.set_description(

               "Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)

           )


训练模型


for epoch in range(1, params["epochs"] + 1):

   train(train_loader, model, criterion, optimizer, epoch, params)

   validate(val_loader, model, criterion, epoch, params)

Epoch: 1. Train.      Loss: 0.700 | Accuracy: 0.598: 100%|██████████| 313/313 [00:38<00:00,  8.04it/s]

Epoch: 1. Validation. Loss: 0.684 | Accuracy: 0.663: 100%|██████████| 78/78 [00:03<00:00, 23.46it/s]

Epoch: 2. Train.      Loss: 0.611 | Accuracy: 0.675: 100%|██████████| 313/313 [00:37<00:00,  8.24it/s]

Epoch: 2. Validation. Loss: 0.581 | Accuracy: 0.689: 100%|██████████| 78/78 [00:03<00:00, 23.25it/s]

Epoch: 3. Train.      Loss: 0.513 | Accuracy: 0.752: 100%|██████████| 313/313 [00:38<00:00,  8.22it/s]

Epoch: 3. Validation. Loss: 0.408 | Accuracy: 0.818: 100%|██████████| 78/78 [00:03<00:00, 23.61it/s]

Epoch: 4. Train.      Loss: 0.440 | Accuracy: 0.796: 100%|██████████| 313/313 [00:37<00:00,  8.24it/s]

Epoch: 4. Validation. Loss: 0.374 | Accuracy: 0.829: 100%|██████████| 78/78 [00:03<00:00, 22.89it/s]

Epoch: 5. Train.      Loss: 0.391 | Accuracy: 0.821: 100%|██████████| 313/313 [00:37<00:00,  8.25it/s]

Epoch: 5. Validation. Loss: 0.345 | Accuracy: 0.853: 100%|██████████| 78/78 [00:03<00:00, 23.03it/s]

Epoch: 6. Train.      Loss: 0.343 | Accuracy: 0.845: 100%|██████████| 313/313 [00:38<00:00,  8.22it/s]

Epoch: 6. Validation. Loss: 0.304 | Accuracy: 0.861: 100%|██████████| 78/78 [00:03<00:00, 23.88it/s]

Epoch: 7. Train.      Loss: 0.312 | Accuracy: 0.858: 100%|██████████| 313/313 [00:38<00:00,  8.23it/s]

Epoch: 7. Validation. Loss: 0.259 | Accuracy: 0.886: 100%|██████████| 78/78 [00:03<00:00, 23.29it/s]

Epoch: 8. Train.      Loss: 0.284 | Accuracy: 0.875: 100%|██████████| 313/313 [00:38<00:00,  8.21it/s]

Epoch: 8. Validation. Loss: 0.304 | Accuracy: 0.882: 100%|██████████| 78/78 [00:03<00:00, 23.81it/s]

Epoch: 9. Train.      Loss: 0.265 | Accuracy: 0.884: 100%|██████████| 313/313 [00:38<00:00,  8.18it/s]

Epoch: 9. Validation. Loss: 0.255 | Accuracy: 0.888: 100%|██████████| 78/78 [00:03<00:00, 23.78it/s]

Epoch: 10. Train.      Loss: 0.248 | Accuracy: 0.890: 100%|██████████| 313/313 [00:38<00:00,  8.21it/s]

Epoch: 10. Validation. Loss: 0.222 | Accuracy: 0.909: 100%|██████████| 78/78 [00:03<00:00, 23.90it/s]

预测图像标签并可视化这些预测

现在我们有了训练好的模型,因此让我们尝试预测一些图像的标签,看看这些预测是否正确。 首先我们制作CatsVsDogsInferenceDatasetPyTorch数据集。 它的代码类似于训练和验证数据集,但是推理数据集仅返回图像,而不返回关联的标签(因为在现实世界中,我们通常无权访问真实标签,并希望使用我们训练有素的模型来推断它们 )。


class CatsVsDogsInferenceDataset(Dataset):

   def __init__(self, images_filepaths, transform=None):

       self.images_filepaths = images_filepaths

       self.transform = transform

   def __len__(self):

       return len(self.images_filepaths)

   def __getitem__(self, idx):

       image_filepath = self.images_filepaths[idx]

       image = cv2.imread(image_filepath)

       image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

       if self.transform is not None:

           image = self.transform(image=image)["image"]

       return image

test_transform = A.Compose(

   [

       A.SmallestMaxSize(max_size=160),

       A.CenterCrop(height=128, width=128),

       A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

       ToTensorV2(),

   ]

)

test_dataset = CatsVsDogsInferenceDataset(images_filepaths=test_images_filepaths, transform=test_transform)

test_loader = DataLoader(

   test_dataset, batch_size=params["batch_size"], shuffle=False, num_workers=params["num_workers"], pin_memory=True,

)

model = model.eval()

predicted_labels = []

with torch.no_grad():

   for images in test_loader:

       images = images.to(params["device"], non_blocking=True)

       output = model(images)

       predictions = (torch.sigmoid(output) >= 0.5)[:, 0].cpu().numpy()

       predicted_labels += ["Cat" if is_cat else "Dog" for is_cat in predictions]

display_image_grid(test_images_filepaths, predicted_labels)

tt.png



完整代码

上面的代码没有问题,但是顺序有点乱,直接训练有错误,我重新整理,并做了适当的修改。


from collections import defaultdict

import copy

import random

import os

import shutil

from urllib.request import urlretrieve

import albumentations as A

from albumentations.pytorch import ToTensorV2

import cv2

import matplotlib.pyplot as plt

from tqdm import tqdm

import torch

import torch.backends.cudnn as cudnn

import torch.nn as nn

import torch.optim

from torch.utils.data import Dataset, DataLoader

import torchvision.models as models

cudnn.benchmark = True

class TqdmUpTo(tqdm):

   def update_to(self, b=1, bsize=1, tsize=None):

       if tsize is not None:

           self.total = tsize

       self.update(b * bsize - self.n)

def download_url(url, filepath):

   directory = os.path.dirname(os.path.abspath(filepath))

   os.makedirs(directory, exist_ok=True)

   if os.path.exists(filepath):

       print("Filepath already exists. Skipping download.")

       return

   with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=os.path.basename(filepath)) as t:

       urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None)

       t.total = t.n

def extract_archive(filepath):

   extract_dir = os.path.dirname(os.path.abspath(filepath))

   shutil.unpack_archive(filepath, extract_dir)

def display_image_grid(images_filepaths, predicted_labels=(), cols=5):

   rows = len(images_filepaths) // cols

   figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))

   for i, image_filepath in enumerate(images_filepaths):

       image = cv2.imread(image_filepath)

       image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

       true_label = os.path.normpath(image_filepath).split(os.sep)[-2]

       predicted_label = predicted_labels[i] if predicted_labels else true_label

       color = "green" if true_label == predicted_label else "red"

       ax.ravel()[i].imshow(image)

       ax.ravel()[i].set_title(predicted_label, color=color)

       ax.ravel()[i].set_axis_off()

   plt.tight_layout()

   plt.show()

class CatsVsDogsDataset(Dataset):

   def __init__(self, images_filepaths, transform=None):

       self.images_filepaths = images_filepaths

       self.transform = transform

   def __len__(self):

       return len(self.images_filepaths)

   def __getitem__(self, idx):

       image_filepath = self.images_filepaths[idx]

       image = cv2.imread(image_filepath)

       image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

       if os.path.normpath(image_filepath).split(os.sep)[-2] == "Cat":

           label = 1.0

       else:

           label = 0.0

       if self.transform is not None:

           image = self.transform(image=image)["image"]

       return image, label

def visualize_augmentations(dataset, idx=0, samples=10, cols=5):

   dataset = copy.deepcopy(dataset)

   dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])

   rows = samples // cols

   figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))

   for i in range(samples):

       image, _ = dataset[idx]

       ax.ravel()[i].imshow(image)

       ax.ravel()[i].set_axis_off()

   plt.tight_layout()

   plt.show()

def calculate_accuracy(output, target):

   output = torch.sigmoid(output) >= 0.5

   target = target == 1.0

   return torch.true_divide((target == output).sum(dim=0), output.size(0)).item()

class MetricMonitor:

   def __init__(self, float_precision=3):

       self.float_precision = float_precision

       self.reset()

   def reset(self):

       self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})

   def update(self, metric_name, val):

       metric = self.metrics[metric_name]

       metric["val"] += val

       metric["count"] += 1

       metric["avg"] = metric["val"] / metric["count"]

   def __str__(self):

       return " | ".join(

           [

               "{metric_name}: {avg:.{float_precision}f}".format(

                   metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision

               )

               for (metric_name, metric) in self.metrics.items()

           ]

       )

def train(train_loader, model, criterion, optimizer, epoch, params):

   metric_monitor = MetricMonitor()

   model.train()

   stream = tqdm(train_loader)

   for i, (images, target) in enumerate(stream, start=1):

       images = images.to(params["device"], non_blocking=True)

       target = target.to(params["device"], non_blocking=True).float().view(-1, 1)

       output = model(images)

       loss = criterion(output, target)

       accuracy = calculate_accuracy(output, target)

       metric_monitor.update("Loss", loss.item())

       metric_monitor.update("Accuracy", accuracy)

       optimizer.zero_grad()

       loss.backward()

       optimizer.step()

       stream.set_description(

           "Epoch: {epoch}. Train.      {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)

       )

def validate(val_loader, model, criterion, epoch, params):

   metric_monitor = MetricMonitor()

   model.eval()

   stream = tqdm(val_loader)

   with torch.no_grad():

       for i, (images, target) in enumerate(stream, start=1):

           images = images.to(params["device"], non_blocking=True)

           target = target.to(params["device"], non_blocking=True).float().view(-1, 1)

           output = model(images)

           loss = criterion(output, target)

           accuracy = calculate_accuracy(output, target)

           metric_monitor.update("Loss", loss.item())

           metric_monitor.update("Accuracy", accuracy)

           stream.set_description(

               "Epoch: {epoch}. Validation. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)

           )

class CatsVsDogsInferenceDataset(Dataset):

   def __init__(self, images_filepaths, transform=None):

       self.images_filepaths = images_filepaths

       self.transform = transform

   def __len__(self):

       return len(self.images_filepaths)

   def __getitem__(self, idx):

       image_filepath = self.images_filepaths[idx]

       image = cv2.imread(image_filepath)

       image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

       if self.transform is not None:

           image = self.transform(image=image)["image"]

       return image

if __name__ == '__main__':

   dataset_directory = "datasets/cats-vs-dogs"

   filepath = os.path.join(dataset_directory, "kagglecatsanddogs_3367a.zip")

   download_url(

       url="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip",

       filepath=filepath,

   )

   extract_archive(filepath)

   root_directory = os.path.join(dataset_directory, "PetImages")

   cat_directory = os.path.join(root_directory, "Cat")

   dog_directory = os.path.join(root_directory, "Dog")

   cat_images_filepaths = sorted([os.path.join(cat_directory, f) for f in os.listdir(cat_directory)])

   dog_images_filepaths = sorted([os.path.join(dog_directory, f) for f in os.listdir(dog_directory)])

   images_filepaths = [*cat_images_filepaths, *dog_images_filepaths]

   correct_images_filepaths = [i for i in images_filepaths if cv2.imread(i) is not None]

   random.seed(42)

   random.shuffle(correct_images_filepaths)

   train_images_filepaths = correct_images_filepaths[:20000]

   val_images_filepaths = correct_images_filepaths[20000:-10]

   test_images_filepaths = correct_images_filepaths[-10:]

   print(len(train_images_filepaths), len(val_images_filepaths), len(test_images_filepaths))

   display_image_grid(test_images_filepaths)

   train_transform = A.Compose(

       [

           A.SmallestMaxSize(max_size=160),

           A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),

           A.RandomCrop(height=128, width=128),

           A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),

           A.RandomBrightnessContrast(p=0.5),

           A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

           ToTensorV2(),

       ]

   )

   train_dataset = CatsVsDogsDataset(images_filepaths=train_images_filepaths, transform=train_transform)

   random.seed(42)

   visualize_augmentations(train_dataset)

   val_transform = A.Compose(

       [

           A.SmallestMaxSize(max_size=160),

           A.CenterCrop(height=128, width=128),

           A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

           ToTensorV2(),

       ]

   )

   val_dataset = CatsVsDogsDataset(images_filepaths=val_images_filepaths, transform=val_transform)

   params = {

       "model": "resnet50",

       "device": "cuda",

       "lr": 0.001,

       "batch_size": 64,

       "num_workers": 4,

       "epochs": 10,

   }

   model = getattr(models, params["model"])(pretrained=False, num_classes=1, )

   model = model.to(params["device"])

   criterion = nn.BCEWithLogitsLoss().to(params["device"])

   optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])

   train_loader = DataLoader(

       train_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=params["num_workers"],

       pin_memory=True,

   )

   val_loader = DataLoader(

       val_dataset, batch_size=params["batch_size"], shuffle=False, num_workers=params["num_workers"], pin_memory=True,

   )

   for epoch in range(1, params["epochs"] + 1):

       train(train_loader, model, criterion, optimizer, epoch, params)

       validate(val_loader, model, criterion, epoch, params)

   test_transform = A.Compose(

       [

           A.SmallestMaxSize(max_size=160),

           A.CenterCrop(height=128, width=128),

           A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

           ToTensorV2(),

       ]

   )

   test_dataset = CatsVsDogsInferenceDataset(images_filepaths=test_images_filepaths, transform=test_transform)

   test_loader = DataLoader(

       test_dataset, batch_size=params["batch_size"], shuffle=False, num_workers=params["num_workers"],

       pin_memory=True,

   )

   model = model.eval()

   predicted_labels = []

   with torch.no_grad():

       for images in test_loader:

           images = images.to(params["device"], non_blocking=True)

           output = model(images)

           predictions = (torch.sigmoid(output) >= 0.5)[:, 0].cpu().numpy()

           predicted_labels += ["Cat" if is_cat else "Dog" for is_cat in predictions]

   display_image_grid(test_images_filepaths, predicted_labels)


目录
相关文章
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch CIFAR10图像分类 Swin Transformer篇(一)
Pytorch CIFAR10图像分类 Swin Transformer篇(一)
|
7月前
|
机器学习/深度学习 数据可视化 算法
Pytorch CIFAR10图像分类 Swin Transformer篇(二)
Pytorch CIFAR10图像分类 Swin Transformer篇(二)
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
470 0
|
7月前
|
机器学习/深度学习 数据采集 PyTorch
PyTorch搭建卷积神经网络(ResNet-50网络)进行图像分类实战(附源码和数据集)
PyTorch搭建卷积神经网络(ResNet-50网络)进行图像分类实战(附源码和数据集)
308 1
|
7月前
|
机器学习/深度学习 数据可视化 PyTorch
Pytorch CIFAR10图像分类 ZFNet篇
Pytorch CIFAR10图像分类 ZFNet篇
|
机器学习/深度学习 数据采集 PyTorch
PyTorch应用实战二:实现卷积神经网络进行图像分类
PyTorch应用实战二:实现卷积神经网络进行图像分类
250 0
|
机器学习/深度学习 人工智能 PyTorch
【图像分类】基于OpenVINO实现PyTorch ResNet50图像分类
【图像分类】基于OpenVINO实现PyTorch ResNet50图像分类
317 0
|
机器学习/深度学习 人工智能 数据挖掘
【Deep Learning B图像分类实战】2023 Pytorch搭建AlexNet、VGG16、GoogleNet等共5个模型实现COIL20数据集图像20分类完整项目(项目已开源)
亮点:代码开源+结构清晰规范+准确率高+保姆级解析+易适配自己数据集+附原始论文+适合新手
393 0
|
机器学习/深度学习 PyTorch 算法框架/工具
计算机视觉PyTorch实现图像分类(二) - AlexNet
计算机视觉PyTorch实现图像分类(二) - AlexNet
175 0
|
机器学习/深度学习 数据采集 数据可视化
PyTorch深度学习实战 | 搭建卷积神经网络进行图像分类与图像风格迁移
PyTorch是当前主流深度学习框架之一,其设计追求最少的封装、最直观的设计,其简洁优美的特性使得PyTorch代码更易理解,对新手非常友好。 本文为实战篇,介绍搭建卷积神经网络进行图像分类与图像风格迁移。
489 0
PyTorch深度学习实战 | 搭建卷积神经网络进行图像分类与图像风格迁移

热门文章

最新文章