计算机视觉PyTorch实现风格迁移

简介: 计算机视觉PyTorch实现风格迁移

神经网络风格迁移


它主要是通过神经网络,将一幅艺术风格画(style image)和一张普通的照片(content image)巧妙地融合,形成一张非常有意思的图片。


大白话说,图像往往由风格与内容组成,比如我们常常说画家的画风是怎么样的,毕加索的画风、动漫的画风。

风格迁移就是保留一张图片的内容(物体,人物),用另一张图片的色彩画图风格去填充。


风格迁移原理


在介绍原理之前先普及一个知识点:


通常将图像输出到卷积神经网络中,在神经网络第一层隐藏层通常会找出一些简单的特征,比如边缘或者颜色阴影。在神经网络深层部分的一个隐藏单元会看到一张图片更大的部分,在极端的情况下,可以假设图像中每一个像素都会影响到神经网络层更深层的输出,靠后的隐藏层可以看到更大的图片块。


也就是说在神经网络中隐藏单元里从第一层的边缘到第二层的质地再到更深层的复杂物体


原理


首先我们需要获取一张内容图片和一张风格图片;然后定义二个度量,一个度量值为内容度量值,另一个度量为风格度量值,其中内容度量值通过生成代价函数来衡量二个图片之间的内容差异程度,风格度量也通过生成代价函数衡量图片之间风格差异程度,最后建立生成图像的神经网络模型,对内容图片中的内容和风格图片的风格进行提取,以内容图片为基准将其输入建立的模型中,通过代价函数梯度下降来调整内容度量值和风格度量值,让它们趋近于最小,最后输出的图片就是内容和风格融合的图片。



1.生成图像代价函数


想要生成出我们想要的图像,就需要定义一个代价函数,通过最小化代价函数,你可以生成任何图像

我们用C表示内容图像,用S表示风格图像,用G表示想要生成的图像


图像代价函数


1687221498794.png


15d6a47055c1480ea1d1864758abd1bc.png


代价函数梯度下降优化过程,生成图像的变化:



2.内容代价函数


1687221529425.png

内容代价函数


1687221569604.png


3.风格代价函数

1687221608298.png


Gram矩阵计算公式如下

1687221625519.png

风格代价函数

1687221648024.png

风格迁移算法实现


1687221672909.png

f64832264ea241afbca8cacdc3a79484.png

VGG-16网络结构如下图所示:

aa17fee8fa154ddd9326a9c39f62f89b.png

Fast Neuarl Style训练步骤如下:


1687221725427.png


PyTorch代码实现如下:


from __future__ import print_function
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,models
from PIL import Image
import matplotlib.pyplot as plt
from torch.autograd import Variable
import copy
#%% 图像预处理
transform=transforms.Compose([transforms.Scale([128,128]),
                              transforms.ToTensor()])
def loadimg(path=None):
    img=Image.open(path)
    img=transform(img)
    img=Variable(img)
    img=img.unsqueeze(0)
    return img
content_img=loadimg("3.jpg")
style_img=loadimg("1.jpg")
#%% 显示图片
unloader = transforms.ToPILImage()  # 重新转换成PIL图像
plt.ion()
def imshow(tensor, title=None):
    image = tensor.clone().cpu()  # 我们克隆张量以不对其进行修改
    image = image.view(3, 128, 128)  # 删除批量处理维度
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
imshow(style_img.data, title='Style Image')
plt.figure()
imshow(content_img.data, title='Content Image')
#%% 内容损失
class Content_loss(nn.Module):
    def __init__(self,target,weight):
        super(Content_loss, self).__init__()
        self.weight=weight
        self.target=target.detach()*weight
        self.loss_fn=nn.MSELoss()
    def forward(self,input):
        self.loss=self.loss_fn(input*self.weight,self.target)
        self.output = input
        return self.output
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss
#%% 风格损失
class Style_loss(nn.Module):
    def __init__(self, target, weight):
        super(Style_loss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.gram = Gram_matrix()
        self.loss_fn = nn.MSELoss()
    def forward(self, input):
        self.output = input.clone()
        self.G = self.gram(input)
        self.G.mul_(self.weight)
        self.loss = self.loss_fn(self.G, self.target)
        return self.output
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss
class Gram_matrix(nn.Module):
    def forward(self,input):
        a,b,c,d=input.size()
        feature=input.view(a*b,c*d)
        gram=torch.mm(feature, feature.t())
        return gram.div(a*b*c*d)
#%% 模型搭建
vgg=models.vgg19(pretrained=True).features
content_layer=["Conv_4"]
style_layer=['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
def get_style_model_and_losses(vgg, style_img, content_img,
                               style_weight=1000, content_weight=1,
                               content_layers=content_layer,
                               style_layers=style_layer):
    vgg = copy.deepcopy(vgg)
    # just in order to have an iterable access to or list of content/syle
    # losses
    content_losses = []
    style_losses = []
    model = nn.Sequential()  # the new Sequential module network
    gram = Gram_matrix() # we need a gram module in order to compute style targets
    # move these modules to the GPU if possible:
    i = 1
    for layer in list(vgg):
        if isinstance(layer, nn.Conv2d):
            name = "conv_" + str(i)
            model.add_module(name, layer)
            if name in content_layers:
                # add content loss:
                target = model(content_img).clone()
                content_loss = Content_loss(target, content_weight)
                model.add_module("content_loss_" + str(i), content_loss)
                content_losses.append(content_loss)
            if name in style_layers:
                # add style loss:
                target_feature = model(style_img).clone()
                target_feature_gram = gram(target_feature)
                style_loss = Style_loss(target_feature_gram, style_weight)
                model.add_module("style_loss_" + str(i), style_loss)
                style_losses.append(style_loss)
        if isinstance(layer, nn.ReLU):
            name = "relu_" + str(i)
            model.add_module(name, layer)
            if name in content_layers:
                # add content loss:
                target = model(content_img).clone()
                content_loss = Content_loss(target, content_weight)
                model.add_module("content_loss_" + str(i), content_loss)
                content_losses.append(content_loss)
            if name in style_layers:
                # add style loss:
                target_feature = model(style_img).clone()
                target_feature_gram = gram(target_feature)
                style_loss = Style_loss(target_feature_gram, style_weight)
                model.add_module("style_loss_" + str(i), style_loss)
                style_losses.append(style_loss)
            i += 1
        if isinstance(layer, nn.MaxPool2d):
            name = "pool_" + str(i)
            model.add_module(name, layer)  # ***
    return model, style_losses, content_losses
#%%输入图像
input_img = content_img.clone()
# if you want to use a white noise instead uncomment the below line:
# input_img = Variable(torch.randn(content_img.data.size())).type(dtype)
# add the original input image to the figue:
plt.figure()
imshow(input_img.data, title='Input Image')    
#%%梯度下降
def get_input_param_optimizer(input_img):
    # this line to show that input is a parameter that requires a gradient
    input_param = nn.Parameter(input_img.data)
    optimizer = torch.optim.LBFGS([input_param])
    return input_param, optimizer        
input_param,optimizer=get_input_param_optimizer(input_img)        
#%%参数优化
def run_style_transfer(cnn, content_img, style_img, input_img, num_steps=300,
                       style_weight=1000, content_weight=1):
    """Run the style transfer."""
    print('Building the style transfer model..')
    model, style_losses, content_losses = get_style_model_and_losses(cnn,
        style_img, content_img, style_weight, content_weight)
    input_param, optimizer = get_input_param_optimizer(input_img)
    print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:
        def closure():
            # correct the values of updated input image
            input_param.data.clamp_(0, 1)
            optimizer.zero_grad()
            model(input_param)
            style_score = 0
            content_score = 0
            for sl in style_losses:
                style_score += sl.backward()
            for cl in content_losses:
                content_score += cl.backward()
            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:5f} Content Loss: {:5f}'.format(
                    style_score, content_score))
            return style_score + content_score
        optimizer.step(closure)
    # a last correction...
    input_param.data.clamp_(0, 1)
    return input_param.data         
#%% 输出图像
output = run_style_transfer(vgg, content_img, style_img, input_img)
plt.figure()
imshow(output,title="Output Image")
# sphinx_gallery_thumbnail_number = 4
plt.ioff()
plt.show()

1b04f7565c19419e81d99b7f11aa6786.png9c9c642b0b774481af0295ac4207931d.png


bf4b30a73cf744a2bc85bf6469efce9f.png

相关文章
|
6月前
|
数据可视化 PyTorch 算法框架/工具
使用PyTorch搭建VGG模型进行图像风格迁移实战(附源码和数据集)
使用PyTorch搭建VGG模型进行图像风格迁移实战(附源码和数据集)
590 1
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现
本文介绍了几种常用的计算机视觉注意力机制及其PyTorch实现,包括SENet、CBAM、BAM、ECA-Net、SA-Net、Polarized Self-Attention、Spatial Group-wise Enhance和Coordinate Attention等,每种方法都附有详细的网络结构说明和实验结果分析。通过这些注意力机制的应用,可以有效提升模型在目标检测任务上的性能。此外,作者还提供了实验数据集的基本情况及baseline模型的选择与实验结果,方便读者理解和复现。
29 0
聊一聊计算机视觉中常用的注意力机制以及Pytorch代码实现
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】47. Pytorch图片样式迁移实战:将一张图片样式迁移至另一张图片,创作自己喜欢风格的图片【含完整源码】
【从零开始学习深度学习】47. Pytorch图片样式迁移实战:将一张图片样式迁移至另一张图片,创作自己喜欢风格的图片【含完整源码】
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch与迁移学习:利用预训练模型提升性能
【4月更文挑战第18天】PyTorch支持迁移学习,助力提升深度学习性能。预训练模型(如ResNet、VGG)在大规模数据集(如ImageNet)训练后,可在新任务中加速训练,提高准确率。通过选择模型、加载预训练权重、修改结构和微调,可适应不同任务需求。迁移学习节省资源,但也需考虑源任务与目标任务的相似度及超参数选择。实践案例显示,预训练模型能有效提升小数据集上的图像分类任务性能。未来,迁移学习将继续在深度学习领域发挥重要作用。
|
机器学习/深度学习 并行计算 PyTorch
迁移学习的 PyTorch 实现
迁移学习的 PyTorch 实现
|
6月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch迁移学习使用MobileNet v3网络模型进行猫狗预测二分类
MobileNet v1是MobileNet系列中的第一个版本,于2017年由Google团队提出。其主要目标是设计一个轻量级的深度神经网络,能够在移动设备和嵌入式系统上进行图像分类和目标检测任务,并且具有较高的计算效率和较小的模型大小。
291 0
|
6月前
|
机器学习/深度学习 PyTorch 语音技术
Pytorch迁移学习使用Resnet50进行模型训练预测猫狗二分类
深度学习在图像分类、目标检测、语音识别等领域取得了重大突破,但是随着网络层数的增加,梯度消失和梯度爆炸问题逐渐凸显。随着层数的增加,梯度信息在反向传播过程中逐渐变小,导致网络难以收敛。同时,梯度爆炸问题也会导致网络的参数更新过大,无法正常收敛。 为了解决这些问题,ResNet提出了一个创新的思路:引入残差块(Residual Block)。残差块的设计允许网络学习残差映射,从而减轻了梯度消失问题,使得网络更容易训练。
550 0
|
6月前
|
机器学习/深度学习 PyTorch 调度
迁移学习的 PyTorch 实现
迁移学习的 PyTorch 实现
|
机器学习/深度学习 数据采集 PyTorch
计算机视觉PyTorch迁移学习 - (一)
计算机视觉PyTorch迁移学习 - (一)

热门文章

最新文章