图像分割库segmentation_models.pytorch和Albumentations 实现图像分割

简介: 图像分割库segmentation_models.pytorch和Albumentations 实现图像分割

数据集下载地址


数据集来源天池比赛:零基础入门语义分割-地表建筑物识别-天池大赛-阿里云天池 (aliyun.com)


| test_a.zip              | 314.49MB | http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531872/%E5%9C%B0%E8%A1%A8%E5%BB%BA%E7%AD%91%E7%89%A9%E8%AF%86%E5%88%AB/test_a.zip |

| test_a_samplesubmit.csv | 46.39KB  | http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531872/%E5%9C%B0%E8%A1%A8%E5%BB%BA%E7%AD%91%E7%89%A9%E8%AF%86%E5%88%AB/test_a_samplesubmit.csv |

| train.zip               | 3.68GB   | http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531872/%E5%9C%B0%E8%A1%A8%E5%BB%BA%E7%AD%91%E7%89%A9%E8%AF%86%E5%88%AB/train.zip |

| train_mask.csv.zip      | 97.52MB  | http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531872/%E5%9C%B0%E8%A1%A8%E5%BB%BA%E7%AD%91%E7%89%A9%E8%AF%86%E5%88%AB/train_mask.csv.zip |


完整代码


训练代码

#!/usr/bin/env python
# coding: utf-8
import numpy as np
import pandas as pd
import pathlib, sys, os, random, time
import numba, cv2, gc
# from tqdm import tqdm_notebook
from tqdm import tqdm
import matplotlib.pyplot as plt
# get_ipython().run_line_magic('matplotlib', 'inline')
import warnings
warnings.filterwarnings('ignore')
from sklearn.model_selection import KFold
import albumentations as A
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import torch.utils.data as D
from torchvision import transforms as T
EPOCHES = 120
BATCH_SIZE = 4
IMAGE_SIZE = 512
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
import logging
logging.basicConfig(filename='log_unet_sh_fold_4_s.log',
                    format='%(asctime)s - %(name)s - %(levelname)s -%(module)s:  %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S ',
                    level=logging.INFO)
def set_seeds(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seeds()
def rle_encode(im):
    '''
    im: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = im.flatten(order='F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)
def rle_decode(mask_rle, shape=(512, 512)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape, order='F')
train_trfm = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(),
    A.OneOf([
        A.RandomContrast(),
        A.RandomGamma(),
        A.RandomBrightness(),
        A.ColorJitter(brightness=0.07, contrast=0.07,
                      saturation=0.1, hue=0.1, always_apply=False, p=0.3),
    ], p=0.3),
])
val_trfm = A.Compose([
    A.Resize(IMAGE_SIZE, IMAGE_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90()
])
class TianChiDataset(D.Dataset):
    def __init__(self, paths, rles, transform, test_mode=False):
        self.paths = paths
        self.rles = rles
        self.transform = transform
        self.test_mode = test_mode
        self.len = len(paths)
        self.as_tensor = T.Compose([
            T.ToPILImage(),
            T.Resize(IMAGE_SIZE),
            T.ToTensor(),
            T.Normalize([0.625, 0.448, 0.688],
                        [0.131, 0.177, 0.101]),
        ])
    # get data operation
    def __getitem__(self, index):
        img = cv2.imread(self.paths[index])
        if not self.test_mode:
            mask = rle_decode(self.rles[index])
            augments = self.transform(image=img, mask=mask)
            return self.as_tensor(augments['image']), augments['mask'][None]
        else:
            return self.as_tensor(img), ''
    def __len__(self):
        """
        Total number of samples in the dataset
        """
        return self.len
train_mask = pd.read_csv('./data/train_mask.csv', sep='\t', names=['name', 'mask'])
train_mask['name'] = train_mask['name'].apply(lambda x: './data/train/' + x)
img = cv2.imread(train_mask['name'].iloc[0])
mask = rle_decode(train_mask['mask'].iloc[0])
dataset = TianChiDataset(
    train_mask['name'].values,
    train_mask['mask'].fillna('').values,
    train_trfm, False
)
skf = KFold(n_splits=5)
idx = np.array(range(len(dataset)))
@torch.no_grad()
def validation(model, loader, loss_fn):
    losses = []
    model.eval()
    for image, target in loader:
        image, target = image.to(DEVICE), target.float().to(DEVICE)
        output = model(image)
        loss = loss_fn(output, target)
        losses.append(loss.item())
    return np.array(losses).mean()
class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1., dims=(-2, -1)):
        super(SoftDiceLoss, self).__init__()
        self.smooth = smooth
        self.dims = dims
    def forward(self, x, y):
        tp = (x * y).sum(self.dims)
        fp = (x * (1 - y)).sum(self.dims)
        fn = ((1 - x) * y).sum(self.dims)
        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
        dc = dc.mean()
        return 1 - dc
bce_fn = nn.BCEWithLogitsLoss()  # nn.NLLLoss()
dice_fn = SoftDiceLoss()
def loss_fn(y_pred, y_true, ratio=0.8, hard=False):
    bce = bce_fn(y_pred, y_true)
    if hard:
        dice = dice_fn((y_pred.sigmoid()).float() > 0.5, y_true)
    else:
        dice = dice_fn(y_pred.sigmoid(), y_true)
    return ratio * bce + (1 - ratio) * dice
header = r'''
        Train | Valid
Epoch |  Loss |  Loss | Time, m
'''
#          Epoch         metrics            time
raw_line = '{:6d}' + '\u2502{:7.4f}' * 2 + '\u2502{:6.2f}'
print(header)
for fold_idx, (train_idx, valid_idx) in enumerate(skf.split(idx, idx)):
    # select folder
    if fold_idx != 4:
        continue
    train_ds = D.Subset(dataset, train_idx)
    valid_ds = D.Subset(dataset, valid_idx)
    # define training and validation data loaders
    loader = D.DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    vloader = D.DataLoader(
        valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    fold_model_path = 'fold4_unet_model_new4_s.pth'
    model = smp.Unet(
        encoder_name="efficientnet-b4",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights='imagenet',  # use `imagenet` pretreined weights for encoder initialization
        in_channels=3,  # model input channels (1 for grayscale images, 3 for RGB, etc.)
        classes=1,  # model output channels (number of classes in your dataset)
    )
    model.load_state_dict(torch.load(fold_model_path))
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1, eta_min=1e-6,                                                                last_epoch=-1)
    model.to(DEVICE)
    best_loss = 10
    for epoch in range(1, EPOCHES + 1):
        losses = []
        start_time = time.time()
        model.train()
        for image, target in tqdm(loader):
            image, target = image.to(DEVICE), target.float().to(DEVICE)
            optimizer.zero_grad()
            output = model(image)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        vloss = validation(model, vloader, loss_fn)
        scheduler.step(vloss)
        logging.info(raw_line.format(epoch, np.array(losses).mean(), vloss,
                                   (time.time() - start_time) / 60 ** 1))
        losses = []
        if vloss < best_loss:
            best_loss = vloss
            torch.save(model.state_dict(), 'fold{}_unet_model_new4_s.pth'.format(fold_idx))
            print("best loss is{}".format(best_loss))

测试代码

#!/usr/bin/env python
# coding: utf-8
import numpy as np
import pandas as pd
import pathlib, sys, os, random, time
import numba, cv2, gc
from tqdm import tqdm_notebook
from tqdm import tqdm
import matplotlib.pyplot as plt
# get_ipython().run_line_magic('matplotlib', 'inline')
import warnings
warnings.filterwarnings('ignore')
from sklearn.model_selection import KFold
import albumentations as A
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import torch.utils.data as D
from torchvision import transforms as T
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
IMAGE_SIZE = 512
trfm = T.Compose([
    T.ToPILImage(),
    T.Resize(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize([0.625, 0.448, 0.688],
                [0.131, 0.177, 0.101]),
])
subm = []
model = smp.Unet(
    encoder_name="efficientnet-b4",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights='imagenet',  # use `imagenet` pretreined weights for encoder initialization
    in_channels=3,  # model input channels (1 for grayscale images, 3 for RGB, etc.)
    classes=1,  # model output channels (number of classes in your dataset)
)
model.load_state_dict(torch.load("./fold4_unet_model_new4_s.pth"))
model.eval()
model = model.to(DEVICE)
test_mask = pd.read_csv('./data/test_a_samplesubmit.csv', sep='\t', names=['name', 'mask'])
test_mask['name'] = test_mask['name'].apply(lambda x: './data/test_a/' + x)
def rle_encode(im):
    '''
    im: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = im.flatten(order='F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)
def rle_decode(mask_rle, shape=(512, 512)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return
    Returns numpy array, 1 - mask, 0 - background
    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape, order='F')
for idx, name in enumerate(tqdm_notebook(test_mask['name'].iloc[:])):
    image = cv2.imread(name)
    image = trfm(image)
    with torch.no_grad():
        image = image.to(DEVICE)[None]
        out=model(image)
        score = model(image)[0][0]
        score_sigmoid = score.sigmoid().cpu().numpy()
        score_sigmoid = (score_sigmoid > 0.5).astype(np.uint8)
        score_sigmoid = cv2.resize(score_sigmoid, (512, 512))
        # break
    subm.append([name.split('/')[-1], rle_encode(score_sigmoid)])
subm = pd.DataFrame(subm)
subm.to_csv('./tmp.csv', index=None, header=None, sep='\t')
plt.figure(figsize=(16,8))
plt.subplot(121)
plt.imshow(rle_decode(subm[1].fillna('').iloc[0]), cmap='gray')
plt.subplot(122)
plt.imshow(cv2.imread('data/test_a/' + subm[0].iloc[0]))
plt.show()
目录
相关文章
|
缓存 PyTorch 数据处理
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
1531 0
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
|
机器学习/深度学习 PyTorch TensorFlow
在深度学习中,数据增强是一种常用的技术,用于通过增加训练数据的多样性来提高模型的泛化能力。`albumentations`是一个强大的Python库,用于图像增强,支持多种图像变换操作,并且可以与深度学习框架(如PyTorch、TensorFlow等)无缝集成。
在深度学习中,数据增强是一种常用的技术,用于通过增加训练数据的多样性来提高模型的泛化能力。`albumentations`是一个强大的Python库,用于图像增强,支持多种图像变换操作,并且可以与深度学习框架(如PyTorch、TensorFlow等)无缝集成。
|
数据可视化 计算机视觉 异构计算
确保您已经安装了必要的库,包括`torch`、`torchvision`、`segmentation_models_pytorch`、`PIL`(用于图像处理)和`matplotlib`(用于结果可视化)。您可以使用pip来安装这些库:
确保您已经安装了必要的库,包括`torch`、`torchvision`、`segmentation_models_pytorch`、`PIL`(用于图像处理)和`matplotlib`(用于结果可视化)。您可以使用pip来安装这些库:
|
人工智能 并行计算 算法
人工智能,丹青圣手,全平台(原生/Docker)构建Stable-Diffusion-Webui的AI绘画库教程(Python3.10/Pytorch1.13.0)
世间无限丹青手,遇上AI画不成。最近一段时间,可能所有人类画师都得发出一句“既生瑜,何生亮”的感叹,因为AI 绘画通用算法Stable Diffusion已然超神,无需美术基础,也不用经年累月的刻苦练习,只需要一台电脑,人人都可以是丹青圣手。
人工智能,丹青圣手,全平台(原生/Docker)构建Stable-Diffusion-Webui的AI绘画库教程(Python3.10/Pytorch1.13.0)
|
数据采集 XML 数据挖掘
计算机视觉PyTorch - 数据处理(库数据和训练自己的数据)
计算机视觉PyTorch - 数据处理(库数据和训练自己的数据)
346 1
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch torch.nn库以及nn与nn.functional有什么区别?
Pytorch torch.nn库以及nn与nn.functional有什么区别?
399 0
|
机器学习/深度学习 数据采集 PyTorch
PyTorch-Forecasting一个新的时间序列预测库
时间序列预测在金融、天气预报、销售预测和需求预测等各个领域发挥着至关重要的作用。PyTorch- forecasting是一个建立在PyTorch之上的开源Python包,专门用于简化和增强时间序列的工作。
616 0
PyTorch-Forecasting一个新的时间序列预测库
|
机器学习/深度学习 人工智能 算法
让模型训练速度提升2到4倍,「彩票假设」作者的这个全新PyTorch库火了
让模型训练速度提升2到4倍,「彩票假设」作者的这个全新PyTorch库火了
511 0
让模型训练速度提升2到4倍,「彩票假设」作者的这个全新PyTorch库火了
|
缓存 并行计算 PyTorch
终于可用可组合函数转换库!PyTorch 1.11发布,弥补JAX短板,支持Python 3.10
终于可用可组合函数转换库!PyTorch 1.11发布,弥补JAX短板,支持Python 3.10
554 0
终于可用可组合函数转换库!PyTorch 1.11发布,弥补JAX短板,支持Python 3.10
|
存储 自然语言处理 并行计算
PyTorch 新库 TorchMultimodal 使用说明:将多模态通用模型 FLAVA 扩展到 100 亿参数
PyTorch 新库 TorchMultimodal 使用说明:将多模态通用模型 FLAVA 扩展到 100 亿参数
527 0

热门文章

最新文章

推荐镜像

更多
下一篇
开通oss服务